refactor: improve embedding handling and ensure float conversion for embeddings
--bug=1058254 --user=刘瑞斌 【知识库】企业版-ldap用户创建通用知识库-创建的空白文档中无法添加分段 https://www.tapd.cn/62980211/s/1725378
This commit is contained in:
parent
90abe70e2e
commit
8f00184122
@ -12,7 +12,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from common.utils.ts_vecto_util import to_ts_vector, to_query
|
||||
from django.contrib.postgres.search import SearchVector
|
||||
from django.db.models import QuerySet, Value
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -20,6 +19,7 @@ from langchain_core.embeddings import Embeddings
|
||||
from common.db.search import generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_list
|
||||
from common.utils.common import get_file_content
|
||||
from common.utils.ts_vecto_util import to_ts_vector, to_query
|
||||
from knowledge.models import Embedding, SearchMode, SourceType
|
||||
from knowledge.vector.base_vector import BaseVectorStore
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
@ -42,36 +42,40 @@ class PGVector(BaseVectorStore):
|
||||
def vector_create(self):
|
||||
return True
|
||||
|
||||
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
|
||||
source_id: str,
|
||||
is_active: bool,
|
||||
embedding: Embeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid7(),
|
||||
knowledge_id=knowledge_id,
|
||||
document_id=document_id,
|
||||
is_active=is_active,
|
||||
paragraph_id=paragraph_id,
|
||||
source_id=source_id,
|
||||
embedding=text_embedding,
|
||||
source_type=source_type,
|
||||
search_vector=to_ts_vector(text))
|
||||
text_embedding = [float(x) for x in embedding.embed_query(text)]
|
||||
embedding = Embedding(
|
||||
id=uuid.uuid7(),
|
||||
knowledge_id=knowledge_id,
|
||||
document_id=document_id,
|
||||
is_active=is_active,
|
||||
paragraph_id=paragraph_id,
|
||||
source_id=source_id,
|
||||
embedding=text_embedding,
|
||||
source_type=source_type,
|
||||
search_vector=to_ts_vector(text)
|
||||
)
|
||||
embedding.save()
|
||||
return True
|
||||
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
||||
texts = [row.get('text') for row in text_list]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embedding_list = [Embedding(id=uuid.uuid7(),
|
||||
document_id=text_list[index].get('document_id'),
|
||||
paragraph_id=text_list[index].get('paragraph_id'),
|
||||
knowledge_id=text_list[index].get('knowledge_id'),
|
||||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
embedding=embeddings[index],
|
||||
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
|
||||
index in
|
||||
range(0, len(texts))]
|
||||
embedding_list = [
|
||||
Embedding(
|
||||
id=uuid.uuid7(),
|
||||
document_id=text_list[index].get('document_id'),
|
||||
paragraph_id=text_list[index].get('paragraph_id'),
|
||||
knowledge_id=text_list[index].get('knowledge_id'),
|
||||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
embedding=[float(x) for x in embeddings[index]],
|
||||
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
|
||||
) for index in range(0, len(texts))]
|
||||
if not is_the_task_interrupted():
|
||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
return True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user