# coding=utf-8 """ @project: maxkb @Author:虎 @file: pg_vector.py @date:2023/10/19 15:28 @desc: """ import json import os from abc import ABC, abstractmethod from typing import Dict, List import uuid_utils.compat as uuid from django.contrib.postgres.search import SearchVector from django.db.models import QuerySet, Value from langchain_core.embeddings import Embeddings from common.db.search import generate_sql_by_query_dict from common.db.sql_execute import select_list from common.utils.common import get_file_content from common.utils.ts_vecto_util import to_ts_vector, to_query from knowledge.models import Embedding, SearchMode, SourceType from knowledge.vector.base_vector import BaseVectorStore from maxkb.conf import PROJECT_DIR class PGVector(BaseVectorStore): def delete_by_source_ids(self, source_ids: List[str], source_type: str): if len(source_ids) == 0: return QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete() def update_by_source_ids(self, source_ids: List[str], instance: Dict): QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) def vector_is_create(self) -> bool: # 项目启动默认是创建好的 不需要再创建 return True def vector_create(self): return True def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, embedding: Embeddings): text_embedding = [float(x) for x in embedding.embed_query(text)] embedding = Embedding( id=uuid.uuid7(), knowledge_id=knowledge_id, document_id=document_id, is_active=is_active, paragraph_id=paragraph_id, source_id=source_id, embedding=text_embedding, source_type=source_type, search_vector=to_ts_vector(text) ) embedding.save() return True def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): from common.utils.logger import maxkb_logger texts = [row.get('text') for row in text_list] maxkb_logger.info(f"PGVector batch_save: Processing {len(texts)} texts") # Log details of first few items for debugging for i, item in enumerate(text_list[:3]): maxkb_logger.debug(f"Item {i}: document_id={item.get('document_id')}, " f"paragraph_id={item.get('paragraph_id')}, " f"is_active={item.get('is_active', True)}, " f"text_preview='{item.get('text', '')[:50]}...'") embeddings = embedding.embed_documents(texts) embedding_list = [ Embedding( id=uuid.uuid7(), document_id=text_list[index].get('document_id'), paragraph_id=text_list[index].get('paragraph_id'), knowledge_id=text_list[index].get('knowledge_id'), is_active=text_list[index].get('is_active', True), source_id=text_list[index].get('source_id'), source_type=text_list[index].get('source_type'), embedding=[float(x) for x in embeddings[index]], search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text']))) ) for index in range(0, len(texts))] maxkb_logger.info(f"PGVector batch_save: Created {len(embedding_list)} embedding objects") if not is_the_task_interrupted(): if len(embedding_list) > 0: QuerySet(Embedding).bulk_create(embedding_list) maxkb_logger.info(f"PGVector batch_save: Successfully saved {len(embedding_list)} embeddings to database") else: maxkb_logger.warning("PGVector batch_save: No embeddings to save") else: maxkb_logger.warning("PGVector batch_save: Task interrupted, embeddings not saved") return True def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, search_mode: SearchMode, embedding: Embeddings): if knowledge_id_list is None or len(knowledge_id_list) == 0: return [] exclude_dict = {} embedding_query = embedding.embed_query(query_text) query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True) if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: exclude_dict.__setitem__('document_id__in', exclude_document_id_list) query_set = query_set.exclude(**exclude_dict) for search_handle in search_handle_list: if search_handle.support(search_mode): return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode) def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, search_mode: SearchMode): from common.utils.logger import maxkb_logger exclude_dict = {} if knowledge_id_list is None or len(knowledge_id_list) == 0: maxkb_logger.warning("Vector query: knowledge_id_list is empty") return [] maxkb_logger.info(f"Vector query starting: query_text='{query_text[:50]}...', knowledge_ids={knowledge_id_list}, " f"is_active={is_active}, top_n={top_n}, similarity={similarity}, search_mode={search_mode.value}") query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active) initial_count = query_set.count() maxkb_logger.info(f"Initial embedding count: {initial_count}") if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: query_set = query_set.exclude(document_id__in=exclude_document_id_list) maxkb_logger.info(f"After excluding documents: {query_set.count()} embeddings") if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list) maxkb_logger.info(f"After excluding paragraphs: {query_set.count()} embeddings") query_set = query_set.exclude(**exclude_dict) final_count = query_set.count() maxkb_logger.info(f"Final embedding count before search: {final_count}") for search_handle in search_handle_list: if search_handle.support(search_mode): maxkb_logger.info(f"Using search handler: {search_handle.__class__.__name__}") results = search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode) maxkb_logger.info(f"Search results: {len(results)} items found") if len(results) > 0: maxkb_logger.info(f"Top result similarity: {results[0].get('similarity', 'N/A')}") return results maxkb_logger.warning("No suitable search handler found") return [] def update_by_source_id(self, source_id: str, instance: Dict): QuerySet(Embedding).filter(source_id=source_id).update(**instance) def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance) def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict): QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance) def delete_by_knowledge_id(self, knowledge_id: str): QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete() def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]): QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete() def delete_by_document_id(self, document_id: str): QuerySet(Embedding).filter(document_id=document_id).delete() return True def delete_by_document_id_list(self, document_id_list: List[str]): if len(document_id_list) == 0: return True return QuerySet(Embedding).filter(document_id__in=document_id_list).delete() def delete_by_source_id(self, source_id: str, source_type: str): QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete() return True def delete_by_paragraph_id(self, paragraph_id: str): QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() def delete_by_paragraph_ids(self, paragraph_ids: List[str]): QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete() class ISearch(ABC): @abstractmethod def support(self, search_mode: SearchMode): pass @abstractmethod def handle(self, query_set, query_text, query_embedding, top_number: int, similarity: float, search_mode: SearchMode): pass class EmbeddingSearch(ISearch): def handle(self, query_set, query_text, query_embedding, top_number: int, similarity: float, search_mode: SearchMode): from common.utils.logger import maxkb_logger maxkb_logger.info(f"EmbeddingSearch: Executing search with similarity threshold={similarity}, top_n={top_number}") # 先查询所有结果不设置相似度阈值,看看实际的相似度是多少 test_sql = """ SELECT paragraph_id, comprehensive_score, comprehensive_score as similarity FROM ( SELECT DISTINCT ON ("paragraph_id") ( 1 - distince ),* ,(1 - distince) AS comprehensive_score FROM ( SELECT *, ( embedding.embedding::vector(%s) <=> %s ) AS distince FROM embedding ${embedding_query} ORDER BY distince) TEMP ORDER BY paragraph_id, distince ) DISTINCT_TEMP ORDER BY comprehensive_score DESC LIMIT %s """ test_exec_sql, test_exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, select_string=test_sql, with_table_name=True) # 查询不带阈值的结果 test_results = select_list(test_exec_sql, [ len(query_embedding), json.dumps(query_embedding), *test_exec_params, 10 # 获取前10个结果 ]) if len(test_results) > 0: test_similarities = [r.get('similarity', 0) for r in test_results[:5]] maxkb_logger.info(f"Actual similarities (no threshold): {test_similarities}") maxkb_logger.info(f"Highest similarity: {test_similarities[0] if test_similarities else 0}, Required threshold: {similarity}") if test_similarities[0] < similarity: maxkb_logger.warning(f"Best similarity {test_similarities[0]} is below threshold {similarity}") # 获取段落内容看看 if len(test_results) > 0: paragraph_id = test_results[0].get('paragraph_id') from knowledge.models import Paragraph para = QuerySet(Paragraph).filter(id=paragraph_id).first() if para: maxkb_logger.info(f"Top paragraph content preview (first 200 chars): {para.content[:200]}...") maxkb_logger.info(f"Paragraph title: {para.title}, length: {len(para.content)}") else: maxkb_logger.warning("No embeddings found even without similarity threshold") # 正常查询(带相似度阈值) exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'embedding_search.sql')), with_table_name=True) maxkb_logger.debug(f"EmbeddingSearch SQL params count: {len(exec_params)}") embedding_model = select_list(exec_sql, [ len(query_embedding), json.dumps(query_embedding), *exec_params, similarity, top_number ]) maxkb_logger.info(f"EmbeddingSearch results: {len(embedding_model)} embeddings found (with threshold)") if len(embedding_model) > 0: similarities = [e.get('similarity', 0) for e in embedding_model[:3]] maxkb_logger.info(f"Top 3 similarities above threshold: {similarities}") return embedding_model def support(self, search_mode: SearchMode): return search_mode.value == SearchMode.embedding.value class KeywordsSearch(ISearch): def handle(self, query_set, query_text, query_embedding, top_number: int, similarity: float, search_mode: SearchMode): exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'keywords_search.sql')), with_table_name=True) embedding_model = select_list(exec_sql, [ to_query(query_text), *exec_params, similarity, top_number ]) return embedding_model def support(self, search_mode: SearchMode): return search_mode.value == SearchMode.keywords.value class BlendSearch(ISearch): def handle(self, query_set, query_text, query_embedding, top_number: int, similarity: float, search_mode: SearchMode): exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'blend_search.sql')), with_table_name=True) embedding_model = select_list(exec_sql, [ len(query_embedding), json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, top_number ]) return embedding_model def support(self, search_mode: SearchMode): return search_mode.value == SearchMode.blend.value search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]