v2 实现shell_env支持
This commit is contained in:
parent
32fd8c8656
commit
5a7aa06681
@ -203,7 +203,82 @@ class AgentConfig:
|
||||
enable_memori=enable_memori,
|
||||
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||||
trace_id=trace_id,
|
||||
shell_env=getattr(request, 'shell_env', None) or bot_config.get("shell_env") or {},
|
||||
shell_env=bot_config.get("shell_env") or {},
|
||||
)
|
||||
|
||||
# 在创建 config 时尽早准备 checkpoint 消息
|
||||
if config.session_id:
|
||||
try:
|
||||
manager = get_checkpointer_manager()
|
||||
checkpointer = manager.checkpointer
|
||||
if checkpointer:
|
||||
await prepare_checkpoint_message(config, checkpointer)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpointer: {e}")
|
||||
|
||||
config.safe_print()
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
async def from_v3_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, language: Optional[str] = None):
|
||||
"""从v3请求创建配置 - 从数据库读取所有配置"""
|
||||
# 延迟导入避免循环依赖
|
||||
from .logging_handler import LoggingCallbackHandler
|
||||
from utils.fastapi_utils import get_preamble_text
|
||||
from utils.settings import (
|
||||
MEM0_ENABLED,
|
||||
MEM0_SEMANTIC_SEARCH_TOP_K,
|
||||
)
|
||||
from .checkpoint_utils import prepare_checkpoint_message
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
from utils.log_util.context import g
|
||||
|
||||
if messages is None:
|
||||
messages = []
|
||||
|
||||
# 从全局上下文获取 trace_id
|
||||
trace_id = None
|
||||
try:
|
||||
trace_id = getattr(g, 'trace_id', None)
|
||||
except LookupError:
|
||||
pass
|
||||
|
||||
# 从数据库配置获取语言(如果没有传递)
|
||||
if language is None:
|
||||
language = bot_config.get("language", "zh")
|
||||
|
||||
# 处理 system_prompt 和 preamble
|
||||
system_prompt_from_db = bot_config.get("system_prompt", "")
|
||||
preamble_text, system_prompt = get_preamble_text(language, system_prompt_from_db)
|
||||
|
||||
|
||||
# 从数据库配置获取其他参数
|
||||
enable_thinking = bot_config.get("enable_thinking", False)
|
||||
enable_memori = bot_config.get("enable_memori", False)
|
||||
|
||||
config = cls(
|
||||
bot_id=request.bot_id,
|
||||
api_key=bot_config.get("api_key", ""),
|
||||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||||
model_server=bot_config.get("model_server", ""),
|
||||
language=language,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=bot_config.get("mcp_settings", []),
|
||||
user_identifier=bot_config.get("user_identifier", ""),
|
||||
session_id=request.session_id,
|
||||
enable_thinking=enable_thinking,
|
||||
project_dir=project_dir,
|
||||
stream=request.stream,
|
||||
tool_response=bot_config.get("tool_response", True),
|
||||
generate_cfg={}, # v3接口不传递额外的generate_cfg
|
||||
logging_handler=LoggingCallbackHandler(),
|
||||
messages=messages,
|
||||
_origin_messages=messages,
|
||||
preamble_text=preamble_text,
|
||||
dataset_ids=bot_config.get("dataset_ids", []),
|
||||
enable_memori=enable_memori,
|
||||
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
# 在创建 config 时尽早准备 checkpoint 消息
|
||||
|
||||
@ -12,10 +12,10 @@ logger = logging.getLogger('app')
|
||||
from utils import (
|
||||
Message, ChatRequest, ChatResponse, BatchSaveChatRequest, BatchSaveChatResponse
|
||||
)
|
||||
from utils.api_models import ChatRequestV2
|
||||
from utils.api_models import ChatRequestV2, ChatRequestV3
|
||||
from utils.fastapi_utils import (
|
||||
process_messages,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, fetch_bot_config_from_db,
|
||||
call_preamble_llm,
|
||||
create_stream_chunk
|
||||
)
|
||||
@ -768,6 +768,95 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/api/v3/chat/completions")
|
||||
async def chat_completions_v3(request: ChatRequestV3, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Chat completions API v3 - 从数据库读取配置
|
||||
|
||||
与 v2 相比,v3 从本地数据库读取所有配置参数,而不是从后端 API。
|
||||
前端只需要传递 bot_id 和 messages,其他配置从数据库自动读取。
|
||||
|
||||
Args:
|
||||
request: ChatRequestV3 包含 bot_id, messages, stream, session_id
|
||||
authorization: 可选的认证头
|
||||
|
||||
Returns:
|
||||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||||
|
||||
Required Parameters:
|
||||
- bot_id: str - 目标机器人ID(用户创建时填写的ID)
|
||||
- messages: List[Message] - 对话消息列表
|
||||
|
||||
Optional Parameters:
|
||||
- stream: bool - 是否流式输出,默认false
|
||||
- session_id: str - 会话ID,用于保存聊天历史
|
||||
|
||||
Configuration (from database):
|
||||
- model: 模型名称
|
||||
- api_key: API密钥
|
||||
- model_server: 模型服务器地址
|
||||
- language: 回复语言
|
||||
- tool_response: 是否包含工具响应
|
||||
- system_prompt: 系统提示词
|
||||
- dataset_ids: 数据集ID列表
|
||||
- mcp_settings: MCP服务器配置
|
||||
- user_identifier: 用户标识符
|
||||
|
||||
Authentication:
|
||||
- 可选的 Authorization header(如果需要验证)
|
||||
"""
|
||||
try:
|
||||
# 获取bot_id(必需参数)
|
||||
bot_id = request.bot_id
|
||||
if not bot_id:
|
||||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||||
|
||||
# 可选的鉴权验证(如果传递了 authorization header)
|
||||
if authorization:
|
||||
expected_token = generate_v2_auth_token(bot_id)
|
||||
provided_token = extract_api_key_from_auth(authorization)
|
||||
if provided_token and provided_token != expected_token:
|
||||
logger.warning(f"Invalid auth token provided for v3 API, but continuing anyway")
|
||||
|
||||
# 从数据库获取机器人配置
|
||||
bot_config = await fetch_bot_config_from_db(bot_id, request.user_identifier)
|
||||
|
||||
# 构造类 v2 的请求格式
|
||||
# 从数据库配置中提取参数
|
||||
language = bot_config.get("language", "zh")
|
||||
# 创建项目目录(从数据库配置获取)
|
||||
project_dir = create_project_directory(
|
||||
bot_config.get("dataset_ids", []),
|
||||
bot_id,
|
||||
bot_config.get("skills", [])
|
||||
)
|
||||
|
||||
# 处理消息
|
||||
messages = process_messages(request.messages, language)
|
||||
|
||||
# 创建 AgentConfig 对象
|
||||
# 需要构造一个兼容 v2 的配置对象
|
||||
config = await AgentConfig.from_v3_request(
|
||||
request,
|
||||
bot_config,
|
||||
project_dir,
|
||||
messages,
|
||||
language
|
||||
)
|
||||
|
||||
# 调用公共的agent创建和响应生成逻辑
|
||||
return await create_agent_and_generate_response(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error in chat_completions_v3: {str(e)}")
|
||||
logger.error(f"Full traceback: {error_details}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 聊天历史查询接口
|
||||
# ============================================================================
|
||||
|
||||
@ -68,11 +68,27 @@ class ChatRequestV2(BaseModel):
|
||||
language: Optional[str] = "zh"
|
||||
user_identifier: Optional[str] = ""
|
||||
session_id: Optional[str] = None
|
||||
shell_env: Optional[Dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class ChatRequestV3(BaseModel):
|
||||
"""
|
||||
v3 API 请求模型 - 从数据库读取配置
|
||||
|
||||
所有配置参数从数据库读取,前端只需传递:
|
||||
- bot_id: Bot 的用户ID(用于从数据库查找配置)
|
||||
- messages: 对话消息列表
|
||||
- session_id: 可选的会话ID
|
||||
- user_identifier: 当前登录用户的用户名,用于标识用户身份
|
||||
"""
|
||||
messages: List[Message]
|
||||
bot_id: str
|
||||
stream: Optional[bool] = False
|
||||
session_id: Optional[str] = None
|
||||
user_identifier: Optional[str] = None
|
||||
|
||||
|
||||
class FileProcessRequest(BaseModel):
|
||||
unique_id: str
|
||||
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user