qwen_agent/utils/fastapi_utils.py
2026-01-28 23:32:34 +08:00

989 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import hashlib
import json
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Optional, Union, Any
import aiohttp
from fastapi import HTTPException
import logging
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, convert_to_openai_messages
from langchain.chat_models import init_chat_model
from utils.settings import MASTERKEY, BACKEND_HOST
from agent.agent_config import AgentConfig
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
logger = logging.getLogger('app')
# 创建全局线程池执行器用于执行同步的HTTP调用
thread_pool = ThreadPoolExecutor(max_workers=10)
# 创建并发信号量限制同时进行的API调用数量
api_semaphore = asyncio.Semaphore(8) # 最多同时进行8个API调用
def detect_provider(model_name,model_server):
"""根据模型名称检测提供商类型"""
model_name_lower = model_name.lower()
if any(claude_model in model_name_lower for claude_model in ["claude", "anthropic"]):
return "anthropic",model_server.replace("/v1","")
elif any(openai_model in model_name_lower for openai_model in ["gpt", "openai", "o1"]):
return "openai",model_server
else:
# 默认使用 openai 兼容格式
return "openai",model_server
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
"""
获取带版本号的文件名,自动处理文件删除和版本递增
Args:
upload_dir: 上传目录路径
name_without_ext: 不含扩展名的文件名
file_extension: 文件扩展名(包含点号)
Returns:
tuple[str, int]: (最终文件名, 版本号)
"""
# 检查原始文件是否存在
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
original_exists = os.path.exists(original_file)
# 查找所有相关的版本化文件
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
existing_versions = []
files_to_delete = []
for filename in os.listdir(upload_dir):
# 检查是否是原始文件
if filename == name_without_ext + file_extension:
files_to_delete.append(filename)
continue
# 检查是否是版本化文件
match = pattern.match(filename)
if match:
version_num = int(match.group(1))
existing_versions.append(version_num)
files_to_delete.append(filename)
# 如果没有任何相关文件存在使用原始文件名版本1
if not original_exists and not existing_versions:
return name_without_ext + file_extension, 1
# 删除所有现有文件(原始文件和版本化文件)
for filename in files_to_delete:
file_to_delete = os.path.join(upload_dir, filename)
try:
os.remove(file_to_delete)
logger.info(f"已删除文件: {file_to_delete}")
except OSError as e:
logger.error(f"删除文件失败 {file_to_delete}: {e}")
# 确定下一个版本号
if existing_versions:
next_version = max(existing_versions) + 1
else:
next_version = 2
# 生成带版本号的文件名
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
return versioned_filename, next_version
def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, finish_reason: str = None) -> dict:
"""Create a standardized streaming response chunk"""
chunk_data = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": content} if content is not None else {},
"finish_reason": finish_reason
}]
}
return chunk_data
# def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
# """Extract content from qwen-agent messages with special formatting"""
# full_text = ''
# content = []
# TOOL_CALL_S = '[TOOL_CALL]'
# TOOL_RESULT_S = '[TOOL_RESPONSE]'
# THOUGHT_S = '[THINK]'
# ANSWER_S = '[ANSWER]'
# PREAMBLE_S = '[PREAMBLE]'
# for msg in messages:
# if msg['role'] == ASSISTANT:
# if msg.get('reasoning_content'):
# assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
# content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
# if msg.get('content'):
# assert isinstance(msg['content'], str), 'Now only supports text messages'
# # 过滤掉流式输出中的不完整 tool_call 文本
# content_text = msg["content"]
# # 使用正则表达式替换不完整的 tool_call 模式为空字符串
# # 匹配并替换不完整的 tool_call 模式
# content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
# # 只有在处理后内容不为空时才添加
# if content_text.strip():
# content.append(f'{ANSWER_S}\n{content_text}')
# if msg.get('function_call'):
# content_text = msg["function_call"]["arguments"]
# content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
# if content_text.strip():
# content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
# elif msg['role'] == FUNCTION:
# if tool_response:
# content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
# elif msg['role'] == "preamble":
# content.append(f'{PREAMBLE_S}\n{msg["content"]}')
# else:
# raise TypeError
# if content:
# full_text = '\n'.join(content)
# return full_text
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
"""处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加
这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回
msg['role'] == 'function' 和 msg.get('function_call') 的格式。
Args:
messages: 消息列表
language: 可选的语言参数
include_function_name: 需要包含的function_name关键词列表默认包含['find', 'get']
"""
# 设置默认的排除function_name列表
include_function_name = ['find', 'get']
processed_messages = []
# 收集所有ASSISTANT消息的索引
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == ASSISTANT]
total_assistant_messages = len(assistant_indices)
cutoff_point = max(0, total_assistant_messages - 5)
# 处理每条消息
for i, msg in enumerate(messages):
if msg.role == ASSISTANT:
# 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置从0开始
assistant_position = assistant_indices.index(i)
# 使用正则表达式按照 [THINK|TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] 进行切割
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content)
# 重新组装内容,根据消息位置决定处理方式
filtered_content = ""
current_tag = None
is_recent_message = assistant_position >= cutoff_point # 最近10条消息
for i in range(0, len(parts)):
if i % 2 == 0: # 文本内容
text = parts[i].strip()
if not text:
continue
# 不往后传输 历史工具调用的文字
if current_tag == "TOOL_RESPONSE":
if is_recent_message:
# 最近10条ASSISTANT消息保留完整TOOL_RESPONSE信息使用简略模式
if len(text) <= 1000:
filtered_content += f"[TOOL_RESPONSE] {text}\n"
else:
# 截取前中后3段内容每段250字
first_part = text[:250]
middle_start = len(text) // 2 - 125
middle_part = text[middle_start:middle_start + 250]
last_part = text[-250:]
# 计算省略的字数
omitted_count = len(text) - 750
omitted_text = f"...此处省略{omitted_count}字..."
# 拼接内容
truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}"
filtered_content += f"[TOOL_RESPONSE] {truncated_text}\n"
# 10条以上的消息不保留TOOL_RESPONSE数据完全跳过
elif current_tag == "TOOL_CALL":
if is_recent_message:
# 最近10条ASSISTANT消息保留TOOL_CALL信息
filtered_content += f"[TOOL_CALL] {text}\n"
# 10条以上的消息不保留TOOL_CALL数据完全跳过
elif current_tag == "ANSWER":
# 所有ASSISTANT消息都保留ANSWER数据
filtered_content += f"[ANSWER] {text}\n"
elif current_tag != "THINK" and current_tag != "PREAMBLE":
filtered_content += text + "\n"
else: # 标签
current_tag = parts[i]
# 取最终处理后的内容,去除首尾空白
final_content = filtered_content.strip()
if final_content:
processed_messages.append({"role": msg.role, "content": final_content})
else:
# 如果处理后为空,使用原内容
processed_messages.append({"role": msg.role, "content": msg.content})
else:
processed_messages.append({"role": msg.role, "content": msg.content})
# 逆运算:将包含 [THINK|TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call')
# 这是 get_content_from_messages 的逆运算
final_messages = []
for msg in processed_messages:
if msg["role"] == ASSISTANT:
# 分割消息内容
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
current_tag = None
tool_id_counter = 0 # 添加唯一的工具调用计数器
tool_id_list = []
for i in range(0, len(parts)):
if i % 2 == 0: # 文本内容
text = parts[i].strip()
if not text:
continue
# 不往后传输 历史工具调用的文字
if current_tag == "TOOL_RESPONSE":
# 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent
lines = text.split('\n', 1)
function_name = lines[0].strip() if lines else ""
response_content = lines[1].strip() if len(lines) > 1 else ""
# 过滤掉包含指定关键词的function_name
should_include = False
if function_name:
for exclude_name in include_function_name:
if exclude_name in function_name:
should_include = True
break
if should_include and len(tool_id_list)>0:
tool_id = tool_id_list.pop(0)
# 将 TOOL_RESPONSE 包装成 tool_result 消息,紧跟对应的 tool_use
final_messages.append({
"role": TOOL,
"tool_call_id": tool_id, # 与前面 tool_use 的 id 保持一致
"name": function_name,
"content": response_content
})
elif current_tag == "TOOL_CALL":
# 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments
lines = text.split('\n', 1)
function_name = lines[0].strip() if lines else ""
arguments = lines[1].strip() if len(lines) > 1 else ""
# 过滤掉包含指定关键词的function_name
should_include = False
if function_name:
for exclude_name in include_function_name:
if exclude_name in function_name:
should_include = True
break
if should_include:
tool_id = f"tool_id_{tool_id_counter}" # 使用唯一计数器
tool_id_list.append(tool_id)
tool_id_counter += 1 # 递增计数器
final_messages.append({
"role": ASSISTANT,
"content": "",
"tool_calls": [{
"id":tool_id,
"function": {
"name": function_name,
"arguments": arguments
}
}]
})
elif current_tag != "THINK" and current_tag != "PREAMBLE":
final_messages.append({
"role": ASSISTANT,
"content": text
})
else: # 标签
current_tag = parts[i]
else:
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
final_messages.append(msg)
return final_messages
def get_user_last_message_content(messages: list) -> Optional[dict]:
"""获取消息列表中的最后一条消息"""
if not messages or len(messages) == 0:
return ""
last_message = messages[-1]
if last_message and last_message.get('role') == 'user':
return last_message["content"]
return ""
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
"""将messages格式化为纯文本聊天记录
Args:
messages: 消息列表
Returns:
str: 格式化的聊天记录
"""
# 只取最后的15句消息
chat_history = []
for message in messages:
role = message.get('role', '')
content = message.get('content', '')
name = message.get('name', '')
if role == USER:
chat_history.append(f"user: {content}")
elif role == TOOL:
chat_history.append(f"{name} response: {content}")
elif role == ASSISTANT:
if len(content) >0:
chat_history.append(f"assistant: {content}")
if message.get('tool_calls'):
for tool_call in message.get('tool_calls'):
function_name = tool_call.get('function').get('name')
arguments = tool_call.get('function').get('arguments')
chat_history.append(f"{function_name} call: {arguments}")
recent_chat_history = chat_history[-16:-1] if len(chat_history) > 16 else chat_history[:-1]
return "\n".join(recent_chat_history)
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "general_agent", skills: Optional[List[str]] = None) -> Optional[str]:
"""创建项目目录的公共逻辑"""
# 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录
if robot_type == "general_agent":
return None
# 如果 dataset_ids 为空,不创建目录
if not dataset_ids:
dataset_ids = []
try:
from utils.multi_project_manager import create_robot_project
from pathlib import Path
return create_robot_project(dataset_ids, bot_id, skills=skills, robot_type=robot_type)
except Exception as e:
logger.error(f"Error creating project directory: {e}")
return None
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
"""从Authorization header中提取API key"""
if not authorization:
return None
# 移除 "Bearer " 前缀
if authorization.startswith("Bearer "):
return authorization[7:]
else:
return authorization
def generate_v2_auth_token(bot_id: str) -> str:
"""生成v2接口的认证token"""
token_input = f"{MASTERKEY}:{bot_id}"
return hashlib.md5(token_input.encode()).hexdigest()
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
"""获取机器人配置从后端API"""
try:
url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
auth_token = generate_v2_auth_token(bot_id)
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
# 使用异步HTTP请求
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, timeout=30) as response:
if response.status != 200:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch bot config: API returned status code {response.status}"
)
# 解析响应
response_data = await response.json()
if not response_data.get("success"):
raise HTTPException(
status_code=400,
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
)
return response_data.get("data", {})
except aiohttp.ClientError as e:
raise HTTPException(
status_code=500,
detail=f"Failed to connect to backend API: {str(e)}"
)
except Exception as e:
if isinstance(e, HTTPException):
raise
raise HTTPException(
status_code=500,
detail=f"Failed to fetch bot config: {str(e)}"
)
async def fetch_bot_config_from_db(bot_user_id: str) -> Dict[str, Any]:
"""
从本地数据库获取机器人配置
Args:
bot_user_id: Bot 的用户IDbot_id 字段,不是 UUID
Returns:
Dict[str, Any]: 包含所有配置参数的字典,格式与 fetch_bot_config 兼容
"""
try:
from agent.db_pool_manager import get_db_pool_manager
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 首先根据 bot_user_id 查找 bot 的 UUID
await cursor.execute(
"SELECT id, name FROM bots WHERE bot_id = %s",
(bot_user_id,)
)
bot_row = await cursor.fetchone()
if not bot_row:
raise HTTPException(
status_code=404,
detail=f"Bot with bot_id '{bot_user_id}' not found"
)
bot_uuid = bot_row[0]
# 查询 bot_settings
await cursor.execute(
"""
SELECT model_id,
language, robot_type, dataset_ids, system_prompt, user_identifier,
enable_memori, tool_response, skills
FROM bot_settings WHERE bot_id = %s
""",
(bot_uuid,)
)
settings_row = await cursor.fetchone()
if not settings_row:
# 没有设置,使用默认值
logger.warning(f"No settings found for bot {bot_user_id}, using defaults")
return {
"model": "qwen3-next",
"api_key": "",
"model_server": "",
"language": "zh",
"robot_type": "general_agent",
"dataset_ids": [],
"system_prompt": "",
"user_identifier": "",
"enable_memori": False,
"tool_response": True,
"skills": []
}
# 解析结果
columns = [
'model_id',
'language', 'robot_type', 'dataset_ids', 'system_prompt', 'user_identifier',
'enable_memori', 'tool_response', 'skills'
]
config = dict(zip(columns, settings_row))
# 根据 model_id 查询模型信息
model_id = config['model_id']
if model_id:
await cursor.execute(
"""
SELECT model, server, api_key
FROM models WHERE id = %s
""",
(model_id,)
)
model_row = await cursor.fetchone()
if model_row:
config['model'] = model_row[0]
config['model_server'] = model_row[1]
config['api_key'] = model_row[2]
else:
logger.warning(f"Model with id {model_id} not found, using defaults")
config['model'] = "qwen3-next"
config['model_server'] = ""
config['api_key'] = ""
else:
# 没有选择模型,使用默认值
config['model'] = "qwen3-next"
config['model_server'] = ""
config['api_key'] = ""
# 处理 dataset_ids (可能是 JSON 数组字符串或逗号分隔字符串)
dataset_ids = config['dataset_ids']
if dataset_ids:
if isinstance(dataset_ids, str):
if dataset_ids.startswith('['):
import json
try:
config['dataset_ids'] = json.loads(dataset_ids)
except json.JSONDecodeError:
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
else:
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
else:
config['dataset_ids'] = []
# 处理 skills (逗号分隔字符串)
skills = config.get('skills', '')
if skills:
if isinstance(skills, str):
config['skills'] = [s.strip() for s in skills.split(',') if s.strip()]
else:
config['skills'] = []
else:
config['skills'] = []
# 查询 MCP 服务器配置
await cursor.execute(
"""
SELECT name, type, config, enabled
FROM mcp_servers WHERE bot_id = %s AND enabled = true
""",
(bot_uuid,)
)
mcp_rows = await cursor.fetchall()
mcp_servers = []
for mcp_row in mcp_rows:
mcp_name = mcp_row[0]
mcp_type = mcp_row[1]
mcp_config = mcp_row[2]
# 如果 config 是 JSONB/字符串,解析它
if isinstance(mcp_config, str):
try:
mcp_config = json.loads(mcp_config)
except json.JSONDecodeError:
mcp_config = {}
mcp_servers.append({
"name": mcp_name,
"type": mcp_type,
"config": mcp_config
})
# 格式化为 mcp_settings 格式 (兼容 v2 API)
if mcp_servers:
mcp_settings_value = []
for server in mcp_servers:
server_config = server.get("config", {})
server_type = server_config.pop("server_type", server["type"])
mcp_settings_value.append({
"mcpServers": {
server_type: server_config
}
})
config["mcp_settings"] = mcp_settings_value
else:
config["mcp_settings"] = []
return config
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching bot config from database: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(
status_code=500,
detail=f"Failed to fetch bot config from database: {str(e)}"
)
async def _sync_call_llm(llm_config, messages) -> str:
"""同步调用LLM的辅助函数在线程池中执行 - 使用LangChain"""
try:
# 创建LangChain LLM实例
model_name = llm_config.get('model')
model_server = llm_config.get('model_server')
api_key = llm_config.get('api_key')
# 检测或使用指定的提供商
model_provider,base_url = detect_provider(model_name,model_server)
# 构建模型参数
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url":base_url,
"api_key":api_key
}
llm_instance = init_chat_model(**model_kwargs)
# 转换消息格式为LangChain格式
langchain_messages = []
for msg in messages:
if msg['role'] == 'system':
langchain_messages.append(SystemMessage(content=msg['content']))
elif msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# 调用LangChain模型
response = await llm_instance.ainvoke(langchain_messages)
# 返回响应内容
return response.content if response.content else ""
except Exception as e:
logger.error(f"Error calling guideline LLM with LangChain: {e}")
return ""
def get_language_text(language: str):
if language == "jp":
language = "ja"
language_map = {
'zh': '请用中文回复',
'en': 'Please reply in English',
'ja': '日本語で回答してください',
}
return language_map.get(language.lower(), '')
def get_preamble_text(language: str, system_prompt: str):
# 首先检查system_prompt中是否有preamble标签
if system_prompt:
preamble_pattern = r'<preamble>\s*(.*?)\s*</preamble>'
preamble_matches = re.findall(preamble_pattern, system_prompt, re.DOTALL)
if preamble_matches:
# 提取preamble内容
preamble_content = preamble_matches[0].strip()
if preamble_content:
# 从system_prompt中删除preamble标签
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
return preamble_content, cleaned_system_prompt
# 如果没有找到preamble代码块使用默认的preamble选择
if language == "jp":
language = "ja"
preamble_choices_map = {
'zh': [
"好的,让我来帮您看看。",
"明白了,请稍等。",
"好的,我理解了。",
"没问题,我来处理。",
"收到,正在为您查询。",
"了解,让我想想。",
"好的,我来帮您解答。",
"明白了,稍等片刻。",
"好的,正在处理中。",
"了解了,让我为您分析。"
],
'en': [
"Just a moment.",
"Got it.",
"Let me check that for you.",
"Sorry to hear that.",
"Thanks for your patience.",
"I understand.",
"Let me help you with that.",
"Please wait a moment.",
"I'll look into that for you.",
"Gotcha, let me see.",
"Understood, one moment please.",
"I'll help you with this.",
"Let me figure that out.",
"Thanks for waiting.",
"I'll check on that."
],
'ja': [
"少々お待ちください。",
"承知いたしました。",
"わかりました。",
"確認いたします。",
"少々お時間をください。",
"了解しました。",
"調べてみますね。",
"お待たせしました。",
"対応いたします。",
"わかりましたね。",
"承知いたしました。",
"確認させてください。",
"少々お待ちいただけますか。",
"お調べいたします。",
"対応いたしますね。"
]
};
default_preamble = "\n".join(preamble_choices_map.get(language.lower(), []))
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
async def call_preamble_llm(config: AgentConfig) -> str:
"""调用大语言模型处理guideline分析
Args:
messages: 消息列表
preamble_choices_text: 指导原则文本
model_name: 模型名称
api_key: API密钥
model_server: 模型服务器地址
Returns:
str: 模型响应结果
"""
# 读取guideline提示词模板
try:
with open('./prompt/preamble_prompt.md', 'r', encoding='utf-8') as f:
preamble_template = f.read()
except Exception as e:
logger.error(f"Error reading guideline prompt template: {e}")
return ""
api_key = config.api_key
model_name = config.model_name
model_server = config.model_server
language = config.language
preamble_choices_text = config.preamble_text
last_message = get_user_last_message_content(config.messages)
chat_history = format_messages_to_chat_history(convert_to_openai_messages(config._session_history))
# 替换模板中的占位符
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
# 配置LLM
llm_config = {
'model': model_name,
'api_key': api_key,
'model_server': model_server, # 使用传入的model_server参数
}
# 调用模型
messages = [{'role': 'user', 'content': system_prompt}]
try:
# 使用信号量控制并发API调用数量
async with api_semaphore:
# 直接调用异步LLM函数
response = await _sync_call_llm(llm_config, messages)
# 从响应中提取 ```json 和 ``` 包裹的内容
json_pattern = r'```json\s*\n(.*?)\n```'
json_matches = re.findall(json_pattern, response, re.DOTALL)
if json_matches:
try:
# 解析第一个找到的JSON对象
json_data = json.loads(json_matches[0])
logger.info(f"Successfully processed preamble")
return json_data["preamble"] # 返回解析后的preamble
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON from preamble analysis: {e}")
return ""
else:
logger.warning(f"No JSON format found in preamble analysis")
return ""
except Exception as e:
logger.error(f"Error calling guideline LLM: {e}")
return ""
async def call_guideline_llm(chat_history: str, guidelines_prompt: str, model_name: str, api_key: str, model_server: str) -> str:
"""调用大语言模型处理guideline分析
Args:
chat_history: 聊天历史记录
guidelines_text: 指导原则文本
model_name: 模型名称
api_key: API密钥
model_server: 模型服务器地址
user_identifier: 用户标识符
Returns:
str: 模型响应结果
"""
# 配置LLM
llm_config = {
'model': model_name,
'api_key': api_key,
'model_server': model_server, # 使用传入的model_server参数
}
# 调用模型
messages = [{'role': 'user', 'content': guidelines_prompt}]
try:
# 使用信号量控制并发API调用数量
async with api_semaphore:
# 直接调用异步LLM函数
response = await _sync_call_llm(llm_config, messages)
return response
except Exception as e:
logger.error(f"Error calling guideline LLM: {e}")
return ""
def _get_optimal_batch_size(guidelines_count: int) -> int:
"""根据guidelines数量决定最优批次数量并发数"""
if guidelines_count <= 10:
return 1
elif guidelines_count <= 20:
return 2
elif guidelines_count <= 30:
return 3
else:
return 5
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
"""
从system prompt中提取guideline和terms内容
Args:
system_prompt: 系统提示词
Returns:
tuple[str, List[Dict], List[Dict]]: (清理后的system_prompt, guidelines_list, terms_list)
"""
if not system_prompt:
return "", [], []
guidelines = ""
tools = ""
scenarios = ""
terms_list = []
# 使用XML标签格式解析块
blocks_to_remove = []
# 解析 <guidelines>
guidelines_pattern = r'<guidelines>\s*(.*?)\s*</guidelines>'
match = re.search(guidelines_pattern, system_prompt, re.DOTALL)
if match:
guidelines = match.group(1).strip()
blocks_to_remove.append(match.group(0))
# 解析 <tools>
tools_pattern = r'<tools>\s*(.*?)\s*</tools>'
match = re.search(tools_pattern, system_prompt, re.DOTALL)
if match:
tools = match.group(1).strip()
blocks_to_remove.append(match.group(0))
# 解析 <scenarios>
scenarios_pattern = r'<scenarios>\s*(.*?)\s*</scenarios>'
match = re.search(scenarios_pattern, system_prompt, re.DOTALL)
if match:
scenarios = match.group(1).strip()
blocks_to_remove.append(match.group(0))
# 解析 <terms>
terms_pattern = r'<terms>\s*(.*?)\s*</terms>'
match = re.search(terms_pattern, system_prompt, re.DOTALL)
if match:
try:
terms = parse_terms_text(match.group(1).strip())
terms_list.extend(terms)
blocks_to_remove.append(match.group(0))
except Exception as e:
logger.error(f"Error parsing terms: {e}")
# 从system_prompt中移除这些已解析的块
cleaned_prompt = system_prompt
for block in blocks_to_remove:
cleaned_prompt = cleaned_prompt.replace(block, '', 1)
# 清理多余的空行
cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip()
return cleaned_prompt, guidelines, tools, scenarios, terms_list
def parse_terms_text(text: str) -> List[Dict[str, Any]]:
"""
解析terms文本支持多种格式
Args:
text: terms文本内容
Returns:
List[Dict]: terms列表
"""
terms = []
# 尝试解析JSON格式
if text.strip().startswith('[') or text.strip().startswith('{'):
try:
data = json.loads(text)
if isinstance(data, list):
for item in data:
if isinstance(item, dict):
terms.append(item)
elif isinstance(data, dict):
terms.append(data)
return terms
except json.JSONDecodeError:
pass
# 解析行格式,支持多种分隔符
lines = [line.strip() for line in text.split('\n') if line.strip()]
current_term = {}
for line in lines:
# 跳过注释行
if line.startswith('#') or line.startswith('//'):
continue
# 尝试解析 "1) Name: term_name1, Description: desc, Synonyms: syn1, syn2" 格式
numbered_term_pattern = r'(?:\d+\)\s*)?Name:\s*([^,]+)(?:,\s*Description:\s*([^,]+))?(?:,\s*Synonyms:\s*(.+))?'
match = re.match(numbered_term_pattern, line, re.IGNORECASE)
if match:
name = match.group(1).strip()
description = match.group(2).strip() if match.group(2) else ''
synonyms_text = match.group(3).strip() if match.group(3) else ''
# 构建term对象
term_data = {'name': name}
if description:
term_data['description'] = description
if synonyms_text:
synonyms = re.split(r'[,;|]', synonyms_text)
term_data['synonyms'] = [s.strip() for s in synonyms if s.strip()]
if current_term: # 保存之前的term
terms.append(current_term)
current_term = term_data
continue
# 添加最后一个term
if current_term:
terms.append(current_term)
return terms