fix: xinference rerank error
--bug=1054256 --user=王孝刚 【模型】添加硅基流动的重排序模型失败 https://www.tapd.cn/57709429/s/1679612
This commit is contained in:
parent
6cf91098d6
commit
2686e76c8a
@ -16,7 +16,6 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
|
|||||||
|
|
||||||
|
|
||||||
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||||
client: Any
|
|
||||||
server_url: Optional[str]
|
server_url: Optional[str]
|
||||||
"""URL of the xinference server"""
|
"""URL of the xinference server"""
|
||||||
model_uid: Optional[str]
|
model_uid: Optional[str]
|
||||||
@ -30,10 +29,13 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
|||||||
|
|
||||||
top_n: Optional[int] = 3
|
top_n: Optional[int] = 3
|
||||||
|
|
||||||
def __init__(
|
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3,
|
Sequence[Document]:
|
||||||
api_key: Optional[str] = None
|
if documents is None or len(documents) == 0:
|
||||||
):
|
return []
|
||||||
|
client: Any
|
||||||
|
if documents is None or len(documents) == 0:
|
||||||
|
return []
|
||||||
try:
|
try:
|
||||||
from xinference.client import RESTfulClient
|
from xinference.client import RESTfulClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -45,29 +47,8 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
|||||||
" with `pip install xinference` or `pip install xinference_client`."
|
" with `pip install xinference` or `pip install xinference_client`."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
super().__init__()
|
client = RESTfulClient(self.server_url, self.api_key)
|
||||||
|
model: RESTfulRerankModelHandle = client.get_model(self.model_uid)
|
||||||
if server_url is None:
|
|
||||||
raise ValueError("Please provide server URL")
|
|
||||||
|
|
||||||
if model_uid is None:
|
|
||||||
raise ValueError("Please provide the model UID")
|
|
||||||
|
|
||||||
self.server_url = server_url
|
|
||||||
|
|
||||||
self.model_uid = model_uid
|
|
||||||
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
self.client = RESTfulClient(server_url, api_key)
|
|
||||||
|
|
||||||
self.top_n = top_n
|
|
||||||
|
|
||||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
|
||||||
Sequence[Document]:
|
|
||||||
if documents is None or len(documents) == 0:
|
|
||||||
return []
|
|
||||||
model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid)
|
|
||||||
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
|
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
|
||||||
return [Document(page_content=d.get('document', {}).get('text'),
|
return [Document(page_content=d.get('document', {}).get('text'),
|
||||||
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]
|
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user