#!/usr/bin/env python3 """ 多关键词搜索MCP服务器 支持关键词数组匹配,按匹配数量排序输出 参考json_reader_server.py的实现方式 """ import json import os import sys import asyncio import re from typing import Any, Dict, List, Optional, Union def validate_file_path(file_path: str, allowed_dir: str) -> str: """验证文件路径是否在允许的目录内""" # 转换为绝对路径 if not os.path.isabs(file_path): file_path = os.path.abspath(file_path) allowed_dir = os.path.abspath(allowed_dir) # 检查路径是否在允许的目录内 if not file_path.startswith(allowed_dir): raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}") # 检查路径遍历攻击 if ".." in file_path: raise ValueError(f"Access denied: path traversal attack detected") return file_path def get_allowed_directory(): """获取允许访问的目录""" # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) def load_tools_from_json() -> List[Dict[str, Any]]: """从 JSON 文件加载工具定义""" try: tools_file = os.path.join(os.path.dirname(__file__), "tools", "multi_keyword_search_tools.json") if os.path.exists(tools_file): with open(tools_file, 'r', encoding='utf-8') as f: return json.load(f) else: # 如果 JSON 文件不存在,使用默认定义 return [] except Exception as e: print(f"Warning: Unable to load tool definition JSON file: {str(e)}") return [] def is_regex_pattern(pattern: str) -> bool: """检测字符串是否为正则表达式模式""" # 检查 /pattern/ 格式 if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2: return True # 检查 r"pattern" 或 r'pattern' 格式 if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3: return True # 检查是否包含正则特殊字符 regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'} return any(char in pattern for char in regex_chars) def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]: """编译正则表达式模式,如果不是正则则返回原字符串""" if not is_regex_pattern(pattern): return pattern try: # 处理 /pattern/ 格式 if pattern.startswith('/') and pattern.endswith('/'): regex_body = pattern[1:-1] return re.compile(regex_body) # 处理 r"pattern" 或 r'pattern' 格式 if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")): regex_body = pattern[2:-1] return re.compile(regex_body) # 直接编译包含正则字符的字符串 return re.compile(pattern) except re.error as e: # 如果编译失败,返回None表示无效的正则 print(f"Warning: Regular expression '{pattern}' compilation failed: {e}") return None def parse_patterns_with_weights(patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """解析搜索模式列表,仅支持权重格式: [{"pattern": "keyword1", "weight": 2.0}, {"pattern": "/regex/", "weight": 0.5}] """ parsed_patterns = [] for item in patterns: if not isinstance(item, dict): raise ValueError(f"Error: Search pattern must be in dictionary format with 'pattern' and 'weight' fields. Invalid item: {item}") pattern = item.get('pattern') weight = item.get('weight') if pattern is None: raise ValueError(f"Error: Missing 'pattern' field. Invalid item: {item}") if weight is None: raise ValueError(f"Error: Missing 'weight' field. Invalid item: {item}") # 确保权重是数字类型 try: weight = float(weight) if weight <= 0: raise ValueError(f"Error: Weight must be a positive number. Invalid weight: {weight}") except (ValueError, TypeError): raise ValueError(f"Error: Weight must be a valid number. Invalid weight: {weight}") parsed_patterns.append({ 'pattern': pattern, 'weight': weight }) return parsed_patterns def search_count(patterns: List[Dict[str, Any]], file_paths: List[str], case_sensitive: bool = False) -> Dict[str, Any]: """统计多模式匹配数量评估(关键词和正则表达式),必须包含权重""" if not patterns: return { "content": [ { "type": "text", "text": "Error: Search pattern list cannot be empty" } ] } # 解析搜索模式和权重 try: parsed_patterns = parse_patterns_with_weights(patterns) except ValueError as e: return { "content": [ { "type": "text", "text": str(e) } ] } if not parsed_patterns: return { "content": [ { "type": "text", "text": "Error: No valid search patterns" } ] } if not file_paths: return { "content": [ { "type": "text", "text": "Error: File path list cannot be empty" } ] } # 预处理和验证搜索模式中的正则表达式 valid_patterns = [] regex_errors = [] for pattern_info in parsed_patterns: pattern = pattern_info['pattern'] compiled = compile_pattern(pattern) if compiled is None: regex_errors.append(pattern) else: valid_patterns.append({ 'pattern': pattern, 'weight': pattern_info['weight'], 'compiled_pattern': compiled }) if regex_errors: error_msg = f"Warning: The following regular expressions failed to compile and will be ignored: {', '.join(regex_errors)}" print(error_msg) # 处理项目目录限制 project_data_dir = get_allowed_directory() # 验证文件路径 valid_paths = [] for file_path in file_paths: try: # 解析相对路径 if not os.path.isabs(file_path): # 移除 projects/ 前缀(如果存在) clean_path = file_path if clean_path.startswith('projects/'): clean_path = clean_path[9:] # 移除 'projects/' 前缀 elif clean_path.startswith('./projects/'): clean_path = clean_path[11:] # 移除 './projects/' 前缀 # 尝试在项目目录中查找文件 full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if os.path.exists(full_path): valid_paths.append(full_path) else: # 如果直接路径不存在,尝试递归查找 found = find_file_in_project(clean_path, project_data_dir) if found: valid_paths.append(found) else: if file_path.startswith(project_data_dir) and os.path.exists(file_path): valid_paths.append(file_path) except Exception as e: continue if not valid_paths: return { "content": [ { "type": "text", "text": f"Error: Specified files not found in project directory {project_data_dir}" } ] } # 统计所有匹配结果 all_results = [] for file_path in valid_paths: try: results = search_patterns_in_file(file_path, valid_patterns, case_sensitive) all_results.extend(results) except Exception as e: continue # 计算统计信息 total_lines_searched = 0 total_weight_score = 0.0 pattern_match_stats = {} file_match_stats = {} # 初始化模式统计 for pattern_info in valid_patterns: pattern_key = pattern_info['pattern'] pattern_match_stats[pattern_key] = { 'match_count': 0, 'weight_score': 0.0, 'lines_matched': set() } # 统计所有文件行数 for file_path in valid_paths: try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: lines = f.readlines() total_lines_searched += len(lines) except Exception: continue # 处理匹配结果 for result in all_results: total_weight_score += result.get('weight_score', 0) # 文件级别统计 file_path = result['file_path'] if file_path not in file_match_stats: file_match_stats[file_path] = { 'match_count': 0, 'weight_score': 0.0, 'lines_matched': set() } file_match_stats[file_path]['match_count'] += 1 file_match_stats[file_path]['weight_score'] += result.get('weight_score', 0) file_match_stats[file_path]['lines_matched'].add(result['line_number']) # 模式级别统计 for pattern in result['matched_patterns']: original_pattern = pattern['original'] if original_pattern in pattern_match_stats: pattern_match_stats[original_pattern]['match_count'] += pattern['match_count'] pattern_match_stats[original_pattern]['weight_score'] += pattern['weight_score'] pattern_match_stats[original_pattern]['lines_matched'].add(result['line_number']) # 格式化统计输出 formatted_lines = [] formatted_lines.append("=== Matching Statistics Evaluation ===") formatted_lines.append(f"Files searched: {len(valid_paths)}") formatted_lines.append(f"Total lines searched: {total_lines_searched}") formatted_lines.append(f"Total matched lines: {len(all_results)}") formatted_lines.append(f"Total weight score: {total_weight_score:.2f}") formatted_lines.append(f"Match rate: {(len(all_results)/total_lines_searched*100):.2f}%" if total_lines_searched > 0 else "Match rate: 0.00%") formatted_lines.append("") # 按文件统计 formatted_lines.append("=== Statistics by File ===") for file_path, stats in sorted(file_match_stats.items(), key=lambda x: x[1]['weight_score'], reverse=True): file_name = os.path.basename(file_path) formatted_lines.append(f"File: {file_name}") formatted_lines.append(f" Matched lines: {len(stats['lines_matched'])}") formatted_lines.append(f" Weight score: {stats['weight_score']:.2f}") formatted_lines.append("") # 按模式统计 formatted_lines.append("=== Statistics by Pattern ===") for pattern, stats in sorted(pattern_match_stats.items(), key=lambda x: x[1]['weight_score'], reverse=True): formatted_lines.append(f"Pattern: {pattern}") formatted_lines.append(f" Match count: {stats['match_count']}") formatted_lines.append(f" Matched lines: {len(stats['lines_matched'])}") formatted_lines.append(f" Weight score: {stats['weight_score']:.2f}") formatted_lines.append("") formatted_output = "\n".join(formatted_lines) return { "content": [ { "type": "text", "text": formatted_output } ] } def search(patterns: List[Dict[str, Any]], file_paths: List[str], limit: int = 10, case_sensitive: bool = False) -> Dict[str, Any]: """执行多模式搜索(关键词和正则表达式),必须包含权重""" if not patterns: return { "content": [ { "type": "text", "text": "Error: Search pattern list cannot be empty" } ] } # 解析搜索模式和权重 try: parsed_patterns = parse_patterns_with_weights(patterns) except ValueError as e: return { "content": [ { "type": "text", "text": str(e) } ] } if not parsed_patterns: return { "content": [ { "type": "text", "text": "Error: No valid search patterns" } ] } if not file_paths: return { "content": [ { "type": "text", "text": "Error: File path list cannot be empty" } ] } # 预处理和验证搜索模式中的正则表达式 valid_patterns = [] regex_errors = [] for pattern_info in parsed_patterns: pattern = pattern_info['pattern'] compiled = compile_pattern(pattern) if compiled is None: regex_errors.append(pattern) else: valid_patterns.append({ 'pattern': pattern, 'weight': pattern_info['weight'], 'compiled_pattern': compiled }) if regex_errors: error_msg = f"Warning: The following regular expressions failed to compile and will be ignored: {', '.join(regex_errors)}" print(error_msg) # 处理项目目录限制 project_data_dir = get_allowed_directory() # 验证文件路径 valid_paths = [] for file_path in file_paths: try: # 解析相对路径 if not os.path.isabs(file_path): # 移除 projects/ 前缀(如果存在) clean_path = file_path if clean_path.startswith('projects/'): clean_path = clean_path[9:] # 移除 'projects/' 前缀 elif clean_path.startswith('./projects/'): clean_path = clean_path[11:] # 移除 './projects/' 前缀 # 尝试在项目目录中查找文件 full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if os.path.exists(full_path): valid_paths.append(full_path) else: # 如果直接路径不存在,尝试递归查找 found = find_file_in_project(clean_path, project_data_dir) if found: valid_paths.append(found) else: if file_path.startswith(project_data_dir) and os.path.exists(file_path): valid_paths.append(file_path) except Exception as e: continue if not valid_paths: return { "content": [ { "type": "text", "text": f"Error: Specified files not found in project directory {project_data_dir}" } ] } # 收集所有匹配结果 all_results = [] for file_path in valid_paths: try: results = search_patterns_in_file(file_path, valid_patterns, case_sensitive) all_results.extend(results) except Exception as e: continue # 按权重得分排序(降序),权重得分相同时按匹配数量排序 all_results.sort(key=lambda x: (x.get('weight_score', 0), x['match_count']), reverse=True) # 限制结果数量 limited_results = all_results[:limit] # 格式化输出 if not limited_results: return { "content": [ { "type": "text", "text": "No matching results found" } ] } # 增强格式化输出,在第一行显示总匹配行数,然后显示权重得分、匹配类型和详细信息 formatted_lines = [] # 第一行显示总匹配行数和当前显示数量 total_matches = len(all_results) showing_count = len(limited_results) summary_line = f"Found {total_matches} matches, showing top {showing_count} results:" formatted_lines.append(summary_line) # 添加格式化的搜索结果 for result in limited_results: weight_score = result.get('weight_score', 0) line_prefix = f"{result['line_number']}:weight({weight_score:.2f}):" # 构建匹配详情 match_details = [] for pattern in result['matched_patterns']: if pattern['type'] == 'regex': match_details.append(f"[regex:{pattern['original']}={pattern['match']}]") else: match_details.append(f"[keyword:{pattern['match']}]") match_info = " ".join(match_details) if match_details else "" formatted_line = f"{line_prefix}{match_info}:{result['content']}" if match_info else f"{line_prefix}{result['content']}" formatted_lines.append(formatted_line) formatted_output = "\n".join(formatted_lines) return { "content": [ { "type": "text", "text": formatted_output } ] } def search_patterns_in_file(file_path: str, patterns: List[Dict[str, Any]], case_sensitive: bool) -> List[Dict[str, Any]]: """搜索单个文件中的搜索模式(关键词和正则表达式),支持权重计算""" results = [] try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: lines = f.readlines() except Exception as e: return results # 预处理所有模式,包含权重信息 processed_patterns = [] for pattern_info in patterns: compiled = pattern_info['compiled_pattern'] if compiled is not None: # 跳过无效的正则表达式 processed_patterns.append({ 'original': pattern_info['pattern'], 'pattern': compiled, 'is_regex': isinstance(compiled, re.Pattern), 'weight': pattern_info['weight'] }) for line_number, line in enumerate(lines, 1): line_content = line.rstrip('\n\r') search_line = line_content if case_sensitive else line_content.lower() # 统计匹配的模式数量和计算权重得分 matched_patterns = [] weight_score = 0.0 for pattern_info in processed_patterns: pattern = pattern_info['pattern'] is_regex = pattern_info['is_regex'] weight = pattern_info['weight'] match_found = False match_details = None match_count_in_line = 0 if is_regex: # 正则表达式匹配 if case_sensitive: matches = list(pattern.finditer(line_content)) else: # 对于不区分大小写的正则,需要重新编译 if isinstance(pattern, re.Pattern): # 创建不区分大小写的版本 flags = pattern.flags | re.IGNORECASE case_insensitive_pattern = re.compile(pattern.pattern, flags) matches = list(case_insensitive_pattern.finditer(line_content)) else: # 对于字符串模式,转换为小写再匹配 search_pattern = pattern.lower() if isinstance(pattern, str) else pattern matches = list(re.finditer(search_pattern, search_line)) if matches: match_found = True match_details = matches[0].group(0) # 重复正则匹配也只计算一次权重 match_count_in_line = 1 else: # 普通字符串匹配 search_keyword = pattern if case_sensitive else pattern.lower() if search_keyword in search_line: match_found = True match_details = pattern # 重复关键词只计算一次权重 match_count_in_line = 1 if match_found: # 计算该模式的权重贡献 (权重 * 匹配次数) pattern_weight_score = weight * match_count_in_line weight_score += pattern_weight_score matched_patterns.append({ 'original': pattern_info['original'], 'type': 'regex' if is_regex else 'keyword', 'match': match_details, 'weight': weight, 'match_count': match_count_in_line, 'weight_score': pattern_weight_score }) if weight_score > 0: results.append({ 'line_number': line_number, 'content': line_content, 'match_count': len(matched_patterns), 'weight_score': weight_score, 'matched_patterns': matched_patterns, 'file_path': file_path }) return results def find_file_in_project(filename: str, project_dir: str) -> Optional[str]: """在项目目录中递归查找文件""" for root, dirs, files in os.walk(project_dir): if filename in files: return os.path.join(root, filename) return None async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: method = request.get("method") params = request.get("params", {}) request_id = request.get("id") if method == "initialize": return { "jsonrpc": "2.0", "id": request_id, "result": { "protocolVersion": "2024-11-05", "capabilities": { "tools": {} }, "serverInfo": { "name": "multi-keyword-search", "version": "1.0.0" } } } elif method == "ping": return { "jsonrpc": "2.0", "id": request_id, "result": { "pong": True } } elif method == "tools/list": # 从 JSON 文件加载工具定义 tools = load_tools_from_json() return { "jsonrpc": "2.0", "id": request_id, "result": { "tools": tools } } elif method == "tools/call": tool_name = params.get("name") arguments = params.get("arguments", {}) if tool_name == "search": patterns = arguments.get("patterns", []) file_paths = arguments.get("file_paths", []) limit = arguments.get("limit", 10) case_sensitive = arguments.get("case_sensitive", False) result = search(patterns, file_paths, limit, case_sensitive) return { "jsonrpc": "2.0", "id": request_id, "result": result } elif tool_name == "search_count": patterns = arguments.get("patterns", []) file_paths = arguments.get("file_paths", []) case_sensitive = arguments.get("case_sensitive", False) result = search_count(patterns, file_paths, case_sensitive) return { "jsonrpc": "2.0", "id": request_id, "result": result } else: return { "jsonrpc": "2.0", "id": request_id, "error": { "code": -32601, "message": f"Unknown tool: {tool_name}" } } else: return { "jsonrpc": "2.0", "id": request_id, "error": { "code": -32601, "message": f"Unknown method: {method}" } } except Exception as e: return { "jsonrpc": "2.0", "id": request.get("id"), "error": { "code": -32603, "message": f"Internal error: {str(e)}" } } async def main(): """Main entry point.""" try: while True: # Read from stdin line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) if not line: break line = line.strip() if not line: continue try: request = json.loads(line) response = await handle_request(request) # Write to stdout sys.stdout.write(json.dumps(response) + "\n") sys.stdout.flush() except json.JSONDecodeError: error_response = { "jsonrpc": "2.0", "error": { "code": -32700, "message": "Parse error" } } sys.stdout.write(json.dumps(error_response) + "\n") sys.stdout.flush() except Exception as e: error_response = { "jsonrpc": "2.0", "error": { "code": -32603, "message": f"Internal error: {str(e)}" } } sys.stdout.write(json.dumps(error_response) + "\n") sys.stdout.flush() except KeyboardInterrupt: pass if __name__ == "__main__": asyncio.run(main())