422 lines
15 KiB
Python
422 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
API data models and response schemas.
|
||
"""
|
||
|
||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||
from utils.settings import DEFAULT_THINKING_ENABLE
|
||
|
||
class Message(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class DatasetRequest(BaseModel):
|
||
system_prompt: Optional[str] = None
|
||
mcp_settings: Optional[List[Dict]] = None
|
||
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
||
unique_id: Optional[str] = None
|
||
|
||
@field_validator('files', mode='before')
|
||
@classmethod
|
||
def validate_files(cls, v):
|
||
"""Validate dict format with key-grouped files"""
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, dict):
|
||
# Validate dict format
|
||
for key, value in v.items():
|
||
if not isinstance(key, str):
|
||
raise ValueError(f"Key in files dict must be string, got {type(key)}")
|
||
if not isinstance(value, list):
|
||
raise ValueError(f"Value in files dict must be list, got {type(value)} for key '{key}'")
|
||
for item in value:
|
||
if not isinstance(item, str):
|
||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||
return v
|
||
else:
|
||
raise ValueError(f"Files must be a dict with key groups, got {type(v)}")
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: List[Message]
|
||
model: str = "qwen3-next"
|
||
model_server: str = ""
|
||
dataset_ids: Optional[List[str]] = None
|
||
bot_id: str
|
||
stream: Optional[bool] = False
|
||
language: Optional[str] = "zh"
|
||
tool_response: Optional[bool] = True
|
||
system_prompt: Optional[str] = ""
|
||
mcp_settings: Optional[List[Dict]] = None
|
||
robot_type: Optional[str] = "general_agent"
|
||
user_identifier: Optional[str] = ""
|
||
session_id: Optional[str] = None
|
||
enable_thinking: Optional[bool] = DEFAULT_THINKING_ENABLE
|
||
skills: Optional[List[str]] = None
|
||
|
||
|
||
class ChatRequestV2(BaseModel):
|
||
messages: List[Message]
|
||
stream: Optional[bool] = False
|
||
tool_response: Optional[bool] = True
|
||
bot_id: str
|
||
language: Optional[str] = "zh"
|
||
user_identifier: Optional[str] = ""
|
||
session_id: Optional[str] = None
|
||
enable_thinking: Optional[bool] = DEFAULT_THINKING_ENABLE
|
||
|
||
|
||
class FileProcessRequest(BaseModel):
|
||
unique_id: str
|
||
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
||
system_prompt: Optional[str] = None
|
||
mcp_settings: Optional[List[Dict]] = None
|
||
|
||
model_config = ConfigDict(extra='allow')
|
||
|
||
@field_validator('files', mode='before')
|
||
@classmethod
|
||
def validate_files(cls, v):
|
||
"""Validate dict format with key-grouped files"""
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, dict):
|
||
# Validate dict format
|
||
for key, value in v.items():
|
||
if not isinstance(key, str):
|
||
raise ValueError(f"Key in files dict must be string, got {type(key)}")
|
||
if not isinstance(value, list):
|
||
raise ValueError(f"Value in files dict must be list, got {type(value)} for key '{key}'")
|
||
for item in value:
|
||
if not isinstance(item, str):
|
||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||
return v
|
||
else:
|
||
raise ValueError(f"Files must be a dict with key groups, got {type(v)}")
|
||
|
||
|
||
class DatasetResponse(BaseModel):
|
||
success: bool
|
||
message: str
|
||
unique_id: Optional[str] = None
|
||
dataset_structure: Optional[str] = None
|
||
|
||
|
||
class ChatCompletionResponse(BaseModel):
|
||
id: str
|
||
object: str = "chat.completion"
|
||
created: int
|
||
model: str
|
||
choices: List[Dict[str, Any]]
|
||
usage: Optional[Dict[str, int]] = None
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
choices: List[Dict]
|
||
usage: Optional[Dict] = None
|
||
|
||
|
||
class FileProcessResponse(BaseModel):
|
||
success: bool
|
||
message: str
|
||
unique_id: str
|
||
processed_files: List[str]
|
||
|
||
|
||
class ErrorResponse(BaseModel):
|
||
error: Dict[str, Any]
|
||
|
||
@classmethod
|
||
def create(cls, message: str, error_type: str = "invalid_request_error", code: Optional[str] = None):
|
||
error_data = {
|
||
"message": message,
|
||
"type": error_type
|
||
}
|
||
if code:
|
||
error_data["code"] = code
|
||
return cls(error=error_data)
|
||
|
||
|
||
class HealthCheckResponse(BaseModel):
|
||
status: str = "healthy"
|
||
timestamp: str
|
||
version: str = "1.0.0"
|
||
|
||
|
||
class SystemStatusResponse(BaseModel):
|
||
status: str
|
||
projects_count: int
|
||
total_projects: List[str]
|
||
active_projects: List[str]
|
||
system_info: Dict[str, Any]
|
||
|
||
|
||
class CacheStatusResponse(BaseModel):
|
||
cached_projects: List[str]
|
||
cache_info: Dict[str, Any]
|
||
|
||
|
||
class ProjectStatusResponse(BaseModel):
|
||
unique_id: str
|
||
project_exists: bool
|
||
project_path: Optional[str] = None
|
||
processed_files_count: int
|
||
processed_files: Dict[str, Dict]
|
||
document_files_count: int
|
||
document_files: List[str]
|
||
has_system_prompt: bool
|
||
has_mcp_settings: bool
|
||
readme_exists: bool
|
||
log_file_exists: bool
|
||
dataset_structure: Optional[str] = None
|
||
error: Optional[str] = None
|
||
|
||
|
||
class ProjectListResponse(BaseModel):
|
||
projects: List[str]
|
||
count: int
|
||
|
||
|
||
class ProjectStatsResponse(BaseModel):
|
||
unique_id: str
|
||
total_processed_files: int
|
||
total_document_files: int
|
||
total_document_size: int
|
||
total_document_size_mb: float
|
||
has_system_prompt: bool
|
||
has_mcp_settings: bool
|
||
has_readme: bool
|
||
document_files_detail: List[Dict[str, Any]]
|
||
embedding_files_count: int
|
||
embedding_files_detail: List[Dict[str, Any]]
|
||
|
||
|
||
class ProjectActionResponse(BaseModel):
|
||
success: bool
|
||
message: str
|
||
unique_id: str
|
||
action: str
|
||
|
||
|
||
# Utility functions for creating responses
|
||
def create_success_response(message: str, **kwargs) -> Dict[str, Any]:
|
||
"""Create a standardized success response"""
|
||
return {
|
||
"success": True,
|
||
"message": message,
|
||
**kwargs
|
||
}
|
||
|
||
|
||
def create_error_response(message: str, error_type: str = "error", **kwargs) -> Dict[str, Any]:
|
||
"""Create a standardized error response"""
|
||
return {
|
||
"success": False,
|
||
"error": error_type,
|
||
"message": message,
|
||
**kwargs
|
||
}
|
||
|
||
|
||
class QueueTaskRequest(BaseModel):
|
||
"""队列任务请求模型"""
|
||
dataset_id: str
|
||
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
||
upload_folder: Optional[Dict[str, str]] = Field(default=None, description="Upload folders organized by group names. Each key maps to a folder name. Example: {'group1': 'my_project1', 'group2': 'my_project2'}")
|
||
priority: Optional[int] = Field(default=0, description="Task priority (higher number = higher priority)")
|
||
delay: Optional[int] = Field(default=0, description="Delay execution by N seconds")
|
||
|
||
model_config = ConfigDict(extra='allow')
|
||
|
||
@field_validator('upload_folder', mode='before')
|
||
@classmethod
|
||
def validate_upload_folder(cls, v):
|
||
"""Validate upload_folder dict format"""
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, dict):
|
||
# 验证字典格式
|
||
for key, value in v.items():
|
||
if not isinstance(key, str):
|
||
raise ValueError(f"Key in upload_folder dict must be string, got {type(key)}")
|
||
if not isinstance(value, str):
|
||
raise ValueError(f"Value in upload_folder dict must be string (folder name), got {type(value)} for key '{key}'")
|
||
return v
|
||
else:
|
||
raise ValueError(f"upload_folder must be a dict with group names as keys and folder names as values, got {type(v)}")
|
||
|
||
@field_validator('files', mode='before')
|
||
@classmethod
|
||
def validate_files(cls, v):
|
||
"""Validate dict format with key-grouped files"""
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, dict):
|
||
# Validate dict format
|
||
for key, value in v.items():
|
||
if not isinstance(key, str):
|
||
raise ValueError(f"Key in files dict must be string, got {type(key)}")
|
||
if not isinstance(value, list):
|
||
raise ValueError(f"Value in files dict must be list, got {type(value)} for key '{key}'")
|
||
for item in value:
|
||
if not isinstance(item, str):
|
||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||
return v
|
||
else:
|
||
raise ValueError(f"Files must be a dict with key groups, got {type(v)}")
|
||
|
||
|
||
class IncrementalTaskRequest(BaseModel):
|
||
"""增量文件处理请求模型"""
|
||
dataset_id: str = Field(..., description="Dataset ID for the project")
|
||
files_to_add: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to add organized by key groups")
|
||
files_to_remove: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to remove organized by key groups")
|
||
system_prompt: Optional[str] = None
|
||
mcp_settings: Optional[List[Dict]] = None
|
||
priority: Optional[int] = Field(default=0, description="Task priority (higher number = higher priority)")
|
||
delay: Optional[int] = Field(default=0, description="Delay execution by N seconds")
|
||
|
||
model_config = ConfigDict(extra='allow')
|
||
|
||
@field_validator('files_to_add', mode='before')
|
||
@classmethod
|
||
def validate_files_to_add(cls, v):
|
||
"""Validate files_to_add dict format"""
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, dict):
|
||
for key, value in v.items():
|
||
if not isinstance(key, str):
|
||
raise ValueError(f"Key in files_to_add dict must be string, got {type(key)}")
|
||
if not isinstance(value, list):
|
||
raise ValueError(f"Value in files_to_add dict must be list, got {type(value)} for key '{key}'")
|
||
for item in value:
|
||
if not isinstance(item, str):
|
||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||
return v
|
||
else:
|
||
raise ValueError(f"files_to_add must be a dict with key groups, got {type(v)}")
|
||
|
||
@field_validator('files_to_remove', mode='before')
|
||
@classmethod
|
||
def validate_files_to_remove(cls, v):
|
||
"""Validate files_to_remove dict format"""
|
||
if v is None:
|
||
return None
|
||
if isinstance(v, dict):
|
||
for key, value in v.items():
|
||
if not isinstance(key, str):
|
||
raise ValueError(f"Key in files_to_remove dict must be string, got {type(key)}")
|
||
if not isinstance(value, list):
|
||
raise ValueError(f"Value in files_to_remove dict must be list, got {type(value)} for key '{key}'")
|
||
for item in value:
|
||
if not isinstance(item, str):
|
||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||
return v
|
||
else:
|
||
raise ValueError(f"files_to_remove must be a dict with key groups, got {type(v)}")
|
||
|
||
|
||
class QueueTaskResponse(BaseModel):
|
||
"""队列任务响应模型"""
|
||
success: bool
|
||
message: str
|
||
dataset_id: str
|
||
task_id: Optional[str] = None
|
||
task_status: Optional[str] = None
|
||
estimated_processing_time: Optional[int] = None # seconds
|
||
|
||
|
||
class QueueStatusResponse(BaseModel):
|
||
"""队列状态响应模型"""
|
||
success: bool
|
||
message: str
|
||
queue_stats: Dict[str, Any]
|
||
pending_tasks: List[Dict[str, Any]]
|
||
|
||
|
||
class TaskStatusResponse(BaseModel):
|
||
"""任务状态响应模型"""
|
||
success: bool
|
||
message: str
|
||
task_id: str
|
||
task_status: Optional[str] = None
|
||
task_result: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
|
||
|
||
def create_chat_response(
|
||
messages: List[Message],
|
||
model: str,
|
||
content: str,
|
||
usage: Optional[Dict[str, int]] = None
|
||
) -> Dict[str, Any]:
|
||
"""Create a chat completion response"""
|
||
import time
|
||
import uuid
|
||
|
||
return {
|
||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}
|
||
],
|
||
"usage": usage or {
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
"total_tokens": 0
|
||
}
|
||
}
|
||
|
||
|
||
# ============================================================================
|
||
# 聊天历史查询相关模型
|
||
# ============================================================================
|
||
|
||
class ChatHistoryRequest(BaseModel):
|
||
"""聊天历史查询请求"""
|
||
session_id: str = Field(..., description="会话ID (thread_id)")
|
||
last_message_id: Optional[str] = Field(None, description="上一条消息的ID,用于分页查询更早的消息")
|
||
limit: int = Field(20, ge=1, le=100, description="每次查询的消息数量上限")
|
||
|
||
|
||
class ChatHistoryMessage(BaseModel):
|
||
"""聊天历史消息"""
|
||
id: str = Field(..., description="消息唯一ID")
|
||
role: str = Field(..., description="消息角色: user 或 assistant")
|
||
content: str = Field(..., description="消息内容")
|
||
timestamp: Optional[str] = Field(None, description="消息时间戳 (ISO 8601)")
|
||
|
||
|
||
class ChatHistoryResponse(BaseModel):
|
||
"""聊天历史查询响应"""
|
||
messages: List[ChatHistoryMessage] = Field(..., description="消息列表,按时间倒序排列")
|
||
has_more: bool = Field(..., description="是否还有更多历史消息")
|
||
|
||
|
||
class BatchSaveChatRequest(BaseModel):
|
||
"""批量保存聊天记录请求"""
|
||
session_id: str = Field(..., description="会话ID (thread_id)")
|
||
messages: List[Message] = Field(..., description="要保存的消息列表,支持 user 和 assistant 角色")
|
||
bot_id: Optional[str] = Field(None, description="机器人ID")
|
||
|
||
|
||
class BatchSaveChatResponse(BaseModel):
|
||
"""批量保存聊天记录响应"""
|
||
success: bool = Field(..., description="是否成功")
|
||
message: str = Field(..., description="响应消息")
|
||
session_id: str = Field(..., description="会话ID")
|
||
saved_count: int = Field(..., description="成功保存的消息数量")
|
||
message_ids: List[str] = Field(..., description="保存的消息ID列表")
|