532 lines
22 KiB
Python
532 lines
22 KiB
Python
import json
|
||
import os
|
||
import asyncio
|
||
from typing import Union, Optional
|
||
from fastapi import APIRouter, HTTPException, Header
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from utils import (
|
||
Message, ChatRequest, ChatResponse
|
||
)
|
||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||
from utils.api_models import ChatRequestV2
|
||
from utils.fastapi_utils import (
|
||
process_messages, extract_block_from_system_prompt, format_messages_to_chat_history,
|
||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||
call_guideline_llm, _get_optimal_batch_size, process_guideline_batch, get_content_from_messages
|
||
)
|
||
|
||
router = APIRouter()
|
||
|
||
# 初始化全局助手管理器
|
||
agent_manager = init_global_sharded_agent_manager(
|
||
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")),
|
||
shard_count=int(os.getenv("SHARD_COUNT", "16"))
|
||
)
|
||
|
||
|
||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||
"""获取消息列表中的最后一条消息"""
|
||
if not messages or len(messages) == 0:
|
||
return ""
|
||
last_message = messages[-1]
|
||
if last_message and last_message.get('role') == 'user':
|
||
return last_message["content"]
|
||
return ""
|
||
|
||
def append_user_last_message(messages: list, content: str) -> bool:
|
||
"""向最后一条用户消息追加内容
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
content: 要追加的内容
|
||
condition: 可选条件,如果提供则检查消息角色是否匹配此条件
|
||
|
||
Returns:
|
||
bool: 是否成功追加内容
|
||
"""
|
||
if not messages or len(messages) == 0:
|
||
return messages
|
||
last_message = messages[-1]
|
||
if last_message and last_message.get('role') == 'user':
|
||
messages[-1]['content'] += content
|
||
return messages
|
||
|
||
|
||
async def generate_stream_response(agent, messages, thought_list, tool_response: bool, model: str):
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
|
||
|
||
chunk_id = 0
|
||
try:
|
||
|
||
if len(thought_list)>0:
|
||
accumulated_content = get_content_from_messages(thought_list, tool_response=tool_response)
|
||
chunk_data = {
|
||
"id": f"chatcmpl-thought",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": accumulated_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
from utils.logger import logger
|
||
logger.error(f"Error in generate_stream_response: {str(e)}")
|
||
logger.error(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
async def create_agent_and_generate_response(
|
||
bot_id: str,
|
||
api_key: str,
|
||
messages: list,
|
||
stream: bool,
|
||
tool_response: bool,
|
||
model_name: str,
|
||
model_server: str,
|
||
language: str,
|
||
system_prompt: Optional[str],
|
||
mcp_settings: Optional[list],
|
||
robot_type: str,
|
||
project_dir: Optional[str] = None,
|
||
generate_cfg: Optional[dict] = None,
|
||
user_identifier: Optional[str] = None
|
||
) -> Union[ChatResponse, StreamingResponse]:
|
||
"""创建agent并生成响应的公共逻辑"""
|
||
if generate_cfg is None:
|
||
generate_cfg = {}
|
||
|
||
# 1. 从system_prompt提取guideline和terms内容
|
||
system_prompt, guidelines_list, terms_list = extract_block_from_system_prompt(system_prompt)
|
||
|
||
# 2. 如果有terms内容,先进行embedding(embedding需要缓存起来,这个可以tmp文件缓存,以{bot_id}_terms作为key)embedding实现情参考 @embedding/embedding.py 文件,可以在里面实现。拿到embedding后,可以进行相似性检索,检索方式先使用cos相似度,找到阈值相似性>0.7的匹配项,重新整理为terms_analysis,格式:1) Name: term_name1, Description: desc, Synonyms: syn1, syn2。
|
||
terms_analysis = ""
|
||
if terms_list:
|
||
print(f"terms_list: {terms_list}")
|
||
# 从messages中提取用户的查询文本用于相似性检索
|
||
query_text = get_user_last_message_content(messages)
|
||
# 使用embedding进行terms处理
|
||
try:
|
||
from embedding.embedding import process_terms_with_embedding
|
||
terms_analysis = process_terms_with_embedding(terms_list, bot_id, query_text)
|
||
if terms_analysis:
|
||
# 将terms分析结果也添加到消息中
|
||
system_prompt = system_prompt.replace("{terms}", terms_analysis)
|
||
print(f"Generated terms analysis: {terms_analysis[:200]}...") # 只打印前200个字符
|
||
except Exception as e:
|
||
print(f"Error processing terms with embedding: {e}")
|
||
terms_analysis = ""
|
||
else:
|
||
# 当terms_list为空时,删除对应的pkl缓存文件
|
||
try:
|
||
import os
|
||
cache_file = f"projects/cache/{bot_id}_terms.pkl"
|
||
if os.path.exists(cache_file):
|
||
os.remove(cache_file)
|
||
print(f"Removed empty terms cache file: {cache_file}")
|
||
except Exception as e:
|
||
print(f"Error removing terms cache file: {e}")
|
||
|
||
# 3. 如果有guideline内容,进行并发处理
|
||
guideline_analysis = ""
|
||
if guidelines_list:
|
||
print(f"guidelines_list: {guidelines_list}")
|
||
guidelines_count = len(guidelines_list)
|
||
|
||
if guidelines_count > 0:
|
||
# 获取最优批次数量(并发数)
|
||
batch_count = _get_optimal_batch_size(guidelines_count)
|
||
|
||
# 计算每个批次应该包含多少条guideline
|
||
guidelines_per_batch = max(1, guidelines_count // batch_count)
|
||
|
||
# 分批处理guidelines - 将字典列表转换为字符串列表以便处理
|
||
batches = []
|
||
for i in range(0, guidelines_count, guidelines_per_batch):
|
||
batch_guidelines = guidelines_list[i:i + guidelines_per_batch]
|
||
# 将格式化为字符串,保持原有的格式以便LLM处理
|
||
batch_strings = []
|
||
for guideline in batch_guidelines:
|
||
guideline_str = f"{guideline['id']}) Condition: {guideline['condition']} Action: {guideline['action']}"
|
||
batch_strings.append(guideline_str)
|
||
batches.append(batch_strings)
|
||
|
||
# 确保批次数量不超过要求的并发数
|
||
while len(batches) > batch_count:
|
||
# 将最后一个批次合并到倒数第二个批次
|
||
batches[-2].extend(batches[-1])
|
||
batches.pop()
|
||
|
||
print(f"Processing {guidelines_count} guidelines in {len(batches)} batches with {batch_count} concurrent batches")
|
||
|
||
# 准备chat_history
|
||
chat_history = format_messages_to_chat_history(messages)
|
||
|
||
# 并发执行所有任务:guideline批次处理 + agent创建
|
||
tasks = []
|
||
|
||
# 添加所有guideline批次任务
|
||
for batch in batches:
|
||
task = process_guideline_batch(
|
||
guidelines_batch=batch,
|
||
chat_history=chat_history,
|
||
terms=terms_analysis,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server
|
||
)
|
||
tasks.append(task)
|
||
|
||
# 添加agent创建任务
|
||
agent_task = agent_manager.get_or_create_agent(
|
||
bot_id=bot_id,
|
||
project_dir=project_dir,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server,
|
||
generate_cfg=generate_cfg,
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=mcp_settings,
|
||
robot_type=robot_type,
|
||
user_identifier=user_identifier
|
||
)
|
||
tasks.append(agent_task)
|
||
|
||
# 等待所有任务完成
|
||
all_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 处理结果:最后一个结果是agent,前面的是guideline批次结果
|
||
agent = all_results[-1] # agent创建的结果
|
||
batch_results = all_results[:-1] # guideline批次的结果
|
||
print(f"batch_results:{batch_results}")
|
||
|
||
# 合并guideline分析结果,使用JSON格式的checks数组
|
||
all_checks = []
|
||
for i, result in enumerate(batch_results):
|
||
if isinstance(result, Exception):
|
||
print(f"Guideline batch {i} failed: {result}")
|
||
continue
|
||
if result and isinstance(result, dict) and 'checks' in result:
|
||
# 如果是JSON对象且包含checks数组,只保留applies为true的checks
|
||
applicable_checks = [check for check in result['checks'] if check.get('applies') is True]
|
||
all_checks.extend(applicable_checks)
|
||
elif result and isinstance(result, str) and result.strip():
|
||
# 如果是普通文本,保留原有逻辑
|
||
print(f"Non-JSON result from batch {i}: {result}")
|
||
|
||
if all_checks:
|
||
# 将checks数组格式化为JSON字符串
|
||
guideline_analysis = "\n".join([item["rationale"] for item in all_checks])
|
||
# guideline_analysis = json.dumps({"checks": all_checks}, ensure_ascii=False)
|
||
print(f"Merged guideline analysis result: {guideline_analysis}")
|
||
|
||
# 将分析结果添加到最后一个消息的内容中
|
||
if guideline_analysis:
|
||
messages = append_user_last_message(messages, f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response.")
|
||
|
||
else:
|
||
# 3. 从全局管理器获取或创建助手实例
|
||
agent = await agent_manager.get_or_create_agent(
|
||
bot_id=bot_id,
|
||
project_dir=project_dir,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server,
|
||
generate_cfg=generate_cfg,
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=mcp_settings,
|
||
robot_type=robot_type,
|
||
user_identifier=user_identifier
|
||
)
|
||
|
||
if language:
|
||
# 在最后一条消息的末尾追加回复语言
|
||
language_map = {
|
||
'zh': '请用中文回复',
|
||
'en': 'Please reply in English',
|
||
'ja': '日本語で回答してください',
|
||
'jp': '日本語で回答してください'
|
||
}
|
||
language_instruction = language_map.get(language.lower(), '')
|
||
if language_instruction:
|
||
messages = append_user_last_message(messages, f"\n\nlanguage:{language_instruction}")
|
||
|
||
thought_list = []
|
||
if guideline_analysis != '':
|
||
thought_list = [{"role": "assistant","reasoning_content": guideline_analysis}]
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, thought_list, tool_response, model_name),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
agent_responses = agent.run_nonstream(messages)
|
||
final_responses = thought_list+agent_responses
|
||
if final_responses and len(final_responses) > 0:
|
||
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
|
||
content = get_content_from_messages(final_responses, tool_response=tool_response)
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
|
||
@router.post("/api/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Notes:
|
||
- dataset_ids: 可选参数,当提供时必须是项目ID列表(单个项目也使用数组格式)
|
||
- bot_id: 必需参数,机器人ID
|
||
- 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录:projects/robot/{bot_id}/
|
||
- robot_type 为其他值(包括默认的 "agent")时不创建任何目录
|
||
- dataset_ids 为空数组 []、None 或未提供时不创建任何目录
|
||
- 支持多知识库合并,自动处理文件夹重名冲突
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
Optional Parameters:
|
||
- dataset_ids: List[str] - 源知识库项目ID列表(单个项目也使用数组格式)
|
||
- robot_type: str - 机器人类型,默认为 "agent"
|
||
|
||
Example:
|
||
{"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]}
|
||
"""
|
||
try:
|
||
# v1接口:从Authorization header中提取API key作为模型API密钥
|
||
api_key = extract_api_key_from_auth(authorization)
|
||
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# 创建项目目录(如果有dataset_ids且不是agent类型)
|
||
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
||
|
||
# 收集额外参数作为 generate_cfg
|
||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'}
|
||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=True,
|
||
model_name=request.model,
|
||
model_server=request.model_server,
|
||
language=request.language,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
robot_type=request.robot_type,
|
||
project_dir=project_dir,
|
||
generate_cfg=generate_cfg,
|
||
user_identifier=request.user_identifier
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v2/chat/completions")
|
||
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API v2 with simplified parameters.
|
||
Only requires messages, stream, tool_response, bot_id, and language parameters.
|
||
Other parameters are fetched from the backend bot configuration API.
|
||
|
||
Args:
|
||
request: ChatRequestV2 containing only essential parameters
|
||
authorization: Authorization header for authentication (different from v1)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
|
||
Optional Parameters:
|
||
- stream: bool - 是否流式输出,默认false
|
||
- tool_response: bool - 是否包含工具响应,默认false
|
||
- language: str - 回复语言,默认"ja"
|
||
|
||
Authentication:
|
||
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
|
||
- Authorization header should contain: Bearer {token}
|
||
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
|
||
- Optionally uses API key from bot config for model access
|
||
"""
|
||
try:
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# v2接口鉴权验证
|
||
expected_token = generate_v2_auth_token(bot_id)
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required for v2 API"
|
||
)
|
||
|
||
if provided_token != expected_token:
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
|
||
)
|
||
|
||
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
||
bot_config = await fetch_bot_config(bot_id)
|
||
|
||
# v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取
|
||
# 注意:这里的Authorization header已经用于鉴权,不再作为API key使用
|
||
api_key = bot_config.get("api_key")
|
||
|
||
# 创建项目目录(从后端配置获取dataset_ids)
|
||
project_dir = create_project_directory(
|
||
bot_config.get("dataset_ids", []),
|
||
bot_id,
|
||
bot_config.get("robot_type", "general_agent")
|
||
)
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=bot_config.get("model_server", ""),
|
||
language=request.language or bot_config.get("language", "ja"),
|
||
system_prompt=bot_config.get("system_prompt"),
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
robot_type=bot_config.get("robot_type", "general_agent"),
|
||
project_dir=project_dir,
|
||
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
||
user_identifier=request.user_identifier
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions_v2: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|