qwen_agent/fastapi_app.py
2025-10-07 14:35:07 +08:00

312 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
from typing import AsyncGenerator, Dict, List, Optional, Union
import uvicorn
from fastapi import BackgroundTasks, FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
# 自定义版本不需要text参数不打印到终端
def get_content_from_messages(messages: List[dict]) -> str:
full_text = ''
content = []
TOOL_CALL_S = '[TOOL_CALL]'
TOOL_RESULT_S = '[TOOL_RESPONSE]'
THOUGHT_S = '[THINK]'
ANSWER_S = '[ANSWER]'
for msg in messages:
if msg['role'] == ASSISTANT:
if msg.get('reasoning_content'):
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
if msg.get('content'):
assert isinstance(msg['content'], str), 'Now only supports text messages'
content.append(f'{ANSWER_S}\n{msg["content"]}')
if msg.get('function_call'):
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
elif msg['role'] == FUNCTION:
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
else:
raise TypeError
if content:
full_text = '\n'.join(content)
return full_text
from agent_pool import (get_agent_from_pool, init_global_agent_pool,
release_agent_to_pool)
from gbase_agent import init_agent_service_universal, update_agent_llm
from zip_project_handler import zip_handler
app = FastAPI(title="Database Assistant API", version="1.0.0")
# 全局助手实例池,在应用启动时初始化
agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1"))
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: str = "qwen3-next"
api_key: Optional[str] = None
model_server: Optional[str] = None
zip_url: Optional[str] = None
extra: Optional[Dict] = None
stream: Optional[bool] = False
file_url: Optional[str] = None
extra_prompt: Optional[str] = None
class ChatResponse(BaseModel):
choices: List[Dict]
usage: Optional[Dict] = None
class ChatStreamResponse(BaseModel):
choices: List[Dict]
usage: Optional[Dict] = None
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
"""生成流式响应"""
accumulated_content = ""
accumulated_args = ""
chunk_id = 0
try:
for response in agent.run(messages=messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
# 只有当有新内容时才发送chunk
if new_content:
chunk_id += 1
# 构造OpenAI格式的流式响应
chunk_data = {
"id": f"chatcmpl-{chunk_id}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {
"content": new_content
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送最终完成标记
final_chunk = {
"id": f"chatcmpl-{chunk_id + 1}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in generate_stream_response: {str(e)}")
print(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, project_id in extra field, etc.
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
"""
agent = None
try:
# 从最外层获取zip_url参数
zip_url = request.zip_url
if not zip_url:
raise HTTPException(status_code=400, detail="zip_url is required")
# 使用ZIP URL获取项目数据
print(f"从ZIP URL加载项目: {zip_url}")
project_dir = zip_handler.get_project_from_zip(zip_url)
if not project_dir:
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
# 从实例池获取助手实例
agent = await get_agent_from_pool(timeout=30.0)
# 动态设置请求的模型支持从接口传入api_key、model_server和extra参数
update_agent_llm(agent, request.model, request.api_key, request.model_server)
extra_prompt = request.extra_prompt if request.extra_prompt else ""
# 构建包含项目信息的消息上下文
messages = [
# 项目信息系统消息
{
"role": "user",
"content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}\n"+ extra_prompt
},
# 用户消息批量转换
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
]
# 根据stream参数决定返回流式还是非流式响应
if request.stream:
return StreamingResponse(
generate_stream_response(agent, messages, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
# 非流式响应
final_responses = agent.run_nonstream(messages)
if final_responses and len(final_responses) > 0:
# 取最后一个响应
final_response = final_responses[-1]
# 如果返回的是Message对象需要转换为字典
if hasattr(final_response, 'model_dump'):
final_response = final_response.model_dump()
elif hasattr(final_response, 'dict'):
final_response = final_response.dict()
content = final_response.get("content", "")
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
finally:
# 确保释放助手实例回池
if agent is not None:
await release_agent_to_pool(agent)
@app.get("/")
async def root():
"""Health check endpoint"""
return {"message": "Database Assistant API is running"}
@app.get("/system/status")
async def system_status():
"""获取系统状态信息"""
from agent_pool import get_agent_pool
pool = get_agent_pool()
pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0}
return {
"status": "running",
"storage_type": "Agent Pool API",
"agent_pool": {
"pool_size": pool_stats["pool_size"],
"available_agents": pool_stats["available_agents"],
"total_agents": pool_stats["total_agents"],
"in_use_agents": pool_stats["in_use_agents"]
}
}
@app.post("/system/cleanup-cache")
async def cleanup_cache():
"""清理ZIP文件缓存"""
try:
zip_handler.cleanup_cache()
return {"message": "缓存清理成功"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化助手实例池"""
print(f"正在启动FastAPI应用初始化助手实例池大小: {agent_pool_size}...")
try:
def agent_factory():
return init_agent_service_universal()
await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory)
print("助手实例池初始化完成!")
except Exception as e:
print(f"助手实例池初始化失败: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时清理实例池"""
print("正在关闭应用,清理助手实例池...")
from agent_pool import get_agent_pool
pool = get_agent_pool()
if pool:
await pool.shutdown()
print("助手实例池清理完成!")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)