add extra_prompt
This commit is contained in:
parent
cc88d52b14
commit
7b538d4967
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
projects/*
|
projects/*
|
||||||
|
workspace
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -186,12 +186,8 @@ query_pattern = simple_field_match(conditions[0]) # 先匹配主要条件
|
|||||||
- **渐进式扩展**:逐步放宽查询条件以发现更多相关数据
|
- **渐进式扩展**:逐步放宽查询条件以发现更多相关数据
|
||||||
- **交叉验证**:使用多种方法验证搜索结果的完整性
|
- **交叉验证**:使用多种方法验证搜索结果的完整性
|
||||||
|
|
||||||
## 重要说明
|
|
||||||
1. 查询的设备类型为第一优先级,比如笔记本和台式机。
|
|
||||||
2. 针对"CPU处理器"和"GPU显卡"的查询,因为命名方式多样性,查询优先级最低。
|
|
||||||
3. 如果确实无法找到完全匹配的数据,根据用户要求,可接受性能更高(更低)的CPU处理器和GPU显卡是作为代替。
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**执行提醒**:始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。
|
**重要说明**:所有文件路径中的 `[当前数据目录]` 将通过系统消息动态提供,请根据实际的数据目录路径进行操作。始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。
|
||||||
|
|
||||||
|
|
||||||
**重要说明**:所有文件路径中的 `[当前数据目录]` 将通过系统消息动态提供,请根据实际的数据目录路径进行操作。
|
|
||||||
|
|||||||
@ -62,6 +62,7 @@ class ChatRequest(BaseModel):
|
|||||||
extra: Optional[Dict] = None
|
extra: Optional[Dict] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
file_url: Optional[str] = None
|
file_url: Optional[str] = None
|
||||||
|
extra_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
@ -174,12 +175,13 @@ async def chat_completions(request: ChatRequest):
|
|||||||
# 动态设置请求的模型,支持从接口传入api_key、model_server和extra参数
|
# 动态设置请求的模型,支持从接口传入api_key、model_server和extra参数
|
||||||
update_agent_llm(agent, request.model, request.api_key, request.model_server)
|
update_agent_llm(agent, request.model, request.api_key, request.model_server)
|
||||||
|
|
||||||
|
extra_prompt = request.extra_prompt if request.extra_prompt else ""
|
||||||
# 构建包含项目信息的消息上下文
|
# 构建包含项目信息的消息上下文
|
||||||
messages = [
|
messages = [
|
||||||
# 项目信息系统消息
|
# 项目信息系统消息
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}"
|
"content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}\n"+ extra_prompt
|
||||||
},
|
},
|
||||||
# 用户消息批量转换
|
# 用户消息批量转换
|
||||||
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||||
|
|||||||
@ -57,11 +57,6 @@ def read_mcp_settings_with_project_restriction(project_data_dir: str):
|
|||||||
return mcp_settings_json
|
return mcp_settings_json
|
||||||
|
|
||||||
|
|
||||||
def read_system_prompt():
|
|
||||||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
|
||||||
return f.read().strip()
|
|
||||||
|
|
||||||
|
|
||||||
def read_system_prompt():
|
def read_system_prompt():
|
||||||
"""读取通用的无状态系统prompt"""
|
"""读取通用的无状态系统prompt"""
|
||||||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||||||
@ -73,32 +68,28 @@ def init_agent_service():
|
|||||||
return init_agent_service_universal("qwen3-next")
|
return init_agent_service_universal("qwen3-next")
|
||||||
|
|
||||||
|
|
||||||
def read_llm_config():
|
|
||||||
"""读取LLM配置文件"""
|
|
||||||
with open("./llm_config.json", "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
|
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
|
||||||
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
|
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
|
||||||
llm_cfg = read_llm_config()
|
|
||||||
|
|
||||||
# 读取通用的系统prompt(无状态)
|
# 读取通用的系统prompt(无状态)
|
||||||
system = read_system_prompt()
|
system = read_system_prompt()
|
||||||
|
|
||||||
# 读取MCP工具配置
|
# 读取MCP工具配置
|
||||||
tools = read_mcp_settings_with_project_restriction(project_data_dir)
|
tools = read_mcp_settings_with_project_restriction(project_data_dir)
|
||||||
|
|
||||||
# 使用指定的模型配置
|
# 创建默认的LLM配置(可以通过update_agent_llm动态更新)
|
||||||
if model_name not in llm_cfg:
|
llm_config = {
|
||||||
raise ValueError(f"Model '{model_name}' not found in llm_config.json. Available models: {list(llm_cfg.keys())}")
|
"model": model_name,
|
||||||
|
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
|
||||||
|
"api_key": "default-key" # 默认密钥,实际使用时需要通过API传入
|
||||||
|
}
|
||||||
|
|
||||||
llm_instance = llm_cfg[model_name]
|
# 创建LLM实例
|
||||||
if "llm_class" in llm_instance:
|
llm_instance = TextChatAtOAI(llm_config)
|
||||||
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
|
||||||
|
|
||||||
bot = Assistant(
|
bot = Assistant(
|
||||||
llm=llm_instance, # 使用指定的模型实例
|
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||||||
name=f"数据库助手-{project_id}",
|
name=f"数据库助手-{project_id}",
|
||||||
description=f"项目 {project_id} 数据库查询",
|
description=f"项目 {project_id} 数据库查询",
|
||||||
system_message=system,
|
system_message=system,
|
||||||
@ -110,26 +101,24 @@ def init_agent_service_with_project(project_id: str, project_data_dir: str, mode
|
|||||||
|
|
||||||
def init_agent_service_universal():
|
def init_agent_service_universal():
|
||||||
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
||||||
llm_cfg = read_llm_config()
|
|
||||||
|
|
||||||
# 读取通用的系统prompt(无状态)
|
# 读取通用的系统prompt(无状态)
|
||||||
system = read_system_prompt()
|
system = read_system_prompt()
|
||||||
|
|
||||||
# 读取基础的MCP工具配置(不包含项目限制)
|
# 读取基础的MCP工具配置(不包含项目限制)
|
||||||
tools = read_mcp_settings()
|
tools = read_mcp_settings()
|
||||||
|
|
||||||
# 使用默认模型创建助手实例
|
# 创建默认的LLM配置(可以通过update_agent_llm动态更新)
|
||||||
default_model = "qwen3-next" # 默认模型
|
llm_config = {
|
||||||
if default_model not in llm_cfg:
|
"model": "qwen3-next", # 默认模型
|
||||||
# 如果默认模型不存在,使用第一个可用模型
|
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
|
||||||
default_model = list(llm_cfg.keys())[0]
|
"api_key": "default-key" # 默认密钥,实际使用时需要通过API传入
|
||||||
|
}
|
||||||
|
|
||||||
llm_instance = llm_cfg[default_model]
|
# 创建LLM实例
|
||||||
if "llm_class" in llm_instance:
|
llm_instance = TextChatAtOAI(llm_config)
|
||||||
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
|
||||||
|
|
||||||
bot = Assistant(
|
bot = Assistant(
|
||||||
llm=llm_instance, # 使用默认LLM初始化
|
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||||||
name="通用数据检索助手",
|
name="通用数据检索助手",
|
||||||
description="无状态通用数据检索助手",
|
description="无状态通用数据检索助手",
|
||||||
system_message=system,
|
system_message=system,
|
||||||
@ -163,8 +152,8 @@ def update_agent_llm(agent, model_name: str, api_key: str = None, model_server:
|
|||||||
|
|
||||||
|
|
||||||
def test(query="数据库里有几张表"):
|
def test(query="数据库里有几张表"):
|
||||||
# Define the agent
|
# Define the agent - 使用通用初始化
|
||||||
bot = init_agent_service()
|
bot = init_agent_service_universal()
|
||||||
|
|
||||||
# Chat
|
# Chat
|
||||||
messages = []
|
messages = []
|
||||||
@ -182,8 +171,8 @@ def test(query="数据库里有几张表"):
|
|||||||
|
|
||||||
|
|
||||||
def app_tui():
|
def app_tui():
|
||||||
# Define the agent
|
# Define the agent - 使用通用初始化
|
||||||
bot = init_agent_service()
|
bot = init_agent_service_universal()
|
||||||
|
|
||||||
# Chat
|
# Chat
|
||||||
messages = []
|
messages = []
|
||||||
@ -209,8 +198,8 @@ def app_tui():
|
|||||||
|
|
||||||
|
|
||||||
def app_gui():
|
def app_gui():
|
||||||
# Define the agent
|
# Define the agent - 使用通用初始化
|
||||||
bot = init_agent_service()
|
bot = init_agent_service_universal()
|
||||||
chatbot_config = {
|
chatbot_config = {
|
||||||
"prompt.suggestions": [
|
"prompt.suggestions": [
|
||||||
"数据库里有几张表",
|
"数据库里有几张表",
|
||||||
|
|||||||
@ -1,65 +0,0 @@
|
|||||||
{
|
|
||||||
"llama-33": {
|
|
||||||
"model": "gbase-llama-33",
|
|
||||||
"model_server": "http://llmapi:9009/v1",
|
|
||||||
"api_key": "any"
|
|
||||||
},
|
|
||||||
"gpt-oss-120b": {
|
|
||||||
"model": "openai/gpt-oss-120b",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1",
|
|
||||||
"api_key": "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",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": true,
|
|
||||||
"fncall_prompt_type": "nous"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"claude-3.7": {
|
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
|
||||||
"model_server": "https://one.felo.me/v1",
|
|
||||||
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gpt-4o": {
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"model_server": "https://one-dev.felo.me/v1",
|
|
||||||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": true,
|
|
||||||
"fncall_prompt_type": "nous"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"glm-45": {
|
|
||||||
"model_server": "https://open.bigmodel.cn/api/paas/v4",
|
|
||||||
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
|
|
||||||
"model": "glm-4.5",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"qwen3-next": {
|
|
||||||
"model": "qwen/qwen3-next-80b-a3b-instruct",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1",
|
|
||||||
"api_key": "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"
|
|
||||||
},
|
|
||||||
"deepresearch": {
|
|
||||||
"model": "alibaba/tongyi-deepresearch-30b-a3b",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1",
|
|
||||||
"api_key": "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"
|
|
||||||
},
|
|
||||||
"qwen3-coder": {
|
|
||||||
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
|
||||||
"model_server": "https://api-inference.modelscope.cn/v1",
|
|
||||||
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556"
|
|
||||||
},
|
|
||||||
"openrouter-gpt4o": {
|
|
||||||
"model": "openai/gpt-4o",
|
|
||||||
"model_server": "https://openrouter.ai/api/v1",
|
|
||||||
"api_key": "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",
|
|
||||||
"generate_cfg": {
|
|
||||||
"use_raw_api": true,
|
|
||||||
"fncall_prompt_type": "nous"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,61 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
目录树 MCP包装器服务器
|
|
||||||
提供安全的目录结构查看功能,限制在项目目录内
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from mcp_wrapper import handle_wrapped_request
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主入口点"""
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# 从stdin读取
|
|
||||||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
request = json.loads(line)
|
|
||||||
response = await handle_wrapped_request(request, "directory-tree-wrapper")
|
|
||||||
|
|
||||||
# 写入stdout
|
|
||||||
sys.stdout.write(json.dumps(response) + "\n")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32700,
|
|
||||||
"message": "Parse error"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": f"Internal error: {str(e)}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@ -1,308 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
通用MCP工具包装器
|
|
||||||
为所有MCP工具提供目录访问控制和安全限制
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
import subprocess
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class MCPSecurityWrapper:
|
|
||||||
"""MCP安全包装器"""
|
|
||||||
|
|
||||||
def __init__(self, allowed_directory: str):
|
|
||||||
self.allowed_directory = os.path.abspath(allowed_directory)
|
|
||||||
self.project_id = os.getenv("PROJECT_ID", "default")
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> str:
|
|
||||||
"""验证路径是否在允许的目录内"""
|
|
||||||
if not os.path.isabs(path):
|
|
||||||
path = os.path.abspath(path)
|
|
||||||
|
|
||||||
# 规范化路径
|
|
||||||
path = os.path.normpath(path)
|
|
||||||
|
|
||||||
# 检查路径遍历
|
|
||||||
if ".." in path.split(os.sep):
|
|
||||||
raise ValueError(f"路径遍历攻击被阻止: {path}")
|
|
||||||
|
|
||||||
# 检查是否在允许的目录内
|
|
||||||
if not path.startswith(self.allowed_directory):
|
|
||||||
raise ValueError(f"访问被拒绝: {path} 不在允许的目录 {self.allowed_directory} 内")
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def safe_execute_command(self, command: list, cwd: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
"""安全执行命令,限制工作目录"""
|
|
||||||
if cwd is None:
|
|
||||||
cwd = self.allowed_directory
|
|
||||||
else:
|
|
||||||
cwd = self.validate_path(cwd)
|
|
||||||
|
|
||||||
# 设置环境变量限制
|
|
||||||
env = os.environ.copy()
|
|
||||||
env["PWD"] = cwd
|
|
||||||
env["PROJECT_DATA_DIR"] = self.allowed_directory
|
|
||||||
env["PROJECT_ID"] = self.project_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
command,
|
|
||||||
cwd=cwd,
|
|
||||||
env=env,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=30 # 30秒超时
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": result.returncode == 0,
|
|
||||||
"stdout": result.stdout,
|
|
||||||
"stderr": result.stderr,
|
|
||||||
"returncode": result.returncode
|
|
||||||
}
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": "命令执行超时",
|
|
||||||
"returncode": -1
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": str(e),
|
|
||||||
"returncode": -1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_wrapped_request(request: Dict[str, Any], tool_name: str) -> Dict[str, Any]:
|
|
||||||
"""处理包装后的MCP请求"""
|
|
||||||
try:
|
|
||||||
method = request.get("method")
|
|
||||||
params = request.get("params", {})
|
|
||||||
request_id = request.get("id")
|
|
||||||
|
|
||||||
allowed_dir = get_allowed_directory()
|
|
||||||
wrapper = MCPSecurityWrapper(allowed_dir)
|
|
||||||
|
|
||||||
if method == "initialize":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"result": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {
|
|
||||||
"tools": {}
|
|
||||||
},
|
|
||||||
"serverInfo": {
|
|
||||||
"name": f"{tool_name}-wrapper",
|
|
||||||
"version": "1.0.0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
elif method == "ping":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"result": {
|
|
||||||
"pong": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
elif method == "tools/list":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"result": {
|
|
||||||
"tools": get_tool_definitions(tool_name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
elif method == "tools/call":
|
|
||||||
return await execute_tool_call(wrapper, tool_name, params, request_id)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"error": {
|
|
||||||
"code": -32601,
|
|
||||||
"message": f"Unknown method: {method}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": f"Internal error: {str(e)}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_directory():
|
|
||||||
"""获取允许访问的目录"""
|
|
||||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./data")
|
|
||||||
return os.path.abspath(project_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_definitions(tool_name: str) -> list:
|
|
||||||
"""根据工具名称返回工具定义"""
|
|
||||||
if tool_name == "ripgrep-wrapper":
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": "ripgrep_search",
|
|
||||||
"description": "在项目目录内搜索文本",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"pattern": {"type": "string", "description": "搜索模式"},
|
|
||||||
"path": {"type": "string", "description": "搜索路径(相对于项目目录)"},
|
|
||||||
"maxResults": {"type": "integer", "default": 100}
|
|
||||||
},
|
|
||||||
"required": ["pattern"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
elif tool_name == "directory-tree-wrapper":
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": "get_directory_tree",
|
|
||||||
"description": "获取项目目录结构",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {"type": "string", "description": "目录路径(相对于项目目录)"},
|
|
||||||
"max_depth": {"type": "integer", "default": 3}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_call(wrapper: MCPSecurityWrapper, tool_name: str, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
|
||||||
"""执行工具调用"""
|
|
||||||
try:
|
|
||||||
if tool_name == "ripgrep-wrapper":
|
|
||||||
return await execute_ripgrep_search(wrapper, params, request_id)
|
|
||||||
elif tool_name == "directory-tree-wrapper":
|
|
||||||
return await execute_directory_tree(wrapper, params, request_id)
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"error": {
|
|
||||||
"code": -32601,
|
|
||||||
"message": f"Unknown tool: {tool_name}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": str(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_ripgrep_search(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
|
||||||
"""执行ripgrep搜索"""
|
|
||||||
pattern = params.get("pattern", "")
|
|
||||||
path = params.get("path", ".")
|
|
||||||
max_results = params.get("maxResults", 100)
|
|
||||||
|
|
||||||
# 验证和构建搜索路径
|
|
||||||
search_path = os.path.join(wrapper.allowed_directory, path)
|
|
||||||
search_path = wrapper.validate_path(search_path)
|
|
||||||
|
|
||||||
# 构建ripgrep命令
|
|
||||||
command = [
|
|
||||||
"rg",
|
|
||||||
"--json",
|
|
||||||
"--max-count", str(max_results),
|
|
||||||
pattern,
|
|
||||||
search_path
|
|
||||||
]
|
|
||||||
|
|
||||||
result = wrapper.safe_execute_command(command)
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": result["stdout"]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": f"搜索失败: {result['stderr']}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_directory_tree(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
|
|
||||||
"""执行目录树获取"""
|
|
||||||
path = params.get("path", ".")
|
|
||||||
max_depth = params.get("max_depth", 3)
|
|
||||||
|
|
||||||
# 验证和构建目录路径
|
|
||||||
dir_path = os.path.join(wrapper.allowed_directory, path)
|
|
||||||
dir_path = wrapper.validate_path(dir_path)
|
|
||||||
|
|
||||||
# 构建目录树命令
|
|
||||||
command = [
|
|
||||||
"find",
|
|
||||||
dir_path,
|
|
||||||
"-type", "d",
|
|
||||||
"-maxdepth", str(max_depth)
|
|
||||||
]
|
|
||||||
|
|
||||||
result = wrapper.safe_execute_command(command)
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": result["stdout"]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request_id,
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": f"获取目录树失败: {result['stderr']}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,61 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Ripgrep MCP包装器服务器
|
|
||||||
提供安全的文本搜索功能,限制在项目目录内
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from mcp_wrapper import handle_wrapped_request
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主入口点"""
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# 从stdin读取
|
|
||||||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
request = json.loads(line)
|
|
||||||
response = await handle_wrapped_request(request, "ripgrep-wrapper")
|
|
||||||
|
|
||||||
# 写入stdout
|
|
||||||
sys.stdout.write(json.dumps(response) + "\n")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32700,
|
|
||||||
"message": "Parse error"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": f"Internal error: {str(e)}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
Loading…
Reference in New Issue
Block a user