diff --git a/embedding/embedding.py b/embedding/embedding.py index 956048a..4ffaf2c 100644 --- a/embedding/embedding.py +++ b/embedding/embedding.py @@ -2,9 +2,11 @@ import pickle import re import numpy as np import os -from typing import Optional +from typing import Optional, List, Dict, Any import requests import asyncio +import hashlib +import json def encode_texts_via_api(texts, batch_size=32): """通过 API 接口编码文本""" @@ -706,6 +708,249 @@ if __name__ == "__main__": max_chunk_size=800, # 较小的chunk大小 overlap=100) +def cache_terms_embeddings(bot_id: str, terms_list: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + 处理terms列表,生成embedding并缓存 + + Args: + bot_id: 机器人ID,用于缓存key + terms_list: terms列表,每个term包含name, description, synonyms等字段 + + Returns: + Dict: 包含embedding数据的字典 + """ + if not terms_list: + return {} + + cache_key = f"{bot_id}_terms" + cache_file = f"projects/cache/{cache_key}.pkl" + + # 确保cache目录存在 + os.makedirs("projects/cache", exist_ok=True) + + # 检查缓存是否存在且有效 + if os.path.exists(cache_file): + try: + with open(cache_file, 'rb') as f: + cached_data = pickle.load(f) + + # 验证缓存数据是否匹配当前的terms + current_hash = _generate_terms_hash(terms_list) + if cached_data.get('hash') == current_hash: + print(f"Using cached terms embeddings for {cache_key}") + return cached_data + except Exception as e: + print(f"Error loading cache: {e}") + + # 准备要编码的文本 + term_texts = [] + term_info = [] + + for term in terms_list: + # 构建term的完整文本用于embedding + term_text_parts = [] + + if 'name' in term and term['name']: + term_text_parts.append(f"Name: {term['name']}") + + if 'description' in term and term['description']: + term_text_parts.append(f"Description: {term['description']}") + + # 处理同义词 + synonyms = [] + if 'synonyms' in term and term['synonyms']: + if isinstance(term['synonyms'], list): + synonyms = term['synonyms'] + elif isinstance(term['synonyms'], str): + synonyms = [s.strip() for s in term['synonyms'].split(',') if s.strip()] + + if synonyms: + term_text_parts.append(f"Synonyms: {', '.join(synonyms)}") + + term_text = " | ".join(term_text_parts) + term_texts.append(term_text) + + # 保存原始信息 + term_info.append({ + 'name': term.get('name', ''), + 'description': term.get('description', ''), + 'synonyms': synonyms + }) + + # 生成embeddings + try: + embeddings = encode_texts_via_api(term_texts, batch_size=16) + + # 准备缓存数据 + cache_data = { + 'hash': _generate_terms_hash(terms_list), + 'term_info': term_info, + 'embeddings': embeddings, + 'texts': term_texts + } + + # 保存到缓存 + with open(cache_file, 'wb') as f: + pickle.dump(cache_data, f) + + print(f"Cached {len(term_texts)} terms embeddings to {cache_file}") + return cache_data + + except Exception as e: + print(f"Error generating terms embeddings: {e}") + return {} + + +def search_similar_terms(query_text: str, cached_terms_data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + 在缓存的terms中搜索与查询文本相似的terms + + Args: + query_text: 查询文本 + cached_terms_data: 缓存的terms数据 + + Returns: + List[Dict]: 匹配的terms列表,按相似度降序排列 + """ + if not cached_terms_data or not query_text or 'embeddings' not in cached_terms_data: + return [] + + try: + # 生成查询文本的embedding + query_embedding = encode_texts_via_api([query_text], batch_size=1) + if len(query_embedding) == 0: + return [] + + query_vector = query_embedding[0] + term_embeddings = cached_terms_data['embeddings'] + term_info = cached_terms_data['term_info'] + + # 添加调试信息 + print(f"DEBUG: Query text: '{query_text}'") + print(f"DEBUG: Query vector shape: {query_vector.shape}, norm: {np.linalg.norm(query_vector)}") + + # 计算cos相似度 + similarities = _cosine_similarity(query_vector, term_embeddings) + + print(f"DEBUG: Similarities: {similarities}") + print(f"DEBUG: Max similarity: {np.max(similarities):.3f}, Mean similarity: {np.mean(similarities):.3f}") + + # 获取所有terms的相似度 + matches = [] + for i, similarity in enumerate(similarities): + match = { + 'term_info': term_info[i], + 'similarity': float(similarity), + 'index': i + } + matches.append(match) + + # 按相似度降序排列 + matches.sort(key=lambda x: x['similarity'], reverse=True) + + # 只返回top5结果 + return matches[:5] + + except Exception as e: + print(f"Error in similarity search: {e}") + return [] + + +def format_terms_analysis(similar_terms: List[Dict[str, Any]]) -> str: + """ + 格式化相似terms为指定格式的字符串 + + Args: + similar_terms: 相似terms列表 + + Returns: + str: 格式化后的terms分析 + """ + if not similar_terms: + return "" + + formatted_terms = [] + + for i, match in enumerate(similar_terms, 1): + term_info = match['term_info'] + similarity = match['similarity'] + + name = term_info.get('name', '') + description = term_info.get('description', '') + synonyms = term_info.get('synonyms', []) + + # 格式化同义词 + synonyms_str = ', '.join(synonyms) if synonyms else 'N/A' + + formatted_term = f"{i}) Name: {name}, Description: {description}, Synonyms: {synonyms_str} (Similarity: {similarity:.3f})" + formatted_terms.append(formatted_term) + + return "\n".join(formatted_terms) + + +def _generate_terms_hash(terms_list: List[Dict[str, Any]]) -> str: + """生成terms列表的哈希值用于缓存验证""" + # 将terms列表转换为标准化的字符串 + terms_str = json.dumps(terms_list, sort_keys=True, ensure_ascii=False) + return hashlib.md5(terms_str.encode('utf-8')).hexdigest() + + +def _cosine_similarity(query_vector: np.ndarray, term_embeddings: np.ndarray) -> np.ndarray: + """ + 计算查询向量与所有term embeddings的cos相似度 + 参考semantic_search_server.py的实现,假设向量已经归一化 + + Args: + query_vector: 查询向量 (shape: [embedding_dim]) + term_embeddings: term embeddings矩阵 (shape: [n_terms, embedding_dim]) + + Returns: + np.ndarray: 相似度数组 (shape: [n_terms]) + """ + # 使用与semantic_search_server.py相同的算法 + if len(term_embeddings.shape) > 1: + cos_scores = np.dot(term_embeddings, query_vector) / ( + np.linalg.norm(term_embeddings, axis=1) * np.linalg.norm(query_vector) + 1e-8 + ) + else: + cos_scores = np.array([0.0] * len(term_embeddings)) + + return cos_scores + + +def process_terms_with_embedding(terms_list: List[Dict[str, Any]], bot_id: str, query_text: str) -> str: + """ + 完整的terms处理流程:缓存、搜索相似度、格式化输出 + + Args: + terms_list: terms列表 + bot_id: 机器人ID + query_text: 用户查询文本 + + Returns: + str: 格式化后的terms分析结果 + """ + if not terms_list or not query_text: + return "" + + # 1. 缓存terms的embeddings + cached_data = cache_terms_embeddings(bot_id, terms_list) + + if not cached_data: + return "" + + # 2. 搜索相似的terms (取top5) + similar_terms = search_similar_terms(query_text, cached_data) + + # 3. 格式化输出 + if similar_terms: + return format_terms_analysis(similar_terms) + else: + # 当没有找到相似terms时,可以返回空字符串或者提示信息 + # 这里返回空字符串,让调用方决定如何处理 + return "" + + # 其他示例调用(注释掉的): # split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行 diff --git a/prompt/guideline_prompt.md b/prompt/guideline_prompt.md index ac34c2d..32f5f54 100644 --- a/prompt/guideline_prompt.md +++ b/prompt/guideline_prompt.md @@ -137,6 +137,8 @@ Examples of Guideline Match Evaluations: } ``` +Terms: +{terms} Chat History: {chat_history} diff --git a/routes/chat.py b/routes/chat.py index 8019d6f..e6f81c0 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -12,7 +12,7 @@ from utils import ( from agent.sharded_agent_manager import init_global_sharded_agent_manager from utils.api_models import ChatRequestV2 from utils.fastapi_utils import ( - process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history, + process_messages, extract_block_from_system_prompt, format_messages_to_chat_history, create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, call_guideline_llm, _get_optimal_batch_size, process_guideline_batch, get_content_from_messages ) @@ -26,6 +26,34 @@ agent_manager = init_global_sharded_agent_manager( ) +def get_user_last_message_content(messages: list) -> Optional[dict]: + """获取消息列表中的最后一条消息""" + if not messages or len(messages) == 0: + return "" + last_message = messages[-1] + if last_message and last_message.get('role') == 'user': + return last_message["content"] + return "" + +def append_user_last_message(messages: list, content: str) -> bool: + """向最后一条用户消息追加内容 + + Args: + messages: 消息列表 + content: 要追加的内容 + condition: 可选条件,如果提供则检查消息角色是否匹配此条件 + + Returns: + bool: 是否成功追加内容 + """ + if not messages or len(messages) == 0: + return messages + last_message = messages[-1] + if last_message and last_message.get('role') == 'user': + messages[-1]['content'] += content + return messages + + async def generate_stream_response(agent, messages, thought_list, tool_response: bool, model: str): """生成流式响应""" accumulated_content = "" @@ -134,15 +162,41 @@ async def create_agent_and_generate_response( if generate_cfg is None: generate_cfg = {} - # 1. 从system_prompt提取guideline内容 - system_prompt, guidelines_text = extract_guidelines_from_system_prompt(system_prompt) - print(f"guidelines_text: {guidelines_text}") + # 1. 从system_prompt提取guideline和terms内容 + system_prompt, guidelines_list, terms_list = extract_block_from_system_prompt(system_prompt) - # 2. 如果有guideline内容,进行并发处理 + # 2. 如果有terms内容,先进行embedding(embedding需要缓存起来,这个可以tmp文件缓存,以{bot_id}_terms作为key)embedding实现情参考 @embedding/embedding.py 文件,可以在里面实现。拿到embedding后,可以进行相似性检索,检索方式先使用cos相似度,找到阈值相似性>0.7的匹配项,重新整理为terms_analysis,格式:1) Name: term_name1, Description: desc, Synonyms: syn1, syn2。 + terms_analysis = "" + if terms_list: + print(f"terms_list: {terms_list}") + # 从messages中提取用户的查询文本用于相似性检索 + query_text = get_user_last_message_content(messages) + # 使用embedding进行terms处理 + try: + from embedding.embedding import process_terms_with_embedding + terms_analysis = process_terms_with_embedding(terms_list, bot_id, query_text) + if terms_analysis: + # 将terms分析结果也添加到消息中 + messages = append_user_last_message(messages, f"\n\nRelevant Terms:\n{terms_analysis}") + print(f"Generated terms analysis: {terms_analysis[:200]}...") # 只打印前200个字符 + except Exception as e: + print(f"Error processing terms with embedding: {e}") + terms_analysis = "" + else: + # 当terms_list为空时,删除对应的pkl缓存文件 + try: + import os + cache_file = f"projects/cache/{bot_id}_terms.pkl" + if os.path.exists(cache_file): + os.remove(cache_file) + print(f"Removed empty terms cache file: {cache_file}") + except Exception as e: + print(f"Error removing terms cache file: {e}") + + # 3. 如果有guideline内容,进行并发处理 guideline_analysis = "" - if guidelines_text: - # 按换行符分割guidelines - guidelines_list = [g.strip() for g in guidelines_text.split('\n') if g.strip()] + if guidelines_list: + print(f"guidelines_list: {guidelines_list}") guidelines_count = len(guidelines_list) if guidelines_count > 0: @@ -152,11 +206,16 @@ async def create_agent_and_generate_response( # 计算每个批次应该包含多少条guideline guidelines_per_batch = max(1, guidelines_count // batch_count) - # 分批处理guidelines + # 分批处理guidelines - 将字典列表转换为字符串列表以便处理 batches = [] for i in range(0, guidelines_count, guidelines_per_batch): - batch = guidelines_list[i:i + guidelines_per_batch] - batches.append(batch) + batch_guidelines = guidelines_list[i:i + guidelines_per_batch] + # 将格式化为字符串,保持原有的格式以便LLM处理 + batch_strings = [] + for guideline in batch_guidelines: + guideline_str = f"{guideline['id']}) Condition: {guideline['condition']} Action: {guideline['action']}" + batch_strings.append(guideline_str) + batches.append(batch_strings) # 确保批次数量不超过要求的并发数 while len(batches) > batch_count: @@ -177,6 +236,7 @@ async def create_agent_and_generate_response( task = process_guideline_batch( guidelines_batch=batch, chat_history=chat_history, + terms=terms_analysis, model_name=model_name, api_key=api_key, model_server=model_server @@ -228,10 +288,9 @@ async def create_agent_and_generate_response( print(f"Merged guideline analysis result: {guideline_analysis}") # 将分析结果添加到最后一个消息的内容中 - if guideline_analysis and messages: - last_message = messages[-1] - if last_message.get('role') == 'user': - messages[-1]['content'] += f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response." + if guideline_analysis: + messages = append_user_last_message(messages, f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response.") + else: # 3. 从全局管理器获取或创建助手实例 agent = await agent_manager.get_or_create_agent( @@ -248,6 +307,18 @@ async def create_agent_and_generate_response( user_identifier=user_identifier ) + if language: + # 在最后一条消息的末尾追加回复语言 + language_map = { + 'zh': '请用中文回复', + 'en': 'Please reply in English', + 'ja': '日本語で回答してください', + 'jp': '日本語で回答してください' + } + language_instruction = language_map.get(language.lower(), '') + if language_instruction: + messages = append_user_last_message(messages, f"\n\nlanguage:{language_instruction}") + thought_list = [] if guideline_analysis != '': thought_list = [{"role": "assistant","reasoning_content": guideline_analysis}] diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index 6aabcd9..59284fb 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -255,44 +255,9 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li # 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加 final_messages.append(msg) - # 在最后一条消息的末尾追加回复语言 - if final_messages and language: - language_map = { - 'zh': '请用中文回复', - 'en': 'Please reply in English', - 'ja': '日本語で回答してください', - 'jp': '日本語で回答してください' - } - language_instruction = language_map.get(language.lower(), '') - if language_instruction: - # 在最后一条消息末尾追加语言指令 - final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\nlanguage:\n{language_instruction}。" - return final_messages -def extract_guidelines_from_system_prompt(system_prompt: Optional[str]) -> tuple[str, str]: - """从system_prompt中提取```guideline内容并清理原提示词 - - Returns: - tuple[str, str]: (清理后的system_prompt, 提取的guidelines内容) - """ - if not system_prompt: - return "", "" - - # 使用正则表达式提取 ```guideline``` 包裹的内容 - pattern = r'```guideline\s*\n(.*?)\n```' - matches = re.findall(pattern, system_prompt, re.DOTALL) - - # 如果没有匹配到guidelines,直接返回空字符串和原始prompt - if not matches: - return system_prompt, "" - - guidelines_text = "\n".join(matches).strip() - - # 从原始system_prompt中删除 ```guideline``` 内容块 - cleaned_prompt = re.sub(pattern, '', system_prompt, flags=re.DOTALL) - return cleaned_prompt, guidelines_text def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str: @@ -397,7 +362,7 @@ async def fetch_bot_config(bot_id: str) -> Dict[str, Any]: ) -async def call_guideline_llm(chat_history: str, guidelines_text: str, model_name: str, api_key: str, model_server: str) -> str: +async def call_guideline_llm(chat_history: str, guidelines_text: str, terms:str, model_name: str, api_key: str, model_server: str) -> str: """调用大语言模型处理guideline分析 Args: @@ -419,7 +384,7 @@ async def call_guideline_llm(chat_history: str, guidelines_text: str, model_name return "" # 替换模板中的占位符 - system_prompt = guideline_template.replace('{chat_history}', chat_history).replace('{guidelines_text}', guidelines_text) + system_prompt = guideline_template.replace('{chat_history}', chat_history).replace('{guidelines_text}', guidelines_text).replace('{terms}', terms) # 配置LLM llm_config = { @@ -473,6 +438,7 @@ def _get_optimal_batch_size(guidelines_count: int) -> int: async def process_guideline_batch( guidelines_batch: List[str], chat_history: str, + terms: str, model_name: str, api_key: str, model_server: str @@ -481,7 +447,7 @@ async def process_guideline_batch( try: # 调用LLM分析这批guidelines batch_guidelines_text = "\n".join(guidelines_batch) - batch_analysis = await call_guideline_llm(chat_history, batch_guidelines_text, model_name, api_key, model_server) + batch_analysis = await call_guideline_llm(chat_history, batch_guidelines_text, terms, model_name, api_key, model_server) # 从响应中提取 ```json 和 ``` 包裹的内容 json_pattern = r'```json\s*\n(.*?)\n```' @@ -500,3 +466,212 @@ async def process_guideline_batch( except Exception as e: print(f"Error processing guideline batch: {e}") return "" + + +def extract_block_from_system_prompt(system_prompt: Optional[str]) -> tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + 从system prompt中提取guideline和terms内容 + + Args: + system_prompt: 系统提示词 + + Returns: + tuple[str, List[Dict], List[Dict]]: (清理后的system_prompt, guidelines_list, terms_list) + """ + if not system_prompt: + return "", [], [] + + guidelines_list = [] + terms_list = [] + + # 首先分割所有的代码块 + block_pattern = r'```(\w+)\s*\n(.*?)\n```' + blocks_to_remove = [] + + for match in re.finditer(block_pattern, system_prompt, re.DOTALL): + block_type, content = match.groups() + + if block_type == 'guideline': + try: + guidelines = parse_guidelines_text(content.strip()) + guidelines_list.extend(guidelines) + blocks_to_remove.append(match.group(0)) + except Exception as e: + print(f"Error parsing guidelines: {e}") + + elif block_type == 'terms': + try: + terms = parse_terms_text(content.strip()) + terms_list.extend(terms) + blocks_to_remove.append(match.group(0)) + except Exception as e: + print(f"Error parsing terms: {e}") + + # 从system_prompt中移除这些已解析的块 + cleaned_prompt = system_prompt + for block in blocks_to_remove: + cleaned_prompt = cleaned_prompt.replace(block, '', 1) + + # 清理多余的空行 + cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip() + + return cleaned_prompt, guidelines_list, terms_list + + +def parse_guidelines_text(text: str) -> List[Dict[str, Any]]: + """ + 解析guidelines文本,支持多种格式 + + Args: + text: guidelines文本内容 + + Returns: + List[Dict]: guidelines列表 + """ + guidelines = [] + + # 尝试解析JSON格式 + if text.strip().startswith('[') or text.strip().startswith('{'): + try: + data = json.loads(text) + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + guidelines.append(item) + elif isinstance(data, dict): + guidelines.append(data) + return guidelines + except json.JSONDecodeError: + pass + + # 解析行格式,支持多种分隔符 + lines = [line.strip() for line in text.split('\n') if line.strip()] + + for line in lines: + # 跳过注释行 + if line.startswith('#') or line.startswith('//'): + continue + + # 尝试解析 "id) Condition: ... Action: ..." 格式 + id_condition_action_pattern = r'(\d+)\)\s*Condition:\s*(.*?)\s*Action:\s*(.*?)(?:\s*Priority:\s*(\d+))?$' + match = re.match(id_condition_action_pattern, line, re.IGNORECASE) + if match: + guidelines.append({ + 'id': int(match.group(1)), + 'condition': match.group(2).strip(), + 'action': match.group(3).strip(), + 'priority': int(match.group(4)) if match.group(4) else 1 + }) + continue + + # 尝试解析 "condition -> action" 格式 + arrow_pattern = r'(?:\d+\)\s*)?(.*?)\s*->\s*(.*?)(?:\s*\[(\d+)\])?$' + match = re.match(arrow_pattern, line, re.IGNORECASE) + if match: + guidelines.append({ + 'id': len(guidelines) + 1, + 'condition': match.group(1).strip(), + 'action': match.group(2).strip(), + 'priority': int(match.group(3)) if match.group(3) else 1 + }) + continue + + # 尝试解析 "if condition then action" 格式 + if_then_pattern = r'(?:\d+\)\s*)?if\s+(.*?)\s+then\s+(.*?)(?:\s*\[(\d+)\])?$' + match = re.match(if_then_pattern, line, re.IGNORECASE) + if match: + guidelines.append({ + 'id': len(guidelines) + 1, + 'condition': match.group(1).strip(), + 'action': match.group(2).strip(), + 'priority': int(match.group(3)) if match.group(3) else 1 + }) + continue + + # 默认格式:整行作为action,condition为空 + guidelines.append({ + 'id': len(guidelines) + 1, + 'condition': '', + 'action': line.strip(), + 'priority': 1 + }) + + return guidelines + + +def parse_terms_text(text: str) -> List[Dict[str, Any]]: + """ + 解析terms文本,支持多种格式 + + Args: + text: terms文本内容 + + Returns: + List[Dict]: terms列表 + """ + terms = [] + + # 尝试解析JSON格式 + if text.strip().startswith('[') or text.strip().startswith('{'): + try: + data = json.loads(text) + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + terms.append(item) + elif isinstance(data, dict): + terms.append(data) + return terms + except json.JSONDecodeError: + pass + + # 解析行格式,支持多种分隔符 + lines = [line.strip() for line in text.split('\n') if line.strip()] + + current_term = {} + + for line in lines: + # 跳过注释行 + if line.startswith('#') or line.startswith('//'): + continue + + # 尝试解析 "1) Name: term_name1, Description: desc, Synonyms: syn1, syn2" 格式 + numbered_term_pattern = r'(?:\d+\)\s*)?Name:\s*([^,]+)(?:,\s*Description:\s*([^,]+))?(?:,\s*Synonyms:\s*(.+))?' + match = re.match(numbered_term_pattern, line, re.IGNORECASE) + if match: + name = match.group(1).strip() + description = match.group(2).strip() if match.group(2) else '' + synonyms_text = match.group(3).strip() if match.group(3) else '' + + # 构建term对象 + term_data = {'name': name} + if description: + term_data['description'] = description + if synonyms_text: + synonyms = re.split(r'[,;|]', synonyms_text) + term_data['synonyms'] = [s.strip() for s in synonyms if s.strip()] + + if current_term: # 保存之前的term + terms.append(current_term) + current_term = term_data + continue + + # 尝试解析 "| value" 格式(简化格式) + if line.startswith('|'): + parts = [p.strip() for p in line[1:].split('|', 2)] # 最多分割3部分 + if len(parts) >= 1: + if current_term: + terms.append(current_term) + current_term = {'name': parts[0]} + if len(parts) >= 2: + current_term['description'] = parts[1] + if len(parts) >= 3: + synonyms = re.split(r'[,;|]', parts[2]) + current_term['synonyms'] = [s.strip() for s in synonyms if s.strip()] + continue + + # 添加最后一个term + if current_term: + terms.append(current_term) + + return terms