qwen_agent/skills/autoload/support/rag-retrieve/mcp_common.py
朱潮 425f3c5bb4 chore: replace Chinese comments and log messages with English
Convert all Chinese comments, docstrings, logger/print output,
HTTPException detail messages, and API response messages to English
across the entire codebase. Functional zh/ja localized strings
(e.g. prompt templates, timezone display names, date formats) are
preserved as-is.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-30 19:45:35 +08:00

253 lines
8.3 KiB
Python

#!/usr/bin/env python3
"""
Common utility functions for MCP servers.
Provides shared functionality for path handling, file validation, and request processing.
"""
import json
import os
import sys
import asyncio
from typing import Any, Dict, List, Optional, Union
import re
def get_allowed_directory():
"""Get the directory that is allowed to be accessed."""
# Prefer dataset_dir passed through command-line arguments.
if len(sys.argv) > 1:
dataset_dir = sys.argv[1]
return os.path.abspath(dataset_dir)
# Read the project data directory from the environment variable.
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects/data")
return os.path.abspath(project_dir)
def resolve_file_path(file_path: str, default_subfolder: str = "default") -> str:
"""
Resolve a file path, supporting both folder/document.txt and document.txt formats.
Args:
file_path: Input file path
default_subfolder: Default subfolder name to use when only a filename is provided
Returns:
The resolved absolute file path
"""
# If the path contains a folder separator, use it directly.
if '/' in file_path or '\\' in file_path:
clean_path = file_path.replace('\\', '/')
# Remove the projects/ prefix if present.
if clean_path.startswith('projects/'):
clean_path = clean_path[9:] # Remove the 'projects/' prefix
elif clean_path.startswith('./projects/'):
clean_path = clean_path[11:] # Remove the './projects/' prefix
else:
# If only a filename is provided, prepend the default subfolder.
clean_path = f"{default_subfolder}/{file_path}"
# Get the allowed directory.
project_data_dir = get_allowed_directory()
# Try to locate the file within the project directory.
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
if os.path.exists(full_path):
return full_path
# If the direct path does not exist, try a recursive search.
found = find_file_in_project(clean_path, project_data_dir)
if found:
return found
# If only a filename was provided and it was not found in the default subfolder,
# try the project root.
if '/' not in file_path and '\\' not in file_path:
root_path = os.path.join(project_data_dir, file_path)
if os.path.exists(root_path):
return root_path
raise FileNotFoundError(f"File not found: {file_path} (searched in {project_data_dir})")
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
"""Recursively search for a file inside the project directory."""
# If filename includes a path, only search within that specific path.
if '/' in filename:
parts = filename.split('/')
target_file = parts[-1]
search_dir = os.path.join(project_dir, *parts[:-1])
if os.path.exists(search_dir):
target_path = os.path.join(search_dir, target_file)
if os.path.exists(target_path):
return target_path
else:
# For a plain filename, recursively search the entire project directory.
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
def load_tools_from_json(tools_file_name: str) -> List[Dict[str, Any]]:
"""Load tool definitions from a JSON file."""
try:
tools_file = os.path.join(os.path.dirname(__file__), tools_file_name)
if os.path.exists(tools_file):
with open(tools_file, 'r', encoding='utf-8') as f:
return json.load(f)
else:
# If the JSON file does not exist, return the default empty definition.
return []
except Exception as e:
print(f"Warning: Unable to load tool definition JSON file: {str(e)}")
return []
def create_error_response(request_id: Any, code: int, message: str) -> Dict[str, Any]:
"""Create a standardized error response."""
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": code,
"message": message
}
}
def create_success_response(request_id: Any, result: Any) -> Dict[str, Any]:
"""Create a standardized success response."""
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
def create_initialize_response(request_id: Any, server_name: str, server_version: str = "1.0.0") -> Dict[str, Any]:
"""Create a standardized initialize response."""
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": server_name,
"version": server_version
}
}
}
def create_ping_response(request_id: Any) -> Dict[str, Any]:
"""Create a standardized ping response."""
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
def create_tools_list_response(request_id: Any, tools: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Create a standardized tools-list response."""
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": tools
}
}
def is_regex_pattern(pattern: str) -> bool:
"""Detect whether a string should be treated as a regex pattern."""
# Check the /pattern/ form.
if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2:
return True
# Check the r"pattern" or r'pattern' form.
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3:
return True
# Check whether it contains regex metacharacters.
regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'}
return any(char in pattern for char in regex_chars)
def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]:
"""Compile a regex pattern, or return the original string if it is not regex."""
if not is_regex_pattern(pattern):
return pattern
try:
# Handle the /pattern/ form.
if pattern.startswith('/') and pattern.endswith('/'):
regex_body = pattern[1:-1]
return re.compile(regex_body)
# Handle the r"pattern" or r'pattern' form.
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")):
regex_body = pattern[2:-1]
return re.compile(regex_body)
# Directly compile strings containing regex metacharacters.
return re.compile(pattern)
except re.error as e:
# Return None when compilation fails to indicate an invalid regex.
print(f"Warning: Regular expression '{pattern}' compilation failed: {e}")
return None
async def handle_mcp_streaming(request_handler):
"""Run the standard main loop for processing MCP requests."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await request_handler(request)
# Write to stdout
sys.stdout.write(json.dumps(response, ensure_ascii=False) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass