增加bot_manager
This commit is contained in:
parent
f1107ea35a
commit
4c70857ff6
@ -223,6 +223,86 @@ class AgentConfig:
|
|||||||
config.safe_print()
|
config.safe_print()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_v3_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, language: Optional[str] = None):
|
||||||
|
"""从v3请求创建配置 - 从数据库读取所有配置"""
|
||||||
|
# 延迟导入避免循环依赖
|
||||||
|
from .logging_handler import LoggingCallbackHandler
|
||||||
|
from utils.fastapi_utils import get_preamble_text
|
||||||
|
from utils.settings import (
|
||||||
|
MEM0_ENABLED,
|
||||||
|
MEM0_SEMANTIC_SEARCH_TOP_K,
|
||||||
|
)
|
||||||
|
from .checkpoint_utils import prepare_checkpoint_message
|
||||||
|
from .checkpoint_manager import get_checkpointer_manager
|
||||||
|
from utils.log_util.context import g
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# 从全局上下文获取 trace_id
|
||||||
|
trace_id = None
|
||||||
|
try:
|
||||||
|
trace_id = getattr(g, 'trace_id', None)
|
||||||
|
except LookupError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 从数据库配置获取语言(如果没有传递)
|
||||||
|
if language is None:
|
||||||
|
language = bot_config.get("language", "zh")
|
||||||
|
|
||||||
|
# 处理 system_prompt 和 preamble
|
||||||
|
system_prompt_from_db = bot_config.get("system_prompt", "")
|
||||||
|
preamble_text, system_prompt = get_preamble_text(language, system_prompt_from_db)
|
||||||
|
|
||||||
|
# 获取 robot_type
|
||||||
|
robot_type = bot_config.get("robot_type", "general_agent")
|
||||||
|
if robot_type == "catalog_agent":
|
||||||
|
robot_type = "deep_agent"
|
||||||
|
|
||||||
|
# 从数据库配置获取其他参数
|
||||||
|
enable_thinking = bot_config.get("enable_thinking", False)
|
||||||
|
enable_memori = bot_config.get("enable_memori", False)
|
||||||
|
|
||||||
|
config = cls(
|
||||||
|
bot_id=request.bot_id,
|
||||||
|
api_key=bot_config.get("api_key", ""),
|
||||||
|
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||||||
|
model_server=bot_config.get("model_server", ""),
|
||||||
|
language=language,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
mcp_settings=bot_config.get("mcp_settings", []),
|
||||||
|
robot_type=robot_type,
|
||||||
|
user_identifier=bot_config.get("user_identifier", ""),
|
||||||
|
session_id=request.session_id,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
project_dir=project_dir,
|
||||||
|
stream=request.stream,
|
||||||
|
tool_response=bot_config.get("tool_response", True),
|
||||||
|
generate_cfg={}, # v3接口不传递额外的generate_cfg
|
||||||
|
logging_handler=LoggingCallbackHandler(),
|
||||||
|
messages=messages,
|
||||||
|
_origin_messages=messages,
|
||||||
|
preamble_text=preamble_text,
|
||||||
|
dataset_ids=bot_config.get("dataset_ids", []),
|
||||||
|
enable_memori=enable_memori,
|
||||||
|
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||||||
|
trace_id=trace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在创建 config 时尽早准备 checkpoint 消息
|
||||||
|
if config.session_id:
|
||||||
|
try:
|
||||||
|
manager = get_checkpointer_manager()
|
||||||
|
checkpointer = manager.checkpointer
|
||||||
|
if checkpointer:
|
||||||
|
await prepare_checkpoint_message(config, checkpointer)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load checkpointer: {e}")
|
||||||
|
|
||||||
|
config.safe_print()
|
||||||
|
return config
|
||||||
|
|
||||||
def invoke_config(self):
|
def invoke_config(self):
|
||||||
"""返回Langchain需要的配置字典"""
|
"""返回Langchain需要的配置字典"""
|
||||||
config = {}
|
config = {}
|
||||||
|
|||||||
@ -86,6 +86,9 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
|||||||
"""
|
"""
|
||||||
from agent.config_cache import config_cache
|
from agent.config_cache import config_cache
|
||||||
|
|
||||||
|
# 初始化 prompt 为空字符串,避免未定义错误
|
||||||
|
prompt = ""
|
||||||
|
|
||||||
# 获取语言显示名称
|
# 获取语言显示名称
|
||||||
language_display_map = {
|
language_display_map = {
|
||||||
'zh': '中文',
|
'zh': '中文',
|
||||||
|
|||||||
@ -71,7 +71,7 @@ from utils.log_util.logger import init_with_fastapi
|
|||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
# Import route modules
|
# Import route modules
|
||||||
from routes import chat, files, projects, system, skill_manager, database
|
from routes import chat, files, projects, system, skill_manager, database, bot_manager
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -118,7 +118,14 @@ async def lifespan(app: FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Mem0 initialization failed (continuing without): {e}")
|
logger.warning(f"Mem0 initialization failed (continuing without): {e}")
|
||||||
|
|
||||||
# 5. 启动 checkpoint 清理调度器
|
# 5. 初始化 Bot Manager 表
|
||||||
|
try:
|
||||||
|
await bot_manager.init_bot_manager_tables()
|
||||||
|
logger.info("Bot Manager tables initialized")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Bot Manager table initialization failed (non-fatal): {e}")
|
||||||
|
|
||||||
|
# 6. 启动 checkpoint 清理调度器
|
||||||
if CHECKPOINT_CLEANUP_ENABLED:
|
if CHECKPOINT_CLEANUP_ENABLED:
|
||||||
# 启动时立即执行一次清理
|
# 启动时立即执行一次清理
|
||||||
try:
|
try:
|
||||||
@ -175,6 +182,7 @@ app.include_router(projects.router)
|
|||||||
app.include_router(system.router)
|
app.include_router(system.router)
|
||||||
app.include_router(skill_manager.router)
|
app.include_router(skill_manager.router)
|
||||||
app.include_router(database.router)
|
app.include_router(database.router)
|
||||||
|
app.include_router(bot_manager.router)
|
||||||
|
|
||||||
# 注册文件管理API路由
|
# 注册文件管理API路由
|
||||||
app.include_router(file_manager_router)
|
app.include_router(file_manager_router)
|
||||||
|
|||||||
1250
routes/bot_manager.py
Normal file
1250
routes/bot_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -10,10 +10,10 @@ logger = logging.getLogger('app')
|
|||||||
from utils import (
|
from utils import (
|
||||||
Message, ChatRequest, ChatResponse, BatchSaveChatRequest, BatchSaveChatResponse
|
Message, ChatRequest, ChatResponse, BatchSaveChatRequest, BatchSaveChatResponse
|
||||||
)
|
)
|
||||||
from utils.api_models import ChatRequestV2
|
from utils.api_models import ChatRequestV2, ChatRequestV3
|
||||||
from utils.fastapi_utils import (
|
from utils.fastapi_utils import (
|
||||||
process_messages,
|
process_messages,
|
||||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, fetch_bot_config_from_db,
|
||||||
call_preamble_llm,
|
call_preamble_llm,
|
||||||
create_stream_chunk
|
create_stream_chunk
|
||||||
)
|
)
|
||||||
@ -654,6 +654,97 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
|||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/v3/chat/completions")
|
||||||
|
async def chat_completions_v3(request: ChatRequestV3, authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
Chat completions API v3 - 从数据库读取配置
|
||||||
|
|
||||||
|
与 v2 相比,v3 从本地数据库读取所有配置参数,而不是从后端 API。
|
||||||
|
前端只需要传递 bot_id 和 messages,其他配置从数据库自动读取。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: ChatRequestV3 包含 bot_id, messages, stream, session_id
|
||||||
|
authorization: 可选的认证头
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||||||
|
|
||||||
|
Required Parameters:
|
||||||
|
- bot_id: str - 目标机器人ID(用户创建时填写的ID)
|
||||||
|
- messages: List[Message] - 对话消息列表
|
||||||
|
|
||||||
|
Optional Parameters:
|
||||||
|
- stream: bool - 是否流式输出,默认false
|
||||||
|
- session_id: str - 会话ID,用于保存聊天历史
|
||||||
|
|
||||||
|
Configuration (from database):
|
||||||
|
- model: 模型名称
|
||||||
|
- api_key: API密钥
|
||||||
|
- model_server: 模型服务器地址
|
||||||
|
- language: 回复语言
|
||||||
|
- tool_response: 是否包含工具响应
|
||||||
|
- system_prompt: 系统提示词
|
||||||
|
- robot_type: 机器人类型
|
||||||
|
- dataset_ids: 数据集ID列表
|
||||||
|
- mcp_settings: MCP服务器配置
|
||||||
|
- user_identifier: 用户标识符
|
||||||
|
|
||||||
|
Authentication:
|
||||||
|
- 可选的 Authorization header(如果需要验证)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取bot_id(必需参数)
|
||||||
|
bot_id = request.bot_id
|
||||||
|
if not bot_id:
|
||||||
|
raise HTTPException(status_code=400, detail="bot_id is required")
|
||||||
|
|
||||||
|
# 可选的鉴权验证(如果传递了 authorization header)
|
||||||
|
if authorization:
|
||||||
|
expected_token = generate_v2_auth_token(bot_id)
|
||||||
|
provided_token = extract_api_key_from_auth(authorization)
|
||||||
|
if provided_token and provided_token != expected_token:
|
||||||
|
logger.warning(f"Invalid auth token provided for v3 API, but continuing anyway")
|
||||||
|
|
||||||
|
# 从数据库获取机器人配置
|
||||||
|
bot_config = await fetch_bot_config_from_db(bot_id)
|
||||||
|
|
||||||
|
# 构造类 v2 的请求格式
|
||||||
|
# 从数据库配置中提取参数
|
||||||
|
language = bot_config.get("language", "zh")
|
||||||
|
# 创建项目目录(从数据库配置获取)
|
||||||
|
project_dir = create_project_directory(
|
||||||
|
bot_config.get("dataset_ids", []),
|
||||||
|
bot_id,
|
||||||
|
bot_config.get("robot_type", "general_agent"),
|
||||||
|
bot_config.get("skills", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理消息
|
||||||
|
messages = process_messages(request.messages, language)
|
||||||
|
|
||||||
|
# 创建 AgentConfig 对象
|
||||||
|
# 需要构造一个兼容 v2 的配置对象
|
||||||
|
config = await AgentConfig.from_v3_request(
|
||||||
|
request,
|
||||||
|
bot_config,
|
||||||
|
project_dir,
|
||||||
|
messages,
|
||||||
|
language
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用公共的agent创建和响应生成逻辑
|
||||||
|
return await create_agent_and_generate_response(config)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_details = traceback.format_exc()
|
||||||
|
logger.error(f"Error in chat_completions_v3: {str(e)}")
|
||||||
|
logger.error(f"Full traceback: {error_details}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 聊天历史查询接口
|
# 聊天历史查询接口
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@ -68,6 +68,21 @@ class ChatRequestV2(BaseModel):
|
|||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestV3(BaseModel):
|
||||||
|
"""
|
||||||
|
v3 API 请求模型 - 从数据库读取配置
|
||||||
|
|
||||||
|
所有配置参数从数据库读取,前端只需传递:
|
||||||
|
- bot_id: Bot 的用户ID(用于从数据库查找配置)
|
||||||
|
- messages: 对话消息列表
|
||||||
|
- session_id: 可选的会话ID
|
||||||
|
"""
|
||||||
|
messages: List[Message]
|
||||||
|
bot_id: str
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class FileProcessRequest(BaseModel):
|
class FileProcessRequest(BaseModel):
|
||||||
unique_id: str
|
unique_id: str
|
||||||
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
||||||
|
|||||||
@ -446,6 +446,184 @@ async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_bot_config_from_db(bot_user_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
从本地数据库获取机器人配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_user_id: Bot 的用户ID(bot_id 字段,不是 UUID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 包含所有配置参数的字典,格式与 fetch_bot_config 兼容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from agent.db_pool_manager import get_db_pool_manager
|
||||||
|
|
||||||
|
pool = get_db_pool_manager().pool
|
||||||
|
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
# 首先根据 bot_user_id 查找 bot 的 UUID
|
||||||
|
await cursor.execute(
|
||||||
|
"SELECT id, name FROM bots WHERE bot_id = %s",
|
||||||
|
(bot_user_id,)
|
||||||
|
)
|
||||||
|
bot_row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if not bot_row:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Bot with bot_id '{bot_user_id}' not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
bot_uuid = bot_row[0]
|
||||||
|
|
||||||
|
# 查询 bot_settings
|
||||||
|
await cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT model_id,
|
||||||
|
language, robot_type, dataset_ids, system_prompt, user_identifier,
|
||||||
|
enable_memori, tool_response, skills
|
||||||
|
FROM bot_settings WHERE bot_id = %s
|
||||||
|
""",
|
||||||
|
(bot_uuid,)
|
||||||
|
)
|
||||||
|
settings_row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if not settings_row:
|
||||||
|
# 没有设置,使用默认值
|
||||||
|
logger.warning(f"No settings found for bot {bot_user_id}, using defaults")
|
||||||
|
return {
|
||||||
|
"model": "qwen3-next",
|
||||||
|
"api_key": "",
|
||||||
|
"model_server": "",
|
||||||
|
"language": "zh",
|
||||||
|
"robot_type": "general_agent",
|
||||||
|
"dataset_ids": [],
|
||||||
|
"system_prompt": "",
|
||||||
|
"user_identifier": "",
|
||||||
|
"enable_memori": False,
|
||||||
|
"tool_response": True,
|
||||||
|
"skills": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
columns = [
|
||||||
|
'model_id',
|
||||||
|
'language', 'robot_type', 'dataset_ids', 'system_prompt', 'user_identifier',
|
||||||
|
'enable_memori', 'tool_response', 'skills'
|
||||||
|
]
|
||||||
|
config = dict(zip(columns, settings_row))
|
||||||
|
|
||||||
|
# 根据 model_id 查询模型信息
|
||||||
|
model_id = config['model_id']
|
||||||
|
if model_id:
|
||||||
|
await cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT model, server, api_key
|
||||||
|
FROM models WHERE id = %s
|
||||||
|
""",
|
||||||
|
(model_id,)
|
||||||
|
)
|
||||||
|
model_row = await cursor.fetchone()
|
||||||
|
if model_row:
|
||||||
|
config['model'] = model_row[0]
|
||||||
|
config['model_server'] = model_row[1]
|
||||||
|
config['api_key'] = model_row[2]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model with id {model_id} not found, using defaults")
|
||||||
|
config['model'] = "qwen3-next"
|
||||||
|
config['model_server'] = ""
|
||||||
|
config['api_key'] = ""
|
||||||
|
else:
|
||||||
|
# 没有选择模型,使用默认值
|
||||||
|
config['model'] = "qwen3-next"
|
||||||
|
config['model_server'] = ""
|
||||||
|
config['api_key'] = ""
|
||||||
|
|
||||||
|
# 处理 dataset_ids (可能是 JSON 数组字符串或逗号分隔字符串)
|
||||||
|
dataset_ids = config['dataset_ids']
|
||||||
|
if dataset_ids:
|
||||||
|
if isinstance(dataset_ids, str):
|
||||||
|
if dataset_ids.startswith('['):
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
config['dataset_ids'] = json.loads(dataset_ids)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
|
||||||
|
else:
|
||||||
|
config['dataset_ids'] = [d.strip() for d in dataset_ids.split(',')]
|
||||||
|
else:
|
||||||
|
config['dataset_ids'] = []
|
||||||
|
|
||||||
|
# 处理 skills (逗号分隔字符串)
|
||||||
|
skills = config.get('skills', '')
|
||||||
|
if skills:
|
||||||
|
if isinstance(skills, str):
|
||||||
|
config['skills'] = [s.strip() for s in skills.split(',') if s.strip()]
|
||||||
|
else:
|
||||||
|
config['skills'] = []
|
||||||
|
else:
|
||||||
|
config['skills'] = []
|
||||||
|
|
||||||
|
# 查询 MCP 服务器配置
|
||||||
|
await cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT name, type, config, enabled
|
||||||
|
FROM mcp_servers WHERE bot_id = %s AND enabled = true
|
||||||
|
""",
|
||||||
|
(bot_uuid,)
|
||||||
|
)
|
||||||
|
mcp_rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
mcp_servers = []
|
||||||
|
for mcp_row in mcp_rows:
|
||||||
|
mcp_name = mcp_row[0]
|
||||||
|
mcp_type = mcp_row[1]
|
||||||
|
mcp_config = mcp_row[2]
|
||||||
|
|
||||||
|
# 如果 config 是 JSONB/字符串,解析它
|
||||||
|
if isinstance(mcp_config, str):
|
||||||
|
try:
|
||||||
|
mcp_config = json.loads(mcp_config)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
mcp_config = {}
|
||||||
|
|
||||||
|
mcp_servers.append({
|
||||||
|
"name": mcp_name,
|
||||||
|
"type": mcp_type,
|
||||||
|
"config": mcp_config
|
||||||
|
})
|
||||||
|
|
||||||
|
# 格式化为 mcp_settings 格式 (兼容 v2 API)
|
||||||
|
if mcp_servers:
|
||||||
|
mcp_settings_value = []
|
||||||
|
for server in mcp_servers:
|
||||||
|
server_config = server.get("config", {})
|
||||||
|
server_type = server_config.pop("server_type", server["type"])
|
||||||
|
mcp_settings_value.append({
|
||||||
|
"mcpServers": {
|
||||||
|
server_type: server_config
|
||||||
|
}
|
||||||
|
})
|
||||||
|
config["mcp_settings"] = mcp_settings_value
|
||||||
|
else:
|
||||||
|
config["mcp_settings"] = []
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching bot config from database: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to fetch bot config from database: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _sync_call_llm(llm_config, messages) -> str:
|
async def _sync_call_llm(llm_config, messages) -> str:
|
||||||
"""同步调用LLM的辅助函数,在线程池中执行 - 使用LangChain"""
|
"""同步调用LLM的辅助函数,在线程池中执行 - 使用LangChain"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user