refactor: improve embedding handling and ensure float conversion for embeddings

--bug=1058254 --user=刘瑞斌 【知识库】企业版-ldap用户创建通用知识库-创建的空白文档中无法添加分段 https://www.tapd.cn/62980211/s/1725378
This commit is contained in:
CaptainB 2025-07-08 14:18:38 +08:00
parent 90abe70e2e
commit 8f00184122

View File

@ -12,7 +12,6 @@ from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List
import uuid_utils.compat as uuid import uuid_utils.compat as uuid
from common.utils.ts_vecto_util import to_ts_vector, to_query
from django.contrib.postgres.search import SearchVector from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -20,6 +19,7 @@ from langchain_core.embeddings import Embeddings
from common.db.search import generate_sql_by_query_dict from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.utils.common import get_file_content from common.utils.common import get_file_content
from common.utils.ts_vecto_util import to_ts_vector, to_query
from knowledge.models import Embedding, SearchMode, SourceType from knowledge.models import Embedding, SearchMode, SourceType
from knowledge.vector.base_vector import BaseVectorStore from knowledge.vector.base_vector import BaseVectorStore
from maxkb.conf import PROJECT_DIR from maxkb.conf import PROJECT_DIR
@ -42,36 +42,40 @@ class PGVector(BaseVectorStore):
def vector_create(self): def vector_create(self):
return True return True
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
source_id: str,
is_active: bool, is_active: bool,
embedding: Embeddings): embedding: Embeddings):
text_embedding = embedding.embed_query(text) text_embedding = [float(x) for x in embedding.embed_query(text)]
embedding = Embedding(id=uuid.uuid7(), embedding = Embedding(
knowledge_id=knowledge_id, id=uuid.uuid7(),
document_id=document_id, knowledge_id=knowledge_id,
is_active=is_active, document_id=document_id,
paragraph_id=paragraph_id, is_active=is_active,
source_id=source_id, paragraph_id=paragraph_id,
embedding=text_embedding, source_id=source_id,
source_type=source_type, embedding=text_embedding,
search_vector=to_ts_vector(text)) source_type=source_type,
search_vector=to_ts_vector(text)
)
embedding.save() embedding.save()
return True return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
texts = [row.get('text') for row in text_list] texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts) embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid7(), embedding_list = [
document_id=text_list[index].get('document_id'), Embedding(
paragraph_id=text_list[index].get('paragraph_id'), id=uuid.uuid7(),
knowledge_id=text_list[index].get('knowledge_id'), document_id=text_list[index].get('document_id'),
is_active=text_list[index].get('is_active', True), paragraph_id=text_list[index].get('paragraph_id'),
source_id=text_list[index].get('source_id'), knowledge_id=text_list[index].get('knowledge_id'),
source_type=text_list[index].get('source_type'), is_active=text_list[index].get('is_active', True),
embedding=embeddings[index], source_id=text_list[index].get('source_id'),
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for source_type=text_list[index].get('source_type'),
index in embedding=[float(x) for x in embeddings[index]],
range(0, len(texts))] search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
) for index in range(0, len(texts))]
if not is_the_task_interrupted(): if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True return True