refactor: improve embedding handling and ensure float conversion for embeddings

--bug=1058254 --user=刘瑞斌 【知识库】企业版-ldap用户创建通用知识库-创建的空白文档中无法添加分段 https://www.tapd.cn/62980211/s/1725378
This commit is contained in:
CaptainB 2025-07-08 14:18:38 +08:00
parent 90abe70e2e
commit 8f00184122

View File

@ -12,7 +12,6 @@ from abc import ABC, abstractmethod
from typing import Dict, List
import uuid_utils.compat as uuid
from common.utils.ts_vecto_util import to_ts_vector, to_query
from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings
@ -20,6 +19,7 @@ from langchain_core.embeddings import Embeddings
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.utils.common import get_file_content
from common.utils.ts_vecto_util import to_ts_vector, to_query
from knowledge.models import Embedding, SearchMode, SourceType
from knowledge.vector.base_vector import BaseVectorStore
from maxkb.conf import PROJECT_DIR
@ -42,36 +42,40 @@ class PGVector(BaseVectorStore):
def vector_create(self):
return True
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
source_id: str,
is_active: bool,
embedding: Embeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid7(),
knowledge_id=knowledge_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
search_vector=to_ts_vector(text))
text_embedding = [float(x) for x in embedding.embed_query(text)]
embedding = Embedding(
id=uuid.uuid7(),
knowledge_id=knowledge_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
search_vector=to_ts_vector(text)
)
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid7(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
knowledge_id=text_list[index].get('knowledge_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index],
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
index in
range(0, len(texts))]
embedding_list = [
Embedding(
id=uuid.uuid7(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
knowledge_id=text_list[index].get('knowledge_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=[float(x) for x in embeddings[index]],
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
) for index in range(0, len(texts))]
if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True