update guidline
This commit is contained in:
parent
66b816c3b2
commit
ec9558fd4c
@ -1,10 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from deepagents import create_deep_agent
|
# from deepagents import create_deep_agent
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
from utils.fastapi_utils import detect_provider
|
from utils.fastapi_utils import detect_provider
|
||||||
|
|
||||||
|
from .guideline_middleware import GuidelineMiddleware
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
def read_system_prompt():
|
def read_system_prompt():
|
||||||
"""读取通用的无状态系统prompt"""
|
"""读取通用的无状态系统prompt"""
|
||||||
@ -18,9 +21,10 @@ def read_mcp_settings():
|
|||||||
mcp_settings_json = json.load(f)
|
mcp_settings_json = json.load(f)
|
||||||
return mcp_settings_json
|
return mcp_settings_json
|
||||||
|
|
||||||
async def init_agent(model_name="qwen3-next", api_key=None,
|
|
||||||
|
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
||||||
model_server=None, generate_cfg=None,
|
model_server=None, generate_cfg=None,
|
||||||
system_prompt=None, mcp=None):
|
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None):
|
||||||
system = system_prompt if system_prompt else read_system_prompt()
|
system = system_prompt if system_prompt else read_system_prompt()
|
||||||
mcp = mcp if mcp else read_mcp_settings()
|
mcp = mcp if mcp else read_mcp_settings()
|
||||||
# 修改mcp[0]["mcpServers"]列表,把 type 字段改成 transport ,如果没有的话,就默认transport:stdio
|
# 修改mcp[0]["mcpServers"]列表,把 type 字段改成 transport ,如果没有的话,就默认transport:stdio
|
||||||
@ -50,6 +54,7 @@ async def init_agent(model_name="qwen3-next", api_key=None,
|
|||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
system_prompt=system,
|
system_prompt=system,
|
||||||
tools=mcp_tools
|
tools=mcp_tools,
|
||||||
|
middleware=[GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
|
||||||
)
|
)
|
||||||
return agent
|
return agent
|
||||||
|
|||||||
148
agent/guideline_middleware.py
Normal file
148
agent/guideline_middleware.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import convert_to_openai_messages
|
||||||
|
from agent.prompt_loader import load_guideline_prompt
|
||||||
|
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
|
||||||
|
from langchain.chat_models import BaseChatModel
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from typing import Any, Callable
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
class ThinkingCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""自定义回调处理器,用于将模型响应内容转换为thinking格式"""
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
|
||||||
|
"""在LLM调用结束后处理响应,将内容转换为thinking格式"""
|
||||||
|
logger.info("Successfully converted response content to thinking format")
|
||||||
|
|
||||||
|
class GuidelineMiddleware(AgentMiddleware):
|
||||||
|
def __init__(self, bot_id: str, model:BaseChatModel, prompt: str, robot_type: str, language: str, user_identifier: str):
|
||||||
|
self.model = model
|
||||||
|
self.bot_id = bot_id
|
||||||
|
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
|
||||||
|
|
||||||
|
self.processed_system_prompt = processed_system_prompt
|
||||||
|
self.guidelines = guidelines
|
||||||
|
self.tool_description = tool_description
|
||||||
|
self.scenarios = scenarios
|
||||||
|
|
||||||
|
self.language = language
|
||||||
|
self.user_identifier = user_identifier
|
||||||
|
|
||||||
|
self.robot_type = robot_type
|
||||||
|
self.terms_list = terms_list
|
||||||
|
|
||||||
|
if self.robot_type == "general_agent":
|
||||||
|
if not self.guidelines:
|
||||||
|
self.guidelines = """
|
||||||
|
1. General Inquiries
|
||||||
|
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
|
||||||
|
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
|
||||||
|
|
||||||
|
2.Social Dialogue
|
||||||
|
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
|
||||||
|
Action: Provide concise, friendly, and personified natural responses.
|
||||||
|
"""
|
||||||
|
if not self.tool_description:
|
||||||
|
self.tool_description = """
|
||||||
|
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_guideline_prompt(self, messages: list[dict[str, Any]]) -> str:
|
||||||
|
## 处理terms
|
||||||
|
terms_analysis = self.get_term_analysis(messages)
|
||||||
|
|
||||||
|
guideline_prompt = ""
|
||||||
|
|
||||||
|
if self.guidelines:
|
||||||
|
chat_history = format_messages_to_chat_history(messages)
|
||||||
|
query_text = get_user_last_message_content(messages)
|
||||||
|
guideline_prompt = load_guideline_prompt(chat_history, query_text, self.guidelines, self.tool_description, self.scenarios, terms_analysis, self.language, self.user_identifier)
|
||||||
|
|
||||||
|
return guideline_prompt
|
||||||
|
|
||||||
|
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
|
||||||
|
## 处理terms
|
||||||
|
terms_analysis = ""
|
||||||
|
if self.terms_list:
|
||||||
|
logger.info(f"Processing terms: {len(self.terms_list)} terms")
|
||||||
|
try:
|
||||||
|
from embedding.embedding import process_terms_with_embedding
|
||||||
|
query_text = get_user_last_message_content(messages)
|
||||||
|
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
|
||||||
|
if terms_analysis:
|
||||||
|
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
|
||||||
|
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing terms with embedding: {e}")
|
||||||
|
terms_analysis = ""
|
||||||
|
else:
|
||||||
|
# 当terms_list为空时,删除对应的pkl缓存文件
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
os.remove(cache_file)
|
||||||
|
logger.info(f"Removed empty terms cache file: {cache_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error removing terms cache file: {e}")
|
||||||
|
return terms_analysis
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
if not self.guidelines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
|
||||||
|
|
||||||
|
# 创建回调处理器实例
|
||||||
|
thinking_handler = ThinkingCallbackHandler()
|
||||||
|
|
||||||
|
# 使用回调处理器调用模型
|
||||||
|
response = self.model.invoke(
|
||||||
|
guideline_prompt,
|
||||||
|
config={"callbacks": [thinking_handler]}
|
||||||
|
)
|
||||||
|
|
||||||
|
response.additional_kwargs["thinking"] = response.content
|
||||||
|
|
||||||
|
messages = state['messages']+[response]
|
||||||
|
return {
|
||||||
|
"messages": messages
|
||||||
|
}
|
||||||
|
|
||||||
|
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
if not self.guidelines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
|
||||||
|
|
||||||
|
# 使用回调处理器调用模型
|
||||||
|
response = await self.model.ainvoke(
|
||||||
|
guideline_prompt,
|
||||||
|
config={"callbacks": [ThinkingCallbackHandler()]}
|
||||||
|
)
|
||||||
|
|
||||||
|
response.additional_kwargs["thinking"] = response.content
|
||||||
|
|
||||||
|
messages = state['messages']+[response]
|
||||||
|
return {
|
||||||
|
"messages": messages
|
||||||
|
}
|
||||||
|
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelResponse:
|
||||||
|
return handler(request.override(system_prompt=self.processed_system_prompt))
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelResponse:
|
||||||
|
return await handler(request.override(system_prompt=self.processed_system_prompt))
|
||||||
@ -11,35 +11,6 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
def safe_replace(text: str, placeholder: str, value: Any) -> str:
|
|
||||||
"""
|
|
||||||
安全的字符串替换函数,确保 value 被转换为字符串
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 原始文本
|
|
||||||
placeholder: 要替换的占位符(如 '{user_identifier}')
|
|
||||||
value: 用于替换的值(可以是任意类型)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 替换后的文本
|
|
||||||
"""
|
|
||||||
if not isinstance(text, str):
|
|
||||||
text = str(text)
|
|
||||||
|
|
||||||
# 如果占位符为空,不进行替换
|
|
||||||
if not placeholder:
|
|
||||||
return text
|
|
||||||
|
|
||||||
# 将 value 转换为字符串,处理 None 等特殊情况
|
|
||||||
if value is None:
|
|
||||||
replacement = ""
|
|
||||||
else:
|
|
||||||
replacement = str(value)
|
|
||||||
|
|
||||||
return text.replace(placeholder, replacement)
|
|
||||||
|
|
||||||
|
|
||||||
def format_datetime_by_language(language: str) -> str:
|
def format_datetime_by_language(language: str) -> str:
|
||||||
"""
|
"""
|
||||||
根据语言格式化当前时间字符串,以UTC时间为基准计算各时区时间
|
根据语言格式化当前时间字符串,以UTC时间为基准计算各时区时间
|
||||||
@ -126,13 +97,12 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
|||||||
# 获取格式化的时间字符串
|
# 获取格式化的时间字符串
|
||||||
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
||||||
|
|
||||||
prompt = system_prompt or ""
|
|
||||||
# 如果存在{language} 占位符,那么就直接使用 system_prompt
|
# 如果存在{language} 占位符,那么就直接使用 system_prompt
|
||||||
if robot_type == "general_agent" or robot_type == "catalog_agent":
|
if robot_type == "general_agent" or robot_type == "catalog_agent":
|
||||||
"""
|
"""
|
||||||
优先使用项目目录的README.md,没有才使用默认的system_prompt_{robot_type}.md
|
优先使用项目目录的README.md,没有才使用默认的system_prompt_{robot_type}.md
|
||||||
"""
|
"""
|
||||||
system_prompt_default = None
|
system_prompt_default = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用缓存读取默认prompt文件
|
# 使用缓存读取默认prompt文件
|
||||||
@ -142,22 +112,17 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
|||||||
logger.info(f"Using cached default system prompt for {robot_type} from prompt folder")
|
logger.info(f"Using cached default system prompt for {robot_type} from prompt folder")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load default system prompt for {robot_type}: {str(e)}")
|
logger.error(f"Failed to load default system prompt for {robot_type}: {str(e)}")
|
||||||
system_prompt_default = None
|
system_prompt_default = ""
|
||||||
|
|
||||||
readme = ""
|
readme = ""
|
||||||
# 只有当 project_dir 不为 None 时才尝试读取 README.md
|
# 只有当 project_dir 不为 None 时才尝试读取 README.md
|
||||||
if project_dir is not None:
|
if project_dir is not None:
|
||||||
readme_path = os.path.join(project_dir, "README.md")
|
readme_path = os.path.join(project_dir, "README.md")
|
||||||
readme = await config_cache.get_text_file(readme_path) or ""
|
readme = await config_cache.get_text_file(readme_path) or ""
|
||||||
if system_prompt_default:
|
|
||||||
system_prompt_default = safe_replace(system_prompt_default, "{readme}", str(readme))
|
|
||||||
|
|
||||||
prompt = system_prompt_default or ""
|
prompt = system_prompt_default.format(readme=str(readme), extra_prompt=system_prompt or "",language=language_display, user_identifier=user_identifier, datetime=datetime_str)
|
||||||
prompt = safe_replace(prompt, "{extra_prompt}", system_prompt or "")
|
elif system_prompt:
|
||||||
|
prompt = system_prompt.format(language=language_display, user_identifier=user_identifier, datetime=datetime_str)
|
||||||
prompt = safe_replace(prompt, "{language}", language_display)
|
|
||||||
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
|
|
||||||
prompt = safe_replace(prompt, '{datetime}', datetime_str)
|
|
||||||
return prompt or ""
|
return prompt or ""
|
||||||
|
|
||||||
|
|
||||||
@ -175,16 +140,15 @@ def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id:
|
|||||||
for key, value in obj.items():
|
for key, value in obj.items():
|
||||||
if key == 'args' and isinstance(value, list):
|
if key == 'args' and isinstance(value, list):
|
||||||
# 特别处理 args 列表
|
# 特别处理 args 列表
|
||||||
obj[key] = [safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
|
obj[key] = [item.format(dataset_dir=dataset_dir, bot_id=bot_id) if isinstance(item, str) else item
|
||||||
for item in value]
|
for item in value]
|
||||||
elif isinstance(value, (dict, list)):
|
elif isinstance(value, (dict, list)):
|
||||||
obj[key] = replace_placeholders_in_obj(value)
|
obj[key] = replace_placeholders_in_obj(value)
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
obj[key] = safe_replace(value, '{dataset_dir}', dataset_dir)
|
obj[key] = value.format(dataset_dir=dataset_dir, bot_id=bot_id)
|
||||||
obj[key] = safe_replace(obj[key], '{bot_id}', bot_id)
|
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
||||||
safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
|
item.format(dataset_dir=dataset_dir, bot_id=bot_id) if isinstance(item, str) else item
|
||||||
for item in obj]
|
for item in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@ -280,7 +244,7 @@ async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot
|
|||||||
return merged_settings
|
return merged_settings
|
||||||
|
|
||||||
|
|
||||||
async def load_guideline_prompt(chat_history: str, last_message: str, guidelines_text: str, tools: str, scenarios: str, terms: str, language: str, user_identifier: str = "") -> str:
|
def load_guideline_prompt(chat_history: str, last_message: str, guidelines_text: str, tools: str, scenarios: str, terms: str, language: str, user_identifier: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
加载并处理guideline提示词
|
加载并处理guideline提示词
|
||||||
|
|
||||||
@ -295,16 +259,9 @@ async def load_guideline_prompt(chat_history: str, last_message: str, guidelines
|
|||||||
Returns:
|
Returns:
|
||||||
str: 处理后的guideline提示词
|
str: 处理后的guideline提示词
|
||||||
"""
|
"""
|
||||||
try:
|
guideline_template_file = os.path.join("prompt", "guideline_prompt.md")
|
||||||
from agent.config_cache import config_cache
|
with open(guideline_template_file, 'r', encoding='utf-8') as f:
|
||||||
guideline_template_file = os.path.join("prompt", "guideline_prompt.md")
|
guideline_template = f.read()
|
||||||
guideline_template = await config_cache.get_text_file(guideline_template_file)
|
|
||||||
if guideline_template is None:
|
|
||||||
logger.error("Failed to load guideline prompt template from cache")
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error reading guideline prompt template: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 获取语言显示文本
|
# 获取语言显示文本
|
||||||
language_display_map = {
|
language_display_map = {
|
||||||
@ -316,15 +273,17 @@ async def load_guideline_prompt(chat_history: str, last_message: str, guidelines
|
|||||||
language_display = language_display_map.get(language, language if language else 'English')
|
language_display = language_display_map.get(language, language if language else 'English')
|
||||||
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
||||||
# 替换模板中的占位符
|
# 替换模板中的占位符
|
||||||
system_prompt = safe_replace(guideline_template, '{chat_history}', chat_history)
|
system_prompt = guideline_template.format(
|
||||||
system_prompt = safe_replace(system_prompt, '{last_message}', last_message)
|
chat_history=chat_history,
|
||||||
system_prompt = safe_replace(system_prompt, '{guidelines_text}', guidelines_text)
|
last_message=last_message,
|
||||||
system_prompt = safe_replace(system_prompt, '{terms}', terms)
|
guidelines_text=guidelines_text,
|
||||||
system_prompt = safe_replace(system_prompt, '{tools}', tools)
|
terms=terms,
|
||||||
system_prompt = safe_replace(system_prompt, '{scenarios}', scenarios)
|
tools=tools,
|
||||||
system_prompt = safe_replace(system_prompt, '{language}', language_display)
|
scenarios=scenarios,
|
||||||
system_prompt = safe_replace(system_prompt, '{user_identifier}', user_identifier)
|
language=language_display,
|
||||||
system_prompt = safe_replace(system_prompt, '{datetime}', datetime_str)
|
user_identifier=user_identifier,
|
||||||
|
datetime=datetime_str
|
||||||
|
)
|
||||||
|
|
||||||
return system_prompt
|
return system_prompt
|
||||||
|
|
||||||
|
|||||||
@ -191,12 +191,16 @@ class ShardedAgentManager:
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
agent = await init_agent(
|
agent = await init_agent(
|
||||||
|
bot_id=bot_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model_server=model_server,
|
model_server=model_server,
|
||||||
generate_cfg=generate_cfg,
|
generate_cfg=generate_cfg,
|
||||||
system_prompt=final_system_prompt,
|
system_prompt=final_system_prompt,
|
||||||
mcp=final_mcp_settings
|
mcp=final_mcp_settings,
|
||||||
|
robot_type=robot_type,
|
||||||
|
language=language,
|
||||||
|
user_identifier=user_identifier,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存实例
|
# 缓存实例
|
||||||
|
|||||||
@ -22,7 +22,6 @@
|
|||||||
- **步骤制定**: 只需要列出详细的工具调用步骤即可,不需要考虑回复用户的话术,步骤数量根据实际需求调整。
|
- **步骤制定**: 只需要列出详细的工具调用步骤即可,不需要考虑回复用户的话术,步骤数量根据实际需求调整。
|
||||||
|
|
||||||
请按照上述思考框架进行完整分析,确保理解目标、分析问题和制定计划,
|
请按照上述思考框架进行完整分析,确保理解目标、分析问题和制定计划,
|
||||||
---
|
|
||||||
|
|
||||||
## 聊天记录 (Chat History)
|
## 聊天记录 (Chat History)
|
||||||
```
|
```
|
||||||
@ -48,7 +47,6 @@
|
|||||||
```
|
```
|
||||||
{guidelines_text}
|
{guidelines_text}
|
||||||
```
|
```
|
||||||
---
|
|
||||||
|
|
||||||
## 系统信息
|
## 系统信息
|
||||||
- **当前用户**: {user_identifier}
|
- **当前用户**: {user_identifier}
|
||||||
|
|||||||
220
routes/chat.py
220
routes/chat.py
@ -15,9 +15,9 @@ from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
|||||||
from utils.api_models import ChatRequestV2
|
from utils.api_models import ChatRequestV2
|
||||||
from agent.prompt_loader import load_guideline_prompt
|
from agent.prompt_loader import load_guideline_prompt
|
||||||
from utils.fastapi_utils import (
|
from utils.fastapi_utils import (
|
||||||
process_messages, extract_block_from_system_prompt, format_messages_to_chat_history,
|
process_messages, format_messages_to_chat_history,
|
||||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||||
process_guideline, call_preamble_llm, get_preamble_text, get_language_text,
|
call_preamble_llm, get_preamble_text, get_user_last_message_content,
|
||||||
create_stream_chunk
|
create_stream_chunk
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||||
@ -31,14 +31,7 @@ agent_manager = init_global_sharded_agent_manager(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
|
||||||
"""获取消息列表中的最后一条消息"""
|
|
||||||
if not messages or len(messages) == 0:
|
|
||||||
return ""
|
|
||||||
last_message = messages[-1]
|
|
||||||
if last_message and last_message.get('role') == 'user':
|
|
||||||
return last_message["content"]
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def append_user_last_message(messages: list, content: str) -> bool:
|
def append_user_last_message(messages: list, content: str) -> bool:
|
||||||
"""向最后一条用户消息追加内容
|
"""向最后一条用户消息追加内容
|
||||||
@ -78,128 +71,6 @@ def append_assistant_last_message(messages: list, content: str) -> bool:
|
|||||||
messages.append({"role":"assistant","content":content})
|
messages.append({"role":"assistant","content":content})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def process_guidelines_and_terms(
|
|
||||||
bot_id: str,
|
|
||||||
api_key: str,
|
|
||||||
model_name: str,
|
|
||||||
model_server: str,
|
|
||||||
system_prompt: str,
|
|
||||||
messages: list,
|
|
||||||
agent_manager,
|
|
||||||
project_dir: Optional[str],
|
|
||||||
generate_cfg: Optional[dict],
|
|
||||||
language: str,
|
|
||||||
mcp_settings: Optional[list],
|
|
||||||
robot_type: str,
|
|
||||||
user_identifier: Optional[str]
|
|
||||||
) -> tuple:
|
|
||||||
"""
|
|
||||||
公共函数:处理guideline分析和terms处理,返回agent和analysis结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (agent, processed_system_prompt, guideline_reasoning, terms_analysis)
|
|
||||||
"""
|
|
||||||
# 提取system_prompt中的guideline和terms
|
|
||||||
processed_system_prompt, guidelines, tools, scenarios, terms_list = extract_block_from_system_prompt(system_prompt)
|
|
||||||
|
|
||||||
# # 处理terms
|
|
||||||
terms_analysis = ""
|
|
||||||
if terms_list:
|
|
||||||
logger.info(f"Processing terms: {len(terms_list)} terms")
|
|
||||||
try:
|
|
||||||
from embedding.embedding import process_terms_with_embedding
|
|
||||||
query_text = get_user_last_message_content(messages)
|
|
||||||
terms_analysis = process_terms_with_embedding(terms_list, bot_id, query_text)
|
|
||||||
if terms_analysis:
|
|
||||||
processed_system_prompt = processed_system_prompt.replace("{terms}", terms_analysis)
|
|
||||||
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing terms with embedding: {e}")
|
|
||||||
terms_analysis = ""
|
|
||||||
else:
|
|
||||||
# 当terms_list为空时,删除对应的pkl缓存文件
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
cache_file = f"projects/cache/{bot_id}_terms.pkl"
|
|
||||||
if os.path.exists(cache_file):
|
|
||||||
os.remove(cache_file)
|
|
||||||
logger.info(f"Removed empty terms cache file: {cache_file}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error removing terms cache file: {e}")
|
|
||||||
|
|
||||||
# 创建所有任务
|
|
||||||
tasks = []
|
|
||||||
|
|
||||||
# 添加agent创建任务
|
|
||||||
agent_task = agent_manager.get_or_create_agent(
|
|
||||||
bot_id=bot_id,
|
|
||||||
project_dir=project_dir,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
model_server=model_server,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
language=language,
|
|
||||||
system_prompt=processed_system_prompt,
|
|
||||||
mcp_settings=mcp_settings,
|
|
||||||
robot_type=robot_type,
|
|
||||||
user_identifier=user_identifier
|
|
||||||
)
|
|
||||||
tasks.append(agent_task)
|
|
||||||
|
|
||||||
guideline_prompt = ""
|
|
||||||
|
|
||||||
if robot_type == "general_agent":
|
|
||||||
if not guidelines:
|
|
||||||
guidelines = """
|
|
||||||
1. General Inquiries
|
|
||||||
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
|
|
||||||
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
|
|
||||||
|
|
||||||
2.Social Dialogue
|
|
||||||
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
|
|
||||||
Action: Provide concise, friendly, and personified natural responses.
|
|
||||||
"""
|
|
||||||
if not tools:
|
|
||||||
tools = """
|
|
||||||
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
|
|
||||||
"""
|
|
||||||
if guidelines:
|
|
||||||
chat_history = format_messages_to_chat_history(messages)
|
|
||||||
query_text = get_user_last_message_content(messages)
|
|
||||||
guideline_prompt = await load_guideline_prompt(chat_history, query_text, guidelines, tools, scenarios, terms_analysis, language, user_identifier)
|
|
||||||
guideline_task = process_guideline(
|
|
||||||
chat_history=chat_history,
|
|
||||||
guideline_prompt=guideline_prompt,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
model_server=model_server
|
|
||||||
)
|
|
||||||
tasks.append(guideline_task)
|
|
||||||
|
|
||||||
# 并发执行所有任务
|
|
||||||
all_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
agent = all_results[0] if len(all_results) > 0 else None # agent创建的结果
|
|
||||||
|
|
||||||
# 检查agent是否为异常对象
|
|
||||||
if isinstance(agent, Exception):
|
|
||||||
logger.error(f"Error creating agent: {agent}")
|
|
||||||
raise agent
|
|
||||||
|
|
||||||
guideline_reasoning = all_results[1] if len(all_results) > 1 else ""
|
|
||||||
if isinstance(guideline_reasoning, Exception):
|
|
||||||
logger.error(f"Error in guideline processing: {guideline_reasoning}")
|
|
||||||
guideline_reasoning = ""
|
|
||||||
if guideline_prompt or guideline_reasoning:
|
|
||||||
logger.info("Guideline Prompt: %s, Reasoning: %s",
|
|
||||||
guideline_prompt.replace('\n', '\\n') if guideline_prompt else "None",
|
|
||||||
guideline_reasoning.replace('\n', '\\n') if guideline_reasoning else "None")
|
|
||||||
logger.info("System Prompt: %s", processed_system_prompt.replace('\n', '\\n'))
|
|
||||||
return agent, processed_system_prompt, guideline_reasoning
|
|
||||||
|
|
||||||
|
|
||||||
async def enhanced_generate_stream_response(
|
async def enhanced_generate_stream_response(
|
||||||
agent_manager,
|
agent_manager,
|
||||||
bot_id: str,
|
bot_id: str,
|
||||||
@ -224,32 +95,10 @@ async def enhanced_generate_stream_response(
|
|||||||
|
|
||||||
# 创建preamble_text生成任务
|
# 创建preamble_text生成任务
|
||||||
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
|
preamble_text, system_prompt = get_preamble_text(language, system_prompt)
|
||||||
preamble_task = asyncio.create_task(
|
|
||||||
call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建guideline分析和agent创建任务
|
|
||||||
guidelines_task = asyncio.create_task(
|
|
||||||
process_guidelines_and_terms(
|
|
||||||
bot_id=bot_id,
|
|
||||||
api_key=api_key,
|
|
||||||
model_name=model_name,
|
|
||||||
model_server=model_server,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
messages=messages,
|
|
||||||
agent_manager=agent_manager,
|
|
||||||
project_dir=project_dir,
|
|
||||||
generate_cfg=generate_cfg,
|
|
||||||
language=language,
|
|
||||||
mcp_settings=mcp_settings,
|
|
||||||
robot_type=robot_type,
|
|
||||||
user_identifier=user_identifier
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 等待preamble_text任务完成
|
# 等待preamble_text任务完成
|
||||||
try:
|
try:
|
||||||
preamble_text = await preamble_task
|
preamble_text = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server)
|
||||||
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
||||||
if preamble_text and preamble_text.strip() and preamble_text != "<empty>":
|
if preamble_text and preamble_text.strip() and preamble_text != "<empty>":
|
||||||
preamble_content = f"[PREAMBLE]\n{preamble_text}\n"
|
preamble_content = f"[PREAMBLE]\n{preamble_text}\n"
|
||||||
@ -262,24 +111,19 @@ async def enhanced_generate_stream_response(
|
|||||||
logger.error(f"Error generating preamble text: {e}")
|
logger.error(f"Error generating preamble text: {e}")
|
||||||
|
|
||||||
# 等待guideline分析任务完成
|
# 等待guideline分析任务完成
|
||||||
agent, system_prompt, guideline_reasoning = await guidelines_task
|
agent = await agent_manager.get_or_create_agent(
|
||||||
|
bot_id=bot_id,
|
||||||
# 立即发送guideline_reasoning
|
project_dir=project_dir,
|
||||||
if guideline_reasoning:
|
model_name=model_name,
|
||||||
guideline_content = f"[THINK]\n{guideline_reasoning}\n"
|
api_key=api_key,
|
||||||
chunk_data = create_stream_chunk(f"chatcmpl-guideline", model_name, guideline_content)
|
model_server=model_server,
|
||||||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
generate_cfg=generate_cfg,
|
||||||
|
language=language,
|
||||||
# 准备最终的消息
|
system_prompt=system_prompt,
|
||||||
final_messages = messages.copy()
|
mcp_settings=mcp_settings,
|
||||||
if guideline_reasoning:
|
robot_type=robot_type,
|
||||||
# 用###分割guideline_reasoning,取最后一段作为Guidelines
|
user_identifier=user_identifier
|
||||||
guidelines_text = guideline_reasoning.split('###')[-1].strip() if guideline_reasoning else ""
|
)
|
||||||
final_messages = append_assistant_last_message(final_messages, f"language:{get_language_text(language)}\n\nGuidelines:\n{guidelines_text}\n I will follow these guidelines step by step.")
|
|
||||||
else:
|
|
||||||
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
|
|
||||||
|
|
||||||
logger.debug(f"Final messages for agent (showing first 2): {final_messages[:2]}")
|
|
||||||
|
|
||||||
# 第三阶段:agent响应流式传输
|
# 第三阶段:agent响应流式传输
|
||||||
logger.info(f"Starting agent stream response")
|
logger.info(f"Starting agent stream response")
|
||||||
@ -287,7 +131,7 @@ async def enhanced_generate_stream_response(
|
|||||||
message_tag = ""
|
message_tag = ""
|
||||||
function_name = ""
|
function_name = ""
|
||||||
tool_args = ""
|
tool_args = ""
|
||||||
async for msg,metadata in agent.astream({"messages": final_messages}, stream_mode="messages"):
|
async for msg,metadata in agent.astream({"messages": messages}, stream_mode="messages"):
|
||||||
new_content = ""
|
new_content = ""
|
||||||
if isinstance(msg, AIMessageChunk):
|
if isinstance(msg, AIMessageChunk):
|
||||||
# 判断是否有工具调用
|
# 判断是否有工具调用
|
||||||
@ -314,6 +158,8 @@ async def enhanced_generate_stream_response(
|
|||||||
elif isinstance(msg, ToolMessage) and len(msg.content)>0:
|
elif isinstance(msg, ToolMessage) and len(msg.content)>0:
|
||||||
message_tag = "TOOL_RESPONSE"
|
message_tag = "TOOL_RESPONSE"
|
||||||
new_content = f"[{message_tag}] {msg.name}\n{msg.text}"
|
new_content = f"[{message_tag}] {msg.name}\n{msg.text}"
|
||||||
|
elif isinstance(msg, AIMessage) and msg.additional_kwargs and "thinking" in msg.additional_kwargs:
|
||||||
|
new_content = "[THINK]\n"+msg.additional_kwargs["thinking"]+ "\n"
|
||||||
|
|
||||||
# 只有当有新内容时才发送chunk
|
# 只有当有新内容时才发送chunk
|
||||||
if new_content:
|
if new_content:
|
||||||
@ -388,20 +234,17 @@ async def create_agent_and_generate_response(
|
|||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_, system_prompt = get_preamble_text(language, system_prompt)
|
_, system_prompt = get_preamble_text(language, system_prompt)
|
||||||
# 使用公共函数处理所有逻辑
|
# 使用公共函数处理所有逻辑
|
||||||
agent, system_prompt, guideline_reasoning = await process_guidelines_and_terms(
|
agent = await agent_manager.get_or_create_agent(
|
||||||
bot_id=bot_id,
|
bot_id=bot_id,
|
||||||
api_key=api_key,
|
|
||||||
model_name=model_name,
|
|
||||||
model_server=model_server,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
messages=messages,
|
|
||||||
agent_manager=agent_manager,
|
|
||||||
project_dir=project_dir,
|
project_dir=project_dir,
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
model_server=model_server,
|
||||||
generate_cfg=generate_cfg,
|
generate_cfg=generate_cfg,
|
||||||
language=language,
|
language=language,
|
||||||
|
system_prompt=system_prompt,
|
||||||
mcp_settings=mcp_settings,
|
mcp_settings=mcp_settings,
|
||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
user_identifier=user_identifier
|
user_identifier=user_identifier
|
||||||
@ -409,23 +252,16 @@ async def create_agent_and_generate_response(
|
|||||||
|
|
||||||
# 准备最终的消息
|
# 准备最终的消息
|
||||||
final_messages = messages.copy()
|
final_messages = messages.copy()
|
||||||
if guideline_reasoning:
|
|
||||||
# 用###分割guideline_reasoning,取最后一段作为Guidelines
|
|
||||||
guidelines_text = guideline_reasoning.split('###')[-1].strip() if guideline_reasoning else ""
|
|
||||||
final_messages = append_assistant_last_message(final_messages, f"language:{get_language_text(language)}\n\nGuidelines:\n{guidelines_text}\n I will follow these guidelines step by step.")
|
|
||||||
else:
|
|
||||||
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
|
|
||||||
|
|
||||||
# 非流式响应
|
# 非流式响应
|
||||||
agent_responses = await agent.ainvoke({"messages": final_messages})
|
agent_responses = await agent.ainvoke({"messages": final_messages})
|
||||||
append_messages = agent_responses["messages"][len(final_messages):]
|
append_messages = agent_responses["messages"][len(final_messages):]
|
||||||
# agent_responses = agent.run_nonstream(final_messages)
|
|
||||||
response_text = ""
|
response_text = ""
|
||||||
if guideline_reasoning:
|
|
||||||
response_text += "[THINK]\n"+guideline_reasoning+ "\n"
|
|
||||||
for msg in append_messages:
|
for msg in append_messages:
|
||||||
if isinstance(msg,AIMessage):
|
if isinstance(msg,AIMessage):
|
||||||
if len(msg.text)>0:
|
if msg.additional_kwargs and "thinking" in msg.additional_kwargs:
|
||||||
|
response_text += "[THINK]\n"+msg.additional_kwargs["thinking"]+ "\n"
|
||||||
|
elif len(msg.text)>0:
|
||||||
response_text += "[ANSWER]\n"+msg.text+ "\n"
|
response_text += "[ANSWER]\n"+msg.text+ "\n"
|
||||||
if len(msg.tool_calls)>0:
|
if len(msg.tool_calls)>0:
|
||||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||||
|
|||||||
@ -320,7 +320,14 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
|||||||
return final_messages
|
return final_messages
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||||||
|
"""获取消息列表中的最后一条消息"""
|
||||||
|
if not messages or len(messages) == 0:
|
||||||
|
return ""
|
||||||
|
last_message = messages[-1]
|
||||||
|
if last_message and last_message.get('role') == 'user':
|
||||||
|
return last_message["content"]
|
||||||
|
return ""
|
||||||
|
|
||||||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||||
"""将messages格式化为纯文本聊天记录
|
"""将messages格式化为纯文本聊天记录
|
||||||
@ -351,7 +358,6 @@ def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
|||||||
chat_history.append(f"{function_name} call: {arguments}")
|
chat_history.append(f"{function_name} call: {arguments}")
|
||||||
|
|
||||||
recent_chat_history = chat_history[-15:] if len(chat_history) > 15 else chat_history
|
recent_chat_history = chat_history[-15:] if len(chat_history) > 15 else chat_history
|
||||||
print(f"recent_chat_history:{recent_chat_history}")
|
|
||||||
return "\n".join(recent_chat_history)
|
return "\n".join(recent_chat_history)
|
||||||
|
|
||||||
|
|
||||||
@ -660,30 +666,6 @@ def _get_optimal_batch_size(guidelines_count: int) -> int:
|
|||||||
else:
|
else:
|
||||||
return 5
|
return 5
|
||||||
|
|
||||||
|
|
||||||
async def process_guideline(
|
|
||||||
chat_history: str,
|
|
||||||
guideline_prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
api_key: str,
|
|
||||||
model_server: str
|
|
||||||
) -> str:
|
|
||||||
"""处理单个guideline批次"""
|
|
||||||
max_retries = 3
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
logger.info(f"Start processed guideline batch on attempt {attempt + 1}")
|
|
||||||
return await call_guideline_llm(chat_history, guideline_prompt, model_name, api_key, model_server)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing guideline batch on attempt {attempt + 1}: {e}")
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
return "" # 最后一次尝试失败,返回空字符串
|
|
||||||
|
|
||||||
# 这里不应该到达,但为了完整性
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
|
def extract_block_from_system_prompt(system_prompt: str) -> tuple[str, str, str, str, List]:
|
||||||
"""
|
"""
|
||||||
从system prompt中提取guideline和terms内容
|
从system prompt中提取guideline和terms内容
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user