199 lines
6.8 KiB
Python
199 lines
6.8 KiB
Python
import json
|
||
import os
|
||
import uuid
|
||
import time
|
||
import multiprocessing
|
||
import sys
|
||
from contextlib import asynccontextmanager
|
||
|
||
# ========== 抑制第三方库的 Pydantic 警告 ==========
|
||
# langgraph-checkpoint-postgres 等库使用 typing.NotRequired 导致的警告
|
||
import warnings
|
||
warnings.filterwarnings(
|
||
"ignore",
|
||
message=".*typing.NotRequired is not a Python type.*",
|
||
category=UserWarning
|
||
)
|
||
# ========== End 抑制警告 ==========
|
||
|
||
# ========== Monkey patch: 必须在所有其他导入之前执行 ==========
|
||
# 使用 json_repair 替换 mem0 的 remove_code_blocks 函数
|
||
# 这必须在导入任何 mem0 模块之前执行
|
||
import logging
|
||
_patch_logger = logging.getLogger('app')
|
||
|
||
try:
|
||
import json_repair
|
||
import re
|
||
|
||
def _remove_code_blocks_with_repair(content: str) -> str:
|
||
"""使用 json_repair 替换 mem0 的 remove_code_blocks 函数"""
|
||
content_stripped = content.strip()
|
||
try:
|
||
result = json_repair.loads(content_stripped)
|
||
if isinstance(result, (dict, list)):
|
||
import json
|
||
return json.dumps(result, ensure_ascii=False)
|
||
if result == "" and content_stripped != "":
|
||
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
||
match = re.match(pattern, content_stripped)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content_stripped
|
||
return str(result)
|
||
except Exception:
|
||
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
||
match = re.match(pattern, content_stripped)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content_stripped
|
||
|
||
# Patch mem0.memory.utils (源头)
|
||
import mem0.memory.utils
|
||
mem0.memory.utils.remove_code_blocks = _remove_code_blocks_with_repair
|
||
|
||
# Patch mem0.memory.main (如果已导入,替换其本地引用)
|
||
# 注意:必须在此模块导入后才能 patch 其本地引用
|
||
import sys
|
||
if 'mem0.memory.main' in sys.modules:
|
||
import mem0.memory.main
|
||
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
|
||
_patch_logger.info("Successfully patched mem0.memory.main.remove_code_blocks")
|
||
else:
|
||
# 如果还没导入,设置一个导入钩子
|
||
_patch_logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
|
||
except ImportError:
|
||
pass # json_repair 或 mem0 未安装
|
||
except Exception as e:
|
||
_patch_logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
|
||
# ========== End Monkey patch ==========
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from routes.file_manager import router as file_manager_router
|
||
import logging
|
||
|
||
from utils.log_util.logger import init_with_fastapi
|
||
|
||
# Initialize logger
|
||
logger = logging.getLogger('app')
|
||
|
||
# Import route modules
|
||
from routes import chat, files, projects, system, skill_manager, database
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""FastAPI 应用生命周期管理"""
|
||
# 启动时初始化
|
||
logger.info("Starting up...")
|
||
from agent.db_pool_manager import (
|
||
init_global_db_pool,
|
||
get_db_pool_manager,
|
||
close_global_db_pool
|
||
)
|
||
from agent.checkpoint_manager import (
|
||
init_global_checkpointer,
|
||
close_global_checkpointer
|
||
)
|
||
from agent.chat_history_manager import (
|
||
init_chat_history_manager,
|
||
close_chat_history_manager
|
||
)
|
||
from agent.mem0_manager import (
|
||
init_global_mem0,
|
||
close_global_mem0
|
||
)
|
||
from utils.settings import CHECKPOINT_CLEANUP_ENABLED, MEM0_ENABLED
|
||
|
||
# 1. 初始化共享的数据库连接池
|
||
db_pool_manager = await init_global_db_pool()
|
||
logger.info("Global DB pool initialized")
|
||
|
||
# 2. 初始化 checkpoint (使用共享连接池)
|
||
await init_global_checkpointer(db_pool_manager.pool)
|
||
logger.info("Global checkpointer initialized")
|
||
|
||
# 3. 初始化 chat_history (使用共享连接池)
|
||
await init_chat_history_manager(db_pool_manager.pool)
|
||
logger.info("Chat history manager initialized")
|
||
|
||
# 4. 初始化 Mem0 长期记忆系统 (如果启用)
|
||
if MEM0_ENABLED:
|
||
try:
|
||
await init_global_mem0(sync_pool=db_pool_manager.sync_pool)
|
||
logger.info("Mem0 long-term memory initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Mem0 initialization failed (continuing without): {e}")
|
||
|
||
# 5. 启动 checkpoint 清理调度器
|
||
if CHECKPOINT_CLEANUP_ENABLED:
|
||
# 启动时立即执行一次清理
|
||
try:
|
||
result = await db_pool_manager.cleanup_old_checkpoints()
|
||
logger.info(f"Startup cleanup completed: {result}")
|
||
except Exception as e:
|
||
logger.warning(f"Startup cleanup failed (non-fatal): {e}")
|
||
# 启动定时清理调度器
|
||
db_pool_manager.start_cleanup_scheduler()
|
||
logger.info("Checkpoint cleanup scheduler started")
|
||
|
||
yield
|
||
|
||
# 关闭时清理(按相反顺序)
|
||
logger.info("Shutting down...")
|
||
# 关闭 Mem0
|
||
if MEM0_ENABLED:
|
||
try:
|
||
await close_global_mem0()
|
||
logger.info("Mem0 long-term memory closed")
|
||
except Exception as e:
|
||
logger.warning(f"Mem0 close failed (non-fatal): {e}")
|
||
await close_chat_history_manager()
|
||
logger.info("Chat history manager closed")
|
||
await close_global_checkpointer()
|
||
logger.info("Global checkpointer closed")
|
||
await close_global_db_pool()
|
||
logger.info("Global DB pool closed")
|
||
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
||
|
||
init_with_fastapi(app)
|
||
|
||
# 挂载public文件夹为静态文件服务
|
||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||
|
||
# 添加CORS中间件,支持前端页面
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||
allow_credentials=True,
|
||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
||
allow_headers=[
|
||
"Authorization", "Content-Type", "Accept", "Origin", "User-Agent",
|
||
"DNT", "Cache-Control", "Range", "X-Requested-With"
|
||
],
|
||
)
|
||
|
||
# Include all route modules
|
||
app.include_router(chat.router)
|
||
app.include_router(files.router)
|
||
app.include_router(projects.router)
|
||
app.include_router(system.router)
|
||
app.include_router(skill_manager.router)
|
||
app.include_router(database.router)
|
||
|
||
# 注册文件管理API路由
|
||
app.include_router(file_manager_router)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 启动 FastAPI 应用
|
||
logger.info("Starting FastAPI server...")
|
||
logger.info("File Manager API available at: http://localhost:8001/api/v1/files")
|
||
logger.info("Web Interface available at: http://localhost:8001/public/file-manager.html")
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|