qwen_agent/utils/fastapi_utils.py
2026-06-18 14:56:27 +08:00

1146 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import hashlib
import json
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Optional, Union, Any
import aiohttp
from fastapi import HTTPException
import logging
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, convert_to_openai_messages
from langchain.chat_models import init_chat_model
from utils.settings import MASTERKEY, BACKEND_HOST
from agent.agent_config import AgentConfig
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
logger = logging.getLogger('app')
# Create a global thread pool executor for synchronous HTTP calls
thread_pool = ThreadPoolExecutor(max_workers=10)
# Create a concurrency semaphore to limit simultaneous API calls
api_semaphore = asyncio.Semaphore(8) # Maximum 8 concurrent API calls
def detect_provider(model_name,model_server):
"""Detect provider type based on model name."""
model_name_lower = model_name.lower()
if any(claude_model in model_name_lower for claude_model in ["claude", "anthropic"]):
return "anthropic",model_server.replace("/v1","")
elif any(openai_model in model_name_lower for openai_model in ["gpt", "openai", "o1"]):
return "openai",model_server
else:
# Default to openai-compatible format
return "openai",model_server
def is_anthropic_opus_model(model_name: Optional[str]) -> bool:
"""Check if the model is an Anthropic Opus model."""
return bool(model_name and "opus" in model_name.lower())
def sanitize_model_kwargs(
model_name: str,
model_provider: str,
base_url: Optional[str],
api_key: Optional[str],
generate_cfg: Optional[Dict[str, Any]] = None,
source: str = "agent"
) -> tuple[Dict[str, Any], List[str], bool]:
"""Sanitize model parameters, filtering incompatible params and returning logging info."""
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"base_url": base_url,
"api_key": api_key
}
internal_params = {
'tool_output_max_length',
'tool_output_truncation_strategy',
'tool_output_filters',
'tool_output_exclude',
'preserve_code_blocks',
'preserve_json',
}
openai_only_params = {
'n',
'presence_penalty',
'frequency_penalty',
'logprobs',
'top_logprobs',
'logit_bias',
'seed',
'suffix',
'best_of',
'echo',
'user',
}
params_to_filter = set(internal_params)
is_opus_model = model_provider == 'anthropic' and is_anthropic_opus_model(model_name)
if model_provider == 'anthropic':
params_to_filter.update(openai_only_params)
if is_opus_model:
params_to_filter.add('temperature')
original_keys = list((generate_cfg or {}).keys())
filtered_cfg = {k: v for k, v in (generate_cfg or {}).items() if k not in params_to_filter}
dropped_params = [k for k in original_keys if k in params_to_filter]
default_temperature_applied = False
if not is_opus_model:
model_kwargs["temperature"] = 0.8
default_temperature_applied = True
model_kwargs.update(filtered_cfg)
logger.info(
"sanitize_model_kwargs source=%s provider=%s model=%s original_keys=%s dropped_keys=%s default_temperature_applied=%s",
source,
model_provider,
model_name,
original_keys,
dropped_params,
default_temperature_applied
)
return model_kwargs, dropped_params, default_temperature_applied
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
"""
Get a versioned filename, automatically handling file deletion and version incrementing.
Args:
upload_dir: Upload directory path
name_without_ext: Filename without extension
file_extension: File extension (including dot)
Returns:
tuple[str, int]: (Final filename, version number)
"""
# Check if the original file exists
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
original_exists = os.path.exists(original_file)
# Find all related versioned files
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
existing_versions = []
files_to_delete = []
for filename in os.listdir(upload_dir):
# Check if it is the original file
if filename == name_without_ext + file_extension:
files_to_delete.append(filename)
continue
# Check if it is a versioned file
match = pattern.match(filename)
if match:
version_num = int(match.group(1))
existing_versions.append(version_num)
files_to_delete.append(filename)
# If no related files exist, use the original filename (version 1)
if not original_exists and not existing_versions:
return name_without_ext + file_extension, 1
# Delete all existing files (original and versioned)
for filename in files_to_delete:
file_to_delete = os.path.join(upload_dir, filename)
try:
os.remove(file_to_delete)
logger.info(f"Deleted file: {file_to_delete}")
except OSError as e:
logger.error(f"Failed to delete file {file_to_delete}: {e}")
# Determine the next version number
if existing_versions:
next_version = max(existing_versions) + 1
else:
next_version = 2
# Generate the versioned filename
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
return versioned_filename, next_version
def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, finish_reason: str = None) -> dict:
"""Create a standardized streaming response chunk"""
chunk_data = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": content} if content is not None else {},
"finish_reason": finish_reason
}]
}
return chunk_data
# def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
# """Extract content from qwen-agent messages with special formatting"""
# full_text = ''
# content = []
# TOOL_CALL_S = '[TOOL_CALL]'
# TOOL_RESULT_S = '[TOOL_RESPONSE]'
# THOUGHT_S = '[THINK]'
# ANSWER_S = '[ANSWER]'
# PREAMBLE_S = '[PREAMBLE]'
# for msg in messages:
# if msg['role'] == ASSISTANT:
# if msg.get('reasoning_content'):
# assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
# content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
# if msg.get('content'):
# assert isinstance(msg['content'], str), 'Now only supports text messages'
# # Filter out incomplete tool_call text from streaming output
# content_text = msg["content"]
# # Use regex to replace incomplete tool_call patterns with empty string
# # Match and replace incomplete tool_call patterns
# content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
# # Only add if content is not empty after processing
# if content_text.strip():
# content.append(f'{ANSWER_S}\n{content_text}')
# if msg.get('function_call'):
# content_text = msg["function_call"]["arguments"]
# content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
# if content_text.strip():
# content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
# elif msg['role'] == FUNCTION:
# if tool_response:
# content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
# elif msg['role'] == "preamble":
# content.append(f'{PREAMBLE_S}\n{msg["content"]}')
# else:
# raise TypeError
# if content:
# full_text = '\n'.join(content)
# return full_text
def normalize_content_blocks(content: Union[str, List[Dict[str, Any]]]) -> Union[str, List[Dict[str, Any]]]:
"""Normalize multimodal content blocks into LangChain standard content blocks.
Accepts both OpenAI-style blocks ({"type": "image_url", "image_url": {"url": ...}})
and LangChain standard blocks ({"type": "image", "base64"/"url": ...}), and emits
LangChain standard blocks so the provider's block_translator can auto-convert for
either OpenAI or Anthropic. Plain string content is returned unchanged.
"""
if not isinstance(content, list):
return content
normalized: List[Dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
# Treat a bare string inside the list as a text block.
if isinstance(block, str):
normalized.append({"type": "text", "text": block})
continue
block_type = block.get("type")
if block_type == "text":
normalized.append({"type": "text", "text": block.get("text", "")})
elif block_type == "image_url":
# OpenAI-style image block: {"type": "image_url", "image_url": {"url": ...}}
image_url = block.get("image_url")
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if not url:
continue
if isinstance(url, str) and url.startswith("data:"):
# data:<mime_type>;base64,<data>
try:
header, data = url.split(",", 1)
mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/jpeg"
normalized.append({"type": "image", "base64": data, "mime_type": mime_type})
except ValueError:
logger.warning("Skipping malformed data URL in image_url block")
else:
normalized.append({"type": "image", "url": url})
elif block_type == "image":
# Already a LangChain standard image block; pass through.
normalized.append(block)
else:
# Unknown block type; pass through untouched.
normalized.append(block)
return normalized
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
"""Process message list, including [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] splitting and language directive addition.
This is the inverse of get_content_from_messages, reassembling messages containing
[TOOL_RESPONSE] back into msg['role'] == 'function' and msg.get('function_call') format.
Args:
messages: Message list
language: Optional language parameter
include_function_name: List of function_name keywords to include, defaults to ['find', 'get']
"""
# Set default include function_name list
include_function_name = ['find', 'get']
processed_messages = []
# Collect indices of all ASSISTANT messages
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == ASSISTANT]
total_assistant_messages = len(assistant_indices)
cutoff_point = max(0, total_assistant_messages - 5)
# Process each message
for i, msg in enumerate(messages):
if msg.role == ASSISTANT and isinstance(msg.content, str):
# Determine the position of this ASSISTANT message among all ASSISTANT messages (0-indexed)
assistant_position = assistant_indices.index(i)
# Split by [THINK|TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] using regex
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content)
# Reassemble content, choosing processing based on message position
filtered_content = ""
current_tag = None
is_recent_message = assistant_position >= cutoff_point # Recent messages
for i in range(0, len(parts)):
if i % 2 == 0: # Text content
text = parts[i].strip()
if not text:
continue
# Do not forward historical tool call text
if current_tag == "TOOL_RESPONSE":
if is_recent_message:
# Recent ASSISTANT messages: preserve full TOOL_RESPONSE info (using abbreviated mode)
if len(text) <= 1000:
filtered_content += f"[TOOL_RESPONSE] {text}\n"
else:
# Extract first, middle, and last sections, 250 chars each
first_part = text[:250]
middle_start = len(text) // 2 - 125
middle_part = text[middle_start:middle_start + 250]
last_part = text[-250:]
# Calculate omitted character count
omitted_count = len(text) - 750
omitted_text = f"...{omitted_count} chars omitted..."
# Concatenate content
truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}"
filtered_content += f"[TOOL_RESPONSE] {truncated_text}\n"
# Messages beyond the recent window: do not preserve TOOL_RESPONSE data (skip entirely)
elif current_tag == "TOOL_CALL":
if is_recent_message:
# Recent ASSISTANT messages: preserve TOOL_CALL info
filtered_content += f"[TOOL_CALL] {text}\n"
# Messages beyond the recent window: do not preserve TOOL_CALL data (skip entirely)
elif current_tag == "ANSWER":
# All ASSISTANT messages preserve ANSWER data
filtered_content += f"[ANSWER] {text}\n"
elif current_tag != "THINK" and current_tag != "PREAMBLE":
filtered_content += text + "\n"
else: # Tag
current_tag = parts[i]
# Get the final processed content, stripping leading/trailing whitespace
final_content = filtered_content.strip()
if final_content:
processed_messages.append({"role": msg.role, "content": final_content})
else:
# If processed content is empty, use original content
processed_messages.append({"role": msg.role, "content": msg.content})
else:
# User/other messages (or assistant messages carrying multimodal list
# content) pass through; normalize multimodal blocks to LangChain standard.
processed_messages.append({"role": msg.role, "content": normalize_content_blocks(msg.content)})
# Inverse operation: reassemble messages containing [THINK|TOOL_RESPONSE] back into
# msg['role'] == 'function' and msg.get('function_call') format.
# This is the inverse of get_content_from_messages.
final_messages = []
for msg in processed_messages:
if msg["role"] == ASSISTANT and isinstance(msg["content"], str):
# Split message content
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
current_tag = None
tool_id_counter = 0 # Unique tool call counter
tool_id_list = []
for i in range(0, len(parts)):
if i % 2 == 0: # Text content
text = parts[i].strip()
if not text:
continue
# Do not forward historical tool call text
if current_tag == "TOOL_RESPONSE":
# Parse TOOL_RESPONSE format: [TOOL_RESPONSE] function_name\ncontent
lines = text.split('\n', 1)
function_name = lines[0].strip() if lines else ""
response_content = lines[1].strip() if len(lines) > 1 else ""
# Filter out function_names that do not contain specified keywords
should_include = False
if function_name:
for exclude_name in include_function_name:
if exclude_name in function_name:
should_include = True
break
if should_include and len(tool_id_list)>0:
tool_id = tool_id_list.pop(0)
# Wrap TOOL_RESPONSE as a tool_result message, following the corresponding tool_use
final_messages.append({
"role": TOOL,
"tool_call_id": tool_id, # Keep consistent with the preceding tool_use id
"name": function_name,
"content": response_content
})
elif current_tag == "TOOL_CALL":
# Parse TOOL_CALL format: [TOOL_CALL] function_name\narguments
lines = text.split('\n', 1)
function_name = lines[0].strip() if lines else ""
arguments = lines[1].strip() if len(lines) > 1 else ""
# Filter out function_names that do not contain specified keywords
should_include = False
if function_name:
for exclude_name in include_function_name:
if exclude_name in function_name:
should_include = True
break
if should_include:
tool_id = f"tool_id_{tool_id_counter}" # Use unique counter
tool_id_list.append(tool_id)
tool_id_counter += 1 # Increment counter
final_messages.append({
"role": ASSISTANT,
"content": "",
"tool_calls": [{
"id":tool_id,
"function": {
"name": function_name,
"arguments": arguments
}
}]
})
elif current_tag != "THINK" and current_tag != "PREAMBLE":
final_messages.append({
"role": ASSISTANT,
"content": text
})
else: # Tag
current_tag = parts[i]
else:
# Non-assistant messages or messages without [TOOL_RESPONSE] are added directly
final_messages.append(msg)
return final_messages
def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from message content that may be a multimodal block list."""
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
texts.append(block.get("text", ""))
elif isinstance(block, str):
texts.append(block)
return "\n".join(texts)
return ""
def get_user_last_message_content(messages: list) -> Optional[str]:
"""Get the last user message's plain text content from a message list.
Multimodal list content is flattened to text so downstream consumers
(e.g. terms embedding) always receive a string.
"""
if not messages or len(messages) == 0:
return ""
last_message = messages[-1]
if last_message and last_message.get('role') == 'user':
return extract_text_from_content(last_message.get("content", ""))
return ""
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
"""Format messages as plain text chat history.
Args:
messages: Message list
Returns:
str: Formatted chat history
"""
# Only take the last 15 messages
chat_history = []
for message in messages:
role = message.get('role', '')
content = message.get('content', '')
name = message.get('name', '')
if role == USER:
chat_history.append(f"user: {content}")
elif role == TOOL:
chat_history.append(f"{name} response: {content}")
elif role == ASSISTANT:
if len(content) >0:
chat_history.append(f"assistant: {content}")
if message.get('tool_calls'):
for tool_call in message.get('tool_calls'):
function_name = tool_call.get('function').get('name')
arguments = tool_call.get('function').get('arguments')
chat_history.append(f"{function_name} call: {arguments}")
recent_chat_history = chat_history[-16:-1] if len(chat_history) > 16 else chat_history[:-1]
return "\n".join(recent_chat_history)
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, skills: Optional[List[str]] = None) -> Optional[str]:
"""Common logic for creating project directories."""
# If dataset_ids is empty, do not create a directory
if not dataset_ids:
dataset_ids = []
try:
from utils.multi_project_manager import create_robot_project
from pathlib import Path
return create_robot_project(dataset_ids, bot_id, skills=skills)
except Exception as e:
logger.error(f"Error creating project directory: {e}")
return None
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
"""Extract API key from Authorization header."""
if not authorization:
return None
# Remove "Bearer " prefix
if authorization.startswith("Bearer "):
return authorization[7:]
else:
return authorization
def generate_v2_auth_token(bot_id: str) -> str:
"""Generate authentication token for v2 API."""
token_input = f"{MASTERKEY}:{bot_id}"
return hashlib.md5(token_input.encode()).hexdigest()
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
"""Fetch bot config from the backend API."""
try:
url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
auth_token = generate_v2_auth_token(bot_id)
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
# Use async HTTP request
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, timeout=30) as response:
if response.status != 200:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch bot config: API returned status code {response.status}"
)
# Parse response
response_data = await response.json()
if not response_data.get("success"):
raise HTTPException(
status_code=400,
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
)
return response_data.get("data", {})
except aiohttp.ClientError as e:
raise HTTPException(
status_code=500,
detail=f"Failed to connect to backend API: {str(e)}"
)
except Exception as e:
if isinstance(e, HTTPException):
raise
raise HTTPException(
status_code=500,
detail=f"Failed to fetch bot config: {str(e)}"
)
async def fetch_bot_config_from_db(bot_user_id: str, user_identifier: Optional[str] = None) -> Dict[str, Any]:
"""
从本地数据库获取机器人配置
Args:
bot_user_id: Bot 的用户IDbot_id 字段,不是 UUID
user_identifier: 当前登录用户的用户名。如果为空则使用 owner_id
Returns:
Dict[str, Any]: 包含所有配置参数的字典,格式与 fetch_bot_config 兼容
"""
try:
from agent.db_pool_manager import get_db_pool_manager
from utils.settings import NEW_API_BASE_URL
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 从 agent_bots 表获取 bot 信息和 settings同时获取 owner_id
await cursor.execute(
"""
SELECT id, name, settings, owner_id
FROM agent_bots WHERE id = %s
""",
(bot_user_id,)
)
bot_row = await cursor.fetchone()
if not bot_row:
raise HTTPException(
status_code=404,
detail=f"Bot with id '{bot_user_id}' not found"
)
bot_uuid = bot_row[0]
bot_name = bot_row[1]
settings_json = bot_row[2]
owner_id = bot_row[3]
# 解析 settings JSONB 字段
if settings_json:
if isinstance(settings_json, str):
try:
settings_data = json.loads(settings_json)
except json.JSONDecodeError:
logger.warning(f"Failed to parse settings JSON for bot {bot_user_id}")
settings_data = {}
else:
settings_data = settings_json
else:
settings_data = {}
# 获取 model_id来自 New API格式为 "Provider/ModelName"
model_id = settings_data.get("model_id", "")
# 获取 bot owner 的 new_api_token 作为 api_key
api_key = ""
if owner_id:
await cursor.execute(
"""
SELECT new_api_token
FROM agent_user WHERE id = %s
""",
(owner_id,)
)
user_row = await cursor.fetchone()
if user_row and user_row[0]:
api_key = user_row[0]
# 构建 config 字典
# model_id 格式为 "Provider/ModelName",需要拆分
model_name = model_id
model_server = NEW_API_BASE_URL.rstrip('/') + "/v1" if NEW_API_BASE_URL else ""
config = {
"name": bot_name,
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"language": settings_data.get("language", "zh"),
"dataset_ids": settings_data.get("dataset_ids", []),
"system_prompt": settings_data.get("system_prompt", ""),
"user_identifier": user_identifier if user_identifier else owner_id,
"enable_memori": settings_data.get("enable_memori", False),
"tool_response": settings_data.get("tool_response", True),
"enable_thinking": settings_data.get("enable_thinking", False),
"skills": settings_data.get("skills", []),
"description": settings_data.get("description", ""),
"suggestions": settings_data.get("suggestions", []),
"shell_env": settings_data.get("shell_env") or {},
"voice_speaker": settings_data.get("voice_speaker", ""),
"voice_system_role": settings_data.get("voice_system_role", ""),
"voice_speaking_style": settings_data.get("voice_speaking_style", ""),
}
# 处理 dataset_ids
dataset_ids = config['dataset_ids']
if dataset_ids:
if isinstance(dataset_ids, str):
if dataset_ids.startswith('['):
try:
config['dataset_ids'] = json.loads(dataset_ids)
except json.JSONDecodeError:
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
else:
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
else:
config['dataset_ids'] = []
# 处理 skills
skills = config.get('skills', [])
if isinstance(skills, str):
config['skills'] = [s.strip() for s in skills.split(',') if s.strip()]
elif not isinstance(skills, list):
config['skills'] = []
# 处理 suggestions
suggestions = config.get('suggestions', [])
if isinstance(suggestions, str):
config['suggestions'] = [s.strip() for s in suggestions.split('\n') if s.strip()]
elif not isinstance(suggestions, list):
config['suggestions'] = []
# 查询 MCP 服务器配置
await cursor.execute(
"""
SELECT name, type, config, enabled
FROM agent_mcp_servers WHERE bot_id = %s AND enabled = true
""",
(bot_uuid,)
)
mcp_rows = await cursor.fetchall()
mcp_servers = []
for mcp_row in mcp_rows:
mcp_name = mcp_row[0]
mcp_type = mcp_row[1]
mcp_config = mcp_row[2]
# 如果 config 是 JSONB/字符串,解析它
if isinstance(mcp_config, str):
try:
mcp_config = json.loads(mcp_config)
except json.JSONDecodeError:
mcp_config = {}
mcp_servers.append({
"name": mcp_name,
"type": mcp_type,
"config": mcp_config
})
# 格式化为 mcp_settings 格式 (兼容 v2 API)
if mcp_servers:
mcp_settings_value = []
for server in mcp_servers:
server_config = server.get("config", {})
# 优先用 config 中的 server_typefallback 到数据库的 name 字段
server_type = server_config.pop("server_type", server["name"])
mcp_settings_value.append({
"mcpServers": {
server_type: server_config
}
})
config["mcp_settings"] = mcp_settings_value
else:
config["mcp_settings"] = []
logger.info(f"Fetched bot config for {bot_user_id}: model={config['model']}, api_key={'*' + config['api_key'][-4:] if config['api_key'] else 'N/A'}")
return config
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching bot config from database: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(
status_code=500,
detail=f"Failed to fetch bot config from database: {str(e)}"
)
async def _sync_call_llm(llm_config, messages) -> str:
"""Sync LLM helper function, executed in thread pool - using LangChain."""
try:
# Create LangChain LLM instance
model_name = llm_config.get('model')
model_server = llm_config.get('model_server')
api_key = llm_config.get('api_key')
# Detect or use specified provider
model_provider,base_url = detect_provider(model_name,model_server)
model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs(
model_name=model_name,
model_provider=model_provider,
base_url=base_url,
api_key=api_key,
source="_sync_call_llm"
)
if dropped_params:
logger.info(
"_sync_call_llm dropped_params=%s model=%s provider=%s default_temperature_applied=%s",
dropped_params,
model_name,
model_provider,
default_temperature_applied
)
llm_instance = init_chat_model(**model_kwargs)
# Convert messages to LangChain format
langchain_messages = []
for msg in messages:
if msg['role'] == 'system':
langchain_messages.append(SystemMessage(content=msg['content']))
elif msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# Call LangChain model
response = await llm_instance.ainvoke(langchain_messages)
# Return response content
return response.content if response.content else ""
except Exception as e:
logger.error(f"Error calling guideline LLM with LangChain: {e}")
return ""
def get_language_text(language: str):
if language == "jp":
language = "ja"
language_map = {
'zh': '请用中文回复',
'en': 'Please reply in English',
'ja': '日本語で回答してください',
}
return language_map.get(language.lower(), '')
def get_preamble_text(language: str, system_prompt: str):
# First check if system_prompt has a preamble tag
if system_prompt:
preamble_pattern = r'<preamble>\s*(.*?)\s*</preamble>'
preamble_matches = re.findall(preamble_pattern, system_prompt, re.DOTALL)
if preamble_matches:
# Extract preamble content
preamble_content = preamble_matches[0].strip()
if preamble_content:
# Remove preamble tag from system_prompt
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
return preamble_content, cleaned_system_prompt
# If no preamble block found, use default preamble choices
if language == "jp":
language = "ja"
preamble_choices_map = {
'zh': [
"好的,让我来帮您看看。",
"明白了,请稍等。",
"好的,我理解了。",
"没问题,我来处理。",
"收到,正在为您查询。",
"了解,让我想想。",
"好的,我来帮您解答。",
"明白了,稍等片刻。",
"好的,正在处理中。",
"了解了,让我为您分析。"
],
'en': [
"Just a moment.",
"Got it.",
"Let me check that for you.",
"Sorry to hear that.",
"Thanks for your patience.",
"I understand.",
"Let me help you with that.",
"Please wait a moment.",
"I'll look into that for you.",
"Gotcha, let me see.",
"Understood, one moment please.",
"I'll help you with this.",
"Let me figure that out.",
"Thanks for waiting.",
"I'll check on that."
],
'ja': [
"少々お待ちください。",
"承知いたしました。",
"わかりました。",
"確認いたします。",
"少々お時間をください。",
"了解しました。",
"調べてみますね。",
"お待たせしました。",
"対応いたします。",
"わかりましたね。",
"承知いたしました。",
"確認させてください。",
"少々お待ちいただけますか。",
"お調べいたします。",
"対応いたしますね。"
]
};
default_preamble = "\n".join(preamble_choices_map.get(language.lower(), []))
return default_preamble, system_prompt # Return default preamble and original system_prompt
async def call_preamble_llm(config: AgentConfig) -> str:
"""Call LLM to process preamble analysis.
Args:
messages: Message list
preamble_choices_text: Guideline text
model_name: Model name
api_key: API key
model_server: Model server URL
Returns:
str: Model response result
"""
# Read preamble prompt template
try:
with open('./prompt/preamble_prompt.md', 'r', encoding='utf-8') as f:
preamble_template = f.read()
except Exception as e:
logger.error(f"Error reading guideline prompt template: {e}")
return ""
api_key = config.api_key
model_name = config.model_name
model_server = config.model_server
language = config.language
preamble_choices_text = config.preamble_text
last_message = get_user_last_message_content(config.messages)
chat_history = format_messages_to_chat_history(convert_to_openai_messages(config._session_history))
# Replace placeholders in the template
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
# Configure LLM
llm_config = {
'model': model_name,
'api_key': api_key,
'model_server': model_server, # Use the passed model_server parameter
}
# Call model
messages = [{'role': 'user', 'content': system_prompt}]
try:
# Use semaphore to control concurrent API calls
async with api_semaphore:
# Call async LLM function directly
response = await _sync_call_llm(llm_config, messages)
# Extract content wrapped in ```json and ``` from response
json_pattern = r'```json\s*\n(.*?)\n```'
json_matches = re.findall(json_pattern, response, re.DOTALL)
if json_matches:
try:
# Parse the first found JSON object
json_data = json.loads(json_matches[0])
logger.info(f"Successfully processed preamble")
return json_data["preamble"] # Return the parsed preamble
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON from preamble analysis: {e}")
return ""
else:
logger.warning(f"No JSON format found in preamble analysis")
return ""
except Exception as e:
logger.error(f"Error calling guideline LLM: {e}")
return ""
async def call_guideline_llm(chat_history: str, guidelines_prompt: str, model_name: str, api_key: str, model_server: str) -> str:
"""Call LLM to process guideline analysis.
Args:
chat_history: Chat history
guidelines_text: Guideline text
model_name: Model name
api_key: API key
model_server: Model server URL
user_identifier: User identifier
Returns:
str: Model response result
"""
# Configure LLM
llm_config = {
'model': model_name,
'api_key': api_key,
'model_server': model_server, # Use the passed model_server parameter
}
# Call model
messages = [{'role': 'user', 'content': guidelines_prompt}]
try:
# Use semaphore to control concurrent API calls
async with api_semaphore:
# Call async LLM function directly
response = await _sync_call_llm(llm_config, messages)
return response
except Exception as e:
logger.error(f"Error calling guideline LLM: {e}")
return ""
def _get_optimal_batch_size(guidelines_count: int) -> int:
"""Determine the optimal batch size (concurrency) based on guidelines count."""
if guidelines_count <= 10:
return 1
elif guidelines_count <= 20:
return 2
elif guidelines_count <= 30:
return 3
else:
return 5
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
"""
Extract guideline and terms content from the system prompt.
Args:
system_prompt: System prompt text
Returns:
tuple[str, List[Dict], List[Dict]]: (cleaned system_prompt, guidelines_list, terms_list)
"""
if not system_prompt:
return "", [], []
guidelines = ""
tools = ""
scenarios = ""
terms_list = []
# Parse blocks using XML tag format
blocks_to_remove = []
# Parse <guidelines>
guidelines_pattern = r'<guidelines>\s*(.*?)\s*</guidelines>'
match = re.search(guidelines_pattern, system_prompt, re.DOTALL)
if match:
guidelines = match.group(1).strip()
blocks_to_remove.append(match.group(0))
# Parse <tools>
tools_pattern = r'<tools>\s*(.*?)\s*</tools>'
match = re.search(tools_pattern, system_prompt, re.DOTALL)
if match:
tools = match.group(1).strip()
blocks_to_remove.append(match.group(0))
# Parse <scenarios>
scenarios_pattern = r'<scenarios>\s*(.*?)\s*</scenarios>'
match = re.search(scenarios_pattern, system_prompt, re.DOTALL)
if match:
scenarios = match.group(1).strip()
blocks_to_remove.append(match.group(0))
# Parse <terms>
terms_pattern = r'<terms>\s*(.*?)\s*</terms>'
match = re.search(terms_pattern, system_prompt, re.DOTALL)
if match:
try:
terms = parse_terms_text(match.group(1).strip())
terms_list.extend(terms)
blocks_to_remove.append(match.group(0))
except Exception as e:
logger.error(f"Error parsing terms: {e}")
# Remove parsed blocks from system_prompt
cleaned_prompt = system_prompt
for block in blocks_to_remove:
cleaned_prompt = cleaned_prompt.replace(block, '', 1)
# Clean up excess blank lines
cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip()
return cleaned_prompt, guidelines, tools, scenarios, terms_list
def parse_terms_text(text: str) -> List[Dict[str, Any]]:
"""
Parse terms text, supporting multiple formats.
Args:
text: Terms text content
Returns:
List[Dict]: List of terms
"""
terms = []
# Try to parse as JSON format
if text.strip().startswith('[') or text.strip().startswith('{'):
try:
data = json.loads(text)
if isinstance(data, list):
for item in data:
if isinstance(item, dict):
terms.append(item)
elif isinstance(data, dict):
terms.append(data)
return terms
except json.JSONDecodeError:
pass
# Parse line format, supporting multiple separators
lines = [line.strip() for line in text.split('\n') if line.strip()]
current_term = {}
for line in lines:
# Skip comment lines
if line.startswith('#') or line.startswith('//'):
continue
# Try to parse "1) Name: term_name1, Description: desc, Synonyms: syn1, syn2" format
numbered_term_pattern = r'(?:\d+\)\s*)?Name:\s*([^,]+)(?:,\s*Description:\s*([^,]+))?(?:,\s*Synonyms:\s*(.+))?'
match = re.match(numbered_term_pattern, line, re.IGNORECASE)
if match:
name = match.group(1).strip()
description = match.group(2).strip() if match.group(2) else ''
synonyms_text = match.group(3).strip() if match.group(3) else ''
# Build term object
term_data = {'name': name}
if description:
term_data['description'] = description
if synonyms_text:
synonyms = re.split(r'[,;|]', synonyms_text)
term_data['synonyms'] = [s.strip() for s in synonyms if s.strip()]
if current_term: # Save the previous term
terms.append(current_term)
current_term = term_data
continue
# Add the last term
if current_term:
terms.append(current_term)
return terms