update guidline
This commit is contained in:
parent
66b816c3b2
commit
ec9558fd4c
@ -1,10 +1,13 @@
|
||||
import json
|
||||
from langchain.chat_models import init_chat_model
|
||||
from deepagents import create_deep_agent
|
||||
# from deepagents import create_deep_agent
|
||||
from langchain.agents import create_agent
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from utils.fastapi_utils import detect_provider
|
||||
|
||||
from .guideline_middleware import GuidelineMiddleware
|
||||
|
||||
|
||||
# Utility functions
|
||||
def read_system_prompt():
|
||||
"""读取通用的无状态系统prompt"""
|
||||
@ -18,9 +21,10 @@ def read_mcp_settings():
|
||||
mcp_settings_json = json.load(f)
|
||||
return mcp_settings_json
|
||||
|
||||
async def init_agent(model_name="qwen3-next", api_key=None,
|
||||
|
||||
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
||||
model_server=None, generate_cfg=None,
|
||||
system_prompt=None, mcp=None):
|
||||
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None):
|
||||
system = system_prompt if system_prompt else read_system_prompt()
|
||||
mcp = mcp if mcp else read_mcp_settings()
|
||||
# 修改mcp[0]["mcpServers"]列表,把 type 字段改成 transport ,如果没有的话,就默认transport:stdio
|
||||
@ -50,6 +54,7 @@ async def init_agent(model_name="qwen3-next", api_key=None,
|
||||
agent = create_agent(
|
||||
model=llm_instance,
|
||||
system_prompt=system,
|
||||
tools=mcp_tools
|
||||
tools=mcp_tools,
|
||||
middleware=[GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
|
||||
)
|
||||
return agent
|
||||
|
||||
148
agent/guideline_middleware.py
Normal file
148
agent/guideline_middleware.py
Normal file
@ -0,0 +1,148 @@
|
||||
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import convert_to_openai_messages
|
||||
from agent.prompt_loader import load_guideline_prompt
|
||||
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
|
||||
from langchain.chat_models import BaseChatModel
|
||||
from langgraph.runtime import Runtime
|
||||
from typing import Any, Callable
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
class ThinkingCallbackHandler(BaseCallbackHandler):
|
||||
"""自定义回调处理器,用于将模型响应内容转换为thinking格式"""
|
||||
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
|
||||
"""在LLM调用结束后处理响应,将内容转换为thinking格式"""
|
||||
logger.info("Successfully converted response content to thinking format")
|
||||
|
||||
class GuidelineMiddleware(AgentMiddleware):
|
||||
def __init__(self, bot_id: str, model:BaseChatModel, prompt: str, robot_type: str, language: str, user_identifier: str):
|
||||
self.model = model
|
||||
self.bot_id = bot_id
|
||||
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
|
||||
|
||||
self.processed_system_prompt = processed_system_prompt
|
||||
self.guidelines = guidelines
|
||||
self.tool_description = tool_description
|
||||
self.scenarios = scenarios
|
||||
|
||||
self.language = language
|
||||
self.user_identifier = user_identifier
|
||||
|
||||
self.robot_type = robot_type
|
||||
self.terms_list = terms_list
|
||||
|
||||
if self.robot_type == "general_agent":
|
||||
if not self.guidelines:
|
||||
self.guidelines = """
|
||||
1. General Inquiries
|
||||
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
|
||||
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
|
||||
|
||||
2.Social Dialogue
|
||||
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
|
||||
Action: Provide concise, friendly, and personified natural responses.
|
||||
"""
|
||||
if not self.tool_description:
|
||||
self.tool_description = """
|
||||
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
|
||||
"""
|
||||
|
||||
def get_guideline_prompt(self, messages: list[dict[str, Any]]) -> str:
|
||||
## 处理terms
|
||||
terms_analysis = self.get_term_analysis(messages)
|
||||
|
||||
guideline_prompt = ""
|
||||
|
||||
if self.guidelines:
|
||||
chat_history = format_messages_to_chat_history(messages)
|
||||
query_text = get_user_last_message_content(messages)
|
||||
guideline_prompt = load_guideline_prompt(chat_history, query_text, self.guidelines, self.tool_description, self.scenarios, terms_analysis, self.language, self.user_identifier)
|
||||
|
||||
return guideline_prompt
|
||||
|
||||
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
|
||||
## 处理terms
|
||||
terms_analysis = ""
|
||||
if self.terms_list:
|
||||
logger.info(f"Processing terms: {len(self.terms_list)} terms")
|
||||
try:
|
||||
from embedding.embedding import process_terms_with_embedding
|
||||
query_text = get_user_last_message_content(messages)
|
||||
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
|
||||
if terms_analysis:
|
||||
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
|
||||
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing terms with embedding: {e}")
|
||||
terms_analysis = ""
|
||||
else:
|
||||
# 当terms_list为空时,删除对应的pkl缓存文件
|
||||
try:
|
||||
import os
|
||||
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
logger.info(f"Removed empty terms cache file: {cache_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing terms cache file: {e}")
|
||||
return terms_analysis
|
||||
|
||||
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if not self.guidelines:
|
||||
return None
|
||||
|
||||
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
|
||||
|
||||
# 创建回调处理器实例
|
||||
thinking_handler = ThinkingCallbackHandler()
|
||||
|
||||
# 使用回调处理器调用模型
|
||||
response = self.model.invoke(
|
||||
guideline_prompt,
|
||||
config={"callbacks": [thinking_handler]}
|
||||
)
|
||||
|
||||
response.additional_kwargs["thinking"] = response.content
|
||||
|
||||
messages = state['messages']+[response]
|
||||
return {
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if not self.guidelines:
|
||||
return None
|
||||
|
||||
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
|
||||
|
||||
# 使用回调处理器调用模型
|
||||
response = await self.model.ainvoke(
|
||||
guideline_prompt,
|
||||
config={"callbacks": [ThinkingCallbackHandler()]}
|
||||
)
|
||||
|
||||
response.additional_kwargs["thinking"] = response.content
|
||||
|
||||
messages = state['messages']+[response]
|
||||
return {
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request.override(system_prompt=self.processed_system_prompt))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return await handler(request.override(system_prompt=self.processed_system_prompt))
|
||||
@ -11,35 +11,6 @@ import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
def safe_replace(text: str, placeholder: str, value: Any) -> str:
|
||||
"""
|
||||
安全的字符串替换函数,确保 value 被转换为字符串
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
placeholder: 要替换的占位符(如 '{user_identifier}')
|
||||
value: 用于替换的值(可以是任意类型)
|
||||
|
||||
Returns:
|
||||
str: 替换后的文本
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
|
||||
# 如果占位符为空,不进行替换
|
||||
if not placeholder:
|
||||
return text
|
||||
|
||||
# 将 value 转换为字符串,处理 None 等特殊情况
|
||||
if value is None:
|
||||
replacement = ""
|
||||
else:
|
||||
replacement = str(value)
|
||||
|
||||
return text.replace(placeholder, replacement)
|
||||
|
||||
|
||||
def format_datetime_by_language(language: str) -> str:
|
||||
"""
|
||||
根据语言格式化当前时间字符串,以UTC时间为基准计算各时区时间
|
||||
@ -126,13 +97,12 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
||||
# 获取格式化的时间字符串
|
||||
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
||||
|
||||
prompt = system_prompt or ""
|
||||
# 如果存在{language} 占位符,那么就直接使用 system_prompt
|
||||
if robot_type == "general_agent" or robot_type == "catalog_agent":
|
||||
"""
|
||||
优先使用项目目录的README.md,没有才使用默认的system_prompt_{robot_type}.md
|
||||
"""
|
||||
system_prompt_default = None
|
||||
system_prompt_default = ""
|
||||
|
||||
try:
|
||||
# 使用缓存读取默认prompt文件
|
||||
@ -142,22 +112,17 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
||||
logger.info(f"Using cached default system prompt for {robot_type} from prompt folder")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load default system prompt for {robot_type}: {str(e)}")
|
||||
system_prompt_default = None
|
||||
system_prompt_default = ""
|
||||
|
||||
readme = ""
|
||||
# 只有当 project_dir 不为 None 时才尝试读取 README.md
|
||||
if project_dir is not None:
|
||||
readme_path = os.path.join(project_dir, "README.md")
|
||||
readme = await config_cache.get_text_file(readme_path) or ""
|
||||
if system_prompt_default:
|
||||
system_prompt_default = safe_replace(system_prompt_default, "{readme}", str(readme))
|
||||
|
||||
prompt = system_prompt_default or ""
|
||||
prompt = safe_replace(prompt, "{extra_prompt}", system_prompt or "")
|
||||
|
||||
prompt = safe_replace(prompt, "{language}", language_display)
|
||||
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
|
||||
prompt = safe_replace(prompt, '{datetime}', datetime_str)
|
||||
prompt = system_prompt_default.format(readme=str(readme), extra_prompt=system_prompt or "",language=language_display, user_identifier=user_identifier, datetime=datetime_str)
|
||||
elif system_prompt:
|
||||
prompt = system_prompt.format(language=language_display, user_identifier=user_identifier, datetime=datetime_str)
|
||||
return prompt or ""
|
||||
|
||||
|
||||
@ -175,16 +140,15 @@ def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id:
|
||||
for key, value in obj.items():
|
||||
if key == 'args' and isinstance(value, list):
|
||||
# 特别处理 args 列表
|
||||
obj[key] = [safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
|
||||
obj[key] = [item.format(dataset_dir=dataset_dir, bot_id=bot_id) if isinstance(item, str) else item
|
||||
for item in value]
|
||||
elif isinstance(value, (dict, list)):
|
||||
obj[key] = replace_placeholders_in_obj(value)
|
||||
elif isinstance(value, str):
|
||||
obj[key] = safe_replace(value, '{dataset_dir}', dataset_dir)
|
||||
obj[key] = safe_replace(obj[key], '{bot_id}', bot_id)
|
||||
obj[key] = value.format(dataset_dir=dataset_dir, bot_id=bot_id)
|
||||
elif isinstance(obj, list):
|
||||
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
||||
safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
|
||||
item.format(dataset_dir=dataset_dir, bot_id=bot_id) if isinstance(item, str) else item
|
||||
for item in obj]
|
||||
return obj
|
||||
|
||||
@ -280,7 +244,7 @@ async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot
|
||||
return merged_settings
|
||||
|
||||
|
||||
async def load_guideline_prompt(chat_history: str, last_message: str, guidelines_text: str, tools: str, scenarios: str, terms: str, language: str, user_identifier: str = "") -> str:
|
||||
def load_guideline_prompt(chat_history: str, last_message: str, guidelines_text: str, tools: str, scenarios: str, terms: str, language: str, user_identifier: str = "") -> str:
|
||||
"""
|
||||
加载并处理guideline提示词
|
||||
|
||||
@ -295,16 +259,9 @@ async def load_guideline_prompt(chat_history: str, last_message: str, guidelines
|
||||
Returns:
|
||||
str: 处理后的guideline提示词
|
||||
"""
|
||||
try:
|
||||
from agent.config_cache import config_cache
|
||||
guideline_template_file = os.path.join("prompt", "guideline_prompt.md")
|
||||
guideline_template = await config_cache.get_text_file(guideline_template_file)
|
||||
if guideline_template is None:
|
||||
logger.error("Failed to load guideline prompt template from cache")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading guideline prompt template: {e}")
|
||||
return ""
|
||||
guideline_template_file = os.path.join("prompt", "guideline_prompt.md")
|
||||
with open(guideline_template_file, 'r', encoding='utf-8') as f:
|
||||
guideline_template = f.read()
|
||||
|
||||
# 获取语言显示文本
|
||||
language_display_map = {
|
||||
@ -316,15 +273,17 @@ async def load_guideline_prompt(chat_history: str, last_message: str, guidelines
|
||||
language_display = language_display_map.get(language, language if language else 'English')
|
||||
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
||||
# 替换模板中的占位符
|
||||
system_prompt = safe_replace(guideline_template, '{chat_history}', chat_history)
|
||||
system_prompt = safe_replace(system_prompt, '{last_message}', last_message)
|
||||
system_prompt = safe_replace(system_prompt, '{guidelines_text}', guidelines_text)
|
||||
system_prompt = safe_replace(system_prompt, '{terms}', terms)
|
||||
system_prompt = safe_replace(system_prompt, '{tools}', tools)
|
||||
system_prompt = safe_replace(system_prompt, '{scenarios}', scenarios)
|
||||
system_prompt = safe_replace(system_prompt, '{language}', language_display)
|
||||
system_prompt = safe_replace(system_prompt, '{user_identifier}', user_identifier)
|
||||
system_prompt = safe_replace(system_prompt, '{datetime}', datetime_str)
|
||||
system_prompt = guideline_template.format(
|
||||
chat_history=chat_history,
|
||||
last_message=last_message,
|
||||
guidelines_text=guidelines_text,
|
||||
terms=terms,
|
||||
tools=tools,
|
||||
scenarios=scenarios,
|
||||
language=language_display,
|
||||
user_identifier=user_identifier,
|
||||
datetime=datetime_str
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
|
||||
@ -191,12 +191,16 @@ class ShardedAgentManager:
|
||||
current_time = time.time()
|
||||
|
||||
agent = await init_agent(
|
||||
bot_id=bot_id,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
system_prompt=final_system_prompt,
|
||||
mcp=final_mcp_settings
|
||||
mcp=final_mcp_settings,
|
||||
robot_type=robot_type,
|
||||
language=language,
|
||||
user_identifier=user_identifier,
|
||||
)
|
||||
|
||||
# 缓存实例
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
- **步骤制定**: 只需要列出详细的工具调用步骤即可,不需要考虑回复用户的话术,步骤数量根据实际需求调整。
|
||||
|
||||
请按照上述思考框架进行完整分析,确保理解目标、分析问题和制定计划,
|
||||
---
|
||||
|
||||
## 聊天记录 (Chat History)
|
||||
```
|
||||
@ -48,7 +47,6 @@
|
||||
```
|
||||
{guidelines_text}
|
||||
```
|
||||
---
|
||||
|
||||
## 系统信息
|
||||
- **当前用户**: {user_identifier}
|
||||
|
||||
220
routes/chat.py
220
routes/chat.py
@ -15,9 +15,9 @@ from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||
from utils.api_models import ChatRequestV2
|
||||
from agent.prompt_loader import load_guideline_prompt
|
||||
from utils.fastapi_utils import (
|
||||
process_messages, extract_block_from_system_prompt, format_messages_to_chat_history,
|
||||
process_messages, format_messages_to_chat_history,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||
process_guideline, call_preamble_llm, get_preamble_text, get_language_text,
|
||||
call_preamble_llm, get_preamble_text, get_user_last_message_content,
|
||||
create_stream_chunk
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||
@ -31,14 +31,7 @@ agent_manager = init_global_sharded_agent_manager(
|
||||
)
|
||||
|
||||
|
||||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||||
"""获取消息列表中的最后一条消息"""
|
||||
if not messages or len(messages) == 0:
|
||||
return ""
|
||||
last_message = messages[-1]
|
||||
if last_message and last_message.get('role') == 'user':
|
||||
return last_message["content"]
|
||||
return ""
|
||||
|
||||
|
||||
def append_user_last_message(messages: list, content: str) -> bool:
|
||||
"""向最后一条用户消息追加内容
|
||||
@ -78,128 +71,6 @@ def append_assistant_last_message(messages: list, content: str) -> bool:
|
||||
messages.append({"role":"assistant","content":content})
|
||||
return messages
|
||||
|
||||
|
||||
async def process_guidelines_and_terms(
|
||||
bot_id: str,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
model_server: str,
|
||||
system_prompt: str,
|
||||
messages: list,
|
||||
agent_manager,
|
||||
project_dir: Optional[str],
|
||||
generate_cfg: Optional[dict],
|
||||
language: str,
|
||||
mcp_settings: Optional[list],
|
||||
robot_type: str,
|
||||
user_identifier: Optional[str]
|
||||
) -> tuple:
|
||||
"""
|
||||
公共函数:处理guideline分析和terms处理,返回agent和analysis结果
|
||||
|
||||
Returns:
|
||||
tuple: (agent, processed_system_prompt, guideline_reasoning, terms_analysis)
|
||||
"""
|
||||
# 提取system_prompt中的guideline和terms
|
||||
processed_system_prompt, guidelines, tools, scenarios, terms_list = extract_block_from_system_prompt(system_prompt)
|
||||
|
||||
# # 处理terms
|
||||
terms_analysis = ""
|
||||
if terms_list:
|
||||
logger.info(f"Processing terms: {len(terms_list)} terms")
|
||||
try:
|
||||
from embedding.embedding import process_terms_with_embedding
|
||||
query_text = get_user_last_message_content(messages)
|
||||
terms_analysis = process_terms_with_embedding(terms_list, bot_id, query_text)
|
||||
if terms_analysis:
|
||||
processed_system_prompt = processed_system_prompt.replace("{terms}", terms_analysis)
|
||||
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing terms with embedding: {e}")
|
||||
terms_analysis = ""
|
||||
else:
|
||||
# 当terms_list为空时,删除对应的pkl缓存文件
|
||||
try:
|
||||
import os
|
||||
cache_file = f"projects/cache/{bot_id}_terms.pkl"
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
logger.info(f"Removed empty terms cache file: {cache_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing terms cache file: {e}")
|
||||
|
||||
# 创建所有任务
|
||||
tasks = []
|
||||
|
||||
# 添加agent创建任务
|
||||
agent_task = agent_manager.get_or_create_agent(
|
||||
bot_id=bot_id,
|
||||
project_dir=project_dir,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
language=language,
|
||||
system_prompt=processed_system_prompt,
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
user_identifier=user_identifier
|
||||
)
|
||||
tasks.append(agent_task)
|
||||
|
||||
guideline_prompt = ""
|
||||
|
||||
if robot_type == "general_agent":
|
||||
if not guidelines:
|
||||
guidelines = """
|
||||
1. General Inquiries
|
||||
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
|
||||
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
|
||||
|
||||
2.Social Dialogue
|
||||
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
|
||||
Action: Provide concise, friendly, and personified natural responses.
|
||||
"""
|
||||
if not tools:
|
||||
tools = """
|
||||
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
|
||||
"""
|
||||
if guidelines:
|
||||
chat_history = format_messages_to_chat_history(messages)
|
||||
query_text = get_user_last_message_content(messages)
|
||||
guideline_prompt = await load_guideline_prompt(chat_history, query_text, guidelines, tools, scenarios, terms_analysis, language, user_identifier)
|
||||
guideline_task = process_guideline(
|
||||
chat_history=chat_history,
|
||||
guideline_prompt=guideline_prompt,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server
|
||||
)
|
||||
tasks.append(guideline_task)
|
||||
|
||||
# 并发执行所有任务
|
||||
all_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
agent = all_results[0] if len(all_results) > 0 else None # agent创建的结果
|
||||
|
||||
# 检查agent是否为异常对象
|
||||
if isinstance(agent, Exception):
|
||||
logger.error(f"Error creating agent: {agent}")
|
||||
raise agent
|
||||
|
||||
guideline_reasoning = all_results[1] if len(all_results) > 1 else ""
|
||||
if isinstance(guideline_reasoning, Exception):
|
||||
logger.error(f"Error in guideline processing: {guideline_reasoning}")
|
||||
guideline_reasoning = ""
|
||||
if guideline_prompt or guideline_reasoning:
|
||||
logger.info("Guideline Prompt: %s, Reasoning: %s",
|
||||
guideline_prompt.replace('\n', '\\n') if guideline_prompt else "None",
|
||||
guideline_reasoning.replace('\n', '\\n') if guideline_reasoning else "None")
|
||||
logger.info("System Prompt: %s", processed_system_prompt.replace('\n', '\\n'))
|
||||
return agent, processed_system_prompt, guideline_reasoning
|
||||
|
||||
|
||||
async def enhanced_generate_stream_response(
|
||||
agent_manager,
|
||||
bot_id: str,
|
||||
@ -224,32 +95,10 @@ async def enhanced_generate_stream_response(
|
||||
|
||||
# 创建preamble_text生成任务
|
||||
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
|
||||
preamble_task = asyncio.create_task(
|
||||
call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
|
||||
)
|
||||
|
||||
# 创建guideline分析和agent创建任务
|
||||
guidelines_task = asyncio.create_task(
|
||||
process_guidelines_and_terms(
|
||||
bot_id=bot_id,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
model_server=model_server,
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
agent_manager=agent_manager,
|
||||
project_dir=project_dir,
|
||||
generate_cfg=generate_cfg,
|
||||
language=language,
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
user_identifier=user_identifier
|
||||
)
|
||||
)
|
||||
|
||||
# 等待preamble_text任务完成
|
||||
try:
|
||||
preamble_text = await preamble_task
|
||||
preamble_text = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
|
||||
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
||||
if preamble_text and preamble_text.strip() and preamble_text != "<empty>":
|
||||
preamble_content = f"[PREAMBLE]\n{preamble_text}\n"
|
||||
@ -262,24 +111,19 @@ async def enhanced_generate_stream_response(
|
||||
logger.error(f"Error generating preamble text: {e}")
|
||||
|
||||
# 等待guideline分析任务完成
|
||||
agent, system_prompt, guideline_reasoning = await guidelines_task
|
||||
|
||||
# 立即发送guideline_reasoning
|
||||
if guideline_reasoning:
|
||||
guideline_content = f"[THINK]\n{guideline_reasoning}\n"
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-guideline", model_name, guideline_content)
|
||||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 准备最终的消息
|
||||
final_messages = messages.copy()
|
||||
if guideline_reasoning:
|
||||
# 用###分割guideline_reasoning,取最后一段作为Guidelines
|
||||
guidelines_text = guideline_reasoning.split('###')[-1].strip() if guideline_reasoning else ""
|
||||
final_messages = append_assistant_last_message(final_messages, f"language:{get_language_text(language)}\n\nGuidelines:\n{guidelines_text}\n I will follow these guidelines step by step.")
|
||||
else:
|
||||
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
|
||||
|
||||
logger.debug(f"Final messages for agent (showing first 2): {final_messages[:2]}")
|
||||
agent = await agent_manager.get_or_create_agent(
|
||||
bot_id=bot_id,
|
||||
project_dir=project_dir,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
language=language,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
user_identifier=user_identifier
|
||||
)
|
||||
|
||||
# 第三阶段:agent响应流式传输
|
||||
logger.info(f"Starting agent stream response")
|
||||
@ -287,7 +131,7 @@ async def enhanced_generate_stream_response(
|
||||
message_tag = ""
|
||||
function_name = ""
|
||||
tool_args = ""
|
||||
async for msg,metadata in agent.astream({"messages": final_messages}, stream_mode="messages"):
|
||||
async for msg,metadata in agent.astream({"messages": messages}, stream_mode="messages"):
|
||||
new_content = ""
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
# 判断是否有工具调用
|
||||
@ -314,6 +158,8 @@ async def enhanced_generate_stream_response(
|
||||
elif isinstance(msg, ToolMessage) and len(msg.content)>0:
|
||||
message_tag = "TOOL_RESPONSE"
|
||||
new_content = f"[{message_tag}] {msg.name}\n{msg.text}"
|
||||
elif isinstance(msg, AIMessage) and msg.additional_kwargs and "thinking" in msg.additional_kwargs:
|
||||
new_content = "[THINK]\n"+msg.additional_kwargs["thinking"]+ "\n"
|
||||
|
||||
# 只有当有新内容时才发送chunk
|
||||
if new_content:
|
||||
@ -388,20 +234,17 @@ async def create_agent_and_generate_response(
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
|
||||
|
||||
_, system_prompt = get_preamble_text(language, system_prompt)
|
||||
# 使用公共函数处理所有逻辑
|
||||
agent, system_prompt, guideline_reasoning = await process_guidelines_and_terms(
|
||||
agent = await agent_manager.get_or_create_agent(
|
||||
bot_id=bot_id,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
model_server=model_server,
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
agent_manager=agent_manager,
|
||||
project_dir=project_dir,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
language=language,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=mcp_settings,
|
||||
robot_type=robot_type,
|
||||
user_identifier=user_identifier
|
||||
@ -409,23 +252,16 @@ async def create_agent_and_generate_response(
|
||||
|
||||
# 准备最终的消息
|
||||
final_messages = messages.copy()
|
||||
if guideline_reasoning:
|
||||
# 用###分割guideline_reasoning,取最后一段作为Guidelines
|
||||
guidelines_text = guideline_reasoning.split('###')[-1].strip() if guideline_reasoning else ""
|
||||
final_messages = append_assistant_last_message(final_messages, f"language:{get_language_text(language)}\n\nGuidelines:\n{guidelines_text}\n I will follow these guidelines step by step.")
|
||||
else:
|
||||
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
|
||||
|
||||
# 非流式响应
|
||||
agent_responses = await agent.ainvoke({"messages": final_messages})
|
||||
append_messages = agent_responses["messages"][len(final_messages):]
|
||||
# agent_responses = agent.run_nonstream(final_messages)
|
||||
response_text = ""
|
||||
if guideline_reasoning:
|
||||
response_text += "[THINK]\n"+guideline_reasoning+ "\n"
|
||||
for msg in append_messages:
|
||||
if isinstance(msg,AIMessage):
|
||||
if len(msg.text)>0:
|
||||
if msg.additional_kwargs and "thinking" in msg.additional_kwargs:
|
||||
response_text += "[THINK]\n"+msg.additional_kwargs["thinking"]+ "\n"
|
||||
elif len(msg.text)>0:
|
||||
response_text += "[ANSWER]\n"+msg.text+ "\n"
|
||||
if len(msg.tool_calls)>0:
|
||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||
|
||||
@ -320,7 +320,14 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
return final_messages
|
||||
|
||||
|
||||
|
||||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||||
"""获取消息列表中的最后一条消息"""
|
||||
if not messages or len(messages) == 0:
|
||||
return ""
|
||||
last_message = messages[-1]
|
||||
if last_message and last_message.get('role') == 'user':
|
||||
return last_message["content"]
|
||||
return ""
|
||||
|
||||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||
"""将messages格式化为纯文本聊天记录
|
||||
@ -351,7 +358,6 @@ def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||
chat_history.append(f"{function_name} call: {arguments}")
|
||||
|
||||
recent_chat_history = chat_history[-15:] if len(chat_history) > 15 else chat_history
|
||||
print(f"recent_chat_history:{recent_chat_history}")
|
||||
return "\n".join(recent_chat_history)
|
||||
|
||||
|
||||
@ -660,30 +666,6 @@ def _get_optimal_batch_size(guidelines_count: int) -> int:
|
||||
else:
|
||||
return 5
|
||||
|
||||
|
||||
async def process_guideline(
|
||||
chat_history: str,
|
||||
guideline_prompt: str,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
model_server: str
|
||||
) -> str:
|
||||
"""处理单个guideline批次"""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"Start processed guideline batch on attempt {attempt + 1}")
|
||||
return await call_guideline_llm(chat_history, guideline_prompt, model_name, api_key, model_server)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing guideline batch on attempt {attempt + 1}: {e}")
|
||||
if attempt == max_retries - 1:
|
||||
return "" # 最后一次尝试失败,返回空字符串
|
||||
|
||||
# 这里不应该到达,但为了完整性
|
||||
return ""
|
||||
|
||||
|
||||
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
|
||||
"""
|
||||
从system prompt中提取guideline和terms内容
|
||||
|
||||
Loading…
Reference in New Issue
Block a user