qwen_agent/routes/chat.py
2025-11-26 09:51:59 +08:00

436 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import asyncio
from typing import Union, Optional
from fastapi import APIRouter, HTTPException, Header
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from utils import (
Message, ChatRequest, ChatResponse,
get_global_agent_manager, init_global_sharded_agent_manager
)
from utils.api_models import ChatRequestV2
from utils.fastapi_utils import (
process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history,
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
call_guideline_llm, _get_optimal_batch_size, process_guideline_batch, get_content_from_messages
)
router = APIRouter()
# 初始化全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")),
shard_count=int(os.getenv("SHARD_COUNT", "16"))
)
async def generate_stream_response(agent, messages, tool_response: bool, model: str):
"""生成流式响应"""
accumulated_content = ""
chunk_id = 0
try:
for response in agent.run(messages=messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
# 只有当有新内容时才发送chunk
if new_content:
chunk_id += 1
# 构造OpenAI格式的流式响应
chunk_data = {
"id": f"chatcmpl-{chunk_id}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": new_content
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送最终完成标记
final_chunk = {
"id": f"chatcmpl-{chunk_id + 1}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
from utils.logger import logger
logger.error(f"Error in generate_stream_response: {str(e)}")
logger.error(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
async def create_agent_and_generate_response(
bot_id: str,
api_key: str,
messages: list,
stream: bool,
tool_response: bool,
model_name: str,
model_server: str,
language: str,
system_prompt: Optional[str],
mcp_settings: Optional[list],
robot_type: str,
project_dir: Optional[str] = None,
generate_cfg: Optional[dict] = None,
user_identifier: Optional[str] = None
) -> Union[ChatResponse, StreamingResponse]:
"""创建agent并生成响应的公共逻辑"""
if generate_cfg is None:
generate_cfg = {}
# 1. 从system_prompt提取guideline内容
system_prompt, guidelines_text = extract_guidelines_from_system_prompt(system_prompt)
print(f"guidelines_text: {guidelines_text}")
# 2. 如果有guideline内容进行并发处理
guideline_analysis = ""
if guidelines_text:
# 按换行符分割guidelines
guidelines_list = [g.strip() for g in guidelines_text.split('\n') if g.strip()]
guidelines_count = len(guidelines_list)
if guidelines_count > 0:
# 获取最优批次数量(并发数)
batch_count = _get_optimal_batch_size(guidelines_count)
# 计算每个批次应该包含多少条guideline
guidelines_per_batch = max(1, guidelines_count // batch_count)
# 分批处理guidelines
batches = []
for i in range(0, guidelines_count, guidelines_per_batch):
batch = guidelines_list[i:i + guidelines_per_batch]
batches.append(batch)
# 确保批次数量不超过要求的并发数
while len(batches) > batch_count:
# 将最后一个批次合并到倒数第二个批次
batches[-2].extend(batches[-1])
batches.pop()
print(f"Processing {guidelines_count} guidelines in {len(batches)} batches with {batch_count} concurrent batches")
# 准备chat_history
chat_history = format_messages_to_chat_history(messages)
# 并发执行所有任务guideline批次处理 + agent创建
tasks = []
# 添加所有guideline批次任务
for batch in batches:
task = process_guideline_batch(
guidelines_batch=batch,
chat_history=chat_history,
model_name=model_name,
api_key=api_key,
model_server=model_server
)
tasks.append(task)
# 添加agent创建任务
agent_task = agent_manager.get_or_create_agent(
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier
)
tasks.append(agent_task)
# 等待所有任务完成
all_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果最后一个结果是agent前面的是guideline批次结果
agent = all_results[-1] # agent创建的结果
batch_results = all_results[:-1] # guideline批次的结果
# 合并guideline分析结果使用JSON格式的checks数组
all_checks = []
for i, result in enumerate(batch_results):
if isinstance(result, Exception):
print(f"Guideline batch {i} failed: {result}")
continue
if result and isinstance(result, dict) and 'checks' in result:
# 如果是JSON对象且包含checks数组只保留applies为true的checks
applicable_checks = [check for check in result['checks'] if check.get('applies') is True]
all_checks.extend(applicable_checks)
elif result and isinstance(result, str) and result.strip():
# 如果是普通文本,保留原有逻辑
print(f"Non-JSON result from batch {i}: {result}")
if all_checks:
# 将checks数组格式化为JSON字符串
guideline_analysis = json.dumps({"checks": all_checks}, ensure_ascii=False)
print(f"Merged guideline analysis result: {guideline_analysis}")
# 将分析结果添加到最后一个消息的内容中
if guideline_analysis and messages:
last_message = messages[-1]
if last_message.get('role') == 'user':
messages[-1]['content'] += f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response."
else:
# 3. 从全局管理器获取或创建助手实例
agent = await agent_manager.get_or_create_agent(
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier
)
# 根据stream参数决定返回流式还是非流式响应
if stream:
return StreamingResponse(
generate_stream_response(agent, messages, tool_response, model_name),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
# 非流式响应
final_responses = agent.run_nonstream(messages)
if final_responses and len(final_responses) > 0:
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
content = get_content_from_messages(final_responses, tool_response=tool_response)
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
@router.post("/api/v1/chat/completions")
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
Notes:
- dataset_ids: 可选参数当提供时必须是项目ID列表单个项目也使用数组格式
- bot_id: 必需参数机器人ID
- 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录projects/robot/{bot_id}/
- robot_type 为其他值(包括默认的 "agent")时不创建任何目录
- dataset_ids 为空数组 []、None 或未提供时不创建任何目录
- 支持多知识库合并,自动处理文件夹重名冲突
Required Parameters:
- bot_id: str - 目标机器人ID
- messages: List[Message] - 对话消息列表
Optional Parameters:
- dataset_ids: List[str] - 源知识库项目ID列表单个项目也使用数组格式
- robot_type: str - 机器人类型,默认为 "agent"
Example:
{"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]}
"""
try:
# v1接口从Authorization header中提取API key作为模型API密钥
api_key = extract_api_key_from_auth(authorization)
# 获取bot_id必需参数
bot_id = request.bot_id
if not bot_id:
raise HTTPException(status_code=400, detail="bot_id is required")
# 创建项目目录如果有dataset_ids且不是agent类型
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 处理消息
messages = process_messages(request.messages, request.language)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages,
stream=request.stream,
tool_response=True,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
project_dir=project_dir,
generate_cfg=generate_cfg,
user_identifier=request.user_identifier
)
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.post("/api/v2/chat/completions")
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
"""
Chat completions API v2 with simplified parameters.
Only requires messages, stream, tool_response, bot_id, and language parameters.
Other parameters are fetched from the backend bot configuration API.
Args:
request: ChatRequestV2 containing only essential parameters
authorization: Authorization header for authentication (different from v1)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
Required Parameters:
- bot_id: str - 目标机器人ID
- messages: List[Message] - 对话消息列表
Optional Parameters:
- stream: bool - 是否流式输出默认false
- tool_response: bool - 是否包含工具响应默认false
- language: str - 回复语言,默认"ja"
Authentication:
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
- Authorization header should contain: Bearer {token}
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
- Optionally uses API key from bot config for model access
"""
try:
# 获取bot_id必需参数
bot_id = request.bot_id
if not bot_id:
raise HTTPException(status_code=400, detail="bot_id is required")
# v2接口鉴权验证
expected_token = generate_v2_auth_token(bot_id)
provided_token = extract_api_key_from_auth(authorization)
if not provided_token:
raise HTTPException(
status_code=401,
detail="Authorization header is required for v2 API"
)
if provided_token != expected_token:
raise HTTPException(
status_code=403,
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
)
# 从后端API获取机器人配置使用v2的鉴权方式
bot_config = await fetch_bot_config(bot_id)
# v2接口API密钥优先从后端配置获取其次才从Authorization header获取
# 注意这里的Authorization header已经用于鉴权不再作为API key使用
api_key = bot_config.get("api_key")
# 创建项目目录从后端配置获取dataset_ids
project_dir = create_project_directory(
bot_config.get("dataset_ids", []),
bot_id,
bot_config.get("robot_type", "general_agent")
)
# 处理消息
messages = process_messages(request.messages, request.language)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages,
stream=request.stream,
tool_response=request.tool_response,
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "ja"),
system_prompt=bot_config.get("system_prompt"),
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
project_dir=project_dir,
generate_cfg={}, # v2接口不传递额外的generate_cfg
user_identifier=request.user_identifier
)
except HTTPException:
raise
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions_v2: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")