qwen_agent/mcp/multi_keyword_search_server.py
2025-10-10 08:58:23 +08:00

355 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
多关键词搜索MCP服务器
支持关键词数组匹配,按匹配数量排序输出
参考json_reader_server.py的实现方式
"""
import json
import os
import sys
import asyncio
from typing import Any, Dict, List, Optional
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
def multi_keyword_search(keywords: List[str], file_paths: List[str],
limit: int = 10, case_sensitive: bool = False) -> Dict[str, Any]:
"""执行多关键词搜索"""
if not keywords:
return {
"content": [
{
"type": "text",
"text": "错误:关键词列表不能为空"
}
]
}
if not file_paths:
return {
"content": [
{
"type": "text",
"text": "错误:文件路径列表不能为空"
}
]
}
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 验证文件路径
valid_paths = []
for file_path in file_paths:
try:
# 解析相对路径
if not os.path.isabs(file_path):
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, file_path.lstrip('./'))
if os.path.exists(full_path):
valid_paths.append(full_path)
else:
# 如果直接路径不存在,尝试递归查找
found = find_file_in_project(file_path, project_data_dir)
if found:
valid_paths.append(found)
else:
if file_path.startswith(project_data_dir) and os.path.exists(file_path):
valid_paths.append(file_path)
except Exception as e:
continue
if not valid_paths:
return {
"content": [
{
"type": "text",
"text": f"错误:在项目目录 {project_data_dir} 中未找到指定文件"
}
]
}
# 收集所有匹配结果
all_results = []
for file_path in valid_paths:
try:
results = search_keywords_in_file(file_path, keywords, case_sensitive)
all_results.extend(results)
except Exception as e:
continue
# 按匹配数量排序(降序)
all_results.sort(key=lambda x: x['match_count'], reverse=True)
# 限制结果数量
limited_results = all_results[:limit]
# 格式化输出
if not limited_results:
return {
"content": [
{
"type": "text",
"text": "未找到匹配的结果"
}
]
}
formatted_output = "\n".join([
f"{result['line_number']}:match_count({result['match_count']}):{result['content']}"
for result in limited_results
])
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
def search_keywords_in_file(file_path: str, keywords: List[str],
case_sensitive: bool) -> List[Dict[str, Any]]:
"""搜索单个文件中的关键词"""
results = []
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
except Exception as e:
return results
# 准备关键词(如果不区分大小写)
search_keywords = keywords if case_sensitive else [kw.lower() for kw in keywords]
for line_number, line in enumerate(lines, 1):
line_content = line.rstrip('\n\r')
search_line = line_content if case_sensitive else line_content.lower()
# 统计匹配的关键词数量
matched_keywords = []
for i, keyword in enumerate(search_keywords):
if keyword in search_line:
matched_keywords.append(keywords[i]) # 使用原始关键词
match_count = len(matched_keywords)
if match_count > 0:
results.append({
'line_number': line_number,
'content': line_content,
'match_count': match_count,
'matched_keywords': matched_keywords,
'file_path': file_path
})
return results
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
"""在项目目录中递归查找文件"""
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "multi-keyword-search",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": [
{
"name": "multi_keyword_search",
"description": "多关键词搜索工具,返回按匹配数量排序的结果。格式:[行号]:[匹配数量]:[行的原始内容]",
"inputSchema": {
"type": "object",
"properties": {
"keywords": {
"type": "array",
"items": {"type": "string"},
"description": "要搜索的关键词数组"
},
"file_paths": {
"type": "array",
"items": {"type": "string"},
"description": "要搜索的文件路径列表"
},
"limit": {
"type": "integer",
"description": "返回结果的最大数量默认10",
"default": 10
},
"case_sensitive": {
"type": "boolean",
"description": "是否区分大小写默认false",
"default": False
}
},
"required": ["keywords", "file_paths"]
}
}
]
}
}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "multi_keyword_search":
keywords = arguments.get("keywords", [])
file_paths = arguments.get("file_paths", [])
limit = arguments.get("limit", 10)
case_sensitive = arguments.get("case_sensitive", False)
result = multi_keyword_search(keywords, file_paths, limit, case_sensitive)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
async def main():
"""Main entry point."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_request(request)
# Write to stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())