refactor: ollama support rerank
This commit is contained in:
parent
f02b40b830
commit
9185515660
@ -64,4 +64,3 @@ class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.TextInputField('API Key', required=True)
|
|
||||||
|
|||||||
@ -1,82 +1,48 @@
|
|||||||
from typing import Sequence, Optional, Any, Dict
|
from typing import Sequence, Optional, Any, Dict
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
from langchain_core.documents import Document
|
||||||
import requests
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||||
api_base: Optional[str]
|
top_n: Optional[int] = Field(3, description="Number of top documents to return")
|
||||||
"""URL of the Ollama server"""
|
|
||||||
model_name: Optional[str]
|
|
||||||
"""The model name to use for reranking"""
|
|
||||||
api_key: Optional[str]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
|
return OllamaReranker(
|
||||||
|
model=model_name,
|
||||||
top_n: Optional[int] = 3
|
base_url=model_credential.get('api_base'),
|
||||||
|
**optional_params
|
||||||
def __init__(
|
)
|
||||||
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if api_base is None:
|
|
||||||
raise ValueError("Please provide server URL")
|
|
||||||
|
|
||||||
if model_name is None:
|
|
||||||
raise ValueError("Please provide the model name")
|
|
||||||
|
|
||||||
self.api_base = api_base
|
|
||||||
self.model_name = model_name
|
|
||||||
self.api_key = api_key
|
|
||||||
self.top_n = top_n
|
|
||||||
|
|
||||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||||
Sequence[Document]:
|
Sequence[Document]:
|
||||||
"""
|
"""Rank documents based on their similarity to the query.
|
||||||
Given a query and a set of documents, rerank them using Ollama API.
|
|
||||||
"""
|
|
||||||
if not documents or len(documents) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Prepare the data to send to Ollama API
|
Args:
|
||||||
headers = {
|
query: The query text.
|
||||||
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
|
documents: The list of document texts to rank.
|
||||||
}
|
|
||||||
|
|
||||||
# Format the documents to be sent in a format understood by Ollama's API
|
Returns:
|
||||||
documents_text = [document.page_content for document in documents]
|
List of documents sorted by relevance to the query.
|
||||||
|
"""
|
||||||
|
# 获取查询和文档的嵌入
|
||||||
|
query_embedding = self.embed_query(query)
|
||||||
|
documents = [doc.page_content for doc in documents]
|
||||||
|
document_embeddings = self.embed_documents(documents)
|
||||||
|
# 计算相似度
|
||||||
|
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
|
||||||
|
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
page_content=doc, # 第一个值是文档内容
|
||||||
|
metadata={'relevance_score': score} # 第二个值是相似度分数
|
||||||
|
)
|
||||||
|
for doc, score in ranked_docs
|
||||||
|
]
|
||||||
|
|
||||||
# Make a POST request to Ollama's rerank API endpoint
|
|
||||||
payload = {
|
|
||||||
'model': self.model_name, # Specify the model
|
|
||||||
'query': query,
|
|
||||||
'documents': documents_text,
|
|
||||||
'top_n': self.top_n
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
res = response.json()
|
|
||||||
|
|
||||||
# Ensure the response contains expected results
|
|
||||||
if 'results' not in res:
|
|
||||||
raise ValueError("The API response did not contain rerank results.")
|
|
||||||
|
|
||||||
# Convert the API response into a list of Document objects with relevance scores
|
|
||||||
ranked_documents = [
|
|
||||||
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
|
|
||||||
for d in res['results']
|
|
||||||
]
|
|
||||||
return ranked_documents
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
print(f"Error during API request: {e}")
|
|
||||||
return [] # Return an empty list if the request failed
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user