修改agent_config
This commit is contained in:
parent
de72321875
commit
e36787fb63
@ -1,7 +1,7 @@
|
|||||||
"""Agent配置类,用于管理所有Agent相关的参数"""
|
"""Agent配置类,用于管理所有Agent相关的参数"""
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, TYPE_CHECKING
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ class AgentConfig:
|
|||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
mcp_settings: Optional[List[Dict]] = None
|
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
|
||||||
robot_type: Optional[str] = "general_agent"
|
robot_type: Optional[str] = "general_agent"
|
||||||
generate_cfg: Optional[Dict] = None
|
generate_cfg: Optional[Dict] = None
|
||||||
enable_thinking: bool = False
|
enable_thinking: bool = False
|
||||||
@ -34,6 +34,9 @@ class AgentConfig:
|
|||||||
stream: bool = False
|
stream: bool = False
|
||||||
tool_response: bool = True
|
tool_response: bool = True
|
||||||
preamble_text: Optional[str] = None
|
preamble_text: Optional[str] = None
|
||||||
|
messages: Optional[List] = field(default_factory=list)
|
||||||
|
|
||||||
|
logging_handler: Optional['LoggingCallbackHandler'] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典格式,用于传递给需要**kwargs的函数"""
|
"""转换为字典格式,用于传递给需要**kwargs的函数"""
|
||||||
@ -53,7 +56,8 @@ class AgentConfig:
|
|||||||
'session_id': self.session_id,
|
'session_id': self.session_id,
|
||||||
'stream': self.stream,
|
'stream': self.stream,
|
||||||
'tool_response': self.tool_response,
|
'tool_response': self.tool_response,
|
||||||
'preamble_text': self.preamble_text
|
'preamble_text': self.preamble_text,
|
||||||
|
'messages': self.messages,
|
||||||
}
|
}
|
||||||
|
|
||||||
def safe_print(self):
|
def safe_print(self):
|
||||||
@ -64,8 +68,14 @@ class AgentConfig:
|
|||||||
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
|
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None):
|
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
|
||||||
"""从v1请求创建配置"""
|
"""从v1请求创建配置"""
|
||||||
|
# 延迟导入避免循环依赖
|
||||||
|
from .logging_handler import LoggingCallbackHandler
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
messages = []
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
bot_id=request.bot_id,
|
bot_id=request.bot_id,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@ -81,12 +91,20 @@ class AgentConfig:
|
|||||||
project_dir=project_dir,
|
project_dir=project_dir,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
tool_response=request.tool_response,
|
tool_response=request.tool_response,
|
||||||
generate_cfg=generate_cfg
|
generate_cfg=generate_cfg,
|
||||||
|
logging_handler=LoggingCallbackHandler(),
|
||||||
|
messages=messages
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None):
|
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
|
||||||
"""从v2请求创建配置"""
|
"""从v2请求创建配置"""
|
||||||
|
# 延迟导入避免循环依赖
|
||||||
|
from .logging_handler import LoggingCallbackHandler
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
messages = []
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
bot_id=request.bot_id,
|
bot_id=request.bot_id,
|
||||||
api_key=bot_config.get("api_key"),
|
api_key=bot_config.get("api_key"),
|
||||||
@ -102,5 +120,16 @@ class AgentConfig:
|
|||||||
project_dir=project_dir,
|
project_dir=project_dir,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
tool_response=request.tool_response,
|
tool_response=request.tool_response,
|
||||||
generate_cfg={} # v2接口不传递额外的generate_cfg
|
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
||||||
)
|
logging_handler=LoggingCallbackHandler(),
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_config(self):
|
||||||
|
"""返回Langchain需要的配置字典"""
|
||||||
|
config = {}
|
||||||
|
if self.logging_handler:
|
||||||
|
config["callbacks"] = [self.logging_handler]
|
||||||
|
if self.session_id:
|
||||||
|
config["configurable"] = {"thread_id": self.session_id}
|
||||||
|
return config
|
||||||
@ -8,70 +8,13 @@ from langchain.chat_models import init_chat_model
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import SummarizationMiddleware
|
from langchain.agents.middleware import SummarizationMiddleware
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from utils.fastapi_utils import detect_provider
|
from utils.fastapi_utils import detect_provider
|
||||||
|
|
||||||
from .guideline_middleware import GuidelineMiddleware
|
from .guideline_middleware import GuidelineMiddleware
|
||||||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||||||
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
||||||
from utils.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
|
||||||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
|
||||||
|
|
||||||
def __init__(self, logger_name: str = 'app'):
|
|
||||||
self.logger = logging.getLogger(logger_name)
|
|
||||||
|
|
||||||
def on_llm_end(self, response, **kwargs: Any) -> None:
|
|
||||||
"""当 LLM 结束时调用"""
|
|
||||||
self.logger.info("✅ LLM End - Output:")
|
|
||||||
|
|
||||||
# 打印生成的文本
|
|
||||||
if hasattr(response, 'generations') and response.generations:
|
|
||||||
for gen_idx, generation_list in enumerate(response.generations):
|
|
||||||
for msg_idx, generation in enumerate(generation_list):
|
|
||||||
if hasattr(generation, 'text'):
|
|
||||||
output_list = generation.text.split("\n")
|
|
||||||
for i, output in enumerate(output_list):
|
|
||||||
if output.strip():
|
|
||||||
self.logger.info(f"{output}")
|
|
||||||
elif hasattr(generation, 'message'):
|
|
||||||
output_list = generation.message.split("\n")
|
|
||||||
for i, output in enumerate(output_list):
|
|
||||||
if output.strip():
|
|
||||||
self.logger.info(f"{output}")
|
|
||||||
|
|
||||||
def on_llm_error(
|
|
||||||
self, error: Exception, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""当 LLM 出错时调用"""
|
|
||||||
self.logger.error(f"❌ LLM Error: {error}")
|
|
||||||
|
|
||||||
def on_tool_start(
|
|
||||||
self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""当工具开始调用时调用"""
|
|
||||||
if serialized is None:
|
|
||||||
tool_name = 'unknown_tool'
|
|
||||||
else:
|
|
||||||
tool_name = serialized.get('name', 'unknown_tool')
|
|
||||||
self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}")
|
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
|
||||||
"""当工具调用结束时调用"""
|
|
||||||
self.logger.info(f"✅ Tool End Output: {output}")
|
|
||||||
|
|
||||||
def on_tool_error(
|
|
||||||
self, error: Exception, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""当工具调用出错时调用"""
|
|
||||||
self.logger.error(f"❌ Tool Error: {error}")
|
|
||||||
|
|
||||||
def on_agent_action(self, action, **kwargs: Any) -> None:
|
|
||||||
"""当 Agent 执行动作时调用"""
|
|
||||||
self.logger.info(f"🎯 Agent Action: {action.log}")
|
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
@ -124,9 +67,9 @@ async def init_agent(config: AgentConfig):
|
|||||||
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
||||||
"""
|
"""
|
||||||
# 如果没有提供mcp,使用config中的mcp_settings
|
# 如果没有提供mcp,使用config中的mcp_settings
|
||||||
mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings()
|
mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings()
|
||||||
system = config.system_prompt if config.system_prompt else read_system_prompt()
|
system_prompt = config.system_prompt if config.system_prompt else read_system_prompt()
|
||||||
mcp_tools = await get_tools_from_mcp(mcp)
|
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||||
|
|
||||||
# 检测或使用指定的提供商
|
# 检测或使用指定的提供商
|
||||||
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
||||||
@ -143,15 +86,11 @@ async def init_agent(config: AgentConfig):
|
|||||||
model_kwargs.update(config.generate_cfg)
|
model_kwargs.update(config.generate_cfg)
|
||||||
llm_instance = init_chat_model(**model_kwargs)
|
llm_instance = init_chat_model(**model_kwargs)
|
||||||
|
|
||||||
# 创建自定义的日志处理器
|
|
||||||
logging_handler = LoggingCallbackHandler()
|
|
||||||
|
|
||||||
# 构建中间件列表
|
# 构建中间件列表
|
||||||
middleware = []
|
middleware = []
|
||||||
|
|
||||||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||||
if config.enable_thinking:
|
if config.enable_thinking:
|
||||||
middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier))
|
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
|
||||||
|
|
||||||
# 添加工具输出长度控制中间件
|
# 添加工具输出长度控制中间件
|
||||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||||
@ -179,15 +118,10 @@ async def init_agent(config: AgentConfig):
|
|||||||
|
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
system_prompt=system,
|
system_prompt=system_prompt,
|
||||||
tools=mcp_tools,
|
tools=mcp_tools,
|
||||||
middleware=middleware,
|
middleware=middleware,
|
||||||
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
||||||
)
|
)
|
||||||
|
|
||||||
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
|
|
||||||
agent.logging_handler = logging_handler
|
|
||||||
agent.checkpointer = checkpointer
|
|
||||||
agent.bot_id = config.bot_id
|
|
||||||
agent.session_id = config.session_id
|
|
||||||
return agent
|
return agent
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
from ast import Str
|
||||||
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
|
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
|
||||||
from langchain_core.messages import convert_to_openai_messages
|
from langchain_core.messages import convert_to_openai_messages
|
||||||
from agent.prompt_loader import load_guideline_prompt
|
from agent.prompt_loader import load_guideline_prompt
|
||||||
@ -9,15 +10,18 @@ from langchain_core.messages import SystemMessage
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
|
from .agent_config import AgentConfig
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
class GuidelineMiddleware(AgentMiddleware):
|
class GuidelineMiddleware(AgentMiddleware):
|
||||||
def __init__(self, bot_id: str, model:BaseChatModel, prompt: str, robot_type: str, language: str, user_identifier: str):
|
def __init__(self, model:BaseChatModel, config:AgentConfig, prompt: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.bot_id = bot_id
|
self.bot_id = config.bot_id
|
||||||
|
|
||||||
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
|
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
|
||||||
|
|
||||||
self.processed_system_prompt = processed_system_prompt
|
self.processed_system_prompt = processed_system_prompt
|
||||||
@ -25,10 +29,10 @@ class GuidelineMiddleware(AgentMiddleware):
|
|||||||
self.tool_description = tool_description
|
self.tool_description = tool_description
|
||||||
self.scenarios = scenarios
|
self.scenarios = scenarios
|
||||||
|
|
||||||
self.language = language
|
self.language = config.language
|
||||||
self.user_identifier = user_identifier
|
self.user_identifier = config.user_identifier
|
||||||
|
|
||||||
self.robot_type = robot_type
|
self.robot_type = config.robot_type
|
||||||
self.terms_list = terms_list
|
self.terms_list = terms_list
|
||||||
|
|
||||||
if self.robot_type == "general_agent":
|
if self.robot_type == "general_agent":
|
||||||
|
|||||||
57
agent/logging_handler.py
Normal file
57
agent/logging_handler.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""日志回调处理器模块"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional, Dict
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||||||
|
|
||||||
|
def __init__(self, logger_name: str = 'app'):
|
||||||
|
self.logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
|
def on_llm_end(self, response, **kwargs: Any) -> None:
|
||||||
|
"""当 LLM 结束时调用"""
|
||||||
|
self.logger.info("✅ LLM End - Output:")
|
||||||
|
|
||||||
|
# 打印生成的文本
|
||||||
|
if hasattr(response, 'generations') and response.generations:
|
||||||
|
for gen_idx, generation_list in enumerate(response.generations):
|
||||||
|
for msg_idx, generation in enumerate(generation_list):
|
||||||
|
if hasattr(generation, 'text'):
|
||||||
|
output_list = generation.text.split("\n")
|
||||||
|
for i, output in enumerate(output_list):
|
||||||
|
if output.strip():
|
||||||
|
self.logger.info(f"{output}")
|
||||||
|
elif hasattr(generation, 'message'):
|
||||||
|
output_list = generation.message.split("\n")
|
||||||
|
for i, output in enumerate(output_list):
|
||||||
|
if output.strip():
|
||||||
|
self.logger.info(f"{output}")
|
||||||
|
|
||||||
|
def on_llm_error(
|
||||||
|
self, error: Exception, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""当 LLM 出错时调用"""
|
||||||
|
self.logger.error(f"❌ LLM Error: {error}")
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""当工具开始调用时调用"""
|
||||||
|
if serialized is None:
|
||||||
|
tool_name = 'unknown_tool'
|
||||||
|
else:
|
||||||
|
tool_name = serialized.get('name', 'unknown_tool')
|
||||||
|
self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}")
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""当工具调用结束时调用"""
|
||||||
|
self.logger.info(f"✅ Tool End Output: {output}")
|
||||||
|
|
||||||
|
def on_tool_error(
|
||||||
|
self, error: Exception, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""当工具调用出错时调用"""
|
||||||
|
self.logger.error(f"❌ Tool Error: {error}")
|
||||||
@ -26,7 +26,7 @@ logger = logging.getLogger('app')
|
|||||||
|
|
||||||
from agent.deep_assistant import init_agent
|
from agent.deep_assistant import init_agent
|
||||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||||
from utils.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
class ShardedAgentManager:
|
class ShardedAgentManager:
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from utils.fastapi_utils import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||||
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
||||||
from utils.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -98,14 +98,12 @@ def format_messages_to_chat_history(messages: list) -> str:
|
|||||||
|
|
||||||
async def enhanced_generate_stream_response(
|
async def enhanced_generate_stream_response(
|
||||||
agent_manager,
|
agent_manager,
|
||||||
messages: list,
|
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
):
|
):
|
||||||
"""增强的渐进式流式响应生成器 - 并发优化版本
|
"""增强的渐进式流式响应生成器 - 并发优化版本
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_manager: agent管理器
|
agent_manager: agent管理器
|
||||||
messages: 消息列表
|
|
||||||
config: AgentConfig 对象,包含所有参数
|
config: AgentConfig 对象,包含所有参数
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@ -116,7 +114,7 @@ async def enhanced_generate_stream_response(
|
|||||||
# Preamble 任务
|
# Preamble 任务
|
||||||
async def preamble_task():
|
async def preamble_task():
|
||||||
try:
|
try:
|
||||||
preamble_result = await call_preamble_llm(messages,config)
|
preamble_result = await call_preamble_llm(config)
|
||||||
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
||||||
if preamble_result and preamble_result.strip() and preamble_result != "<empty>":
|
if preamble_result and preamble_result.strip() and preamble_result != "<empty>":
|
||||||
preamble_content = f"[PREAMBLE]\n{preamble_result}\n"
|
preamble_content = f"[PREAMBLE]\n{preamble_result}\n"
|
||||||
@ -147,12 +145,7 @@ async def enhanced_generate_stream_response(
|
|||||||
chunk_id = 0
|
chunk_id = 0
|
||||||
message_tag = ""
|
message_tag = ""
|
||||||
|
|
||||||
stream_config = {}
|
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
||||||
if config.session_id:
|
|
||||||
stream_config["configurable"] = {"thread_id": config.session_id}
|
|
||||||
if hasattr(agent, 'logging_handler'):
|
|
||||||
stream_config["callbacks"] = [agent.logging_handler]
|
|
||||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=stream_config, max_tokens=MAX_OUTPUT_TOKENS):
|
|
||||||
new_content = ""
|
new_content = ""
|
||||||
|
|
||||||
if isinstance(msg, AIMessageChunk):
|
if isinstance(msg, AIMessageChunk):
|
||||||
@ -270,17 +263,14 @@ async def enhanced_generate_stream_response(
|
|||||||
|
|
||||||
|
|
||||||
async def create_agent_and_generate_response(
|
async def create_agent_and_generate_response(
|
||||||
messages: list,
|
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
) -> Union[ChatResponse, StreamingResponse]:
|
) -> Union[ChatResponse, StreamingResponse]:
|
||||||
"""创建agent并生成响应的公共逻辑
|
"""创建agent并生成响应的公共逻辑
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
|
||||||
config: AgentConfig 对象,包含所有参数
|
config: AgentConfig 对象,包含所有参数
|
||||||
"""
|
"""
|
||||||
config.safe_print()
|
config.safe_print()
|
||||||
logger.info(f"messages={json.dumps(messages, ensure_ascii=False)}")
|
|
||||||
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
|
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
|
||||||
|
|
||||||
# 如果是流式模式,使用增强的流式响应生成器
|
# 如果是流式模式,使用增强的流式响应生成器
|
||||||
@ -288,28 +278,17 @@ async def create_agent_and_generate_response(
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
enhanced_generate_stream_response(
|
enhanced_generate_stream_response(
|
||||||
agent_manager=agent_manager,
|
agent_manager=agent_manager,
|
||||||
messages=messages,
|
|
||||||
config=config
|
config=config
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
messages = config.messages
|
||||||
# 使用公共函数处理所有逻辑
|
# 使用公共函数处理所有逻辑
|
||||||
agent = await agent_manager.get_or_create_agent(config)
|
agent = await agent_manager.get_or_create_agent(config)
|
||||||
|
agent_responses = await agent.ainvoke({"messages": messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||||
# 准备最终的消息
|
append_messages = agent_responses["messages"][len(messages):]
|
||||||
final_messages = messages.copy()
|
|
||||||
|
|
||||||
# 非流式响应
|
|
||||||
agent_config = {}
|
|
||||||
if config.session_id:
|
|
||||||
agent_config["configurable"] = {"thread_id": config.session_id}
|
|
||||||
if hasattr(agent, 'logging_handler'):
|
|
||||||
agent_config["callbacks"] = [agent.logging_handler]
|
|
||||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=agent_config, max_tokens=MAX_OUTPUT_TOKENS)
|
|
||||||
append_messages = agent_responses["messages"][len(final_messages):]
|
|
||||||
response_text = ""
|
response_text = ""
|
||||||
for msg in append_messages:
|
for msg in append_messages:
|
||||||
if isinstance(msg,AIMessage):
|
if isinstance(msg,AIMessage):
|
||||||
@ -394,10 +373,9 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
|||||||
# 处理消息
|
# 处理消息
|
||||||
messages = process_messages(request.messages, request.language)
|
messages = process_messages(request.messages, request.language)
|
||||||
# 创建 AgentConfig 对象
|
# 创建 AgentConfig 对象
|
||||||
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg)
|
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
|
||||||
# 调用公共的agent创建和响应生成逻辑
|
# 调用公共的agent创建和响应生成逻辑
|
||||||
return await create_agent_and_generate_response(
|
return await create_agent_and_generate_response(
|
||||||
messages=messages,
|
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -471,10 +449,9 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
|||||||
# 处理消息
|
# 处理消息
|
||||||
messages = process_messages(request.messages, request.language)
|
messages = process_messages(request.messages, request.language)
|
||||||
# 创建 AgentConfig 对象
|
# 创建 AgentConfig 对象
|
||||||
config = AgentConfig.from_v2_request(request, bot_config, project_dir)
|
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
|
||||||
# 调用公共的agent创建和响应生成逻辑
|
# 调用公共的agent创建和响应生成逻辑
|
||||||
return await create_agent_and_generate_response(
|
return await create_agent_and_generate_response(
|
||||||
messages=messages,
|
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import logging
|
|||||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from utils.settings import MASTERKEY, BACKEND_HOST
|
from utils.settings import MASTERKEY, BACKEND_HOST
|
||||||
from utils.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
|
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
@ -561,7 +561,7 @@ def get_preamble_text(language: str, system_prompt: str):
|
|||||||
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
|
return default_preamble, system_prompt # 返回默认preamble和原始system_prompt
|
||||||
|
|
||||||
|
|
||||||
async def call_preamble_llm(messages: list, config: AgentConfig) -> str:
|
async def call_preamble_llm(config: AgentConfig) -> str:
|
||||||
"""调用大语言模型处理guideline分析
|
"""调用大语言模型处理guideline分析
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -587,8 +587,8 @@ async def call_preamble_llm(messages: list, config: AgentConfig) -> str:
|
|||||||
model_server = config.model_server
|
model_server = config.model_server
|
||||||
language = config.language
|
language = config.language
|
||||||
preamble_choices_text = config.preamble_text
|
preamble_choices_text = config.preamble_text
|
||||||
last_message = get_user_last_message_content(messages)
|
last_message = get_user_last_message_content(config.messages)
|
||||||
chat_history = format_messages_to_chat_history(messages)
|
chat_history = format_messages_to_chat_history(config.messages)
|
||||||
|
|
||||||
# 替换模板中的占位符
|
# 替换模板中的占位符
|
||||||
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user