diff --git a/apps/knowledge/vector/pg_vector.py b/apps/knowledge/vector/pg_vector.py index de912448..d8235d2c 100644 --- a/apps/knowledge/vector/pg_vector.py +++ b/apps/knowledge/vector/pg_vector.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from typing import Dict, List import uuid_utils.compat as uuid -from common.utils.ts_vecto_util import to_ts_vector, to_query from django.contrib.postgres.search import SearchVector from django.db.models import QuerySet, Value from langchain_core.embeddings import Embeddings @@ -20,6 +19,7 @@ from langchain_core.embeddings import Embeddings from common.db.search import generate_sql_by_query_dict from common.db.sql_execute import select_list from common.utils.common import get_file_content +from common.utils.ts_vecto_util import to_ts_vector, to_query from knowledge.models import Embedding, SearchMode, SourceType from knowledge.vector.base_vector import BaseVectorStore from maxkb.conf import PROJECT_DIR @@ -42,36 +42,40 @@ class PGVector(BaseVectorStore): def vector_create(self): return True - def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str, + def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, + source_id: str, is_active: bool, embedding: Embeddings): - text_embedding = embedding.embed_query(text) - embedding = Embedding(id=uuid.uuid7(), - knowledge_id=knowledge_id, - document_id=document_id, - is_active=is_active, - paragraph_id=paragraph_id, - source_id=source_id, - embedding=text_embedding, - source_type=source_type, - search_vector=to_ts_vector(text)) + text_embedding = [float(x) for x in embedding.embed_query(text)] + embedding = Embedding( + id=uuid.uuid7(), + knowledge_id=knowledge_id, + document_id=document_id, + is_active=is_active, + paragraph_id=paragraph_id, + source_id=source_id, + embedding=text_embedding, + source_type=source_type, + search_vector=to_ts_vector(text) + ) embedding.save() return True def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) - embedding_list = [Embedding(id=uuid.uuid7(), - document_id=text_list[index].get('document_id'), - paragraph_id=text_list[index].get('paragraph_id'), - knowledge_id=text_list[index].get('knowledge_id'), - is_active=text_list[index].get('is_active', True), - source_id=text_list[index].get('source_id'), - source_type=text_list[index].get('source_type'), - embedding=embeddings[index], - search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for - index in - range(0, len(texts))] + embedding_list = [ + Embedding( + id=uuid.uuid7(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + knowledge_id=text_list[index].get('knowledge_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + embedding=[float(x) for x in embeddings[index]], + search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text']))) + ) for index in range(0, len(texts))] if not is_the_task_interrupted(): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True