refactor: improve embedding handling and ensure float conversion for embeddings
--bug=1058254 --user=刘瑞斌 【知识库】企业版-ldap用户创建通用知识库-创建的空白文档中无法添加分段 https://www.tapd.cn/62980211/s/1725378
This commit is contained in:
parent
90abe70e2e
commit
8f00184122
@ -12,7 +12,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import uuid_utils.compat as uuid
|
import uuid_utils.compat as uuid
|
||||||
from common.utils.ts_vecto_util import to_ts_vector, to_query
|
|
||||||
from django.contrib.postgres.search import SearchVector
|
from django.contrib.postgres.search import SearchVector
|
||||||
from django.db.models import QuerySet, Value
|
from django.db.models import QuerySet, Value
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -20,6 +19,7 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from common.db.search import generate_sql_by_query_dict
|
from common.db.search import generate_sql_by_query_dict
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.utils.common import get_file_content
|
from common.utils.common import get_file_content
|
||||||
|
from common.utils.ts_vecto_util import to_ts_vector, to_query
|
||||||
from knowledge.models import Embedding, SearchMode, SourceType
|
from knowledge.models import Embedding, SearchMode, SourceType
|
||||||
from knowledge.vector.base_vector import BaseVectorStore
|
from knowledge.vector.base_vector import BaseVectorStore
|
||||||
from maxkb.conf import PROJECT_DIR
|
from maxkb.conf import PROJECT_DIR
|
||||||
@ -42,36 +42,40 @@ class PGVector(BaseVectorStore):
|
|||||||
def vector_create(self):
|
def vector_create(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
|
||||||
|
source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: Embeddings):
|
embedding: Embeddings):
|
||||||
text_embedding = embedding.embed_query(text)
|
text_embedding = [float(x) for x in embedding.embed_query(text)]
|
||||||
embedding = Embedding(id=uuid.uuid7(),
|
embedding = Embedding(
|
||||||
knowledge_id=knowledge_id,
|
id=uuid.uuid7(),
|
||||||
document_id=document_id,
|
knowledge_id=knowledge_id,
|
||||||
is_active=is_active,
|
document_id=document_id,
|
||||||
paragraph_id=paragraph_id,
|
is_active=is_active,
|
||||||
source_id=source_id,
|
paragraph_id=paragraph_id,
|
||||||
embedding=text_embedding,
|
source_id=source_id,
|
||||||
source_type=source_type,
|
embedding=text_embedding,
|
||||||
search_vector=to_ts_vector(text))
|
source_type=source_type,
|
||||||
|
search_vector=to_ts_vector(text)
|
||||||
|
)
|
||||||
embedding.save()
|
embedding.save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
||||||
texts = [row.get('text') for row in text_list]
|
texts = [row.get('text') for row in text_list]
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
embedding_list = [Embedding(id=uuid.uuid7(),
|
embedding_list = [
|
||||||
document_id=text_list[index].get('document_id'),
|
Embedding(
|
||||||
paragraph_id=text_list[index].get('paragraph_id'),
|
id=uuid.uuid7(),
|
||||||
knowledge_id=text_list[index].get('knowledge_id'),
|
document_id=text_list[index].get('document_id'),
|
||||||
is_active=text_list[index].get('is_active', True),
|
paragraph_id=text_list[index].get('paragraph_id'),
|
||||||
source_id=text_list[index].get('source_id'),
|
knowledge_id=text_list[index].get('knowledge_id'),
|
||||||
source_type=text_list[index].get('source_type'),
|
is_active=text_list[index].get('is_active', True),
|
||||||
embedding=embeddings[index],
|
source_id=text_list[index].get('source_id'),
|
||||||
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
|
source_type=text_list[index].get('source_type'),
|
||||||
index in
|
embedding=[float(x) for x in embeddings[index]],
|
||||||
range(0, len(texts))]
|
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
|
||||||
|
) for index in range(0, len(texts))]
|
||||||
if not is_the_task_interrupted():
|
if not is_the_task_interrupted():
|
||||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||||
return True
|
return True
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user