344 lines
15 KiB
Python
344 lines
15 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: maxkb
|
||
@Author:虎
|
||
@file: pg_vector.py
|
||
@date:2023/10/19 15:28
|
||
@desc:
|
||
"""
|
||
import json
|
||
import os
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, List
|
||
|
||
import uuid_utils.compat as uuid
|
||
from django.contrib.postgres.search import SearchVector
|
||
from django.db.models import QuerySet, Value
|
||
from langchain_core.embeddings import Embeddings
|
||
|
||
from common.db.search import generate_sql_by_query_dict
|
||
from common.db.sql_execute import select_list
|
||
from common.utils.common import get_file_content
|
||
from common.utils.ts_vecto_util import to_ts_vector, to_query
|
||
from knowledge.models import Embedding, SearchMode, SourceType
|
||
from knowledge.vector.base_vector import BaseVectorStore
|
||
from maxkb.conf import PROJECT_DIR
|
||
|
||
|
||
class PGVector(BaseVectorStore):
|
||
|
||
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||
if len(source_ids) == 0:
|
||
return
|
||
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
|
||
|
||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||
|
||
def vector_is_create(self) -> bool:
|
||
# 项目启动默认是创建好的 不需要再创建
|
||
return True
|
||
|
||
def vector_create(self):
|
||
return True
|
||
|
||
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
|
||
source_id: str,
|
||
is_active: bool,
|
||
embedding: Embeddings):
|
||
text_embedding = [float(x) for x in embedding.embed_query(text)]
|
||
embedding = Embedding(
|
||
id=uuid.uuid7(),
|
||
knowledge_id=knowledge_id,
|
||
document_id=document_id,
|
||
is_active=is_active,
|
||
paragraph_id=paragraph_id,
|
||
source_id=source_id,
|
||
embedding=text_embedding,
|
||
source_type=source_type,
|
||
search_vector=to_ts_vector(text)
|
||
)
|
||
embedding.save()
|
||
return True
|
||
|
||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
||
from common.utils.logger import maxkb_logger
|
||
texts = [row.get('text') for row in text_list]
|
||
maxkb_logger.info(f"PGVector batch_save: Processing {len(texts)} texts")
|
||
|
||
# Log details of first few items for debugging
|
||
for i, item in enumerate(text_list[:3]):
|
||
maxkb_logger.debug(f"Item {i}: document_id={item.get('document_id')}, "
|
||
f"paragraph_id={item.get('paragraph_id')}, "
|
||
f"is_active={item.get('is_active', True)}, "
|
||
f"text_preview='{item.get('text', '')[:50]}...'")
|
||
|
||
embeddings = embedding.embed_documents(texts)
|
||
embedding_list = [
|
||
Embedding(
|
||
id=uuid.uuid7(),
|
||
document_id=text_list[index].get('document_id'),
|
||
paragraph_id=text_list[index].get('paragraph_id'),
|
||
knowledge_id=text_list[index].get('knowledge_id'),
|
||
is_active=text_list[index].get('is_active', True),
|
||
source_id=text_list[index].get('source_id'),
|
||
source_type=text_list[index].get('source_type'),
|
||
embedding=[float(x) for x in embeddings[index]],
|
||
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
|
||
) for index in range(0, len(texts))]
|
||
|
||
maxkb_logger.info(f"PGVector batch_save: Created {len(embedding_list)} embedding objects")
|
||
|
||
if not is_the_task_interrupted():
|
||
if len(embedding_list) > 0:
|
||
QuerySet(Embedding).bulk_create(embedding_list)
|
||
maxkb_logger.info(f"PGVector batch_save: Successfully saved {len(embedding_list)} embeddings to database")
|
||
else:
|
||
maxkb_logger.warning("PGVector batch_save: No embeddings to save")
|
||
else:
|
||
maxkb_logger.warning("PGVector batch_save: Task interrupted, embeddings not saved")
|
||
return True
|
||
|
||
def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||
similarity: float,
|
||
search_mode: SearchMode,
|
||
embedding: Embeddings):
|
||
if knowledge_id_list is None or len(knowledge_id_list) == 0:
|
||
return []
|
||
exclude_dict = {}
|
||
embedding_query = embedding.embed_query(query_text)
|
||
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True)
|
||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||
query_set = query_set.exclude(**exclude_dict)
|
||
for search_handle in search_handle_list:
|
||
if search_handle.support(search_mode):
|
||
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
|
||
|
||
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
|
||
exclude_document_id_list: list[str],
|
||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
|
||
search_mode: SearchMode):
|
||
from common.utils.logger import maxkb_logger
|
||
exclude_dict = {}
|
||
if knowledge_id_list is None or len(knowledge_id_list) == 0:
|
||
maxkb_logger.warning("Vector query: knowledge_id_list is empty")
|
||
return []
|
||
|
||
maxkb_logger.info(f"Vector query starting: query_text='{query_text[:50]}...', knowledge_ids={knowledge_id_list}, "
|
||
f"is_active={is_active}, top_n={top_n}, similarity={similarity}, search_mode={search_mode.value}")
|
||
|
||
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
|
||
initial_count = query_set.count()
|
||
maxkb_logger.info(f"Initial embedding count: {initial_count}")
|
||
|
||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
|
||
maxkb_logger.info(f"After excluding documents: {query_set.count()} embeddings")
|
||
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
||
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
|
||
maxkb_logger.info(f"After excluding paragraphs: {query_set.count()} embeddings")
|
||
query_set = query_set.exclude(**exclude_dict)
|
||
|
||
final_count = query_set.count()
|
||
maxkb_logger.info(f"Final embedding count before search: {final_count}")
|
||
|
||
for search_handle in search_handle_list:
|
||
if search_handle.support(search_mode):
|
||
maxkb_logger.info(f"Using search handler: {search_handle.__class__.__name__}")
|
||
results = search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
|
||
maxkb_logger.info(f"Search results: {len(results)} items found")
|
||
if len(results) > 0:
|
||
maxkb_logger.info(f"Top result similarity: {results[0].get('similarity', 'N/A')}")
|
||
return results
|
||
|
||
maxkb_logger.warning("No suitable search handler found")
|
||
return []
|
||
|
||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
||
|
||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
|
||
|
||
def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
|
||
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)
|
||
|
||
def delete_by_knowledge_id(self, knowledge_id: str):
|
||
QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete()
|
||
|
||
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
|
||
QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete()
|
||
|
||
def delete_by_document_id(self, document_id: str):
|
||
QuerySet(Embedding).filter(document_id=document_id).delete()
|
||
return True
|
||
|
||
def delete_by_document_id_list(self, document_id_list: List[str]):
|
||
if len(document_id_list) == 0:
|
||
return True
|
||
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
|
||
|
||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
|
||
return True
|
||
|
||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
|
||
|
||
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
|
||
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()
|
||
|
||
|
||
class ISearch(ABC):
|
||
@abstractmethod
|
||
def support(self, search_mode: SearchMode):
|
||
pass
|
||
|
||
@abstractmethod
|
||
def handle(self, query_set, query_text, query_embedding, top_number: int,
|
||
similarity: float, search_mode: SearchMode):
|
||
pass
|
||
|
||
|
||
class EmbeddingSearch(ISearch):
|
||
def handle(self,
|
||
query_set,
|
||
query_text,
|
||
query_embedding,
|
||
top_number: int,
|
||
similarity: float,
|
||
search_mode: SearchMode):
|
||
from common.utils.logger import maxkb_logger
|
||
maxkb_logger.info(f"EmbeddingSearch: Executing search with similarity threshold={similarity}, top_n={top_number}")
|
||
|
||
# 先查询所有结果不设置相似度阈值,看看实际的相似度是多少
|
||
test_sql = """
|
||
SELECT
|
||
paragraph_id,
|
||
comprehensive_score,
|
||
comprehensive_score as similarity
|
||
FROM
|
||
(
|
||
SELECT DISTINCT ON
|
||
("paragraph_id") ( 1 - distince ),* ,(1 - distince) AS comprehensive_score
|
||
FROM
|
||
( SELECT *, ( embedding.embedding::vector(%s) <=> %s ) AS distince FROM embedding ${embedding_query} ORDER BY distince) TEMP
|
||
ORDER BY
|
||
paragraph_id,
|
||
distince
|
||
) DISTINCT_TEMP
|
||
ORDER BY comprehensive_score DESC
|
||
LIMIT %s
|
||
"""
|
||
|
||
test_exec_sql, test_exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||
select_string=test_sql,
|
||
with_table_name=True)
|
||
|
||
# 查询不带阈值的结果
|
||
test_results = select_list(test_exec_sql, [
|
||
len(query_embedding),
|
||
json.dumps(query_embedding),
|
||
*test_exec_params,
|
||
10 # 获取前10个结果
|
||
])
|
||
|
||
if len(test_results) > 0:
|
||
test_similarities = [r.get('similarity', 0) for r in test_results[:5]]
|
||
maxkb_logger.info(f"Actual similarities (no threshold): {test_similarities}")
|
||
maxkb_logger.info(f"Highest similarity: {test_similarities[0] if test_similarities else 0}, Required threshold: {similarity}")
|
||
if test_similarities[0] < similarity:
|
||
maxkb_logger.warning(f"Best similarity {test_similarities[0]} is below threshold {similarity}")
|
||
# 获取段落内容看看
|
||
if len(test_results) > 0:
|
||
paragraph_id = test_results[0].get('paragraph_id')
|
||
from knowledge.models import Paragraph
|
||
para = QuerySet(Paragraph).filter(id=paragraph_id).first()
|
||
if para:
|
||
maxkb_logger.info(f"Top paragraph content preview (first 200 chars): {para.content[:200]}...")
|
||
maxkb_logger.info(f"Paragraph title: {para.title}, length: {len(para.content)}")
|
||
else:
|
||
maxkb_logger.warning("No embeddings found even without similarity threshold")
|
||
|
||
# 正常查询(带相似度阈值)
|
||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||
select_string=get_file_content(
|
||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
|
||
'embedding_search.sql')),
|
||
with_table_name=True)
|
||
|
||
maxkb_logger.debug(f"EmbeddingSearch SQL params count: {len(exec_params)}")
|
||
|
||
embedding_model = select_list(exec_sql, [
|
||
len(query_embedding),
|
||
json.dumps(query_embedding),
|
||
*exec_params,
|
||
similarity,
|
||
top_number
|
||
])
|
||
|
||
maxkb_logger.info(f"EmbeddingSearch results: {len(embedding_model)} embeddings found (with threshold)")
|
||
if len(embedding_model) > 0:
|
||
similarities = [e.get('similarity', 0) for e in embedding_model[:3]]
|
||
maxkb_logger.info(f"Top 3 similarities above threshold: {similarities}")
|
||
|
||
return embedding_model
|
||
|
||
def support(self, search_mode: SearchMode):
|
||
return search_mode.value == SearchMode.embedding.value
|
||
|
||
|
||
class KeywordsSearch(ISearch):
|
||
def handle(self,
|
||
query_set,
|
||
query_text,
|
||
query_embedding,
|
||
top_number: int,
|
||
similarity: float,
|
||
search_mode: SearchMode):
|
||
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
|
||
select_string=get_file_content(
|
||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
|
||
'keywords_search.sql')),
|
||
with_table_name=True)
|
||
embedding_model = select_list(exec_sql, [
|
||
to_query(query_text),
|
||
*exec_params,
|
||
similarity,
|
||
top_number
|
||
])
|
||
return embedding_model
|
||
|
||
def support(self, search_mode: SearchMode):
|
||
return search_mode.value == SearchMode.keywords.value
|
||
|
||
|
||
class BlendSearch(ISearch):
|
||
def handle(self,
|
||
query_set,
|
||
query_text,
|
||
query_embedding,
|
||
top_number: int,
|
||
similarity: float,
|
||
search_mode: SearchMode):
|
||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||
select_string=get_file_content(
|
||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
|
||
'blend_search.sql')),
|
||
with_table_name=True)
|
||
embedding_model = select_list(exec_sql, [
|
||
len(query_embedding),
|
||
json.dumps(query_embedding),
|
||
to_query(query_text),
|
||
*exec_params, similarity,
|
||
top_number
|
||
])
|
||
return embedding_model
|
||
|
||
def support(self, search_mode: SearchMode):
|
||
return search_mode.value == SearchMode.blend.value
|
||
|
||
|
||
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]
|