openai api
This commit is contained in:
parent
afe7600534
commit
10c2ef0bbc
BIN
__pycache__/agent_pool.cpython-312.pyc
Normal file
BIN
__pycache__/agent_pool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/fastapi_app.cpython-312.pyc
Normal file
BIN
__pycache__/fastapi_app.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
__pycache__/project_config.cpython-312.pyc
Normal file
BIN
__pycache__/project_config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/session_manager.cpython-312.pyc
Normal file
BIN
__pycache__/session_manager.cpython-312.pyc
Normal file
Binary file not shown.
178
agent_pool.py
Normal file
178
agent_pool.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPool:
|
||||||
|
"""助手实例池管理器"""
|
||||||
|
|
||||||
|
def __init__(self, pool_size: int = 5):
|
||||||
|
"""
|
||||||
|
初始化助手实例池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_size: 池中实例的数量,默认5个
|
||||||
|
"""
|
||||||
|
self.pool_size = pool_size
|
||||||
|
self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
|
||||||
|
self.semaphore = asyncio.Semaphore(pool_size)
|
||||||
|
self.agents = [] # 保存所有创建的实例引用
|
||||||
|
|
||||||
|
async def initialize(self, agent_factory):
|
||||||
|
"""
|
||||||
|
初始化实例池,使用工厂函数创建助手实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_factory: 创建助手实例的工厂函数
|
||||||
|
"""
|
||||||
|
logger.info(f"正在初始化助手实例池,大小: {self.pool_size}")
|
||||||
|
|
||||||
|
for i in range(self.pool_size):
|
||||||
|
try:
|
||||||
|
agent = agent_factory()
|
||||||
|
await self.pool.put(agent)
|
||||||
|
self.agents.append(agent)
|
||||||
|
logger.info(f"助手实例 {i+1}/{self.pool_size} 创建成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建助手实例 {i+1} 失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info("助手实例池初始化完成")
|
||||||
|
|
||||||
|
async def get_agent(self, timeout: Optional[float] = 30.0):
|
||||||
|
"""
|
||||||
|
获取空闲的助手实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 获取超时时间,默认30秒
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
助手实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: 获取超时
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用信号量控制并发
|
||||||
|
await asyncio.wait_for(self.semaphore.acquire(), timeout=timeout)
|
||||||
|
# 从池中获取实例
|
||||||
|
agent = await asyncio.wait_for(self.pool.get(), timeout=timeout)
|
||||||
|
logger.debug(f"成功获取助手实例,剩余池大小: {self.pool.qsize()}")
|
||||||
|
return agent
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"获取助手实例超时 ({timeout}秒)")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def release_agent(self, agent):
|
||||||
|
"""
|
||||||
|
释放助手实例回池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: 要释放的助手实例
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.pool.put(agent)
|
||||||
|
self.semaphore.release()
|
||||||
|
logger.debug(f"释放助手实例,当前池大小: {self.pool.qsize()}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"释放助手实例失败: {e}")
|
||||||
|
# 即使释放失败也要释放信号量
|
||||||
|
self.semaphore.release()
|
||||||
|
|
||||||
|
def get_pool_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
获取池状态统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含池状态信息的字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"pool_size": self.pool_size,
|
||||||
|
"available_agents": self.pool.qsize(),
|
||||||
|
"total_agents": len(self.agents),
|
||||||
|
"in_use_agents": len(self.agents) - self.pool.qsize()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""关闭实例池,清理资源"""
|
||||||
|
logger.info("正在关闭助手实例池...")
|
||||||
|
|
||||||
|
# 清空队列
|
||||||
|
while not self.pool.empty():
|
||||||
|
try:
|
||||||
|
agent = self.pool.get_nowait()
|
||||||
|
# 如果有清理方法,调用清理
|
||||||
|
if hasattr(agent, 'cleanup'):
|
||||||
|
await agent.cleanup()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("助手实例池已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
# 全局实例池单例
|
||||||
|
_global_agent_pool: Optional[AgentPool] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_pool() -> Optional[AgentPool]:
|
||||||
|
"""获取全局助手实例池"""
|
||||||
|
return _global_agent_pool
|
||||||
|
|
||||||
|
|
||||||
|
def set_agent_pool(pool: AgentPool):
|
||||||
|
"""设置全局助手实例池"""
|
||||||
|
global _global_agent_pool
|
||||||
|
_global_agent_pool = pool
|
||||||
|
|
||||||
|
|
||||||
|
async def init_global_agent_pool(pool_size: int = 5, agent_factory=None):
|
||||||
|
"""
|
||||||
|
初始化全局助手实例池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_size: 池大小
|
||||||
|
agent_factory: 实例工厂函数
|
||||||
|
"""
|
||||||
|
global _global_agent_pool
|
||||||
|
|
||||||
|
if _global_agent_pool is not None:
|
||||||
|
logger.warning("全局助手实例池已存在,跳过初始化")
|
||||||
|
return
|
||||||
|
|
||||||
|
if agent_factory is None:
|
||||||
|
raise ValueError("必须提供 agent_factory 参数")
|
||||||
|
|
||||||
|
_global_agent_pool = AgentPool(pool_size=pool_size)
|
||||||
|
await _global_agent_pool.initialize(agent_factory)
|
||||||
|
logger.info("全局助手实例池初始化完成")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent_from_pool(timeout: Optional[float] = 30.0):
|
||||||
|
"""
|
||||||
|
从全局池获取助手实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 获取超时时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
助手实例
|
||||||
|
"""
|
||||||
|
if _global_agent_pool is None:
|
||||||
|
raise RuntimeError("全局助手实例池未初始化")
|
||||||
|
|
||||||
|
return await _global_agent_pool.get_agent(timeout)
|
||||||
|
|
||||||
|
|
||||||
|
async def release_agent_to_pool(agent):
|
||||||
|
"""
|
||||||
|
释放助手实例到全局池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: 要释放的助手实例
|
||||||
|
"""
|
||||||
|
if _global_agent_pool is None:
|
||||||
|
raise RuntimeError("全局助手实例池未初始化")
|
||||||
|
|
||||||
|
await _global_agent_pool.release_agent(agent)
|
||||||
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
### 数据存储层次
|
### 数据存储层次
|
||||||
```
|
```
|
||||||
./data/
|
[当前数据目录]/
|
||||||
├── [数据集文件夹]/
|
├── [数据集文件夹]/
|
||||||
│ ├── schema.json # 倒排索引层
|
│ ├── schema.json # 倒排索引层
|
||||||
│ ├── serialization.txt # 序列化数据层
|
│ ├── serialization.txt # 序列化数据层
|
||||||
@ -28,7 +28,7 @@
|
|||||||
|
|
||||||
#### 1. 索引层 (schema.json)
|
#### 1. 索引层 (schema.json)
|
||||||
- **功能**:字段枚举值倒排索引,查询入口点
|
- **功能**:字段枚举值倒排索引,查询入口点
|
||||||
- **访问方式**:`json-reader-get_all_keys({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema"})`
|
- **访问方式**:`json-reader-get_all_keys({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema"})`
|
||||||
- **数据结构**:
|
- **数据结构**:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@ -66,14 +66,14 @@
|
|||||||
**执行步骤**:
|
**执行步骤**:
|
||||||
1. **加载索引**:读取schema.json获取字段元数据
|
1. **加载索引**:读取schema.json获取字段元数据
|
||||||
2. **字段分析**:识别数值字段、文本字段、枚举字段
|
2. **字段分析**:识别数值字段、文本字段、枚举字段
|
||||||
3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围
|
3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围
|
||||||
4. **策略制定**:基于查询条件选择最优检索路径
|
4. **策略制定**:基于查询条件选择最优检索路径
|
||||||
5. **范围预估**:评估各条件的数据分布和选择度
|
5. **范围预估**:评估各条件的数据分布和选择度
|
||||||
|
|
||||||
### 阶段2:精准数据匹配
|
### 阶段2:精准数据匹配
|
||||||
**目标**:从序列化数据中提取符合条件的记录
|
**目标**:从序列化数据中提取符合条件的记录
|
||||||
**执行步骤**:
|
**执行步骤**:
|
||||||
1. **预检查**:`ripgrep-count-matches({"path": "./data/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})`
|
1. **预检查**:`ripgrep-count-matches({"path": "[当前数据目录]/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})`
|
||||||
2. **智能限流**:
|
2. **智能限流**:
|
||||||
- 匹配数 > 1000:增加过滤条件,重新预检查
|
- 匹配数 > 1000:增加过滤条件,重新预检查
|
||||||
- 匹配数 100-1000:`ripgrep-search({"maxResults": 30})`
|
- 匹配数 100-1000:`ripgrep-search({"maxResults": 30})`
|
||||||
@ -193,3 +193,5 @@ query_pattern = simple_field_match(conditions[0]) # 先匹配主要条件
|
|||||||
---
|
---
|
||||||
|
|
||||||
**执行提醒**:始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。
|
**执行提醒**:始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。
|
||||||
|
|
||||||
|
**重要说明**:所有文件路径中的 `[当前数据目录]` 将通过系统消息动态提供,请根据实际的数据目录路径进行操作。
|
||||||
|
|||||||
BIN
data/all_hp_product_spec_book2506/.DS_Store
vendored
BIN
data/all_hp_product_spec_book2506/.DS_Store
vendored
Binary file not shown.
306
fastapi_app.py
306
fastapi_app.py
@ -1,62 +1,247 @@
|
|||||||
from typing import Optional
|
import json
|
||||||
|
import os
|
||||||
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||||||
|
|
||||||
from gbase_agent import init_agent_service
|
|
||||||
|
# 自定义版本,不需要text参数,不打印到终端
|
||||||
|
def get_content_from_messages(messages: List[dict]) -> str:
|
||||||
|
full_text = ''
|
||||||
|
content = []
|
||||||
|
TOOL_CALL_S = '[TOOL_CALL]'
|
||||||
|
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||||||
|
THOUGHT_S = '[THINK]'
|
||||||
|
ANSWER_S = '[ANSWER]'
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg['role'] == ASSISTANT:
|
||||||
|
if msg.get('reasoning_content'):
|
||||||
|
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||||||
|
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||||||
|
if msg.get('content'):
|
||||||
|
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||||||
|
content.append(f'{ANSWER_S}\n{msg["content"]}')
|
||||||
|
if msg.get('function_call'):
|
||||||
|
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
|
||||||
|
elif msg['role'] == FUNCTION:
|
||||||
|
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
if content:
|
||||||
|
full_text = '\n'.join(content)
|
||||||
|
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
from agent_pool import (get_agent_from_pool, init_global_agent_pool,
|
||||||
|
release_agent_to_pool)
|
||||||
|
from gbase_agent import init_agent_service_universal, update_agent_llm
|
||||||
|
from project_config import project_manager
|
||||||
|
|
||||||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||||||
|
|
||||||
# Initialize agent globally at startup
|
# 全局助手实例池,在应用启动时初始化
|
||||||
bot = init_agent_service()
|
agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class Message(BaseModel):
|
||||||
question: str
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
model: str = "qwen3-next"
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
extra: Optional[Dict] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
file_url: Optional[str] = None
|
file_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
answer: str
|
choices: List[Dict]
|
||||||
|
usage: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query", response_model=QueryResponse)
|
class ChatStreamResponse(BaseModel):
|
||||||
async def query_database(request: QueryRequest):
|
choices: List[Dict]
|
||||||
"""
|
usage: Optional[Dict] = None
|
||||||
Process a database query using the assistant agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: QueryRequest containing the query and optional file URL
|
|
||||||
|
|
||||||
Returns:
|
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
|
||||||
QueryResponse containing the assistant's response
|
"""生成流式响应"""
|
||||||
"""
|
accumulated_content = ""
|
||||||
|
accumulated_args = ""
|
||||||
|
chunk_id = 0
|
||||||
try:
|
try:
|
||||||
messages = []
|
for response in agent.run(messages=messages):
|
||||||
|
previous_content = accumulated_content
|
||||||
if request.file_url:
|
accumulated_content = get_content_from_messages(response)
|
||||||
messages.append(
|
|
||||||
{
|
# 计算新增的内容
|
||||||
"role": "user",
|
if accumulated_content.startswith(previous_content):
|
||||||
"content": [{"text":"使用sqlite数据库,用日语回答下面问题:"+request.question}, {"file": request.file_url}],
|
new_content = accumulated_content[len(previous_content):]
|
||||||
|
else:
|
||||||
|
new_content = accumulated_content
|
||||||
|
previous_content = ""
|
||||||
|
|
||||||
|
# 只有当有新内容时才发送chunk
|
||||||
|
if new_content:
|
||||||
|
chunk_id += 1
|
||||||
|
# 构造OpenAI格式的流式响应
|
||||||
|
chunk_data = {
|
||||||
|
"id": f"chatcmpl-{chunk_id}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(__import__('time').time()),
|
||||||
|
"model": request.model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": new_content
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 发送最终完成标记
|
||||||
|
final_chunk = {
|
||||||
|
"id": f"chatcmpl-{chunk_id + 1}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(__import__('time').time()),
|
||||||
|
"model": request.model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 发送结束标记
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_details = traceback.format_exc()
|
||||||
|
print(f"Error in generate_stream_response: {str(e)}")
|
||||||
|
print(f"Full traceback: {error_details}")
|
||||||
|
|
||||||
|
error_data = {
|
||||||
|
"error": {
|
||||||
|
"message": f"Stream error: {str(e)}",
|
||||||
|
"type": "internal_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat/completions")
|
||||||
|
async def chat_completions(request: ChatRequest):
|
||||||
|
"""
|
||||||
|
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: ChatRequest containing messages, model, project_id in extra field, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||||||
|
"""
|
||||||
|
agent = None
|
||||||
|
try:
|
||||||
|
# 从extra字段中获取project_id
|
||||||
|
if not request.extra or 'project_id' not in request.extra:
|
||||||
|
raise HTTPException(status_code=400, detail="project_id is required in extra field")
|
||||||
|
|
||||||
|
project_id = request.extra['project_id']
|
||||||
|
|
||||||
|
# 验证项目访问权限
|
||||||
|
if not project_manager.validate_project_access(project_id):
|
||||||
|
raise HTTPException(status_code=404, detail=f"Project {project_id} not found or inactive")
|
||||||
|
|
||||||
|
# 获取项目数据目录
|
||||||
|
project_dir = project_manager.get_project_dir(project_id)
|
||||||
|
|
||||||
|
# 从实例池获取助手实例
|
||||||
|
agent = await get_agent_from_pool(timeout=30.0)
|
||||||
|
|
||||||
|
# 准备LLM配置,从extra字段中移除project_id
|
||||||
|
llm_extra = request.extra.copy() if request.extra else {}
|
||||||
|
llm_extra.pop('project_id', None) # 移除project_id,不传递给LLM
|
||||||
|
|
||||||
|
# 动态设置请求的模型,支持从接口传入api_key和extra参数
|
||||||
|
update_agent_llm(agent, request.model, request.api_key, llm_extra)
|
||||||
|
|
||||||
|
# 构建包含项目信息的消息上下文
|
||||||
|
messages = [
|
||||||
|
# 项目信息系统消息
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"当前项目ID: {project_id},数据目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}"
|
||||||
|
},
|
||||||
|
# 用户消息批量转换
|
||||||
|
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||||
|
]
|
||||||
|
|
||||||
|
# 根据stream参数决定返回流式还是非流式响应
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_stream_response(agent, messages, request),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
messages.append({"role": "user", "content": request.question})
|
# 非流式响应
|
||||||
|
final_responses = agent.run_nonstream(messages)
|
||||||
responses = []
|
|
||||||
for response in bot.run(messages):
|
if final_responses and len(final_responses) > 0:
|
||||||
responses.append(response)
|
# 取最后一个响应
|
||||||
|
final_response = final_responses[-1]
|
||||||
if responses:
|
|
||||||
final_response = responses[-1][-1]
|
# 如果返回的是Message对象,需要转换为字典
|
||||||
return QueryResponse(answer=final_response["content"])
|
if hasattr(final_response, 'model_dump'):
|
||||||
else:
|
final_response = final_response.model_dump()
|
||||||
raise HTTPException(status_code=500, detail="No response from agent")
|
elif hasattr(final_response, 'dict'):
|
||||||
|
final_response = final_response.dict()
|
||||||
|
|
||||||
|
content = final_response.get("content", "")
|
||||||
|
|
||||||
|
# 构造OpenAI格式的响应
|
||||||
|
return ChatResponse(
|
||||||
|
choices=[{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
|
||||||
|
"completion_tokens": len(content),
|
||||||
|
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="No response from agent")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_details = traceback.format_exc()
|
||||||
|
print(f"Error in chat_completions: {str(e)}")
|
||||||
|
print(f"Full traceback: {error_details}")
|
||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
finally:
|
||||||
|
# 确保释放助手实例回池
|
||||||
|
if agent is not None:
|
||||||
|
await release_agent_to_pool(agent)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@ -65,6 +250,53 @@ async def root():
|
|||||||
return {"message": "Database Assistant API is running"}
|
return {"message": "Database Assistant API is running"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/system/status")
|
||||||
|
async def system_status():
|
||||||
|
"""获取系统状态信息"""
|
||||||
|
from agent_pool import get_agent_pool
|
||||||
|
|
||||||
|
pool = get_agent_pool()
|
||||||
|
pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "running",
|
||||||
|
"storage_type": "Agent Pool API",
|
||||||
|
"agent_pool": {
|
||||||
|
"pool_size": pool_stats["pool_size"],
|
||||||
|
"available_agents": pool_stats["available_agents"],
|
||||||
|
"total_agents": pool_stats["total_agents"],
|
||||||
|
"in_use_agents": pool_stats["in_use_agents"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""应用启动时初始化助手实例池"""
|
||||||
|
print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
def agent_factory():
|
||||||
|
return init_agent_service_universal()
|
||||||
|
|
||||||
|
await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory)
|
||||||
|
print("助手实例池初始化完成!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"助手实例池初始化失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""应用关闭时清理实例池"""
|
||||||
|
print("正在关闭应用,清理助手实例池...")
|
||||||
|
|
||||||
|
from agent_pool import get_agent_pool
|
||||||
|
pool = get_agent_pool()
|
||||||
|
if pool:
|
||||||
|
await pool.shutdown()
|
||||||
|
print("助手实例池清理完成!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
|||||||
302
gbase_agent.py
302
gbase_agent.py
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
@ -30,118 +29,6 @@ from qwen_agent.utils.output_beautify import typewriter_print
|
|||||||
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
||||||
|
|
||||||
|
|
||||||
class GPT4OChat(TextChatAtOAI):
|
|
||||||
"""自定义 GPT-4o 聊天类,修复 tool_call_id 问题"""
|
|
||||||
|
|
||||||
def convert_messages_to_dicts(self, messages: List[Message]) -> List[dict]:
|
|
||||||
# 使用父类方法进行基础转换
|
|
||||||
messages = super().convert_messages_to_dicts(messages)
|
|
||||||
|
|
||||||
# 应用修复后的消息转换
|
|
||||||
messages = self._fixed_conv_qwen_agent_messages_to_oai(messages)
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fixed_conv_qwen_agent_messages_to_oai(messages: List[Union[Message, Dict]]):
|
|
||||||
"""修复后的消息转换方法,确保 tool 消息包含 tool_call_id 字段"""
|
|
||||||
new_messages = []
|
|
||||||
i = 0
|
|
||||||
|
|
||||||
while i < len(messages):
|
|
||||||
msg = messages[i]
|
|
||||||
|
|
||||||
if msg['role'] == ASSISTANT:
|
|
||||||
# 处理 assistant 消息
|
|
||||||
assistant_msg = {'role': 'assistant'}
|
|
||||||
|
|
||||||
# 设置 content
|
|
||||||
content = msg.get('content', '')
|
|
||||||
if isinstance(content, (list, dict)):
|
|
||||||
assistant_msg['content'] = json.dumps(content, ensure_ascii=False)
|
|
||||||
elif content is None:
|
|
||||||
assistant_msg['content'] = ''
|
|
||||||
else:
|
|
||||||
assistant_msg['content'] = content
|
|
||||||
|
|
||||||
# 设置 reasoning_content
|
|
||||||
if msg.get('reasoning_content'):
|
|
||||||
assistant_msg['reasoning_content'] = msg['reasoning_content']
|
|
||||||
|
|
||||||
# 检查是否需要构造 tool_calls
|
|
||||||
has_tool_call = False
|
|
||||||
tool_calls = []
|
|
||||||
|
|
||||||
# 情况1:当前消息有 function_call
|
|
||||||
if msg.get('function_call'):
|
|
||||||
has_tool_call = True
|
|
||||||
tool_calls.append({
|
|
||||||
'id': msg.get('extra', {}).get('function_id', '1'),
|
|
||||||
'type': 'function',
|
|
||||||
'function': {
|
|
||||||
'name': msg['function_call']['name'],
|
|
||||||
'arguments': msg['function_call']['arguments']
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 注意:不再为孤立的 tool 消息构造虚假的 tool_call
|
|
||||||
|
|
||||||
if has_tool_call:
|
|
||||||
assistant_msg['tool_calls'] = tool_calls
|
|
||||||
new_messages.append(assistant_msg)
|
|
||||||
|
|
||||||
# 检查后续是否有对应的 tool 消息
|
|
||||||
if i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
|
|
||||||
tool_msg = copy.deepcopy(messages[i + 1])
|
|
||||||
# 确保 tool_call_id 匹配
|
|
||||||
tool_msg['tool_call_id'] = tool_calls[0]['id']
|
|
||||||
# 移除多余字段
|
|
||||||
for field in ['id', 'extra', 'function_call']:
|
|
||||||
if field in tool_msg:
|
|
||||||
del tool_msg[field]
|
|
||||||
# 确保 content 有效且为字符串
|
|
||||||
content = tool_msg.get('content', '')
|
|
||||||
if isinstance(content, (list, dict)):
|
|
||||||
tool_msg['content'] = json.dumps(content, ensure_ascii=False)
|
|
||||||
elif content is None:
|
|
||||||
tool_msg['content'] = ''
|
|
||||||
new_messages.append(tool_msg)
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
new_messages.append(assistant_msg)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
elif msg['role'] == 'tool':
|
|
||||||
# 孤立的 tool 消息,转换为 assistant + user 消息序列
|
|
||||||
# 首先添加一个包含工具结果的 assistant 消息
|
|
||||||
assistant_result = {'role': 'assistant'}
|
|
||||||
content = msg.get('content', '')
|
|
||||||
if isinstance(content, (list, dict)):
|
|
||||||
content = json.dumps(content, ensure_ascii=False)
|
|
||||||
assistant_result['content'] = f"工具查询结果: {content}"
|
|
||||||
new_messages.append(assistant_result)
|
|
||||||
|
|
||||||
# 然后添加一个 user 消息来继续对话
|
|
||||||
new_messages.append({'role': 'user', 'content': '请继续分析以上结果'})
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 处理其他角色消息
|
|
||||||
new_msg = copy.deepcopy(msg)
|
|
||||||
|
|
||||||
# 确保 content 有效且为字符串
|
|
||||||
content = new_msg.get('content', '')
|
|
||||||
if isinstance(content, (list, dict)):
|
|
||||||
new_msg['content'] = json.dumps(content, ensure_ascii=False)
|
|
||||||
elif content is None:
|
|
||||||
new_msg['content'] = ''
|
|
||||||
|
|
||||||
new_messages.append(new_msg)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
|
|
||||||
def read_mcp_settings():
|
def read_mcp_settings():
|
||||||
@ -150,107 +37,70 @@ def read_mcp_settings():
|
|||||||
return mcp_settings_json
|
return mcp_settings_json
|
||||||
|
|
||||||
|
|
||||||
|
def read_mcp_settings_with_project_restriction(project_data_dir: str):
|
||||||
|
"""读取MCP配置并添加项目目录限制"""
|
||||||
|
with open("./mcp/mcp_settings.json", "r") as f:
|
||||||
|
mcp_settings_json = json.load(f)
|
||||||
|
|
||||||
|
# 为json-reader添加项目目录限制
|
||||||
|
for server_config in mcp_settings_json:
|
||||||
|
if "mcpServers" in server_config:
|
||||||
|
for server_name, server_info in server_config["mcpServers"].items():
|
||||||
|
if server_name == "json-reader":
|
||||||
|
# 添加环境变量来传递项目目录限制
|
||||||
|
if "env" not in server_info:
|
||||||
|
server_info["env"] = {}
|
||||||
|
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
|
||||||
|
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
|
||||||
|
break
|
||||||
|
|
||||||
|
return mcp_settings_json
|
||||||
|
|
||||||
|
|
||||||
def read_system_prompt():
|
def read_system_prompt():
|
||||||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||||||
return f.read().strip()
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def read_system_prompt():
|
||||||
|
"""读取通用的无状态系统prompt"""
|
||||||
|
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
def init_agent_service():
|
def init_agent_service():
|
||||||
llm_cfg = {
|
"""默认初始化函数,保持向后兼容"""
|
||||||
"llama-33": {
|
return init_agent_service_universal("qwen3-next")
|
||||||
"model": "gbase-llama-33",
|
|
||||||
"model_server": "http://llmapi:9009/v1",
|
|
||||||
"api_key": "any",
|
|
||||||
},
|
|
||||||
"gpt-oss-120b": {
|
|
||||||
"model": "openai/gpt-oss-120b",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
|
||||||
"api_key": "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",
|
|
||||||
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
|
||||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
"claude-3.7": {
|
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
|
||||||
"model_server": "https://one.felo.me/v1",
|
|
||||||
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"gpt-4o": {
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"model_server": "https://one-dev.felo.me/v1",
|
|
||||||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
|
|
||||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"Gpt-4o-back": {
|
|
||||||
"model_type": "oai", # 使用 oai 类型以便使用自定义类
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"model_server": "https://one-dev.felo.me/v1",
|
|
||||||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
|
|
||||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
|
||||||
},
|
|
||||||
# 使用自定义的 GPT4OChat 类
|
|
||||||
"llm_class": GPT4OChat,
|
|
||||||
},
|
|
||||||
|
|
||||||
"glm-45": {
|
def read_llm_config():
|
||||||
"model_server": "https://open.bigmodel.cn/api/paas/v4",
|
"""读取LLM配置文件"""
|
||||||
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
|
with open("./llm_config.json", "r") as f:
|
||||||
"model": "glm-4.5",
|
return json.load(f)
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"qwen3-next": {
|
|
||||||
"model": "qwen/qwen3-next-80b-a3b-instruct",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
|
||||||
"api_key": "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",
|
|
||||||
|
|
||||||
},
|
|
||||||
"deepresearch": {
|
|
||||||
"model": "alibaba/tongyi-deepresearch-30b-a3b",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
|
||||||
"api_key": "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",
|
|
||||||
},
|
|
||||||
|
|
||||||
"qwen3-coder":{
|
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
|
||||||
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
|
||||||
"model_server": "https://api-inference.modelscope.cn/v1", # base_url, also known as api_base
|
llm_cfg = read_llm_config()
|
||||||
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556",
|
|
||||||
},
|
# 读取通用的系统prompt(无状态)
|
||||||
"openrouter-gpt4o":{
|
|
||||||
"model": "openai/gpt-4o",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
|
||||||
"api_key": "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",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
|
||||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
system = read_system_prompt()
|
system = read_system_prompt()
|
||||||
|
|
||||||
# 暂时禁用 MCP 工具以测试 GPT-4o
|
# 读取MCP工具配置
|
||||||
tools = read_mcp_settings()
|
tools = read_mcp_settings_with_project_restriction(project_data_dir)
|
||||||
# 使用自定义的 GPT-4o 配置
|
|
||||||
llm_instance = llm_cfg["qwen3-next"]
|
# 使用指定的模型配置
|
||||||
|
if model_name not in llm_cfg:
|
||||||
|
raise ValueError(f"Model '{model_name}' not found in llm_config.json. Available models: {list(llm_cfg.keys())}")
|
||||||
|
|
||||||
|
llm_instance = llm_cfg[model_name]
|
||||||
if "llm_class" in llm_instance:
|
if "llm_class" in llm_instance:
|
||||||
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
||||||
|
|
||||||
bot = Assistant(
|
bot = Assistant(
|
||||||
llm=llm_instance, # 使用自定义的 GPT-4o 实例
|
llm=llm_instance, # 使用指定的模型实例
|
||||||
name="数据库助手",
|
name=f"数据库助手-{project_id}",
|
||||||
description="数据库查询",
|
description=f"项目 {project_id} 数据库查询",
|
||||||
system_message=system,
|
system_message=system,
|
||||||
function_list=tools,
|
function_list=tools,
|
||||||
)
|
)
|
||||||
@ -258,6 +108,62 @@ def init_agent_service():
|
|||||||
return bot
|
return bot
|
||||||
|
|
||||||
|
|
||||||
|
def init_agent_service_universal():
|
||||||
|
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
||||||
|
llm_cfg = read_llm_config()
|
||||||
|
|
||||||
|
# 读取通用的系统prompt(无状态)
|
||||||
|
system = read_system_prompt()
|
||||||
|
|
||||||
|
# 读取基础的MCP工具配置(不包含项目限制)
|
||||||
|
tools = read_mcp_settings()
|
||||||
|
|
||||||
|
# 使用默认模型创建助手实例
|
||||||
|
default_model = "qwen3-next" # 默认模型
|
||||||
|
if default_model not in llm_cfg:
|
||||||
|
# 如果默认模型不存在,使用第一个可用模型
|
||||||
|
default_model = list(llm_cfg.keys())[0]
|
||||||
|
|
||||||
|
llm_instance = llm_cfg[default_model]
|
||||||
|
if "llm_class" in llm_instance:
|
||||||
|
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
||||||
|
|
||||||
|
bot = Assistant(
|
||||||
|
llm=llm_instance, # 使用默认LLM初始化
|
||||||
|
name="通用数据检索助手",
|
||||||
|
description="无状态通用数据检索助手",
|
||||||
|
system_message=system,
|
||||||
|
function_list=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return bot
|
||||||
|
|
||||||
|
|
||||||
|
def update_agent_llm(agent, model_name: str, api_key: str = None, extra: Dict = None):
|
||||||
|
"""动态更新助手实例的LLM,支持从接口传入参数"""
|
||||||
|
|
||||||
|
# 获取基础配置
|
||||||
|
llm_config = {
|
||||||
|
"model": model_name,
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
# 如果接口传入了extra参数,则合并到配置中
|
||||||
|
if extra is not None:
|
||||||
|
llm_config.update(extra)
|
||||||
|
|
||||||
|
# 创建LLM实例,确保不是字典
|
||||||
|
if "llm_class" in llm_config:
|
||||||
|
llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
|
||||||
|
else:
|
||||||
|
# 使用默认的 TextChatAtOAI 类
|
||||||
|
llm_instance = TextChatAtOAI(llm_config)
|
||||||
|
|
||||||
|
# 动态设置LLM
|
||||||
|
agent.llm = llm_instance
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
def test(query="数据库里有几张表"):
|
def test(query="数据库里有几张表"):
|
||||||
# Define the agent
|
# Define the agent
|
||||||
bot = init_agent_service()
|
bot = init_agent_service()
|
||||||
|
|||||||
65
llm_config.json
Normal file
65
llm_config.json
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
{
|
||||||
|
"llama-33": {
|
||||||
|
"model": "gbase-llama-33",
|
||||||
|
"model_server": "http://llmapi:9009/v1",
|
||||||
|
"api_key": "any"
|
||||||
|
},
|
||||||
|
"gpt-oss-120b": {
|
||||||
|
"model": "openai/gpt-oss-120b",
|
||||||
|
"model_server": "https://openrouter.ai/api/v1",
|
||||||
|
"api_key": "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",
|
||||||
|
"generate_cfg": {
|
||||||
|
"use_raw_api": true,
|
||||||
|
"fncall_prompt_type": "nous"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"claude-3.7": {
|
||||||
|
"model": "claude-3-7-sonnet-20250219",
|
||||||
|
"model_server": "https://one.felo.me/v1",
|
||||||
|
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
|
||||||
|
"generate_cfg": {
|
||||||
|
"use_raw_api": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gpt-4o": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"model_server": "https://one-dev.felo.me/v1",
|
||||||
|
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
||||||
|
"generate_cfg": {
|
||||||
|
"use_raw_api": true,
|
||||||
|
"fncall_prompt_type": "nous"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"glm-45": {
|
||||||
|
"model_server": "https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
|
||||||
|
"model": "glm-4.5",
|
||||||
|
"generate_cfg": {
|
||||||
|
"use_raw_api": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"qwen3-next": {
|
||||||
|
"model": "qwen/qwen3-next-80b-a3b-instruct",
|
||||||
|
"model_server": "https://openrouter.ai/api/v1",
|
||||||
|
"api_key": "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"
|
||||||
|
},
|
||||||
|
"deepresearch": {
|
||||||
|
"model": "alibaba/tongyi-deepresearch-30b-a3b",
|
||||||
|
"model_server": "https://openrouter.ai/api/v1",
|
||||||
|
"api_key": "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"
|
||||||
|
},
|
||||||
|
"qwen3-coder": {
|
||||||
|
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||||||
|
"model_server": "https://api-inference.modelscope.cn/v1",
|
||||||
|
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556"
|
||||||
|
},
|
||||||
|
"openrouter-gpt4o": {
|
||||||
|
"model": "openai/gpt-4o",
|
||||||
|
"model_server": "https://openrouter.ai/api/v1",
|
||||||
|
"api_key": "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",
|
||||||
|
"generate_cfg": {
|
||||||
|
"use_raw_api": true,
|
||||||
|
"fncall_prompt_type": "nous"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
mcp/__pycache__/json_reader_server.cpython-312.pyc
Normal file
BIN
mcp/__pycache__/json_reader_server.cpython-312.pyc
Normal file
Binary file not shown.
BIN
mcp/__pycache__/mcp_wrapper.cpython-312.pyc
Normal file
BIN
mcp/__pycache__/mcp_wrapper.cpython-312.pyc
Normal file
Binary file not shown.
61
mcp/directory_tree_wrapper_server.py
Normal file
61
mcp/directory_tree_wrapper_server.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
目录树 MCP包装器服务器
|
||||||
|
提供安全的目录结构查看功能,限制在项目目录内
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from mcp_wrapper import handle_wrapped_request
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主入口点"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 从stdin读取
|
||||||
|
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = json.loads(line)
|
||||||
|
response = await handle_wrapped_request(request, "directory-tree-wrapper")
|
||||||
|
|
||||||
|
# 写入stdout
|
||||||
|
sys.stdout.write(json.dumps(response) + "\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32700,
|
||||||
|
"message": "Parse error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": f"Internal error: {str(e)}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -12,6 +12,32 @@ import sys
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||||
|
"""验证文件路径是否在允许的目录内"""
|
||||||
|
# 转换为绝对路径
|
||||||
|
if not os.path.isabs(file_path):
|
||||||
|
file_path = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
allowed_dir = os.path.abspath(allowed_dir)
|
||||||
|
|
||||||
|
# 检查路径是否在允许的目录内
|
||||||
|
if not file_path.startswith(allowed_dir):
|
||||||
|
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir} 内")
|
||||||
|
|
||||||
|
# 检查路径遍历攻击
|
||||||
|
if ".." in file_path:
|
||||||
|
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_directory():
|
||||||
|
"""获取允许访问的目录"""
|
||||||
|
# 从环境变量读取项目数据目录
|
||||||
|
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||||
|
return os.path.abspath(project_dir)
|
||||||
|
|
||||||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Handle MCP request"""
|
"""Handle MCP request"""
|
||||||
try:
|
try:
|
||||||
@ -130,9 +156,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert relative path to absolute path
|
# 验证文件路径是否在允许的目录内
|
||||||
if not os.path.isabs(file_path):
|
allowed_dir = get_allowed_directory()
|
||||||
file_path = os.path.abspath(file_path)
|
file_path = validate_file_path(file_path, allowed_dir)
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -222,9 +248,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert relative path to absolute path
|
# 验证文件路径是否在允许的目录内
|
||||||
if not os.path.isabs(file_path):
|
allowed_dir = get_allowed_directory()
|
||||||
file_path = os.path.abspath(file_path)
|
file_path = validate_file_path(file_path, allowed_dir)
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -307,9 +333,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert relative path to absolute path
|
# 验证文件路径是否在允许的目录内
|
||||||
if not os.path.isabs(file_path):
|
allowed_dir = get_allowed_directory()
|
||||||
file_path = os.path.abspath(file_path)
|
file_path = validate_file_path(file_path, allowed_dir)
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|||||||
@ -8,13 +8,6 @@
|
|||||||
"@andredezzy/deep-directory-tree-mcp"
|
"@andredezzy/deep-directory-tree-mcp"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"mcp-server-code-runner": {
|
|
||||||
"command": "npx",
|
|
||||||
"args": [
|
|
||||||
"-y",
|
|
||||||
"mcp-server-code-runner@latest"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"ripgrep": {
|
"ripgrep": {
|
||||||
"command": "npx",
|
"command": "npx",
|
||||||
"args": [
|
"args": [
|
||||||
|
|||||||
308
mcp/mcp_wrapper.py
Normal file
308
mcp/mcp_wrapper.py
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
通用MCP工具包装器
|
||||||
|
为所有MCP工具提供目录访问控制和安全限制
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSecurityWrapper:
|
||||||
|
"""MCP安全包装器"""
|
||||||
|
|
||||||
|
def __init__(self, allowed_directory: str):
|
||||||
|
self.allowed_directory = os.path.abspath(allowed_directory)
|
||||||
|
self.project_id = os.getenv("PROJECT_ID", "default")
|
||||||
|
|
||||||
|
def validate_path(self, path: str) -> str:
|
||||||
|
"""验证路径是否在允许的目录内"""
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
|
||||||
|
# 规范化路径
|
||||||
|
path = os.path.normpath(path)
|
||||||
|
|
||||||
|
# 检查路径遍历
|
||||||
|
if ".." in path.split(os.sep):
|
||||||
|
raise ValueError(f"路径遍历攻击被阻止: {path}")
|
||||||
|
|
||||||
|
# 检查是否在允许的目录内
|
||||||
|
if not path.startswith(self.allowed_directory):
|
||||||
|
raise ValueError(f"访问被拒绝: {path} 不在允许的目录 {self.allowed_directory} 内")
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
def safe_execute_command(self, command: list, cwd: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""安全执行命令,限制工作目录"""
|
||||||
|
if cwd is None:
|
||||||
|
cwd = self.allowed_directory
|
||||||
|
else:
|
||||||
|
cwd = self.validate_path(cwd)
|
||||||
|
|
||||||
|
# 设置环境变量限制
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PWD"] = cwd
|
||||||
|
env["PROJECT_DATA_DIR"] = self.allowed_directory
|
||||||
|
env["PROJECT_ID"] = self.project_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30 # 30秒超时
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": result.returncode == 0,
|
||||||
|
"stdout": result.stdout,
|
||||||
|
"stderr": result.stderr,
|
||||||
|
"returncode": result.returncode
|
||||||
|
}
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "命令执行超时",
|
||||||
|
"returncode": -1
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"returncode": -1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_wrapped_request(request: Dict[str, Any], tool_name: str) -> Dict[str, Any]:
|
||||||
|
"""处理包装后的MCP请求"""
|
||||||
|
try:
|
||||||
|
method = request.get("method")
|
||||||
|
params = request.get("params", {})
|
||||||
|
request_id = request.get("id")
|
||||||
|
|
||||||
|
allowed_dir = get_allowed_directory()
|
||||||
|
wrapper = MCPSecurityWrapper(allowed_dir)
|
||||||
|
|
||||||
|
if method == "initialize":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {
|
||||||
|
"tools": {}
|
||||||
|
},
|
||||||
|
"serverInfo": {
|
||||||
|
"name": f"{tool_name}-wrapper",
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif method == "ping":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"pong": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif method == "tools/list":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"tools": get_tool_definitions(tool_name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif method == "tools/call":
|
||||||
|
return await execute_tool_call(wrapper, tool_name, params, request_id)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"code": -32601,
|
||||||
|
"message": f"Unknown method: {method}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request.get("id"),
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": f"Internal error: {str(e)}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_directory():
|
||||||
|
"""获取允许访问的目录"""
|
||||||
|
project_dir = os.getenv("PROJECT_DATA_DIR", "./data")
|
||||||
|
return os.path.abspath(project_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_definitions(tool_name: str) -> list:
|
||||||
|
"""根据工具名称返回工具定义"""
|
||||||
|
if tool_name == "ripgrep-wrapper":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "ripgrep_search",
|
||||||
|
"description": "在项目目录内搜索文本",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {"type": "string", "description": "搜索模式"},
|
||||||
|
"path": {"type": "string", "description": "搜索路径(相对于项目目录)"},
|
||||||
|
"maxResults": {"type": "integer", "default": 100}
|
||||||
|
},
|
||||||
|
"required": ["pattern"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif tool_name == "directory-tree-wrapper":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "get_directory_tree",
|
||||||
|
"description": "获取项目目录结构",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string", "description": "目录路径(相对于项目目录)"},
|
||||||
|
"max_depth": {"type": "integer", "default": 3}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_tool_call(wrapper: MCPSecurityWrapper, tool_name: str, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""执行工具调用"""
|
||||||
|
try:
|
||||||
|
if tool_name == "ripgrep-wrapper":
|
||||||
|
return await execute_ripgrep_search(wrapper, params, request_id)
|
||||||
|
elif tool_name == "directory-tree-wrapper":
|
||||||
|
return await execute_directory_tree(wrapper, params, request_id)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"code": -32601,
|
||||||
|
"message": f"Unknown tool: {tool_name}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_ripgrep_search(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""执行ripgrep搜索"""
|
||||||
|
pattern = params.get("pattern", "")
|
||||||
|
path = params.get("path", ".")
|
||||||
|
max_results = params.get("maxResults", 100)
|
||||||
|
|
||||||
|
# 验证和构建搜索路径
|
||||||
|
search_path = os.path.join(wrapper.allowed_directory, path)
|
||||||
|
search_path = wrapper.validate_path(search_path)
|
||||||
|
|
||||||
|
# 构建ripgrep命令
|
||||||
|
command = [
|
||||||
|
"rg",
|
||||||
|
"--json",
|
||||||
|
"--max-count", str(max_results),
|
||||||
|
pattern,
|
||||||
|
search_path
|
||||||
|
]
|
||||||
|
|
||||||
|
result = wrapper.safe_execute_command(command)
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": result["stdout"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": f"搜索失败: {result['stderr']}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_directory_tree(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||||
|
"""执行目录树获取"""
|
||||||
|
path = params.get("path", ".")
|
||||||
|
max_depth = params.get("max_depth", 3)
|
||||||
|
|
||||||
|
# 验证和构建目录路径
|
||||||
|
dir_path = os.path.join(wrapper.allowed_directory, path)
|
||||||
|
dir_path = wrapper.validate_path(dir_path)
|
||||||
|
|
||||||
|
# 构建目录树命令
|
||||||
|
command = [
|
||||||
|
"find",
|
||||||
|
dir_path,
|
||||||
|
"-type", "d",
|
||||||
|
"-maxdepth", str(max_depth)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = wrapper.safe_execute_command(command)
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": result["stdout"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": f"获取目录树失败: {result['stderr']}"
|
||||||
|
}
|
||||||
|
}
|
||||||
61
mcp/ripgrep_wrapper_server.py
Normal file
61
mcp/ripgrep_wrapper_server.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Ripgrep MCP包装器服务器
|
||||||
|
提供安全的文本搜索功能,限制在项目目录内
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from mcp_wrapper import handle_wrapped_request
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主入口点"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 从stdin读取
|
||||||
|
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = json.loads(line)
|
||||||
|
response = await handle_wrapped_request(request, "ripgrep-wrapper")
|
||||||
|
|
||||||
|
# 写入stdout
|
||||||
|
sys.stdout.write(json.dumps(response) + "\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32700,
|
||||||
|
"message": "Parse error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": f"Internal error: {str(e)}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
154
project_config.py
Normal file
154
project_config.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
项目配置管理系统
|
||||||
|
负责管理项目ID到数据目录的映射,以及项目访问权限控制
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional, List
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProjectConfig:
|
||||||
|
"""项目配置数据类"""
|
||||||
|
project_id: str
|
||||||
|
data_dir: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
allowed_file_types: List[str] = None
|
||||||
|
max_file_size_mb: int = 100
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.allowed_file_types is None:
|
||||||
|
self.allowed_file_types = [".json", ".txt", ".csv", ".pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectManager:
|
||||||
|
"""项目管理器"""
|
||||||
|
|
||||||
|
def __init__(self, config_file: str = "./projects/project_registry.json"):
|
||||||
|
self.config_file = config_file
|
||||||
|
self.projects: Dict[str, ProjectConfig] = {}
|
||||||
|
self._ensure_config_dir()
|
||||||
|
self._load_projects()
|
||||||
|
|
||||||
|
def _ensure_config_dir(self):
|
||||||
|
"""确保配置目录存在"""
|
||||||
|
config_dir = os.path.dirname(self.config_file)
|
||||||
|
if not os.path.exists(config_dir):
|
||||||
|
os.makedirs(config_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _load_projects(self):
|
||||||
|
"""从配置文件加载项目"""
|
||||||
|
if os.path.exists(self.config_file):
|
||||||
|
try:
|
||||||
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for project_data in data.get('projects', []):
|
||||||
|
config = ProjectConfig(**project_data)
|
||||||
|
self.projects[config.project_id] = config
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载项目配置失败: {e}")
|
||||||
|
self._create_default_config()
|
||||||
|
else:
|
||||||
|
self._create_default_config()
|
||||||
|
|
||||||
|
def _create_default_config(self):
|
||||||
|
"""创建默认配置"""
|
||||||
|
default_project = ProjectConfig(
|
||||||
|
project_id="default",
|
||||||
|
data_dir="./data",
|
||||||
|
name="默认项目",
|
||||||
|
description="默认数据项目"
|
||||||
|
)
|
||||||
|
self.projects["default"] = default_project
|
||||||
|
self._save_projects()
|
||||||
|
|
||||||
|
def _save_projects(self):
|
||||||
|
"""保存项目配置到文件"""
|
||||||
|
data = {
|
||||||
|
"projects": [asdict(project) for project in self.projects.values()]
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存项目配置失败: {e}")
|
||||||
|
|
||||||
|
def get_project(self, project_id: str) -> Optional[ProjectConfig]:
|
||||||
|
"""获取项目配置"""
|
||||||
|
return self.projects.get(project_id)
|
||||||
|
|
||||||
|
def add_project(self, config: ProjectConfig) -> bool:
|
||||||
|
"""添加项目"""
|
||||||
|
if config.project_id in self.projects:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 确保数据目录存在
|
||||||
|
if not os.path.isabs(config.data_dir):
|
||||||
|
config.data_dir = os.path.abspath(config.data_dir)
|
||||||
|
|
||||||
|
os.makedirs(config.data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
self.projects[config.project_id] = config
|
||||||
|
self._save_projects()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_project(self, project_id: str, **kwargs) -> bool:
|
||||||
|
"""更新项目配置"""
|
||||||
|
if project_id not in self.projects:
|
||||||
|
return False
|
||||||
|
|
||||||
|
project = self.projects[project_id]
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(project, key):
|
||||||
|
setattr(project, key, value)
|
||||||
|
|
||||||
|
self._save_projects()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete_project(self, project_id: str) -> bool:
|
||||||
|
"""删除项目"""
|
||||||
|
if project_id not in self.projects:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del self.projects[project_id]
|
||||||
|
self._save_projects()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def list_projects(self) -> List[ProjectConfig]:
|
||||||
|
"""列出所有项目"""
|
||||||
|
return list(self.projects.values())
|
||||||
|
|
||||||
|
def get_project_dir(self, project_id: str) -> str:
|
||||||
|
"""获取项目数据目录"""
|
||||||
|
project = self.get_project(project_id)
|
||||||
|
if project:
|
||||||
|
return project.data_dir
|
||||||
|
|
||||||
|
# 如果项目不存在,创建默认目录结构
|
||||||
|
default_dir = f"./projects/{project_id}/data"
|
||||||
|
os.makedirs(default_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 自动创建新项目配置
|
||||||
|
new_project = ProjectConfig(
|
||||||
|
project_id=project_id,
|
||||||
|
data_dir=default_dir,
|
||||||
|
name=f"项目 {project_id}",
|
||||||
|
description=f"自动创建的项目 {project_id}"
|
||||||
|
)
|
||||||
|
self.add_project(new_project)
|
||||||
|
|
||||||
|
return default_dir
|
||||||
|
|
||||||
|
def validate_project_access(self, project_id: str) -> bool:
|
||||||
|
"""验证项目访问权限"""
|
||||||
|
project = self.get_project(project_id)
|
||||||
|
return project and project.is_active
|
||||||
|
|
||||||
|
|
||||||
|
# 全局项目管理器实例
|
||||||
|
project_manager = ProjectManager()
|
||||||
18
projects/project_registry.json
Normal file
18
projects/project_registry.json
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"project_id": "demo-project",
|
||||||
|
"data_dir": "/Users/moshui/Documents/felo/qwen-agent/projects/demo-project/",
|
||||||
|
"name": "演示项目",
|
||||||
|
"description": "演示多项目隔离功能",
|
||||||
|
"allowed_file_types": [
|
||||||
|
".json",
|
||||||
|
".txt",
|
||||||
|
".csv",
|
||||||
|
".pdf"
|
||||||
|
],
|
||||||
|
"max_file_size_mb": 100,
|
||||||
|
"is_active": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user