openai api

This commit is contained in:
朱潮 2025-10-07 12:25:41 +08:00
parent afe7600534
commit 10c2ef0bbc
23 changed files with 1259 additions and 255 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

178
agent_pool.py Normal file
View File

@ -0,0 +1,178 @@
import asyncio
from typing import List, Optional
import logging
logger = logging.getLogger(__name__)
class AgentPool:
"""助手实例池管理器"""
def __init__(self, pool_size: int = 5):
"""
初始化助手实例池
Args:
pool_size: 池中实例的数量默认5个
"""
self.pool_size = pool_size
self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
self.semaphore = asyncio.Semaphore(pool_size)
self.agents = [] # 保存所有创建的实例引用
async def initialize(self, agent_factory):
"""
初始化实例池使用工厂函数创建助手实例
Args:
agent_factory: 创建助手实例的工厂函数
"""
logger.info(f"正在初始化助手实例池,大小: {self.pool_size}")
for i in range(self.pool_size):
try:
agent = agent_factory()
await self.pool.put(agent)
self.agents.append(agent)
logger.info(f"助手实例 {i+1}/{self.pool_size} 创建成功")
except Exception as e:
logger.error(f"创建助手实例 {i+1} 失败: {e}")
raise
logger.info("助手实例池初始化完成")
async def get_agent(self, timeout: Optional[float] = 30.0):
"""
获取空闲的助手实例
Args:
timeout: 获取超时时间默认30秒
Returns:
助手实例
Raises:
asyncio.TimeoutError: 获取超时
"""
try:
# 使用信号量控制并发
await asyncio.wait_for(self.semaphore.acquire(), timeout=timeout)
# 从池中获取实例
agent = await asyncio.wait_for(self.pool.get(), timeout=timeout)
logger.debug(f"成功获取助手实例,剩余池大小: {self.pool.qsize()}")
return agent
except asyncio.TimeoutError:
logger.error(f"获取助手实例超时 ({timeout}秒)")
raise
async def release_agent(self, agent):
"""
释放助手实例回池
Args:
agent: 要释放的助手实例
"""
try:
await self.pool.put(agent)
self.semaphore.release()
logger.debug(f"释放助手实例,当前池大小: {self.pool.qsize()}")
except Exception as e:
logger.error(f"释放助手实例失败: {e}")
# 即使释放失败也要释放信号量
self.semaphore.release()
def get_pool_stats(self) -> dict:
"""
获取池状态统计信息
Returns:
包含池状态信息的字典
"""
return {
"pool_size": self.pool_size,
"available_agents": self.pool.qsize(),
"total_agents": len(self.agents),
"in_use_agents": len(self.agents) - self.pool.qsize()
}
async def shutdown(self):
"""关闭实例池,清理资源"""
logger.info("正在关闭助手实例池...")
# 清空队列
while not self.pool.empty():
try:
agent = self.pool.get_nowait()
# 如果有清理方法,调用清理
if hasattr(agent, 'cleanup'):
await agent.cleanup()
except asyncio.QueueEmpty:
break
logger.info("助手实例池已关闭")
# 全局实例池单例
_global_agent_pool: Optional[AgentPool] = None
def get_agent_pool() -> Optional[AgentPool]:
"""获取全局助手实例池"""
return _global_agent_pool
def set_agent_pool(pool: AgentPool):
"""设置全局助手实例池"""
global _global_agent_pool
_global_agent_pool = pool
async def init_global_agent_pool(pool_size: int = 5, agent_factory=None):
"""
初始化全局助手实例池
Args:
pool_size: 池大小
agent_factory: 实例工厂函数
"""
global _global_agent_pool
if _global_agent_pool is not None:
logger.warning("全局助手实例池已存在,跳过初始化")
return
if agent_factory is None:
raise ValueError("必须提供 agent_factory 参数")
_global_agent_pool = AgentPool(pool_size=pool_size)
await _global_agent_pool.initialize(agent_factory)
logger.info("全局助手实例池初始化完成")
async def get_agent_from_pool(timeout: Optional[float] = 30.0):
"""
从全局池获取助手实例
Args:
timeout: 获取超时时间
Returns:
助手实例
"""
if _global_agent_pool is None:
raise RuntimeError("全局助手实例池未初始化")
return await _global_agent_pool.get_agent(timeout)
async def release_agent_to_pool(agent):
"""
释放助手实例到全局池
Args:
agent: 要释放的助手实例
"""
if _global_agent_pool is None:
raise RuntimeError("全局助手实例池未初始化")
await _global_agent_pool.release_agent(agent)

View File

@ -17,7 +17,7 @@
### 数据存储层次
```
./data/
[当前数据目录]/
├── [数据集文件夹]/
│ ├── schema.json # 倒排索引层
│ ├── serialization.txt # 序列化数据层
@ -28,7 +28,7 @@
#### 1. 索引层 (schema.json)
- **功能**:字段枚举值倒排索引,查询入口点
- **访问方式**`json-reader-get_all_keys({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema"})`
- **访问方式**`json-reader-get_all_keys({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema"})`
- **数据结构**
```json
{
@ -66,14 +66,14 @@
**执行步骤**
1. **加载索引**读取schema.json获取字段元数据
2. **字段分析**:识别数值字段、文本字段、枚举字段
3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围
3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围
4. **策略制定**:基于查询条件选择最优检索路径
5. **范围预估**:评估各条件的数据分布和选择度
### 阶段2精准数据匹配
**目标**:从序列化数据中提取符合条件的记录
**执行步骤**
1. **预检查**`ripgrep-count-matches({"path": "./data/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})`
1. **预检查**`ripgrep-count-matches({"path": "[当前数据目录]/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})`
2. **智能限流**
- 匹配数 > 1000增加过滤条件重新预检查
- 匹配数 100-1000`ripgrep-search({"maxResults": 30})`
@ -193,3 +193,5 @@ query_pattern = simple_field_match(conditions[0]) # 先匹配主要条件
---
**执行提醒**:始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。
**重要说明**:所有文件路径中的 `[当前数据目录]` 将通过系统消息动态提供,请根据实际的数据目录路径进行操作。

Binary file not shown.

View File

@ -1,62 +1,247 @@
from typing import Optional
import json
import os
from typing import AsyncGenerator, Dict, List, Optional, Union
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import BackgroundTasks, FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
from gbase_agent import init_agent_service
# 自定义版本不需要text参数不打印到终端
def get_content_from_messages(messages: List[dict]) -> str:
full_text = ''
content = []
TOOL_CALL_S = '[TOOL_CALL]'
TOOL_RESULT_S = '[TOOL_RESPONSE]'
THOUGHT_S = '[THINK]'
ANSWER_S = '[ANSWER]'
for msg in messages:
if msg['role'] == ASSISTANT:
if msg.get('reasoning_content'):
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
if msg.get('content'):
assert isinstance(msg['content'], str), 'Now only supports text messages'
content.append(f'{ANSWER_S}\n{msg["content"]}')
if msg.get('function_call'):
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
elif msg['role'] == FUNCTION:
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
else:
raise TypeError
if content:
full_text = '\n'.join(content)
return full_text
from agent_pool import (get_agent_from_pool, init_global_agent_pool,
release_agent_to_pool)
from gbase_agent import init_agent_service_universal, update_agent_llm
from project_config import project_manager
app = FastAPI(title="Database Assistant API", version="1.0.0")
# Initialize agent globally at startup
bot = init_agent_service()
# 全局助手实例池,在应用启动时初始化
agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1"))
class QueryRequest(BaseModel):
question: str
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: str = "qwen3-next"
api_key: Optional[str] = None
extra: Optional[Dict] = None
stream: Optional[bool] = False
file_url: Optional[str] = None
class QueryResponse(BaseModel):
answer: str
class ChatResponse(BaseModel):
choices: List[Dict]
usage: Optional[Dict] = None
@app.post("/query", response_model=QueryResponse)
async def query_database(request: QueryRequest):
"""
Process a database query using the assistant agent.
class ChatStreamResponse(BaseModel):
choices: List[Dict]
usage: Optional[Dict] = None
Args:
request: QueryRequest containing the query and optional file URL
Returns:
QueryResponse containing the assistant's response
"""
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
"""生成流式响应"""
accumulated_content = ""
accumulated_args = ""
chunk_id = 0
try:
messages = []
if request.file_url:
messages.append(
{
"role": "user",
"content": [{"text":"使用sqlite数据库用日语回答下面问题"+request.question}, {"file": request.file_url}],
for response in agent.run(messages=messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
# 只有当有新内容时才发送chunk
if new_content:
chunk_id += 1
# 构造OpenAI格式的流式响应
chunk_data = {
"id": f"chatcmpl-{chunk_id}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {
"content": new_content
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送最终完成标记
final_chunk = {
"id": f"chatcmpl-{chunk_id + 1}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in generate_stream_response: {str(e)}")
print(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, project_id in extra field, etc.
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
"""
agent = None
try:
# 从extra字段中获取project_id
if not request.extra or 'project_id' not in request.extra:
raise HTTPException(status_code=400, detail="project_id is required in extra field")
project_id = request.extra['project_id']
# 验证项目访问权限
if not project_manager.validate_project_access(project_id):
raise HTTPException(status_code=404, detail=f"Project {project_id} not found or inactive")
# 获取项目数据目录
project_dir = project_manager.get_project_dir(project_id)
# 从实例池获取助手实例
agent = await get_agent_from_pool(timeout=30.0)
# 准备LLM配置从extra字段中移除project_id
llm_extra = request.extra.copy() if request.extra else {}
llm_extra.pop('project_id', None) # 移除project_id不传递给LLM
# 动态设置请求的模型支持从接口传入api_key和extra参数
update_agent_llm(agent, request.model, request.api_key, llm_extra)
# 构建包含项目信息的消息上下文
messages = [
# 项目信息系统消息
{
"role": "user",
"content": f"当前项目ID: {project_id},数据目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}"
},
# 用户消息批量转换
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
]
# 根据stream参数决定返回流式还是非流式响应
if request.stream:
return StreamingResponse(
generate_stream_response(agent, messages, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
messages.append({"role": "user", "content": request.question})
responses = []
for response in bot.run(messages):
responses.append(response)
if responses:
final_response = responses[-1][-1]
return QueryResponse(answer=final_response["content"])
else:
raise HTTPException(status_code=500, detail="No response from agent")
# 非流式响应
final_responses = agent.run_nonstream(messages)
if final_responses and len(final_responses) > 0:
# 取最后一个响应
final_response = final_responses[-1]
# 如果返回的是Message对象需要转换为字典
if hasattr(final_response, 'model_dump'):
final_response = final_response.model_dump()
elif hasattr(final_response, 'dict'):
final_response = final_response.dict()
content = final_response.get("content", "")
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
finally:
# 确保释放助手实例回池
if agent is not None:
await release_agent_to_pool(agent)
@app.get("/")
@ -65,6 +250,53 @@ async def root():
return {"message": "Database Assistant API is running"}
@app.get("/system/status")
async def system_status():
"""获取系统状态信息"""
from agent_pool import get_agent_pool
pool = get_agent_pool()
pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0}
return {
"status": "running",
"storage_type": "Agent Pool API",
"agent_pool": {
"pool_size": pool_stats["pool_size"],
"available_agents": pool_stats["available_agents"],
"total_agents": pool_stats["total_agents"],
"in_use_agents": pool_stats["in_use_agents"]
}
}
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化助手实例池"""
print(f"正在启动FastAPI应用初始化助手实例池大小: {agent_pool_size}...")
try:
def agent_factory():
return init_agent_service_universal()
await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory)
print("助手实例池初始化完成!")
except Exception as e:
print(f"助手实例池初始化失败: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时清理实例池"""
print("正在关闭应用,清理助手实例池...")
from agent_pool import get_agent_pool
pool = get_agent_pool()
if pool:
await pool.shutdown()
print("助手实例池清理完成!")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -16,7 +16,6 @@
import argparse
import asyncio
import copy
import json
import os
from typing import Dict, List, Optional, Union
@ -30,118 +29,6 @@ from qwen_agent.utils.output_beautify import typewriter_print
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
class GPT4OChat(TextChatAtOAI):
"""自定义 GPT-4o 聊天类,修复 tool_call_id 问题"""
def convert_messages_to_dicts(self, messages: List[Message]) -> List[dict]:
# 使用父类方法进行基础转换
messages = super().convert_messages_to_dicts(messages)
# 应用修复后的消息转换
messages = self._fixed_conv_qwen_agent_messages_to_oai(messages)
return messages
@staticmethod
def _fixed_conv_qwen_agent_messages_to_oai(messages: List[Union[Message, Dict]]):
"""修复后的消息转换方法,确保 tool 消息包含 tool_call_id 字段"""
new_messages = []
i = 0
while i < len(messages):
msg = messages[i]
if msg['role'] == ASSISTANT:
# 处理 assistant 消息
assistant_msg = {'role': 'assistant'}
# 设置 content
content = msg.get('content', '')
if isinstance(content, (list, dict)):
assistant_msg['content'] = json.dumps(content, ensure_ascii=False)
elif content is None:
assistant_msg['content'] = ''
else:
assistant_msg['content'] = content
# 设置 reasoning_content
if msg.get('reasoning_content'):
assistant_msg['reasoning_content'] = msg['reasoning_content']
# 检查是否需要构造 tool_calls
has_tool_call = False
tool_calls = []
# 情况1当前消息有 function_call
if msg.get('function_call'):
has_tool_call = True
tool_calls.append({
'id': msg.get('extra', {}).get('function_id', '1'),
'type': 'function',
'function': {
'name': msg['function_call']['name'],
'arguments': msg['function_call']['arguments']
}
})
# 注意:不再为孤立的 tool 消息构造虚假的 tool_call
if has_tool_call:
assistant_msg['tool_calls'] = tool_calls
new_messages.append(assistant_msg)
# 检查后续是否有对应的 tool 消息
if i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
tool_msg = copy.deepcopy(messages[i + 1])
# 确保 tool_call_id 匹配
tool_msg['tool_call_id'] = tool_calls[0]['id']
# 移除多余字段
for field in ['id', 'extra', 'function_call']:
if field in tool_msg:
del tool_msg[field]
# 确保 content 有效且为字符串
content = tool_msg.get('content', '')
if isinstance(content, (list, dict)):
tool_msg['content'] = json.dumps(content, ensure_ascii=False)
elif content is None:
tool_msg['content'] = ''
new_messages.append(tool_msg)
i += 2
else:
i += 1
else:
new_messages.append(assistant_msg)
i += 1
elif msg['role'] == 'tool':
# 孤立的 tool 消息,转换为 assistant + user 消息序列
# 首先添加一个包含工具结果的 assistant 消息
assistant_result = {'role': 'assistant'}
content = msg.get('content', '')
if isinstance(content, (list, dict)):
content = json.dumps(content, ensure_ascii=False)
assistant_result['content'] = f"工具查询结果: {content}"
new_messages.append(assistant_result)
# 然后添加一个 user 消息来继续对话
new_messages.append({'role': 'user', 'content': '请继续分析以上结果'})
i += 1
else:
# 处理其他角色消息
new_msg = copy.deepcopy(msg)
# 确保 content 有效且为字符串
content = new_msg.get('content', '')
if isinstance(content, (list, dict)):
new_msg['content'] = json.dumps(content, ensure_ascii=False)
elif content is None:
new_msg['content'] = ''
new_messages.append(new_msg)
i += 1
return new_messages
def read_mcp_settings():
@ -150,107 +37,70 @@ def read_mcp_settings():
return mcp_settings_json
def read_mcp_settings_with_project_restriction(project_data_dir: str):
"""读取MCP配置并添加项目目录限制"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
# 为json-reader添加项目目录限制
for server_config in mcp_settings_json:
if "mcpServers" in server_config:
for server_name, server_info in server_config["mcpServers"].items():
if server_name == "json-reader":
# 添加环境变量来传递项目目录限制
if "env" not in server_info:
server_info["env"] = {}
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
break
return mcp_settings_json
def read_system_prompt():
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
def init_agent_service():
llm_cfg = {
"llama-33": {
"model": "gbase-llama-33",
"model_server": "http://llmapi:9009/v1",
"api_key": "any",
},
"gpt-oss-120b": {
"model": "openai/gpt-oss-120b",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
}
},
"""默认初始化函数,保持向后兼容"""
return init_agent_service_universal("qwen3-next")
"claude-3.7": {
"model": "claude-3-7-sonnet-20250219",
"model_server": "https://one.felo.me/v1",
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
},
},
"gpt-4o": {
"model": "gpt-4o",
"model_server": "https://one-dev.felo.me/v1",
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
"generate_cfg": {
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
},
},
"Gpt-4o-back": {
"model_type": "oai", # 使用 oai 类型以便使用自定义类
"model": "gpt-4o",
"model_server": "https://one-dev.felo.me/v1",
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
"generate_cfg": {
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
},
# 使用自定义的 GPT4OChat 类
"llm_class": GPT4OChat,
},
"glm-45": {
"model_server": "https://open.bigmodel.cn/api/paas/v4",
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
"model": "glm-4.5",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
},
},
"qwen3-next": {
"model": "qwen/qwen3-next-80b-a3b-instruct",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
def read_llm_config():
"""读取LLM配置文件"""
with open("./llm_config.json", "r") as f:
return json.load(f)
},
"deepresearch": {
"model": "alibaba/tongyi-deepresearch-30b-a3b",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
},
"qwen3-coder":{
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
"model_server": "https://api-inference.modelscope.cn/v1", # base_url, also known as api_base
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556",
},
"openrouter-gpt4o":{
"model": "openai/gpt-4o",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
},
}
}
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
llm_cfg = read_llm_config()
# 读取通用的系统prompt无状态
system = read_system_prompt()
# 暂时禁用 MCP 工具以测试 GPT-4o
tools = read_mcp_settings()
# 使用自定义的 GPT-4o 配置
llm_instance = llm_cfg["qwen3-next"]
# 读取MCP工具配置
tools = read_mcp_settings_with_project_restriction(project_data_dir)
# 使用指定的模型配置
if model_name not in llm_cfg:
raise ValueError(f"Model '{model_name}' not found in llm_config.json. Available models: {list(llm_cfg.keys())}")
llm_instance = llm_cfg[model_name]
if "llm_class" in llm_instance:
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
bot = Assistant(
llm=llm_instance, # 使用自定义的 GPT-4o 实例
name="数据库助手",
description="数据库查询",
llm=llm_instance, # 使用指定的模型实例
name=f"数据库助手-{project_id}",
description=f"项目 {project_id} 数据库查询",
system_message=system,
function_list=tools,
)
@ -258,6 +108,62 @@ def init_agent_service():
return bot
def init_agent_service_universal():
"""创建无状态的通用助手实例使用默认LLM可动态切换"""
llm_cfg = read_llm_config()
# 读取通用的系统prompt无状态
system = read_system_prompt()
# 读取基础的MCP工具配置不包含项目限制
tools = read_mcp_settings()
# 使用默认模型创建助手实例
default_model = "qwen3-next" # 默认模型
if default_model not in llm_cfg:
# 如果默认模型不存在,使用第一个可用模型
default_model = list(llm_cfg.keys())[0]
llm_instance = llm_cfg[default_model]
if "llm_class" in llm_instance:
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
bot = Assistant(
llm=llm_instance, # 使用默认LLM初始化
name="通用数据检索助手",
description="无状态通用数据检索助手",
system_message=system,
function_list=tools,
)
return bot
def update_agent_llm(agent, model_name: str, api_key: str = None, extra: Dict = None):
"""动态更新助手实例的LLM支持从接口传入参数"""
# 获取基础配置
llm_config = {
"model": model_name,
"api_key": api_key,
}
# 如果接口传入了extra参数则合并到配置中
if extra is not None:
llm_config.update(extra)
# 创建LLM实例确保不是字典
if "llm_class" in llm_config:
llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
else:
# 使用默认的 TextChatAtOAI 类
llm_instance = TextChatAtOAI(llm_config)
# 动态设置LLM
agent.llm = llm_instance
return agent
def test(query="数据库里有几张表"):
# Define the agent
bot = init_agent_service()

65
llm_config.json Normal file
View File

@ -0,0 +1,65 @@
{
"llama-33": {
"model": "gbase-llama-33",
"model_server": "http://llmapi:9009/v1",
"api_key": "any"
},
"gpt-oss-120b": {
"model": "openai/gpt-oss-120b",
"model_server": "https://openrouter.ai/api/v1",
"api_key": "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",
"generate_cfg": {
"use_raw_api": true,
"fncall_prompt_type": "nous"
}
},
"claude-3.7": {
"model": "claude-3-7-sonnet-20250219",
"model_server": "https://one.felo.me/v1",
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
"generate_cfg": {
"use_raw_api": true
}
},
"gpt-4o": {
"model": "gpt-4o",
"model_server": "https://one-dev.felo.me/v1",
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
"generate_cfg": {
"use_raw_api": true,
"fncall_prompt_type": "nous"
}
},
"glm-45": {
"model_server": "https://open.bigmodel.cn/api/paas/v4",
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
"model": "glm-4.5",
"generate_cfg": {
"use_raw_api": true
}
},
"qwen3-next": {
"model": "qwen/qwen3-next-80b-a3b-instruct",
"model_server": "https://openrouter.ai/api/v1",
"api_key": "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"
},
"deepresearch": {
"model": "alibaba/tongyi-deepresearch-30b-a3b",
"model_server": "https://openrouter.ai/api/v1",
"api_key": "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"
},
"qwen3-coder": {
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
"model_server": "https://api-inference.modelscope.cn/v1",
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556"
},
"openrouter-gpt4o": {
"model": "openai/gpt-4o",
"model_server": "https://openrouter.ai/api/v1",
"api_key": "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",
"generate_cfg": {
"use_raw_api": true,
"fncall_prompt_type": "nous"
}
}
}

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python3
"""
目录树 MCP包装器服务器
提供安全的目录结构查看功能限制在项目目录内
"""
import asyncio
import json
import sys
from mcp_wrapper import handle_wrapped_request
async def main():
"""主入口点"""
try:
while True:
# 从stdin读取
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_wrapped_request(request, "directory-tree-wrapper")
# 写入stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())

View File

@ -12,6 +12,32 @@ import sys
import asyncio
from typing import Any, Dict, List
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
@ -130,9 +156,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
}
try:
# Convert relative path to absolute path
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
# 验证文件路径是否在允许的目录内
allowed_dir = get_allowed_directory()
file_path = validate_file_path(file_path, allowed_dir)
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
@ -222,9 +248,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
}
try:
# Convert relative path to absolute path
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
# 验证文件路径是否在允许的目录内
allowed_dir = get_allowed_directory()
file_path = validate_file_path(file_path, allowed_dir)
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
@ -307,9 +333,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
}
try:
# Convert relative path to absolute path
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
# 验证文件路径是否在允许的目录内
allowed_dir = get_allowed_directory()
file_path = validate_file_path(file_path, allowed_dir)
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)

View File

@ -8,13 +8,6 @@
"@andredezzy/deep-directory-tree-mcp"
]
},
"mcp-server-code-runner": {
"command": "npx",
"args": [
"-y",
"mcp-server-code-runner@latest"
]
},
"ripgrep": {
"command": "npx",
"args": [

308
mcp/mcp_wrapper.py Normal file
View File

@ -0,0 +1,308 @@
#!/usr/bin/env python3
"""
通用MCP工具包装器
为所有MCP工具提供目录访问控制和安全限制
"""
import os
import sys
import json
import asyncio
import subprocess
from typing import Dict, Any, Optional
from pathlib import Path
class MCPSecurityWrapper:
"""MCP安全包装器"""
def __init__(self, allowed_directory: str):
self.allowed_directory = os.path.abspath(allowed_directory)
self.project_id = os.getenv("PROJECT_ID", "default")
def validate_path(self, path: str) -> str:
"""验证路径是否在允许的目录内"""
if not os.path.isabs(path):
path = os.path.abspath(path)
# 规范化路径
path = os.path.normpath(path)
# 检查路径遍历
if ".." in path.split(os.sep):
raise ValueError(f"路径遍历攻击被阻止: {path}")
# 检查是否在允许的目录内
if not path.startswith(self.allowed_directory):
raise ValueError(f"访问被拒绝: {path} 不在允许的目录 {self.allowed_directory}")
return path
def safe_execute_command(self, command: list, cwd: Optional[str] = None) -> Dict[str, Any]:
"""安全执行命令,限制工作目录"""
if cwd is None:
cwd = self.allowed_directory
else:
cwd = self.validate_path(cwd)
# 设置环境变量限制
env = os.environ.copy()
env["PWD"] = cwd
env["PROJECT_DATA_DIR"] = self.allowed_directory
env["PROJECT_ID"] = self.project_id
try:
result = subprocess.run(
command,
cwd=cwd,
env=env,
capture_output=True,
text=True,
timeout=30 # 30秒超时
)
return {
"success": result.returncode == 0,
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"stdout": "",
"stderr": "命令执行超时",
"returncode": -1
}
except Exception as e:
return {
"success": False,
"stdout": "",
"stderr": str(e),
"returncode": -1
}
async def handle_wrapped_request(request: Dict[str, Any], tool_name: str) -> Dict[str, Any]:
"""处理包装后的MCP请求"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
allowed_dir = get_allowed_directory()
wrapper = MCPSecurityWrapper(allowed_dir)
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": f"{tool_name}-wrapper",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": get_tool_definitions(tool_name)
}
}
elif method == "tools/call":
return await execute_tool_call(wrapper, tool_name, params, request_id)
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
def get_allowed_directory():
"""获取允许访问的目录"""
project_dir = os.getenv("PROJECT_DATA_DIR", "./data")
return os.path.abspath(project_dir)
def get_tool_definitions(tool_name: str) -> list:
"""根据工具名称返回工具定义"""
if tool_name == "ripgrep-wrapper":
return [
{
"name": "ripgrep_search",
"description": "在项目目录内搜索文本",
"inputSchema": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "搜索模式"},
"path": {"type": "string", "description": "搜索路径(相对于项目目录)"},
"maxResults": {"type": "integer", "default": 100}
},
"required": ["pattern"]
}
}
]
elif tool_name == "directory-tree-wrapper":
return [
{
"name": "get_directory_tree",
"description": "获取项目目录结构",
"inputSchema": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "目录路径(相对于项目目录)"},
"max_depth": {"type": "integer", "default": 3}
}
}
}
]
else:
return []
async def execute_tool_call(wrapper: MCPSecurityWrapper, tool_name: str, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
"""执行工具调用"""
try:
if tool_name == "ripgrep-wrapper":
return await execute_ripgrep_search(wrapper, params, request_id)
elif tool_name == "directory-tree-wrapper":
return await execute_directory_tree(wrapper, params, request_id)
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": str(e)
}
}
async def execute_ripgrep_search(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
"""执行ripgrep搜索"""
pattern = params.get("pattern", "")
path = params.get("path", ".")
max_results = params.get("maxResults", 100)
# 验证和构建搜索路径
search_path = os.path.join(wrapper.allowed_directory, path)
search_path = wrapper.validate_path(search_path)
# 构建ripgrep命令
command = [
"rg",
"--json",
"--max-count", str(max_results),
pattern,
search_path
]
result = wrapper.safe_execute_command(command)
if result["success"]:
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result["stdout"]
}
]
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": f"搜索失败: {result['stderr']}"
}
}
async def execute_directory_tree(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
"""执行目录树获取"""
path = params.get("path", ".")
max_depth = params.get("max_depth", 3)
# 验证和构建目录路径
dir_path = os.path.join(wrapper.allowed_directory, path)
dir_path = wrapper.validate_path(dir_path)
# 构建目录树命令
command = [
"find",
dir_path,
"-type", "d",
"-maxdepth", str(max_depth)
]
result = wrapper.safe_execute_command(command)
if result["success"]:
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result["stdout"]
}
]
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": f"获取目录树失败: {result['stderr']}"
}
}

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python3
"""
Ripgrep MCP包装器服务器
提供安全的文本搜索功能限制在项目目录内
"""
import asyncio
import json
import sys
from mcp_wrapper import handle_wrapped_request
async def main():
"""主入口点"""
try:
while True:
# 从stdin读取
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_wrapped_request(request, "ripgrep-wrapper")
# 写入stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())

154
project_config.py Normal file
View File

@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""
项目配置管理系统
负责管理项目ID到数据目录的映射以及项目访问权限控制
"""
import json
import os
from typing import Dict, Optional, List
from dataclasses import dataclass, asdict
@dataclass
class ProjectConfig:
"""项目配置数据类"""
project_id: str
data_dir: str
name: str
description: str = ""
allowed_file_types: List[str] = None
max_file_size_mb: int = 100
is_active: bool = True
def __post_init__(self):
if self.allowed_file_types is None:
self.allowed_file_types = [".json", ".txt", ".csv", ".pdf"]
class ProjectManager:
"""项目管理器"""
def __init__(self, config_file: str = "./projects/project_registry.json"):
self.config_file = config_file
self.projects: Dict[str, ProjectConfig] = {}
self._ensure_config_dir()
self._load_projects()
def _ensure_config_dir(self):
"""确保配置目录存在"""
config_dir = os.path.dirname(self.config_file)
if not os.path.exists(config_dir):
os.makedirs(config_dir, exist_ok=True)
def _load_projects(self):
"""从配置文件加载项目"""
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for project_data in data.get('projects', []):
config = ProjectConfig(**project_data)
self.projects[config.project_id] = config
except Exception as e:
print(f"加载项目配置失败: {e}")
self._create_default_config()
else:
self._create_default_config()
def _create_default_config(self):
"""创建默认配置"""
default_project = ProjectConfig(
project_id="default",
data_dir="./data",
name="默认项目",
description="默认数据项目"
)
self.projects["default"] = default_project
self._save_projects()
def _save_projects(self):
"""保存项目配置到文件"""
data = {
"projects": [asdict(project) for project in self.projects.values()]
}
try:
with open(self.config_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存项目配置失败: {e}")
def get_project(self, project_id: str) -> Optional[ProjectConfig]:
"""获取项目配置"""
return self.projects.get(project_id)
def add_project(self, config: ProjectConfig) -> bool:
"""添加项目"""
if config.project_id in self.projects:
return False
# 确保数据目录存在
if not os.path.isabs(config.data_dir):
config.data_dir = os.path.abspath(config.data_dir)
os.makedirs(config.data_dir, exist_ok=True)
self.projects[config.project_id] = config
self._save_projects()
return True
def update_project(self, project_id: str, **kwargs) -> bool:
"""更新项目配置"""
if project_id not in self.projects:
return False
project = self.projects[project_id]
for key, value in kwargs.items():
if hasattr(project, key):
setattr(project, key, value)
self._save_projects()
return True
def delete_project(self, project_id: str) -> bool:
"""删除项目"""
if project_id not in self.projects:
return False
del self.projects[project_id]
self._save_projects()
return True
def list_projects(self) -> List[ProjectConfig]:
"""列出所有项目"""
return list(self.projects.values())
def get_project_dir(self, project_id: str) -> str:
"""获取项目数据目录"""
project = self.get_project(project_id)
if project:
return project.data_dir
# 如果项目不存在,创建默认目录结构
default_dir = f"./projects/{project_id}/data"
os.makedirs(default_dir, exist_ok=True)
# 自动创建新项目配置
new_project = ProjectConfig(
project_id=project_id,
data_dir=default_dir,
name=f"项目 {project_id}",
description=f"自动创建的项目 {project_id}"
)
self.add_project(new_project)
return default_dir
def validate_project_access(self, project_id: str) -> bool:
"""验证项目访问权限"""
project = self.get_project(project_id)
return project and project.is_active
# 全局项目管理器实例
project_manager = ProjectManager()

View File

@ -0,0 +1,18 @@
{
"projects": [
{
"project_id": "demo-project",
"data_dir": "/Users/moshui/Documents/felo/qwen-agent/projects/demo-project/",
"name": "演示项目",
"description": "演示多项目隔离功能",
"allowed_file_types": [
".json",
".txt",
".csv",
".pdf"
],
"max_file_size_mb": 100,
"is_active": true
}
]
}