import json import os from typing import AsyncGenerator, Dict, List, Optional, Union import uvicorn from fastapi import BackgroundTasks, FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from qwen_agent.llm.schema import ASSISTANT, FUNCTION # 自定义版本,不需要text参数,不打印到终端 def get_content_from_messages(messages: List[dict]) -> str: full_text = '' content = [] TOOL_CALL_S = '[TOOL_CALL]' TOOL_RESULT_S = '[TOOL_RESPONSE]' THOUGHT_S = '[THINK]' ANSWER_S = '[ANSWER]' for msg in messages: if msg['role'] == ASSISTANT: if msg.get('reasoning_content'): assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages' content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}') if msg.get('content'): assert isinstance(msg['content'], str), 'Now only supports text messages' content.append(f'{ANSWER_S}\n{msg["content"]}') if msg.get('function_call'): content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}') elif msg['role'] == FUNCTION: content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}') else: raise TypeError if content: full_text = '\n'.join(content) return full_text from agent_pool import (get_agent_from_pool, init_global_agent_pool, release_agent_to_pool) from gbase_agent import init_agent_service_universal, update_agent_llm from project_config import project_manager app = FastAPI(title="Database Assistant API", version="1.0.0") # 全局助手实例池,在应用启动时初始化 agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1")) class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] model: str = "qwen3-next" api_key: Optional[str] = None extra: Optional[Dict] = None stream: Optional[bool] = False file_url: Optional[str] = None class ChatResponse(BaseModel): choices: List[Dict] usage: Optional[Dict] = None class ChatStreamResponse(BaseModel): choices: List[Dict] usage: Optional[Dict] = None async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]: """生成流式响应""" accumulated_content = "" accumulated_args = "" chunk_id = 0 try: for response in agent.run(messages=messages): previous_content = accumulated_content accumulated_content = get_content_from_messages(response) # 计算新增的内容 if accumulated_content.startswith(previous_content): new_content = accumulated_content[len(previous_content):] else: new_content = accumulated_content previous_content = "" # 只有当有新内容时才发送chunk if new_content: chunk_id += 1 # 构造OpenAI格式的流式响应 chunk_data = { "id": f"chatcmpl-{chunk_id}", "object": "chat.completion.chunk", "created": int(__import__('time').time()), "model": request.model, "choices": [{ "index": 0, "delta": { "content": new_content }, "finish_reason": None }] } yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n" # 发送最终完成标记 final_chunk = { "id": f"chatcmpl-{chunk_id + 1}", "object": "chat.completion.chunk", "created": int(__import__('time').time()), "model": request.model, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] } yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" # 发送结束标记 yield "data: [DONE]\n\n" except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in generate_stream_response: {str(e)}") print(f"Full traceback: {error_details}") error_data = { "error": { "message": f"Stream error: {str(e)}", "type": "internal_error" } } yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """ Chat completions API similar to OpenAI, supports both streaming and non-streaming Args: request: ChatRequest containing messages, model, project_id in extra field, etc. Returns: Union[ChatResponse, StreamingResponse]: Chat completion response or stream """ agent = None try: # 从extra字段中获取project_id if not request.extra or 'project_id' not in request.extra: raise HTTPException(status_code=400, detail="project_id is required in extra field") project_id = request.extra['project_id'] # 验证项目访问权限 if not project_manager.validate_project_access(project_id): raise HTTPException(status_code=404, detail=f"Project {project_id} not found or inactive") # 获取项目数据目录 project_dir = project_manager.get_project_dir(project_id) # 从实例池获取助手实例 agent = await get_agent_from_pool(timeout=30.0) # 准备LLM配置,从extra字段中移除project_id llm_extra = request.extra.copy() if request.extra else {} llm_extra.pop('project_id', None) # 移除project_id,不传递给LLM # 动态设置请求的模型,支持从接口传入api_key和extra参数 update_agent_llm(agent, request.model, request.api_key, llm_extra) # 构建包含项目信息的消息上下文 messages = [ # 项目信息系统消息 { "role": "user", "content": f"当前项目ID: {project_id},数据目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}" }, # 用户消息批量转换 *[{"role": msg.role, "content": msg.content} for msg in request.messages] ] # 根据stream参数决定返回流式还是非流式响应 if request.stream: return StreamingResponse( generate_stream_response(agent, messages, request), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) else: # 非流式响应 final_responses = agent.run_nonstream(messages) if final_responses and len(final_responses) > 0: # 取最后一个响应 final_response = final_responses[-1] # 如果返回的是Message对象,需要转换为字典 if hasattr(final_response, 'model_dump'): final_response = final_response.model_dump() elif hasattr(final_response, 'dict'): final_response = final_response.dict() content = final_response.get("content", "") # 构造OpenAI格式的响应 return ChatResponse( choices=[{ "index": 0, "message": { "role": "assistant", "content": content }, "finish_reason": "stop" }], usage={ "prompt_tokens": sum(len(msg.content) for msg in request.messages), "completion_tokens": len(content), "total_tokens": sum(len(msg.content) for msg in request.messages) + len(content) } ) else: raise HTTPException(status_code=500, detail="No response from agent") except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in chat_completions: {str(e)}") print(f"Full traceback: {error_details}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") finally: # 确保释放助手实例回池 if agent is not None: await release_agent_to_pool(agent) @app.get("/") async def root(): """Health check endpoint""" return {"message": "Database Assistant API is running"} @app.get("/system/status") async def system_status(): """获取系统状态信息""" from agent_pool import get_agent_pool pool = get_agent_pool() pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0} return { "status": "running", "storage_type": "Agent Pool API", "agent_pool": { "pool_size": pool_stats["pool_size"], "available_agents": pool_stats["available_agents"], "total_agents": pool_stats["total_agents"], "in_use_agents": pool_stats["in_use_agents"] } } @app.on_event("startup") async def startup_event(): """应用启动时初始化助手实例池""" print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...") try: def agent_factory(): return init_agent_service_universal() await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory) print("助手实例池初始化完成!") except Exception as e: print(f"助手实例池初始化失败: {e}") raise @app.on_event("shutdown") async def shutdown_event(): """应用关闭时清理实例池""" print("正在关闭应用,清理助手实例池...") from agent_pool import get_agent_pool pool = get_agent_pool() if pool: await pool.shutdown() print("助手实例池清理完成!") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)