314 lines
11 KiB
Python
314 lines
11 KiB
Python
import json
|
||
import os
|
||
from contextlib import asynccontextmanager
|
||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||
|
||
import uvicorn
|
||
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||
|
||
|
||
# 自定义版本,不需要text参数,不打印到终端
|
||
def get_content_from_messages(messages: List[dict]) -> str:
|
||
full_text = ''
|
||
content = []
|
||
TOOL_CALL_S = '[TOOL_CALL]'
|
||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
THOUGHT_S = '[THINK]'
|
||
ANSWER_S = '[ANSWER]'
|
||
|
||
for msg in messages:
|
||
if msg['role'] == ASSISTANT:
|
||
if msg.get('reasoning_content'):
|
||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
if msg.get('content'):
|
||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
content.append(f'{ANSWER_S}\n{msg["content"]}')
|
||
if msg.get('function_call'):
|
||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
|
||
elif msg['role'] == FUNCTION:
|
||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
else:
|
||
raise TypeError
|
||
if content:
|
||
full_text = '\n'.join(content)
|
||
|
||
return full_text
|
||
|
||
from agent_pool import (get_agent_from_pool, init_global_agent_pool,
|
||
release_agent_to_pool)
|
||
from gbase_agent import init_agent_service_universal, update_agent_llm
|
||
from zip_project_handler import zip_handler
|
||
|
||
# 全局助手实例池,在应用启动时初始化
|
||
agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1"))
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
# 启动时初始化助手实例池
|
||
print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...")
|
||
|
||
try:
|
||
def agent_factory():
|
||
return init_agent_service_universal()
|
||
|
||
await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory)
|
||
print("助手实例池初始化完成!")
|
||
yield
|
||
except Exception as e:
|
||
print(f"助手实例池初始化失败: {e}")
|
||
raise
|
||
|
||
# 关闭时清理实例池
|
||
print("正在关闭应用,清理助手实例池...")
|
||
|
||
from agent_pool import get_agent_pool
|
||
pool = get_agent_pool()
|
||
if pool:
|
||
await pool.shutdown()
|
||
print("助手实例池清理完成!")
|
||
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
||
|
||
|
||
class Message(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: List[Message]
|
||
model: str = "qwen3-next"
|
||
api_key: Optional[str] = None
|
||
model_server: Optional[str] = None
|
||
zip_url: Optional[str] = None
|
||
generate_cfg: Optional[Dict] = None
|
||
stream: Optional[bool] = False
|
||
extra_prompt: Optional[str] = None
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
choices: List[Dict]
|
||
usage: Optional[Dict] = None
|
||
|
||
|
||
class ChatStreamResponse(BaseModel):
|
||
choices: List[Dict]
|
||
usage: Optional[Dict] = None
|
||
|
||
|
||
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
accumulated_args = ""
|
||
chunk_id = 0
|
||
try:
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in generate_stream_response: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
@app.post("/chat/completions")
|
||
async def chat_completions(request: ChatRequest):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, project_id in extra field, etc.
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
"""
|
||
agent = None
|
||
try:
|
||
# 从最外层获取zip_url参数
|
||
zip_url = request.zip_url
|
||
|
||
if not zip_url:
|
||
raise HTTPException(status_code=400, detail="zip_url is required")
|
||
|
||
# 使用ZIP URL获取项目数据
|
||
print(f"从ZIP URL加载项目: {zip_url}")
|
||
project_dir = zip_handler.get_project_from_zip(zip_url)
|
||
if not project_dir:
|
||
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
|
||
|
||
# 从实例池获取助手实例
|
||
agent = await get_agent_from_pool(timeout=30.0)
|
||
|
||
# 动态设置请求的模型,支持从接口传入api_key、model_server和extra参数
|
||
update_agent_llm(agent, request.model, request.api_key, request.model_server, request.generate_cfg)
|
||
|
||
extra_prompt = request.extra_prompt if request.extra_prompt else ""
|
||
# 构建包含项目信息的消息上下文
|
||
messages = [
|
||
# 项目信息系统消息
|
||
{
|
||
"role": "user",
|
||
"content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}\n"+ extra_prompt
|
||
},
|
||
# 用户消息批量转换
|
||
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||
]
|
||
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if request.stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, request),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
final_responses = agent.run_nonstream(messages)
|
||
|
||
if final_responses and len(final_responses) > 0:
|
||
# 取最后一个响应
|
||
final_response = final_responses[-1]
|
||
|
||
# 如果返回的是Message对象,需要转换为字典
|
||
if hasattr(final_response, 'model_dump'):
|
||
final_response = final_response.model_dump()
|
||
elif hasattr(final_response, 'dict'):
|
||
final_response = final_response.dict()
|
||
|
||
content = final_response.get("content", "")
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
finally:
|
||
# 确保释放助手实例回池
|
||
if agent is not None:
|
||
await release_agent_to_pool(agent)
|
||
|
||
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@app.get("/system/status")
|
||
async def system_status():
|
||
"""获取系统状态信息"""
|
||
from agent_pool import get_agent_pool
|
||
|
||
pool = get_agent_pool()
|
||
pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0}
|
||
|
||
return {
|
||
"status": "running",
|
||
"storage_type": "Agent Pool API",
|
||
"agent_pool": {
|
||
"pool_size": pool_stats["pool_size"],
|
||
"available_agents": pool_stats["available_agents"],
|
||
"total_agents": pool_stats["total_agents"],
|
||
"in_use_agents": pool_stats["in_use_agents"]
|
||
}
|
||
}
|
||
|
||
|
||
@app.post("/system/cleanup-cache")
|
||
async def cleanup_cache():
|
||
"""清理ZIP文件缓存"""
|
||
try:
|
||
zip_handler.cleanup_cache()
|
||
return {"message": "缓存清理成功"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|