openai api
This commit is contained in:
parent
afe7600534
commit
10c2ef0bbc
BIN
__pycache__/agent_pool.cpython-312.pyc
Normal file
BIN
__pycache__/agent_pool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/fastapi_app.cpython-312.pyc
Normal file
BIN
__pycache__/fastapi_app.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
__pycache__/project_config.cpython-312.pyc
Normal file
BIN
__pycache__/project_config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/session_manager.cpython-312.pyc
Normal file
BIN
__pycache__/session_manager.cpython-312.pyc
Normal file
Binary file not shown.
178
agent_pool.py
Normal file
178
agent_pool.py
Normal file
@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentPool:
|
||||
"""助手实例池管理器"""
|
||||
|
||||
def __init__(self, pool_size: int = 5):
|
||||
"""
|
||||
初始化助手实例池
|
||||
|
||||
Args:
|
||||
pool_size: 池中实例的数量,默认5个
|
||||
"""
|
||||
self.pool_size = pool_size
|
||||
self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
|
||||
self.semaphore = asyncio.Semaphore(pool_size)
|
||||
self.agents = [] # 保存所有创建的实例引用
|
||||
|
||||
async def initialize(self, agent_factory):
|
||||
"""
|
||||
初始化实例池,使用工厂函数创建助手实例
|
||||
|
||||
Args:
|
||||
agent_factory: 创建助手实例的工厂函数
|
||||
"""
|
||||
logger.info(f"正在初始化助手实例池,大小: {self.pool_size}")
|
||||
|
||||
for i in range(self.pool_size):
|
||||
try:
|
||||
agent = agent_factory()
|
||||
await self.pool.put(agent)
|
||||
self.agents.append(agent)
|
||||
logger.info(f"助手实例 {i+1}/{self.pool_size} 创建成功")
|
||||
except Exception as e:
|
||||
logger.error(f"创建助手实例 {i+1} 失败: {e}")
|
||||
raise
|
||||
|
||||
logger.info("助手实例池初始化完成")
|
||||
|
||||
async def get_agent(self, timeout: Optional[float] = 30.0):
|
||||
"""
|
||||
获取空闲的助手实例
|
||||
|
||||
Args:
|
||||
timeout: 获取超时时间,默认30秒
|
||||
|
||||
Returns:
|
||||
助手实例
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: 获取超时
|
||||
"""
|
||||
try:
|
||||
# 使用信号量控制并发
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout=timeout)
|
||||
# 从池中获取实例
|
||||
agent = await asyncio.wait_for(self.pool.get(), timeout=timeout)
|
||||
logger.debug(f"成功获取助手实例,剩余池大小: {self.pool.qsize()}")
|
||||
return agent
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"获取助手实例超时 ({timeout}秒)")
|
||||
raise
|
||||
|
||||
async def release_agent(self, agent):
|
||||
"""
|
||||
释放助手实例回池
|
||||
|
||||
Args:
|
||||
agent: 要释放的助手实例
|
||||
"""
|
||||
try:
|
||||
await self.pool.put(agent)
|
||||
self.semaphore.release()
|
||||
logger.debug(f"释放助手实例,当前池大小: {self.pool.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"释放助手实例失败: {e}")
|
||||
# 即使释放失败也要释放信号量
|
||||
self.semaphore.release()
|
||||
|
||||
def get_pool_stats(self) -> dict:
|
||||
"""
|
||||
获取池状态统计信息
|
||||
|
||||
Returns:
|
||||
包含池状态信息的字典
|
||||
"""
|
||||
return {
|
||||
"pool_size": self.pool_size,
|
||||
"available_agents": self.pool.qsize(),
|
||||
"total_agents": len(self.agents),
|
||||
"in_use_agents": len(self.agents) - self.pool.qsize()
|
||||
}
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭实例池,清理资源"""
|
||||
logger.info("正在关闭助手实例池...")
|
||||
|
||||
# 清空队列
|
||||
while not self.pool.empty():
|
||||
try:
|
||||
agent = self.pool.get_nowait()
|
||||
# 如果有清理方法,调用清理
|
||||
if hasattr(agent, 'cleanup'):
|
||||
await agent.cleanup()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
logger.info("助手实例池已关闭")
|
||||
|
||||
|
||||
# 全局实例池单例
|
||||
_global_agent_pool: Optional[AgentPool] = None
|
||||
|
||||
|
||||
def get_agent_pool() -> Optional[AgentPool]:
|
||||
"""获取全局助手实例池"""
|
||||
return _global_agent_pool
|
||||
|
||||
|
||||
def set_agent_pool(pool: AgentPool):
|
||||
"""设置全局助手实例池"""
|
||||
global _global_agent_pool
|
||||
_global_agent_pool = pool
|
||||
|
||||
|
||||
async def init_global_agent_pool(pool_size: int = 5, agent_factory=None):
|
||||
"""
|
||||
初始化全局助手实例池
|
||||
|
||||
Args:
|
||||
pool_size: 池大小
|
||||
agent_factory: 实例工厂函数
|
||||
"""
|
||||
global _global_agent_pool
|
||||
|
||||
if _global_agent_pool is not None:
|
||||
logger.warning("全局助手实例池已存在,跳过初始化")
|
||||
return
|
||||
|
||||
if agent_factory is None:
|
||||
raise ValueError("必须提供 agent_factory 参数")
|
||||
|
||||
_global_agent_pool = AgentPool(pool_size=pool_size)
|
||||
await _global_agent_pool.initialize(agent_factory)
|
||||
logger.info("全局助手实例池初始化完成")
|
||||
|
||||
|
||||
async def get_agent_from_pool(timeout: Optional[float] = 30.0):
|
||||
"""
|
||||
从全局池获取助手实例
|
||||
|
||||
Args:
|
||||
timeout: 获取超时时间
|
||||
|
||||
Returns:
|
||||
助手实例
|
||||
"""
|
||||
if _global_agent_pool is None:
|
||||
raise RuntimeError("全局助手实例池未初始化")
|
||||
|
||||
return await _global_agent_pool.get_agent(timeout)
|
||||
|
||||
|
||||
async def release_agent_to_pool(agent):
|
||||
"""
|
||||
释放助手实例到全局池
|
||||
|
||||
Args:
|
||||
agent: 要释放的助手实例
|
||||
"""
|
||||
if _global_agent_pool is None:
|
||||
raise RuntimeError("全局助手实例池未初始化")
|
||||
|
||||
await _global_agent_pool.release_agent(agent)
|
||||
@ -17,7 +17,7 @@
|
||||
|
||||
### 数据存储层次
|
||||
```
|
||||
./data/
|
||||
[当前数据目录]/
|
||||
├── [数据集文件夹]/
|
||||
│ ├── schema.json # 倒排索引层
|
||||
│ ├── serialization.txt # 序列化数据层
|
||||
@ -28,7 +28,7 @@
|
||||
|
||||
#### 1. 索引层 (schema.json)
|
||||
- **功能**:字段枚举值倒排索引,查询入口点
|
||||
- **访问方式**:`json-reader-get_all_keys({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema"})`
|
||||
- **访问方式**:`json-reader-get_all_keys({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema"})`
|
||||
- **数据结构**:
|
||||
```json
|
||||
{
|
||||
@ -66,14 +66,14 @@
|
||||
**执行步骤**:
|
||||
1. **加载索引**:读取schema.json获取字段元数据
|
||||
2. **字段分析**:识别数值字段、文本字段、枚举字段
|
||||
3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围
|
||||
3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围
|
||||
4. **策略制定**:基于查询条件选择最优检索路径
|
||||
5. **范围预估**:评估各条件的数据分布和选择度
|
||||
|
||||
### 阶段2:精准数据匹配
|
||||
**目标**:从序列化数据中提取符合条件的记录
|
||||
**执行步骤**:
|
||||
1. **预检查**:`ripgrep-count-matches({"path": "./data/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})`
|
||||
1. **预检查**:`ripgrep-count-matches({"path": "[当前数据目录]/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})`
|
||||
2. **智能限流**:
|
||||
- 匹配数 > 1000:增加过滤条件,重新预检查
|
||||
- 匹配数 100-1000:`ripgrep-search({"maxResults": 30})`
|
||||
@ -193,3 +193,5 @@ query_pattern = simple_field_match(conditions[0]) # 先匹配主要条件
|
||||
---
|
||||
|
||||
**执行提醒**:始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。
|
||||
|
||||
**重要说明**:所有文件路径中的 `[当前数据目录]` 将通过系统消息动态提供,请根据实际的数据目录路径进行操作。
|
||||
|
||||
BIN
data/all_hp_product_spec_book2506/.DS_Store
vendored
BIN
data/all_hp_product_spec_book2506/.DS_Store
vendored
Binary file not shown.
304
fastapi_app.py
304
fastapi_app.py
@ -1,62 +1,247 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||||
|
||||
from gbase_agent import init_agent_service
|
||||
|
||||
# 自定义版本,不需要text参数,不打印到终端
|
||||
def get_content_from_messages(messages: List[dict]) -> str:
|
||||
full_text = ''
|
||||
content = []
|
||||
TOOL_CALL_S = '[TOOL_CALL]'
|
||||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||||
THOUGHT_S = '[THINK]'
|
||||
ANSWER_S = '[ANSWER]'
|
||||
|
||||
for msg in messages:
|
||||
if msg['role'] == ASSISTANT:
|
||||
if msg.get('reasoning_content'):
|
||||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||||
if msg.get('content'):
|
||||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||||
content.append(f'{ANSWER_S}\n{msg["content"]}')
|
||||
if msg.get('function_call'):
|
||||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
|
||||
elif msg['role'] == FUNCTION:
|
||||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||||
else:
|
||||
raise TypeError
|
||||
if content:
|
||||
full_text = '\n'.join(content)
|
||||
|
||||
return full_text
|
||||
|
||||
from agent_pool import (get_agent_from_pool, init_global_agent_pool,
|
||||
release_agent_to_pool)
|
||||
from gbase_agent import init_agent_service_universal, update_agent_llm
|
||||
from project_config import project_manager
|
||||
|
||||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||||
|
||||
# Initialize agent globally at startup
|
||||
bot = init_agent_service()
|
||||
# 全局助手实例池,在应用启动时初始化
|
||||
agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1"))
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str = "qwen3-next"
|
||||
api_key: Optional[str] = None
|
||||
extra: Optional[Dict] = None
|
||||
stream: Optional[bool] = False
|
||||
file_url: Optional[str] = None
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
answer: str
|
||||
class ChatResponse(BaseModel):
|
||||
choices: List[Dict]
|
||||
usage: Optional[Dict] = None
|
||||
|
||||
|
||||
@app.post("/query", response_model=QueryResponse)
|
||||
async def query_database(request: QueryRequest):
|
||||
"""
|
||||
Process a database query using the assistant agent.
|
||||
class ChatStreamResponse(BaseModel):
|
||||
choices: List[Dict]
|
||||
usage: Optional[Dict] = None
|
||||
|
||||
Args:
|
||||
request: QueryRequest containing the query and optional file URL
|
||||
|
||||
Returns:
|
||||
QueryResponse containing the assistant's response
|
||||
"""
|
||||
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
|
||||
"""生成流式响应"""
|
||||
accumulated_content = ""
|
||||
accumulated_args = ""
|
||||
chunk_id = 0
|
||||
try:
|
||||
messages = []
|
||||
for response in agent.run(messages=messages):
|
||||
previous_content = accumulated_content
|
||||
accumulated_content = get_content_from_messages(response)
|
||||
|
||||
if request.file_url:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text":"使用sqlite数据库,用日语回答下面问题:"+request.question}, {"file": request.file_url}],
|
||||
# 计算新增的内容
|
||||
if accumulated_content.startswith(previous_content):
|
||||
new_content = accumulated_content[len(previous_content):]
|
||||
else:
|
||||
new_content = accumulated_content
|
||||
previous_content = ""
|
||||
|
||||
# 只有当有新内容时才发送chunk
|
||||
if new_content:
|
||||
chunk_id += 1
|
||||
# 构造OpenAI格式的流式响应
|
||||
chunk_data = {
|
||||
"id": f"chatcmpl-{chunk_id}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(__import__('time').time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new_content
|
||||
},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": request.question})
|
||||
|
||||
responses = []
|
||||
for response in bot.run(messages):
|
||||
responses.append(response)
|
||||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
if responses:
|
||||
final_response = responses[-1][-1]
|
||||
return QueryResponse(answer=final_response["content"])
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="No response from agent")
|
||||
# 发送最终完成标记
|
||||
final_chunk = {
|
||||
"id": f"chatcmpl-{chunk_id + 1}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(__import__('time').time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送结束标记
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
print(f"Error in generate_stream_response: {str(e)}")
|
||||
print(f"Full traceback: {error_details}")
|
||||
|
||||
error_data = {
|
||||
"error": {
|
||||
"message": f"Stream error: {str(e)}",
|
||||
"type": "internal_error"
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatRequest):
|
||||
"""
|
||||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||||
|
||||
Args:
|
||||
request: ChatRequest containing messages, model, project_id in extra field, etc.
|
||||
|
||||
Returns:
|
||||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||||
"""
|
||||
agent = None
|
||||
try:
|
||||
# 从extra字段中获取project_id
|
||||
if not request.extra or 'project_id' not in request.extra:
|
||||
raise HTTPException(status_code=400, detail="project_id is required in extra field")
|
||||
|
||||
project_id = request.extra['project_id']
|
||||
|
||||
# 验证项目访问权限
|
||||
if not project_manager.validate_project_access(project_id):
|
||||
raise HTTPException(status_code=404, detail=f"Project {project_id} not found or inactive")
|
||||
|
||||
# 获取项目数据目录
|
||||
project_dir = project_manager.get_project_dir(project_id)
|
||||
|
||||
# 从实例池获取助手实例
|
||||
agent = await get_agent_from_pool(timeout=30.0)
|
||||
|
||||
# 准备LLM配置,从extra字段中移除project_id
|
||||
llm_extra = request.extra.copy() if request.extra else {}
|
||||
llm_extra.pop('project_id', None) # 移除project_id,不传递给LLM
|
||||
|
||||
# 动态设置请求的模型,支持从接口传入api_key和extra参数
|
||||
update_agent_llm(agent, request.model, request.api_key, llm_extra)
|
||||
|
||||
# 构建包含项目信息的消息上下文
|
||||
messages = [
|
||||
# 项目信息系统消息
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"当前项目ID: {project_id},数据目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}"
|
||||
},
|
||||
# 用户消息批量转换
|
||||
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
]
|
||||
|
||||
# 根据stream参数决定返回流式还是非流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
generate_stream_response(agent, messages, request),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
else:
|
||||
# 非流式响应
|
||||
final_responses = agent.run_nonstream(messages)
|
||||
|
||||
if final_responses and len(final_responses) > 0:
|
||||
# 取最后一个响应
|
||||
final_response = final_responses[-1]
|
||||
|
||||
# 如果返回的是Message对象,需要转换为字典
|
||||
if hasattr(final_response, 'model_dump'):
|
||||
final_response = final_response.model_dump()
|
||||
elif hasattr(final_response, 'dict'):
|
||||
final_response = final_response.dict()
|
||||
|
||||
content = final_response.get("content", "")
|
||||
|
||||
# 构造OpenAI格式的响应
|
||||
return ChatResponse(
|
||||
choices=[{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
|
||||
"completion_tokens": len(content),
|
||||
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="No response from agent")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
print(f"Error in chat_completions: {str(e)}")
|
||||
print(f"Full traceback: {error_details}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
finally:
|
||||
# 确保释放助手实例回池
|
||||
if agent is not None:
|
||||
await release_agent_to_pool(agent)
|
||||
|
||||
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@ -65,6 +250,53 @@ async def root():
|
||||
return {"message": "Database Assistant API is running"}
|
||||
|
||||
|
||||
@app.get("/system/status")
|
||||
async def system_status():
|
||||
"""获取系统状态信息"""
|
||||
from agent_pool import get_agent_pool
|
||||
|
||||
pool = get_agent_pool()
|
||||
pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0}
|
||||
|
||||
return {
|
||||
"status": "running",
|
||||
"storage_type": "Agent Pool API",
|
||||
"agent_pool": {
|
||||
"pool_size": pool_stats["pool_size"],
|
||||
"available_agents": pool_stats["available_agents"],
|
||||
"total_agents": pool_stats["total_agents"],
|
||||
"in_use_agents": pool_stats["in_use_agents"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时初始化助手实例池"""
|
||||
print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...")
|
||||
|
||||
try:
|
||||
def agent_factory():
|
||||
return init_agent_service_universal()
|
||||
|
||||
await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory)
|
||||
print("助手实例池初始化完成!")
|
||||
except Exception as e:
|
||||
print(f"助手实例池初始化失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""应用关闭时清理实例池"""
|
||||
print("正在关闭应用,清理助手实例池...")
|
||||
|
||||
from agent_pool import get_agent_pool
|
||||
pool = get_agent_pool()
|
||||
if pool:
|
||||
await pool.shutdown()
|
||||
print("助手实例池清理完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
||||
300
gbase_agent.py
300
gbase_agent.py
@ -16,7 +16,6 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
@ -30,118 +29,6 @@ from qwen_agent.utils.output_beautify import typewriter_print
|
||||
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
||||
|
||||
|
||||
class GPT4OChat(TextChatAtOAI):
|
||||
"""自定义 GPT-4o 聊天类,修复 tool_call_id 问题"""
|
||||
|
||||
def convert_messages_to_dicts(self, messages: List[Message]) -> List[dict]:
|
||||
# 使用父类方法进行基础转换
|
||||
messages = super().convert_messages_to_dicts(messages)
|
||||
|
||||
# 应用修复后的消息转换
|
||||
messages = self._fixed_conv_qwen_agent_messages_to_oai(messages)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _fixed_conv_qwen_agent_messages_to_oai(messages: List[Union[Message, Dict]]):
|
||||
"""修复后的消息转换方法,确保 tool 消息包含 tool_call_id 字段"""
|
||||
new_messages = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
if msg['role'] == ASSISTANT:
|
||||
# 处理 assistant 消息
|
||||
assistant_msg = {'role': 'assistant'}
|
||||
|
||||
# 设置 content
|
||||
content = msg.get('content', '')
|
||||
if isinstance(content, (list, dict)):
|
||||
assistant_msg['content'] = json.dumps(content, ensure_ascii=False)
|
||||
elif content is None:
|
||||
assistant_msg['content'] = ''
|
||||
else:
|
||||
assistant_msg['content'] = content
|
||||
|
||||
# 设置 reasoning_content
|
||||
if msg.get('reasoning_content'):
|
||||
assistant_msg['reasoning_content'] = msg['reasoning_content']
|
||||
|
||||
# 检查是否需要构造 tool_calls
|
||||
has_tool_call = False
|
||||
tool_calls = []
|
||||
|
||||
# 情况1:当前消息有 function_call
|
||||
if msg.get('function_call'):
|
||||
has_tool_call = True
|
||||
tool_calls.append({
|
||||
'id': msg.get('extra', {}).get('function_id', '1'),
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': msg['function_call']['name'],
|
||||
'arguments': msg['function_call']['arguments']
|
||||
}
|
||||
})
|
||||
|
||||
# 注意:不再为孤立的 tool 消息构造虚假的 tool_call
|
||||
|
||||
if has_tool_call:
|
||||
assistant_msg['tool_calls'] = tool_calls
|
||||
new_messages.append(assistant_msg)
|
||||
|
||||
# 检查后续是否有对应的 tool 消息
|
||||
if i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
|
||||
tool_msg = copy.deepcopy(messages[i + 1])
|
||||
# 确保 tool_call_id 匹配
|
||||
tool_msg['tool_call_id'] = tool_calls[0]['id']
|
||||
# 移除多余字段
|
||||
for field in ['id', 'extra', 'function_call']:
|
||||
if field in tool_msg:
|
||||
del tool_msg[field]
|
||||
# 确保 content 有效且为字符串
|
||||
content = tool_msg.get('content', '')
|
||||
if isinstance(content, (list, dict)):
|
||||
tool_msg['content'] = json.dumps(content, ensure_ascii=False)
|
||||
elif content is None:
|
||||
tool_msg['content'] = ''
|
||||
new_messages.append(tool_msg)
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
new_messages.append(assistant_msg)
|
||||
i += 1
|
||||
|
||||
elif msg['role'] == 'tool':
|
||||
# 孤立的 tool 消息,转换为 assistant + user 消息序列
|
||||
# 首先添加一个包含工具结果的 assistant 消息
|
||||
assistant_result = {'role': 'assistant'}
|
||||
content = msg.get('content', '')
|
||||
if isinstance(content, (list, dict)):
|
||||
content = json.dumps(content, ensure_ascii=False)
|
||||
assistant_result['content'] = f"工具查询结果: {content}"
|
||||
new_messages.append(assistant_result)
|
||||
|
||||
# 然后添加一个 user 消息来继续对话
|
||||
new_messages.append({'role': 'user', 'content': '请继续分析以上结果'})
|
||||
i += 1
|
||||
|
||||
else:
|
||||
# 处理其他角色消息
|
||||
new_msg = copy.deepcopy(msg)
|
||||
|
||||
# 确保 content 有效且为字符串
|
||||
content = new_msg.get('content', '')
|
||||
if isinstance(content, (list, dict)):
|
||||
new_msg['content'] = json.dumps(content, ensure_ascii=False)
|
||||
elif content is None:
|
||||
new_msg['content'] = ''
|
||||
|
||||
new_messages.append(new_msg)
|
||||
i += 1
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
def read_mcp_settings():
|
||||
@ -150,107 +37,70 @@ def read_mcp_settings():
|
||||
return mcp_settings_json
|
||||
|
||||
|
||||
def read_mcp_settings_with_project_restriction(project_data_dir: str):
|
||||
"""读取MCP配置并添加项目目录限制"""
|
||||
with open("./mcp/mcp_settings.json", "r") as f:
|
||||
mcp_settings_json = json.load(f)
|
||||
|
||||
# 为json-reader添加项目目录限制
|
||||
for server_config in mcp_settings_json:
|
||||
if "mcpServers" in server_config:
|
||||
for server_name, server_info in server_config["mcpServers"].items():
|
||||
if server_name == "json-reader":
|
||||
# 添加环境变量来传递项目目录限制
|
||||
if "env" not in server_info:
|
||||
server_info["env"] = {}
|
||||
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
|
||||
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
|
||||
break
|
||||
|
||||
return mcp_settings_json
|
||||
|
||||
|
||||
def read_system_prompt():
|
||||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def read_system_prompt():
|
||||
"""读取通用的无状态系统prompt"""
|
||||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def init_agent_service():
|
||||
llm_cfg = {
|
||||
"llama-33": {
|
||||
"model": "gbase-llama-33",
|
||||
"model_server": "http://llmapi:9009/v1",
|
||||
"api_key": "any",
|
||||
},
|
||||
"gpt-oss-120b": {
|
||||
"model": "openai/gpt-oss-120b",
|
||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||||
"api_key": "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",
|
||||
"""默认初始化函数,保持向后兼容"""
|
||||
return init_agent_service_universal("qwen3-next")
|
||||
|
||||
"generate_cfg": {
|
||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||||
}
|
||||
},
|
||||
|
||||
"claude-3.7": {
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_server": "https://one.felo.me/v1",
|
||||
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||||
},
|
||||
},
|
||||
"gpt-4o": {
|
||||
"model": "gpt-4o",
|
||||
"model_server": "https://one-dev.felo.me/v1",
|
||||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
|
||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||||
},
|
||||
},
|
||||
"Gpt-4o-back": {
|
||||
"model_type": "oai", # 使用 oai 类型以便使用自定义类
|
||||
"model": "gpt-4o",
|
||||
"model_server": "https://one-dev.felo.me/v1",
|
||||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
|
||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||||
},
|
||||
# 使用自定义的 GPT4OChat 类
|
||||
"llm_class": GPT4OChat,
|
||||
},
|
||||
def read_llm_config():
|
||||
"""读取LLM配置文件"""
|
||||
with open("./llm_config.json", "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
"glm-45": {
|
||||
"model_server": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
|
||||
"model": "glm-4.5",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||||
},
|
||||
},
|
||||
"qwen3-next": {
|
||||
"model": "qwen/qwen3-next-80b-a3b-instruct",
|
||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||||
"api_key": "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",
|
||||
|
||||
},
|
||||
"deepresearch": {
|
||||
"model": "alibaba/tongyi-deepresearch-30b-a3b",
|
||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||||
"api_key": "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",
|
||||
},
|
||||
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
|
||||
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
|
||||
llm_cfg = read_llm_config()
|
||||
|
||||
"qwen3-coder":{
|
||||
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||||
"model_server": "https://api-inference.modelscope.cn/v1", # base_url, also known as api_base
|
||||
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556",
|
||||
},
|
||||
"openrouter-gpt4o":{
|
||||
"model": "openai/gpt-4o",
|
||||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||||
"api_key": "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",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||||
},
|
||||
}
|
||||
}
|
||||
# 读取通用的系统prompt(无状态)
|
||||
system = read_system_prompt()
|
||||
|
||||
# 暂时禁用 MCP 工具以测试 GPT-4o
|
||||
tools = read_mcp_settings()
|
||||
# 使用自定义的 GPT-4o 配置
|
||||
llm_instance = llm_cfg["qwen3-next"]
|
||||
# 读取MCP工具配置
|
||||
tools = read_mcp_settings_with_project_restriction(project_data_dir)
|
||||
|
||||
# 使用指定的模型配置
|
||||
if model_name not in llm_cfg:
|
||||
raise ValueError(f"Model '{model_name}' not found in llm_config.json. Available models: {list(llm_cfg.keys())}")
|
||||
|
||||
llm_instance = llm_cfg[model_name]
|
||||
if "llm_class" in llm_instance:
|
||||
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
||||
|
||||
bot = Assistant(
|
||||
llm=llm_instance, # 使用自定义的 GPT-4o 实例
|
||||
name="数据库助手",
|
||||
description="数据库查询",
|
||||
llm=llm_instance, # 使用指定的模型实例
|
||||
name=f"数据库助手-{project_id}",
|
||||
description=f"项目 {project_id} 数据库查询",
|
||||
system_message=system,
|
||||
function_list=tools,
|
||||
)
|
||||
@ -258,6 +108,62 @@ def init_agent_service():
|
||||
return bot
|
||||
|
||||
|
||||
def init_agent_service_universal():
|
||||
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
||||
llm_cfg = read_llm_config()
|
||||
|
||||
# 读取通用的系统prompt(无状态)
|
||||
system = read_system_prompt()
|
||||
|
||||
# 读取基础的MCP工具配置(不包含项目限制)
|
||||
tools = read_mcp_settings()
|
||||
|
||||
# 使用默认模型创建助手实例
|
||||
default_model = "qwen3-next" # 默认模型
|
||||
if default_model not in llm_cfg:
|
||||
# 如果默认模型不存在,使用第一个可用模型
|
||||
default_model = list(llm_cfg.keys())[0]
|
||||
|
||||
llm_instance = llm_cfg[default_model]
|
||||
if "llm_class" in llm_instance:
|
||||
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
||||
|
||||
bot = Assistant(
|
||||
llm=llm_instance, # 使用默认LLM初始化
|
||||
name="通用数据检索助手",
|
||||
description="无状态通用数据检索助手",
|
||||
system_message=system,
|
||||
function_list=tools,
|
||||
)
|
||||
|
||||
return bot
|
||||
|
||||
|
||||
def update_agent_llm(agent, model_name: str, api_key: str = None, extra: Dict = None):
|
||||
"""动态更新助手实例的LLM,支持从接口传入参数"""
|
||||
|
||||
# 获取基础配置
|
||||
llm_config = {
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
}
|
||||
# 如果接口传入了extra参数,则合并到配置中
|
||||
if extra is not None:
|
||||
llm_config.update(extra)
|
||||
|
||||
# 创建LLM实例,确保不是字典
|
||||
if "llm_class" in llm_config:
|
||||
llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
|
||||
else:
|
||||
# 使用默认的 TextChatAtOAI 类
|
||||
llm_instance = TextChatAtOAI(llm_config)
|
||||
|
||||
# 动态设置LLM
|
||||
agent.llm = llm_instance
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def test(query="数据库里有几张表"):
|
||||
# Define the agent
|
||||
bot = init_agent_service()
|
||||
|
||||
65
llm_config.json
Normal file
65
llm_config.json
Normal file
@ -0,0 +1,65 @@
|
||||
{
|
||||
"llama-33": {
|
||||
"model": "gbase-llama-33",
|
||||
"model_server": "http://llmapi:9009/v1",
|
||||
"api_key": "any"
|
||||
},
|
||||
"gpt-oss-120b": {
|
||||
"model": "openai/gpt-oss-120b",
|
||||
"model_server": "https://openrouter.ai/api/v1",
|
||||
"api_key": "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",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": true,
|
||||
"fncall_prompt_type": "nous"
|
||||
}
|
||||
},
|
||||
"claude-3.7": {
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_server": "https://one.felo.me/v1",
|
||||
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": true
|
||||
}
|
||||
},
|
||||
"gpt-4o": {
|
||||
"model": "gpt-4o",
|
||||
"model_server": "https://one-dev.felo.me/v1",
|
||||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": true,
|
||||
"fncall_prompt_type": "nous"
|
||||
}
|
||||
},
|
||||
"glm-45": {
|
||||
"model_server": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
|
||||
"model": "glm-4.5",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": true
|
||||
}
|
||||
},
|
||||
"qwen3-next": {
|
||||
"model": "qwen/qwen3-next-80b-a3b-instruct",
|
||||
"model_server": "https://openrouter.ai/api/v1",
|
||||
"api_key": "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"
|
||||
},
|
||||
"deepresearch": {
|
||||
"model": "alibaba/tongyi-deepresearch-30b-a3b",
|
||||
"model_server": "https://openrouter.ai/api/v1",
|
||||
"api_key": "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"
|
||||
},
|
||||
"qwen3-coder": {
|
||||
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||||
"model_server": "https://api-inference.modelscope.cn/v1",
|
||||
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556"
|
||||
},
|
||||
"openrouter-gpt4o": {
|
||||
"model": "openai/gpt-4o",
|
||||
"model_server": "https://openrouter.ai/api/v1",
|
||||
"api_key": "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",
|
||||
"generate_cfg": {
|
||||
"use_raw_api": true,
|
||||
"fncall_prompt_type": "nous"
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
mcp/__pycache__/json_reader_server.cpython-312.pyc
Normal file
BIN
mcp/__pycache__/json_reader_server.cpython-312.pyc
Normal file
Binary file not shown.
BIN
mcp/__pycache__/mcp_wrapper.cpython-312.pyc
Normal file
BIN
mcp/__pycache__/mcp_wrapper.cpython-312.pyc
Normal file
Binary file not shown.
61
mcp/directory_tree_wrapper_server.py
Normal file
61
mcp/directory_tree_wrapper_server.py
Normal file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
目录树 MCP包装器服务器
|
||||
提供安全的目录结构查看功能,限制在项目目录内
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from mcp_wrapper import handle_wrapped_request
|
||||
|
||||
|
||||
async def main():
|
||||
"""主入口点"""
|
||||
try:
|
||||
while True:
|
||||
# 从stdin读取
|
||||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = json.loads(line)
|
||||
response = await handle_wrapped_request(request, "directory-tree-wrapper")
|
||||
|
||||
# 写入stdout
|
||||
sys.stdout.write(json.dumps(response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": "Parse error"
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as e:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"Internal error: {str(e)}"
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -12,6 +12,32 @@ import sys
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||
"""验证文件路径是否在允许的目录内"""
|
||||
# 转换为绝对路径
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
allowed_dir = os.path.abspath(allowed_dir)
|
||||
|
||||
# 检查路径是否在允许的目录内
|
||||
if not file_path.startswith(allowed_dir):
|
||||
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir} 内")
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if ".." in file_path:
|
||||
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 从环境变量读取项目数据目录
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||
return os.path.abspath(project_dir)
|
||||
|
||||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Handle MCP request"""
|
||||
try:
|
||||
@ -130,9 +156,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# Convert relative path to absolute path
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
# 验证文件路径是否在允许的目录内
|
||||
allowed_dir = get_allowed_directory()
|
||||
file_path = validate_file_path(file_path, allowed_dir)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
@ -222,9 +248,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# Convert relative path to absolute path
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
# 验证文件路径是否在允许的目录内
|
||||
allowed_dir = get_allowed_directory()
|
||||
file_path = validate_file_path(file_path, allowed_dir)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
@ -307,9 +333,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# Convert relative path to absolute path
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
# 验证文件路径是否在允许的目录内
|
||||
allowed_dir = get_allowed_directory()
|
||||
file_path = validate_file_path(file_path, allowed_dir)
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
@ -8,13 +8,6 @@
|
||||
"@andredezzy/deep-directory-tree-mcp"
|
||||
]
|
||||
},
|
||||
"mcp-server-code-runner": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"mcp-server-code-runner@latest"
|
||||
]
|
||||
},
|
||||
"ripgrep": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
|
||||
308
mcp/mcp_wrapper.py
Normal file
308
mcp/mcp_wrapper.py
Normal file
@ -0,0 +1,308 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
通用MCP工具包装器
|
||||
为所有MCP工具提供目录访问控制和安全限制
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MCPSecurityWrapper:
|
||||
"""MCP安全包装器"""
|
||||
|
||||
def __init__(self, allowed_directory: str):
|
||||
self.allowed_directory = os.path.abspath(allowed_directory)
|
||||
self.project_id = os.getenv("PROJECT_ID", "default")
|
||||
|
||||
def validate_path(self, path: str) -> str:
|
||||
"""验证路径是否在允许的目录内"""
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
|
||||
# 规范化路径
|
||||
path = os.path.normpath(path)
|
||||
|
||||
# 检查路径遍历
|
||||
if ".." in path.split(os.sep):
|
||||
raise ValueError(f"路径遍历攻击被阻止: {path}")
|
||||
|
||||
# 检查是否在允许的目录内
|
||||
if not path.startswith(self.allowed_directory):
|
||||
raise ValueError(f"访问被拒绝: {path} 不在允许的目录 {self.allowed_directory} 内")
|
||||
|
||||
return path
|
||||
|
||||
def safe_execute_command(self, command: list, cwd: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""安全执行命令,限制工作目录"""
|
||||
if cwd is None:
|
||||
cwd = self.allowed_directory
|
||||
else:
|
||||
cwd = self.validate_path(cwd)
|
||||
|
||||
# 设置环境变量限制
|
||||
env = os.environ.copy()
|
||||
env["PWD"] = cwd
|
||||
env["PROJECT_DATA_DIR"] = self.allowed_directory
|
||||
env["PROJECT_ID"] = self.project_id
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30 # 30秒超时
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"returncode": result.returncode
|
||||
}
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"stdout": "",
|
||||
"stderr": "命令执行超时",
|
||||
"returncode": -1
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"returncode": -1
|
||||
}
|
||||
|
||||
|
||||
async def handle_wrapped_request(request: Dict[str, Any], tool_name: str) -> Dict[str, Any]:
|
||||
"""处理包装后的MCP请求"""
|
||||
try:
|
||||
method = request.get("method")
|
||||
params = request.get("params", {})
|
||||
request_id = request.get("id")
|
||||
|
||||
allowed_dir = get_allowed_directory()
|
||||
wrapper = MCPSecurityWrapper(allowed_dir)
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": f"{tool_name}-wrapper",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elif method == "ping":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"pong": True
|
||||
}
|
||||
}
|
||||
|
||||
elif method == "tools/list":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"tools": get_tool_definitions(tool_name)
|
||||
}
|
||||
}
|
||||
|
||||
elif method == "tools/call":
|
||||
return await execute_tool_call(wrapper, tool_name, params, request_id)
|
||||
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Unknown method: {method}"
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"Internal error: {str(e)}"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./data")
|
||||
return os.path.abspath(project_dir)
|
||||
|
||||
|
||||
def get_tool_definitions(tool_name: str) -> list:
|
||||
"""根据工具名称返回工具定义"""
|
||||
if tool_name == "ripgrep-wrapper":
|
||||
return [
|
||||
{
|
||||
"name": "ripgrep_search",
|
||||
"description": "在项目目录内搜索文本",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "搜索模式"},
|
||||
"path": {"type": "string", "description": "搜索路径(相对于项目目录)"},
|
||||
"maxResults": {"type": "integer", "default": 100}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
]
|
||||
elif tool_name == "directory-tree-wrapper":
|
||||
return [
|
||||
{
|
||||
"name": "get_directory_tree",
|
||||
"description": "获取项目目录结构",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(相对于项目目录)"},
|
||||
"max_depth": {"type": "integer", "default": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
async def execute_tool_call(wrapper: MCPSecurityWrapper, tool_name: str, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||
"""执行工具调用"""
|
||||
try:
|
||||
if tool_name == "ripgrep-wrapper":
|
||||
return await execute_ripgrep_search(wrapper, params, request_id)
|
||||
elif tool_name == "directory-tree-wrapper":
|
||||
return await execute_directory_tree(wrapper, params, request_id)
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Unknown tool: {tool_name}"
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def execute_ripgrep_search(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||
"""执行ripgrep搜索"""
|
||||
pattern = params.get("pattern", "")
|
||||
path = params.get("path", ".")
|
||||
max_results = params.get("maxResults", 100)
|
||||
|
||||
# 验证和构建搜索路径
|
||||
search_path = os.path.join(wrapper.allowed_directory, path)
|
||||
search_path = wrapper.validate_path(search_path)
|
||||
|
||||
# 构建ripgrep命令
|
||||
command = [
|
||||
"rg",
|
||||
"--json",
|
||||
"--max-count", str(max_results),
|
||||
pattern,
|
||||
search_path
|
||||
]
|
||||
|
||||
result = wrapper.safe_execute_command(command)
|
||||
|
||||
if result["success"]:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result["stdout"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"搜索失败: {result['stderr']}"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def execute_directory_tree(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
||||
"""执行目录树获取"""
|
||||
path = params.get("path", ".")
|
||||
max_depth = params.get("max_depth", 3)
|
||||
|
||||
# 验证和构建目录路径
|
||||
dir_path = os.path.join(wrapper.allowed_directory, path)
|
||||
dir_path = wrapper.validate_path(dir_path)
|
||||
|
||||
# 构建目录树命令
|
||||
command = [
|
||||
"find",
|
||||
dir_path,
|
||||
"-type", "d",
|
||||
"-maxdepth", str(max_depth)
|
||||
]
|
||||
|
||||
result = wrapper.safe_execute_command(command)
|
||||
|
||||
if result["success"]:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result["stdout"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"获取目录树失败: {result['stderr']}"
|
||||
}
|
||||
}
|
||||
61
mcp/ripgrep_wrapper_server.py
Normal file
61
mcp/ripgrep_wrapper_server.py
Normal file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Ripgrep MCP包装器服务器
|
||||
提供安全的文本搜索功能,限制在项目目录内
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from mcp_wrapper import handle_wrapped_request
|
||||
|
||||
|
||||
async def main():
|
||||
"""主入口点"""
|
||||
try:
|
||||
while True:
|
||||
# 从stdin读取
|
||||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = json.loads(line)
|
||||
response = await handle_wrapped_request(request, "ripgrep-wrapper")
|
||||
|
||||
# 写入stdout
|
||||
sys.stdout.write(json.dumps(response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": "Parse error"
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as e:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"Internal error: {str(e)}"
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
154
project_config.py
Normal file
154
project_config.py
Normal file
@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
项目配置管理系统
|
||||
负责管理项目ID到数据目录的映射,以及项目访问权限控制
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional, List
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectConfig:
|
||||
"""项目配置数据类"""
|
||||
project_id: str
|
||||
data_dir: str
|
||||
name: str
|
||||
description: str = ""
|
||||
allowed_file_types: List[str] = None
|
||||
max_file_size_mb: int = 100
|
||||
is_active: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.allowed_file_types is None:
|
||||
self.allowed_file_types = [".json", ".txt", ".csv", ".pdf"]
|
||||
|
||||
|
||||
class ProjectManager:
|
||||
"""项目管理器"""
|
||||
|
||||
def __init__(self, config_file: str = "./projects/project_registry.json"):
|
||||
self.config_file = config_file
|
||||
self.projects: Dict[str, ProjectConfig] = {}
|
||||
self._ensure_config_dir()
|
||||
self._load_projects()
|
||||
|
||||
def _ensure_config_dir(self):
|
||||
"""确保配置目录存在"""
|
||||
config_dir = os.path.dirname(self.config_file)
|
||||
if not os.path.exists(config_dir):
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
|
||||
def _load_projects(self):
|
||||
"""从配置文件加载项目"""
|
||||
if os.path.exists(self.config_file):
|
||||
try:
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for project_data in data.get('projects', []):
|
||||
config = ProjectConfig(**project_data)
|
||||
self.projects[config.project_id] = config
|
||||
except Exception as e:
|
||||
print(f"加载项目配置失败: {e}")
|
||||
self._create_default_config()
|
||||
else:
|
||||
self._create_default_config()
|
||||
|
||||
def _create_default_config(self):
|
||||
"""创建默认配置"""
|
||||
default_project = ProjectConfig(
|
||||
project_id="default",
|
||||
data_dir="./data",
|
||||
name="默认项目",
|
||||
description="默认数据项目"
|
||||
)
|
||||
self.projects["default"] = default_project
|
||||
self._save_projects()
|
||||
|
||||
def _save_projects(self):
|
||||
"""保存项目配置到文件"""
|
||||
data = {
|
||||
"projects": [asdict(project) for project in self.projects.values()]
|
||||
}
|
||||
try:
|
||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存项目配置失败: {e}")
|
||||
|
||||
def get_project(self, project_id: str) -> Optional[ProjectConfig]:
|
||||
"""获取项目配置"""
|
||||
return self.projects.get(project_id)
|
||||
|
||||
def add_project(self, config: ProjectConfig) -> bool:
|
||||
"""添加项目"""
|
||||
if config.project_id in self.projects:
|
||||
return False
|
||||
|
||||
# 确保数据目录存在
|
||||
if not os.path.isabs(config.data_dir):
|
||||
config.data_dir = os.path.abspath(config.data_dir)
|
||||
|
||||
os.makedirs(config.data_dir, exist_ok=True)
|
||||
|
||||
self.projects[config.project_id] = config
|
||||
self._save_projects()
|
||||
return True
|
||||
|
||||
def update_project(self, project_id: str, **kwargs) -> bool:
|
||||
"""更新项目配置"""
|
||||
if project_id not in self.projects:
|
||||
return False
|
||||
|
||||
project = self.projects[project_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(project, key):
|
||||
setattr(project, key, value)
|
||||
|
||||
self._save_projects()
|
||||
return True
|
||||
|
||||
def delete_project(self, project_id: str) -> bool:
|
||||
"""删除项目"""
|
||||
if project_id not in self.projects:
|
||||
return False
|
||||
|
||||
del self.projects[project_id]
|
||||
self._save_projects()
|
||||
return True
|
||||
|
||||
def list_projects(self) -> List[ProjectConfig]:
|
||||
"""列出所有项目"""
|
||||
return list(self.projects.values())
|
||||
|
||||
def get_project_dir(self, project_id: str) -> str:
|
||||
"""获取项目数据目录"""
|
||||
project = self.get_project(project_id)
|
||||
if project:
|
||||
return project.data_dir
|
||||
|
||||
# 如果项目不存在,创建默认目录结构
|
||||
default_dir = f"./projects/{project_id}/data"
|
||||
os.makedirs(default_dir, exist_ok=True)
|
||||
|
||||
# 自动创建新项目配置
|
||||
new_project = ProjectConfig(
|
||||
project_id=project_id,
|
||||
data_dir=default_dir,
|
||||
name=f"项目 {project_id}",
|
||||
description=f"自动创建的项目 {project_id}"
|
||||
)
|
||||
self.add_project(new_project)
|
||||
|
||||
return default_dir
|
||||
|
||||
def validate_project_access(self, project_id: str) -> bool:
|
||||
"""验证项目访问权限"""
|
||||
project = self.get_project(project_id)
|
||||
return project and project.is_active
|
||||
|
||||
|
||||
# 全局项目管理器实例
|
||||
project_manager = ProjectManager()
|
||||
18
projects/project_registry.json
Normal file
18
projects/project_registry.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"projects": [
|
||||
{
|
||||
"project_id": "demo-project",
|
||||
"data_dir": "/Users/moshui/Documents/felo/qwen-agent/projects/demo-project/",
|
||||
"name": "演示项目",
|
||||
"description": "演示多项目隔离功能",
|
||||
"allowed_file_types": [
|
||||
".json",
|
||||
".txt",
|
||||
".csv",
|
||||
".pdf"
|
||||
],
|
||||
"max_file_size_mb": 100,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user