添加datetime & process_message逆运算
This commit is contained in:
parent
c1a06aae35
commit
0ac0fcbfb3
@ -13,6 +13,7 @@ import re
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File, Form
|
from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File, Form
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
|
from utils.logger import logger
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from file_manager_api import router as file_manager_router
|
from file_manager_api import router as file_manager_router
|
||||||
@ -249,8 +250,8 @@ async def generate_stream_response(agent, messages, tool_response: bool, model:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
error_details = traceback.format_exc()
|
error_details = traceback.format_exc()
|
||||||
print(f"Error in generate_stream_response: {str(e)}")
|
logger.error(f"Error in generate_stream_response: {str(e)}")
|
||||||
print(f"Full traceback: {error_details}")
|
logger.error(f"Full traceback: {error_details}")
|
||||||
|
|
||||||
error_data = {
|
error_data = {
|
||||||
"error": {
|
"error": {
|
||||||
@ -784,7 +785,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
|||||||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||||||
"""获取机器人配置从后端API"""
|
"""获取机器人配置从后端API"""
|
||||||
try:
|
try:
|
||||||
backend_host = os.getenv("BACKEND_HOST", "http://127.0.0.1:8000")
|
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
||||||
|
|
||||||
auth_token = generate_v2_auth_token(bot_id)
|
auth_token = generate_v2_auth_token(bot_id)
|
||||||
@ -827,13 +828,18 @@ async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]:
|
def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||||||
"""处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加"""
|
"""处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加
|
||||||
|
|
||||||
|
这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回
|
||||||
|
msg['role'] == 'function' 和 msg.get('function_call') 的格式。
|
||||||
|
"""
|
||||||
processed_messages = []
|
processed_messages = []
|
||||||
|
|
||||||
# 收集所有ASSISTANT消息的索引
|
# 收集所有ASSISTANT消息的索引
|
||||||
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"]
|
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"]
|
||||||
total_assistant_messages = len(assistant_indices)
|
total_assistant_messages = len(assistant_indices)
|
||||||
cutoff_point = max(0, total_assistant_messages - 5)
|
cutoff_point = max(0, total_assistant_messages - 5)
|
||||||
|
|
||||||
# 处理每条消息
|
# 处理每条消息
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
if msg.role == "assistant":
|
if msg.role == "assistant":
|
||||||
@ -898,8 +904,72 @@ def process_messages(messages: List[Message], language: Optional[str] = None) ->
|
|||||||
else:
|
else:
|
||||||
processed_messages.append({"role": msg.role, "content": msg.content})
|
processed_messages.append({"role": msg.role, "content": msg.content})
|
||||||
|
|
||||||
|
# 逆运算:将包含 [TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call')
|
||||||
|
# 这是 get_content_from_messages 的逆运算
|
||||||
|
final_messages = []
|
||||||
|
for msg in processed_messages:
|
||||||
|
if msg["role"] == ASSISTANT and "[TOOL_RESPONSE]" in msg["content"]:
|
||||||
|
# 分割消息内容
|
||||||
|
parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||||||
|
|
||||||
|
current_tag = None
|
||||||
|
assistant_content = ""
|
||||||
|
function_calls = []
|
||||||
|
tool_responses = []
|
||||||
|
|
||||||
|
for i in range(0, len(parts)):
|
||||||
|
if i % 2 == 0: # 文本内容
|
||||||
|
text = parts[i].strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_tag == "TOOL_RESPONSE":
|
||||||
|
# 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent
|
||||||
|
lines = text.split('\n', 1)
|
||||||
|
function_name = lines[0].strip() if lines else ""
|
||||||
|
response_content = lines[1].strip() if len(lines) > 1 else ""
|
||||||
|
|
||||||
|
tool_responses.append({
|
||||||
|
"role": FUNCTION,
|
||||||
|
"name": function_name,
|
||||||
|
"content": response_content
|
||||||
|
})
|
||||||
|
elif current_tag == "TOOL_CALL":
|
||||||
|
# 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments
|
||||||
|
lines = text.split('\n', 1)
|
||||||
|
function_name = lines[0].strip() if lines else ""
|
||||||
|
arguments = lines[1].strip() if len(lines) > 1 else ""
|
||||||
|
|
||||||
|
function_calls.append({
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments
|
||||||
|
})
|
||||||
|
elif current_tag == "ANSWER":
|
||||||
|
assistant_content += text + "\n"
|
||||||
|
else:
|
||||||
|
# 第一个标签之前的内容也属于 assistant
|
||||||
|
assistant_content += text + "\n"
|
||||||
|
else: # 标签
|
||||||
|
current_tag = parts[i]
|
||||||
|
|
||||||
|
# 添加 assistant 消息(如果有内容)
|
||||||
|
if assistant_content.strip() or function_calls:
|
||||||
|
assistant_msg = {"role": ASSISTANT}
|
||||||
|
if assistant_content.strip():
|
||||||
|
assistant_msg["content"] = assistant_content.strip()
|
||||||
|
if function_calls:
|
||||||
|
# 如果有多个 function_call,只取第一个(兼容原有逻辑)
|
||||||
|
assistant_msg["function_call"] = function_calls[0]
|
||||||
|
final_messages.append(assistant_msg)
|
||||||
|
|
||||||
|
# 添加所有 tool_responses 作为 function 消息
|
||||||
|
final_messages.extend(tool_responses)
|
||||||
|
else:
|
||||||
|
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
|
||||||
|
final_messages.append(msg)
|
||||||
|
|
||||||
# 在最后一条消息的末尾追加回复语言
|
# 在最后一条消息的末尾追加回复语言
|
||||||
if processed_messages and language:
|
if final_messages and language:
|
||||||
language_map = {
|
language_map = {
|
||||||
'zh': '请用中文回复',
|
'zh': '请用中文回复',
|
||||||
'en': 'Please reply in English',
|
'en': 'Please reply in English',
|
||||||
@ -909,9 +979,9 @@ def process_messages(messages: List[Message], language: Optional[str] = None) ->
|
|||||||
language_instruction = language_map.get(language.lower(), '')
|
language_instruction = language_map.get(language.lower(), '')
|
||||||
if language_instruction:
|
if language_instruction:
|
||||||
# 在最后一条消息末尾追加语言指令
|
# 在最后一条消息末尾追加语言指令
|
||||||
processed_messages[-1]['content'] = processed_messages[-1]['content'] + f"\n\n{language_instruction}。"
|
final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\n{language_instruction}。"
|
||||||
|
|
||||||
return processed_messages
|
return final_messages
|
||||||
|
|
||||||
|
|
||||||
async def create_agent_and_generate_response(
|
async def create_agent_and_generate_response(
|
||||||
|
|||||||
273
mcp/datetime_server.py
Normal file
273
mcp/datetime_server.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MCP Server for date and time operations.
|
||||||
|
Provides functions to:
|
||||||
|
1. Get current date and time
|
||||||
|
2. Get current date
|
||||||
|
3. Get current time
|
||||||
|
4. Format date and time
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from mcp_common import (
|
||||||
|
load_tools_from_json,
|
||||||
|
create_error_response,
|
||||||
|
create_success_response,
|
||||||
|
create_initialize_response,
|
||||||
|
create_ping_response,
|
||||||
|
create_tools_list_response,
|
||||||
|
handle_mcp_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Handle MCP request"""
|
||||||
|
try:
|
||||||
|
method = request.get("method")
|
||||||
|
params = request.get("params", {})
|
||||||
|
request_id = request.get("id")
|
||||||
|
|
||||||
|
if method == "initialize":
|
||||||
|
return create_initialize_response(request_id, "datetime-server")
|
||||||
|
|
||||||
|
elif method == "ping":
|
||||||
|
return create_ping_response(request_id)
|
||||||
|
|
||||||
|
elif method == "tools/list":
|
||||||
|
# 从 JSON 文件加载工具定义
|
||||||
|
tools = load_tools_from_json("datetime_tools.json")
|
||||||
|
return create_tools_list_response(request_id, tools)
|
||||||
|
|
||||||
|
elif method == "tools/call":
|
||||||
|
tool_name = params.get("name")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
|
||||||
|
if tool_name == "get_current_datetime":
|
||||||
|
return await get_current_datetime(arguments, request_id)
|
||||||
|
|
||||||
|
elif tool_name == "get_current_date":
|
||||||
|
return await get_current_date(arguments, request_id)
|
||||||
|
|
||||||
|
elif tool_name == "get_current_time":
|
||||||
|
return await get_current_time(arguments, request_id)
|
||||||
|
|
||||||
|
elif tool_name == "format_datetime":
|
||||||
|
return await format_datetime(arguments, request_id)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return create_error_response(
|
||||||
|
request_id,
|
||||||
|
-32601,
|
||||||
|
f"Unknown tool: {tool_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return create_error_response(
|
||||||
|
request_id,
|
||||||
|
-32601,
|
||||||
|
f"Unknown method: {method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(
|
||||||
|
request.get("id"),
|
||||||
|
-32603,
|
||||||
|
f"Internal error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_datetime(arguments: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""获取当前日期时间"""
|
||||||
|
try:
|
||||||
|
timezone_str = arguments.get("timezone", "local")
|
||||||
|
|
||||||
|
if timezone_str == "UTC":
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
elif timezone_str == "local":
|
||||||
|
now = datetime.now()
|
||||||
|
else:
|
||||||
|
# 支持常见的时区名称
|
||||||
|
try:
|
||||||
|
import pytz
|
||||||
|
tz = pytz.timezone(timezone_str)
|
||||||
|
now = datetime.now(tz)
|
||||||
|
except ImportError:
|
||||||
|
# 如果没有pytz库,回退到本地时间
|
||||||
|
now = datetime.now()
|
||||||
|
except Exception:
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"datetime": now.isoformat(),
|
||||||
|
"year": now.year,
|
||||||
|
"month": now.month,
|
||||||
|
"day": now.day,
|
||||||
|
"hour": now.hour,
|
||||||
|
"minute": now.minute,
|
||||||
|
"second": now.second,
|
||||||
|
"weekday": now.weekday(), # 0=Monday, 6=Sunday
|
||||||
|
"timezone": timezone_str
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将结果包装在 content 字段中
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(request_id, -32603, f"获取日期时间失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_date(arguments: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""获取当前日期"""
|
||||||
|
try:
|
||||||
|
timezone_str = arguments.get("timezone", "local")
|
||||||
|
|
||||||
|
if timezone_str == "UTC":
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
elif timezone_str == "local":
|
||||||
|
now = datetime.now()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import pytz
|
||||||
|
tz = pytz.timezone(timezone_str)
|
||||||
|
now = datetime.now(tz)
|
||||||
|
except ImportError:
|
||||||
|
now = datetime.now()
|
||||||
|
except Exception:
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"date": now.date().isoformat(),
|
||||||
|
"year": now.year,
|
||||||
|
"month": now.month,
|
||||||
|
"day": now.day,
|
||||||
|
"weekday": now.weekday(),
|
||||||
|
"timezone": timezone_str
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将结果包装在 content 字段中
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(request_id, -32603, f"获取日期失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_time(arguments: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""获取当前时间"""
|
||||||
|
try:
|
||||||
|
timezone_str = arguments.get("timezone", "local")
|
||||||
|
|
||||||
|
if timezone_str == "UTC":
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
elif timezone_str == "local":
|
||||||
|
now = datetime.now()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import pytz
|
||||||
|
tz = pytz.timezone(timezone_str)
|
||||||
|
now = datetime.now(tz)
|
||||||
|
except ImportError:
|
||||||
|
now = datetime.now()
|
||||||
|
except Exception:
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"time": now.time().isoformat(),
|
||||||
|
"hour": now.hour,
|
||||||
|
"minute": now.minute,
|
||||||
|
"second": now.second,
|
||||||
|
"timezone": timezone_str
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将结果包装在 content 字段中
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(request_id, -32603, f"获取时间失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def format_datetime(arguments: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""格式化日期时间"""
|
||||||
|
try:
|
||||||
|
format_string = arguments.get("format", "%Y-%m-%d %H:%M:%S")
|
||||||
|
timezone_str = arguments.get("timezone", "local")
|
||||||
|
|
||||||
|
if timezone_str == "UTC":
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
elif timezone_str == "local":
|
||||||
|
now = datetime.now()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import pytz
|
||||||
|
tz = pytz.timezone(timezone_str)
|
||||||
|
now = datetime.now(tz)
|
||||||
|
except ImportError:
|
||||||
|
now = datetime.now()
|
||||||
|
except Exception:
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
formatted_datetime = now.strftime(format_string)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"formatted_datetime": formatted_datetime,
|
||||||
|
"format": format_string,
|
||||||
|
"original_datetime": now.isoformat(),
|
||||||
|
"timezone": timezone_str
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将结果包装在 content 字段中
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(request_id, -32603, f"格式化日期时间失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(handle_mcp_streaming(handle_request))
|
||||||
@ -14,6 +14,12 @@
|
|||||||
"./mcp/multi_keyword_search_server.py",
|
"./mcp/multi_keyword_search_server.py",
|
||||||
"{dataset_dir}"
|
"{dataset_dir}"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"datetime": {
|
||||||
|
"command": "python",
|
||||||
|
"args": [
|
||||||
|
"./mcp/datetime_server.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,12 @@
|
|||||||
"./mcp/rag_retrieve_server.py",
|
"./mcp/rag_retrieve_server.py",
|
||||||
"{bot_id}"
|
"{bot_id}"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"datetime": {
|
||||||
|
"command": "python",
|
||||||
|
"args": [
|
||||||
|
"./mcp/datetime_server.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,6 +14,12 @@
|
|||||||
"./mcp/multi_keyword_search_server.py",
|
"./mcp/multi_keyword_search_server.py",
|
||||||
"{dataset_dir}"
|
"{dataset_dir}"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"datetime": {
|
||||||
|
"command": "python",
|
||||||
|
"args": [
|
||||||
|
"./mcp/datetime_server.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
67
mcp/tools/datetime_tools.json
Normal file
67
mcp/tools/datetime_tools.json
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_current_datetime",
|
||||||
|
"description": "获取当前的日期和时间,返回详细的时间信息",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "时区设置,支持 'local'(本地时间), 'UTC'(协调世界时), 或者其他时区名称如 'Asia/Shanghai'",
|
||||||
|
"default": "Asia/Tokyo",
|
||||||
|
"enum": ["UTC", "Asia/Shanghai", "America/New_York", "Europe/London", "Asia/Tokyo"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "get_current_date",
|
||||||
|
"description": "获取当前的日期信息",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "时区设置,支持 'local'(本地时间), 'UTC'(协调世界时), 或者其他时区名称",
|
||||||
|
"default": "Asia/Tokyo",
|
||||||
|
"enum": ["UTC", "Asia/Shanghai", "America/New_York", "Europe/London", "Asia/Tokyo"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "get_current_time",
|
||||||
|
"description": "获取当前的时间信息",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "时区设置,支持 'local'(本地时间), 'UTC'(协调世界时), 或者其他时区名称",
|
||||||
|
"default": "Asia/Tokyo",
|
||||||
|
"enum": ["UTC", "Asia/Shanghai", "America/New_York", "Europe/London", "Asia/Tokyo"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "format_datetime",
|
||||||
|
"description": "按指定格式获取当前日期时间",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "日期时间格式字符串,例如: '%Y-%m-%d %H:%M:%S', '%Y年%m月%d日', '%H:%M:%S' 等",
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "时区设置,支持 'local'(本地时间), 'UTC'(协调世界时), 或者其他时区名称",
|
||||||
|
"default": "Asia/Tokyo",
|
||||||
|
"enum": ["UTC", "Asia/Shanghai", "America/New_York", "Europe/London", "Asia/Tokyo"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
@ -21,6 +21,7 @@ from typing import Dict, Iterator, List, Literal, Optional, Union
|
|||||||
from qwen_agent.agents import Assistant
|
from qwen_agent.agents import Assistant
|
||||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||||||
from qwen_agent.llm.oai import TextChatAtOAI
|
from qwen_agent.llm.oai import TextChatAtOAI
|
||||||
|
from utils.logger import tool_logger
|
||||||
|
|
||||||
class ModifiedAssistant(Assistant):
|
class ModifiedAssistant(Assistant):
|
||||||
"""
|
"""
|
||||||
@ -55,6 +56,44 @@ class ModifiedAssistant(Assistant):
|
|||||||
]
|
]
|
||||||
return any(indicator in error_str for indicator in retryable_indicators)
|
return any(indicator in error_str for indicator in retryable_indicators)
|
||||||
|
|
||||||
|
def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str:
|
||||||
|
"""重写工具调用方法,添加调试信息"""
|
||||||
|
if tool_name not in self.function_map:
|
||||||
|
error_msg = f'Tool {tool_name} does not exist. Available tools: {list(self.function_map.keys())}'
|
||||||
|
tool_logger.error(error_msg)
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
tool = self.function_map[tool_name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_logger.info(f"开始调用工具: {tool_name} {tool_args}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 调用父类的_call_tool方法
|
||||||
|
tool_result = super()._call_tool(tool_name, tool_args, **kwargs)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
tool_logger.info(f"工具 {tool_name} 执行完成,耗时: {end_time - start_time:.2f}秒 结果长度: {len(tool_result) if tool_result else 0}")
|
||||||
|
|
||||||
|
# 打印部分结果内容(避免过长)
|
||||||
|
if tool_result and len(tool_result) > 0:
|
||||||
|
preview = tool_result[:200] if len(tool_result) > 200 else tool_result
|
||||||
|
tool_logger.debug(f"工具 {tool_name} 结果预览: {preview}...")
|
||||||
|
|
||||||
|
return tool_result
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
end_time = time.time()
|
||||||
|
tool_logger.error(f"工具调用异常,耗时: {end_time - start_time:.2f}秒 异常类型: {type(ex).__name__} {str(ex)}")
|
||||||
|
|
||||||
|
# 打印完整的堆栈跟踪
|
||||||
|
import traceback
|
||||||
|
tool_logger.error(f"堆栈跟踪:\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
# 返回详细的错误信息
|
||||||
|
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
|
||||||
|
return error_message
|
||||||
|
|
||||||
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
||||||
"""带重试机制的LLM调用
|
"""带重试机制的LLM调用
|
||||||
|
|
||||||
@ -77,13 +116,13 @@ class ModifiedAssistant(Assistant):
|
|||||||
# 检查是否为可重试的错误
|
# 检查是否为可重试的错误
|
||||||
if self._is_retryable_error(e) and attempt < max_retries - 1:
|
if self._is_retryable_error(e) and attempt < max_retries - 1:
|
||||||
delay = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
delay = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||||||
print(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}),{delay}秒后重试: {str(e)}")
|
tool_logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}),{delay}秒后重试: {str(e)}")
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# 不可重试的错误或已达到最大重试次数
|
# 不可重试的错误或已达到最大重试次数
|
||||||
if attempt > 0:
|
if attempt > 0:
|
||||||
print(f"LLM调用重试失败,已达到最大重试次数 {max_retries}")
|
tool_logger.error(f"LLM调用重试失败,已达到最大重试次数 {max_retries}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _run(self, messages: List[Message], lang: Literal['en', 'zh', 'ja'] = 'en', **kwargs) -> Iterator[List[Message]]:
|
def _run(self, messages: List[Message], lang: Literal['en', 'zh', 'ja'] = 'en', **kwargs) -> Iterator[List[Message]]:
|
||||||
@ -118,6 +157,14 @@ class ModifiedAssistant(Assistant):
|
|||||||
use_tool, tool_name, tool_args, _ = self._detect_tool(out)
|
use_tool, tool_name, tool_args, _ = self._detect_tool(out)
|
||||||
if use_tool:
|
if use_tool:
|
||||||
tool_result = self._call_tool(tool_name, tool_args, messages=message_list, **kwargs)
|
tool_result = self._call_tool(tool_name, tool_args, messages=message_list, **kwargs)
|
||||||
|
|
||||||
|
# 验证工具结果
|
||||||
|
if not tool_result:
|
||||||
|
tool_logger.warning(f"工具 {tool_name} 返回空结果")
|
||||||
|
tool_result = f"Tool {tool_name} completed execution but returned empty result"
|
||||||
|
elif tool_result.startswith('An error occurred when calling tool') or tool_result.startswith('工具调用失败'):
|
||||||
|
tool_logger.error(f"工具 {tool_name} 调用失败: {tool_result}")
|
||||||
|
|
||||||
fn_msg = Message(role=FUNCTION,
|
fn_msg = Message(role=FUNCTION,
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
content=tool_result,
|
content=tool_result,
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
请仔细按照所有系统说明进行下一次用户查询:
|
请仔细按照所有系统说明进行下一次用户查询:
|
||||||
1.在适当的时候执行`rag_retrieve`工具调用,以检索准确的信息
|
1.在适当的时候执行`rag_retrieve`工具调用,以检索准确的信息。
|
||||||
2.遵守指定的输出格式和响应结构
|
2.在处理和时间有关的问题时,必须先调用`datetime`工具获取当前时间再进行处理。
|
||||||
3.逐步遵循既定的处理流程
|
3.遵守指定的输出格式和响应结构。
|
||||||
4.使用系统提示中定义的正确工具调用程序
|
4.逐步遵循既定的处理流程。
|
||||||
5.保持与既定角色和行为准则的一致性
|
5.使用系统提示中定义的正确工具调用程序。
|
||||||
|
6.保持与既定角色和行为准则的一致性。
|
||||||
|
|
||||||
{extra_prompt}
|
{extra_prompt}
|
||||||
|
|
||||||
|
|||||||
46
utils/logger.py
Normal file
46
utils/logger.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
项目日志工具
|
||||||
|
参考 qwen_agent 的日志实现方式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(name='qwen_agent_project', level=None):
|
||||||
|
"""
|
||||||
|
设置日志记录器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: logger名称
|
||||||
|
level: 日志级别,默认根据环境变量QWEN_AGENT_DEBUG决定
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 配置好的logger实例
|
||||||
|
"""
|
||||||
|
if level is None:
|
||||||
|
level = logging.DEBUG
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
# 使用与qwen_agent相同的格式
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(filename)s - %(lineno)d - %(levelname)s - %(message)s')
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
_logger = logging.getLogger(name)
|
||||||
|
_logger.setLevel(level)
|
||||||
|
|
||||||
|
# 避免重复添加handler
|
||||||
|
if not _logger.handlers:
|
||||||
|
_logger.addHandler(handler)
|
||||||
|
|
||||||
|
return _logger
|
||||||
|
|
||||||
|
|
||||||
|
# 创建项目主logger
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
# 创建专用的工具调用logger,便于过滤
|
||||||
|
tool_logger = setup_logger('qwen_agent_tools')
|
||||||
|
|
||||||
|
# 创建任务队列相关的logger
|
||||||
|
queue_logger = setup_logger('qwen_agent_queue')
|
||||||
@ -160,6 +160,17 @@ def load_mcp_settings(project_dir: str, mcp_settings: list=None, bot_id: str="",
|
|||||||
print(f"Failed to load default mcp_settings_{robot_type}: {str(e)}")
|
print(f"Failed to load default mcp_settings_{robot_type}: {str(e)}")
|
||||||
default_mcp_settings = []
|
default_mcp_settings = []
|
||||||
|
|
||||||
|
# 遍历mcpServers工具,给每个工具增加env参数
|
||||||
|
if default_mcp_settings and len(default_mcp_settings) > 0:
|
||||||
|
mcp_servers = default_mcp_settings[0].get('mcpServers', {})
|
||||||
|
for server_name, server_config in mcp_servers.items():
|
||||||
|
if isinstance(server_config, dict):
|
||||||
|
# 如果还没有env字段,则创建一个
|
||||||
|
if 'env' not in server_config:
|
||||||
|
server_config['env'] = {}
|
||||||
|
# 添加必要的环境变量
|
||||||
|
server_config['env']['BACKEND_HOST'] = os.environ.get('BACKEND_HOST', 'https://api-dev.gptbase.ai')
|
||||||
|
|
||||||
# 2. 处理传入的mcp_settings参数
|
# 2. 处理传入的mcp_settings参数
|
||||||
input_mcp_settings = []
|
input_mcp_settings = []
|
||||||
if mcp_settings is not None:
|
if mcp_settings is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user