maxkb/apps/knowledge/vector/pg_vector.py
朱潮 9d65e181eb
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run
add log
2025-08-27 01:41:47 +08:00

344 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file pg_vector.py
@date2023/10/19 15:28
@desc:
"""
import json
import os
from abc import ABC, abstractmethod
from typing import Dict, List
import uuid_utils.compat as uuid
from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.utils.common import get_file_content
from common.utils.ts_vecto_util import to_ts_vector, to_query
from knowledge.models import Embedding, SearchMode, SourceType
from knowledge.vector.base_vector import BaseVectorStore
from maxkb.conf import PROJECT_DIR
class PGVector(BaseVectorStore):
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
if len(source_ids) == 0:
return
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
def vector_create(self):
return True
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
source_id: str,
is_active: bool,
embedding: Embeddings):
text_embedding = [float(x) for x in embedding.embed_query(text)]
embedding = Embedding(
id=uuid.uuid7(),
knowledge_id=knowledge_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
search_vector=to_ts_vector(text)
)
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
from common.utils.logger import maxkb_logger
texts = [row.get('text') for row in text_list]
maxkb_logger.info(f"PGVector batch_save: Processing {len(texts)} texts")
# Log details of first few items for debugging
for i, item in enumerate(text_list[:3]):
maxkb_logger.debug(f"Item {i}: document_id={item.get('document_id')}, "
f"paragraph_id={item.get('paragraph_id')}, "
f"is_active={item.get('is_active', True)}, "
f"text_preview='{item.get('text', '')[:50]}...'")
embeddings = embedding.embed_documents(texts)
embedding_list = [
Embedding(
id=uuid.uuid7(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
knowledge_id=text_list[index].get('knowledge_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=[float(x) for x in embeddings[index]],
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
) for index in range(0, len(texts))]
maxkb_logger.info(f"PGVector batch_save: Created {len(embedding_list)} embedding objects")
if not is_the_task_interrupted():
if len(embedding_list) > 0:
QuerySet(Embedding).bulk_create(embedding_list)
maxkb_logger.info(f"PGVector batch_save: Successfully saved {len(embedding_list)} embeddings to database")
else:
maxkb_logger.warning("PGVector batch_save: No embeddings to save")
else:
maxkb_logger.warning("PGVector batch_save: Task interrupted, embeddings not saved")
return True
def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
exclude_dict = {}
embedding_query = embedding.embed_query(query_text)
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
query_set = query_set.exclude(**exclude_dict)
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
from common.utils.logger import maxkb_logger
exclude_dict = {}
if knowledge_id_list is None or len(knowledge_id_list) == 0:
maxkb_logger.warning("Vector query: knowledge_id_list is empty")
return []
maxkb_logger.info(f"Vector query starting: query_text='{query_text[:50]}...', knowledge_ids={knowledge_id_list}, "
f"is_active={is_active}, top_n={top_n}, similarity={similarity}, search_mode={search_mode.value}")
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
initial_count = query_set.count()
maxkb_logger.info(f"Initial embedding count: {initial_count}")
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
maxkb_logger.info(f"After excluding documents: {query_set.count()} embeddings")
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
maxkb_logger.info(f"After excluding paragraphs: {query_set.count()} embeddings")
query_set = query_set.exclude(**exclude_dict)
final_count = query_set.count()
maxkb_logger.info(f"Final embedding count before search: {final_count}")
for search_handle in search_handle_list:
if search_handle.support(search_mode):
maxkb_logger.info(f"Using search handler: {search_handle.__class__.__name__}")
results = search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
maxkb_logger.info(f"Search results: {len(results)} items found")
if len(results) > 0:
maxkb_logger.info(f"Top result similarity: {results[0].get('similarity', 'N/A')}")
return results
maxkb_logger.warning("No suitable search handler found")
return []
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)
def delete_by_knowledge_id(self, knowledge_id: str):
QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete()
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete()
def delete_by_document_id(self, document_id: str):
QuerySet(Embedding).filter(document_id=document_id).delete()
return True
def delete_by_document_id_list(self, document_id_list: List[str]):
if len(document_id_list) == 0:
return True
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
def delete_by_source_id(self, source_id: str, source_type: str):
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
return True
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()
class ISearch(ABC):
@abstractmethod
def support(self, search_mode: SearchMode):
pass
@abstractmethod
def handle(self, query_set, query_text, query_embedding, top_number: int,
similarity: float, search_mode: SearchMode):
pass
class EmbeddingSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
from common.utils.logger import maxkb_logger
maxkb_logger.info(f"EmbeddingSearch: Executing search with similarity threshold={similarity}, top_n={top_number}")
# 先查询所有结果不设置相似度阈值,看看实际的相似度是多少
test_sql = """
SELECT
paragraph_id,
comprehensive_score,
comprehensive_score as similarity
FROM
(
SELECT DISTINCT ON
("paragraph_id") ( 1 - distince ),* ,(1 - distince) AS comprehensive_score
FROM
( SELECT *, ( embedding.embedding::vector(%s) <=> %s ) AS distince FROM embedding ${embedding_query} ORDER BY distince) TEMP
ORDER BY
paragraph_id,
distince
) DISTINCT_TEMP
ORDER BY comprehensive_score DESC
LIMIT %s
"""
test_exec_sql, test_exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=test_sql,
with_table_name=True)
# 查询不带阈值的结果
test_results = select_list(test_exec_sql, [
len(query_embedding),
json.dumps(query_embedding),
*test_exec_params,
10 # 获取前10个结果
])
if len(test_results) > 0:
test_similarities = [r.get('similarity', 0) for r in test_results[:5]]
maxkb_logger.info(f"Actual similarities (no threshold): {test_similarities}")
maxkb_logger.info(f"Highest similarity: {test_similarities[0] if test_similarities else 0}, Required threshold: {similarity}")
if test_similarities[0] < similarity:
maxkb_logger.warning(f"Best similarity {test_similarities[0]} is below threshold {similarity}")
# 获取段落内容看看
if len(test_results) > 0:
paragraph_id = test_results[0].get('paragraph_id')
from knowledge.models import Paragraph
para = QuerySet(Paragraph).filter(id=paragraph_id).first()
if para:
maxkb_logger.info(f"Top paragraph content preview (first 200 chars): {para.content[:200]}...")
maxkb_logger.info(f"Paragraph title: {para.title}, length: {len(para.content)}")
else:
maxkb_logger.warning("No embeddings found even without similarity threshold")
# 正常查询(带相似度阈值)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
'embedding_search.sql')),
with_table_name=True)
maxkb_logger.debug(f"EmbeddingSearch SQL params count: {len(exec_params)}")
embedding_model = select_list(exec_sql, [
len(query_embedding),
json.dumps(query_embedding),
*exec_params,
similarity,
top_number
])
maxkb_logger.info(f"EmbeddingSearch results: {len(embedding_model)} embeddings found (with threshold)")
if len(embedding_model) > 0:
similarities = [e.get('similarity', 0) for e in embedding_model[:3]]
maxkb_logger.info(f"Top 3 similarities above threshold: {similarities}")
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.embedding.value
class KeywordsSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
'keywords_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql, [
to_query(query_text),
*exec_params,
similarity,
top_number
])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.keywords.value
class BlendSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
'blend_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql, [
len(query_embedding),
json.dumps(query_embedding),
to_query(query_text),
*exec_params, similarity,
top_number
])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.blend.value
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]