add AgentConfig

This commit is contained in:
朱潮 2025-12-16 16:06:47 +08:00
parent 73b87bd2eb
commit 9525c0f883
6 changed files with 282 additions and 257 deletions

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import os import os
import sqlite3 import sqlite3
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, List
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
# from deepagents import create_deep_agent # from deepagents import create_deep_agent
from langchain.agents import create_agent from langchain.agents import create_agent
@ -15,6 +15,8 @@ from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
from utils.agent_config import AgentConfig
class LoggingCallbackHandler(BaseCallbackHandler): class LoggingCallbackHandler(BaseCallbackHandler):
"""自定义的 CallbackHandler使用项目的 logger 来记录日志""" """自定义的 CallbackHandler使用项目的 logger 来记录日志"""
@ -22,17 +24,6 @@ class LoggingCallbackHandler(BaseCallbackHandler):
def __init__(self, logger_name: str = 'app'): def __init__(self, logger_name: str = 'app'):
self.logger = logging.getLogger(logger_name) self.logger = logging.getLogger(logger_name)
# def on_llm_start(
# self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any
# ) -> None:
# """当 LLM 开始时调用"""
# self.logger.info("🤖 LLM Start - Input Messages:")
# if prompts:
# for i, prompt in enumerate(prompts):
# self.logger.info(f" Message {i+1}:\n{prompt}")
# else:
# self.logger.info(" No prompts")
def on_llm_end(self, response, **kwargs: Any) -> None: def on_llm_end(self, response, **kwargs: Any) -> None:
"""当 LLM 结束时调用""" """当 LLM 结束时调用"""
self.logger.info("✅ LLM End - Output:") self.logger.info("✅ LLM End - Output:")
@ -78,7 +69,6 @@ class LoggingCallbackHandler(BaseCallbackHandler):
"""当工具调用出错时调用""" """当工具调用出错时调用"""
self.logger.error(f"❌ Tool Error: {error}") self.logger.error(f"❌ Tool Error: {error}")
def on_agent_action(self, action, **kwargs: Any) -> None: def on_agent_action(self, action, **kwargs: Any) -> None:
"""当 Agent 执行动作时调用""" """当 Agent 执行动作时调用"""
self.logger.info(f"🎯 Agent Action: {action.log}") self.logger.info(f"🎯 Agent Action: {action.log}")
@ -97,6 +87,7 @@ def read_mcp_settings():
mcp_settings_json = json.load(f) mcp_settings_json = json.load(f)
return mcp_settings_json return mcp_settings_json
async def get_tools_from_mcp(mcp): async def get_tools_from_mcp(mcp):
"""从MCP配置中提取工具""" """从MCP配置中提取工具"""
# 防御式处理:确保 mcp 是列表且长度大于 0且包含 mcpServers # 防御式处理:确保 mcp 是列表且长度大于 0且包含 mcpServers
@ -123,66 +114,60 @@ async def get_tools_from_mcp(mcp):
# 发生异常时返回空列表,避免上层调用报错 # 发生异常时返回空列表,避免上层调用报错
return [] return []
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
model_server=None, generate_cfg=None, async def init_agent(config: AgentConfig):
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None,
session_id=None):
""" """
初始化 Agent支持持久化内存和对话摘要 初始化 Agent支持持久化内存和对话摘要
Args: Args:
bot_id: Bot ID config: AgentConfig 对象包含所有初始化参数
model_name: 模型名称 mcp: MCP配置如果为None则使用配置中的mcp_settings
api_key: API密钥
model_server: 模型服务器地址
generate_cfg: 生成配置
system_prompt: 系统提示
mcp: MCP配置
robot_type: 机器人类型
language: 语言
user_identifier: 用户标识
session_id: 会话ID如果为None则不启用持久化内存
""" """
system = system_prompt if system_prompt else read_system_prompt() # 如果没有提供mcp使用config中的mcp_settings
mcp = mcp if mcp else read_mcp_settings() mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings()
system = config.system_prompt if config.system_prompt else read_system_prompt()
mcp_tools = await get_tools_from_mcp(mcp) mcp_tools = await get_tools_from_mcp(mcp)
# 检测或使用指定的提供商 # 检测或使用指定的提供商
model_provider,base_url = detect_provider(model_name,model_server) model_provider,base_url = detect_provider(config.model_name, config.model_server)
# 构建模型参数 # 构建模型参数
model_kwargs = { model_kwargs = {
"model": model_name, "model": config.model_name,
"model_provider": model_provider, "model_provider": model_provider,
"temperature": 0.8, "temperature": 0.8,
"base_url": base_url, "base_url": base_url,
"api_key": api_key "api_key": config.api_key
} }
if generate_cfg: if config.generate_cfg:
model_kwargs.update(generate_cfg) model_kwargs.update(config.generate_cfg)
llm_instance = init_chat_model(**model_kwargs) llm_instance = init_chat_model(**model_kwargs)
# 创建自定义的日志处理器 # 创建自定义的日志处理器
logging_handler = LoggingCallbackHandler() logging_handler = LoggingCallbackHandler()
# 构建中间件列表 # 构建中间件列表
middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] middleware = []
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
if config.enable_thinking:
middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier))
# 添加工具输出长度控制中间件 # 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware( tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH, max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'), truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具 tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具 exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True), preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(generate_cfg, 'preserve_json', True) preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
) )
middleware.append(tool_output_middleware) middleware.append(tool_output_middleware)
# 初始化 checkpointer 和中间件 # 初始化 checkpointer 和中间件
checkpointer = None checkpointer = None
if session_id: if config.session_id:
checkpointer = MemorySaver() checkpointer = MemorySaver()
summarization_middleware = SummarizationMiddleware( summarization_middleware = SummarizationMiddleware(
model=llm_instance, model=llm_instance,
@ -203,6 +188,6 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用 # 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
agent.logging_handler = logging_handler agent.logging_handler = logging_handler
agent.checkpointer = checkpointer agent.checkpointer = checkpointer
agent.bot_id = bot_id agent.bot_id = config.bot_id
agent.session_id = session_id agent.session_id = config.session_id
return agent return agent

View File

@ -26,6 +26,7 @@ logger = logging.getLogger('app')
from agent.deep_assistant import init_agent from agent.deep_assistant import init_agent
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
from utils.agent_config import AgentConfig
class ShardedAgentManager: class ShardedAgentManager:
@ -67,7 +68,8 @@ class ShardedAgentManager:
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None, def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
model_server: str = None, generate_cfg: Dict = None, model_server: str = None, generate_cfg: Dict = None,
system_prompt: str = None, mcp_settings: List[Dict] = None) -> str: system_prompt: str = None, mcp_settings: List[Dict] = None,
enable_thinking: bool = False) -> str:
"""获取包含所有相关参数的哈希值作为缓存键""" """获取包含所有相关参数的哈希值作为缓存键"""
cache_data = { cache_data = {
'bot_id': bot_id, 'bot_id': bot_id,
@ -76,7 +78,8 @@ class ShardedAgentManager:
'model_server': model_server or '', 'model_server': model_server or '',
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True), 'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
'system_prompt': system_prompt or '', 'system_prompt': system_prompt or '',
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True) 'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
'enable_thinking': enable_thinking
} }
cache_str = json.dumps(cache_data, sort_keys=True) cache_str = json.dumps(cache_data, sort_keys=True)
@ -116,20 +119,12 @@ class ShardedAgentManager:
if removed_count > 0: if removed_count > 0:
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存") logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
async def get_or_create_agent(self, async def get_or_create_agent(self, config: AgentConfig):
bot_id: str, """获取或创建文件预加载的助手实例
project_dir: Optional[str],
model_name: str = "qwen3-next", Args:
api_key: Optional[str] = None, config: AgentConfig 对象包含所有初始化参数
model_server: Optional[str] = None, """
generate_cfg: Optional[Dict] = None,
language: Optional[str] = None,
system_prompt: Optional[str] = None,
mcp_settings: Optional[List[Dict]] = None,
robot_type: Optional[str] = "general_agent",
user_identifier: Optional[str] = None,
session_id: Optional[str] = None):
"""获取或创建文件预加载的助手实例"""
# 更新请求统计 # 更新请求统计
with self._stats_lock: with self._stats_lock:
@ -137,14 +132,16 @@ class ShardedAgentManager:
# 异步加载配置文件(带缓存) # 异步加载配置文件(带缓存)
final_system_prompt = await load_system_prompt_async( final_system_prompt = await load_system_prompt_async(
project_dir, language, system_prompt, robot_type, bot_id, user_identifier config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
) )
final_mcp_settings = await load_mcp_settings_async( final_mcp_settings = await load_mcp_settings_async(
project_dir, mcp_settings, bot_id, robot_type config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
) )
config.system_prompt = final_system_prompt
cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server, config.mcp_settings = final_mcp_settings
generate_cfg, final_system_prompt, final_mcp_settings) cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
config.generate_cfg, final_system_prompt, final_mcp_settings,
config.enable_thinking)
# 获取分片 # 获取分片
shard_index = self._get_shard_index(cache_key) shard_index = self._get_shard_index(cache_key)
@ -160,7 +157,7 @@ class ShardedAgentManager:
with self._stats_lock: with self._stats_lock:
self._global_stats['cache_hits'] += 1 self._global_stats['cache_hits'] += 1
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id}, shard: {shard_index})") logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
return agent return agent
# 更新缓存未命中统计 # 更新缓存未命中统计
@ -188,27 +185,15 @@ class ShardedAgentManager:
self._cleanup_old_agents(shard) self._cleanup_old_agents(shard)
# 创建新的助手实例 # 创建新的助手实例
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}") logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
current_time = time.time() current_time = time.time()
agent = await init_agent( agent = await init_agent(config)
bot_id=bot_id,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
system_prompt=final_system_prompt,
mcp=final_mcp_settings,
robot_type=robot_type,
language=language,
user_identifier=user_identifier,
session_id=session_id
)
# 缓存实例 # 缓存实例
async with shard['lock']: async with shard['lock']:
shard['agents'][cache_key] = agent shard['agents'][cache_key] = agent
shard['unique_ids'][cache_key] = bot_id shard['unique_ids'][cache_key] = config.bot_id
shard['access_times'][cache_key] = current_time shard['access_times'][cache_key] = current_time
shard['creation_times'][cache_key] = current_time shard['creation_times'][cache_key] = current_time

View File

@ -14,13 +14,14 @@ from utils import (
from agent.sharded_agent_manager import init_global_sharded_agent_manager from agent.sharded_agent_manager import init_global_sharded_agent_manager
from utils.api_models import ChatRequestV2 from utils.api_models import ChatRequestV2
from utils.fastapi_utils import ( from utils.fastapi_utils import (
process_messages, format_messages_to_chat_history, process_messages,
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
call_preamble_llm, get_preamble_text, get_user_last_message_content, call_preamble_llm, get_preamble_text,
create_stream_chunk create_stream_chunk
) )
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
from utils.agent_config import AgentConfig
router = APIRouter() router = APIRouter()
@ -69,30 +70,45 @@ def append_assistant_last_message(messages: list, content: str) -> bool:
messages.append({"role":"assistant","content":content}) messages.append({"role":"assistant","content":content})
return messages return messages
def get_user_last_message_content(messages: list) -> str:
"""获取最后一条用户消息的内容"""
if not messages:
return ""
for msg in reversed(messages):
if msg.get('role') == 'user':
return msg.get('content', '')
return ""
def format_messages_to_chat_history(messages: list) -> str:
"""将消息格式化为聊天历史字符串"""
chat_history = ""
for msg in messages:
role = msg.get('role', '')
content = msg.get('content', '')
if role == 'user':
chat_history += f"用户: {content}\n"
elif role == 'assistant':
chat_history += f"助手: {content}\n"
return chat_history
async def enhanced_generate_stream_response( async def enhanced_generate_stream_response(
agent_manager, agent_manager,
bot_id: str,
api_key: str,
messages: list, messages: list,
tool_response: bool, config: AgentConfig
model_name: str,
model_server: str,
language: str,
system_prompt: str,
mcp_settings: Optional[list],
robot_type: str,
project_dir: Optional[str],
generate_cfg: Optional[dict],
user_identifier: Optional[str],
session_id: Optional[str] = None,
): ):
"""增强的渐进式流式响应生成器 - 并发优化版本""" """增强的渐进式流式响应生成器 - 并发优化版本
try:
# 准备参数
query_text = get_user_last_message_content(messages)
chat_history = format_messages_to_chat_history(messages)
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
Args:
agent_manager: agent管理器
messages: 消息列表
config: AgentConfig 对象包含所有参数
"""
try:
# 创建输出队列和控制事件 # 创建输出队列和控制事件
output_queue = asyncio.Queue() output_queue = asyncio.Queue()
preamble_completed = asyncio.Event() preamble_completed = asyncio.Event()
@ -100,11 +116,11 @@ async def enhanced_generate_stream_response(
# Preamble 任务 # Preamble 任务
async def preamble_task(): async def preamble_task():
try: try:
preamble_result = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server) preamble_result = await call_preamble_llm(messages,config)
# 只有当preamble_text不为空且不为"<empty>"时才输出 # 只有当preamble_text不为空且不为"<empty>"时才输出
if preamble_result and preamble_result.strip() and preamble_result != "<empty>": if preamble_result and preamble_result.strip() and preamble_result != "<empty>":
preamble_content = f"[PREAMBLE]\n{preamble_result}\n" preamble_content = f"[PREAMBLE]\n{preamble_result}\n"
chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content) chunk_data = create_stream_chunk(f"chatcmpl-preamble", config.model_name, preamble_content)
await output_queue.put(("preamble", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n")) await output_queue.put(("preamble", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
logger.info(f"Stream mode: Generated preamble text ({len(preamble_result)} chars)") logger.info(f"Stream mode: Generated preamble text ({len(preamble_result)} chars)")
else: else:
@ -124,32 +140,19 @@ async def enhanced_generate_stream_response(
async def agent_task(): async def agent_task():
try: try:
# 准备 agent # 准备 agent
agent = await agent_manager.get_or_create_agent( agent = await agent_manager.get_or_create_agent(config)
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier,
session_id=session_id,
)
# 开始流式处理 # 开始流式处理
logger.info(f"Starting agent stream response") logger.info(f"Starting agent stream response")
chunk_id = 0 chunk_id = 0
message_tag = "" message_tag = ""
config = {} stream_config = {}
if session_id: if config.session_id:
config["configurable"] = {"thread_id": session_id} stream_config["configurable"] = {"thread_id": config.session_id}
if hasattr(agent, 'logging_handler'): if hasattr(agent, 'logging_handler'):
config["callbacks"] = [agent.logging_handler] stream_config["callbacks"] = [agent.logging_handler]
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS): async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=stream_config, max_tokens=MAX_OUTPUT_TOKENS):
new_content = "" new_content = ""
if isinstance(msg, AIMessageChunk): if isinstance(msg, AIMessageChunk):
@ -174,7 +177,7 @@ async def enhanced_generate_stream_response(
new_content += msg.text new_content += msg.text
# 处理工具响应 # 处理工具响应
elif isinstance(msg, ToolMessage) and tool_response and msg.content: elif isinstance(msg, ToolMessage) and config.tool_response and msg.content:
message_tag = "TOOL_RESPONSE" message_tag = "TOOL_RESPONSE"
new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n" new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n"
@ -183,11 +186,11 @@ async def enhanced_generate_stream_response(
if chunk_id == 0: if chunk_id == 0:
logger.info(f"Agent首个Token已生成, 开始流式输出") logger.info(f"Agent首个Token已生成, 开始流式输出")
chunk_id += 1 chunk_id += 1
chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", model_name, new_content) chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", config.model_name, new_content)
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n")) await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
# 发送最终chunk # 发送最终chunk
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", model_name, finish_reason="stop") final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop")
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n")) await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
await output_queue.put(("agent_done", None)) await output_queue.put(("agent_done", None))
@ -196,7 +199,15 @@ async def enhanced_generate_stream_response(
await output_queue.put(("agent_done", None)) await output_queue.put(("agent_done", None))
# 并发执行任务 # 并发执行任务
preamble_task_handle = asyncio.create_task(preamble_task()) # 只有在 enable_thinking 为 True 时才执行 preamble 任务
if config.enable_thinking:
preamble_task_handle = asyncio.create_task(preamble_task())
else:
# 如果不启用 thinking创建一个空的已完成任务
preamble_task_handle = asyncio.create_task(asyncio.sleep(0))
# 直接标记 preamble 完成
preamble_completed.set()
agent_task_handle = asyncio.create_task(agent_task()) agent_task_handle = asyncio.create_task(agent_task())
# 输出控制器:确保 preamble 先输出,然后是 agent stream # 输出控制器:确保 preamble 先输出,然后是 agent stream
@ -214,7 +225,7 @@ async def enhanced_generate_stream_response(
preamble_output_done = True preamble_output_done = True
elif item_type == "preamble_done": elif item_type == "preamble_done":
# Preamble 已完成,标记并继续处理 # Preamble 已完成,标记并继续
preamble_output_done = True preamble_output_done = True
elif item_type == "agent": elif item_type == "agent":
@ -259,77 +270,43 @@ async def enhanced_generate_stream_response(
async def create_agent_and_generate_response( async def create_agent_and_generate_response(
bot_id: str,
api_key: str,
messages: list, messages: list,
stream: bool, config: AgentConfig
tool_response: bool,
model_name: str,
model_server: str,
language: str,
system_prompt: Optional[str],
mcp_settings: Optional[list],
robot_type: str,
project_dir: Optional[str] = None,
generate_cfg: Optional[dict] = None,
user_identifier: Optional[str] = None,
session_id: Optional[str] = None
) -> Union[ChatResponse, StreamingResponse]: ) -> Union[ChatResponse, StreamingResponse]:
"""创建agent并生成响应的公共逻辑""" """创建agent并生成响应的公共逻辑
if generate_cfg is None:
generate_cfg = {} Args:
messages: 消息列表
config: AgentConfig 对象包含所有参数
"""
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
# 如果是流式模式,使用增强的流式响应生成器 # 如果是流式模式,使用增强的流式响应生成器
if stream: if config.stream:
return StreamingResponse( return StreamingResponse(
enhanced_generate_stream_response( enhanced_generate_stream_response(
agent_manager=agent_manager, agent_manager=agent_manager,
bot_id=bot_id,
api_key=api_key,
messages=messages, messages=messages,
tool_response=tool_response, config=config
model_name=model_name,
model_server=model_server,
language=language,
system_prompt=system_prompt or "",
mcp_settings=mcp_settings,
robot_type=robot_type,
project_dir=project_dir,
generate_cfg=generate_cfg,
user_identifier=user_identifier,
session_id=session_id
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
) )
_, system_prompt = get_preamble_text(language, system_prompt)
# 使用公共函数处理所有逻辑 # 使用公共函数处理所有逻辑
agent = await agent_manager.get_or_create_agent( agent = await agent_manager.get_or_create_agent(config)
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier,
session_id=session_id,
)
# 准备最终的消息 # 准备最终的消息
final_messages = messages.copy() final_messages = messages.copy()
# 非流式响应 # 非流式响应
config = {} agent_config = {}
if session_id: if config.session_id:
config["configurable"] = {"thread_id": session_id} agent_config["configurable"] = {"thread_id": config.session_id}
if hasattr(agent, 'logging_handler'): if hasattr(agent, 'logging_handler'):
config["callbacks"] = [agent.logging_handler] agent_config["callbacks"] = [agent.logging_handler]
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS) agent_responses = await agent.ainvoke({"messages": final_messages}, config=agent_config, max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(final_messages):] append_messages = agent_responses["messages"][len(final_messages):]
response_text = "" response_text = ""
for msg in append_messages: for msg in append_messages:
@ -340,7 +317,7 @@ async def create_agent_and_generate_response(
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n" response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
if len(msg.tool_calls)>0: if len(msg.tool_calls)>0:
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls]) response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
elif isinstance(msg,ToolMessage) and tool_response: elif isinstance(msg,ToolMessage) and config.tool_response:
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n" response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
if len(response_text) > 0: if len(response_text) > 0:
@ -410,29 +387,16 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type) project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
# 收集额外参数作为 generate_cfg # 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id'} exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 处理消息 # 处理消息
messages = process_messages(request.messages, request.language) messages = process_messages(request.messages, request.language)
# 创建 AgentConfig 对象
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg)
# 调用公共的agent创建和响应生成逻辑 # 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response( return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages, messages=messages,
stream=request.stream, config=config
tool_response=request.tool_response,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
project_dir=project_dir,
generate_cfg=generate_cfg,
user_identifier=request.user_identifier,
session_id=request.session_id
) )
except Exception as e: except Exception as e:
@ -496,38 +460,20 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
# 从后端API获取机器人配置使用v2的鉴权方式 # 从后端API获取机器人配置使用v2的鉴权方式
bot_config = await fetch_bot_config(bot_id) bot_config = await fetch_bot_config(bot_id)
# v2接口API密钥优先从后端配置获取其次才从Authorization header获取
# 注意这里的Authorization header已经用于鉴权不再作为API key使用
api_key = bot_config.get("api_key")
# 创建项目目录从后端配置获取dataset_ids # 创建项目目录从后端配置获取dataset_ids
project_dir = create_project_directory( project_dir = create_project_directory(
bot_config.get("dataset_ids", []), bot_config.get("dataset_ids", []),
bot_id, bot_id,
bot_config.get("robot_type", "general_agent") bot_config.get("robot_type", "general_agent")
) )
# 处理消息 # 处理消息
messages = process_messages(request.messages, request.language) messages = process_messages(request.messages, request.language)
# 创建 AgentConfig 对象
config = AgentConfig.from_v2_request(request, bot_config, project_dir)
# 调用公共的agent创建和响应生成逻辑 # 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response( return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages, messages=messages,
stream=request.stream, config=config
tool_response=request.tool_response,
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "zh"),
system_prompt=bot_config.get("system_prompt"),
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
project_dir=project_dir,
generate_cfg={}, # v2接口不传递额外的generate_cfg
user_identifier=request.user_identifier,
session_id=request.session_id
) )
except HTTPException: except HTTPException:

97
utils/agent_config.py Normal file
View File

@ -0,0 +1,97 @@
"""Agent配置类用于管理所有Agent相关的参数"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from dataclasses import dataclass
@dataclass
class AgentConfig:
"""Agent配置类包含创建和管理Agent所需的所有参数"""
# 基础参数
bot_id: str
api_key: Optional[str] = None
model_name: str = "qwen3-next"
model_server: Optional[str] = None
language: Optional[str] = "jp"
# 配置参数
system_prompt: Optional[str] = None
mcp_settings: Optional[List[Dict]] = None
robot_type: Optional[str] = "general_agent"
generate_cfg: Optional[Dict] = None
enable_thinking: bool = True
# 上下文参数
project_dir: Optional[str] = None
user_identifier: Optional[str] = None
session_id: Optional[str] = None
# 响应控制参数
stream: bool = False
tool_response: bool = True
preamble_text: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式,用于传递给需要**kwargs的函数"""
return {
'bot_id': self.bot_id,
'api_key': self.api_key,
'model_name': self.model_name,
'model_server': self.model_server,
'language': self.language,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'robot_type': self.robot_type,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'project_dir': self.project_dir,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'stream': self.stream,
'tool_response': self.tool_response,
'preamble_text': self.preamble_text
}
@classmethod
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None):
"""从v1请求创建配置"""
return cls(
bot_id=request.bot_id,
api_key=api_key,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg
)
@classmethod
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None):
"""从v2请求创建配置"""
return cls(
bot_id=request.bot_id,
api_key=bot_config.get("api_key"),
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "zh"),
system_prompt=bot_config.get("system_prompt"),
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg={} # v2接口不传递额外的generate_cfg
)

View File

@ -53,6 +53,7 @@ class ChatRequest(BaseModel):
robot_type: Optional[str] = "general_agent" robot_type: Optional[str] = "general_agent"
user_identifier: Optional[str] = "" user_identifier: Optional[str] = ""
session_id: Optional[str] = None session_id: Optional[str] = None
enable_thinking: Optional[bool] = False
class ChatRequestV2(BaseModel): class ChatRequestV2(BaseModel):
@ -63,6 +64,7 @@ class ChatRequestV2(BaseModel):
language: Optional[str] = "zh" language: Optional[str] = "zh"
user_identifier: Optional[str] = "" user_identifier: Optional[str] = ""
session_id: Optional[str] = None session_id: Optional[str] = None
enable_thinking: Optional[bool] = False
class FileProcessRequest(BaseModel): class FileProcessRequest(BaseModel):

View File

@ -11,6 +11,7 @@ import logging
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from utils.settings import MASTERKEY, BACKEND_HOST from utils.settings import MASTERKEY, BACKEND_HOST
from utils.agent_config import AgentConfig
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
@ -500,9 +501,10 @@ def get_preamble_text(language: str, system_prompt: str):
if preamble_matches: if preamble_matches:
# 提取preamble内容 # 提取preamble内容
preamble_content = preamble_matches[0].strip() preamble_content = preamble_matches[0].strip()
# 从system_prompt中删除preamble代码块 if preamble_content:
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL) # 从system_prompt中删除preamble代码块
return preamble_content, cleaned_system_prompt cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
return preamble_content, cleaned_system_prompt
# 如果没有找到preamble代码块使用默认的preamble选择 # 如果没有找到preamble代码块使用默认的preamble选择
if language == "jp": if language == "jp":
@ -559,12 +561,12 @@ def get_preamble_text(language: str, system_prompt: str):
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
async def call_preamble_llm(chat_history: str, last_message: str, preamble_choices_text: str, language: str, model_name: str, api_key: str, model_server: str) -> str: async def call_preamble_llm(messages: list, config: AgentConfig) -> str:
"""调用大语言模型处理guideline分析 """调用大语言模型处理guideline分析
Args: Args:
chat_history: 聊天历史记录 messages: 消息列表
guidelines_text: 指导原则文本 preamble_choices_text: 指导原则文本
model_name: 模型名称 model_name: 模型名称
api_key: API密钥 api_key: API密钥
model_server: 模型服务器地址 model_server: 模型服务器地址
@ -580,6 +582,14 @@ async def call_preamble_llm(chat_history: str, last_message: str, preamble_choic
logger.error(f"Error reading guideline prompt template: {e}") logger.error(f"Error reading guideline prompt template: {e}")
return "" return ""
api_key = config.api_key
model_name = config.model_name
model_server = config.model_server
language = config.language
preamble_choices_text = config.preamble_text
last_message = get_user_last_message_content(messages)
chat_history = format_messages_to_chat_history(messages)
# 替换模板中的占位符 # 替换模板中的占位符
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language)) system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
# 配置LLM # 配置LLM