Compare commits
5 Commits
3dc119bca8
...
f9efe09f24
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9efe09f24 | ||
|
|
5134c0d8a6 | ||
|
|
4e8e94861f | ||
|
|
f8a44e8d6d | ||
|
|
44b4295a87 |
@ -31,6 +31,7 @@ class AgentConfig:
|
||||
user_identifier: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
dataset_ids: Optional[List[str]] = field(default_factory=list)
|
||||
trace_id: Optional[str] = None # 请求追踪ID,从 X-Request-ID header 获取
|
||||
|
||||
# 响应控制参数
|
||||
stream: bool = False
|
||||
@ -72,6 +73,7 @@ class AgentConfig:
|
||||
'messages': self.messages,
|
||||
'enable_memori': self.enable_memori,
|
||||
'memori_semantic_search_top_k': self.memori_semantic_search_top_k,
|
||||
'trace_id': self.trace_id,
|
||||
}
|
||||
|
||||
def safe_print(self):
|
||||
@ -93,10 +95,18 @@ class AgentConfig:
|
||||
)
|
||||
from .checkpoint_utils import prepare_checkpoint_message
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
from utils.log_util.context import g
|
||||
|
||||
if messages is None:
|
||||
messages = []
|
||||
|
||||
# 从全局上下文获取 trace_id
|
||||
trace_id = None
|
||||
try:
|
||||
trace_id = getattr(g, 'trace_id', None)
|
||||
except LookupError:
|
||||
pass
|
||||
|
||||
robot_type = request.robot_type
|
||||
if robot_type == "catalog_agent":
|
||||
robot_type = "deep_agent"
|
||||
@ -130,6 +140,7 @@ class AgentConfig:
|
||||
dataset_ids=request.dataset_ids,
|
||||
enable_memori=enable_memori,
|
||||
memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K,
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
# 在创建 config 时尽早准备 checkpoint 消息
|
||||
@ -158,9 +169,17 @@ class AgentConfig:
|
||||
)
|
||||
from .checkpoint_utils import prepare_checkpoint_message
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
from utils.log_util.context import g
|
||||
|
||||
if messages is None:
|
||||
messages = []
|
||||
|
||||
# 从全局上下文获取 trace_id
|
||||
trace_id = None
|
||||
try:
|
||||
trace_id = getattr(g, 'trace_id', None)
|
||||
except LookupError:
|
||||
pass
|
||||
language = request.language or bot_config.get("language", "zh")
|
||||
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
|
||||
robot_type = bot_config.get("robot_type", "general_agent")
|
||||
@ -194,6 +213,7 @@ class AgentConfig:
|
||||
dataset_ids=bot_config.get("dataset_ids", []), # 从后端配置获取dataset_ids
|
||||
enable_memori=enable_memori,
|
||||
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
# 在创建 config 时尽早准备 checkpoint 消息
|
||||
|
||||
@ -137,7 +137,7 @@ async def init_agent(config: AgentConfig):
|
||||
|
||||
# 加载配置
|
||||
final_system_prompt = await load_system_prompt_async(
|
||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier, config.trace_id or ""
|
||||
)
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||
@ -240,7 +240,11 @@ async def init_agent(config: AgentConfig):
|
||||
enable_memory=False,
|
||||
workspace_root=workspace_root,
|
||||
middleware=middleware,
|
||||
checkpointer=checkpointer
|
||||
checkpointer=checkpointer,
|
||||
shell_env={
|
||||
"ASSISTANT_ID": config.bot_id,
|
||||
"TRACE_ID": config.trace_id
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||
@ -369,6 +373,7 @@ def create_custom_cli_agent(
|
||||
workspace_root: str | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
shell_env: dict[str, str] | None = None,
|
||||
) -> tuple[Pregel, CompositeBackend]:
|
||||
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
||||
|
||||
@ -393,6 +398,8 @@ def create_custom_cli_agent(
|
||||
workspace_root: Working directory for shell commands. If None, uses Path.cwd().
|
||||
checkpointer: Optional checkpointer for persisting conversation state
|
||||
store: Optional BaseStore for persisting user preferences and agent memory
|
||||
shell_env: Optional custom environment variables to pass to ShellMiddleware.
|
||||
These will be merged with os.environ. Custom vars take precedence.
|
||||
|
||||
Returns:
|
||||
2-tuple of (agent_graph, composite_backend)
|
||||
@ -440,15 +447,18 @@ def create_custom_cli_agent(
|
||||
# Add shell middleware (only in local mode)
|
||||
if enable_shell:
|
||||
# Create environment for shell commands
|
||||
# Restore user's original LANGSMITH_PROJECT so their code traces separately
|
||||
shell_env = os.environ.copy()
|
||||
# Start with a copy of current environment
|
||||
final_shell_env = os.environ.copy()
|
||||
# Merge custom environment variables if provided (custom vars take precedence)
|
||||
if shell_env:
|
||||
final_shell_env.update(shell_env)
|
||||
# Use custom workspace_root if provided, otherwise use current directory
|
||||
shell_workspace = workspace_root if workspace_root is not None else str(Path.cwd())
|
||||
|
||||
agent_middleware.append(
|
||||
ShellMiddleware(
|
||||
workspace_root=shell_workspace,
|
||||
env=shell_env,
|
||||
env=final_shell_env,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@ -69,7 +69,7 @@ def format_datetime_by_language(language: str) -> str:
|
||||
return utc_now.strftime("%Y-%m-%d %H:%M:%S") + " UTC"
|
||||
|
||||
|
||||
async def load_system_prompt_async(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "general_agent", bot_id: str="", user_identifier: str = "") -> str:
|
||||
async def load_system_prompt_async(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "general_agent", bot_id: str="", user_identifier: str = "", trace_id: str = "") -> str:
|
||||
"""异步版本的系统prompt加载
|
||||
|
||||
Args:
|
||||
@ -79,6 +79,7 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
||||
robot_type: 机器人类型,取值 agent/catalog_agent
|
||||
bot_id: 机器人ID
|
||||
user_identifier: 用户标识符
|
||||
trace_id: 请求追踪ID,用于日志追踪
|
||||
|
||||
Returns:
|
||||
str: 加载到的系统提示词内容
|
||||
@ -127,7 +128,8 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
||||
language=language_display,
|
||||
user_identifier=user_identifier,
|
||||
datetime=datetime_str,
|
||||
agent_dir_path="."
|
||||
agent_dir_path=".",
|
||||
trace_id=trace_id or ""
|
||||
)
|
||||
elif system_prompt:
|
||||
prompt = system_prompt.format(language=language_display, user_identifier=user_identifier, datetime=datetime_str)
|
||||
|
||||
@ -71,7 +71,7 @@ from utils.log_util.logger import init_with_fastapi
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
# Import route modules
|
||||
from routes import chat, files, projects, system, skill_manager
|
||||
from routes import chat, files, projects, system, skill_manager, database
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -174,6 +174,7 @@ app.include_router(files.router)
|
||||
app.include_router(projects.router)
|
||||
app.include_router(system.router)
|
||||
app.include_router(skill_manager.router)
|
||||
app.include_router(database.router)
|
||||
|
||||
# 注册文件管理API路由
|
||||
app.include_router(file_manager_router)
|
||||
|
||||
@ -47,7 +47,7 @@ When executing scripts from SKILL.md files, you MUST convert relative paths to a
|
||||
|
||||
- **`{agent_dir_path}/skills/`** - Skill packages with embedded scripts
|
||||
- **`{agent_dir_path}/dataset/`** - Store file datasets and document data
|
||||
- **`{agent_dir_path}/scripts/`** - Place generated executable scripts here (not skill scripts)
|
||||
- **`{agent_dir_path}/executable_code/`** - Place generated executable scripts here (not skill scripts)
|
||||
- **`{agent_dir_path}/download/`** - Store downloaded files and content
|
||||
|
||||
**Path Examples:**
|
||||
|
||||
291
routes/database.py
Normal file
291
routes/database.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""
|
||||
数据库操作 API 路由
|
||||
提供数据库迁移、表结构变更等功能
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent.db_pool_manager import get_db_pool_manager
|
||||
from utils.settings import MASTERKEY
|
||||
from utils.fastapi_utils import extract_api_key_from_auth
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def verify_database_auth(authorization: Optional[str]) -> None:
|
||||
"""
|
||||
验证数据库操作 API 的认证
|
||||
|
||||
直接使用 MASTERKEY 进行验证
|
||||
|
||||
Args:
|
||||
authorization: Authorization header 值
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401/403 错误
|
||||
"""
|
||||
# 提取提供的 token
|
||||
provided_token = extract_api_key_from_auth(authorization)
|
||||
|
||||
if not provided_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Authorization header is required"
|
||||
)
|
||||
|
||||
if provided_token != MASTERKEY:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid authorization token"
|
||||
)
|
||||
|
||||
|
||||
class DatabaseMigrationResponse(BaseModel):
|
||||
"""数据库迁移响应"""
|
||||
success: bool
|
||||
message: str
|
||||
steps_completed: list[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ExecuteSQLRequest(BaseModel):
|
||||
"""执行 SQL 请求"""
|
||||
sql: str
|
||||
autocommit: bool = True
|
||||
|
||||
|
||||
class ExecuteSQLResponse(BaseModel):
|
||||
"""执行 SQL 响应"""
|
||||
success: bool
|
||||
rows_affected: Optional[int] = None
|
||||
message: str
|
||||
columns: Optional[list[str]] = None
|
||||
data: Optional[list[list]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/api/v1/database/migrate-pgvector", response_model=DatabaseMigrationResponse)
|
||||
async def migrate_pgvector(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
执行 pgvector 扩展安装迁移
|
||||
|
||||
执行步骤:
|
||||
1. 将 public.vector 表重命名为 public.vector_legacy
|
||||
2. 创建 pgvector 扩展 (CREATE EXTENSION vector)
|
||||
|
||||
注意:此操作会修改数据库结构,请确保在执行前已做好备份。
|
||||
|
||||
Authentication:
|
||||
- Authorization header should contain: Bearer {MASTERKEY}
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationResponse: 迁移结果
|
||||
"""
|
||||
# 验证认证
|
||||
verify_database_auth(authorization)
|
||||
|
||||
steps_completed = []
|
||||
pool_manager = get_db_pool_manager()
|
||||
|
||||
try:
|
||||
# 获取异步连接
|
||||
pool = pool_manager.pool
|
||||
|
||||
# 步骤 1: 重命名 vector 表为 vector_legacy
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 检查 vector 表是否存在
|
||||
await cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = 'vector'
|
||||
)
|
||||
""")
|
||||
vector_exists = (await cursor.fetchone())[0]
|
||||
|
||||
if vector_exists:
|
||||
# 检查 vector_legacy 是否已存在
|
||||
await cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = 'vector_legacy'
|
||||
)
|
||||
""")
|
||||
legacy_exists = (await cursor.fetchone())[0]
|
||||
|
||||
if legacy_exists:
|
||||
steps_completed.append("vector_legacy 表已存在,跳过重命名")
|
||||
else:
|
||||
# 执行重命名
|
||||
await cursor.execute("ALTER TABLE public.vector RENAME TO vector_legacy")
|
||||
steps_completed.append("已将 public.vector 表重命名为 public.vector_legacy")
|
||||
logger.info("Renamed public.vector to public.vector_legacy")
|
||||
else:
|
||||
steps_completed.append("public.vector 表不存在,跳过重命名")
|
||||
|
||||
# 提交事务
|
||||
await conn.commit()
|
||||
|
||||
# 步骤 2: 创建 pgvector 扩展
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 检查 pgvector 扩展是否已安装
|
||||
await cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_extension WHERE extname = 'vector'
|
||||
)
|
||||
""")
|
||||
extension_exists = (await cursor.fetchone())[0]
|
||||
|
||||
if extension_exists:
|
||||
steps_completed.append("pgvector 扩展已安装")
|
||||
else:
|
||||
# 创建 pgvector 扩展
|
||||
await cursor.execute("CREATE EXTENSION vector")
|
||||
steps_completed.append("已成功安装 pgvector 扩展")
|
||||
logger.info("Created pgvector extension")
|
||||
|
||||
# 提交事务
|
||||
await conn.commit()
|
||||
|
||||
return DatabaseMigrationResponse(
|
||||
success=True,
|
||||
message="pgvector 迁移完成",
|
||||
steps_completed=steps_completed
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"pgvector 迁移失败: {e}")
|
||||
return DatabaseMigrationResponse(
|
||||
success=False,
|
||||
message="pgvector 迁移失败",
|
||||
steps_completed=steps_completed,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/database/execute-sql", response_model=ExecuteSQLResponse)
|
||||
async def execute_sql(request: ExecuteSQLRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
执行自定义 SQL 语句
|
||||
|
||||
注意:此接口具有较高权限,请谨慎使用。
|
||||
|
||||
Authentication:
|
||||
- Authorization header should contain: Bearer {MASTERKEY}
|
||||
|
||||
Args:
|
||||
request: 包含 SQL 语句和是否自动提交的请求
|
||||
|
||||
Returns:
|
||||
ExecuteSQLResponse: 执行结果
|
||||
"""
|
||||
# 验证认证
|
||||
verify_database_auth(authorization)
|
||||
|
||||
pool_manager = get_db_pool_manager()
|
||||
|
||||
try:
|
||||
pool = pool_manager.pool
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(request.sql)
|
||||
rows_affected = cursor.rowcount
|
||||
|
||||
# 获取列名
|
||||
columns = None
|
||||
data = None
|
||||
if cursor.description:
|
||||
columns = [desc.name for desc in cursor.description]
|
||||
# 获取所有行数据
|
||||
rows = await cursor.fetchall()
|
||||
data = [list(row) for row in rows]
|
||||
|
||||
if request.autocommit:
|
||||
await conn.commit()
|
||||
|
||||
return ExecuteSQLResponse(
|
||||
success=True,
|
||||
rows_affected=rows_affected,
|
||||
message=f"SQL 执行成功,影响行数: {rows_affected}, 返回数据: {len(data) if data else 0} 行",
|
||||
columns=columns,
|
||||
data=data
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"SQL 执行失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"SQL 执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/database/check-pgvector")
|
||||
async def check_pgvector(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
检查 pgvector 扩展安装状态
|
||||
|
||||
Authentication:
|
||||
- Authorization header should contain: Bearer {MASTERKEY}
|
||||
|
||||
Returns:
|
||||
pgvector 扩展的状态信息
|
||||
"""
|
||||
# 验证认证
|
||||
verify_database_auth(authorization)
|
||||
|
||||
pool_manager = get_db_pool_manager()
|
||||
|
||||
try:
|
||||
pool = pool_manager.pool
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 检查扩展是否存在
|
||||
await cursor.execute("""
|
||||
SELECT
|
||||
extname,
|
||||
extversion
|
||||
FROM pg_extension
|
||||
WHERE extname = 'vector'
|
||||
""")
|
||||
extension_result = await cursor.fetchone()
|
||||
|
||||
# 检查 vector 相关表
|
||||
await cursor.execute("""
|
||||
SELECT
|
||||
table_name,
|
||||
table_type
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name LIKE '%vector%'
|
||||
ORDER BY table_name
|
||||
""")
|
||||
tables = await cursor.fetchall()
|
||||
|
||||
return {
|
||||
"extension_installed": extension_result is not None,
|
||||
"extension_version": extension_result[1] if extension_result else None,
|
||||
"vector_tables": [
|
||||
{"name": row[0], "type": row[1]}
|
||||
for row in tables
|
||||
]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"检查 pgvector 状态失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"检查失败: {str(e)}"
|
||||
)
|
||||
@ -4,6 +4,7 @@ import shutil
|
||||
import zipfile
|
||||
import logging
|
||||
import asyncio
|
||||
import yaml
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
@ -294,14 +295,12 @@ def parse_skill_frontmatter(skill_md_path: str) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
frontmatter = frontmatter_match.group(1)
|
||||
metadata = {}
|
||||
|
||||
# Parse key: value pairs from frontmatter
|
||||
for line in frontmatter.split('\n'):
|
||||
line = line.strip()
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
metadata[key.strip()] = value.strip()
|
||||
# Parse YAML using yaml.safe_load
|
||||
metadata = yaml.safe_load(frontmatter)
|
||||
if not isinstance(metadata, dict):
|
||||
logger.warning(f"Invalid frontmatter format in {skill_md_path}")
|
||||
return None
|
||||
|
||||
# Return name and description if both exist
|
||||
if 'name' in metadata and 'description' in metadata:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user