831 lines
32 KiB
Python
831 lines
32 KiB
Python
import os
|
||
import re
|
||
import hashlib
|
||
import json
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List, Dict, Optional, Union, Any
|
||
import aiohttp
|
||
from fastapi import HTTPException
|
||
import logging
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||
from langchain.chat_models import init_chat_model
|
||
|
||
USER = "user"
|
||
ASSISTANT = "assistant"
|
||
TOOL = "tool"
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
# 创建全局线程池执行器,用于执行同步的HTTP调用
|
||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||
|
||
# 创建并发信号量,限制同时进行的API调用数量
|
||
api_semaphore = asyncio.Semaphore(8) # 最多同时进行8个API调用
|
||
|
||
def detect_provider(model_name,model_server):
|
||
"""根据模型名称检测提供商类型"""
|
||
model_name_lower = model_name.lower()
|
||
if any(claude_model in model_name_lower for claude_model in ["claude", "anthropic"]):
|
||
return "anthropic",model_server.replace("/v1","")
|
||
elif any(openai_model in model_name_lower for openai_model in ["gpt", "openai", "o1"]):
|
||
return "openai",model_server
|
||
else:
|
||
# 默认使用 openai 兼容格式
|
||
return "openai",model_server
|
||
|
||
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
|
||
"""
|
||
获取带版本号的文件名,自动处理文件删除和版本递增
|
||
|
||
Args:
|
||
upload_dir: 上传目录路径
|
||
name_without_ext: 不含扩展名的文件名
|
||
file_extension: 文件扩展名(包含点号)
|
||
|
||
Returns:
|
||
tuple[str, int]: (最终文件名, 版本号)
|
||
"""
|
||
# 检查原始文件是否存在
|
||
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
|
||
original_exists = os.path.exists(original_file)
|
||
|
||
# 查找所有相关的版本化文件
|
||
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
|
||
existing_versions = []
|
||
files_to_delete = []
|
||
|
||
for filename in os.listdir(upload_dir):
|
||
# 检查是否是原始文件
|
||
if filename == name_without_ext + file_extension:
|
||
files_to_delete.append(filename)
|
||
continue
|
||
|
||
# 检查是否是版本化文件
|
||
match = pattern.match(filename)
|
||
if match:
|
||
version_num = int(match.group(1))
|
||
existing_versions.append(version_num)
|
||
files_to_delete.append(filename)
|
||
|
||
# 如果没有任何相关文件存在,使用原始文件名(版本1)
|
||
if not original_exists and not existing_versions:
|
||
return name_without_ext + file_extension, 1
|
||
|
||
# 删除所有现有文件(原始文件和版本化文件)
|
||
for filename in files_to_delete:
|
||
file_to_delete = os.path.join(upload_dir, filename)
|
||
try:
|
||
os.remove(file_to_delete)
|
||
logger.info(f"已删除文件: {file_to_delete}")
|
||
except OSError as e:
|
||
logger.error(f"删除文件失败 {file_to_delete}: {e}")
|
||
|
||
# 确定下一个版本号
|
||
if existing_versions:
|
||
next_version = max(existing_versions) + 1
|
||
else:
|
||
next_version = 2
|
||
|
||
# 生成带版本号的文件名
|
||
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
|
||
|
||
return versioned_filename, next_version
|
||
|
||
def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, finish_reason: str = None) -> dict:
|
||
"""Create a standardized streaming response chunk"""
|
||
chunk_data = {
|
||
"id": chunk_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model_name,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {"content": content} if content is not None else {},
|
||
"finish_reason": finish_reason
|
||
}]
|
||
}
|
||
return chunk_data
|
||
|
||
# def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||
# """Extract content from qwen-agent messages with special formatting"""
|
||
# full_text = ''
|
||
# content = []
|
||
# TOOL_CALL_S = '[TOOL_CALL]'
|
||
# TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
# THOUGHT_S = '[THINK]'
|
||
# ANSWER_S = '[ANSWER]'
|
||
# PREAMBLE_S = '[PREAMBLE]'
|
||
|
||
# for msg in messages:
|
||
# if msg['role'] == ASSISTANT:
|
||
# if msg.get('reasoning_content'):
|
||
# assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
# content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
# if msg.get('content'):
|
||
# assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
# # 过滤掉流式输出中的不完整 tool_call 文本
|
||
# content_text = msg["content"]
|
||
|
||
# # 使用正则表达式替换不完整的 tool_call 模式为空字符串
|
||
|
||
# # 匹配并替换不完整的 tool_call 模式
|
||
# content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# # 只有在处理后内容不为空时才添加
|
||
# if content_text.strip():
|
||
# content.append(f'{ANSWER_S}\n{content_text}')
|
||
# if msg.get('function_call'):
|
||
# content_text = msg["function_call"]["arguments"]
|
||
# content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# if content_text.strip():
|
||
# content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||
# elif msg['role'] == FUNCTION:
|
||
# if tool_response:
|
||
# content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
# elif msg['role'] == "preamble":
|
||
# content.append(f'{PREAMBLE_S}\n{msg["content"]}')
|
||
# else:
|
||
# raise TypeError
|
||
|
||
# if content:
|
||
# full_text = '\n'.join(content)
|
||
|
||
# return full_text
|
||
|
||
|
||
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||
"""处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加
|
||
|
||
这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回
|
||
msg['role'] == 'function' 和 msg.get('function_call') 的格式。
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
language: 可选的语言参数
|
||
include_function_name: 需要包含的function_name关键词列表,默认包含['find', 'get']
|
||
"""
|
||
# 设置默认的排除function_name列表
|
||
include_function_name = ['find', 'get']
|
||
|
||
processed_messages = []
|
||
|
||
# 收集所有ASSISTANT消息的索引
|
||
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == ASSISTANT]
|
||
total_assistant_messages = len(assistant_indices)
|
||
cutoff_point = max(0, total_assistant_messages - 5)
|
||
|
||
# 处理每条消息
|
||
for i, msg in enumerate(messages):
|
||
if msg.role == ASSISTANT:
|
||
# 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置(从0开始)
|
||
assistant_position = assistant_indices.index(i)
|
||
|
||
# 使用正则表达式按照 [THINK|TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] 进行切割
|
||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content)
|
||
|
||
# 重新组装内容,根据消息位置决定处理方式
|
||
filtered_content = ""
|
||
current_tag = None
|
||
is_recent_message = assistant_position >= cutoff_point # 最近10条消息
|
||
|
||
for i in range(0, len(parts)):
|
||
if i % 2 == 0: # 文本内容
|
||
text = parts[i].strip()
|
||
if not text:
|
||
continue
|
||
|
||
# 不往后传输 历史工具调用的文字
|
||
if current_tag == "TOOL_RESPONSE":
|
||
if is_recent_message:
|
||
# 最近10条ASSISTANT消息:保留完整TOOL_RESPONSE信息(使用简略模式)
|
||
if len(text) <= 1000:
|
||
filtered_content += f"[TOOL_RESPONSE] {text}\n"
|
||
else:
|
||
# 截取前中后3段内容,每段250字
|
||
first_part = text[:250]
|
||
middle_start = len(text) // 2 - 125
|
||
middle_part = text[middle_start:middle_start + 250]
|
||
last_part = text[-250:]
|
||
|
||
# 计算省略的字数
|
||
omitted_count = len(text) - 750
|
||
omitted_text = f"...此处省略{omitted_count}字..."
|
||
|
||
# 拼接内容
|
||
truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}"
|
||
filtered_content += f"[TOOL_RESPONSE] {truncated_text}\n"
|
||
# 10条以上的消息:不保留TOOL_RESPONSE数据(完全跳过)
|
||
elif current_tag == "TOOL_CALL":
|
||
if is_recent_message:
|
||
# 最近10条ASSISTANT消息:保留TOOL_CALL信息
|
||
filtered_content += f"[TOOL_CALL] {text}\n"
|
||
# 10条以上的消息:不保留TOOL_CALL数据(完全跳过)
|
||
elif current_tag == "ANSWER":
|
||
# 所有ASSISTANT消息都保留ANSWER数据
|
||
filtered_content += f"[ANSWER] {text}\n"
|
||
elif current_tag != "THINK" and current_tag != "PREAMBLE":
|
||
filtered_content += text + "\n"
|
||
else: # 标签
|
||
current_tag = parts[i]
|
||
|
||
# 取最终处理后的内容,去除首尾空白
|
||
final_content = filtered_content.strip()
|
||
if final_content:
|
||
processed_messages.append({"role": msg.role, "content": final_content})
|
||
else:
|
||
# 如果处理后为空,使用原内容
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 逆运算:将包含 [THINK|TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call')
|
||
# 这是 get_content_from_messages 的逆运算
|
||
final_messages = []
|
||
for msg in processed_messages:
|
||
if msg["role"] == ASSISTANT:
|
||
# 分割消息内容
|
||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||
|
||
current_tag = None
|
||
assistant_content = ""
|
||
function_calls = []
|
||
tool_responses = []
|
||
tool_id = ""
|
||
for i in range(0, len(parts)):
|
||
if i % 2 == 0: # 文本内容
|
||
text = parts[i].strip()
|
||
if not text:
|
||
continue
|
||
# 不往后传输 历史工具调用的文字
|
||
|
||
if current_tag == "TOOL_RESPONSE":
|
||
# 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent
|
||
lines = text.split('\n', 1)
|
||
function_name = lines[0].strip() if lines else ""
|
||
response_content = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
# 过滤掉包含指定关键词的function_name
|
||
should_include = False
|
||
if function_name:
|
||
for exclude_name in include_function_name:
|
||
if exclude_name in function_name:
|
||
should_include = True
|
||
break
|
||
|
||
if should_include:
|
||
# 将 TOOL_RESPONSE 包装成 tool_result 消息,紧跟对应的 tool_use
|
||
final_messages.append({
|
||
"role": TOOL,
|
||
"tool_call_id": tool_id, # 与前面 tool_use 的 id 保持一致
|
||
"name": function_name,
|
||
"content": response_content
|
||
})
|
||
elif current_tag == "TOOL_CALL":
|
||
# 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments
|
||
lines = text.split('\n', 1)
|
||
function_name = lines[0].strip() if lines else ""
|
||
arguments = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
# 过滤掉包含指定关键词的function_name
|
||
should_include = False
|
||
if function_name:
|
||
for exclude_name in include_function_name:
|
||
if exclude_name in function_name:
|
||
should_include = True
|
||
break
|
||
|
||
if should_include:
|
||
tool_id = f"tool_id_{i}"
|
||
final_messages.append({
|
||
"role": ASSISTANT,
|
||
"content": "",
|
||
"tool_calls": [{
|
||
"id":tool_id,
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": arguments
|
||
}
|
||
}]
|
||
})
|
||
elif current_tag != "THINK" and current_tag != "PREAMBLE":
|
||
final_messages.append({
|
||
"role": ASSISTANT,
|
||
"content": text
|
||
})
|
||
else: # 标签
|
||
current_tag = parts[i]
|
||
else:
|
||
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
|
||
final_messages.append(msg)
|
||
return final_messages
|
||
|
||
|
||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||
"""获取消息列表中的最后一条消息"""
|
||
if not messages or len(messages) == 0:
|
||
return ""
|
||
last_message = messages[-1]
|
||
if last_message and last_message.get('role') == 'user':
|
||
return last_message["content"]
|
||
return ""
|
||
|
||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||
"""将messages格式化为纯文本聊天记录
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
|
||
Returns:
|
||
str: 格式化的聊天记录
|
||
"""
|
||
# 只取最后的15句消息
|
||
chat_history = []
|
||
for message in messages:
|
||
role = message.get('role', '')
|
||
content = message.get('content', '')
|
||
name = message.get('name', '')
|
||
if role == USER:
|
||
chat_history.append(f"user: {content}")
|
||
elif role == TOOL:
|
||
chat_history.append(f"{name} response: {content}")
|
||
elif role == ASSISTANT:
|
||
if len(content) >0:
|
||
chat_history.append(f"assistant: {content}")
|
||
if message.get('tool_calls'):
|
||
for tool_call in message.get('tool_calls'):
|
||
function_name = tool_call.get('function').get('name')
|
||
arguments = tool_call.get('function').get('arguments')
|
||
chat_history.append(f"{function_name} call: {arguments}")
|
||
|
||
recent_chat_history = chat_history[-15:] if len(chat_history) > 15 else chat_history
|
||
return "\n".join(recent_chat_history)
|
||
|
||
|
||
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "general_agent") -> Optional[str]:
|
||
"""创建项目目录的公共逻辑"""
|
||
# 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录
|
||
if robot_type != "catalog_agent" or not dataset_ids or len(dataset_ids) == 0:
|
||
return None
|
||
|
||
try:
|
||
from utils.multi_project_manager import create_robot_project
|
||
return create_robot_project(dataset_ids, bot_id)
|
||
except Exception as e:
|
||
logger.error(f"Error creating project directory: {e}")
|
||
return None
|
||
|
||
|
||
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
||
"""从Authorization header中提取API key"""
|
||
if not authorization:
|
||
return None
|
||
|
||
# 移除 "Bearer " 前缀
|
||
if authorization.startswith("Bearer "):
|
||
return authorization[7:]
|
||
else:
|
||
return authorization
|
||
|
||
|
||
def generate_v2_auth_token(bot_id: str) -> str:
|
||
"""生成v2接口的认证token"""
|
||
masterkey = os.getenv("MASTERKEY", "master")
|
||
token_input = f"{masterkey}:{bot_id}"
|
||
return hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
|
||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||
"""获取机器人配置从后端API"""
|
||
try:
|
||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
||
|
||
auth_token = generate_v2_auth_token(bot_id)
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
# 使用异步HTTP请求
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, headers=headers, timeout=30) as response:
|
||
if response.status != 200:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: API returned status code {response.status}"
|
||
)
|
||
|
||
# 解析响应
|
||
response_data = await response.json()
|
||
|
||
if not response_data.get("success"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
|
||
)
|
||
|
||
return response_data.get("data", {})
|
||
|
||
except aiohttp.ClientError as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to connect to backend API: {str(e)}"
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, HTTPException):
|
||
raise
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to fetch bot config: {str(e)}"
|
||
)
|
||
|
||
|
||
async def _sync_call_llm(llm_config, messages) -> str:
|
||
"""同步调用LLM的辅助函数,在线程池中执行 - 使用LangChain"""
|
||
try:
|
||
# 创建LangChain LLM实例
|
||
model_name = llm_config.get('model')
|
||
model_server = llm_config.get('model_server')
|
||
api_key = llm_config.get('api_key')
|
||
# 检测或使用指定的提供商
|
||
model_provider,base_url = detect_provider(model_name,model_server)
|
||
|
||
# 构建模型参数
|
||
model_kwargs = {
|
||
"model": model_name,
|
||
"model_provider": model_provider,
|
||
"temperature": 0.8,
|
||
"base_url":base_url,
|
||
"api_key":api_key
|
||
}
|
||
llm_instance = init_chat_model(**model_kwargs)
|
||
|
||
# 转换消息格式为LangChain格式
|
||
langchain_messages = []
|
||
for msg in messages:
|
||
if msg['role'] == 'system':
|
||
langchain_messages.append(SystemMessage(content=msg['content']))
|
||
elif msg['role'] == 'user':
|
||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||
elif msg['role'] == 'assistant':
|
||
langchain_messages.append(AIMessage(content=msg['content']))
|
||
|
||
# 调用LangChain模型
|
||
response = await llm_instance.ainvoke(langchain_messages)
|
||
|
||
# 返回响应内容
|
||
return response.content if response.content else ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling guideline LLM with LangChain: {e}")
|
||
return ""
|
||
|
||
def get_language_text(language: str):
|
||
if language == "jp":
|
||
language = "ja"
|
||
language_map = {
|
||
'zh': '请用中文回复',
|
||
'en': 'Please reply in English',
|
||
'ja': '日本語で回答してください',
|
||
}
|
||
return language_map.get(language.lower(), '')
|
||
|
||
def get_preamble_text(language: str, system_prompt: str):
|
||
# 首先检查system_prompt中是否有preamble代码块
|
||
if system_prompt:
|
||
preamble_pattern = r'```preamble\s*\n(.*?)\n```'
|
||
preamble_matches = re.findall(preamble_pattern, system_prompt, re.DOTALL)
|
||
if preamble_matches:
|
||
# 提取preamble内容
|
||
preamble_content = preamble_matches[0].strip()
|
||
# 从system_prompt中删除preamble代码块
|
||
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
|
||
return preamble_content, cleaned_system_prompt
|
||
|
||
# 如果没有找到preamble代码块,使用默认的preamble选择
|
||
if language == "jp":
|
||
language = "ja"
|
||
preamble_choices_map = {
|
||
'zh': [
|
||
"好的,让我来帮您看看。",
|
||
"明白了,请稍等。",
|
||
"好的,我理解了。",
|
||
"没问题,我来处理。",
|
||
"收到,正在为您查询。",
|
||
"了解,让我想想。",
|
||
"好的,我来帮您解答。",
|
||
"明白了,稍等片刻。",
|
||
"好的,正在处理中。",
|
||
"了解了,让我为您分析。"
|
||
],
|
||
'en': [
|
||
"Just a moment.",
|
||
"Got it.",
|
||
"Let me check that for you.",
|
||
"Sorry to hear that.",
|
||
"Thanks for your patience.",
|
||
"I understand.",
|
||
"Let me help you with that.",
|
||
"Please wait a moment.",
|
||
"I'll look into that for you.",
|
||
"Gotcha, let me see.",
|
||
"Understood, one moment please.",
|
||
"I'll help you with this.",
|
||
"Let me figure that out.",
|
||
"Thanks for waiting.",
|
||
"I'll check on that."
|
||
],
|
||
'ja': [
|
||
"少々お待ちください。",
|
||
"承知いたしました。",
|
||
"わかりました。",
|
||
"確認いたします。",
|
||
"少々お時間をください。",
|
||
"了解しました。",
|
||
"調べてみますね。",
|
||
"お待たせしました。",
|
||
"対応いたします。",
|
||
"わかりましたね。",
|
||
"承知いたしました。",
|
||
"確認させてください。",
|
||
"少々お待ちいただけますか。",
|
||
"お調べいたします。",
|
||
"対応いたしますね。"
|
||
]
|
||
};
|
||
default_preamble = "\n".join(preamble_choices_map.get(language.lower(), []))
|
||
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
|
||
|
||
|
||
async def call_preamble_llm(chat_history: str, last_message: str, preamble_choices_text: str, language: str, model_name: str, api_key: str, model_server: str) -> str:
|
||
"""调用大语言模型处理guideline分析
|
||
|
||
Args:
|
||
chat_history: 聊天历史记录
|
||
guidelines_text: 指导原则文本
|
||
model_name: 模型名称
|
||
api_key: API密钥
|
||
model_server: 模型服务器地址
|
||
|
||
Returns:
|
||
str: 模型响应结果
|
||
"""
|
||
# 读取guideline提示词模板
|
||
try:
|
||
with open('./prompt/preamble_prompt.md', 'r', encoding='utf-8') as f:
|
||
preamble_template = f.read()
|
||
except Exception as e:
|
||
logger.error(f"Error reading guideline prompt template: {e}")
|
||
return ""
|
||
|
||
# 替换模板中的占位符
|
||
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
||
# 配置LLM
|
||
llm_config = {
|
||
'model': model_name,
|
||
'api_key': api_key,
|
||
'model_server': model_server, # 使用传入的model_server参数
|
||
}
|
||
|
||
# 调用模型
|
||
messages = [{'role': 'user', 'content': system_prompt}]
|
||
|
||
try:
|
||
# 使用信号量控制并发API调用数量
|
||
async with api_semaphore:
|
||
# 直接调用异步LLM函数
|
||
response = await _sync_call_llm(llm_config, messages)
|
||
|
||
# 从响应中提取 ```json 和 ``` 包裹的内容
|
||
json_pattern = r'```json\s*\n(.*?)\n```'
|
||
json_matches = re.findall(json_pattern, response, re.DOTALL)
|
||
|
||
if json_matches:
|
||
try:
|
||
# 解析第一个找到的JSON对象
|
||
json_data = json.loads(json_matches[0])
|
||
logger.info(f"Successfully processed preamble")
|
||
return json_data["preamble"] # 返回解析后的preamble
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"Error parsing JSON from preamble analysis: {e}")
|
||
return ""
|
||
else:
|
||
logger.warning(f"No JSON format found in preamble analysis")
|
||
return ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling guideline LLM: {e}")
|
||
return ""
|
||
|
||
|
||
|
||
async def call_guideline_llm(chat_history: str, guidelines_prompt: str, model_name: str, api_key: str, model_server: str) -> str:
|
||
"""调用大语言模型处理guideline分析
|
||
|
||
Args:
|
||
chat_history: 聊天历史记录
|
||
guidelines_text: 指导原则文本
|
||
model_name: 模型名称
|
||
api_key: API密钥
|
||
model_server: 模型服务器地址
|
||
user_identifier: 用户标识符
|
||
|
||
Returns:
|
||
str: 模型响应结果
|
||
"""
|
||
|
||
# 配置LLM
|
||
llm_config = {
|
||
'model': model_name,
|
||
'api_key': api_key,
|
||
'model_server': model_server, # 使用传入的model_server参数
|
||
}
|
||
|
||
# 调用模型
|
||
messages = [{'role': 'user', 'content': guidelines_prompt}]
|
||
|
||
try:
|
||
# 使用信号量控制并发API调用数量
|
||
async with api_semaphore:
|
||
# 直接调用异步LLM函数
|
||
response = await _sync_call_llm(llm_config, messages)
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling guideline LLM: {e}")
|
||
return ""
|
||
|
||
|
||
def _get_optimal_batch_size(guidelines_count: int) -> int:
|
||
"""根据guidelines数量决定最优批次数量(并发数)"""
|
||
if guidelines_count <= 10:
|
||
return 1
|
||
elif guidelines_count <= 20:
|
||
return 2
|
||
elif guidelines_count <= 30:
|
||
return 3
|
||
else:
|
||
return 5
|
||
|
||
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
|
||
"""
|
||
从system prompt中提取guideline和terms内容
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
|
||
Returns:
|
||
tuple[str, List[Dict], List[Dict]]: (清理后的system_prompt, guidelines_list, terms_list)
|
||
"""
|
||
if not system_prompt:
|
||
return "", [], []
|
||
|
||
guidelines = ""
|
||
tools = ""
|
||
scenarios = ""
|
||
|
||
terms_list = []
|
||
|
||
# 首先分割所有的代码块
|
||
block_pattern = r'```(\w+)\s*\n(.*?)\n```'
|
||
blocks_to_remove = []
|
||
|
||
for match in re.finditer(block_pattern, system_prompt, re.DOTALL):
|
||
block_type, content = match.groups()
|
||
|
||
if block_type == 'guideline' or block_type == 'guidelines':
|
||
guidelines = content.strip()
|
||
blocks_to_remove.append(match.group(0))
|
||
elif block_type == 'tools':
|
||
tools = content.strip()
|
||
elif block_type == 'scenarios':
|
||
scenarios = content.strip()
|
||
elif block_type == 'terms':
|
||
try:
|
||
terms = parse_terms_text(content.strip())
|
||
terms_list.extend(terms)
|
||
blocks_to_remove.append(match.group(0))
|
||
except Exception as e:
|
||
logger.error(f"Error parsing terms: {e}")
|
||
|
||
# 从system_prompt中移除这些已解析的块
|
||
cleaned_prompt = system_prompt
|
||
for block in blocks_to_remove:
|
||
cleaned_prompt = cleaned_prompt.replace(block, '', 1)
|
||
|
||
# 清理多余的空行
|
||
cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip()
|
||
return cleaned_prompt, guidelines, tools, scenarios, terms_list
|
||
|
||
|
||
def parse_guidelines_text(text: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
解析guidelines文本,支持多种格式
|
||
|
||
Args:
|
||
text: guidelines文本内容
|
||
|
||
Returns:
|
||
List[Dict]: guidelines列表
|
||
"""
|
||
guidelines = []
|
||
|
||
# 尝试解析JSON格式
|
||
if text.strip().startswith('[') or text.strip().startswith('{'):
|
||
try:
|
||
data = json.loads(text)
|
||
if isinstance(data, list):
|
||
for item in data:
|
||
if isinstance(item, dict):
|
||
guidelines.append(item)
|
||
elif isinstance(data, dict):
|
||
guidelines.append(data)
|
||
return guidelines
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 解析行格式,支持多种分隔符
|
||
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||
|
||
for line in lines:
|
||
# 跳过注释行
|
||
if line.startswith('#') or line.startswith('//'):
|
||
continue
|
||
|
||
# 尝试解析 "id) Condition: ... Action: ..." 格式
|
||
id_condition_action_pattern = r'(\d+)\)\s*Condition:\s*(.*?)\s*Action:\s*(.*?)(?:\s*Priority:\s*(\d+))?$'
|
||
match = re.match(id_condition_action_pattern, line, re.IGNORECASE)
|
||
if match:
|
||
guidelines.append({
|
||
'guideline_id': int(match.group(1)),
|
||
'condition': match.group(2).strip(),
|
||
'action': match.group(3).strip(),
|
||
'priority': int(match.group(4)) if match.group(4) else 1
|
||
})
|
||
continue
|
||
|
||
return guidelines
|
||
|
||
|
||
def parse_terms_text(text: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
解析terms文本,支持多种格式
|
||
|
||
Args:
|
||
text: terms文本内容
|
||
|
||
Returns:
|
||
List[Dict]: terms列表
|
||
"""
|
||
terms = []
|
||
|
||
# 尝试解析JSON格式
|
||
if text.strip().startswith('[') or text.strip().startswith('{'):
|
||
try:
|
||
data = json.loads(text)
|
||
if isinstance(data, list):
|
||
for item in data:
|
||
if isinstance(item, dict):
|
||
terms.append(item)
|
||
elif isinstance(data, dict):
|
||
terms.append(data)
|
||
return terms
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 解析行格式,支持多种分隔符
|
||
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||
|
||
current_term = {}
|
||
|
||
for line in lines:
|
||
# 跳过注释行
|
||
if line.startswith('#') or line.startswith('//'):
|
||
continue
|
||
|
||
# 尝试解析 "1) Name: term_name1, Description: desc, Synonyms: syn1, syn2" 格式
|
||
numbered_term_pattern = r'(?:\d+\)\s*)?Name:\s*([^,]+)(?:,\s*Description:\s*([^,]+))?(?:,\s*Synonyms:\s*(.+))?'
|
||
match = re.match(numbered_term_pattern, line, re.IGNORECASE)
|
||
if match:
|
||
name = match.group(1).strip()
|
||
description = match.group(2).strip() if match.group(2) else ''
|
||
synonyms_text = match.group(3).strip() if match.group(3) else ''
|
||
|
||
# 构建term对象
|
||
term_data = {'name': name}
|
||
if description:
|
||
term_data['description'] = description
|
||
if synonyms_text:
|
||
synonyms = re.split(r'[,;|]', synonyms_text)
|
||
term_data['synonyms'] = [s.strip() for s in synonyms if s.strip()]
|
||
|
||
if current_term: # 保存之前的term
|
||
terms.append(current_term)
|
||
current_term = term_data
|
||
continue
|
||
|
||
# 添加最后一个term
|
||
if current_term:
|
||
terms.append(current_term)
|
||
|
||
return terms
|