add terms
This commit is contained in:
parent
58ac6e3024
commit
a40da62413
@ -2,9 +2,11 @@ import pickle
|
|||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict, Any
|
||||||
import requests
|
import requests
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
def encode_texts_via_api(texts, batch_size=32):
|
def encode_texts_via_api(texts, batch_size=32):
|
||||||
"""通过 API 接口编码文本"""
|
"""通过 API 接口编码文本"""
|
||||||
@ -706,6 +708,249 @@ if __name__ == "__main__":
|
|||||||
max_chunk_size=800, # 较小的chunk大小
|
max_chunk_size=800, # 较小的chunk大小
|
||||||
overlap=100)
|
overlap=100)
|
||||||
|
|
||||||
|
def cache_terms_embeddings(bot_id: str, terms_list: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理terms列表,生成embedding并缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_id: 机器人ID,用于缓存key
|
||||||
|
terms_list: terms列表,每个term包含name, description, synonyms等字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含embedding数据的字典
|
||||||
|
"""
|
||||||
|
if not terms_list:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
cache_key = f"{bot_id}_terms"
|
||||||
|
cache_file = f"projects/cache/{cache_key}.pkl"
|
||||||
|
|
||||||
|
# 确保cache目录存在
|
||||||
|
os.makedirs("projects/cache", exist_ok=True)
|
||||||
|
|
||||||
|
# 检查缓存是否存在且有效
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
try:
|
||||||
|
with open(cache_file, 'rb') as f:
|
||||||
|
cached_data = pickle.load(f)
|
||||||
|
|
||||||
|
# 验证缓存数据是否匹配当前的terms
|
||||||
|
current_hash = _generate_terms_hash(terms_list)
|
||||||
|
if cached_data.get('hash') == current_hash:
|
||||||
|
print(f"Using cached terms embeddings for {cache_key}")
|
||||||
|
return cached_data
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading cache: {e}")
|
||||||
|
|
||||||
|
# 准备要编码的文本
|
||||||
|
term_texts = []
|
||||||
|
term_info = []
|
||||||
|
|
||||||
|
for term in terms_list:
|
||||||
|
# 构建term的完整文本用于embedding
|
||||||
|
term_text_parts = []
|
||||||
|
|
||||||
|
if 'name' in term and term['name']:
|
||||||
|
term_text_parts.append(f"Name: {term['name']}")
|
||||||
|
|
||||||
|
if 'description' in term and term['description']:
|
||||||
|
term_text_parts.append(f"Description: {term['description']}")
|
||||||
|
|
||||||
|
# 处理同义词
|
||||||
|
synonyms = []
|
||||||
|
if 'synonyms' in term and term['synonyms']:
|
||||||
|
if isinstance(term['synonyms'], list):
|
||||||
|
synonyms = term['synonyms']
|
||||||
|
elif isinstance(term['synonyms'], str):
|
||||||
|
synonyms = [s.strip() for s in term['synonyms'].split(',') if s.strip()]
|
||||||
|
|
||||||
|
if synonyms:
|
||||||
|
term_text_parts.append(f"Synonyms: {', '.join(synonyms)}")
|
||||||
|
|
||||||
|
term_text = " | ".join(term_text_parts)
|
||||||
|
term_texts.append(term_text)
|
||||||
|
|
||||||
|
# 保存原始信息
|
||||||
|
term_info.append({
|
||||||
|
'name': term.get('name', ''),
|
||||||
|
'description': term.get('description', ''),
|
||||||
|
'synonyms': synonyms
|
||||||
|
})
|
||||||
|
|
||||||
|
# 生成embeddings
|
||||||
|
try:
|
||||||
|
embeddings = encode_texts_via_api(term_texts, batch_size=16)
|
||||||
|
|
||||||
|
# 准备缓存数据
|
||||||
|
cache_data = {
|
||||||
|
'hash': _generate_terms_hash(terms_list),
|
||||||
|
'term_info': term_info,
|
||||||
|
'embeddings': embeddings,
|
||||||
|
'texts': term_texts
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存到缓存
|
||||||
|
with open(cache_file, 'wb') as f:
|
||||||
|
pickle.dump(cache_data, f)
|
||||||
|
|
||||||
|
print(f"Cached {len(term_texts)} terms embeddings to {cache_file}")
|
||||||
|
return cache_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating terms embeddings: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def search_similar_terms(query_text: str, cached_terms_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
在缓存的terms中搜索与查询文本相似的terms
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_text: 查询文本
|
||||||
|
cached_terms_data: 缓存的terms数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 匹配的terms列表,按相似度降序排列
|
||||||
|
"""
|
||||||
|
if not cached_terms_data or not query_text or 'embeddings' not in cached_terms_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 生成查询文本的embedding
|
||||||
|
query_embedding = encode_texts_via_api([query_text], batch_size=1)
|
||||||
|
if len(query_embedding) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_vector = query_embedding[0]
|
||||||
|
term_embeddings = cached_terms_data['embeddings']
|
||||||
|
term_info = cached_terms_data['term_info']
|
||||||
|
|
||||||
|
# 添加调试信息
|
||||||
|
print(f"DEBUG: Query text: '{query_text}'")
|
||||||
|
print(f"DEBUG: Query vector shape: {query_vector.shape}, norm: {np.linalg.norm(query_vector)}")
|
||||||
|
|
||||||
|
# 计算cos相似度
|
||||||
|
similarities = _cosine_similarity(query_vector, term_embeddings)
|
||||||
|
|
||||||
|
print(f"DEBUG: Similarities: {similarities}")
|
||||||
|
print(f"DEBUG: Max similarity: {np.max(similarities):.3f}, Mean similarity: {np.mean(similarities):.3f}")
|
||||||
|
|
||||||
|
# 获取所有terms的相似度
|
||||||
|
matches = []
|
||||||
|
for i, similarity in enumerate(similarities):
|
||||||
|
match = {
|
||||||
|
'term_info': term_info[i],
|
||||||
|
'similarity': float(similarity),
|
||||||
|
'index': i
|
||||||
|
}
|
||||||
|
matches.append(match)
|
||||||
|
|
||||||
|
# 按相似度降序排列
|
||||||
|
matches.sort(key=lambda x: x['similarity'], reverse=True)
|
||||||
|
|
||||||
|
# 只返回top5结果
|
||||||
|
return matches[:5]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in similarity search: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def format_terms_analysis(similar_terms: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
格式化相似terms为指定格式的字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
similar_terms: 相似terms列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的terms分析
|
||||||
|
"""
|
||||||
|
if not similar_terms:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_terms = []
|
||||||
|
|
||||||
|
for i, match in enumerate(similar_terms, 1):
|
||||||
|
term_info = match['term_info']
|
||||||
|
similarity = match['similarity']
|
||||||
|
|
||||||
|
name = term_info.get('name', '')
|
||||||
|
description = term_info.get('description', '')
|
||||||
|
synonyms = term_info.get('synonyms', [])
|
||||||
|
|
||||||
|
# 格式化同义词
|
||||||
|
synonyms_str = ', '.join(synonyms) if synonyms else 'N/A'
|
||||||
|
|
||||||
|
formatted_term = f"{i}) Name: {name}, Description: {description}, Synonyms: {synonyms_str} (Similarity: {similarity:.3f})"
|
||||||
|
formatted_terms.append(formatted_term)
|
||||||
|
|
||||||
|
return "\n".join(formatted_terms)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_terms_hash(terms_list: List[Dict[str, Any]]) -> str:
|
||||||
|
"""生成terms列表的哈希值用于缓存验证"""
|
||||||
|
# 将terms列表转换为标准化的字符串
|
||||||
|
terms_str = json.dumps(terms_list, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.md5(terms_str.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(query_vector: np.ndarray, term_embeddings: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
计算查询向量与所有term embeddings的cos相似度
|
||||||
|
参考semantic_search_server.py的实现,假设向量已经归一化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_vector: 查询向量 (shape: [embedding_dim])
|
||||||
|
term_embeddings: term embeddings矩阵 (shape: [n_terms, embedding_dim])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 相似度数组 (shape: [n_terms])
|
||||||
|
"""
|
||||||
|
# 使用与semantic_search_server.py相同的算法
|
||||||
|
if len(term_embeddings.shape) > 1:
|
||||||
|
cos_scores = np.dot(term_embeddings, query_vector) / (
|
||||||
|
np.linalg.norm(term_embeddings, axis=1) * np.linalg.norm(query_vector) + 1e-8
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cos_scores = np.array([0.0] * len(term_embeddings))
|
||||||
|
|
||||||
|
return cos_scores
|
||||||
|
|
||||||
|
|
||||||
|
def process_terms_with_embedding(terms_list: List[Dict[str, Any]], bot_id: str, query_text: str) -> str:
|
||||||
|
"""
|
||||||
|
完整的terms处理流程:缓存、搜索相似度、格式化输出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
terms_list: terms列表
|
||||||
|
bot_id: 机器人ID
|
||||||
|
query_text: 用户查询文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的terms分析结果
|
||||||
|
"""
|
||||||
|
if not terms_list or not query_text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 1. 缓存terms的embeddings
|
||||||
|
cached_data = cache_terms_embeddings(bot_id, terms_list)
|
||||||
|
|
||||||
|
if not cached_data:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 2. 搜索相似的terms (取top5)
|
||||||
|
similar_terms = search_similar_terms(query_text, cached_data)
|
||||||
|
|
||||||
|
# 3. 格式化输出
|
||||||
|
if similar_terms:
|
||||||
|
return format_terms_analysis(similar_terms)
|
||||||
|
else:
|
||||||
|
# 当没有找到相似terms时,可以返回空字符串或者提示信息
|
||||||
|
# 这里返回空字符串,让调用方决定如何处理
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
# 其他示例调用(注释掉的):
|
# 其他示例调用(注释掉的):
|
||||||
# split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
|
# split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
|
||||||
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行
|
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行
|
||||||
|
|||||||
@ -137,6 +137,8 @@ Examples of Guideline Match Evaluations:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Terms:
|
||||||
|
{terms}
|
||||||
|
|
||||||
Chat History:
|
Chat History:
|
||||||
{chat_history}
|
{chat_history}
|
||||||
|
|||||||
101
routes/chat.py
101
routes/chat.py
@ -12,7 +12,7 @@ from utils import (
|
|||||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||||
from utils.api_models import ChatRequestV2
|
from utils.api_models import ChatRequestV2
|
||||||
from utils.fastapi_utils import (
|
from utils.fastapi_utils import (
|
||||||
process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history,
|
process_messages, extract_block_from_system_prompt, format_messages_to_chat_history,
|
||||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||||
call_guideline_llm, _get_optimal_batch_size, process_guideline_batch, get_content_from_messages
|
call_guideline_llm, _get_optimal_batch_size, process_guideline_batch, get_content_from_messages
|
||||||
)
|
)
|
||||||
@ -26,6 +26,34 @@ agent_manager = init_global_sharded_agent_manager(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||||||
|
"""获取消息列表中的最后一条消息"""
|
||||||
|
if not messages or len(messages) == 0:
|
||||||
|
return ""
|
||||||
|
last_message = messages[-1]
|
||||||
|
if last_message and last_message.get('role') == 'user':
|
||||||
|
return last_message["content"]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def append_user_last_message(messages: list, content: str) -> bool:
|
||||||
|
"""向最后一条用户消息追加内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
content: 要追加的内容
|
||||||
|
condition: 可选条件,如果提供则检查消息角色是否匹配此条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功追加内容
|
||||||
|
"""
|
||||||
|
if not messages or len(messages) == 0:
|
||||||
|
return messages
|
||||||
|
last_message = messages[-1]
|
||||||
|
if last_message and last_message.get('role') == 'user':
|
||||||
|
messages[-1]['content'] += content
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def generate_stream_response(agent, messages, thought_list, tool_response: bool, model: str):
|
async def generate_stream_response(agent, messages, thought_list, tool_response: bool, model: str):
|
||||||
"""生成流式响应"""
|
"""生成流式响应"""
|
||||||
accumulated_content = ""
|
accumulated_content = ""
|
||||||
@ -134,15 +162,41 @@ async def create_agent_and_generate_response(
|
|||||||
if generate_cfg is None:
|
if generate_cfg is None:
|
||||||
generate_cfg = {}
|
generate_cfg = {}
|
||||||
|
|
||||||
# 1. 从system_prompt提取guideline内容
|
# 1. 从system_prompt提取guideline和terms内容
|
||||||
system_prompt, guidelines_text = extract_guidelines_from_system_prompt(system_prompt)
|
system_prompt, guidelines_list, terms_list = extract_block_from_system_prompt(system_prompt)
|
||||||
print(f"guidelines_text: {guidelines_text}")
|
|
||||||
|
|
||||||
# 2. 如果有guideline内容,进行并发处理
|
# 2. 如果有terms内容,先进行embedding(embedding需要缓存起来,这个可以tmp文件缓存,以{bot_id}_terms作为key)embedding实现情参考 @embedding/embedding.py 文件,可以在里面实现。拿到embedding后,可以进行相似性检索,检索方式先使用cos相似度,找到阈值相似性>0.7的匹配项,重新整理为terms_analysis,格式:1) Name: term_name1, Description: desc, Synonyms: syn1, syn2。
|
||||||
|
terms_analysis = ""
|
||||||
|
if terms_list:
|
||||||
|
print(f"terms_list: {terms_list}")
|
||||||
|
# 从messages中提取用户的查询文本用于相似性检索
|
||||||
|
query_text = get_user_last_message_content(messages)
|
||||||
|
# 使用embedding进行terms处理
|
||||||
|
try:
|
||||||
|
from embedding.embedding import process_terms_with_embedding
|
||||||
|
terms_analysis = process_terms_with_embedding(terms_list, bot_id, query_text)
|
||||||
|
if terms_analysis:
|
||||||
|
# 将terms分析结果也添加到消息中
|
||||||
|
messages = append_user_last_message(messages, f"\n\nRelevant Terms:\n{terms_analysis}")
|
||||||
|
print(f"Generated terms analysis: {terms_analysis[:200]}...") # 只打印前200个字符
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing terms with embedding: {e}")
|
||||||
|
terms_analysis = ""
|
||||||
|
else:
|
||||||
|
# 当terms_list为空时,删除对应的pkl缓存文件
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
cache_file = f"projects/cache/{bot_id}_terms.pkl"
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
os.remove(cache_file)
|
||||||
|
print(f"Removed empty terms cache file: {cache_file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error removing terms cache file: {e}")
|
||||||
|
|
||||||
|
# 3. 如果有guideline内容,进行并发处理
|
||||||
guideline_analysis = ""
|
guideline_analysis = ""
|
||||||
if guidelines_text:
|
if guidelines_list:
|
||||||
# 按换行符分割guidelines
|
print(f"guidelines_list: {guidelines_list}")
|
||||||
guidelines_list = [g.strip() for g in guidelines_text.split('\n') if g.strip()]
|
|
||||||
guidelines_count = len(guidelines_list)
|
guidelines_count = len(guidelines_list)
|
||||||
|
|
||||||
if guidelines_count > 0:
|
if guidelines_count > 0:
|
||||||
@ -152,11 +206,16 @@ async def create_agent_and_generate_response(
|
|||||||
# 计算每个批次应该包含多少条guideline
|
# 计算每个批次应该包含多少条guideline
|
||||||
guidelines_per_batch = max(1, guidelines_count // batch_count)
|
guidelines_per_batch = max(1, guidelines_count // batch_count)
|
||||||
|
|
||||||
# 分批处理guidelines
|
# 分批处理guidelines - 将字典列表转换为字符串列表以便处理
|
||||||
batches = []
|
batches = []
|
||||||
for i in range(0, guidelines_count, guidelines_per_batch):
|
for i in range(0, guidelines_count, guidelines_per_batch):
|
||||||
batch = guidelines_list[i:i + guidelines_per_batch]
|
batch_guidelines = guidelines_list[i:i + guidelines_per_batch]
|
||||||
batches.append(batch)
|
# 将格式化为字符串,保持原有的格式以便LLM处理
|
||||||
|
batch_strings = []
|
||||||
|
for guideline in batch_guidelines:
|
||||||
|
guideline_str = f"{guideline['id']}) Condition: {guideline['condition']} Action: {guideline['action']}"
|
||||||
|
batch_strings.append(guideline_str)
|
||||||
|
batches.append(batch_strings)
|
||||||
|
|
||||||
# 确保批次数量不超过要求的并发数
|
# 确保批次数量不超过要求的并发数
|
||||||
while len(batches) > batch_count:
|
while len(batches) > batch_count:
|
||||||
@ -177,6 +236,7 @@ async def create_agent_and_generate_response(
|
|||||||
task = process_guideline_batch(
|
task = process_guideline_batch(
|
||||||
guidelines_batch=batch,
|
guidelines_batch=batch,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
|
terms=terms_analysis,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model_server=model_server
|
model_server=model_server
|
||||||
@ -228,10 +288,9 @@ async def create_agent_and_generate_response(
|
|||||||
print(f"Merged guideline analysis result: {guideline_analysis}")
|
print(f"Merged guideline analysis result: {guideline_analysis}")
|
||||||
|
|
||||||
# 将分析结果添加到最后一个消息的内容中
|
# 将分析结果添加到最后一个消息的内容中
|
||||||
if guideline_analysis and messages:
|
if guideline_analysis:
|
||||||
last_message = messages[-1]
|
messages = append_user_last_message(messages, f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response.")
|
||||||
if last_message.get('role') == 'user':
|
|
||||||
messages[-1]['content'] += f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response."
|
|
||||||
else:
|
else:
|
||||||
# 3. 从全局管理器获取或创建助手实例
|
# 3. 从全局管理器获取或创建助手实例
|
||||||
agent = await agent_manager.get_or_create_agent(
|
agent = await agent_manager.get_or_create_agent(
|
||||||
@ -248,6 +307,18 @@ async def create_agent_and_generate_response(
|
|||||||
user_identifier=user_identifier
|
user_identifier=user_identifier
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if language:
|
||||||
|
# 在最后一条消息的末尾追加回复语言
|
||||||
|
language_map = {
|
||||||
|
'zh': '请用中文回复',
|
||||||
|
'en': 'Please reply in English',
|
||||||
|
'ja': '日本語で回答してください',
|
||||||
|
'jp': '日本語で回答してください'
|
||||||
|
}
|
||||||
|
language_instruction = language_map.get(language.lower(), '')
|
||||||
|
if language_instruction:
|
||||||
|
messages = append_user_last_message(messages, f"\n\nlanguage:{language_instruction}")
|
||||||
|
|
||||||
thought_list = []
|
thought_list = []
|
||||||
if guideline_analysis != '':
|
if guideline_analysis != '':
|
||||||
thought_list = [{"role": "assistant","reasoning_content": guideline_analysis}]
|
thought_list = [{"role": "assistant","reasoning_content": guideline_analysis}]
|
||||||
|
|||||||
@ -255,44 +255,9 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
|||||||
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
|
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
|
||||||
final_messages.append(msg)
|
final_messages.append(msg)
|
||||||
|
|
||||||
# 在最后一条消息的末尾追加回复语言
|
|
||||||
if final_messages and language:
|
|
||||||
language_map = {
|
|
||||||
'zh': '请用中文回复',
|
|
||||||
'en': 'Please reply in English',
|
|
||||||
'ja': '日本語で回答してください',
|
|
||||||
'jp': '日本語で回答してください'
|
|
||||||
}
|
|
||||||
language_instruction = language_map.get(language.lower(), '')
|
|
||||||
if language_instruction:
|
|
||||||
# 在最后一条消息末尾追加语言指令
|
|
||||||
final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\nlanguage:\n{language_instruction}。"
|
|
||||||
|
|
||||||
return final_messages
|
return final_messages
|
||||||
|
|
||||||
|
|
||||||
def extract_guidelines_from_system_prompt(system_prompt: Optional[str]) -> tuple[str, str]:
|
|
||||||
"""从system_prompt中提取```guideline内容并清理原提示词
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str]: (清理后的system_prompt, 提取的guidelines内容)
|
|
||||||
"""
|
|
||||||
if not system_prompt:
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
# 使用正则表达式提取 ```guideline``` 包裹的内容
|
|
||||||
pattern = r'```guideline\s*\n(.*?)\n```'
|
|
||||||
matches = re.findall(pattern, system_prompt, re.DOTALL)
|
|
||||||
|
|
||||||
# 如果没有匹配到guidelines,直接返回空字符串和原始prompt
|
|
||||||
if not matches:
|
|
||||||
return system_prompt, ""
|
|
||||||
|
|
||||||
guidelines_text = "\n".join(matches).strip()
|
|
||||||
|
|
||||||
# 从原始system_prompt中删除 ```guideline``` 内容块
|
|
||||||
cleaned_prompt = re.sub(pattern, '', system_prompt, flags=re.DOTALL)
|
|
||||||
return cleaned_prompt, guidelines_text
|
|
||||||
|
|
||||||
|
|
||||||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||||
@ -397,7 +362,7 @@ async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def call_guideline_llm(chat_history: str, guidelines_text: str, model_name: str, api_key: str, model_server: str) -> str:
|
async def call_guideline_llm(chat_history: str, guidelines_text: str, terms:str, model_name: str, api_key: str, model_server: str) -> str:
|
||||||
"""调用大语言模型处理guideline分析
|
"""调用大语言模型处理guideline分析
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -419,7 +384,7 @@ async def call_guideline_llm(chat_history: str, guidelines_text: str, model_name
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 替换模板中的占位符
|
# 替换模板中的占位符
|
||||||
system_prompt = guideline_template.replace('{chat_history}', chat_history).replace('{guidelines_text}', guidelines_text)
|
system_prompt = guideline_template.replace('{chat_history}', chat_history).replace('{guidelines_text}', guidelines_text).replace('{terms}', terms)
|
||||||
|
|
||||||
# 配置LLM
|
# 配置LLM
|
||||||
llm_config = {
|
llm_config = {
|
||||||
@ -473,6 +438,7 @@ def _get_optimal_batch_size(guidelines_count: int) -> int:
|
|||||||
async def process_guideline_batch(
|
async def process_guideline_batch(
|
||||||
guidelines_batch: List[str],
|
guidelines_batch: List[str],
|
||||||
chat_history: str,
|
chat_history: str,
|
||||||
|
terms: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model_server: str
|
model_server: str
|
||||||
@ -481,7 +447,7 @@ async def process_guideline_batch(
|
|||||||
try:
|
try:
|
||||||
# 调用LLM分析这批guidelines
|
# 调用LLM分析这批guidelines
|
||||||
batch_guidelines_text = "\n".join(guidelines_batch)
|
batch_guidelines_text = "\n".join(guidelines_batch)
|
||||||
batch_analysis = await call_guideline_llm(chat_history, batch_guidelines_text, model_name, api_key, model_server)
|
batch_analysis = await call_guideline_llm(chat_history, batch_guidelines_text, terms, model_name, api_key, model_server)
|
||||||
|
|
||||||
# 从响应中提取 ```json 和 ``` 包裹的内容
|
# 从响应中提取 ```json 和 ``` 包裹的内容
|
||||||
json_pattern = r'```json\s*\n(.*?)\n```'
|
json_pattern = r'```json\s*\n(.*?)\n```'
|
||||||
@ -500,3 +466,212 @@ async def process_guideline_batch(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing guideline batch: {e}")
|
print(f"Error processing guideline batch: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_block_from_system_prompt(system_prompt: Optional[str]) -> tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
从system prompt中提取guideline和terms内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, List[Dict], List[Dict]]: (清理后的system_prompt, guidelines_list, terms_list)
|
||||||
|
"""
|
||||||
|
if not system_prompt:
|
||||||
|
return "", [], []
|
||||||
|
|
||||||
|
guidelines_list = []
|
||||||
|
terms_list = []
|
||||||
|
|
||||||
|
# 首先分割所有的代码块
|
||||||
|
block_pattern = r'```(\w+)\s*\n(.*?)\n```'
|
||||||
|
blocks_to_remove = []
|
||||||
|
|
||||||
|
for match in re.finditer(block_pattern, system_prompt, re.DOTALL):
|
||||||
|
block_type, content = match.groups()
|
||||||
|
|
||||||
|
if block_type == 'guideline':
|
||||||
|
try:
|
||||||
|
guidelines = parse_guidelines_text(content.strip())
|
||||||
|
guidelines_list.extend(guidelines)
|
||||||
|
blocks_to_remove.append(match.group(0))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing guidelines: {e}")
|
||||||
|
|
||||||
|
elif block_type == 'terms':
|
||||||
|
try:
|
||||||
|
terms = parse_terms_text(content.strip())
|
||||||
|
terms_list.extend(terms)
|
||||||
|
blocks_to_remove.append(match.group(0))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing terms: {e}")
|
||||||
|
|
||||||
|
# 从system_prompt中移除这些已解析的块
|
||||||
|
cleaned_prompt = system_prompt
|
||||||
|
for block in blocks_to_remove:
|
||||||
|
cleaned_prompt = cleaned_prompt.replace(block, '', 1)
|
||||||
|
|
||||||
|
# 清理多余的空行
|
||||||
|
cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip()
|
||||||
|
|
||||||
|
return cleaned_prompt, guidelines_list, terms_list
|
||||||
|
|
||||||
|
|
||||||
|
def parse_guidelines_text(text: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
解析guidelines文本,支持多种格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: guidelines文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: guidelines列表
|
||||||
|
"""
|
||||||
|
guidelines = []
|
||||||
|
|
||||||
|
# 尝试解析JSON格式
|
||||||
|
if text.strip().startswith('[') or text.strip().startswith('{'):
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
if isinstance(data, list):
|
||||||
|
for item in data:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
guidelines.append(item)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
guidelines.append(data)
|
||||||
|
return guidelines
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 解析行格式,支持多种分隔符
|
||||||
|
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
# 跳过注释行
|
||||||
|
if line.startswith('#') or line.startswith('//'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 "id) Condition: ... Action: ..." 格式
|
||||||
|
id_condition_action_pattern = r'(\d+)\)\s*Condition:\s*(.*?)\s*Action:\s*(.*?)(?:\s*Priority:\s*(\d+))?$'
|
||||||
|
match = re.match(id_condition_action_pattern, line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
guidelines.append({
|
||||||
|
'id': int(match.group(1)),
|
||||||
|
'condition': match.group(2).strip(),
|
||||||
|
'action': match.group(3).strip(),
|
||||||
|
'priority': int(match.group(4)) if match.group(4) else 1
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 "condition -> action" 格式
|
||||||
|
arrow_pattern = r'(?:\d+\)\s*)?(.*?)\s*->\s*(.*?)(?:\s*\[(\d+)\])?$'
|
||||||
|
match = re.match(arrow_pattern, line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
guidelines.append({
|
||||||
|
'id': len(guidelines) + 1,
|
||||||
|
'condition': match.group(1).strip(),
|
||||||
|
'action': match.group(2).strip(),
|
||||||
|
'priority': int(match.group(3)) if match.group(3) else 1
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 "if condition then action" 格式
|
||||||
|
if_then_pattern = r'(?:\d+\)\s*)?if\s+(.*?)\s+then\s+(.*?)(?:\s*\[(\d+)\])?$'
|
||||||
|
match = re.match(if_then_pattern, line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
guidelines.append({
|
||||||
|
'id': len(guidelines) + 1,
|
||||||
|
'condition': match.group(1).strip(),
|
||||||
|
'action': match.group(2).strip(),
|
||||||
|
'priority': int(match.group(3)) if match.group(3) else 1
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 默认格式:整行作为action,condition为空
|
||||||
|
guidelines.append({
|
||||||
|
'id': len(guidelines) + 1,
|
||||||
|
'condition': '',
|
||||||
|
'action': line.strip(),
|
||||||
|
'priority': 1
|
||||||
|
})
|
||||||
|
|
||||||
|
return guidelines
|
||||||
|
|
||||||
|
|
||||||
|
def parse_terms_text(text: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
解析terms文本,支持多种格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: terms文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: terms列表
|
||||||
|
"""
|
||||||
|
terms = []
|
||||||
|
|
||||||
|
# 尝试解析JSON格式
|
||||||
|
if text.strip().startswith('[') or text.strip().startswith('{'):
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
if isinstance(data, list):
|
||||||
|
for item in data:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
terms.append(item)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
terms.append(data)
|
||||||
|
return terms
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 解析行格式,支持多种分隔符
|
||||||
|
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||||||
|
|
||||||
|
current_term = {}
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
# 跳过注释行
|
||||||
|
if line.startswith('#') or line.startswith('//'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 "1) Name: term_name1, Description: desc, Synonyms: syn1, syn2" 格式
|
||||||
|
numbered_term_pattern = r'(?:\d+\)\s*)?Name:\s*([^,]+)(?:,\s*Description:\s*([^,]+))?(?:,\s*Synonyms:\s*(.+))?'
|
||||||
|
match = re.match(numbered_term_pattern, line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
name = match.group(1).strip()
|
||||||
|
description = match.group(2).strip() if match.group(2) else ''
|
||||||
|
synonyms_text = match.group(3).strip() if match.group(3) else ''
|
||||||
|
|
||||||
|
# 构建term对象
|
||||||
|
term_data = {'name': name}
|
||||||
|
if description:
|
||||||
|
term_data['description'] = description
|
||||||
|
if synonyms_text:
|
||||||
|
synonyms = re.split(r'[,;|]', synonyms_text)
|
||||||
|
term_data['synonyms'] = [s.strip() for s in synonyms if s.strip()]
|
||||||
|
|
||||||
|
if current_term: # 保存之前的term
|
||||||
|
terms.append(current_term)
|
||||||
|
current_term = term_data
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 "| value" 格式(简化格式)
|
||||||
|
if line.startswith('|'):
|
||||||
|
parts = [p.strip() for p in line[1:].split('|', 2)] # 最多分割3部分
|
||||||
|
if len(parts) >= 1:
|
||||||
|
if current_term:
|
||||||
|
terms.append(current_term)
|
||||||
|
current_term = {'name': parts[0]}
|
||||||
|
if len(parts) >= 2:
|
||||||
|
current_term['description'] = parts[1]
|
||||||
|
if len(parts) >= 3:
|
||||||
|
synonyms = re.split(r'[,;|]', parts[2])
|
||||||
|
current_term['synonyms'] = [s.strip() for s in synonyms if s.strip()]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 添加最后一个term
|
||||||
|
if current_term:
|
||||||
|
terms.append(current_term)
|
||||||
|
|
||||||
|
return terms
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user