catalog-agent/fastapi_app.py
2025-10-16 21:06:02 +08:00

387 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
from typing import AsyncGenerator, Dict, List, Optional, Union
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
# 自定义版本不需要text参数不打印到终端
def get_content_from_messages(messages: List[dict]) -> str:
full_text = ''
content = []
TOOL_CALL_S = '[TOOL_CALL]'
TOOL_RESULT_S = '[TOOL_RESPONSE]'
THOUGHT_S = '[THINK]'
ANSWER_S = '[ANSWER]'
for msg in messages:
if msg['role'] == ASSISTANT:
if msg.get('reasoning_content'):
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
if msg.get('content'):
assert isinstance(msg['content'], str), 'Now only supports text messages'
content.append(f'{ANSWER_S}\n{msg["content"]}')
if msg.get('function_call'):
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
elif msg['role'] == FUNCTION:
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
else:
raise TypeError
if content:
full_text = '\n'.join(content)
return full_text
from file_loaded_agent_manager import get_global_agent_manager, init_global_agent_manager
from gbase_agent import update_agent_llm
from zip_project_handler import zip_handler
def get_zip_url_from_unique_id(unique_id: str) -> Optional[str]:
"""从unique_map.json中读取zip_url"""
try:
with open('unique_map.json', 'r', encoding='utf-8') as f:
unique_map = json.load(f)
return unique_map.get(unique_id)
except Exception as e:
print(f"Error reading unique_map.json: {e}")
return None
# 全局助手管理器配置
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
# 初始化全局助手管理器
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
app = FastAPI(title="Database Assistant API", version="1.0.0")
# 挂载public文件夹为静态文件服务
app.mount("/public", StaticFiles(directory="public"), name="static")
# 添加CORS中间件支持前端页面
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: str = "qwen3-next"
model_server: str = ""
zip_url: Optional[str] = None
unique_id: Optional[str] = None
stream: Optional[bool] = False
class Config:
extra = 'allow'
class ChatResponse(BaseModel):
choices: List[Dict]
usage: Optional[Dict] = None
class ChatStreamResponse(BaseModel):
choices: List[Dict]
usage: Optional[Dict] = None
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
"""生成流式响应"""
accumulated_content = ""
chunk_id = 0
try:
for response in agent.run(messages=messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
# 只有当有新内容时才发送chunk
if new_content:
chunk_id += 1
# 构造OpenAI格式的流式响应
chunk_data = {
"id": f"chatcmpl-{chunk_id}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {
"content": new_content
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送最终完成标记
final_chunk = {
"id": f"chatcmpl-{chunk_id + 1}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in generate_stream_response: {str(e)}")
print(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, zip_url, etc.
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
"""
try:
# 从Authorization header中提取API key
api_key = None
if authorization:
# 移除 "Bearer " 前缀
if authorization.startswith("Bearer "):
api_key = authorization[7:]
else:
api_key = authorization
# 从最外层获取zip_url和unique_id参数
zip_url = request.zip_url
unique_id = request.unique_id
# 如果提供了unique_id从unique_map.json中读取zip_url
if unique_id:
zip_url = get_zip_url_from_unique_id(unique_id)
if not zip_url:
raise HTTPException(status_code=400, detail=f"No zip_url found for unique_id: {unique_id}")
if not zip_url:
raise HTTPException(status_code=400, detail="zip_url is required")
# 使用ZIP URL获取项目数据
print(f"从ZIP URL加载项目: {zip_url}")
project_dir = zip_handler.get_project_from_zip(zip_url, unique_id if unique_id else None)
if not project_dir:
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
# 收集项目目录下所有的 document.txt 文件
document_files = zip_handler.collect_document_files(project_dir)
if not document_files:
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'unique_id', 'stream'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 从全局管理器获取或创建文件预加载的助手实例
agent = await agent_manager.get_or_create_agent(
zip_url=zip_url,
files=document_files,
project_dir=project_dir,
model_name=request.model,
api_key=api_key,
model_server=request.model_server,
generate_cfg=generate_cfg
)
# 构建包含项目信息的消息上下文
messages = []
for msg in request.messages:
if msg.role == "assistant":
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
content_parts = msg.content.split("[ANSWER]")
if content_parts:
# 取最后一段非空文本
last_part = content_parts[-1].strip()
messages.append({"role": msg.role, "content": last_part})
else:
messages.append({"role": msg.role, "content": msg.content})
else:
messages.append({"role": msg.role, "content": msg.content})
# 根据stream参数决定返回流式还是非流式响应
if request.stream:
return StreamingResponse(
generate_stream_response(agent, messages, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
# 非流式响应
final_responses = agent.run_nonstream(messages)
if final_responses and len(final_responses) > 0:
# 取最后一个响应
final_response = final_responses[-1]
# 如果返回的是Message对象需要转换为字典
if hasattr(final_response, 'model_dump'):
final_response = final_response.model_dump()
elif hasattr(final_response, 'dict'):
final_response = final_response.dict()
content = final_response.get("content", "")
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {"message": "Database Assistant API is running"}
@app.get("/system/status")
async def system_status():
"""获取系统状态信息"""
# 获取助手缓存统计
cache_stats = agent_manager.get_cache_stats()
return {
"status": "running",
"storage_type": "File-Loaded Agent Manager",
"max_cached_agents": max_cached_agents,
"agent_cache": {
"total_cached_agents": cache_stats["total_cached_agents"],
"max_cached_agents": cache_stats["max_cached_agents"],
"cached_agents": cache_stats["agents"]
}
}
@app.post("/system/cleanup-cache")
async def cleanup_cache():
"""清理ZIP文件缓存和助手缓存"""
try:
# 清理ZIP文件缓存
zip_handler.cleanup_cache()
# 清理助手实例缓存
cleared_count = agent_manager.clear_cache()
return {
"message": "缓存清理成功",
"cleared_zip_files": True,
"cleared_agent_instances": cleared_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
@app.post("/system/cleanup-agent-cache")
async def cleanup_agent_cache():
"""仅清理助手实例缓存"""
try:
cleared_count = agent_manager.clear_cache()
return {
"message": "助手实例缓存清理成功",
"cleared_agent_instances": cleared_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"助手实例缓存清理失败: {str(e)}")
@app.get("/system/cached-projects")
async def get_cached_projects():
"""获取所有缓存的项目信息"""
try:
cached_urls = agent_manager.list_cached_zip_urls()
cache_stats = agent_manager.get_cache_stats()
return {
"cached_projects": cached_urls,
"cache_stats": cache_stats
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取缓存项目信息失败: {str(e)}")
@app.post("/system/remove-project-cache")
async def remove_project_cache(zip_url: str):
"""移除特定项目的缓存"""
try:
success = agent_manager.remove_cache_by_url(zip_url)
if success:
return {"message": f"项目缓存移除成功: {zip_url}"}
else:
return {"message": f"未找到项目缓存: {zip_url}", "removed": False}
except Exception as e:
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8001)