387 lines
14 KiB
Python
387 lines
14 KiB
Python
import json
|
||
import os
|
||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Depends, Header
|
||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||
|
||
|
||
# 自定义版本,不需要text参数,不打印到终端
|
||
def get_content_from_messages(messages: List[dict]) -> str:
|
||
full_text = ''
|
||
content = []
|
||
TOOL_CALL_S = '[TOOL_CALL]'
|
||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
THOUGHT_S = '[THINK]'
|
||
ANSWER_S = '[ANSWER]'
|
||
|
||
for msg in messages:
|
||
if msg['role'] == ASSISTANT:
|
||
if msg.get('reasoning_content'):
|
||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
if msg.get('content'):
|
||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
content.append(f'{ANSWER_S}\n{msg["content"]}')
|
||
if msg.get('function_call'):
|
||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
|
||
elif msg['role'] == FUNCTION:
|
||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
else:
|
||
raise TypeError
|
||
if content:
|
||
full_text = '\n'.join(content)
|
||
|
||
return full_text
|
||
|
||
from file_loaded_agent_manager import get_global_agent_manager, init_global_agent_manager
|
||
from gbase_agent import update_agent_llm
|
||
from zip_project_handler import zip_handler
|
||
|
||
|
||
def get_zip_url_from_unique_id(unique_id: str) -> Optional[str]:
|
||
"""从unique_map.json中读取zip_url"""
|
||
try:
|
||
with open('unique_map.json', 'r', encoding='utf-8') as f:
|
||
unique_map = json.load(f)
|
||
return unique_map.get(unique_id)
|
||
except Exception as e:
|
||
print(f"Error reading unique_map.json: {e}")
|
||
return None
|
||
|
||
# 全局助手管理器配置
|
||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||
|
||
# 初始化全局助手管理器
|
||
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||
|
||
# 挂载public文件夹为静态文件服务
|
||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||
|
||
# 添加CORS中间件,支持前端页面
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
class Message(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: List[Message]
|
||
model: str = "qwen3-next"
|
||
model_server: str = ""
|
||
zip_url: Optional[str] = None
|
||
unique_id: Optional[str] = None
|
||
stream: Optional[bool] = False
|
||
|
||
class Config:
|
||
extra = 'allow'
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
choices: List[Dict]
|
||
usage: Optional[Dict] = None
|
||
|
||
|
||
class ChatStreamResponse(BaseModel):
|
||
choices: List[Dict]
|
||
usage: Optional[Dict] = None
|
||
|
||
|
||
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
chunk_id = 0
|
||
try:
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in generate_stream_response: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
@app.post("/api/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, zip_url, etc.
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
"""
|
||
try:
|
||
# 从Authorization header中提取API key
|
||
api_key = None
|
||
if authorization:
|
||
# 移除 "Bearer " 前缀
|
||
if authorization.startswith("Bearer "):
|
||
api_key = authorization[7:]
|
||
else:
|
||
api_key = authorization
|
||
|
||
# 从最外层获取zip_url和unique_id参数
|
||
zip_url = request.zip_url
|
||
unique_id = request.unique_id
|
||
|
||
# 如果提供了unique_id,从unique_map.json中读取zip_url
|
||
if unique_id:
|
||
zip_url = get_zip_url_from_unique_id(unique_id)
|
||
if not zip_url:
|
||
raise HTTPException(status_code=400, detail=f"No zip_url found for unique_id: {unique_id}")
|
||
|
||
if not zip_url:
|
||
raise HTTPException(status_code=400, detail="zip_url is required")
|
||
|
||
# 使用ZIP URL获取项目数据
|
||
print(f"从ZIP URL加载项目: {zip_url}")
|
||
project_dir = zip_handler.get_project_from_zip(zip_url, unique_id if unique_id else None)
|
||
if not project_dir:
|
||
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
|
||
|
||
# 收集项目目录下所有的 document.txt 文件
|
||
document_files = zip_handler.collect_document_files(project_dir)
|
||
|
||
if not document_files:
|
||
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
|
||
|
||
# 收集额外参数作为 generate_cfg
|
||
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'unique_id', 'stream'}
|
||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||
|
||
# 从全局管理器获取或创建文件预加载的助手实例
|
||
agent = await agent_manager.get_or_create_agent(
|
||
zip_url=zip_url,
|
||
files=document_files,
|
||
project_dir=project_dir,
|
||
model_name=request.model,
|
||
api_key=api_key,
|
||
model_server=request.model_server,
|
||
generate_cfg=generate_cfg
|
||
)
|
||
# 构建包含项目信息的消息上下文
|
||
messages = []
|
||
for msg in request.messages:
|
||
if msg.role == "assistant":
|
||
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
|
||
content_parts = msg.content.split("[ANSWER]")
|
||
if content_parts:
|
||
# 取最后一段非空文本
|
||
last_part = content_parts[-1].strip()
|
||
messages.append({"role": msg.role, "content": last_part})
|
||
else:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if request.stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, request),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
final_responses = agent.run_nonstream(messages)
|
||
|
||
if final_responses and len(final_responses) > 0:
|
||
# 取最后一个响应
|
||
final_response = final_responses[-1]
|
||
|
||
# 如果返回的是Message对象,需要转换为字典
|
||
if hasattr(final_response, 'model_dump'):
|
||
final_response = final_response.model_dump()
|
||
elif hasattr(final_response, 'dict'):
|
||
final_response = final_response.dict()
|
||
|
||
content = final_response.get("content", "")
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health_check():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@app.get("/system/status")
|
||
async def system_status():
|
||
"""获取系统状态信息"""
|
||
# 获取助手缓存统计
|
||
cache_stats = agent_manager.get_cache_stats()
|
||
|
||
return {
|
||
"status": "running",
|
||
"storage_type": "File-Loaded Agent Manager",
|
||
"max_cached_agents": max_cached_agents,
|
||
"agent_cache": {
|
||
"total_cached_agents": cache_stats["total_cached_agents"],
|
||
"max_cached_agents": cache_stats["max_cached_agents"],
|
||
"cached_agents": cache_stats["agents"]
|
||
}
|
||
}
|
||
|
||
|
||
@app.post("/system/cleanup-cache")
|
||
async def cleanup_cache():
|
||
"""清理ZIP文件缓存和助手缓存"""
|
||
try:
|
||
# 清理ZIP文件缓存
|
||
zip_handler.cleanup_cache()
|
||
|
||
# 清理助手实例缓存
|
||
cleared_count = agent_manager.clear_cache()
|
||
|
||
return {
|
||
"message": "缓存清理成功",
|
||
"cleared_zip_files": True,
|
||
"cleared_agent_instances": cleared_count
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
|
||
|
||
|
||
@app.post("/system/cleanup-agent-cache")
|
||
async def cleanup_agent_cache():
|
||
"""仅清理助手实例缓存"""
|
||
try:
|
||
cleared_count = agent_manager.clear_cache()
|
||
return {
|
||
"message": "助手实例缓存清理成功",
|
||
"cleared_agent_instances": cleared_count
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"助手实例缓存清理失败: {str(e)}")
|
||
|
||
|
||
@app.get("/system/cached-projects")
|
||
async def get_cached_projects():
|
||
"""获取所有缓存的项目信息"""
|
||
try:
|
||
cached_urls = agent_manager.list_cached_zip_urls()
|
||
cache_stats = agent_manager.get_cache_stats()
|
||
|
||
return {
|
||
"cached_projects": cached_urls,
|
||
"cache_stats": cache_stats
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取缓存项目信息失败: {str(e)}")
|
||
|
||
|
||
@app.post("/system/remove-project-cache")
|
||
async def remove_project_cache(zip_url: str):
|
||
"""移除特定项目的缓存"""
|
||
try:
|
||
success = agent_manager.remove_cache_by_url(zip_url)
|
||
if success:
|
||
return {"message": f"项目缓存移除成功: {zip_url}"}
|
||
else:
|
||
return {"message": f"未找到项目缓存: {zip_url}", "removed": False}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|