qwen_agent/mcp/multi_keyword_search_server.py
2025-10-20 12:51:36 +08:00

459 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
多关键词搜索MCP服务器
支持关键词数组匹配,按匹配数量排序输出
参考json_reader_server.py的实现方式
"""
import json
import os
import sys
import asyncio
import re
from typing import Any, Dict, List, Optional, Union
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
def load_tools_from_json() -> List[Dict[str, Any]]:
"""从 JSON 文件加载工具定义"""
try:
tools_file = os.path.join(os.path.dirname(__file__), "tools", "multi_keyword_search_tools.json")
if os.path.exists(tools_file):
with open(tools_file, 'r', encoding='utf-8') as f:
return json.load(f)
else:
# 如果 JSON 文件不存在,使用默认定义
return []
except Exception as e:
print(f"警告: 无法加载工具定义 JSON 文件: {str(e)}")
return []
def is_regex_pattern(pattern: str) -> bool:
"""检测字符串是否为正则表达式模式"""
# 检查 /pattern/ 格式
if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2:
return True
# 检查 r"pattern" 或 r'pattern' 格式
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3:
return True
# 检查是否包含正则特殊字符
regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'}
return any(char in pattern for char in regex_chars)
def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]:
"""编译正则表达式模式,如果不是正则则返回原字符串"""
if not is_regex_pattern(pattern):
return pattern
try:
# 处理 /pattern/ 格式
if pattern.startswith('/') and pattern.endswith('/'):
regex_body = pattern[1:-1]
return re.compile(regex_body)
# 处理 r"pattern" 或 r'pattern' 格式
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")):
regex_body = pattern[2:-1]
return re.compile(regex_body)
# 直接编译包含正则字符的字符串
return re.compile(pattern)
except re.error as e:
# 如果编译失败返回None表示无效的正则
print(f"警告: 正则表达式 '{pattern}' 编译失败: {e}")
return None
def multi_keyword_search(keywords: List[str], file_paths: List[str],
limit: int = 10, case_sensitive: bool = False) -> Dict[str, Any]:
"""执行多关键词和正则表达式搜索"""
if not keywords:
return {
"content": [
{
"type": "text",
"text": "错误:关键词列表不能为空"
}
]
}
if not file_paths:
return {
"content": [
{
"type": "text",
"text": "错误:文件路径列表不能为空"
}
]
}
# 预处理和验证关键词中的正则表达式
valid_keywords = []
regex_errors = []
for keyword in keywords:
compiled = compile_pattern(keyword)
if compiled is None:
regex_errors.append(keyword)
else:
valid_keywords.append(keyword)
if regex_errors:
error_msg = f"警告: 以下正则表达式编译失败,将被忽略: {', '.join(regex_errors)}"
print(error_msg)
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 验证文件路径
valid_paths = []
for file_path in file_paths:
try:
# 解析相对路径
if not os.path.isabs(file_path):
# 移除 projects/ 前缀(如果存在)
clean_path = file_path
if clean_path.startswith('projects/'):
clean_path = clean_path[9:] # 移除 'projects/' 前缀
elif clean_path.startswith('./projects/'):
clean_path = clean_path[11:] # 移除 './projects/' 前缀
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
if os.path.exists(full_path):
valid_paths.append(full_path)
else:
# 如果直接路径不存在,尝试递归查找
found = find_file_in_project(clean_path, project_data_dir)
if found:
valid_paths.append(found)
else:
if file_path.startswith(project_data_dir) and os.path.exists(file_path):
valid_paths.append(file_path)
except Exception as e:
continue
if not valid_paths:
return {
"content": [
{
"type": "text",
"text": f"错误:在项目目录 {project_data_dir} 中未找到指定文件"
}
]
}
# 收集所有匹配结果
all_results = []
for file_path in valid_paths:
try:
results = search_keywords_in_file(file_path, valid_keywords, case_sensitive)
all_results.extend(results)
except Exception as e:
continue
# 按匹配数量排序(降序)
all_results.sort(key=lambda x: x['match_count'], reverse=True)
# 限制结果数量
limited_results = all_results[:limit]
# 格式化输出
if not limited_results:
return {
"content": [
{
"type": "text",
"text": "未找到匹配的结果"
}
]
}
# 增强格式化输出,显示匹配类型和详细信息
formatted_lines = []
for result in limited_results:
line_prefix = f"{result['line_number']}:match_count({result['match_count']}):"
# 构建匹配详情
match_details = []
for pattern in result['matched_patterns']:
if pattern['type'] == 'regex':
match_details.append(f"[regex:{pattern['original']}={pattern['match']}]")
else:
match_details.append(f"[keyword:{pattern['match']}]")
match_info = " ".join(match_details) if match_details else ""
formatted_line = f"{line_prefix}{match_info}:{result['content']}" if match_info else f"{line_prefix}{result['content']}"
formatted_lines.append(formatted_line)
formatted_output = "\n".join(formatted_lines)
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
def search_keywords_in_file(file_path: str, keywords: List[str],
case_sensitive: bool) -> List[Dict[str, Any]]:
"""搜索单个文件中的关键词和正则表达式"""
results = []
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
except Exception as e:
return results
# 预处理所有模式
processed_patterns = []
for keyword in keywords:
compiled = compile_pattern(keyword)
if compiled is not None: # 跳过无效的正则表达式
processed_patterns.append({
'original': keyword,
'pattern': compiled,
'is_regex': isinstance(compiled, re.Pattern)
})
for line_number, line in enumerate(lines, 1):
line_content = line.rstrip('\n\r')
search_line = line_content if case_sensitive else line_content.lower()
# 统计匹配的模式数量
matched_patterns = []
for pattern_info in processed_patterns:
pattern = pattern_info['pattern']
is_regex = pattern_info['is_regex']
match_found = False
match_details = None
if is_regex:
# 正则表达式匹配
if case_sensitive:
match = pattern.search(line_content)
else:
# 对于不区分大小写的正则,需要重新编译
if isinstance(pattern, re.Pattern):
# 创建不区分大小写的版本
flags = pattern.flags | re.IGNORECASE
case_insensitive_pattern = re.compile(pattern.pattern, flags)
match = case_insensitive_pattern.search(line_content)
else:
match = pattern.search(search_line)
if match:
match_found = True
match_details = match.group(0)
else:
# 普通字符串匹配
search_keyword = pattern if case_sensitive else pattern.lower()
if search_keyword in search_line:
match_found = True
match_details = pattern
if match_found:
matched_patterns.append({
'original': pattern_info['original'],
'type': 'regex' if is_regex else 'keyword',
'match': match_details
})
match_count = len(matched_patterns)
if match_count > 0:
results.append({
'line_number': line_number,
'content': line_content,
'match_count': match_count,
'matched_patterns': matched_patterns,
'file_path': file_path
})
return results
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
"""在项目目录中递归查找文件"""
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "multi-keyword-search",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
# 从 JSON 文件加载工具定义
tools = load_tools_from_json()
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": tools
}
}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "multi_keyword_search":
keywords = arguments.get("keywords", [])
file_paths = arguments.get("file_paths", [])
limit = arguments.get("limit", 10)
case_sensitive = arguments.get("case_sensitive", False)
result = multi_keyword_search(keywords, file_paths, limit, case_sensitive)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
async def main():
"""Main entry point."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_request(request)
# Write to stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())