feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-19 16:44:53 +08:00
parent f372095f51
commit a9b8bdd365
2 changed files with 30 additions and 4 deletions

View File

@ -159,7 +159,7 @@ class ListenerManagement:
@param embedding_model 向量模型 @param embedding_model 向量模型
:return: None :return: None
""" """
if not try_lock('embedding' + document_id): if not try_lock('embedding' + str(document_id)):
return return
max_kb.info(f"开始--->向量化文档:{document_id}") max_kb.info(f"开始--->向量化文档:{document_id}")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
@ -186,7 +186,7 @@ class ListenerManagement:
**{'status': status, 'update_time': datetime.datetime.now()}) **{'status': status, 'update_time': datetime.datetime.now()})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
max_kb.info(f"结束--->向量化文档:{document_id}") max_kb.info(f"结束--->向量化文档:{document_id}")
un_lock('embedding' + document_id) un_lock('embedding' + str(document_id))
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy

View File

@ -6,7 +6,7 @@
@date2024/7/12 15:02 @date2024/7/12 15:02
@desc: @desc:
""" """
from typing import Dict from typing import Dict, List
from langchain_community.embeddings import OllamaEmbeddings from langchain_community.embeddings import OllamaEmbeddings
@ -16,7 +16,33 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings): class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OllamaEmbeddings( return OllamaEmbedding(
model=model_name, model=model_name,
base_url=model_credential.get('api_base'), base_url=model_credential.get('api_base'),
) )
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using an Ollama deployed embedding model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = [f"{text}" for text in texts]
embeddings = self._embed(instruction_pairs)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a query using a Ollama deployed embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
instruction_pair = f"{text}"
embedding = self._embed([instruction_pair])[0]
return embedding