227 lines
8.1 KiB
Python
227 lines
8.1 KiB
Python
import json
|
|
import os
|
|
import uuid
|
|
import time
|
|
import multiprocessing
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
|
|
# ========== Suppress Pydantic warnings from third-party libraries ==========
|
|
# Warnings caused by libraries such as langgraph-checkpoint-postgres using typing.NotRequired
|
|
import warnings
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message=".*typing.NotRequired is not a Python type.*",
|
|
category=UserWarning
|
|
)
|
|
# ========== End warning suppression ==========
|
|
|
|
# ========== Monkey patch: must run before all other imports ==========
|
|
# Replace mem0's remove_code_blocks function with json_repair
|
|
# This must run before importing any mem0 modules
|
|
import logging
|
|
_patch_logger = logging.getLogger('app')
|
|
|
|
try:
|
|
import json_repair
|
|
import re
|
|
|
|
def _remove_code_blocks_with_repair(content: str) -> str:
|
|
"""Replace mem0's remove_code_blocks function with json_repair."""
|
|
content_stripped = content.strip()
|
|
try:
|
|
result = json_repair.loads(content_stripped)
|
|
if isinstance(result, (dict, list)):
|
|
import json
|
|
return json.dumps(result, ensure_ascii=False)
|
|
if result == "" and content_stripped != "":
|
|
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
|
match = re.match(pattern, content_stripped)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return content_stripped
|
|
return str(result)
|
|
except Exception:
|
|
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
|
match = re.match(pattern, content_stripped)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return content_stripped
|
|
|
|
# Patch mem0.memory.utils (source)
|
|
import mem0.memory.utils
|
|
mem0.memory.utils.remove_code_blocks = _remove_code_blocks_with_repair
|
|
|
|
# Patch mem0.memory.main (if already imported, replace its local reference)
|
|
# Note: its local reference can only be patched after this module is imported
|
|
import sys
|
|
if 'mem0.memory.main' in sys.modules:
|
|
import mem0.memory.main
|
|
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
|
|
_patch_logger.info("Successfully patched mem0.memory.main.remove_code_blocks")
|
|
else:
|
|
# If it has not been imported yet, set an import hook
|
|
_patch_logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
|
|
except ImportError:
|
|
pass # json_repair or mem0 is not installed
|
|
except Exception as e:
|
|
_patch_logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
|
|
# ========== End Monkey patch ==========
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from routes.file_manager import router as file_manager_router
|
|
import logging
|
|
|
|
from utils.log_util.logger import init_with_fastapi
|
|
|
|
# Initialize logger
|
|
logger = logging.getLogger('app')
|
|
|
|
# Import route modules
|
|
from routes import chat, files, projects, system, skill_manager, database, memory
|
|
from routes.webdav import wsgidav_app
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage the FastAPI application lifespan."""
|
|
# Initialize on startup
|
|
logger.info("Starting up...")
|
|
from agent.db_pool_manager import (
|
|
init_global_db_pool,
|
|
get_db_pool_manager,
|
|
close_global_db_pool
|
|
)
|
|
from agent.checkpoint_manager import (
|
|
init_global_checkpointer,
|
|
close_global_checkpointer
|
|
)
|
|
from agent.chat_history_manager import (
|
|
init_chat_history_manager,
|
|
close_chat_history_manager
|
|
)
|
|
from agent.mem0_manager import (
|
|
init_global_mem0,
|
|
close_global_mem0
|
|
)
|
|
from utils.settings import CHECKPOINT_CLEANUP_ENABLED, MEM0_ENABLED
|
|
from utils.settings import SCHEDULE_ENABLED
|
|
|
|
# 1. Initialize the shared database connection pool
|
|
db_pool_manager = await init_global_db_pool()
|
|
logger.info("Global DB pool initialized")
|
|
|
|
# 2. Initialize checkpointing (using the shared connection pool)
|
|
await init_global_checkpointer(db_pool_manager.pool)
|
|
logger.info("Global checkpointer initialized")
|
|
|
|
# 3. Initialize chat history (using the shared connection pool)
|
|
await init_chat_history_manager(db_pool_manager.pool)
|
|
logger.info("Chat history manager initialized")
|
|
|
|
# 4. Initialize the Mem0 long-term memory system (if enabled)
|
|
if MEM0_ENABLED:
|
|
try:
|
|
await init_global_mem0(sync_pool=db_pool_manager.sync_pool)
|
|
logger.info("Mem0 long-term memory initialized")
|
|
except Exception as e:
|
|
logger.warning(f"Mem0 initialization failed (continuing without): {e}")
|
|
|
|
# 5. Start the checkpoint cleanup scheduler
|
|
if CHECKPOINT_CLEANUP_ENABLED:
|
|
# Run cleanup immediately on startup
|
|
try:
|
|
result = await db_pool_manager.cleanup_old_checkpoints()
|
|
logger.info(f"Startup cleanup completed: {result}")
|
|
except Exception as e:
|
|
logger.warning(f"Startup cleanup failed (non-fatal): {e}")
|
|
# Start the scheduled cleanup task
|
|
db_pool_manager.start_cleanup_scheduler()
|
|
logger.info("Checkpoint cleanup scheduler started")
|
|
|
|
# 6. Start the scheduled task executor
|
|
schedule_executor = None
|
|
if SCHEDULE_ENABLED:
|
|
try:
|
|
from services.schedule_executor import get_schedule_executor
|
|
schedule_executor = get_schedule_executor()
|
|
schedule_executor.start()
|
|
logger.info("Schedule executor started")
|
|
except Exception as e:
|
|
logger.warning(f"Schedule executor start failed (non-fatal): {e}")
|
|
|
|
yield
|
|
|
|
# Clean up on shutdown in reverse order
|
|
logger.info("Shutting down...")
|
|
# Stop the scheduled task executor
|
|
if schedule_executor:
|
|
try:
|
|
await schedule_executor.stop()
|
|
logger.info("Schedule executor stopped")
|
|
except Exception as e:
|
|
logger.warning(f"Schedule executor stop failed (non-fatal): {e}")
|
|
# Close Mem0
|
|
if MEM0_ENABLED:
|
|
try:
|
|
await close_global_mem0()
|
|
logger.info("Mem0 long-term memory closed")
|
|
except Exception as e:
|
|
logger.warning(f"Mem0 close failed (non-fatal): {e}")
|
|
await close_chat_history_manager()
|
|
logger.info("Chat history manager closed")
|
|
await close_global_checkpointer()
|
|
logger.info("Global checkpointer closed")
|
|
await close_global_db_pool()
|
|
logger.info("Global DB pool closed")
|
|
|
|
|
|
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
|
|
|
init_with_fastapi(app)
|
|
|
|
# Mount the public directory as static files
|
|
app.mount("/public", StaticFiles(directory="public"), name="static")
|
|
|
|
# Mount robot projects directory as static files (supports HTML/CSS/JS/images)
|
|
app.mount("/robots", StaticFiles(directory="projects/robot", html=True), name="robots")
|
|
|
|
# Add CORS middleware for frontend pages
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, this should be set to specific frontend domains
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
allow_headers=[
|
|
"Authorization", "Content-Type", "Accept", "Origin", "User-Agent",
|
|
"DNT", "Cache-Control", "Range", "X-Requested-With"
|
|
],
|
|
)
|
|
|
|
# Include all route modules
|
|
app.include_router(chat.router)
|
|
app.include_router(files.router)
|
|
app.include_router(projects.router)
|
|
app.include_router(system.router)
|
|
app.include_router(skill_manager.router)
|
|
app.include_router(database.router)
|
|
app.include_router(memory.router)
|
|
|
|
# Register the file management API routes
|
|
app.include_router(file_manager_router)
|
|
|
|
# Mount WsgiDAV (the WSGI app is integrated into ASGI via WSGIMiddleware)
|
|
from starlette.middleware.wsgi import WSGIMiddleware
|
|
app.mount("/webdav", WSGIMiddleware(wsgidav_app))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Start the FastAPI application
|
|
logger.info("Starting FastAPI server...")
|
|
logger.info("File Manager API available at: http://localhost:8001/api/v1/files")
|
|
logger.info("Web Interface available at: http://localhost:8001/public/file-manager.html")
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|