add AgentConfig
This commit is contained in:
parent
73b87bd2eb
commit
9525c0f883
@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, List
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
# from deepagents import create_deep_agent
|
# from deepagents import create_deep_agent
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
@ -15,6 +15,8 @@ from utils.fastapi_utils import detect_provider
|
|||||||
from .guideline_middleware import GuidelineMiddleware
|
from .guideline_middleware import GuidelineMiddleware
|
||||||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||||||
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
||||||
|
from utils.agent_config import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||||||
@ -22,17 +24,6 @@ class LoggingCallbackHandler(BaseCallbackHandler):
|
|||||||
def __init__(self, logger_name: str = 'app'):
|
def __init__(self, logger_name: str = 'app'):
|
||||||
self.logger = logging.getLogger(logger_name)
|
self.logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
# def on_llm_start(
|
|
||||||
# self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any
|
|
||||||
# ) -> None:
|
|
||||||
# """当 LLM 开始时调用"""
|
|
||||||
# self.logger.info("🤖 LLM Start - Input Messages:")
|
|
||||||
# if prompts:
|
|
||||||
# for i, prompt in enumerate(prompts):
|
|
||||||
# self.logger.info(f" Message {i+1}:\n{prompt}")
|
|
||||||
# else:
|
|
||||||
# self.logger.info(" No prompts")
|
|
||||||
|
|
||||||
def on_llm_end(self, response, **kwargs: Any) -> None:
|
def on_llm_end(self, response, **kwargs: Any) -> None:
|
||||||
"""当 LLM 结束时调用"""
|
"""当 LLM 结束时调用"""
|
||||||
self.logger.info("✅ LLM End - Output:")
|
self.logger.info("✅ LLM End - Output:")
|
||||||
@ -78,7 +69,6 @@ class LoggingCallbackHandler(BaseCallbackHandler):
|
|||||||
"""当工具调用出错时调用"""
|
"""当工具调用出错时调用"""
|
||||||
self.logger.error(f"❌ Tool Error: {error}")
|
self.logger.error(f"❌ Tool Error: {error}")
|
||||||
|
|
||||||
|
|
||||||
def on_agent_action(self, action, **kwargs: Any) -> None:
|
def on_agent_action(self, action, **kwargs: Any) -> None:
|
||||||
"""当 Agent 执行动作时调用"""
|
"""当 Agent 执行动作时调用"""
|
||||||
self.logger.info(f"🎯 Agent Action: {action.log}")
|
self.logger.info(f"🎯 Agent Action: {action.log}")
|
||||||
@ -97,6 +87,7 @@ def read_mcp_settings():
|
|||||||
mcp_settings_json = json.load(f)
|
mcp_settings_json = json.load(f)
|
||||||
return mcp_settings_json
|
return mcp_settings_json
|
||||||
|
|
||||||
|
|
||||||
async def get_tools_from_mcp(mcp):
|
async def get_tools_from_mcp(mcp):
|
||||||
"""从MCP配置中提取工具"""
|
"""从MCP配置中提取工具"""
|
||||||
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
||||||
@ -123,66 +114,60 @@ async def get_tools_from_mcp(mcp):
|
|||||||
# 发生异常时返回空列表,避免上层调用报错
|
# 发生异常时返回空列表,避免上层调用报错
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
|
||||||
model_server=None, generate_cfg=None,
|
async def init_agent(config: AgentConfig):
|
||||||
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None,
|
|
||||||
session_id=None):
|
|
||||||
"""
|
"""
|
||||||
初始化 Agent,支持持久化内存和对话摘要
|
初始化 Agent,支持持久化内存和对话摘要
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bot_id: Bot ID
|
config: AgentConfig 对象,包含所有初始化参数
|
||||||
model_name: 模型名称
|
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
||||||
api_key: API密钥
|
|
||||||
model_server: 模型服务器地址
|
|
||||||
generate_cfg: 生成配置
|
|
||||||
system_prompt: 系统提示
|
|
||||||
mcp: MCP配置
|
|
||||||
robot_type: 机器人类型
|
|
||||||
language: 语言
|
|
||||||
user_identifier: 用户标识
|
|
||||||
session_id: 会话ID(如果为None,则不启用持久化内存)
|
|
||||||
"""
|
"""
|
||||||
system = system_prompt if system_prompt else read_system_prompt()
|
# 如果没有提供mcp,使用config中的mcp_settings
|
||||||
mcp = mcp if mcp else read_mcp_settings()
|
mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings()
|
||||||
|
system = config.system_prompt if config.system_prompt else read_system_prompt()
|
||||||
mcp_tools = await get_tools_from_mcp(mcp)
|
mcp_tools = await get_tools_from_mcp(mcp)
|
||||||
|
|
||||||
# 检测或使用指定的提供商
|
# 检测或使用指定的提供商
|
||||||
model_provider,base_url = detect_provider(model_name,model_server)
|
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
||||||
|
|
||||||
# 构建模型参数
|
# 构建模型参数
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"model": model_name,
|
"model": config.model_name,
|
||||||
"model_provider": model_provider,
|
"model_provider": model_provider,
|
||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
"base_url": base_url,
|
"base_url": base_url,
|
||||||
"api_key": api_key
|
"api_key": config.api_key
|
||||||
}
|
}
|
||||||
if generate_cfg:
|
if config.generate_cfg:
|
||||||
model_kwargs.update(generate_cfg)
|
model_kwargs.update(config.generate_cfg)
|
||||||
llm_instance = init_chat_model(**model_kwargs)
|
llm_instance = init_chat_model(**model_kwargs)
|
||||||
|
|
||||||
# 创建自定义的日志处理器
|
# 创建自定义的日志处理器
|
||||||
logging_handler = LoggingCallbackHandler()
|
logging_handler = LoggingCallbackHandler()
|
||||||
|
|
||||||
# 构建中间件列表
|
# 构建中间件列表
|
||||||
middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
|
middleware = []
|
||||||
|
|
||||||
|
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||||
|
if config.enable_thinking:
|
||||||
|
middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier))
|
||||||
|
|
||||||
# 添加工具输出长度控制中间件
|
# 添加工具输出长度控制中间件
|
||||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||||
max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH,
|
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||||
truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'),
|
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||||
tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具
|
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
|
||||||
exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具
|
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
|
||||||
preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True),
|
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||||
preserve_json=getattr(generate_cfg, 'preserve_json', True)
|
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||||
)
|
)
|
||||||
middleware.append(tool_output_middleware)
|
middleware.append(tool_output_middleware)
|
||||||
|
|
||||||
# 初始化 checkpointer 和中间件
|
# 初始化 checkpointer 和中间件
|
||||||
checkpointer = None
|
checkpointer = None
|
||||||
|
|
||||||
if session_id:
|
if config.session_id:
|
||||||
checkpointer = MemorySaver()
|
checkpointer = MemorySaver()
|
||||||
summarization_middleware = SummarizationMiddleware(
|
summarization_middleware = SummarizationMiddleware(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
@ -203,6 +188,6 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
|||||||
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
|
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
|
||||||
agent.logging_handler = logging_handler
|
agent.logging_handler = logging_handler
|
||||||
agent.checkpointer = checkpointer
|
agent.checkpointer = checkpointer
|
||||||
agent.bot_id = bot_id
|
agent.bot_id = config.bot_id
|
||||||
agent.session_id = session_id
|
agent.session_id = config.session_id
|
||||||
return agent
|
return agent
|
||||||
@ -26,6 +26,7 @@ logger = logging.getLogger('app')
|
|||||||
|
|
||||||
from agent.deep_assistant import init_agent
|
from agent.deep_assistant import init_agent
|
||||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||||
|
from utils.agent_config import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
class ShardedAgentManager:
|
class ShardedAgentManager:
|
||||||
@ -67,7 +68,8 @@ class ShardedAgentManager:
|
|||||||
|
|
||||||
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
|
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
|
||||||
model_server: str = None, generate_cfg: Dict = None,
|
model_server: str = None, generate_cfg: Dict = None,
|
||||||
system_prompt: str = None, mcp_settings: List[Dict] = None) -> str:
|
system_prompt: str = None, mcp_settings: List[Dict] = None,
|
||||||
|
enable_thinking: bool = False) -> str:
|
||||||
"""获取包含所有相关参数的哈希值作为缓存键"""
|
"""获取包含所有相关参数的哈希值作为缓存键"""
|
||||||
cache_data = {
|
cache_data = {
|
||||||
'bot_id': bot_id,
|
'bot_id': bot_id,
|
||||||
@ -76,7 +78,8 @@ class ShardedAgentManager:
|
|||||||
'model_server': model_server or '',
|
'model_server': model_server or '',
|
||||||
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
|
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
|
||||||
'system_prompt': system_prompt or '',
|
'system_prompt': system_prompt or '',
|
||||||
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True)
|
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
|
||||||
|
'enable_thinking': enable_thinking
|
||||||
}
|
}
|
||||||
|
|
||||||
cache_str = json.dumps(cache_data, sort_keys=True)
|
cache_str = json.dumps(cache_data, sort_keys=True)
|
||||||
@ -116,20 +119,12 @@ class ShardedAgentManager:
|
|||||||
if removed_count > 0:
|
if removed_count > 0:
|
||||||
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
|
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
|
||||||
|
|
||||||
async def get_or_create_agent(self,
|
async def get_or_create_agent(self, config: AgentConfig):
|
||||||
bot_id: str,
|
"""获取或创建文件预加载的助手实例
|
||||||
project_dir: Optional[str],
|
|
||||||
model_name: str = "qwen3-next",
|
Args:
|
||||||
api_key: Optional[str] = None,
|
config: AgentConfig 对象,包含所有初始化参数
|
||||||
model_server: Optional[str] = None,
|
"""
|
||||||
generate_cfg: Optional[Dict] = None,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
mcp_settings: Optional[List[Dict]] = None,
|
|
||||||
robot_type: Optional[str] = "general_agent",
|
|
||||||
user_identifier: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None):
|
|
||||||
"""获取或创建文件预加载的助手实例"""
|
|
||||||
|
|
||||||
# 更新请求统计
|
# 更新请求统计
|
||||||
with self._stats_lock:
|
with self._stats_lock:
|
||||||
@ -137,14 +132,16 @@ class ShardedAgentManager:
|
|||||||
|
|
||||||
# 异步加载配置文件(带缓存)
|
# 异步加载配置文件(带缓存)
|
||||||
final_system_prompt = await load_system_prompt_async(
|
final_system_prompt = await load_system_prompt_async(
|
||||||
project_dir, language, system_prompt, robot_type, bot_id, user_identifier
|
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||||
)
|
)
|
||||||
final_mcp_settings = await load_mcp_settings_async(
|
final_mcp_settings = await load_mcp_settings_async(
|
||||||
project_dir, mcp_settings, bot_id, robot_type
|
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||||
)
|
)
|
||||||
|
config.system_prompt = final_system_prompt
|
||||||
cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server,
|
config.mcp_settings = final_mcp_settings
|
||||||
generate_cfg, final_system_prompt, final_mcp_settings)
|
cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
|
||||||
|
config.generate_cfg, final_system_prompt, final_mcp_settings,
|
||||||
|
config.enable_thinking)
|
||||||
|
|
||||||
# 获取分片
|
# 获取分片
|
||||||
shard_index = self._get_shard_index(cache_key)
|
shard_index = self._get_shard_index(cache_key)
|
||||||
@ -160,7 +157,7 @@ class ShardedAgentManager:
|
|||||||
with self._stats_lock:
|
with self._stats_lock:
|
||||||
self._global_stats['cache_hits'] += 1
|
self._global_stats['cache_hits'] += 1
|
||||||
|
|
||||||
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id}, shard: {shard_index})")
|
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
# 更新缓存未命中统计
|
# 更新缓存未命中统计
|
||||||
@ -188,27 +185,15 @@ class ShardedAgentManager:
|
|||||||
self._cleanup_old_agents(shard)
|
self._cleanup_old_agents(shard)
|
||||||
|
|
||||||
# 创建新的助手实例
|
# 创建新的助手实例
|
||||||
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}")
|
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
agent = await init_agent(
|
agent = await init_agent(config)
|
||||||
bot_id=bot_id,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
model_server=model_server,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
system_prompt=final_system_prompt,
|
|
||||||
mcp=final_mcp_settings,
|
|
||||||
robot_type=robot_type,
|
|
||||||
language=language,
|
|
||||||
user_identifier=user_identifier,
|
|
||||||
session_id=session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 缓存实例
|
# 缓存实例
|
||||||
async with shard['lock']:
|
async with shard['lock']:
|
||||||
shard['agents'][cache_key] = agent
|
shard['agents'][cache_key] = agent
|
||||||
shard['unique_ids'][cache_key] = bot_id
|
shard['unique_ids'][cache_key] = config.bot_id
|
||||||
shard['access_times'][cache_key] = current_time
|
shard['access_times'][cache_key] = current_time
|
||||||
shard['creation_times'][cache_key] = current_time
|
shard['creation_times'][cache_key] = current_time
|
||||||
|
|
||||||
|
|||||||
220
routes/chat.py
220
routes/chat.py
@ -14,13 +14,14 @@ from utils import (
|
|||||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||||
from utils.api_models import ChatRequestV2
|
from utils.api_models import ChatRequestV2
|
||||||
from utils.fastapi_utils import (
|
from utils.fastapi_utils import (
|
||||||
process_messages, format_messages_to_chat_history,
|
process_messages,
|
||||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||||
call_preamble_llm, get_preamble_text, get_user_last_message_content,
|
call_preamble_llm, get_preamble_text,
|
||||||
create_stream_chunk
|
create_stream_chunk
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||||
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
||||||
|
from utils.agent_config import AgentConfig
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -69,30 +70,45 @@ def append_assistant_last_message(messages: list, content: str) -> bool:
|
|||||||
messages.append({"role":"assistant","content":content})
|
messages.append({"role":"assistant","content":content})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_last_message_content(messages: list) -> str:
|
||||||
|
"""获取最后一条用户消息的内容"""
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get('role') == 'user':
|
||||||
|
return msg.get('content', '')
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def format_messages_to_chat_history(messages: list) -> str:
|
||||||
|
"""将消息格式化为聊天历史字符串"""
|
||||||
|
chat_history = ""
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get('role', '')
|
||||||
|
content = msg.get('content', '')
|
||||||
|
if role == 'user':
|
||||||
|
chat_history += f"用户: {content}\n"
|
||||||
|
elif role == 'assistant':
|
||||||
|
chat_history += f"助手: {content}\n"
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
async def enhanced_generate_stream_response(
|
async def enhanced_generate_stream_response(
|
||||||
agent_manager,
|
agent_manager,
|
||||||
bot_id: str,
|
|
||||||
api_key: str,
|
|
||||||
messages: list,
|
messages: list,
|
||||||
tool_response: bool,
|
config: AgentConfig
|
||||||
model_name: str,
|
|
||||||
model_server: str,
|
|
||||||
language: str,
|
|
||||||
system_prompt: str,
|
|
||||||
mcp_settings: Optional[list],
|
|
||||||
robot_type: str,
|
|
||||||
project_dir: Optional[str],
|
|
||||||
generate_cfg: Optional[dict],
|
|
||||||
user_identifier: Optional[str],
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
"""增强的渐进式流式响应生成器 - 并发优化版本"""
|
"""增强的渐进式流式响应生成器 - 并发优化版本
|
||||||
try:
|
|
||||||
# 准备参数
|
|
||||||
query_text = get_user_last_message_content(messages)
|
|
||||||
chat_history = format_messages_to_chat_history(messages)
|
|
||||||
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_manager: agent管理器
|
||||||
|
messages: 消息列表
|
||||||
|
config: AgentConfig 对象,包含所有参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
# 创建输出队列和控制事件
|
# 创建输出队列和控制事件
|
||||||
output_queue = asyncio.Queue()
|
output_queue = asyncio.Queue()
|
||||||
preamble_completed = asyncio.Event()
|
preamble_completed = asyncio.Event()
|
||||||
@ -100,11 +116,11 @@ async def enhanced_generate_stream_response(
|
|||||||
# Preamble 任务
|
# Preamble 任务
|
||||||
async def preamble_task():
|
async def preamble_task():
|
||||||
try:
|
try:
|
||||||
preamble_result = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
|
preamble_result = await call_preamble_llm(messages,config)
|
||||||
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
||||||
if preamble_result and preamble_result.strip() and preamble_result != "<empty>":
|
if preamble_result and preamble_result.strip() and preamble_result != "<empty>":
|
||||||
preamble_content = f"[PREAMBLE]\n{preamble_result}\n"
|
preamble_content = f"[PREAMBLE]\n{preamble_result}\n"
|
||||||
chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content)
|
chunk_data = create_stream_chunk(f"chatcmpl-preamble", config.model_name, preamble_content)
|
||||||
await output_queue.put(("preamble", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
await output_queue.put(("preamble", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
||||||
logger.info(f"Stream mode: Generated preamble text ({len(preamble_result)} chars)")
|
logger.info(f"Stream mode: Generated preamble text ({len(preamble_result)} chars)")
|
||||||
else:
|
else:
|
||||||
@ -124,32 +140,19 @@ async def enhanced_generate_stream_response(
|
|||||||
async def agent_task():
|
async def agent_task():
|
||||||
try:
|
try:
|
||||||
# 准备 agent
|
# 准备 agent
|
||||||
agent = await agent_manager.get_or_create_agent(
|
agent = await agent_manager.get_or_create_agent(config)
|
||||||
bot_id=bot_id,
|
|
||||||
project_dir=project_dir,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
model_server=model_server,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
language=language,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
mcp_settings=mcp_settings,
|
|
||||||
robot_type=robot_type,
|
|
||||||
user_identifier=user_identifier,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 开始流式处理
|
# 开始流式处理
|
||||||
logger.info(f"Starting agent stream response")
|
logger.info(f"Starting agent stream response")
|
||||||
chunk_id = 0
|
chunk_id = 0
|
||||||
message_tag = ""
|
message_tag = ""
|
||||||
|
|
||||||
config = {}
|
stream_config = {}
|
||||||
if session_id:
|
if config.session_id:
|
||||||
config["configurable"] = {"thread_id": session_id}
|
stream_config["configurable"] = {"thread_id": config.session_id}
|
||||||
if hasattr(agent, 'logging_handler'):
|
if hasattr(agent, 'logging_handler'):
|
||||||
config["callbacks"] = [agent.logging_handler]
|
stream_config["callbacks"] = [agent.logging_handler]
|
||||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS):
|
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=stream_config, max_tokens=MAX_OUTPUT_TOKENS):
|
||||||
new_content = ""
|
new_content = ""
|
||||||
|
|
||||||
if isinstance(msg, AIMessageChunk):
|
if isinstance(msg, AIMessageChunk):
|
||||||
@ -174,7 +177,7 @@ async def enhanced_generate_stream_response(
|
|||||||
new_content += msg.text
|
new_content += msg.text
|
||||||
|
|
||||||
# 处理工具响应
|
# 处理工具响应
|
||||||
elif isinstance(msg, ToolMessage) and tool_response and msg.content:
|
elif isinstance(msg, ToolMessage) and config.tool_response and msg.content:
|
||||||
message_tag = "TOOL_RESPONSE"
|
message_tag = "TOOL_RESPONSE"
|
||||||
new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n"
|
new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n"
|
||||||
|
|
||||||
@ -183,11 +186,11 @@ async def enhanced_generate_stream_response(
|
|||||||
if chunk_id == 0:
|
if chunk_id == 0:
|
||||||
logger.info(f"Agent首个Token已生成, 开始流式输出")
|
logger.info(f"Agent首个Token已生成, 开始流式输出")
|
||||||
chunk_id += 1
|
chunk_id += 1
|
||||||
chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", model_name, new_content)
|
chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", config.model_name, new_content)
|
||||||
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
||||||
|
|
||||||
# 发送最终chunk
|
# 发送最终chunk
|
||||||
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", model_name, finish_reason="stop")
|
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop")
|
||||||
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
|
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
|
||||||
await output_queue.put(("agent_done", None))
|
await output_queue.put(("agent_done", None))
|
||||||
|
|
||||||
@ -196,7 +199,15 @@ async def enhanced_generate_stream_response(
|
|||||||
await output_queue.put(("agent_done", None))
|
await output_queue.put(("agent_done", None))
|
||||||
|
|
||||||
# 并发执行任务
|
# 并发执行任务
|
||||||
preamble_task_handle = asyncio.create_task(preamble_task())
|
# 只有在 enable_thinking 为 True 时才执行 preamble 任务
|
||||||
|
if config.enable_thinking:
|
||||||
|
preamble_task_handle = asyncio.create_task(preamble_task())
|
||||||
|
else:
|
||||||
|
# 如果不启用 thinking,创建一个空的已完成任务
|
||||||
|
preamble_task_handle = asyncio.create_task(asyncio.sleep(0))
|
||||||
|
# 直接标记 preamble 完成
|
||||||
|
preamble_completed.set()
|
||||||
|
|
||||||
agent_task_handle = asyncio.create_task(agent_task())
|
agent_task_handle = asyncio.create_task(agent_task())
|
||||||
|
|
||||||
# 输出控制器:确保 preamble 先输出,然后是 agent stream
|
# 输出控制器:确保 preamble 先输出,然后是 agent stream
|
||||||
@ -214,7 +225,7 @@ async def enhanced_generate_stream_response(
|
|||||||
preamble_output_done = True
|
preamble_output_done = True
|
||||||
|
|
||||||
elif item_type == "preamble_done":
|
elif item_type == "preamble_done":
|
||||||
# Preamble 已完成,标记并继续处理
|
# Preamble 已完成,标记并继续
|
||||||
preamble_output_done = True
|
preamble_output_done = True
|
||||||
|
|
||||||
elif item_type == "agent":
|
elif item_type == "agent":
|
||||||
@ -259,77 +270,43 @@ async def enhanced_generate_stream_response(
|
|||||||
|
|
||||||
|
|
||||||
async def create_agent_and_generate_response(
|
async def create_agent_and_generate_response(
|
||||||
bot_id: str,
|
|
||||||
api_key: str,
|
|
||||||
messages: list,
|
messages: list,
|
||||||
stream: bool,
|
config: AgentConfig
|
||||||
tool_response: bool,
|
|
||||||
model_name: str,
|
|
||||||
model_server: str,
|
|
||||||
language: str,
|
|
||||||
system_prompt: Optional[str],
|
|
||||||
mcp_settings: Optional[list],
|
|
||||||
robot_type: str,
|
|
||||||
project_dir: Optional[str] = None,
|
|
||||||
generate_cfg: Optional[dict] = None,
|
|
||||||
user_identifier: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None
|
|
||||||
) -> Union[ChatResponse, StreamingResponse]:
|
) -> Union[ChatResponse, StreamingResponse]:
|
||||||
"""创建agent并生成响应的公共逻辑"""
|
"""创建agent并生成响应的公共逻辑
|
||||||
if generate_cfg is None:
|
|
||||||
generate_cfg = {}
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
config: AgentConfig 对象,包含所有参数
|
||||||
|
"""
|
||||||
|
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
|
||||||
|
|
||||||
# 如果是流式模式,使用增强的流式响应生成器
|
# 如果是流式模式,使用增强的流式响应生成器
|
||||||
if stream:
|
if config.stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
enhanced_generate_stream_response(
|
enhanced_generate_stream_response(
|
||||||
agent_manager=agent_manager,
|
agent_manager=agent_manager,
|
||||||
bot_id=bot_id,
|
|
||||||
api_key=api_key,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tool_response=tool_response,
|
config=config
|
||||||
model_name=model_name,
|
|
||||||
model_server=model_server,
|
|
||||||
language=language,
|
|
||||||
system_prompt=system_prompt or "",
|
|
||||||
mcp_settings=mcp_settings,
|
|
||||||
robot_type=robot_type,
|
|
||||||
project_dir=project_dir,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
user_identifier=user_identifier,
|
|
||||||
session_id=session_id
|
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||||
)
|
)
|
||||||
|
|
||||||
_, system_prompt = get_preamble_text(language, system_prompt)
|
|
||||||
# 使用公共函数处理所有逻辑
|
# 使用公共函数处理所有逻辑
|
||||||
agent = await agent_manager.get_or_create_agent(
|
agent = await agent_manager.get_or_create_agent(config)
|
||||||
bot_id=bot_id,
|
|
||||||
project_dir=project_dir,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
model_server=model_server,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
language=language,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
mcp_settings=mcp_settings,
|
|
||||||
robot_type=robot_type,
|
|
||||||
user_identifier=user_identifier,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 准备最终的消息
|
# 准备最终的消息
|
||||||
final_messages = messages.copy()
|
final_messages = messages.copy()
|
||||||
|
|
||||||
# 非流式响应
|
# 非流式响应
|
||||||
config = {}
|
agent_config = {}
|
||||||
if session_id:
|
if config.session_id:
|
||||||
config["configurable"] = {"thread_id": session_id}
|
agent_config["configurable"] = {"thread_id": config.session_id}
|
||||||
if hasattr(agent, 'logging_handler'):
|
if hasattr(agent, 'logging_handler'):
|
||||||
config["callbacks"] = [agent.logging_handler]
|
agent_config["callbacks"] = [agent.logging_handler]
|
||||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS)
|
agent_responses = await agent.ainvoke({"messages": final_messages}, config=agent_config, max_tokens=MAX_OUTPUT_TOKENS)
|
||||||
append_messages = agent_responses["messages"][len(final_messages):]
|
append_messages = agent_responses["messages"][len(final_messages):]
|
||||||
response_text = ""
|
response_text = ""
|
||||||
for msg in append_messages:
|
for msg in append_messages:
|
||||||
@ -340,7 +317,7 @@ async def create_agent_and_generate_response(
|
|||||||
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
||||||
if len(msg.tool_calls)>0:
|
if len(msg.tool_calls)>0:
|
||||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||||
elif isinstance(msg,ToolMessage) and tool_response:
|
elif isinstance(msg,ToolMessage) and config.tool_response:
|
||||||
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
|
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
|
||||||
|
|
||||||
if len(response_text) > 0:
|
if len(response_text) > 0:
|
||||||
@ -410,29 +387,16 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
|||||||
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
||||||
|
|
||||||
# 收集额外参数作为 generate_cfg
|
# 收集额外参数作为 generate_cfg
|
||||||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id'}
|
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking'}
|
||||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||||
|
|
||||||
# 处理消息
|
# 处理消息
|
||||||
messages = process_messages(request.messages, request.language)
|
messages = process_messages(request.messages, request.language)
|
||||||
|
# 创建 AgentConfig 对象
|
||||||
|
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg)
|
||||||
# 调用公共的agent创建和响应生成逻辑
|
# 调用公共的agent创建和响应生成逻辑
|
||||||
return await create_agent_and_generate_response(
|
return await create_agent_and_generate_response(
|
||||||
bot_id=bot_id,
|
|
||||||
api_key=api_key,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=request.stream,
|
config=config
|
||||||
tool_response=request.tool_response,
|
|
||||||
model_name=request.model,
|
|
||||||
model_server=request.model_server,
|
|
||||||
language=request.language,
|
|
||||||
system_prompt=request.system_prompt,
|
|
||||||
mcp_settings=request.mcp_settings,
|
|
||||||
robot_type=request.robot_type,
|
|
||||||
project_dir=project_dir,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
user_identifier=request.user_identifier,
|
|
||||||
session_id=request.session_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -496,38 +460,20 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
|||||||
|
|
||||||
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
||||||
bot_config = await fetch_bot_config(bot_id)
|
bot_config = await fetch_bot_config(bot_id)
|
||||||
|
|
||||||
# v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取
|
|
||||||
# 注意:这里的Authorization header已经用于鉴权,不再作为API key使用
|
|
||||||
api_key = bot_config.get("api_key")
|
|
||||||
|
|
||||||
# 创建项目目录(从后端配置获取dataset_ids)
|
# 创建项目目录(从后端配置获取dataset_ids)
|
||||||
project_dir = create_project_directory(
|
project_dir = create_project_directory(
|
||||||
bot_config.get("dataset_ids", []),
|
bot_config.get("dataset_ids", []),
|
||||||
bot_id,
|
bot_id,
|
||||||
bot_config.get("robot_type", "general_agent")
|
bot_config.get("robot_type", "general_agent")
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理消息
|
# 处理消息
|
||||||
messages = process_messages(request.messages, request.language)
|
messages = process_messages(request.messages, request.language)
|
||||||
|
# 创建 AgentConfig 对象
|
||||||
|
config = AgentConfig.from_v2_request(request, bot_config, project_dir)
|
||||||
# 调用公共的agent创建和响应生成逻辑
|
# 调用公共的agent创建和响应生成逻辑
|
||||||
return await create_agent_and_generate_response(
|
return await create_agent_and_generate_response(
|
||||||
bot_id=bot_id,
|
|
||||||
api_key=api_key,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=request.stream,
|
config=config
|
||||||
tool_response=request.tool_response,
|
|
||||||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
|
||||||
model_server=bot_config.get("model_server", ""),
|
|
||||||
language=request.language or bot_config.get("language", "zh"),
|
|
||||||
system_prompt=bot_config.get("system_prompt"),
|
|
||||||
mcp_settings=bot_config.get("mcp_settings", []),
|
|
||||||
robot_type=bot_config.get("robot_type", "general_agent"),
|
|
||||||
project_dir=project_dir,
|
|
||||||
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
|
||||||
user_identifier=request.user_identifier,
|
|
||||||
session_id=request.session_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
97
utils/agent_config.py
Normal file
97
utils/agent_config.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Agent配置类,用于管理所有Agent相关的参数"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
"""Agent配置类,包含创建和管理Agent所需的所有参数"""
|
||||||
|
|
||||||
|
# 基础参数
|
||||||
|
bot_id: str
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
model_name: str = "qwen3-next"
|
||||||
|
model_server: Optional[str] = None
|
||||||
|
language: Optional[str] = "jp"
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
mcp_settings: Optional[List[Dict]] = None
|
||||||
|
robot_type: Optional[str] = "general_agent"
|
||||||
|
generate_cfg: Optional[Dict] = None
|
||||||
|
enable_thinking: bool = True
|
||||||
|
|
||||||
|
# 上下文参数
|
||||||
|
project_dir: Optional[str] = None
|
||||||
|
user_identifier: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
# 响应控制参数
|
||||||
|
stream: bool = False
|
||||||
|
tool_response: bool = True
|
||||||
|
preamble_text: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典格式,用于传递给需要**kwargs的函数"""
|
||||||
|
return {
|
||||||
|
'bot_id': self.bot_id,
|
||||||
|
'api_key': self.api_key,
|
||||||
|
'model_name': self.model_name,
|
||||||
|
'model_server': self.model_server,
|
||||||
|
'language': self.language,
|
||||||
|
'system_prompt': self.system_prompt,
|
||||||
|
'mcp_settings': self.mcp_settings,
|
||||||
|
'robot_type': self.robot_type,
|
||||||
|
'generate_cfg': self.generate_cfg,
|
||||||
|
'enable_thinking': self.enable_thinking,
|
||||||
|
'project_dir': self.project_dir,
|
||||||
|
'user_identifier': self.user_identifier,
|
||||||
|
'session_id': self.session_id,
|
||||||
|
'stream': self.stream,
|
||||||
|
'tool_response': self.tool_response,
|
||||||
|
'preamble_text': self.preamble_text
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None):
|
||||||
|
"""从v1请求创建配置"""
|
||||||
|
return cls(
|
||||||
|
bot_id=request.bot_id,
|
||||||
|
api_key=api_key,
|
||||||
|
model_name=request.model,
|
||||||
|
model_server=request.model_server,
|
||||||
|
language=request.language,
|
||||||
|
system_prompt=request.system_prompt,
|
||||||
|
mcp_settings=request.mcp_settings,
|
||||||
|
robot_type=request.robot_type,
|
||||||
|
user_identifier=request.user_identifier,
|
||||||
|
session_id=request.session_id,
|
||||||
|
enable_thinking=request.enable_thinking,
|
||||||
|
project_dir=project_dir,
|
||||||
|
stream=request.stream,
|
||||||
|
tool_response=request.tool_response,
|
||||||
|
generate_cfg=generate_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None):
|
||||||
|
"""从v2请求创建配置"""
|
||||||
|
return cls(
|
||||||
|
bot_id=request.bot_id,
|
||||||
|
api_key=bot_config.get("api_key"),
|
||||||
|
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||||||
|
model_server=bot_config.get("model_server", ""),
|
||||||
|
language=request.language or bot_config.get("language", "zh"),
|
||||||
|
system_prompt=bot_config.get("system_prompt"),
|
||||||
|
mcp_settings=bot_config.get("mcp_settings", []),
|
||||||
|
robot_type=bot_config.get("robot_type", "general_agent"),
|
||||||
|
user_identifier=request.user_identifier,
|
||||||
|
session_id=request.session_id,
|
||||||
|
enable_thinking=request.enable_thinking,
|
||||||
|
project_dir=project_dir,
|
||||||
|
stream=request.stream,
|
||||||
|
tool_response=request.tool_response,
|
||||||
|
generate_cfg={} # v2接口不传递额外的generate_cfg
|
||||||
|
)
|
||||||
@ -53,6 +53,7 @@ class ChatRequest(BaseModel):
|
|||||||
robot_type: Optional[str] = "general_agent"
|
robot_type: Optional[str] = "general_agent"
|
||||||
user_identifier: Optional[str] = ""
|
user_identifier: Optional[str] = ""
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
|
enable_thinking: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestV2(BaseModel):
|
class ChatRequestV2(BaseModel):
|
||||||
@ -63,6 +64,7 @@ class ChatRequestV2(BaseModel):
|
|||||||
language: Optional[str] = "zh"
|
language: Optional[str] = "zh"
|
||||||
user_identifier: Optional[str] = ""
|
user_identifier: Optional[str] = ""
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
|
enable_thinking: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class FileProcessRequest(BaseModel):
|
class FileProcessRequest(BaseModel):
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import logging
|
|||||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from utils.settings import MASTERKEY, BACKEND_HOST
|
from utils.settings import MASTERKEY, BACKEND_HOST
|
||||||
|
from utils.agent_config import AgentConfig
|
||||||
|
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
@ -500,9 +501,10 @@ def get_preamble_text(language: str, system_prompt: str):
|
|||||||
if preamble_matches:
|
if preamble_matches:
|
||||||
# 提取preamble内容
|
# 提取preamble内容
|
||||||
preamble_content = preamble_matches[0].strip()
|
preamble_content = preamble_matches[0].strip()
|
||||||
# 从system_prompt中删除preamble代码块
|
if preamble_content:
|
||||||
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
|
# 从system_prompt中删除preamble代码块
|
||||||
return preamble_content, cleaned_system_prompt
|
cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL)
|
||||||
|
return preamble_content, cleaned_system_prompt
|
||||||
|
|
||||||
# 如果没有找到preamble代码块,使用默认的preamble选择
|
# 如果没有找到preamble代码块,使用默认的preamble选择
|
||||||
if language == "jp":
|
if language == "jp":
|
||||||
@ -559,12 +561,12 @@ def get_preamble_text(language: str, system_prompt: str):
|
|||||||
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
|
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
|
||||||
|
|
||||||
|
|
||||||
async def call_preamble_llm(chat_history: str, last_message: str, preamble_choices_text: str, language: str, model_name: str, api_key: str, model_server: str) -> str:
|
async def call_preamble_llm(messages: list, config: AgentConfig) -> str:
|
||||||
"""调用大语言模型处理guideline分析
|
"""调用大语言模型处理guideline分析
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_history: 聊天历史记录
|
messages: 消息列表
|
||||||
guidelines_text: 指导原则文本
|
preamble_choices_text: 指导原则文本
|
||||||
model_name: 模型名称
|
model_name: 模型名称
|
||||||
api_key: API密钥
|
api_key: API密钥
|
||||||
model_server: 模型服务器地址
|
model_server: 模型服务器地址
|
||||||
@ -580,6 +582,14 @@ async def call_preamble_llm(chat_history: str, last_message: str, preamble_choic
|
|||||||
logger.error(f"Error reading guideline prompt template: {e}")
|
logger.error(f"Error reading guideline prompt template: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
api_key = config.api_key
|
||||||
|
model_name = config.model_name
|
||||||
|
model_server = config.model_server
|
||||||
|
language = config.language
|
||||||
|
preamble_choices_text = config.preamble_text
|
||||||
|
last_message = get_user_last_message_content(messages)
|
||||||
|
chat_history = format_messages_to_chat_history(messages)
|
||||||
|
|
||||||
# 替换模板中的占位符
|
# 替换模板中的占位符
|
||||||
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
||||||
# 配置LLM
|
# 配置LLM
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user