qwen_agent/routes/chat.py
2025-12-13 02:52:01 +08:00

464 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import asyncio
from typing import Union, Optional
from fastapi import APIRouter, HTTPException, Header
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import logging
logger = logging.getLogger('app')
from utils import (
Message, ChatRequest, ChatResponse
)
from agent.sharded_agent_manager import init_global_sharded_agent_manager
from utils.api_models import ChatRequestV2
from agent.prompt_loader import load_guideline_prompt
from utils.fastapi_utils import (
process_messages, format_messages_to_chat_history,
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
call_preamble_llm, get_preamble_text, get_user_last_message_content,
create_stream_chunk
)
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
router = APIRouter()
# 初始化全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")),
shard_count=int(os.getenv("SHARD_COUNT", "16"))
)
def append_user_last_message(messages: list, content: str) -> bool:
"""向最后一条用户消息追加内容
Args:
messages: 消息列表
content: 要追加的内容
condition: 可选条件,如果提供则检查消息角色是否匹配此条件
Returns:
bool: 是否成功追加内容
"""
if not messages or len(messages) == 0:
return messages
last_message = messages[-1]
if last_message and last_message.get('role') == 'user':
messages[-1]['content'] += content
return messages
def append_assistant_last_message(messages: list, content: str) -> bool:
"""向最后一条用户消息追加内容
Args:
messages: 消息列表
content: 要追加的内容
condition: 可选条件,如果提供则检查消息角色是否匹配此条件
Returns:
bool: 是否成功追加内容
"""
if not messages or len(messages) == 0:
return messages
last_message = messages[-1]
if last_message and last_message.get('role') == 'assistant':
messages[-1]['content'] += content
else:
messages.append({"role":"assistant","content":content})
return messages
async def enhanced_generate_stream_response(
agent_manager,
bot_id: str,
api_key: str,
messages: list,
tool_response: bool,
model_name: str,
model_server: str,
language: str,
system_prompt: str,
mcp_settings: Optional[list],
robot_type: str,
project_dir: Optional[str],
generate_cfg: Optional[dict],
user_identifier: Optional[str]
):
"""增强的渐进式流式响应生成器"""
try:
# 第一阶段并行启动preamble_text生成和第二阶段处理
query_text = get_user_last_message_content(messages)
chat_history = format_messages_to_chat_history(messages)
# 创建preamble_text生成任务
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
# 等待preamble_text任务完成
try:
preamble_text = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
# 只有当preamble_text不为空且不为"<empty>"时才输出
if preamble_text and preamble_text.strip() and preamble_text != "<empty>":
preamble_content = f"[PREAMBLE]\n{preamble_text}\n"
chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
logger.info(f"Stream mode: Generated preamble text ({len(preamble_text)} chars)")
else:
logger.info("Stream mode: Skipped empty preamble text")
except Exception as e:
logger.error(f"Error generating preamble text: {e}")
# 等待guideline分析任务完成
agent = await agent_manager.get_or_create_agent(
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier
)
# 第三阶段agent响应流式传输
logger.info(f"Starting agent stream response")
chunk_id = 0
message_tag = ""
function_name = ""
tool_args = ""
async for msg,metadata in agent.astream({"messages": messages}, stream_mode="messages"):
new_content = ""
if isinstance(msg, AIMessageChunk):
# 判断是否有工具调用
if msg.tool_call_chunks: # 检查工具调用块
if message_tag != "TOOL_CALL":
message_tag = "TOOL_CALL"
if msg.tool_call_chunks[0]["name"]:
function_name = msg.tool_call_chunks[0]["name"]
if msg.tool_call_chunks[0]["args"]:
tool_args += msg.tool_call_chunks[0]["args"]
elif len(msg.content)>0:
if message_tag != "ANSWER":
message_tag = "ANSWER"
new_content = f"[{message_tag}]\n{msg.text}"
elif message_tag == "ANSWER":
new_content = msg.text
elif message_tag == "TOOL_CALL" and \
(
("finish_reason" in msg.response_metadata and msg.response_metadata["finish_reason"] == "tool_calls") or \
("stop_reason" in msg.response_metadata and msg.response_metadata["stop_reason"] == "tool_use")
):
new_content = f"[{message_tag}] {function_name}\n{tool_args}"
message_tag = "TOOL_CALL"
elif isinstance(msg, ToolMessage) and len(msg.content)>0:
message_tag = "TOOL_RESPONSE"
new_content = f"[{message_tag}] {msg.name}\n{msg.text}"
elif isinstance(msg, AIMessage) and msg.additional_kwargs and "thinking" in msg.additional_kwargs:
new_content = "[THINK]\n"+msg.additional_kwargs["thinking"]+ "\n"
# 只有当有新内容时才发送chunk
if new_content:
if chunk_id == 0:
logger.info(f"Agent首个Token已生成, 开始流式输出")
chunk_id += 1
chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", model_name, new_content)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", model_name, finish_reason="stop")
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
logger.info(f"Enhanced stream response completed, total chunks: {chunk_id}")
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error in enhanced_generate_stream_response: {str(e)}")
logger.error(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
async def create_agent_and_generate_response(
bot_id: str,
api_key: str,
messages: list,
stream: bool,
tool_response: bool,
model_name: str,
model_server: str,
language: str,
system_prompt: Optional[str],
mcp_settings: Optional[list],
robot_type: str,
project_dir: Optional[str] = None,
generate_cfg: Optional[dict] = None,
user_identifier: Optional[str] = None
) -> Union[ChatResponse, StreamingResponse]:
"""创建agent并生成响应的公共逻辑"""
if generate_cfg is None:
generate_cfg = {}
# 如果是流式模式,使用增强的流式响应生成器
if stream:
return StreamingResponse(
enhanced_generate_stream_response(
agent_manager=agent_manager,
bot_id=bot_id,
api_key=api_key,
messages=messages,
tool_response=tool_response,
model_name=model_name,
model_server=model_server,
language=language,
system_prompt=system_prompt or "",
mcp_settings=mcp_settings,
robot_type=robot_type,
project_dir=project_dir,
generate_cfg=generate_cfg,
user_identifier=user_identifier
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
_, system_prompt = get_preamble_text(language, system_prompt)
# 使用公共函数处理所有逻辑
agent = await agent_manager.get_or_create_agent(
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier
)
# 准备最终的消息
final_messages = messages.copy()
# 非流式响应
agent_responses = await agent.ainvoke({"messages": final_messages})
append_messages = agent_responses["messages"][len(final_messages):]
response_text = ""
for msg in append_messages:
if isinstance(msg,AIMessage):
if msg.additional_kwargs and "thinking" in msg.additional_kwargs:
response_text += "[THINK]\n"+msg.additional_kwargs["thinking"]+ "\n"
elif len(msg.text)>0:
response_text += "[ANSWER]\n"+msg.text+ "\n"
if len(msg.tool_calls)>0:
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
elif isinstance(msg,ToolMessage) and tool_response:
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
if len(response_text) > 0:
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
"completion_tokens": len(response_text),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
@router.post("/api/v1/chat/completions")
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
Notes:
- dataset_ids: 可选参数当提供时必须是项目ID列表单个项目也使用数组格式
- bot_id: 必需参数机器人ID
- 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录projects/robot/{bot_id}/
- robot_type 为其他值(包括默认的 "agent")时不创建任何目录
- dataset_ids 为空数组 []、None 或未提供时不创建任何目录
- 支持多知识库合并,自动处理文件夹重名冲突
Required Parameters:
- bot_id: str - 目标机器人ID
- messages: List[Message] - 对话消息列表
Optional Parameters:
- dataset_ids: List[str] - 源知识库项目ID列表单个项目也使用数组格式
- robot_type: str - 机器人类型,默认为 "agent"
Example:
{"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]}
"""
try:
# v1接口从Authorization header中提取API key作为模型API密钥
api_key = extract_api_key_from_auth(authorization)
# 获取bot_id必需参数
bot_id = request.bot_id
if not bot_id:
raise HTTPException(status_code=400, detail="bot_id is required")
# 创建项目目录如果有dataset_ids且不是agent类型
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 处理消息
messages = process_messages(request.messages, request.language)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages,
stream=request.stream,
tool_response=True,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
project_dir=project_dir,
generate_cfg=generate_cfg,
user_identifier=request.user_identifier
)
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error in chat_completions: {str(e)}")
logger.error(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.post("/api/v2/chat/completions")
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
"""
Chat completions API v2 with simplified parameters.
Only requires messages, stream, tool_response, bot_id, and language parameters.
Other parameters are fetched from the backend bot configuration API.
Args:
request: ChatRequestV2 containing only essential parameters
authorization: Authorization header for authentication (different from v1)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
Required Parameters:
- bot_id: str - 目标机器人ID
- messages: List[Message] - 对话消息列表
Optional Parameters:
- stream: bool - 是否流式输出默认false
- tool_response: bool - 是否包含工具响应默认false
- language: str - 回复语言,默认"ja"
Authentication:
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
- Authorization header should contain: Bearer {token}
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
- Optionally uses API key from bot config for model access
"""
try:
# 获取bot_id必需参数
bot_id = request.bot_id
if not bot_id:
raise HTTPException(status_code=400, detail="bot_id is required")
# v2接口鉴权验证
expected_token = generate_v2_auth_token(bot_id)
provided_token = extract_api_key_from_auth(authorization)
if not provided_token:
raise HTTPException(
status_code=401,
detail="Authorization header is required for v2 API"
)
if provided_token != expected_token:
raise HTTPException(
status_code=403,
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
)
# 从后端API获取机器人配置使用v2的鉴权方式
bot_config = await fetch_bot_config(bot_id)
# v2接口API密钥优先从后端配置获取其次才从Authorization header获取
# 注意这里的Authorization header已经用于鉴权不再作为API key使用
api_key = bot_config.get("api_key")
# 创建项目目录从后端配置获取dataset_ids
project_dir = create_project_directory(
bot_config.get("dataset_ids", []),
bot_id,
bot_config.get("robot_type", "general_agent")
)
# 处理消息
messages = process_messages(request.messages, request.language)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages,
stream=request.stream,
tool_response=request.tool_response,
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "ja"),
system_prompt=bot_config.get("system_prompt"),
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
project_dir=project_dir,
generate_cfg={}, # v2接口不传递额外的generate_cfg
user_identifier=request.user_identifier
)
except HTTPException:
raise
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error in chat_completions_v2: {str(e)}")
logger.error(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")