qwen_agent/mcp/multi_keyword_search_server.py
2025-10-22 10:37:28 +08:00

767 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
多关键词搜索MCP服务器
支持关键词数组匹配,按匹配数量排序输出
参考json_reader_server.py的实现方式
"""
import json
import os
import sys
import asyncio
import re
from typing import Any, Dict, List, Optional, Union
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"Access denied: path traversal attack detected")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
def load_tools_from_json() -> List[Dict[str, Any]]:
"""从 JSON 文件加载工具定义"""
try:
tools_file = os.path.join(os.path.dirname(__file__), "tools", "multi_keyword_search_tools.json")
if os.path.exists(tools_file):
with open(tools_file, 'r', encoding='utf-8') as f:
return json.load(f)
else:
# 如果 JSON 文件不存在,使用默认定义
return []
except Exception as e:
print(f"Warning: Unable to load tool definition JSON file: {str(e)}")
return []
def is_regex_pattern(pattern: str) -> bool:
"""检测字符串是否为正则表达式模式"""
# 检查 /pattern/ 格式
if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2:
return True
# 检查 r"pattern" 或 r'pattern' 格式
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3:
return True
# 检查是否包含正则特殊字符
regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'}
return any(char in pattern for char in regex_chars)
def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]:
"""编译正则表达式模式,如果不是正则则返回原字符串"""
if not is_regex_pattern(pattern):
return pattern
try:
# 处理 /pattern/ 格式
if pattern.startswith('/') and pattern.endswith('/'):
regex_body = pattern[1:-1]
return re.compile(regex_body)
# 处理 r"pattern" 或 r'pattern' 格式
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")):
regex_body = pattern[2:-1]
return re.compile(regex_body)
# 直接编译包含正则字符的字符串
return re.compile(pattern)
except re.error as e:
# 如果编译失败返回None表示无效的正则
print(f"Warning: Regular expression '{pattern}' compilation failed: {e}")
return None
def parse_patterns_with_weights(patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""解析搜索模式列表,仅支持权重格式:
[{"pattern": "keyword1", "weight": 2.0}, {"pattern": "/regex/", "weight": 0.5}]
"""
parsed_patterns = []
for item in patterns:
if not isinstance(item, dict):
raise ValueError(f"Error: Search pattern must be in dictionary format with 'pattern' and 'weight' fields. Invalid item: {item}")
pattern = item.get('pattern')
weight = item.get('weight')
if pattern is None:
raise ValueError(f"Error: Missing 'pattern' field. Invalid item: {item}")
if weight is None:
raise ValueError(f"Error: Missing 'weight' field. Invalid item: {item}")
# 确保权重是数字类型
try:
weight = float(weight)
if weight <= 0:
raise ValueError(f"Error: Weight must be a positive number. Invalid weight: {weight}")
except (ValueError, TypeError):
raise ValueError(f"Error: Weight must be a valid number. Invalid weight: {weight}")
parsed_patterns.append({
'pattern': pattern,
'weight': weight
})
return parsed_patterns
def search_count(patterns: List[Dict[str, Any]], file_paths: List[str],
case_sensitive: bool = False) -> Dict[str, Any]:
"""统计多模式匹配数量评估(关键词和正则表达式),必须包含权重"""
if not patterns:
return {
"content": [
{
"type": "text",
"text": "Error: Search pattern list cannot be empty"
}
]
}
# 解析搜索模式和权重
try:
parsed_patterns = parse_patterns_with_weights(patterns)
except ValueError as e:
return {
"content": [
{
"type": "text",
"text": str(e)
}
]
}
if not parsed_patterns:
return {
"content": [
{
"type": "text",
"text": "Error: No valid search patterns"
}
]
}
if not file_paths:
return {
"content": [
{
"type": "text",
"text": "Error: File path list cannot be empty"
}
]
}
# 预处理和验证搜索模式中的正则表达式
valid_patterns = []
regex_errors = []
for pattern_info in parsed_patterns:
pattern = pattern_info['pattern']
compiled = compile_pattern(pattern)
if compiled is None:
regex_errors.append(pattern)
else:
valid_patterns.append({
'pattern': pattern,
'weight': pattern_info['weight'],
'compiled_pattern': compiled
})
if regex_errors:
error_msg = f"Warning: The following regular expressions failed to compile and will be ignored: {', '.join(regex_errors)}"
print(error_msg)
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 验证文件路径
valid_paths = []
for file_path in file_paths:
try:
# 解析相对路径
if not os.path.isabs(file_path):
# 移除 projects/ 前缀(如果存在)
clean_path = file_path
if clean_path.startswith('projects/'):
clean_path = clean_path[9:] # 移除 'projects/' 前缀
elif clean_path.startswith('./projects/'):
clean_path = clean_path[11:] # 移除 './projects/' 前缀
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
if os.path.exists(full_path):
valid_paths.append(full_path)
else:
# 如果直接路径不存在,尝试递归查找
found = find_file_in_project(clean_path, project_data_dir)
if found:
valid_paths.append(found)
else:
if file_path.startswith(project_data_dir) and os.path.exists(file_path):
valid_paths.append(file_path)
except Exception as e:
continue
if not valid_paths:
return {
"content": [
{
"type": "text",
"text": f"Error: Specified files not found in project directory {project_data_dir}"
}
]
}
# 统计所有匹配结果
all_results = []
for file_path in valid_paths:
try:
results = search_patterns_in_file(file_path, valid_patterns, case_sensitive)
all_results.extend(results)
except Exception as e:
continue
# 计算统计信息
total_lines_searched = 0
total_weight_score = 0.0
pattern_match_stats = {}
file_match_stats = {}
# 初始化模式统计
for pattern_info in valid_patterns:
pattern_key = pattern_info['pattern']
pattern_match_stats[pattern_key] = {
'match_count': 0,
'weight_score': 0.0,
'lines_matched': set()
}
# 统计所有文件行数
for file_path in valid_paths:
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
total_lines_searched += len(lines)
except Exception:
continue
# 处理匹配结果
for result in all_results:
total_weight_score += result.get('weight_score', 0)
# 文件级别统计
file_path = result['file_path']
if file_path not in file_match_stats:
file_match_stats[file_path] = {
'match_count': 0,
'weight_score': 0.0,
'lines_matched': set()
}
file_match_stats[file_path]['match_count'] += 1
file_match_stats[file_path]['weight_score'] += result.get('weight_score', 0)
file_match_stats[file_path]['lines_matched'].add(result['line_number'])
# 模式级别统计
for pattern in result['matched_patterns']:
original_pattern = pattern['original']
if original_pattern in pattern_match_stats:
pattern_match_stats[original_pattern]['match_count'] += pattern['match_count']
pattern_match_stats[original_pattern]['weight_score'] += pattern['weight_score']
pattern_match_stats[original_pattern]['lines_matched'].add(result['line_number'])
# 格式化统计输出
formatted_lines = []
formatted_lines.append("=== Matching Statistics Evaluation ===")
formatted_lines.append(f"Files searched: {len(valid_paths)}")
formatted_lines.append(f"Total lines searched: {total_lines_searched}")
formatted_lines.append(f"Total matched lines: {len(all_results)}")
formatted_lines.append(f"Total weight score: {total_weight_score:.2f}")
formatted_lines.append(f"Match rate: {(len(all_results)/total_lines_searched*100):.2f}%" if total_lines_searched > 0 else "Match rate: 0.00%")
formatted_lines.append("")
# 按文件统计
formatted_lines.append("=== Statistics by File ===")
for file_path, stats in sorted(file_match_stats.items(), key=lambda x: x[1]['weight_score'], reverse=True):
file_name = os.path.basename(file_path)
formatted_lines.append(f"File: {file_name}")
formatted_lines.append(f" Matched lines: {len(stats['lines_matched'])}")
formatted_lines.append(f" Weight score: {stats['weight_score']:.2f}")
formatted_lines.append("")
# 按模式统计
formatted_lines.append("=== Statistics by Pattern ===")
for pattern, stats in sorted(pattern_match_stats.items(), key=lambda x: x[1]['weight_score'], reverse=True):
formatted_lines.append(f"Pattern: {pattern}")
formatted_lines.append(f" Match count: {stats['match_count']}")
formatted_lines.append(f" Matched lines: {len(stats['lines_matched'])}")
formatted_lines.append(f" Weight score: {stats['weight_score']:.2f}")
formatted_lines.append("")
formatted_output = "\n".join(formatted_lines)
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
def search(patterns: List[Dict[str, Any]], file_paths: List[str],
limit: int = 10, case_sensitive: bool = False) -> Dict[str, Any]:
"""执行多模式搜索(关键词和正则表达式),必须包含权重"""
if not patterns:
return {
"content": [
{
"type": "text",
"text": "Error: Search pattern list cannot be empty"
}
]
}
# 解析搜索模式和权重
try:
parsed_patterns = parse_patterns_with_weights(patterns)
except ValueError as e:
return {
"content": [
{
"type": "text",
"text": str(e)
}
]
}
if not parsed_patterns:
return {
"content": [
{
"type": "text",
"text": "Error: No valid search patterns"
}
]
}
if not file_paths:
return {
"content": [
{
"type": "text",
"text": "Error: File path list cannot be empty"
}
]
}
# 预处理和验证搜索模式中的正则表达式
valid_patterns = []
regex_errors = []
for pattern_info in parsed_patterns:
pattern = pattern_info['pattern']
compiled = compile_pattern(pattern)
if compiled is None:
regex_errors.append(pattern)
else:
valid_patterns.append({
'pattern': pattern,
'weight': pattern_info['weight'],
'compiled_pattern': compiled
})
if regex_errors:
error_msg = f"Warning: The following regular expressions failed to compile and will be ignored: {', '.join(regex_errors)}"
print(error_msg)
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 验证文件路径
valid_paths = []
for file_path in file_paths:
try:
# 解析相对路径
if not os.path.isabs(file_path):
# 移除 projects/ 前缀(如果存在)
clean_path = file_path
if clean_path.startswith('projects/'):
clean_path = clean_path[9:] # 移除 'projects/' 前缀
elif clean_path.startswith('./projects/'):
clean_path = clean_path[11:] # 移除 './projects/' 前缀
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
if os.path.exists(full_path):
valid_paths.append(full_path)
else:
# 如果直接路径不存在,尝试递归查找
found = find_file_in_project(clean_path, project_data_dir)
if found:
valid_paths.append(found)
else:
if file_path.startswith(project_data_dir) and os.path.exists(file_path):
valid_paths.append(file_path)
except Exception as e:
continue
if not valid_paths:
return {
"content": [
{
"type": "text",
"text": f"Error: Specified files not found in project directory {project_data_dir}"
}
]
}
# 收集所有匹配结果
all_results = []
for file_path in valid_paths:
try:
results = search_patterns_in_file(file_path, valid_patterns, case_sensitive)
all_results.extend(results)
except Exception as e:
continue
# 按权重得分排序(降序),权重得分相同时按匹配数量排序
all_results.sort(key=lambda x: (x.get('weight_score', 0), x['match_count']), reverse=True)
# 限制结果数量
limited_results = all_results[:limit]
# 格式化输出
if not limited_results:
return {
"content": [
{
"type": "text",
"text": "No matching results found"
}
]
}
# 增强格式化输出,在第一行显示总匹配行数,然后显示权重得分、匹配类型和详细信息
formatted_lines = []
# 第一行显示总匹配行数和当前显示数量
total_matches = len(all_results)
showing_count = len(limited_results)
summary_line = f"Found {total_matches} matches, showing top {showing_count} results:"
formatted_lines.append(summary_line)
# 添加格式化的搜索结果
for result in limited_results:
weight_score = result.get('weight_score', 0)
line_prefix = f"{result['line_number']}:weight({weight_score:.2f}):"
# 构建匹配详情
match_details = []
for pattern in result['matched_patterns']:
if pattern['type'] == 'regex':
match_details.append(f"[regex:{pattern['original']}={pattern['match']}]")
else:
match_details.append(f"[keyword:{pattern['match']}]")
match_info = " ".join(match_details) if match_details else ""
formatted_line = f"{line_prefix}{match_info}:{result['content']}" if match_info else f"{line_prefix}{result['content']}"
formatted_lines.append(formatted_line)
formatted_output = "\n".join(formatted_lines)
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
def search_patterns_in_file(file_path: str, patterns: List[Dict[str, Any]],
case_sensitive: bool) -> List[Dict[str, Any]]:
"""搜索单个文件中的搜索模式(关键词和正则表达式),支持权重计算"""
results = []
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
except Exception as e:
return results
# 预处理所有模式,包含权重信息
processed_patterns = []
for pattern_info in patterns:
compiled = pattern_info['compiled_pattern']
if compiled is not None: # 跳过无效的正则表达式
processed_patterns.append({
'original': pattern_info['pattern'],
'pattern': compiled,
'is_regex': isinstance(compiled, re.Pattern),
'weight': pattern_info['weight']
})
for line_number, line in enumerate(lines, 1):
line_content = line.rstrip('\n\r')
search_line = line_content if case_sensitive else line_content.lower()
# 统计匹配的模式数量和计算权重得分
matched_patterns = []
weight_score = 0.0
for pattern_info in processed_patterns:
pattern = pattern_info['pattern']
is_regex = pattern_info['is_regex']
weight = pattern_info['weight']
match_found = False
match_details = None
match_count_in_line = 0
if is_regex:
# 正则表达式匹配
if case_sensitive:
matches = list(pattern.finditer(line_content))
else:
# 对于不区分大小写的正则,需要重新编译
if isinstance(pattern, re.Pattern):
# 创建不区分大小写的版本
flags = pattern.flags | re.IGNORECASE
case_insensitive_pattern = re.compile(pattern.pattern, flags)
matches = list(case_insensitive_pattern.finditer(line_content))
else:
# 对于字符串模式,转换为小写再匹配
search_pattern = pattern.lower() if isinstance(pattern, str) else pattern
matches = list(re.finditer(search_pattern, search_line))
if matches:
match_found = True
match_details = matches[0].group(0)
# 重复正则匹配也只计算一次权重
match_count_in_line = 1
else:
# 普通字符串匹配
search_keyword = pattern if case_sensitive else pattern.lower()
if search_keyword in search_line:
match_found = True
match_details = pattern
# 重复关键词只计算一次权重
match_count_in_line = 1
if match_found:
# 计算该模式的权重贡献 (权重 * 匹配次数)
pattern_weight_score = weight * match_count_in_line
weight_score += pattern_weight_score
matched_patterns.append({
'original': pattern_info['original'],
'type': 'regex' if is_regex else 'keyword',
'match': match_details,
'weight': weight,
'match_count': match_count_in_line,
'weight_score': pattern_weight_score
})
if weight_score > 0:
results.append({
'line_number': line_number,
'content': line_content,
'match_count': len(matched_patterns),
'weight_score': weight_score,
'matched_patterns': matched_patterns,
'file_path': file_path
})
return results
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
"""在项目目录中递归查找文件"""
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "multi-keyword-search",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
# 从 JSON 文件加载工具定义
tools = load_tools_from_json()
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": tools
}
}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "search":
patterns = arguments.get("patterns", [])
file_paths = arguments.get("file_paths", [])
limit = arguments.get("limit", 10)
case_sensitive = arguments.get("case_sensitive", False)
result = search(patterns, file_paths, limit, case_sensitive)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
elif tool_name == "search_count":
patterns = arguments.get("patterns", [])
file_paths = arguments.get("file_paths", [])
case_sensitive = arguments.get("case_sensitive", False)
result = search_count(patterns, file_paths, case_sensitive)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
async def main():
"""Main entry point."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_request(request)
# Write to stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())