1146 lines
45 KiB
Python
1146 lines
45 KiB
Python
import os
|
||
import re
|
||
import hashlib
|
||
import json
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List, Dict, Optional, Union, Any
|
||
import aiohttp
|
||
from fastapi import HTTPException
|
||
import logging
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, convert_to_openai_messages
|
||
from langchain.chat_models import init_chat_model
|
||
from utils.settings import MASTERKEY, BACKEND_HOST
|
||
from agent.agent_config import AgentConfig
|
||
|
||
USER = "user"
|
||
ASSISTANT = "assistant"
|
||
TOOL = "tool"
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
# Create a global thread pool executor for synchronous HTTP calls
|
||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||
|
||
# Create a concurrency semaphore to limit simultaneous API calls
|
||
api_semaphore = asyncio.Semaphore(8) # Maximum 8 concurrent API calls
|
||
|
||
def detect_provider(model_name,model_server):
|
||
"""Detect provider type based on model name."""
|
||
model_name_lower = model_name.lower()
|
||
if any(claude_model in model_name_lower for claude_model in ["claude", "anthropic"]):
|
||
return "anthropic",model_server.replace("/v1","")
|
||
elif any(openai_model in model_name_lower for openai_model in ["gpt", "openai", "o1"]):
|
||
return "openai",model_server
|
||
else:
|
||
# Default to openai-compatible format
|
||
return "openai",model_server
|
||
|
||
|
||
def is_anthropic_opus_model(model_name: Optional[str]) -> bool:
|
||
"""Check if the model is an Anthropic Opus model."""
|
||
return bool(model_name and "opus" in model_name.lower())
|
||
|
||
|
||
def sanitize_model_kwargs(
|
||
model_name: str,
|
||
model_provider: str,
|
||
base_url: Optional[str],
|
||
api_key: Optional[str],
|
||
generate_cfg: Optional[Dict[str, Any]] = None,
|
||
source: str = "agent"
|
||
) -> tuple[Dict[str, Any], List[str], bool]:
|
||
"""Sanitize model parameters, filtering incompatible params and returning logging info."""
|
||
model_kwargs = {
|
||
"model": model_name,
|
||
"model_provider": model_provider,
|
||
"base_url": base_url,
|
||
"api_key": api_key
|
||
}
|
||
|
||
internal_params = {
|
||
'tool_output_max_length',
|
||
'tool_output_truncation_strategy',
|
||
'tool_output_filters',
|
||
'tool_output_exclude',
|
||
'preserve_code_blocks',
|
||
'preserve_json',
|
||
}
|
||
|
||
openai_only_params = {
|
||
'n',
|
||
'presence_penalty',
|
||
'frequency_penalty',
|
||
'logprobs',
|
||
'top_logprobs',
|
||
'logit_bias',
|
||
'seed',
|
||
'suffix',
|
||
'best_of',
|
||
'echo',
|
||
'user',
|
||
}
|
||
|
||
params_to_filter = set(internal_params)
|
||
is_opus_model = model_provider == 'anthropic' and is_anthropic_opus_model(model_name)
|
||
|
||
if model_provider == 'anthropic':
|
||
params_to_filter.update(openai_only_params)
|
||
if is_opus_model:
|
||
params_to_filter.add('temperature')
|
||
|
||
original_keys = list((generate_cfg or {}).keys())
|
||
filtered_cfg = {k: v for k, v in (generate_cfg or {}).items() if k not in params_to_filter}
|
||
dropped_params = [k for k in original_keys if k in params_to_filter]
|
||
|
||
default_temperature_applied = False
|
||
if not is_opus_model:
|
||
model_kwargs["temperature"] = 0.8
|
||
default_temperature_applied = True
|
||
|
||
model_kwargs.update(filtered_cfg)
|
||
|
||
logger.info(
|
||
"sanitize_model_kwargs source=%s provider=%s model=%s original_keys=%s dropped_keys=%s default_temperature_applied=%s",
|
||
source,
|
||
model_provider,
|
||
model_name,
|
||
original_keys,
|
||
dropped_params,
|
||
default_temperature_applied
|
||
)
|
||
|
||
return model_kwargs, dropped_params, default_temperature_applied
|
||
|
||
|
||
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
|
||
"""
|
||
Get a versioned filename, automatically handling file deletion and version incrementing.
|
||
|
||
Args:
|
||
upload_dir: Upload directory path
|
||
name_without_ext: Filename without extension
|
||
file_extension: File extension (including dot)
|
||
|
||
Returns:
|
||
tuple[str, int]: (Final filename, version number)
|
||
"""
|
||
# Check if the original file exists
|
||
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
|
||
original_exists = os.path.exists(original_file)
|
||
|
||
# Find all related versioned files
|
||
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
|
||
existing_versions = []
|
||
files_to_delete = []
|
||
|
||
for filename in os.listdir(upload_dir):
|
||
# Check if it is the original file
|
||
if filename == name_without_ext + file_extension:
|
||
files_to_delete.append(filename)
|
||
continue
|
||
|
||
# Check if it is a versioned file
|
||
match = pattern.match(filename)
|
||
if match:
|
||
version_num = int(match.group(1))
|
||
existing_versions.append(version_num)
|
||
files_to_delete.append(filename)
|
||
|
||
# If no related files exist, use the original filename (version 1)
|
||
if not original_exists and not existing_versions:
|
||
return name_without_ext + file_extension, 1
|
||
|
||
# Delete all existing files (original and versioned)
|
||
for filename in files_to_delete:
|
||
file_to_delete = os.path.join(upload_dir, filename)
|
||
try:
|
||
os.remove(file_to_delete)
|
||
logger.info(f"Deleted file: {file_to_delete}")
|
||
except OSError as e:
|
||
logger.error(f"Failed to delete file {file_to_delete}: {e}")
|
||
|
||
# Determine the next version number
|
||
if existing_versions:
|
||
next_version = max(existing_versions) + 1
|
||
else:
|
||
next_version = 2
|
||
|
||
# Generate the versioned filename
|
||
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
|
||
|
||
return versioned_filename, next_version
|
||
|
||
def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, finish_reason: str = None) -> dict:
|
||
"""Create a standardized streaming response chunk"""
|
||
chunk_data = {
|
||
"id": chunk_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model_name,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {"content": content} if content is not None else {},
|
||
"finish_reason": finish_reason
|
||
}]
|
||
}
|
||
return chunk_data
|
||
|
||
# def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||
# """Extract content from qwen-agent messages with special formatting"""
|
||
# full_text = ''
|
||
# content = []
|
||
# TOOL_CALL_S = '[TOOL_CALL]'
|
||
# TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
# THOUGHT_S = '[THINK]'
|
||
# ANSWER_S = '[ANSWER]'
|
||
# PREAMBLE_S = '[PREAMBLE]'
|
||
|
||
# for msg in messages:
|
||
# if msg['role'] == ASSISTANT:
|
||
# if msg.get('reasoning_content'):
|
||
# assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
# content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
# if msg.get('content'):
|
||
# assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
# # Filter out incomplete tool_call text from streaming output
|
||
# content_text = msg["content"]
|
||
|
||
# # Use regex to replace incomplete tool_call patterns with empty string
|
||
|
||
# # Match and replace incomplete tool_call patterns
|
||
# content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# # Only add if content is not empty after processing
|
||
# if content_text.strip():
|
||
# content.append(f'{ANSWER_S}\n{content_text}')
|
||
# if msg.get('function_call'):
|
||
# content_text = msg["function_call"]["arguments"]
|
||
# content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# if content_text.strip():
|
||
# content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||
# elif msg['role'] == FUNCTION:
|
||
# if tool_response:
|
||
# content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
# elif msg['role'] == "preamble":
|
||
# content.append(f'{PREAMBLE_S}\n{msg["content"]}')
|
||
# else:
|
||
# raise TypeError
|
||
|
||
# if content:
|
||
# full_text = '\n'.join(content)
|
||
|
||
# return full_text
|
||
|
||
|
||
def normalize_content_blocks(content: Union[str, List[Dict[str, Any]]]) -> Union[str, List[Dict[str, Any]]]:
|
||
"""Normalize multimodal content blocks into LangChain standard content blocks.
|
||
|
||
Accepts both OpenAI-style blocks ({"type": "image_url", "image_url": {"url": ...}})
|
||
and LangChain standard blocks ({"type": "image", "base64"/"url": ...}), and emits
|
||
LangChain standard blocks so the provider's block_translator can auto-convert for
|
||
either OpenAI or Anthropic. Plain string content is returned unchanged.
|
||
"""
|
||
if not isinstance(content, list):
|
||
return content
|
||
|
||
normalized: List[Dict[str, Any]] = []
|
||
for block in content:
|
||
if not isinstance(block, dict):
|
||
# Treat a bare string inside the list as a text block.
|
||
if isinstance(block, str):
|
||
normalized.append({"type": "text", "text": block})
|
||
continue
|
||
|
||
block_type = block.get("type")
|
||
|
||
if block_type == "text":
|
||
normalized.append({"type": "text", "text": block.get("text", "")})
|
||
elif block_type == "image_url":
|
||
# OpenAI-style image block: {"type": "image_url", "image_url": {"url": ...}}
|
||
image_url = block.get("image_url")
|
||
url = image_url.get("url") if isinstance(image_url, dict) else image_url
|
||
if not url:
|
||
continue
|
||
if isinstance(url, str) and url.startswith("data:"):
|
||
# data:<mime_type>;base64,<data>
|
||
try:
|
||
header, data = url.split(",", 1)
|
||
mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/jpeg"
|
||
normalized.append({"type": "image", "base64": data, "mime_type": mime_type})
|
||
except ValueError:
|
||
logger.warning("Skipping malformed data URL in image_url block")
|
||
else:
|
||
normalized.append({"type": "image", "url": url})
|
||
elif block_type == "image":
|
||
# Already a LangChain standard image block; pass through.
|
||
normalized.append(block)
|
||
else:
|
||
# Unknown block type; pass through untouched.
|
||
normalized.append(block)
|
||
|
||
return normalized
|
||
|
||
|
||
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||
"""Process message list, including [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] splitting and language directive addition.
|
||
|
||
This is the inverse of get_content_from_messages, reassembling messages containing
|
||
[TOOL_RESPONSE] back into msg['role'] == 'function' and msg.get('function_call') format.
|
||
|
||
Args:
|
||
messages: Message list
|
||
language: Optional language parameter
|
||
include_function_name: List of function_name keywords to include, defaults to ['find', 'get']
|
||
"""
|
||
# Set default include function_name list
|
||
include_function_name = ['find', 'get']
|
||
|
||
processed_messages = []
|
||
|
||
# Collect indices of all ASSISTANT messages
|
||
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == ASSISTANT]
|
||
total_assistant_messages = len(assistant_indices)
|
||
cutoff_point = max(0, total_assistant_messages - 5)
|
||
|
||
# Process each message
|
||
for i, msg in enumerate(messages):
|
||
if msg.role == ASSISTANT and isinstance(msg.content, str):
|
||
# Determine the position of this ASSISTANT message among all ASSISTANT messages (0-indexed)
|
||
assistant_position = assistant_indices.index(i)
|
||
|
||
# Split by [THINK|TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] using regex
|
||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content)
|
||
|
||
# Reassemble content, choosing processing based on message position
|
||
filtered_content = ""
|
||
current_tag = None
|
||
is_recent_message = assistant_position >= cutoff_point # Recent messages
|
||
|
||
for i in range(0, len(parts)):
|
||
if i % 2 == 0: # Text content
|
||
text = parts[i].strip()
|
||
if not text:
|
||
continue
|
||
|
||
# Do not forward historical tool call text
|
||
if current_tag == "TOOL_RESPONSE":
|
||
if is_recent_message:
|
||
# Recent ASSISTANT messages: preserve full TOOL_RESPONSE info (using abbreviated mode)
|
||
if len(text) <= 1000:
|
||
filtered_content += f"[TOOL_RESPONSE] {text}\n"
|
||
else:
|
||
# Extract first, middle, and last sections, 250 chars each
|
||
first_part = text[:250]
|
||
middle_start = len(text) // 2 - 125
|
||
middle_part = text[middle_start:middle_start + 250]
|
||
last_part = text[-250:]
|
||
|
||
# Calculate omitted character count
|
||
omitted_count = len(text) - 750
|
||
omitted_text = f"...{omitted_count} chars omitted..."
|
||
|
||
# Concatenate content
|
||
truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}"
|
||
filtered_content += f"[TOOL_RESPONSE] {truncated_text}\n"
|
||
# Messages beyond the recent window: do not preserve TOOL_RESPONSE data (skip entirely)
|
||
elif current_tag == "TOOL_CALL":
|
||
if is_recent_message:
|
||
# Recent ASSISTANT messages: preserve TOOL_CALL info
|
||
filtered_content += f"[TOOL_CALL] {text}\n"
|
||
# Messages beyond the recent window: do not preserve TOOL_CALL data (skip entirely)
|
||
elif current_tag == "ANSWER":
|
||
# All ASSISTANT messages preserve ANSWER data
|
||
filtered_content += f"[ANSWER] {text}\n"
|
||
elif current_tag != "THINK" and current_tag != "PREAMBLE":
|
||
filtered_content += text + "\n"
|
||
else: # Tag
|
||
current_tag = parts[i]
|
||
|
||
# Get the final processed content, stripping leading/trailing whitespace
|
||
final_content = filtered_content.strip()
|
||
if final_content:
|
||
processed_messages.append({"role": msg.role, "content": final_content})
|
||
else:
|
||
# If processed content is empty, use original content
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
# User/other messages (or assistant messages carrying multimodal list
|
||
# content) pass through; normalize multimodal blocks to LangChain standard.
|
||
processed_messages.append({"role": msg.role, "content": normalize_content_blocks(msg.content)})
|
||
|
||
# Inverse operation: reassemble messages containing [THINK|TOOL_RESPONSE] back into
|
||
# msg['role'] == 'function' and msg.get('function_call') format.
|
||
# This is the inverse of get_content_from_messages.
|
||
final_messages = []
|
||
for msg in processed_messages:
|
||
if msg["role"] == ASSISTANT and isinstance(msg["content"], str):
|
||
# Split message content
|
||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||
|
||
current_tag = None
|
||
tool_id_counter = 0 # Unique tool call counter
|
||
tool_id_list = []
|
||
for i in range(0, len(parts)):
|
||
if i % 2 == 0: # Text content
|
||
text = parts[i].strip()
|
||
if not text:
|
||
continue
|
||
# Do not forward historical tool call text
|
||
|
||
if current_tag == "TOOL_RESPONSE":
|
||
# Parse TOOL_RESPONSE format: [TOOL_RESPONSE] function_name\ncontent
|
||
lines = text.split('\n', 1)
|
||
function_name = lines[0].strip() if lines else ""
|
||
response_content = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
# Filter out function_names that do not contain specified keywords
|
||
should_include = False
|
||
if function_name:
|
||
for exclude_name in include_function_name:
|
||
if exclude_name in function_name:
|
||
should_include = True
|
||
break
|
||
|
||
if should_include and len(tool_id_list)>0:
|
||
tool_id = tool_id_list.pop(0)
|
||
# Wrap TOOL_RESPONSE as a tool_result message, following the corresponding tool_use
|
||
final_messages.append({
|
||
"role": TOOL,
|
||
"tool_call_id": tool_id, # Keep consistent with the preceding tool_use id
|
||
"name": function_name,
|
||
"content": response_content
|
||
})
|
||
elif current_tag == "TOOL_CALL":
|
||
# Parse TOOL_CALL format: [TOOL_CALL] function_name\narguments
|
||
lines = text.split('\n', 1)
|
||
function_name = lines[0].strip() if lines else ""
|
||
arguments = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
# Filter out function_names that do not contain specified keywords
|
||
should_include = False
|
||
if function_name:
|
||
for exclude_name in include_function_name:
|
||
if exclude_name in function_name:
|
||
should_include = True
|
||
break
|
||
|
||
if should_include:
|
||
tool_id = f"tool_id_{tool_id_counter}" # Use unique counter
|
||
tool_id_list.append(tool_id)
|
||
tool_id_counter += 1 # Increment counter
|
||
final_messages.append({
|
||
"role": ASSISTANT,
|
||
"content": "",
|
||
"tool_calls": [{
|
||
"id":tool_id,
|
||
"function": {
|
||
"name": function_name,
|
||
"arguments": arguments
|
||
}
|
||
}]
|
||
})
|
||
elif current_tag != "THINK" and current_tag != "PREAMBLE":
|
||
final_messages.append({
|
||
"role": ASSISTANT,
|
||
"content": text
|
||
})
|
||
else: # Tag
|
||
current_tag = parts[i]
|
||
else:
|
||
# Non-assistant messages or messages without [TOOL_RESPONSE] are added directly
|
||
final_messages.append(msg)
|
||
return final_messages
|
||
|
||
|
||
def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||
"""Extract plain text from message content that may be a multimodal block list."""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
texts = []
|
||
for block in content:
|
||
if isinstance(block, dict) and block.get("type") == "text":
|
||
texts.append(block.get("text", ""))
|
||
elif isinstance(block, str):
|
||
texts.append(block)
|
||
return "\n".join(texts)
|
||
return ""
|
||
|
||
|
||
def get_user_last_message_content(messages: list) -> Optional[str]:
|
||
"""Get the last user message's plain text content from a message list.
|
||
|
||
Multimodal list content is flattened to text so downstream consumers
|
||
(e.g. terms embedding) always receive a string.
|
||
"""
|
||
if not messages or len(messages) == 0:
|
||
return ""
|
||
last_message = messages[-1]
|
||
if last_message and last_message.get('role') == 'user':
|
||
return extract_text_from_content(last_message.get("content", ""))
|
||
return ""
|
||
|
||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||
"""Format messages as plain text chat history.
|
||
|
||
Args:
|
||
messages: Message list
|
||
|
||
Returns:
|
||
str: Formatted chat history
|
||
"""
|
||
# Only take the last 15 messages
|
||
chat_history = []
|
||
for message in messages:
|
||
role = message.get('role', '')
|
||
content = message.get('content', '')
|
||
name = message.get('name', '')
|
||
if role == USER:
|
||
chat_history.append(f"user: {content}")
|
||
elif role == TOOL:
|
||
chat_history.append(f"{name} response: {content}")
|
||
elif role == ASSISTANT:
|
||
if len(content) >0:
|
||
chat_history.append(f"assistant: {content}")
|
||
if message.get('tool_calls'):
|
||
for tool_call in message.get('tool_calls'):
|
||
function_name = tool_call.get('function').get('name')
|
||
arguments = tool_call.get('function').get('arguments')
|
||
chat_history.append(f"{function_name} call: {arguments}")
|
||
|
||
recent_chat_history = chat_history[-16:-1] if len(chat_history) > 16 else chat_history[:-1]
|
||
return "\n".join(recent_chat_history)
|
||
|
||
|
||
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, skills: Optional[List[str]] = None) -> Optional[str]:
|
||
"""Common logic for creating project directories."""
|
||
|
||
# If dataset_ids is empty, do not create a directory
|
||
if not dataset_ids:
|
||
dataset_ids = []
|
||
|
||
try:
|
||
from utils.multi_project_manager import create_robot_project
|
||
from pathlib import Path
|
||
return create_robot_project(dataset_ids, bot_id, skills=skills)
|
||
except Exception as e:
|
||
logger.error(f"Error creating project directory: {e}")
|
||
return None
|
||
|
||
|
||
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
||
"""Extract API key from Authorization header."""
|
||
if not authorization:
|
||
return None
|
||
|
||
# Remove "Bearer " prefix
|
||
if authorization.startswith("Bearer "):
|
||
return authorization[7:]
|
||
else:
|
||
return authorization
|
||
|
||
|
||
def generate_v2_auth_token(bot_id: str) -> str:
|
||
"""Generate authentication token for v2 API."""
|
||
token_input = f"{MASTERKEY}:{bot_id}"
|
||
return hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
|
||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||
"""Fetch bot config from the backend API."""
|
||
try:
|
||
url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
|
||
|
||
auth_token = generate_v2_auth_token(bot_id)
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
# Use async HTTP request
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, headers=headers, timeout=30) as response:
|
||
if response.status != 200:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: API returned status code {response.status}"
|
||
)
|
||
|
||
# Parse response
|
||
response_data = await response.json()
|
||
|
||
if not response_data.get("success"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
|
||
)
|
||
|
||
return response_data.get("data", {})
|
||
|
||
except aiohttp.ClientError as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to connect to backend API: {str(e)}"
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, HTTPException):
|
||
raise
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to fetch bot config: {str(e)}"
|
||
)
|
||
|
||
|
||
async def fetch_bot_config_from_db(bot_user_id: str, user_identifier: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
从本地数据库获取机器人配置
|
||
|
||
Args:
|
||
bot_user_id: Bot 的用户ID(bot_id 字段,不是 UUID)
|
||
user_identifier: 当前登录用户的用户名。如果为空则使用 owner_id
|
||
|
||
Returns:
|
||
Dict[str, Any]: 包含所有配置参数的字典,格式与 fetch_bot_config 兼容
|
||
"""
|
||
try:
|
||
from agent.db_pool_manager import get_db_pool_manager
|
||
from utils.settings import NEW_API_BASE_URL
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 从 agent_bots 表获取 bot 信息和 settings,同时获取 owner_id
|
||
await cursor.execute(
|
||
"""
|
||
SELECT id, name, settings, owner_id
|
||
FROM agent_bots WHERE id = %s
|
||
""",
|
||
(bot_user_id,)
|
||
)
|
||
bot_row = await cursor.fetchone()
|
||
|
||
if not bot_row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"Bot with id '{bot_user_id}' not found"
|
||
)
|
||
|
||
bot_uuid = bot_row[0]
|
||
bot_name = bot_row[1]
|
||
settings_json = bot_row[2]
|
||
owner_id = bot_row[3]
|
||
|
||
# 解析 settings JSONB 字段
|
||
if settings_json:
|
||
if isinstance(settings_json, str):
|
||
try:
|
||
settings_data = json.loads(settings_json)
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"Failed to parse settings JSON for bot {bot_user_id}")
|
||
settings_data = {}
|
||
else:
|
||
settings_data = settings_json
|
||
else:
|
||
settings_data = {}
|
||
|
||
# 获取 model_id(来自 New API,格式为 "Provider/ModelName")
|
||
model_id = settings_data.get("model_id", "")
|
||
|
||
# 获取 bot owner 的 new_api_token 作为 api_key
|
||
api_key = ""
|
||
if owner_id:
|
||
await cursor.execute(
|
||
"""
|
||
SELECT new_api_token
|
||
FROM agent_user WHERE id = %s
|
||
""",
|
||
(owner_id,)
|
||
)
|
||
user_row = await cursor.fetchone()
|
||
if user_row and user_row[0]:
|
||
api_key = user_row[0]
|
||
|
||
# 构建 config 字典
|
||
# model_id 格式为 "Provider/ModelName",需要拆分
|
||
model_name = model_id
|
||
model_server = NEW_API_BASE_URL.rstrip('/') + "/v1" if NEW_API_BASE_URL else ""
|
||
|
||
config = {
|
||
"name": bot_name,
|
||
"model": model_name,
|
||
"api_key": api_key,
|
||
"model_server": model_server,
|
||
"language": settings_data.get("language", "zh"),
|
||
"dataset_ids": settings_data.get("dataset_ids", []),
|
||
"system_prompt": settings_data.get("system_prompt", ""),
|
||
"user_identifier": user_identifier if user_identifier else owner_id,
|
||
"enable_memori": settings_data.get("enable_memori", False),
|
||
"tool_response": settings_data.get("tool_response", True),
|
||
"enable_thinking": settings_data.get("enable_thinking", False),
|
||
"skills": settings_data.get("skills", []),
|
||
"description": settings_data.get("description", ""),
|
||
"suggestions": settings_data.get("suggestions", []),
|
||
"shell_env": settings_data.get("shell_env") or {},
|
||
"voice_speaker": settings_data.get("voice_speaker", ""),
|
||
"voice_system_role": settings_data.get("voice_system_role", ""),
|
||
"voice_speaking_style": settings_data.get("voice_speaking_style", ""),
|
||
}
|
||
|
||
# 处理 dataset_ids
|
||
dataset_ids = config['dataset_ids']
|
||
if dataset_ids:
|
||
if isinstance(dataset_ids, str):
|
||
if dataset_ids.startswith('['):
|
||
try:
|
||
config['dataset_ids'] = json.loads(dataset_ids)
|
||
except json.JSONDecodeError:
|
||
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
|
||
else:
|
||
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
|
||
else:
|
||
config['dataset_ids'] = []
|
||
|
||
# 处理 skills
|
||
skills = config.get('skills', [])
|
||
if isinstance(skills, str):
|
||
config['skills'] = [s.strip() for s in skills.split(',') if s.strip()]
|
||
elif not isinstance(skills, list):
|
||
config['skills'] = []
|
||
|
||
# 处理 suggestions
|
||
suggestions = config.get('suggestions', [])
|
||
if isinstance(suggestions, str):
|
||
config['suggestions'] = [s.strip() for s in suggestions.split('\n') if s.strip()]
|
||
elif not isinstance(suggestions, list):
|
||
config['suggestions'] = []
|
||
|
||
# 查询 MCP 服务器配置
|
||
await cursor.execute(
|
||
"""
|
||
SELECT name, type, config, enabled
|
||
FROM agent_mcp_servers WHERE bot_id = %s AND enabled = true
|
||
""",
|
||
(bot_uuid,)
|
||
)
|
||
mcp_rows = await cursor.fetchall()
|
||
|
||
mcp_servers = []
|
||
for mcp_row in mcp_rows:
|
||
mcp_name = mcp_row[0]
|
||
mcp_type = mcp_row[1]
|
||
mcp_config = mcp_row[2]
|
||
|
||
# 如果 config 是 JSONB/字符串,解析它
|
||
if isinstance(mcp_config, str):
|
||
try:
|
||
mcp_config = json.loads(mcp_config)
|
||
except json.JSONDecodeError:
|
||
mcp_config = {}
|
||
|
||
mcp_servers.append({
|
||
"name": mcp_name,
|
||
"type": mcp_type,
|
||
"config": mcp_config
|
||
})
|
||
|
||
# 格式化为 mcp_settings 格式 (兼容 v2 API)
|
||
if mcp_servers:
|
||
mcp_settings_value = []
|
||
for server in mcp_servers:
|
||
server_config = server.get("config", {})
|
||
# 优先用 config 中的 server_type,fallback 到数据库的 name 字段
|
||
server_type = server_config.pop("server_type", server["name"])
|
||
mcp_settings_value.append({
|
||
"mcpServers": {
|
||
server_type: server_config
|
||
}
|
||
})
|
||
config["mcp_settings"] = mcp_settings_value
|
||
else:
|
||
config["mcp_settings"] = []
|
||
|
||
logger.info(f"Fetched bot config for {bot_user_id}: model={config['model']}, api_key={'*' + config['api_key'][-4:] if config['api_key'] else 'N/A'}")
|
||
return config
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error fetching bot config from database: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to fetch bot config from database: {str(e)}"
|
||
)
|
||
|
||
|
||
async def _sync_call_llm(llm_config, messages) -> str:
|
||
"""Sync LLM helper function, executed in thread pool - using LangChain."""
|
||
try:
|
||
# Create LangChain LLM instance
|
||
model_name = llm_config.get('model')
|
||
model_server = llm_config.get('model_server')
|
||
api_key = llm_config.get('api_key')
|
||
# Detect or use specified provider
|
||
model_provider,base_url = detect_provider(model_name,model_server)
|
||
|
||
model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs(
|
||
model_name=model_name,
|
||
model_provider=model_provider,
|
||
base_url=base_url,
|
||
api_key=api_key,
|
||
source="_sync_call_llm"
|
||
)
|
||
if dropped_params:
|
||
logger.info(
|
||
"_sync_call_llm dropped_params=%s model=%s provider=%s default_temperature_applied=%s",
|
||
dropped_params,
|
||
model_name,
|
||
model_provider,
|
||
default_temperature_applied
|
||
)
|
||
llm_instance = init_chat_model(**model_kwargs)
|
||
|
||
# Convert messages to LangChain format
|
||
langchain_messages = []
|
||
for msg in messages:
|
||
if msg['role'] == 'system':
|
||
langchain_messages.append(SystemMessage(content=msg['content']))
|
||
elif msg['role'] == 'user':
|
||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||
elif msg['role'] == 'assistant':
|
||
langchain_messages.append(AIMessage(content=msg['content']))
|
||
|
||
# Call LangChain model
|
||
response = await llm_instance.ainvoke(langchain_messages)
|
||
|
||
# Return response content
|
||
return response.content if response.content else ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling guideline LLM with LangChain: {e}")
|
||
return ""
|
||
|
||
def get_language_text(language: str):
|
||
if language == "jp":
|
||
language = "ja"
|
||
language_map = {
|
||
'zh': '请用中文回复',
|
||
'en': 'Please reply in English',
|
||
'ja': '日本語で回答してください',
|
||
}
|
||
return language_map.get(language.lower(), '')
|
||
|
||
def get_preamble_text(language: str, system_prompt: str):
|
||
# First check if system_prompt has a preamble tag
|
||
if system_prompt:
|
||
preamble_pattern = r'<preamble>\s*(.*?)\s*</preamble>'
|
||
preamble_matches = re.findall(preamble_pattern, system_prompt, re.DOTALL)
|
||
if preamble_matches:
|
||
# Extract preamble content
|
||
preamble_content = preamble_matches[0].strip()
|
||
if preamble_content:
|
||
# Remove preamble tag from system_prompt
|
||
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
|
||
return preamble_content, cleaned_system_prompt
|
||
|
||
# If no preamble block found, use default preamble choices
|
||
if language == "jp":
|
||
language = "ja"
|
||
preamble_choices_map = {
|
||
'zh': [
|
||
"好的,让我来帮您看看。",
|
||
"明白了,请稍等。",
|
||
"好的,我理解了。",
|
||
"没问题,我来处理。",
|
||
"收到,正在为您查询。",
|
||
"了解,让我想想。",
|
||
"好的,我来帮您解答。",
|
||
"明白了,稍等片刻。",
|
||
"好的,正在处理中。",
|
||
"了解了,让我为您分析。"
|
||
],
|
||
'en': [
|
||
"Just a moment.",
|
||
"Got it.",
|
||
"Let me check that for you.",
|
||
"Sorry to hear that.",
|
||
"Thanks for your patience.",
|
||
"I understand.",
|
||
"Let me help you with that.",
|
||
"Please wait a moment.",
|
||
"I'll look into that for you.",
|
||
"Gotcha, let me see.",
|
||
"Understood, one moment please.",
|
||
"I'll help you with this.",
|
||
"Let me figure that out.",
|
||
"Thanks for waiting.",
|
||
"I'll check on that."
|
||
],
|
||
'ja': [
|
||
"少々お待ちください。",
|
||
"承知いたしました。",
|
||
"わかりました。",
|
||
"確認いたします。",
|
||
"少々お時間をください。",
|
||
"了解しました。",
|
||
"調べてみますね。",
|
||
"お待たせしました。",
|
||
"対応いたします。",
|
||
"わかりましたね。",
|
||
"承知いたしました。",
|
||
"確認させてください。",
|
||
"少々お待ちいただけますか。",
|
||
"お調べいたします。",
|
||
"対応いたしますね。"
|
||
]
|
||
};
|
||
default_preamble = "\n".join(preamble_choices_map.get(language.lower(), []))
|
||
return default_preamble, system_prompt # Return default preamble and original system_prompt
|
||
|
||
|
||
async def call_preamble_llm(config: AgentConfig) -> str:
|
||
"""Call LLM to process preamble analysis.
|
||
|
||
Args:
|
||
messages: Message list
|
||
preamble_choices_text: Guideline text
|
||
model_name: Model name
|
||
api_key: API key
|
||
model_server: Model server URL
|
||
|
||
Returns:
|
||
str: Model response result
|
||
"""
|
||
# Read preamble prompt template
|
||
try:
|
||
with open('./prompt/preamble_prompt.md', 'r', encoding='utf-8') as f:
|
||
preamble_template = f.read()
|
||
except Exception as e:
|
||
logger.error(f"Error reading guideline prompt template: {e}")
|
||
return ""
|
||
|
||
api_key = config.api_key
|
||
model_name = config.model_name
|
||
model_server = config.model_server
|
||
language = config.language
|
||
preamble_choices_text = config.preamble_text
|
||
last_message = get_user_last_message_content(config.messages)
|
||
chat_history = format_messages_to_chat_history(convert_to_openai_messages(config._session_history))
|
||
|
||
# Replace placeholders in the template
|
||
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
||
# Configure LLM
|
||
llm_config = {
|
||
'model': model_name,
|
||
'api_key': api_key,
|
||
'model_server': model_server, # Use the passed model_server parameter
|
||
}
|
||
|
||
# Call model
|
||
messages = [{'role': 'user', 'content': system_prompt}]
|
||
|
||
try:
|
||
# Use semaphore to control concurrent API calls
|
||
async with api_semaphore:
|
||
# Call async LLM function directly
|
||
response = await _sync_call_llm(llm_config, messages)
|
||
|
||
# Extract content wrapped in ```json and ``` from response
|
||
json_pattern = r'```json\s*\n(.*?)\n```'
|
||
json_matches = re.findall(json_pattern, response, re.DOTALL)
|
||
|
||
if json_matches:
|
||
try:
|
||
# Parse the first found JSON object
|
||
json_data = json.loads(json_matches[0])
|
||
logger.info(f"Successfully processed preamble")
|
||
return json_data["preamble"] # Return the parsed preamble
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"Error parsing JSON from preamble analysis: {e}")
|
||
return ""
|
||
else:
|
||
logger.warning(f"No JSON format found in preamble analysis")
|
||
return ""
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling guideline LLM: {e}")
|
||
return ""
|
||
|
||
|
||
|
||
async def call_guideline_llm(chat_history: str, guidelines_prompt: str, model_name: str, api_key: str, model_server: str) -> str:
|
||
"""Call LLM to process guideline analysis.
|
||
|
||
Args:
|
||
chat_history: Chat history
|
||
guidelines_text: Guideline text
|
||
model_name: Model name
|
||
api_key: API key
|
||
model_server: Model server URL
|
||
user_identifier: User identifier
|
||
|
||
Returns:
|
||
str: Model response result
|
||
"""
|
||
|
||
# Configure LLM
|
||
llm_config = {
|
||
'model': model_name,
|
||
'api_key': api_key,
|
||
'model_server': model_server, # Use the passed model_server parameter
|
||
}
|
||
|
||
# Call model
|
||
messages = [{'role': 'user', 'content': guidelines_prompt}]
|
||
|
||
try:
|
||
# Use semaphore to control concurrent API calls
|
||
async with api_semaphore:
|
||
# Call async LLM function directly
|
||
response = await _sync_call_llm(llm_config, messages)
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling guideline LLM: {e}")
|
||
return ""
|
||
|
||
|
||
def _get_optimal_batch_size(guidelines_count: int) -> int:
|
||
"""Determine the optimal batch size (concurrency) based on guidelines count."""
|
||
if guidelines_count <= 10:
|
||
return 1
|
||
elif guidelines_count <= 20:
|
||
return 2
|
||
elif guidelines_count <= 30:
|
||
return 3
|
||
else:
|
||
return 5
|
||
|
||
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
|
||
"""
|
||
Extract guideline and terms content from the system prompt.
|
||
|
||
Args:
|
||
system_prompt: System prompt text
|
||
|
||
Returns:
|
||
tuple[str, List[Dict], List[Dict]]: (cleaned system_prompt, guidelines_list, terms_list)
|
||
"""
|
||
if not system_prompt:
|
||
return "", [], []
|
||
|
||
guidelines = ""
|
||
tools = ""
|
||
scenarios = ""
|
||
|
||
terms_list = []
|
||
|
||
# Parse blocks using XML tag format
|
||
blocks_to_remove = []
|
||
|
||
# Parse <guidelines>
|
||
guidelines_pattern = r'<guidelines>\s*(.*?)\s*</guidelines>'
|
||
match = re.search(guidelines_pattern, system_prompt, re.DOTALL)
|
||
if match:
|
||
guidelines = match.group(1).strip()
|
||
blocks_to_remove.append(match.group(0))
|
||
|
||
# Parse <tools>
|
||
tools_pattern = r'<tools>\s*(.*?)\s*</tools>'
|
||
match = re.search(tools_pattern, system_prompt, re.DOTALL)
|
||
if match:
|
||
tools = match.group(1).strip()
|
||
blocks_to_remove.append(match.group(0))
|
||
|
||
# Parse <scenarios>
|
||
scenarios_pattern = r'<scenarios>\s*(.*?)\s*</scenarios>'
|
||
match = re.search(scenarios_pattern, system_prompt, re.DOTALL)
|
||
if match:
|
||
scenarios = match.group(1).strip()
|
||
blocks_to_remove.append(match.group(0))
|
||
|
||
# Parse <terms>
|
||
terms_pattern = r'<terms>\s*(.*?)\s*</terms>'
|
||
match = re.search(terms_pattern, system_prompt, re.DOTALL)
|
||
if match:
|
||
try:
|
||
terms = parse_terms_text(match.group(1).strip())
|
||
terms_list.extend(terms)
|
||
blocks_to_remove.append(match.group(0))
|
||
except Exception as e:
|
||
logger.error(f"Error parsing terms: {e}")
|
||
|
||
# Remove parsed blocks from system_prompt
|
||
cleaned_prompt = system_prompt
|
||
for block in blocks_to_remove:
|
||
cleaned_prompt = cleaned_prompt.replace(block, '', 1)
|
||
|
||
# Clean up excess blank lines
|
||
cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip()
|
||
return cleaned_prompt, guidelines, tools, scenarios, terms_list
|
||
|
||
|
||
def parse_terms_text(text: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
Parse terms text, supporting multiple formats.
|
||
|
||
Args:
|
||
text: Terms text content
|
||
|
||
Returns:
|
||
List[Dict]: List of terms
|
||
"""
|
||
terms = []
|
||
|
||
# Try to parse as JSON format
|
||
if text.strip().startswith('[') or text.strip().startswith('{'):
|
||
try:
|
||
data = json.loads(text)
|
||
if isinstance(data, list):
|
||
for item in data:
|
||
if isinstance(item, dict):
|
||
terms.append(item)
|
||
elif isinstance(data, dict):
|
||
terms.append(data)
|
||
return terms
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Parse line format, supporting multiple separators
|
||
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||
|
||
current_term = {}
|
||
|
||
for line in lines:
|
||
# Skip comment lines
|
||
if line.startswith('#') or line.startswith('//'):
|
||
continue
|
||
|
||
# Try to parse "1) Name: term_name1, Description: desc, Synonyms: syn1, syn2" format
|
||
numbered_term_pattern = r'(?:\d+\)\s*)?Name:\s*([^,]+)(?:,\s*Description:\s*([^,]+))?(?:,\s*Synonyms:\s*(.+))?'
|
||
match = re.match(numbered_term_pattern, line, re.IGNORECASE)
|
||
if match:
|
||
name = match.group(1).strip()
|
||
description = match.group(2).strip() if match.group(2) else ''
|
||
synonyms_text = match.group(3).strip() if match.group(3) else ''
|
||
|
||
# Build term object
|
||
term_data = {'name': name}
|
||
if description:
|
||
term_data['description'] = description
|
||
if synonyms_text:
|
||
synonyms = re.split(r'[,;|]', synonyms_text)
|
||
term_data['synonyms'] = [s.strip() for s in synonyms if s.strip()]
|
||
|
||
if current_term: # Save the previous term
|
||
terms.append(current_term)
|
||
current_term = term_data
|
||
continue
|
||
|
||
# Add the last term
|
||
if current_term:
|
||
terms.append(current_term)
|
||
|
||
return terms
|