add AgentConfig
This commit is contained in:
parent
73b87bd2eb
commit
9525c0f883
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
from langchain.chat_models import init_chat_model
|
||||
# from deepagents import create_deep_agent
|
||||
from langchain.agents import create_agent
|
||||
@ -15,6 +15,8 @@ from utils.fastapi_utils import detect_provider
|
||||
from .guideline_middleware import GuidelineMiddleware
|
||||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||||
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
||||
from utils.agent_config import AgentConfig
|
||||
|
||||
|
||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||||
@ -22,17 +24,6 @@ class LoggingCallbackHandler(BaseCallbackHandler):
|
||||
def __init__(self, logger_name: str = 'app'):
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
# def on_llm_start(
|
||||
# self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any
|
||||
# ) -> None:
|
||||
# """当 LLM 开始时调用"""
|
||||
# self.logger.info("🤖 LLM Start - Input Messages:")
|
||||
# if prompts:
|
||||
# for i, prompt in enumerate(prompts):
|
||||
# self.logger.info(f" Message {i+1}:\n{prompt}")
|
||||
# else:
|
||||
# self.logger.info(" No prompts")
|
||||
|
||||
def on_llm_end(self, response, **kwargs: Any) -> None:
|
||||
"""当 LLM 结束时调用"""
|
||||
self.logger.info("✅ LLM End - Output:")
|
||||
@ -78,7 +69,6 @@ class LoggingCallbackHandler(BaseCallbackHandler):
|
||||
"""当工具调用出错时调用"""
|
||||
self.logger.error(f"❌ Tool Error: {error}")
|
||||
|
||||
|
||||
def on_agent_action(self, action, **kwargs: Any) -> None:
|
||||
"""当 Agent 执行动作时调用"""
|
||||
self.logger.info(f"🎯 Agent Action: {action.log}")
|
||||
@ -97,6 +87,7 @@ def read_mcp_settings():
|
||||
mcp_settings_json = json.load(f)
|
||||
return mcp_settings_json
|
||||
|
||||
|
||||
async def get_tools_from_mcp(mcp):
|
||||
"""从MCP配置中提取工具"""
|
||||
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
||||
@ -123,66 +114,60 @@ async def get_tools_from_mcp(mcp):
|
||||
# 发生异常时返回空列表,避免上层调用报错
|
||||
return []
|
||||
|
||||
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
||||
model_server=None, generate_cfg=None,
|
||||
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None,
|
||||
session_id=None):
|
||||
|
||||
async def init_agent(config: AgentConfig):
|
||||
"""
|
||||
初始化 Agent,支持持久化内存和对话摘要
|
||||
|
||||
Args:
|
||||
bot_id: Bot ID
|
||||
model_name: 模型名称
|
||||
api_key: API密钥
|
||||
model_server: 模型服务器地址
|
||||
generate_cfg: 生成配置
|
||||
system_prompt: 系统提示
|
||||
mcp: MCP配置
|
||||
robot_type: 机器人类型
|
||||
language: 语言
|
||||
user_identifier: 用户标识
|
||||
session_id: 会话ID(如果为None,则不启用持久化内存)
|
||||
config: AgentConfig 对象,包含所有初始化参数
|
||||
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
||||
"""
|
||||
system = system_prompt if system_prompt else read_system_prompt()
|
||||
mcp = mcp if mcp else read_mcp_settings()
|
||||
# 如果没有提供mcp,使用config中的mcp_settings
|
||||
mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings()
|
||||
system = config.system_prompt if config.system_prompt else read_system_prompt()
|
||||
mcp_tools = await get_tools_from_mcp(mcp)
|
||||
|
||||
# 检测或使用指定的提供商
|
||||
model_provider,base_url = detect_provider(model_name,model_server)
|
||||
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
||||
|
||||
# 构建模型参数
|
||||
model_kwargs = {
|
||||
"model": model_name,
|
||||
"model": config.model_name,
|
||||
"model_provider": model_provider,
|
||||
"temperature": 0.8,
|
||||
"base_url": base_url,
|
||||
"api_key": api_key
|
||||
"api_key": config.api_key
|
||||
}
|
||||
if generate_cfg:
|
||||
model_kwargs.update(generate_cfg)
|
||||
if config.generate_cfg:
|
||||
model_kwargs.update(config.generate_cfg)
|
||||
llm_instance = init_chat_model(**model_kwargs)
|
||||
|
||||
# 创建自定义的日志处理器
|
||||
logging_handler = LoggingCallbackHandler()
|
||||
|
||||
# 构建中间件列表
|
||||
middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
|
||||
middleware = []
|
||||
|
||||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||
if config.enable_thinking:
|
||||
middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier))
|
||||
|
||||
# 添加工具输出长度控制中间件
|
||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||
max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH,
|
||||
truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'),
|
||||
tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具
|
||||
exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具
|
||||
preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True),
|
||||
preserve_json=getattr(generate_cfg, 'preserve_json', True)
|
||||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
|
||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
|
||||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||
)
|
||||
middleware.append(tool_output_middleware)
|
||||
|
||||
# 初始化 checkpointer 和中间件
|
||||
checkpointer = None
|
||||
|
||||
if session_id:
|
||||
if config.session_id:
|
||||
checkpointer = MemorySaver()
|
||||
summarization_middleware = SummarizationMiddleware(
|
||||
model=llm_instance,
|
||||
@ -203,6 +188,6 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
||||
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
|
||||
agent.logging_handler = logging_handler
|
||||
agent.checkpointer = checkpointer
|
||||
agent.bot_id = bot_id
|
||||
agent.session_id = session_id
|
||||
agent.bot_id = config.bot_id
|
||||
agent.session_id = config.session_id
|
||||
return agent
|
||||
@ -26,6 +26,7 @@ logger = logging.getLogger('app')
|
||||
|
||||
from agent.deep_assistant import init_agent
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
from utils.agent_config import AgentConfig
|
||||
|
||||
|
||||
class ShardedAgentManager:
|
||||
@ -67,7 +68,8 @@ class ShardedAgentManager:
|
||||
|
||||
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
|
||||
model_server: str = None, generate_cfg: Dict = None,
|
||||
system_prompt: str = None, mcp_settings: List[Dict] = None) -> str:
|
||||
system_prompt: str = None, mcp_settings: List[Dict] = None,
|
||||
enable_thinking: bool = False) -> str:
|
||||
"""获取包含所有相关参数的哈希值作为缓存键"""
|
||||
cache_data = {
|
||||
'bot_id': bot_id,
|
||||
@ -76,7 +78,8 @@ class ShardedAgentManager:
|
||||
'model_server': model_server or '',
|
||||
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
|
||||
'system_prompt': system_prompt or '',
|
||||
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True)
|
||||
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
|
||||
'enable_thinking': enable_thinking
|
||||
}
|
||||
|
||||
cache_str = json.dumps(cache_data, sort_keys=True)
|
||||
@ -116,20 +119,12 @@ class ShardedAgentManager:
|
||||
if removed_count > 0:
|
||||
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
|
||||
|
||||
async def get_or_create_agent(self,
|
||||
bot_id: str,
|
||||
project_dir: Optional[str],
|
||||
model_name: str = "qwen3-next",
|
||||
api_key: Optional[str] = None,
|
||||
model_server: Optional[str] = None,
|
||||
generate_cfg: Optional[Dict] = None,
|
||||
language: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
mcp_settings: Optional[List[Dict]] = None,
|
||||
robot_type: Optional[str] = "general_agent",
|
||||
user_identifier: Optional[str] = None,
|
||||
session_id: Optional[str] = None):
|
||||
"""获取或创建文件预加载的助手实例"""
|
||||
async def get_or_create_agent(self, config: AgentConfig):
|
||||
"""获取或创建文件预加载的助手实例
|
||||
|
||||
Args:
|
||||
config: AgentConfig 对象,包含所有初始化参数
|
||||
"""
|
||||
|
||||
# 更新请求统计
|
||||
with self._stats_lock:
|
||||
@ -137,14 +132,16 @@ class ShardedAgentManager:
|
||||
|
||||
# 异步加载配置文件(带缓存)
|
||||
final_system_prompt = await load_system_prompt_async(
|
||||
project_dir, language, system_prompt, robot_type, bot_id, user_identifier
|
||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||
)
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
project_dir, mcp_settings, bot_id, robot_type
|
||||
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||
)
|
||||
|
||||
cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server,
|
||||
generate_cfg, final_system_prompt, final_mcp_settings)
|
||||
config.system_prompt = final_system_prompt
|
||||
config.mcp_settings = final_mcp_settings
|
||||
cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
|
||||
config.generate_cfg, final_system_prompt, final_mcp_settings,
|
||||
config.enable_thinking)
|
||||
|
||||
# 获取分片
|
||||
shard_index = self._get_shard_index(cache_key)
|
||||
@ -160,7 +157,7 @@ class ShardedAgentManager:
|
||||
with self._stats_lock:
|
||||
self._global_stats['cache_hits'] += 1
|
||||
|
||||
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id}, shard: {shard_index})")
|
||||
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
|
||||
return agent
|
||||
|
||||
# 更新缓存未命中统计
|
||||
@ -188,27 +185,15 @@ class ShardedAgentManager:
|
||||
self._cleanup_old_agents(shard)
|
||||
|
||||
# 创建新的助手实例
|
||||
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}")
|
||||
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
|
||||
current_time = time.time()
|
||||
|
||||
agent = await init_agent(
|
||||
bot_id=bot_id,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
system_prompt=final_system_prompt,
|
||||
mcp=final_mcp_settings,
|
||||
robot_type=robot_type,
|
||||
language=language,
|
||||
user_identifier=user_identifier,
|
||||
session_id=session_id
|
||||
)
|
||||
agent = await init_agent(config)
|
||||
|
||||
# 缓存实例
|
||||
async with shard['lock']:
|
||||
shard['agents'][cache_key] = agent
|
||||
shard['unique_ids'][cache_key] = bot_id
|
||||
shard['unique_ids'][cache_key] = config.bot_id
|
||||
shard['access_times'][cache_key] = current_time
|
||||
shard['creation_times'][cache_key] = current_time
|
||||
|
||||
|
||||
220
routes/chat.py
220
routes/chat.py
@ -14,13 +14,14 @@ from utils import (
|
||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||
from utils.api_models import ChatRequestV2
|
||||
from utils.fastapi_utils import (
|
||||
process_messages, format_messages_to_chat_history,
|
||||
process_messages,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||
call_preamble_llm, get_preamble_text, get_user_last_message_content,
|
||||
call_preamble_llm, get_preamble_text,
|
||||
create_stream_chunk
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
||||
from utils.agent_config import AgentConfig
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -69,30 +70,45 @@ def append_assistant_last_message(messages: list, content: str) -> bool:
|
||||
messages.append({"role":"assistant","content":content})
|
||||
return messages
|
||||
|
||||
|
||||
def get_user_last_message_content(messages: list) -> str:
|
||||
"""获取最后一条用户消息的内容"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
for msg in reversed(messages):
|
||||
if msg.get('role') == 'user':
|
||||
return msg.get('content', '')
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def format_messages_to_chat_history(messages: list) -> str:
|
||||
"""将消息格式化为聊天历史字符串"""
|
||||
chat_history = ""
|
||||
for msg in messages:
|
||||
role = msg.get('role', '')
|
||||
content = msg.get('content', '')
|
||||
if role == 'user':
|
||||
chat_history += f"用户: {content}\n"
|
||||
elif role == 'assistant':
|
||||
chat_history += f"助手: {content}\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
async def enhanced_generate_stream_response(
|
||||
agent_manager,
|
||||
bot_id: str,
|
||||
api_key: str,
|
||||
messages: list,
|
||||
tool_response: bool,
|
||||
model_name: str,
|
||||
model_server: str,
|
||||
language: str,
|
||||
system_prompt: str,
|
||||
mcp_settings: Optional[list],
|
||||
robot_type: str,
|
||||
project_dir: Optional[str],
|
||||
generate_cfg: Optional[dict],
|
||||
user_identifier: Optional[str],
|
||||
session_id: Optional[str] = None,
|
||||
config: AgentConfig
|
||||
):
|
||||
"""增强的渐进式流式响应生成器 - 并发优化版本"""
|
||||
try:
|
||||
# 准备参数
|
||||
query_text = get_user_last_message_content(messages)
|
||||
chat_history = format_messages_to_chat_history(messages)
|
||||
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
|
||||
"""增强的渐进式流式响应生成器 - 并发优化版本
|
||||
|
||||
Args:
|
||||
agent_manager: agent管理器
|
||||
messages: 消息列表
|
||||
config: AgentConfig 对象,包含所有参数
|
||||
"""
|
||||
try:
|
||||
# 创建输出队列和控制事件
|
||||
output_queue = asyncio.Queue()
|
||||
preamble_completed = asyncio.Event()
|
||||
@ -100,11 +116,11 @@ async def enhanced_generate_stream_response(
|
||||
# Preamble 任务
|
||||
async def preamble_task():
|
||||
try:
|
||||
preamble_result = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
|
||||
preamble_result = await call_preamble_llm(messages,config)
|
||||
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
||||
if preamble_result and preamble_result.strip() and preamble_result != "<empty>":
|
||||
preamble_content = f"[PREAMBLE]\n{preamble_result}\n"
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content)
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-preamble", config.model_name, preamble_content)
|
||||
await output_queue.put(("preamble", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
||||
logger.info(f"Stream mode: Generated preamble text ({len(preamble_result)} chars)")
|
||||
else:
|
||||
@ -124,32 +140,19 @@ async def enhanced_generate_stream_response(
|
||||
async def agent_task():
|
||||
try:
|
||||
# 准备 agent
|
||||
agent = await agent_manager.get_or_create_agent(
|
||||
bot_id=bot_id,
|
||||
project_dir=project_dir,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
language=language,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
user_identifier=user_identifier,
|
||||
session_id=session_id,
|
||||
)
|
||||
agent = await agent_manager.get_or_create_agent(config)
|
||||
|
||||
# 开始流式处理
|
||||
logger.info(f"Starting agent stream response")
|
||||
chunk_id = 0
|
||||
message_tag = ""
|
||||
|
||||
config = {}
|
||||
if session_id:
|
||||
config["configurable"] = {"thread_id": session_id}
|
||||
stream_config = {}
|
||||
if config.session_id:
|
||||
stream_config["configurable"] = {"thread_id": config.session_id}
|
||||
if hasattr(agent, 'logging_handler'):
|
||||
config["callbacks"] = [agent.logging_handler]
|
||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS):
|
||||
stream_config["callbacks"] = [agent.logging_handler]
|
||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=stream_config, max_tokens=MAX_OUTPUT_TOKENS):
|
||||
new_content = ""
|
||||
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
@ -174,7 +177,7 @@ async def enhanced_generate_stream_response(
|
||||
new_content += msg.text
|
||||
|
||||
# 处理工具响应
|
||||
elif isinstance(msg, ToolMessage) and tool_response and msg.content:
|
||||
elif isinstance(msg, ToolMessage) and config.tool_response and msg.content:
|
||||
message_tag = "TOOL_RESPONSE"
|
||||
new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n"
|
||||
|
||||
@ -183,11 +186,11 @@ async def enhanced_generate_stream_response(
|
||||
if chunk_id == 0:
|
||||
logger.info(f"Agent首个Token已生成, 开始流式输出")
|
||||
chunk_id += 1
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", model_name, new_content)
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", config.model_name, new_content)
|
||||
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
||||
|
||||
# 发送最终chunk
|
||||
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", model_name, finish_reason="stop")
|
||||
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop")
|
||||
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
|
||||
await output_queue.put(("agent_done", None))
|
||||
|
||||
@ -196,7 +199,15 @@ async def enhanced_generate_stream_response(
|
||||
await output_queue.put(("agent_done", None))
|
||||
|
||||
# 并发执行任务
|
||||
preamble_task_handle = asyncio.create_task(preamble_task())
|
||||
# 只有在 enable_thinking 为 True 时才执行 preamble 任务
|
||||
if config.enable_thinking:
|
||||
preamble_task_handle = asyncio.create_task(preamble_task())
|
||||
else:
|
||||
# 如果不启用 thinking,创建一个空的已完成任务
|
||||
preamble_task_handle = asyncio.create_task(asyncio.sleep(0))
|
||||
# 直接标记 preamble 完成
|
||||
preamble_completed.set()
|
||||
|
||||
agent_task_handle = asyncio.create_task(agent_task())
|
||||
|
||||
# 输出控制器:确保 preamble 先输出,然后是 agent stream
|
||||
@ -214,7 +225,7 @@ async def enhanced_generate_stream_response(
|
||||
preamble_output_done = True
|
||||
|
||||
elif item_type == "preamble_done":
|
||||
# Preamble 已完成,标记并继续处理
|
||||
# Preamble 已完成,标记并继续
|
||||
preamble_output_done = True
|
||||
|
||||
elif item_type == "agent":
|
||||
@ -259,77 +270,43 @@ async def enhanced_generate_stream_response(
|
||||
|
||||
|
||||
async def create_agent_and_generate_response(
|
||||
bot_id: str,
|
||||
api_key: str,
|
||||
messages: list,
|
||||
stream: bool,
|
||||
tool_response: bool,
|
||||
model_name: str,
|
||||
model_server: str,
|
||||
language: str,
|
||||
system_prompt: Optional[str],
|
||||
mcp_settings: Optional[list],
|
||||
robot_type: str,
|
||||
project_dir: Optional[str] = None,
|
||||
generate_cfg: Optional[dict] = None,
|
||||
user_identifier: Optional[str] = None,
|
||||
session_id: Optional[str] = None
|
||||
config: AgentConfig
|
||||
) -> Union[ChatResponse, StreamingResponse]:
|
||||
"""创建agent并生成响应的公共逻辑"""
|
||||
if generate_cfg is None:
|
||||
generate_cfg = {}
|
||||
"""创建agent并生成响应的公共逻辑
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
config: AgentConfig 对象,包含所有参数
|
||||
"""
|
||||
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
|
||||
|
||||
# 如果是流式模式,使用增强的流式响应生成器
|
||||
if stream:
|
||||
if config.stream:
|
||||
return StreamingResponse(
|
||||
enhanced_generate_stream_response(
|
||||
agent_manager=agent_manager,
|
||||
bot_id=bot_id,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
tool_response=tool_response,
|
||||
model_name=model_name,
|
||||
model_server=model_server,
|
||||
language=language,
|
||||
system_prompt=system_prompt or "",
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
project_dir=project_dir,
|
||||
generate_cfg=generate_cfg,
|
||||
user_identifier=user_identifier,
|
||||
session_id=session_id
|
||||
config=config
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
|
||||
_, system_prompt = get_preamble_text(language, system_prompt)
|
||||
|
||||
# 使用公共函数处理所有逻辑
|
||||
agent = await agent_manager.get_or_create_agent(
|
||||
bot_id=bot_id,
|
||||
project_dir=project_dir,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
language=language,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
user_identifier=user_identifier,
|
||||
session_id=session_id,
|
||||
)
|
||||
agent = await agent_manager.get_or_create_agent(config)
|
||||
|
||||
# 准备最终的消息
|
||||
final_messages = messages.copy()
|
||||
|
||||
# 非流式响应
|
||||
config = {}
|
||||
if session_id:
|
||||
config["configurable"] = {"thread_id": session_id}
|
||||
agent_config = {}
|
||||
if config.session_id:
|
||||
agent_config["configurable"] = {"thread_id": config.session_id}
|
||||
if hasattr(agent, 'logging_handler'):
|
||||
config["callbacks"] = [agent.logging_handler]
|
||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS)
|
||||
agent_config["callbacks"] = [agent.logging_handler]
|
||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=agent_config, max_tokens=MAX_OUTPUT_TOKENS)
|
||||
append_messages = agent_responses["messages"][len(final_messages):]
|
||||
response_text = ""
|
||||
for msg in append_messages:
|
||||
@ -340,7 +317,7 @@ async def create_agent_and_generate_response(
|
||||
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
||||
if len(msg.tool_calls)>0:
|
||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||
elif isinstance(msg,ToolMessage) and tool_response:
|
||||
elif isinstance(msg,ToolMessage) and config.tool_response:
|
||||
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
|
||||
|
||||
if len(response_text) > 0:
|
||||
@ -410,29 +387,16 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
||||
|
||||
# 收集额外参数作为 generate_cfg
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id'}
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking'}
|
||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||
|
||||
# 处理消息
|
||||
messages = process_messages(request.messages, request.language)
|
||||
|
||||
# 创建 AgentConfig 对象
|
||||
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg)
|
||||
# 调用公共的agent创建和响应生成逻辑
|
||||
return await create_agent_and_generate_response(
|
||||
bot_id=bot_id,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
stream=request.stream,
|
||||
tool_response=request.tool_response,
|
||||
model_name=request.model,
|
||||
model_server=request.model_server,
|
||||
language=request.language,
|
||||
system_prompt=request.system_prompt,
|
||||
mcp_settings=request.mcp_settings,
|
||||
robot_type=request.robot_type,
|
||||
project_dir=project_dir,
|
||||
generate_cfg=generate_cfg,
|
||||
user_identifier=request.user_identifier,
|
||||
session_id=request.session_id
|
||||
config=config
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -496,38 +460,20 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
|
||||
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
||||
bot_config = await fetch_bot_config(bot_id)
|
||||
|
||||
# v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取
|
||||
# 注意:这里的Authorization header已经用于鉴权,不再作为API key使用
|
||||
api_key = bot_config.get("api_key")
|
||||
|
||||
# 创建项目目录(从后端配置获取dataset_ids)
|
||||
project_dir = create_project_directory(
|
||||
bot_config.get("dataset_ids", []),
|
||||
bot_id,
|
||||
bot_config.get("robot_type", "general_agent")
|
||||
)
|
||||
|
||||
# 处理消息
|
||||
messages = process_messages(request.messages, request.language)
|
||||
|
||||
# 创建 AgentConfig 对象
|
||||
config = AgentConfig.from_v2_request(request, bot_config, project_dir)
|
||||
# 调用公共的agent创建和响应生成逻辑
|
||||
return await create_agent_and_generate_response(
|
||||
bot_id=bot_id,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
stream=request.stream,
|
||||
tool_response=request.tool_response,
|
||||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||||
model_server=bot_config.get("model_server", ""),
|
||||
language=request.language or bot_config.get("language", "zh"),
|
||||
system_prompt=bot_config.get("system_prompt"),
|
||||
mcp_settings=bot_config.get("mcp_settings", []),
|
||||
robot_type=bot_config.get("robot_type", "general_agent"),
|
||||
project_dir=project_dir,
|
||||
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
||||
user_identifier=request.user_identifier,
|
||||
session_id=request.session_id
|
||||
config=config
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
||||
97
utils/agent_config.py
Normal file
97
utils/agent_config.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Agent配置类,用于管理所有Agent相关的参数"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Agent配置类,包含创建和管理Agent所需的所有参数"""
|
||||
|
||||
# 基础参数
|
||||
bot_id: str
|
||||
api_key: Optional[str] = None
|
||||
model_name: str = "qwen3-next"
|
||||
model_server: Optional[str] = None
|
||||
language: Optional[str] = "jp"
|
||||
|
||||
# 配置参数
|
||||
system_prompt: Optional[str] = None
|
||||
mcp_settings: Optional[List[Dict]] = None
|
||||
robot_type: Optional[str] = "general_agent"
|
||||
generate_cfg: Optional[Dict] = None
|
||||
enable_thinking: bool = True
|
||||
|
||||
# 上下文参数
|
||||
project_dir: Optional[str] = None
|
||||
user_identifier: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
# 响应控制参数
|
||||
stream: bool = False
|
||||
tool_response: bool = True
|
||||
preamble_text: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式,用于传递给需要**kwargs的函数"""
|
||||
return {
|
||||
'bot_id': self.bot_id,
|
||||
'api_key': self.api_key,
|
||||
'model_name': self.model_name,
|
||||
'model_server': self.model_server,
|
||||
'language': self.language,
|
||||
'system_prompt': self.system_prompt,
|
||||
'mcp_settings': self.mcp_settings,
|
||||
'robot_type': self.robot_type,
|
||||
'generate_cfg': self.generate_cfg,
|
||||
'enable_thinking': self.enable_thinking,
|
||||
'project_dir': self.project_dir,
|
||||
'user_identifier': self.user_identifier,
|
||||
'session_id': self.session_id,
|
||||
'stream': self.stream,
|
||||
'tool_response': self.tool_response,
|
||||
'preamble_text': self.preamble_text
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None):
|
||||
"""从v1请求创建配置"""
|
||||
return cls(
|
||||
bot_id=request.bot_id,
|
||||
api_key=api_key,
|
||||
model_name=request.model,
|
||||
model_server=request.model_server,
|
||||
language=request.language,
|
||||
system_prompt=request.system_prompt,
|
||||
mcp_settings=request.mcp_settings,
|
||||
robot_type=request.robot_type,
|
||||
user_identifier=request.user_identifier,
|
||||
session_id=request.session_id,
|
||||
enable_thinking=request.enable_thinking,
|
||||
project_dir=project_dir,
|
||||
stream=request.stream,
|
||||
tool_response=request.tool_response,
|
||||
generate_cfg=generate_cfg
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None):
|
||||
"""从v2请求创建配置"""
|
||||
return cls(
|
||||
bot_id=request.bot_id,
|
||||
api_key=bot_config.get("api_key"),
|
||||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||||
model_server=bot_config.get("model_server", ""),
|
||||
language=request.language or bot_config.get("language", "zh"),
|
||||
system_prompt=bot_config.get("system_prompt"),
|
||||
mcp_settings=bot_config.get("mcp_settings", []),
|
||||
robot_type=bot_config.get("robot_type", "general_agent"),
|
||||
user_identifier=request.user_identifier,
|
||||
session_id=request.session_id,
|
||||
enable_thinking=request.enable_thinking,
|
||||
project_dir=project_dir,
|
||||
stream=request.stream,
|
||||
tool_response=request.tool_response,
|
||||
generate_cfg={} # v2接口不传递额外的generate_cfg
|
||||
)
|
||||
@ -53,6 +53,7 @@ class ChatRequest(BaseModel):
|
||||
robot_type: Optional[str] = "general_agent"
|
||||
user_identifier: Optional[str] = ""
|
||||
session_id: Optional[str] = None
|
||||
enable_thinking: Optional[bool] = False
|
||||
|
||||
|
||||
class ChatRequestV2(BaseModel):
|
||||
@ -63,6 +64,7 @@ class ChatRequestV2(BaseModel):
|
||||
language: Optional[str] = "zh"
|
||||
user_identifier: Optional[str] = ""
|
||||
session_id: Optional[str] = None
|
||||
enable_thinking: Optional[bool] = False
|
||||
|
||||
|
||||
class FileProcessRequest(BaseModel):
|
||||
|
||||
@ -11,6 +11,7 @@ import logging
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.chat_models import init_chat_model
|
||||
from utils.settings import MASTERKEY, BACKEND_HOST
|
||||
from utils.agent_config import AgentConfig
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
@ -500,9 +501,10 @@ def get_preamble_text(language: str, system_prompt: str):
|
||||
if preamble_matches:
|
||||
# 提取preamble内容
|
||||
preamble_content = preamble_matches[0].strip()
|
||||
# 从system_prompt中删除preamble代码块
|
||||
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
|
||||
return preamble_content, cleaned_system_prompt
|
||||
if preamble_content:
|
||||
# 从system_prompt中删除preamble代码块
|
||||
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
|
||||
return preamble_content, cleaned_system_prompt
|
||||
|
||||
# 如果没有找到preamble代码块,使用默认的preamble选择
|
||||
if language == "jp":
|
||||
@ -559,12 +561,12 @@ def get_preamble_text(language: str, system_prompt: str):
|
||||
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
|
||||
|
||||
|
||||
async def call_preamble_llm(chat_history: str, last_message: str, preamble_choices_text: str, language: str, model_name: str, api_key: str, model_server: str) -> str:
|
||||
async def call_preamble_llm(messages: list, config: AgentConfig) -> str:
|
||||
"""调用大语言模型处理guideline分析
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
guidelines_text: 指导原则文本
|
||||
messages: 消息列表
|
||||
preamble_choices_text: 指导原则文本
|
||||
model_name: 模型名称
|
||||
api_key: API密钥
|
||||
model_server: 模型服务器地址
|
||||
@ -580,6 +582,14 @@ async def call_preamble_llm(chat_history: str, last_message: str, preamble_choic
|
||||
logger.error(f"Error reading guideline prompt template: {e}")
|
||||
return ""
|
||||
|
||||
api_key = config.api_key
|
||||
model_name = config.model_name
|
||||
model_server = config.model_server
|
||||
language = config.language
|
||||
preamble_choices_text = config.preamble_text
|
||||
last_message = get_user_last_message_content(messages)
|
||||
chat_history = format_messages_to_chat_history(messages)
|
||||
|
||||
# 替换模板中的占位符
|
||||
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
||||
# 配置LLM
|
||||
|
||||
Loading…
Reference in New Issue
Block a user