diff --git a/__pycache__/agent_pool.cpython-312.pyc b/__pycache__/agent_pool.cpython-312.pyc new file mode 100644 index 0000000..b674af2 Binary files /dev/null and b/__pycache__/agent_pool.cpython-312.pyc differ diff --git a/__pycache__/fastapi_app.cpython-312.pyc b/__pycache__/fastapi_app.cpython-312.pyc new file mode 100644 index 0000000..2a1a252 Binary files /dev/null and b/__pycache__/fastapi_app.cpython-312.pyc differ diff --git a/__pycache__/gbase_agent.cpython-312.pyc b/__pycache__/gbase_agent.cpython-312.pyc index 42beec4..3b844b9 100644 Binary files a/__pycache__/gbase_agent.cpython-312.pyc and b/__pycache__/gbase_agent.cpython-312.pyc differ diff --git a/__pycache__/project_config.cpython-312.pyc b/__pycache__/project_config.cpython-312.pyc new file mode 100644 index 0000000..c24f269 Binary files /dev/null and b/__pycache__/project_config.cpython-312.pyc differ diff --git a/__pycache__/session_manager.cpython-312.pyc b/__pycache__/session_manager.cpython-312.pyc new file mode 100644 index 0000000..9a9599d Binary files /dev/null and b/__pycache__/session_manager.cpython-312.pyc differ diff --git a/agent_pool.py b/agent_pool.py new file mode 100644 index 0000000..82b9684 --- /dev/null +++ b/agent_pool.py @@ -0,0 +1,178 @@ +import asyncio +from typing import List, Optional +import logging + +logger = logging.getLogger(__name__) + + +class AgentPool: + """助手实例池管理器""" + + def __init__(self, pool_size: int = 5): + """ + 初始化助手实例池 + + Args: + pool_size: 池中实例的数量,默认5个 + """ + self.pool_size = pool_size + self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size) + self.semaphore = asyncio.Semaphore(pool_size) + self.agents = [] # 保存所有创建的实例引用 + + async def initialize(self, agent_factory): + """ + 初始化实例池,使用工厂函数创建助手实例 + + Args: + agent_factory: 创建助手实例的工厂函数 + """ + logger.info(f"正在初始化助手实例池,大小: {self.pool_size}") + + for i in range(self.pool_size): + try: + agent = agent_factory() + await self.pool.put(agent) + self.agents.append(agent) + logger.info(f"助手实例 {i+1}/{self.pool_size} 创建成功") + except Exception as e: + logger.error(f"创建助手实例 {i+1} 失败: {e}") + raise + + logger.info("助手实例池初始化完成") + + async def get_agent(self, timeout: Optional[float] = 30.0): + """ + 获取空闲的助手实例 + + Args: + timeout: 获取超时时间,默认30秒 + + Returns: + 助手实例 + + Raises: + asyncio.TimeoutError: 获取超时 + """ + try: + # 使用信号量控制并发 + await asyncio.wait_for(self.semaphore.acquire(), timeout=timeout) + # 从池中获取实例 + agent = await asyncio.wait_for(self.pool.get(), timeout=timeout) + logger.debug(f"成功获取助手实例,剩余池大小: {self.pool.qsize()}") + return agent + except asyncio.TimeoutError: + logger.error(f"获取助手实例超时 ({timeout}秒)") + raise + + async def release_agent(self, agent): + """ + 释放助手实例回池 + + Args: + agent: 要释放的助手实例 + """ + try: + await self.pool.put(agent) + self.semaphore.release() + logger.debug(f"释放助手实例,当前池大小: {self.pool.qsize()}") + except Exception as e: + logger.error(f"释放助手实例失败: {e}") + # 即使释放失败也要释放信号量 + self.semaphore.release() + + def get_pool_stats(self) -> dict: + """ + 获取池状态统计信息 + + Returns: + 包含池状态信息的字典 + """ + return { + "pool_size": self.pool_size, + "available_agents": self.pool.qsize(), + "total_agents": len(self.agents), + "in_use_agents": len(self.agents) - self.pool.qsize() + } + + async def shutdown(self): + """关闭实例池,清理资源""" + logger.info("正在关闭助手实例池...") + + # 清空队列 + while not self.pool.empty(): + try: + agent = self.pool.get_nowait() + # 如果有清理方法,调用清理 + if hasattr(agent, 'cleanup'): + await agent.cleanup() + except asyncio.QueueEmpty: + break + + logger.info("助手实例池已关闭") + + +# 全局实例池单例 +_global_agent_pool: Optional[AgentPool] = None + + +def get_agent_pool() -> Optional[AgentPool]: + """获取全局助手实例池""" + return _global_agent_pool + + +def set_agent_pool(pool: AgentPool): + """设置全局助手实例池""" + global _global_agent_pool + _global_agent_pool = pool + + +async def init_global_agent_pool(pool_size: int = 5, agent_factory=None): + """ + 初始化全局助手实例池 + + Args: + pool_size: 池大小 + agent_factory: 实例工厂函数 + """ + global _global_agent_pool + + if _global_agent_pool is not None: + logger.warning("全局助手实例池已存在,跳过初始化") + return + + if agent_factory is None: + raise ValueError("必须提供 agent_factory 参数") + + _global_agent_pool = AgentPool(pool_size=pool_size) + await _global_agent_pool.initialize(agent_factory) + logger.info("全局助手实例池初始化完成") + + +async def get_agent_from_pool(timeout: Optional[float] = 30.0): + """ + 从全局池获取助手实例 + + Args: + timeout: 获取超时时间 + + Returns: + 助手实例 + """ + if _global_agent_pool is None: + raise RuntimeError("全局助手实例池未初始化") + + return await _global_agent_pool.get_agent(timeout) + + +async def release_agent_to_pool(agent): + """ + 释放助手实例到全局池 + + Args: + agent: 要释放的助手实例 + """ + if _global_agent_pool is None: + raise RuntimeError("全局助手实例池未初始化") + + await _global_agent_pool.release_agent(agent) \ No newline at end of file diff --git a/agent_prompt.txt b/agent_prompt.txt index 815463b..e7826b3 100644 --- a/agent_prompt.txt +++ b/agent_prompt.txt @@ -17,7 +17,7 @@ ### 数据存储层次 ``` -./data/ +[当前数据目录]/ ├── [数据集文件夹]/ │ ├── schema.json # 倒排索引层 │ ├── serialization.txt # 序列化数据层 @@ -28,7 +28,7 @@ #### 1. 索引层 (schema.json) - **功能**:字段枚举值倒排索引,查询入口点 -- **访问方式**:`json-reader-get_all_keys({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema"})` +- **访问方式**:`json-reader-get_all_keys({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema"})` - **数据结构**: ```json { @@ -66,14 +66,14 @@ **执行步骤**: 1. **加载索引**:读取schema.json获取字段元数据 2. **字段分析**:识别数值字段、文本字段、枚举字段 -3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "./data/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围 +3. **字段详情分析**:对于相关字段调用`json-reader-get_value({"file_path": "[当前数据目录]/[数据集文件夹]/schema.json", "key_path": "schema.[字段名]"})`查看具体的枚举值和取值范围 4. **策略制定**:基于查询条件选择最优检索路径 5. **范围预估**:评估各条件的数据分布和选择度 ### 阶段2:精准数据匹配 **目标**:从序列化数据中提取符合条件的记录 **执行步骤**: -1. **预检查**:`ripgrep-count-matches({"path": "./data/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})` +1. **预检查**:`ripgrep-count-matches({"path": "[当前数据目录]/[数据集文件夹]/serialization.txt", "pattern": "匹配模式"})` 2. **智能限流**: - 匹配数 > 1000:增加过滤条件,重新预检查 - 匹配数 100-1000:`ripgrep-search({"maxResults": 30})` @@ -193,3 +193,5 @@ query_pattern = simple_field_match(conditions[0]) # 先匹配主要条件 --- **执行提醒**:始终使用完整的文件路径参数调用工具,确保数据访问的准确性和安全性。在查询执行过程中,动态调整策略以适应不同的数据特征和查询需求。 + +**重要说明**:所有文件路径中的 `[当前数据目录]` 将通过系统消息动态提供,请根据实际的数据目录路径进行操作。 diff --git a/data/all_hp_product_spec_book2506/.DS_Store b/data/all_hp_product_spec_book2506/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/data/all_hp_product_spec_book2506/.DS_Store and /dev/null differ diff --git a/fastapi_app.py b/fastapi_app.py index 97bc74a..5c16e4a 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -1,62 +1,247 @@ -from typing import Optional +import json +import os +from typing import AsyncGenerator, Dict, List, Optional, Union import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import BackgroundTasks, FastAPI, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel +from qwen_agent.llm.schema import ASSISTANT, FUNCTION -from gbase_agent import init_agent_service + +# 自定义版本,不需要text参数,不打印到终端 +def get_content_from_messages(messages: List[dict]) -> str: + full_text = '' + content = [] + TOOL_CALL_S = '[TOOL_CALL]' + TOOL_RESULT_S = '[TOOL_RESPONSE]' + THOUGHT_S = '[THINK]' + ANSWER_S = '[ANSWER]' + + for msg in messages: + if msg['role'] == ASSISTANT: + if msg.get('reasoning_content'): + assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages' + content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}') + if msg.get('content'): + assert isinstance(msg['content'], str), 'Now only supports text messages' + content.append(f'{ANSWER_S}\n{msg["content"]}') + if msg.get('function_call'): + content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}') + elif msg['role'] == FUNCTION: + content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}') + else: + raise TypeError + if content: + full_text = '\n'.join(content) + + return full_text + +from agent_pool import (get_agent_from_pool, init_global_agent_pool, + release_agent_to_pool) +from gbase_agent import init_agent_service_universal, update_agent_llm +from project_config import project_manager app = FastAPI(title="Database Assistant API", version="1.0.0") -# Initialize agent globally at startup -bot = init_agent_service() +# 全局助手实例池,在应用启动时初始化 +agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1")) -class QueryRequest(BaseModel): - question: str +class Message(BaseModel): + role: str + content: str + + +class ChatRequest(BaseModel): + messages: List[Message] + model: str = "qwen3-next" + api_key: Optional[str] = None + extra: Optional[Dict] = None + stream: Optional[bool] = False file_url: Optional[str] = None -class QueryResponse(BaseModel): - answer: str +class ChatResponse(BaseModel): + choices: List[Dict] + usage: Optional[Dict] = None -@app.post("/query", response_model=QueryResponse) -async def query_database(request: QueryRequest): - """ - Process a database query using the assistant agent. +class ChatStreamResponse(BaseModel): + choices: List[Dict] + usage: Optional[Dict] = None - Args: - request: QueryRequest containing the query and optional file URL - Returns: - QueryResponse containing the assistant's response - """ +async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]: + """生成流式响应""" + accumulated_content = "" + accumulated_args = "" + chunk_id = 0 try: - messages = [] - - if request.file_url: - messages.append( - { - "role": "user", - "content": [{"text":"使用sqlite数据库,用日语回答下面问题:"+request.question}, {"file": request.file_url}], + for response in agent.run(messages=messages): + previous_content = accumulated_content + accumulated_content = get_content_from_messages(response) + + # 计算新增的内容 + if accumulated_content.startswith(previous_content): + new_content = accumulated_content[len(previous_content):] + else: + new_content = accumulated_content + previous_content = "" + + # 只有当有新内容时才发送chunk + if new_content: + chunk_id += 1 + # 构造OpenAI格式的流式响应 + chunk_data = { + "id": f"chatcmpl-{chunk_id}", + "object": "chat.completion.chunk", + "created": int(__import__('time').time()), + "model": request.model, + "choices": [{ + "index": 0, + "delta": { + "content": new_content + }, + "finish_reason": None + }] } + + yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n" + + # 发送最终完成标记 + final_chunk = { + "id": f"chatcmpl-{chunk_id + 1}", + "object": "chat.completion.chunk", + "created": int(__import__('time').time()), + "model": request.model, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }] + } + yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" + + # 发送结束标记 + yield "data: [DONE]\n\n" + + except Exception as e: + import traceback + error_details = traceback.format_exc() + print(f"Error in generate_stream_response: {str(e)}") + print(f"Full traceback: {error_details}") + + error_data = { + "error": { + "message": f"Stream error: {str(e)}", + "type": "internal_error" + } + } + yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" + + +@app.post("/chat/completions") +async def chat_completions(request: ChatRequest): + """ + Chat completions API similar to OpenAI, supports both streaming and non-streaming + + Args: + request: ChatRequest containing messages, model, project_id in extra field, etc. + + Returns: + Union[ChatResponse, StreamingResponse]: Chat completion response or stream + """ + agent = None + try: + # 从extra字段中获取project_id + if not request.extra or 'project_id' not in request.extra: + raise HTTPException(status_code=400, detail="project_id is required in extra field") + + project_id = request.extra['project_id'] + + # 验证项目访问权限 + if not project_manager.validate_project_access(project_id): + raise HTTPException(status_code=404, detail=f"Project {project_id} not found or inactive") + + # 获取项目数据目录 + project_dir = project_manager.get_project_dir(project_id) + + # 从实例池获取助手实例 + agent = await get_agent_from_pool(timeout=30.0) + + # 准备LLM配置,从extra字段中移除project_id + llm_extra = request.extra.copy() if request.extra else {} + llm_extra.pop('project_id', None) # 移除project_id,不传递给LLM + + # 动态设置请求的模型,支持从接口传入api_key和extra参数 + update_agent_llm(agent, request.model, request.api_key, llm_extra) + + # 构建包含项目信息的消息上下文 + messages = [ + # 项目信息系统消息 + { + "role": "user", + "content": f"当前项目ID: {project_id},数据目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}" + }, + # 用户消息批量转换 + *[{"role": msg.role, "content": msg.content} for msg in request.messages] + ] + + # 根据stream参数决定返回流式还是非流式响应 + if request.stream: + return StreamingResponse( + generate_stream_response(agent, messages, request), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) else: - messages.append({"role": "user", "content": request.question}) - - responses = [] - for response in bot.run(messages): - responses.append(response) - - if responses: - final_response = responses[-1][-1] - return QueryResponse(answer=final_response["content"]) - else: - raise HTTPException(status_code=500, detail="No response from agent") + # 非流式响应 + final_responses = agent.run_nonstream(messages) + + if final_responses and len(final_responses) > 0: + # 取最后一个响应 + final_response = final_responses[-1] + + # 如果返回的是Message对象,需要转换为字典 + if hasattr(final_response, 'model_dump'): + final_response = final_response.model_dump() + elif hasattr(final_response, 'dict'): + final_response = final_response.dict() + + content = final_response.get("content", "") + + # 构造OpenAI格式的响应 + return ChatResponse( + choices=[{ + "index": 0, + "message": { + "role": "assistant", + "content": content + }, + "finish_reason": "stop" + }], + usage={ + "prompt_tokens": sum(len(msg.content) for msg in request.messages), + "completion_tokens": len(content), + "total_tokens": sum(len(msg.content) for msg in request.messages) + len(content) + } + ) + else: + raise HTTPException(status_code=500, detail="No response from agent") except Exception as e: + import traceback + error_details = traceback.format_exc() + print(f"Error in chat_completions: {str(e)}") + print(f"Full traceback: {error_details}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + finally: + # 确保释放助手实例回池 + if agent is not None: + await release_agent_to_pool(agent) + + @app.get("/") @@ -65,6 +250,53 @@ async def root(): return {"message": "Database Assistant API is running"} +@app.get("/system/status") +async def system_status(): + """获取系统状态信息""" + from agent_pool import get_agent_pool + + pool = get_agent_pool() + pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0} + + return { + "status": "running", + "storage_type": "Agent Pool API", + "agent_pool": { + "pool_size": pool_stats["pool_size"], + "available_agents": pool_stats["available_agents"], + "total_agents": pool_stats["total_agents"], + "in_use_agents": pool_stats["in_use_agents"] + } + } + + +@app.on_event("startup") +async def startup_event(): + """应用启动时初始化助手实例池""" + print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...") + + try: + def agent_factory(): + return init_agent_service_universal() + + await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory) + print("助手实例池初始化完成!") + except Exception as e: + print(f"助手实例池初始化失败: {e}") + raise + + +@app.on_event("shutdown") +async def shutdown_event(): + """应用关闭时清理实例池""" + print("正在关闭应用,清理助手实例池...") + + from agent_pool import get_agent_pool + pool = get_agent_pool() + if pool: + await pool.shutdown() + print("助手实例池清理完成!") + + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) - diff --git a/gbase_agent.py b/gbase_agent.py index a3d7cdd..f682094 100644 --- a/gbase_agent.py +++ b/gbase_agent.py @@ -16,7 +16,6 @@ import argparse import asyncio -import copy import json import os from typing import Dict, List, Optional, Union @@ -30,118 +29,6 @@ from qwen_agent.utils.output_beautify import typewriter_print ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource") -class GPT4OChat(TextChatAtOAI): - """自定义 GPT-4o 聊天类,修复 tool_call_id 问题""" - - def convert_messages_to_dicts(self, messages: List[Message]) -> List[dict]: - # 使用父类方法进行基础转换 - messages = super().convert_messages_to_dicts(messages) - - # 应用修复后的消息转换 - messages = self._fixed_conv_qwen_agent_messages_to_oai(messages) - - return messages - - @staticmethod - def _fixed_conv_qwen_agent_messages_to_oai(messages: List[Union[Message, Dict]]): - """修复后的消息转换方法,确保 tool 消息包含 tool_call_id 字段""" - new_messages = [] - i = 0 - - while i < len(messages): - msg = messages[i] - - if msg['role'] == ASSISTANT: - # 处理 assistant 消息 - assistant_msg = {'role': 'assistant'} - - # 设置 content - content = msg.get('content', '') - if isinstance(content, (list, dict)): - assistant_msg['content'] = json.dumps(content, ensure_ascii=False) - elif content is None: - assistant_msg['content'] = '' - else: - assistant_msg['content'] = content - - # 设置 reasoning_content - if msg.get('reasoning_content'): - assistant_msg['reasoning_content'] = msg['reasoning_content'] - - # 检查是否需要构造 tool_calls - has_tool_call = False - tool_calls = [] - - # 情况1:当前消息有 function_call - if msg.get('function_call'): - has_tool_call = True - tool_calls.append({ - 'id': msg.get('extra', {}).get('function_id', '1'), - 'type': 'function', - 'function': { - 'name': msg['function_call']['name'], - 'arguments': msg['function_call']['arguments'] - } - }) - - # 注意:不再为孤立的 tool 消息构造虚假的 tool_call - - if has_tool_call: - assistant_msg['tool_calls'] = tool_calls - new_messages.append(assistant_msg) - - # 检查后续是否有对应的 tool 消息 - if i + 1 < len(messages) and messages[i + 1]['role'] == 'tool': - tool_msg = copy.deepcopy(messages[i + 1]) - # 确保 tool_call_id 匹配 - tool_msg['tool_call_id'] = tool_calls[0]['id'] - # 移除多余字段 - for field in ['id', 'extra', 'function_call']: - if field in tool_msg: - del tool_msg[field] - # 确保 content 有效且为字符串 - content = tool_msg.get('content', '') - if isinstance(content, (list, dict)): - tool_msg['content'] = json.dumps(content, ensure_ascii=False) - elif content is None: - tool_msg['content'] = '' - new_messages.append(tool_msg) - i += 2 - else: - i += 1 - else: - new_messages.append(assistant_msg) - i += 1 - - elif msg['role'] == 'tool': - # 孤立的 tool 消息,转换为 assistant + user 消息序列 - # 首先添加一个包含工具结果的 assistant 消息 - assistant_result = {'role': 'assistant'} - content = msg.get('content', '') - if isinstance(content, (list, dict)): - content = json.dumps(content, ensure_ascii=False) - assistant_result['content'] = f"工具查询结果: {content}" - new_messages.append(assistant_result) - - # 然后添加一个 user 消息来继续对话 - new_messages.append({'role': 'user', 'content': '请继续分析以上结果'}) - i += 1 - - else: - # 处理其他角色消息 - new_msg = copy.deepcopy(msg) - - # 确保 content 有效且为字符串 - content = new_msg.get('content', '') - if isinstance(content, (list, dict)): - new_msg['content'] = json.dumps(content, ensure_ascii=False) - elif content is None: - new_msg['content'] = '' - - new_messages.append(new_msg) - i += 1 - - return new_messages def read_mcp_settings(): @@ -150,107 +37,70 @@ def read_mcp_settings(): return mcp_settings_json +def read_mcp_settings_with_project_restriction(project_data_dir: str): + """读取MCP配置并添加项目目录限制""" + with open("./mcp/mcp_settings.json", "r") as f: + mcp_settings_json = json.load(f) + + # 为json-reader添加项目目录限制 + for server_config in mcp_settings_json: + if "mcpServers" in server_config: + for server_name, server_info in server_config["mcpServers"].items(): + if server_name == "json-reader": + # 添加环境变量来传递项目目录限制 + if "env" not in server_info: + server_info["env"] = {} + server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir + server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default" + break + + return mcp_settings_json + + def read_system_prompt(): with open("./agent_prompt.txt", "r", encoding="utf-8") as f: return f.read().strip() +def read_system_prompt(): + """读取通用的无状态系统prompt""" + with open("./agent_prompt.txt", "r", encoding="utf-8") as f: + return f.read().strip() + + def init_agent_service(): - llm_cfg = { - "llama-33": { - "model": "gbase-llama-33", - "model_server": "http://llmapi:9009/v1", - "api_key": "any", - }, - "gpt-oss-120b": { - "model": "openai/gpt-oss-120b", - "model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base - "api_key": "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", - - "generate_cfg": { - "use_raw_api": True, # GPT-OSS true ,Qwen false - "fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示 - } - }, + """默认初始化函数,保持向后兼容""" + return init_agent_service_universal("qwen3-next") - "claude-3.7": { - "model": "claude-3-7-sonnet-20250219", - "model_server": "https://one.felo.me/v1", - "api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2", - "generate_cfg": { - "use_raw_api": True, # GPT-OSS true ,Qwen false - }, - }, - "gpt-4o": { - "model": "gpt-4o", - "model_server": "https://one-dev.felo.me/v1", - "api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4", - "generate_cfg": { - "use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题 - "fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示 - }, - }, - "Gpt-4o-back": { - "model_type": "oai", # 使用 oai 类型以便使用自定义类 - "model": "gpt-4o", - "model_server": "https://one-dev.felo.me/v1", - "api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4", - "generate_cfg": { - "use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题 - "fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示 - }, - # 使用自定义的 GPT4OChat 类 - "llm_class": GPT4OChat, - }, - "glm-45": { - "model_server": "https://open.bigmodel.cn/api/paas/v4", - "api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm", - "model": "glm-4.5", - "generate_cfg": { - "use_raw_api": True, # GPT-OSS true ,Qwen false - }, - }, - "qwen3-next": { - "model": "qwen/qwen3-next-80b-a3b-instruct", - "model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base - "api_key": "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", +def read_llm_config(): + """读取LLM配置文件""" + with open("./llm_config.json", "r") as f: + return json.load(f) - }, - "deepresearch": { - "model": "alibaba/tongyi-deepresearch-30b-a3b", - "model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base - "api_key": "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", - }, - "qwen3-coder":{ - "model": "Qwen/Qwen3-Coder-30B-A3B-Instruct", - "model_server": "https://api-inference.modelscope.cn/v1", # base_url, also known as api_base - "api_key": "ms-92027446-2787-4fd6-af01-f002459ec556", - }, - "openrouter-gpt4o":{ - "model": "openai/gpt-4o", - "model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base - "api_key": "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", - "generate_cfg": { - "use_raw_api": True, # GPT-OSS true ,Qwen false - "fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示 - }, - } - } +def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"): + """支持项目目录的agent初始化函数 - 保持向后兼容""" + llm_cfg = read_llm_config() + + # 读取通用的系统prompt(无状态) system = read_system_prompt() - # 暂时禁用 MCP 工具以测试 GPT-4o - tools = read_mcp_settings() - # 使用自定义的 GPT-4o 配置 - llm_instance = llm_cfg["qwen3-next"] + # 读取MCP工具配置 + tools = read_mcp_settings_with_project_restriction(project_data_dir) + + # 使用指定的模型配置 + if model_name not in llm_cfg: + raise ValueError(f"Model '{model_name}' not found in llm_config.json. Available models: {list(llm_cfg.keys())}") + + llm_instance = llm_cfg[model_name] if "llm_class" in llm_instance: llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance) bot = Assistant( - llm=llm_instance, # 使用自定义的 GPT-4o 实例 - name="数据库助手", - description="数据库查询", + llm=llm_instance, # 使用指定的模型实例 + name=f"数据库助手-{project_id}", + description=f"项目 {project_id} 数据库查询", system_message=system, function_list=tools, ) @@ -258,6 +108,62 @@ def init_agent_service(): return bot +def init_agent_service_universal(): + """创建无状态的通用助手实例(使用默认LLM,可动态切换)""" + llm_cfg = read_llm_config() + + # 读取通用的系统prompt(无状态) + system = read_system_prompt() + + # 读取基础的MCP工具配置(不包含项目限制) + tools = read_mcp_settings() + + # 使用默认模型创建助手实例 + default_model = "qwen3-next" # 默认模型 + if default_model not in llm_cfg: + # 如果默认模型不存在,使用第一个可用模型 + default_model = list(llm_cfg.keys())[0] + + llm_instance = llm_cfg[default_model] + if "llm_class" in llm_instance: + llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance) + + bot = Assistant( + llm=llm_instance, # 使用默认LLM初始化 + name="通用数据检索助手", + description="无状态通用数据检索助手", + system_message=system, + function_list=tools, + ) + + return bot + + +def update_agent_llm(agent, model_name: str, api_key: str = None, extra: Dict = None): + """动态更新助手实例的LLM,支持从接口传入参数""" + + # 获取基础配置 + llm_config = { + "model": model_name, + "api_key": api_key, + } + # 如果接口传入了extra参数,则合并到配置中 + if extra is not None: + llm_config.update(extra) + + # 创建LLM实例,确保不是字典 + if "llm_class" in llm_config: + llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config) + else: + # 使用默认的 TextChatAtOAI 类 + llm_instance = TextChatAtOAI(llm_config) + + # 动态设置LLM + agent.llm = llm_instance + + return agent + + def test(query="数据库里有几张表"): # Define the agent bot = init_agent_service() diff --git a/llm_config.json b/llm_config.json new file mode 100644 index 0000000..c5df2e2 --- /dev/null +++ b/llm_config.json @@ -0,0 +1,65 @@ +{ + "llama-33": { + "model": "gbase-llama-33", + "model_server": "http://llmapi:9009/v1", + "api_key": "any" + }, + "gpt-oss-120b": { + "model": "openai/gpt-oss-120b", + "model_server": "https://openrouter.ai/api/v1", + "api_key": "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", + "generate_cfg": { + "use_raw_api": true, + "fncall_prompt_type": "nous" + } + }, + "claude-3.7": { + "model": "claude-3-7-sonnet-20250219", + "model_server": "https://one.felo.me/v1", + "api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2", + "generate_cfg": { + "use_raw_api": true + } + }, + "gpt-4o": { + "model": "gpt-4o", + "model_server": "https://one-dev.felo.me/v1", + "api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4", + "generate_cfg": { + "use_raw_api": true, + "fncall_prompt_type": "nous" + } + }, + "glm-45": { + "model_server": "https://open.bigmodel.cn/api/paas/v4", + "api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm", + "model": "glm-4.5", + "generate_cfg": { + "use_raw_api": true + } + }, + "qwen3-next": { + "model": "qwen/qwen3-next-80b-a3b-instruct", + "model_server": "https://openrouter.ai/api/v1", + "api_key": "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" + }, + "deepresearch": { + "model": "alibaba/tongyi-deepresearch-30b-a3b", + "model_server": "https://openrouter.ai/api/v1", + "api_key": "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" + }, + "qwen3-coder": { + "model": "Qwen/Qwen3-Coder-30B-A3B-Instruct", + "model_server": "https://api-inference.modelscope.cn/v1", + "api_key": "ms-92027446-2787-4fd6-af01-f002459ec556" + }, + "openrouter-gpt4o": { + "model": "openai/gpt-4o", + "model_server": "https://openrouter.ai/api/v1", + "api_key": "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", + "generate_cfg": { + "use_raw_api": true, + "fncall_prompt_type": "nous" + } + } +} \ No newline at end of file diff --git a/mcp/__pycache__/json_reader_server.cpython-312.pyc b/mcp/__pycache__/json_reader_server.cpython-312.pyc new file mode 100644 index 0000000..c13b4ff Binary files /dev/null and b/mcp/__pycache__/json_reader_server.cpython-312.pyc differ diff --git a/mcp/__pycache__/mcp_wrapper.cpython-312.pyc b/mcp/__pycache__/mcp_wrapper.cpython-312.pyc new file mode 100644 index 0000000..081bcd1 Binary files /dev/null and b/mcp/__pycache__/mcp_wrapper.cpython-312.pyc differ diff --git a/mcp/directory_tree_wrapper_server.py b/mcp/directory_tree_wrapper_server.py new file mode 100644 index 0000000..f533988 --- /dev/null +++ b/mcp/directory_tree_wrapper_server.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +""" +目录树 MCP包装器服务器 +提供安全的目录结构查看功能,限制在项目目录内 +""" + +import asyncio +import json +import sys +from mcp_wrapper import handle_wrapped_request + + +async def main(): + """主入口点""" + try: + while True: + # 从stdin读取 + line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) + if not line: + break + + line = line.strip() + if not line: + continue + + try: + request = json.loads(line) + response = await handle_wrapped_request(request, "directory-tree-wrapper") + + # 写入stdout + sys.stdout.write(json.dumps(response) + "\n") + sys.stdout.flush() + + except json.JSONDecodeError: + error_response = { + "jsonrpc": "2.0", + "error": { + "code": -32700, + "message": "Parse error" + } + } + sys.stdout.write(json.dumps(error_response) + "\n") + sys.stdout.flush() + + except Exception as e: + error_response = { + "jsonrpc": "2.0", + "error": { + "code": -32603, + "message": f"Internal error: {str(e)}" + } + } + sys.stdout.write(json.dumps(error_response) + "\n") + sys.stdout.flush() + + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/mcp/json_reader_server.py b/mcp/json_reader_server.py index 94fda8b..d2d07b3 100644 --- a/mcp/json_reader_server.py +++ b/mcp/json_reader_server.py @@ -12,6 +12,32 @@ import sys import asyncio from typing import Any, Dict, List + +def validate_file_path(file_path: str, allowed_dir: str) -> str: + """验证文件路径是否在允许的目录内""" + # 转换为绝对路径 + if not os.path.isabs(file_path): + file_path = os.path.abspath(file_path) + + allowed_dir = os.path.abspath(allowed_dir) + + # 检查路径是否在允许的目录内 + if not file_path.startswith(allowed_dir): + raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir} 内") + + # 检查路径遍历攻击 + if ".." in file_path: + raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试") + + return file_path + + +def get_allowed_directory(): + """获取允许访问的目录""" + # 从环境变量读取项目数据目录 + project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") + return os.path.abspath(project_dir) + async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: @@ -130,9 +156,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: } try: - # Convert relative path to absolute path - if not os.path.isabs(file_path): - file_path = os.path.abspath(file_path) + # 验证文件路径是否在允许的目录内 + allowed_dir = get_allowed_directory() + file_path = validate_file_path(file_path, allowed_dir) with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) @@ -222,9 +248,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: } try: - # Convert relative path to absolute path - if not os.path.isabs(file_path): - file_path = os.path.abspath(file_path) + # 验证文件路径是否在允许的目录内 + allowed_dir = get_allowed_directory() + file_path = validate_file_path(file_path, allowed_dir) with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) @@ -307,9 +333,9 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: } try: - # Convert relative path to absolute path - if not os.path.isabs(file_path): - file_path = os.path.abspath(file_path) + # 验证文件路径是否在允许的目录内 + allowed_dir = get_allowed_directory() + file_path = validate_file_path(file_path, allowed_dir) with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) diff --git a/mcp/mcp_settings.json b/mcp/mcp_settings.json index ab6ed0f..ccc07f6 100644 --- a/mcp/mcp_settings.json +++ b/mcp/mcp_settings.json @@ -8,13 +8,6 @@ "@andredezzy/deep-directory-tree-mcp" ] }, - "mcp-server-code-runner": { - "command": "npx", - "args": [ - "-y", - "mcp-server-code-runner@latest" - ] - }, "ripgrep": { "command": "npx", "args": [ diff --git a/mcp/mcp_wrapper.py b/mcp/mcp_wrapper.py new file mode 100644 index 0000000..879b641 --- /dev/null +++ b/mcp/mcp_wrapper.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +""" +通用MCP工具包装器 +为所有MCP工具提供目录访问控制和安全限制 +""" + +import os +import sys +import json +import asyncio +import subprocess +from typing import Dict, Any, Optional +from pathlib import Path + + +class MCPSecurityWrapper: + """MCP安全包装器""" + + def __init__(self, allowed_directory: str): + self.allowed_directory = os.path.abspath(allowed_directory) + self.project_id = os.getenv("PROJECT_ID", "default") + + def validate_path(self, path: str) -> str: + """验证路径是否在允许的目录内""" + if not os.path.isabs(path): + path = os.path.abspath(path) + + # 规范化路径 + path = os.path.normpath(path) + + # 检查路径遍历 + if ".." in path.split(os.sep): + raise ValueError(f"路径遍历攻击被阻止: {path}") + + # 检查是否在允许的目录内 + if not path.startswith(self.allowed_directory): + raise ValueError(f"访问被拒绝: {path} 不在允许的目录 {self.allowed_directory} 内") + + return path + + def safe_execute_command(self, command: list, cwd: Optional[str] = None) -> Dict[str, Any]: + """安全执行命令,限制工作目录""" + if cwd is None: + cwd = self.allowed_directory + else: + cwd = self.validate_path(cwd) + + # 设置环境变量限制 + env = os.environ.copy() + env["PWD"] = cwd + env["PROJECT_DATA_DIR"] = self.allowed_directory + env["PROJECT_ID"] = self.project_id + + try: + result = subprocess.run( + command, + cwd=cwd, + env=env, + capture_output=True, + text=True, + timeout=30 # 30秒超时 + ) + + return { + "success": result.returncode == 0, + "stdout": result.stdout, + "stderr": result.stderr, + "returncode": result.returncode + } + except subprocess.TimeoutExpired: + return { + "success": False, + "stdout": "", + "stderr": "命令执行超时", + "returncode": -1 + } + except Exception as e: + return { + "success": False, + "stdout": "", + "stderr": str(e), + "returncode": -1 + } + + +async def handle_wrapped_request(request: Dict[str, Any], tool_name: str) -> Dict[str, Any]: + """处理包装后的MCP请求""" + try: + method = request.get("method") + params = request.get("params", {}) + request_id = request.get("id") + + allowed_dir = get_allowed_directory() + wrapper = MCPSecurityWrapper(allowed_dir) + + if method == "initialize": + return { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": { + "name": f"{tool_name}-wrapper", + "version": "1.0.0" + } + } + } + + elif method == "ping": + return { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "pong": True + } + } + + elif method == "tools/list": + return { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "tools": get_tool_definitions(tool_name) + } + } + + elif method == "tools/call": + return await execute_tool_call(wrapper, tool_name, params, request_id) + + else: + return { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32601, + "message": f"Unknown method: {method}" + } + } + + except Exception as e: + return { + "jsonrpc": "2.0", + "id": request.get("id"), + "error": { + "code": -32603, + "message": f"Internal error: {str(e)}" + } + } + + +def get_allowed_directory(): + """获取允许访问的目录""" + project_dir = os.getenv("PROJECT_DATA_DIR", "./data") + return os.path.abspath(project_dir) + + +def get_tool_definitions(tool_name: str) -> list: + """根据工具名称返回工具定义""" + if tool_name == "ripgrep-wrapper": + return [ + { + "name": "ripgrep_search", + "description": "在项目目录内搜索文本", + "inputSchema": { + "type": "object", + "properties": { + "pattern": {"type": "string", "description": "搜索模式"}, + "path": {"type": "string", "description": "搜索路径(相对于项目目录)"}, + "maxResults": {"type": "integer", "default": 100} + }, + "required": ["pattern"] + } + } + ] + elif tool_name == "directory-tree-wrapper": + return [ + { + "name": "get_directory_tree", + "description": "获取项目目录结构", + "inputSchema": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "目录路径(相对于项目目录)"}, + "max_depth": {"type": "integer", "default": 3} + } + } + } + ] + else: + return [] + + +async def execute_tool_call(wrapper: MCPSecurityWrapper, tool_name: str, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]: + """执行工具调用""" + try: + if tool_name == "ripgrep-wrapper": + return await execute_ripgrep_search(wrapper, params, request_id) + elif tool_name == "directory-tree-wrapper": + return await execute_directory_tree(wrapper, params, request_id) + else: + return { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32601, + "message": f"Unknown tool: {tool_name}" + } + } + except Exception as e: + return { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32603, + "message": str(e) + } + } + + +async def execute_ripgrep_search(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]: + """执行ripgrep搜索""" + pattern = params.get("pattern", "") + path = params.get("path", ".") + max_results = params.get("maxResults", 100) + + # 验证和构建搜索路径 + search_path = os.path.join(wrapper.allowed_directory, path) + search_path = wrapper.validate_path(search_path) + + # 构建ripgrep命令 + command = [ + "rg", + "--json", + "--max-count", str(max_results), + pattern, + search_path + ] + + result = wrapper.safe_execute_command(command) + + if result["success"]: + return { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "content": [ + { + "type": "text", + "text": result["stdout"] + } + ] + } + } + else: + return { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32603, + "message": f"搜索失败: {result['stderr']}" + } + } + + +async def execute_directory_tree(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]: + """执行目录树获取""" + path = params.get("path", ".") + max_depth = params.get("max_depth", 3) + + # 验证和构建目录路径 + dir_path = os.path.join(wrapper.allowed_directory, path) + dir_path = wrapper.validate_path(dir_path) + + # 构建目录树命令 + command = [ + "find", + dir_path, + "-type", "d", + "-maxdepth", str(max_depth) + ] + + result = wrapper.safe_execute_command(command) + + if result["success"]: + return { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "content": [ + { + "type": "text", + "text": result["stdout"] + } + ] + } + } + else: + return { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32603, + "message": f"获取目录树失败: {result['stderr']}" + } + } \ No newline at end of file diff --git a/mcp/ripgrep_wrapper_server.py b/mcp/ripgrep_wrapper_server.py new file mode 100644 index 0000000..ce6ac84 --- /dev/null +++ b/mcp/ripgrep_wrapper_server.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +""" +Ripgrep MCP包装器服务器 +提供安全的文本搜索功能,限制在项目目录内 +""" + +import asyncio +import json +import sys +from mcp_wrapper import handle_wrapped_request + + +async def main(): + """主入口点""" + try: + while True: + # 从stdin读取 + line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) + if not line: + break + + line = line.strip() + if not line: + continue + + try: + request = json.loads(line) + response = await handle_wrapped_request(request, "ripgrep-wrapper") + + # 写入stdout + sys.stdout.write(json.dumps(response) + "\n") + sys.stdout.flush() + + except json.JSONDecodeError: + error_response = { + "jsonrpc": "2.0", + "error": { + "code": -32700, + "message": "Parse error" + } + } + sys.stdout.write(json.dumps(error_response) + "\n") + sys.stdout.flush() + + except Exception as e: + error_response = { + "jsonrpc": "2.0", + "error": { + "code": -32603, + "message": f"Internal error: {str(e)}" + } + } + sys.stdout.write(json.dumps(error_response) + "\n") + sys.stdout.flush() + + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/project_config.py b/project_config.py new file mode 100644 index 0000000..0277cd3 --- /dev/null +++ b/project_config.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +项目配置管理系统 +负责管理项目ID到数据目录的映射,以及项目访问权限控制 +""" + +import json +import os +from typing import Dict, Optional, List +from dataclasses import dataclass, asdict + + +@dataclass +class ProjectConfig: + """项目配置数据类""" + project_id: str + data_dir: str + name: str + description: str = "" + allowed_file_types: List[str] = None + max_file_size_mb: int = 100 + is_active: bool = True + + def __post_init__(self): + if self.allowed_file_types is None: + self.allowed_file_types = [".json", ".txt", ".csv", ".pdf"] + + +class ProjectManager: + """项目管理器""" + + def __init__(self, config_file: str = "./projects/project_registry.json"): + self.config_file = config_file + self.projects: Dict[str, ProjectConfig] = {} + self._ensure_config_dir() + self._load_projects() + + def _ensure_config_dir(self): + """确保配置目录存在""" + config_dir = os.path.dirname(self.config_file) + if not os.path.exists(config_dir): + os.makedirs(config_dir, exist_ok=True) + + def _load_projects(self): + """从配置文件加载项目""" + if os.path.exists(self.config_file): + try: + with open(self.config_file, 'r', encoding='utf-8') as f: + data = json.load(f) + for project_data in data.get('projects', []): + config = ProjectConfig(**project_data) + self.projects[config.project_id] = config + except Exception as e: + print(f"加载项目配置失败: {e}") + self._create_default_config() + else: + self._create_default_config() + + def _create_default_config(self): + """创建默认配置""" + default_project = ProjectConfig( + project_id="default", + data_dir="./data", + name="默认项目", + description="默认数据项目" + ) + self.projects["default"] = default_project + self._save_projects() + + def _save_projects(self): + """保存项目配置到文件""" + data = { + "projects": [asdict(project) for project in self.projects.values()] + } + try: + with open(self.config_file, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"保存项目配置失败: {e}") + + def get_project(self, project_id: str) -> Optional[ProjectConfig]: + """获取项目配置""" + return self.projects.get(project_id) + + def add_project(self, config: ProjectConfig) -> bool: + """添加项目""" + if config.project_id in self.projects: + return False + + # 确保数据目录存在 + if not os.path.isabs(config.data_dir): + config.data_dir = os.path.abspath(config.data_dir) + + os.makedirs(config.data_dir, exist_ok=True) + + self.projects[config.project_id] = config + self._save_projects() + return True + + def update_project(self, project_id: str, **kwargs) -> bool: + """更新项目配置""" + if project_id not in self.projects: + return False + + project = self.projects[project_id] + for key, value in kwargs.items(): + if hasattr(project, key): + setattr(project, key, value) + + self._save_projects() + return True + + def delete_project(self, project_id: str) -> bool: + """删除项目""" + if project_id not in self.projects: + return False + + del self.projects[project_id] + self._save_projects() + return True + + def list_projects(self) -> List[ProjectConfig]: + """列出所有项目""" + return list(self.projects.values()) + + def get_project_dir(self, project_id: str) -> str: + """获取项目数据目录""" + project = self.get_project(project_id) + if project: + return project.data_dir + + # 如果项目不存在,创建默认目录结构 + default_dir = f"./projects/{project_id}/data" + os.makedirs(default_dir, exist_ok=True) + + # 自动创建新项目配置 + new_project = ProjectConfig( + project_id=project_id, + data_dir=default_dir, + name=f"项目 {project_id}", + description=f"自动创建的项目 {project_id}" + ) + self.add_project(new_project) + + return default_dir + + def validate_project_access(self, project_id: str) -> bool: + """验证项目访问权限""" + project = self.get_project(project_id) + return project and project.is_active + + +# 全局项目管理器实例 +project_manager = ProjectManager() \ No newline at end of file diff --git a/data/all_hp_product_spec_book2506/document.txt b/projects/demo-project/all_hp_product_spec_book2506/document.txt similarity index 100% rename from data/all_hp_product_spec_book2506/document.txt rename to projects/demo-project/all_hp_product_spec_book2506/document.txt diff --git a/data/all_hp_product_spec_book2506/schema.json b/projects/demo-project/all_hp_product_spec_book2506/schema.json similarity index 100% rename from data/all_hp_product_spec_book2506/schema.json rename to projects/demo-project/all_hp_product_spec_book2506/schema.json diff --git a/data/all_hp_product_spec_book2506/serialization.txt b/projects/demo-project/all_hp_product_spec_book2506/serialization.txt similarity index 100% rename from data/all_hp_product_spec_book2506/serialization.txt rename to projects/demo-project/all_hp_product_spec_book2506/serialization.txt diff --git a/projects/project_registry.json b/projects/project_registry.json new file mode 100644 index 0000000..6d8d20d --- /dev/null +++ b/projects/project_registry.json @@ -0,0 +1,18 @@ +{ + "projects": [ + { + "project_id": "demo-project", + "data_dir": "/Users/moshui/Documents/felo/qwen-agent/projects/demo-project/", + "name": "演示项目", + "description": "演示多项目隔离功能", + "allowed_file_types": [ + ".json", + ".txt", + ".csv", + ".pdf" + ], + "max_file_size_mb": 100, + "is_active": true + } + ] +} \ No newline at end of file