qwen_agent/mcp/excel_csv_operator_server.py
2025-10-22 00:45:32 +08:00

703 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Excel和CSV文件操作MCP服务器
支持Excel/CSV文件的读取、搜索、枚举值获取等操作
参考multi_keyword_search_server.py的实现方式
"""
import json
import os
import sys
import asyncio
import re
import chardet
from typing import Any, Dict, List, Optional, Union
import pandas as pd
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"Access denied: path traversal attack detected")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
def load_tools_from_json() -> List[Dict[str, Any]]:
"""从 JSON 文件加载工具定义"""
try:
tools_file = os.path.join(os.path.dirname(__file__), "tools", "excel_csv_operator_tools.json")
if os.path.exists(tools_file):
with open(tools_file, 'r', encoding='utf-8') as f:
return json.load(f)
else:
# 如果 JSON 文件不存在,使用默认定义
return []
except Exception as e:
print(f"Warning: Unable to load tool definition JSON file: {str(e)}")
return []
def detect_encoding(file_path: str) -> str:
"""检测文件编码"""
try:
with open(file_path, 'rb') as f:
raw_data = f.read(10000) # 读取前10KB来检测编码
result = chardet.detect(raw_data)
return result['encoding'] or 'utf-8'
except:
return 'utf-8'
def is_regex_pattern(pattern: str) -> bool:
"""检测字符串是否为正则表达式模式"""
# 检查 /pattern/ 格式
if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2:
return True
# 检查 r"pattern" 或 r'pattern' 格式
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3:
return True
# 检查是否包含正则特殊字符
regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'}
return any(char in pattern for char in regex_chars)
def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]:
"""编译正则表达式模式,如果不是正则则返回原字符串"""
if not is_regex_pattern(pattern):
return pattern
try:
# 处理 /pattern/ 格式
if pattern.startswith('/') and pattern.endswith('/'):
regex_body = pattern[1:-1]
return re.compile(regex_body)
# 处理 r"pattern" 或 r'pattern' 格式
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")):
regex_body = pattern[2:-1]
return re.compile(regex_body)
# 直接编译包含正则字符的字符串
return re.compile(pattern)
except re.error as e:
# 如果编译失败返回None表示无效的正则
print(f"Warning: Regular expression '{pattern}' compilation failed: {e}")
return None
class ExcelCSVOperator:
"""Excel和CSV文件操作核心类"""
def __init__(self):
self.supported_extensions = ['.xlsx', '.xls', '.csv']
self.encoding_cache = {}
def _validate_file(self, file_path: str) -> str:
"""验证并处理文件路径"""
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 解析相对路径
if not os.path.isabs(file_path):
# 移除 projects/ 前缀(如果存在)
clean_path = file_path
if clean_path.startswith('projects/'):
clean_path = clean_path[9:] # 移除 'projects/' 前缀
elif clean_path.startswith('./projects/'):
clean_path = clean_path[11:] # 移除 './projects/' 前缀
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
if os.path.exists(full_path):
file_path = full_path
else:
# 如果直接路径不存在,尝试递归查找
found = self._find_file_in_project(clean_path, project_data_dir)
if found:
file_path = found
else:
raise ValueError(f"File does not exist: {file_path}")
# 验证文件扩展名
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in self.supported_extensions:
raise ValueError(f"Unsupported file format: {file_ext}, supported formats: {self.supported_extensions}")
return file_path
def _find_file_in_project(self, filename: str, project_dir: str) -> Optional[str]:
"""在项目目录中递归查找文件"""
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
def load_data(self, file_path: str, sheet_name: str = None) -> pd.DataFrame:
"""加载Excel或CSV文件数据"""
file_path = self._validate_file(file_path)
file_ext = os.path.splitext(file_path)[1].lower()
try:
if file_ext == '.csv':
encoding = detect_encoding(file_path)
df = pd.read_csv(file_path, encoding=encoding)
else:
# Excel文件
if sheet_name:
df = pd.read_excel(file_path, sheet_name=sheet_name)
else:
# 读取第一个sheet
df = pd.read_excel(file_path)
# 处理空值
df = df.fillna('')
return df
except Exception as e:
raise ValueError(f"File loading failed: {str(e)}")
def get_sheets(self, file_path: str) -> List[str]:
"""获取Excel文件的所有sheet名称"""
file_path = self._validate_file(file_path)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.csv':
return ['default'] # CSV文件只有一个默认sheet
try:
excel_file = pd.ExcelFile(file_path)
return excel_file.sheet_names
except Exception as e:
raise ValueError(f"Failed to read Excel sheet list: {str(e)}")
def get_schema(self, file_path: str, sheet_name: str = None) -> List[str]:
"""获取文件的schema字段列表"""
try:
df = self.load_data(file_path, sheet_name)
return df.columns.tolist()
except Exception as e:
raise ValueError(f"Failed to get schema: {str(e)}")
def full_text_search(self, file_path: str, keywords: List[str],
top_k: int = 10, case_sensitive: bool = False) -> str:
"""全文搜索功能"""
if not keywords:
return "Error: Keyword list cannot be empty"
# 预处理和验证关键词中的正则表达式
valid_keywords = []
regex_errors = []
for keyword in keywords:
compiled = compile_pattern(keyword)
if compiled is None:
regex_errors.append(keyword)
else:
valid_keywords.append(keyword)
if regex_errors:
error_msg = f"Warning: The following regular expressions failed to compile and will be ignored: {', '.join(regex_errors)}"
print(error_msg)
if not valid_keywords:
return "Error: No valid search keywords"
try:
# 验证文件路径
validated_path = self._validate_file(file_path)
file_ext = os.path.splitext(validated_path)[1].lower()
all_results = []
if file_ext == '.csv':
# CSV文件只有一个数据集
results = self._search_in_file(validated_path, valid_keywords, case_sensitive, 'default')
all_results.extend(results)
else:
# Excel文件搜索所有sheet
sheets = self.get_sheets(validated_path)
for sheet in sheets:
results = self._search_in_file(validated_path, valid_keywords, case_sensitive, sheet)
all_results.extend(results)
# 按匹配数量排序(降序)
all_results.sort(key=lambda x: x['match_count'], reverse=True)
# 限制结果数量
limited_results = all_results[:top_k]
# 格式化为CSV输出
if not limited_results:
return "No matching results found"
# 构建CSV格式输出
csv_lines = []
headers = ["sheet", "row_index", "match_count", "matched_content", "match_details"]
csv_lines.append(",".join(headers))
for result in limited_results:
# 转义CSV中的特殊字符
sheet = str(result.get('sheet', '')).replace(',', '')
row_index = str(result.get('row_index', ''))
match_count = str(result.get('match_count', 0))
matched_content = str(result.get('matched_content', '')).replace(',', '').replace('\n', ' ')
match_details = str(result.get('match_details', '')).replace(',', '')
csv_lines.append(f"{sheet},{row_index},{match_count},{matched_content},{match_details}")
return "\n".join(csv_lines)
except Exception as e:
return f"Search failed: {str(e)}"
def _search_in_file(self, file_path: str, keywords: List[str],
case_sensitive: bool, sheet_name: str = None) -> List[Dict[str, Any]]:
"""在文件中搜索关键词"""
results = []
try:
df = self.load_data(file_path, sheet_name)
# 预处理所有模式
processed_patterns = []
for keyword in keywords:
compiled = compile_pattern(keyword)
if compiled is not None:
processed_patterns.append({
'original': keyword,
'pattern': compiled,
'is_regex': isinstance(compiled, re.Pattern)
})
# 逐行搜索
for row_index, row in df.iterrows():
# 将整行内容合并为字符串进行搜索
row_content = " ".join([str(cell) for cell in row.values if str(cell).strip()])
search_content = row_content if case_sensitive else row_content.lower()
# 统计匹配的模式数量
matched_patterns = []
for pattern_info in processed_patterns:
pattern = pattern_info['pattern']
is_regex = pattern_info['is_regex']
match_found = False
match_details = None
if is_regex:
# 正则表达式匹配
if case_sensitive:
match = pattern.search(row_content)
else:
# 对于不区分大小写的正则,需要重新编译
if isinstance(pattern, re.Pattern):
flags = pattern.flags | re.IGNORECASE
case_insensitive_pattern = re.compile(pattern.pattern, flags)
match = case_insensitive_pattern.search(row_content)
else:
match = pattern.search(search_content)
if match:
match_found = True
match_details = match.group(0)
else:
# 普通字符串匹配
search_keyword = pattern if case_sensitive else pattern.lower()
if search_keyword in search_content:
match_found = True
match_details = pattern
if match_found:
matched_patterns.append({
'original': pattern_info['original'],
'type': 'regex' if is_regex else 'keyword',
'match': match_details
})
match_count = len(matched_patterns)
if match_count > 0:
# 构建匹配详情
match_details = []
for pattern in matched_patterns:
if pattern['type'] == 'regex':
match_details.append(f"[regex:{pattern['original']}={pattern['match']}]")
else:
match_details.append(f"[keyword:{pattern['match']}]")
match_info = " ".join(match_details)
results.append({
'sheet': sheet_name,
'row_index': row_index,
'match_count': match_count,
'matched_content': row_content,
'match_details': match_info
})
except Exception as e:
print(f"Error searching file {file_path} (sheet: {sheet_name}): {str(e)}")
return results
def filter_search(self, file_path: str, filters: Dict,
sheet_name: str = None) -> str:
"""字段过滤搜索功能"""
if not filters:
return "Error: Filter conditions cannot be empty"
try:
df = self.load_data(file_path, sheet_name)
# 应用过滤条件
filtered_df = df.copy()
for field_name, filter_condition in filters.items():
if field_name not in df.columns:
return f"Error: Field '{field_name}' does not exist"
operator = filter_condition.get('operator', 'eq')
value = filter_condition.get('value')
if operator == 'eq':
# 等于
filtered_df = filtered_df[filtered_df[field_name] == value]
elif operator == 'gt':
# 大于
filtered_df = filtered_df[filtered_df[field_name] > value]
elif operator == 'lt':
# 小于
filtered_df = filtered_df[filtered_df[field_name] < value]
elif operator == 'gte':
# 大于等于
filtered_df = filtered_df[filtered_df[field_name] >= value]
elif operator == 'lte':
# 小于等于
filtered_df = filtered_df[filtered_df[field_name] <= value]
elif operator == 'contains':
# 包含
filtered_df = filtered_df[filtered_df[field_name].astype(str).str.contains(str(value), na=False)]
elif operator == 'regex':
# 正则表达式
try:
pattern = re.compile(str(value))
filtered_df = filtered_df[filtered_df[field_name].astype(str).str.match(pattern, na=False)]
except re.error as e:
return f"Error: Regular expression '{value}' compilation failed: {str(e)}"
else:
return f"Error: Unsupported operator '{operator}'"
# 格式化为CSV输出
if filtered_df.empty:
return "No records matching conditions found"
# 转换为CSV字符串
csv_result = filtered_df.to_csv(index=False, encoding='utf-8')
return csv_result
except Exception as e:
return f"Filter search failed: {str(e)}"
def get_field_enums(self, file_path: str, field_names: List[str],
sheet_name: str = None, max_enum_count: int = 100,
min_occurrence: int = 1) -> str:
"""获取指定字段的枚举值列表"""
if not field_names:
return "Error: Field name list cannot be empty"
try:
df = self.load_data(file_path, sheet_name)
# 验证字段存在性
missing_fields = [field for field in field_names if field not in df.columns]
if missing_fields:
return f"Error: Fields do not exist: {', '.join(missing_fields)}"
# 计算每个字段的枚举值
enum_results = {}
for field in field_names:
# 统计值出现次数
value_counts = df[field].value_counts()
# 过滤出现次数过少的值
filtered_counts = value_counts[value_counts >= min_occurrence]
# 限制返回数量
top_values = filtered_counts.head(max_enum_count)
# 格式化结果
enum_values = []
for value, count in top_values.items():
enum_values.append(f"{value}({count})")
enum_results[field] = {
'enum_values': enum_values,
'total_unique': len(value_counts),
'total_filtered': len(filtered_counts),
'total_rows': len(df)
}
# 格式化输出
output_lines = []
for field, data in enum_results.items():
enum_str = ", ".join(data['enum_values'])
field_info = f"{field}: [{enum_str}] (总计: {data['total_unique']}个唯一值, 过滤后: {data['total_filtered']}个, 总行数: {data['total_rows']})"
output_lines.append(field_info)
return "\n".join(output_lines)
except Exception as e:
return f"Failed to get enum values: {str(e)}"
# 全局操作器实例
operator = ExcelCSVOperator()
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "excel-csv-operator",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
# 从 JSON 文件加载工具定义
tools = load_tools_from_json()
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": tools
}
}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "get_excel_sheets":
file_path = arguments.get("file_path")
result = operator.get_sheets(file_path)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": json.dumps(result, ensure_ascii=False, indent=2)
}
]
}
}
elif tool_name == "get_table_schema":
file_path = arguments.get("file_path")
sheet_name = arguments.get("sheet_name")
result = operator.get_schema(file_path, sheet_name)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": json.dumps(result, ensure_ascii=False, indent=2)
}
]
}
}
elif tool_name == "full_text_search":
file_path = arguments.get("file_path")
keywords = arguments.get("keywords", [])
top_k = arguments.get("top_k", 10)
case_sensitive = arguments.get("case_sensitive", False)
result = operator.full_text_search(file_path, keywords, top_k, case_sensitive)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result
}
]
}
}
elif tool_name == "filter_search":
file_path = arguments.get("file_path")
sheet_name = arguments.get("sheet_name")
filters = arguments.get("filters")
result = operator.filter_search(file_path, filters, sheet_name)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result
}
]
}
}
elif tool_name == "get_field_enums":
file_path = arguments.get("file_path")
sheet_name = arguments.get("sheet_name")
field_names = arguments.get("field_names", [])
max_enum_count = arguments.get("max_enum_count", 100)
min_occurrence = arguments.get("min_occurrence", 1)
result = operator.get_field_enums(file_path, field_names, sheet_name, max_enum_count, min_occurrence)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result
}
]
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
async def main():
"""Main entry point."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_request(request)
# Write to stdout
sys.stdout.write(json.dumps(response, ensure_ascii=False) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())