From 77c8f5e5013f0c0d7857904ab2d59a705f646492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Mon, 15 Dec 2025 21:58:54 +0800 Subject: [PATCH] settings --- agent/deep_assistant.py | 7 ++----- embedding/embedding.py | 4 ++-- embedding/manager.py | 4 ++-- mcp/mcp_common.py | 1 - mcp/rag_retrieve_server.py | 8 ++++---- mcp/semantic_search_server.py | 3 ++- routes/chat.py | 9 +++++---- routes/system.py | 35 ++++++++++++++++++++--------------- utils/fastapi_utils.py | 7 +++---- utils/settings.py | 35 +++++++++++++++++++++++++++++++++++ 10 files changed, 75 insertions(+), 38 deletions(-) create mode 100644 utils/settings.py diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 32b0f6f..f49a63f 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -13,10 +13,7 @@ from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware - -MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536)) -MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) -SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 +from utils.settings import SUMMARIZATION_MAX_TOKENS class LoggingCallbackHandler(BaseCallbackHandler): """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" @@ -182,7 +179,7 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" ) middleware.append(summarization_middleware) - + agent = create_agent( model=llm_instance, system_prompt=system, diff --git a/embedding/embedding.py b/embedding/embedding.py index 1cc61b6..9c66385 100644 --- a/embedding/embedding.py +++ b/embedding/embedding.py @@ -8,6 +8,7 @@ import asyncio import hashlib import json import logging +from utils.settings import FASTAPI_URL # Configure logger logger = logging.getLogger('app') @@ -19,8 +20,7 @@ def encode_texts_via_api(texts, batch_size=32): try: # FastAPI 服务地址 - fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') - api_endpoint = f"{fastapi_url}/api/v1/embedding/encode" + api_endpoint = f"{FASTAPI_URL}/api/v1/embedding/encode" # 调用编码接口 request_data = { diff --git a/embedding/manager.py b/embedding/manager.py index 9fb506d..807a57b 100644 --- a/embedding/manager.py +++ b/embedding/manager.py @@ -13,6 +13,7 @@ import logging from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from collections import OrderedDict +from utils.settings import SENTENCE_TRANSFORMER_MODEL import threading import psutil import numpy as np @@ -126,6 +127,5 @@ def get_model_manager() -> GlobalModelManager: """获取模型管理器实例""" global _model_manager if _model_manager is None: - model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny") - _model_manager = GlobalModelManager(model_name) + _model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL) return _model_manager diff --git a/mcp/mcp_common.py b/mcp/mcp_common.py index f016fe6..a08f4ef 100644 --- a/mcp/mcp_common.py +++ b/mcp/mcp_common.py @@ -11,7 +11,6 @@ import asyncio from typing import Any, Dict, List, Optional, Union import re - def get_allowed_directory(): """获取允许访问的目录""" # 优先使用命令行参数传入的dataset_dir diff --git a/mcp/rag_retrieve_server.py b/mcp/rag_retrieve_server.py index ae832e0..18c8187 100644 --- a/mcp/rag_retrieve_server.py +++ b/mcp/rag_retrieve_server.py @@ -26,8 +26,8 @@ from mcp_common import ( load_tools_from_json, handle_mcp_streaming ) - -backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") +BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") +MASTERKEY = os.getenv("MASTERKEY", "master") def rag_retrieve(query: str) -> Dict[str, Any]: """调用RAG检索API""" @@ -36,7 +36,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]: if len(sys.argv) > 1: bot_id = sys.argv[1] - url = f"{backend_host}/v1/rag_retrieve/{bot_id}" + url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}" if not url: return { "content": [ @@ -48,7 +48,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]: } # 获取masterkey并生成认证token - masterkey = os.getenv("MASTERKEY", "master") + masterkey = MASTERKEY token_input = f"{masterkey}:{bot_id}" auth_token = hashlib.md5(token_input.encode()).hexdigest() diff --git a/mcp/semantic_search_server.py b/mcp/semantic_search_server.py index 242fb70..655c010 100644 --- a/mcp/semantic_search_server.py +++ b/mcp/semantic_search_server.py @@ -25,6 +25,7 @@ from mcp_common import ( create_tools_list_response, handle_mcp_streaming ) +from utils.settings import FASTAPI_URL import requests @@ -32,7 +33,7 @@ import requests def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray: """通过API编码单个查询""" if not fastapi_url: - fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') + fastapi_url = FASTAPI_URL api_endpoint = f"{fastapi_url}/api/v1/embedding/encode" diff --git a/routes/chat.py b/routes/chat.py index 52e3509..d695d39 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -20,13 +20,14 @@ from utils.fastapi_utils import ( create_stream_chunk ) from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage +from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT router = APIRouter() # 初始化全局助手管理器 agent_manager = init_global_sharded_agent_manager( - max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")), - shard_count=int(os.getenv("SHARD_COUNT", "16")) + max_cached_agents=MAX_CACHED_AGENTS, + shard_count=SHARD_COUNT ) @@ -148,7 +149,7 @@ async def enhanced_generate_stream_response( config["configurable"] = {"thread_id": session_id} if hasattr(agent, 'logging_handler'): config["callbacks"] = [agent.logging_handler] - async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config): + async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS): new_content = "" if isinstance(msg, AIMessageChunk): @@ -328,7 +329,7 @@ async def create_agent_and_generate_response( config["configurable"] = {"thread_id": session_id} if hasattr(agent, 'logging_handler'): config["callbacks"] = [agent.logging_handler] - agent_responses = await agent.ainvoke({"messages": final_messages}, config=config) + agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS) append_messages = agent_responses["messages"][len(final_messages):] response_text = "" for msg in append_messages: diff --git a/routes/system.py b/routes/system.py index 70e927c..f6b44a8 100644 --- a/routes/system.py +++ b/routes/system.py @@ -20,6 +20,11 @@ except ImportError: from embedding import get_model_manager from pydantic import BaseModel import logging +from utils.settings import ( + MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL, + KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL, + TOKENIZERS_PARALLELISM +) logger = logging.getLogger('app') @@ -45,28 +50,28 @@ logger.info("正在初始化系统优化...") system_optimizer = setup_system_optimizations() # 全局助手管理器配置(使用优化后的配置) -max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小 -shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量 +max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小 +shard_count = SHARD_COUNT # 分片数量 # 初始化优化的全局助手管理器 agent_manager = init_global_sharded_agent_manager( - max_cached_agents=max_cached_agents, + max_cached_agents=max_cached_agents, shard_count=shard_count ) # 初始化连接池 connection_pool = init_global_connection_pool( - max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")), - max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")), - keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")), - connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")), - total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60")) + max_connections_per_host=MAX_CONNECTIONS_PER_HOST, + max_connections_total=MAX_CONNECTIONS_TOTAL, + keepalive_timeout=KEEPALIVE_TIMEOUT, + connect_timeout=CONNECT_TIMEOUT, + total_timeout=TOTAL_TIMEOUT ) # 初始化文件缓存 file_cache = init_global_file_cache( - cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")), - ttl=int(os.getenv("FILE_CACHE_TTL", "300")) + cache_size=FILE_CACHE_SIZE, + ttl=FILE_CACHE_TTL ) logger.info("系统优化初始化完成") @@ -191,11 +196,11 @@ async def get_system_config(): "config": { "max_cached_agents": max_cached_agents, "shard_count": shard_count, - "tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"), - "max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"), - "max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"), - "file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"), - "file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300") + "tokenizer_parallelism": TOKENIZERS_PARALLELISM, + "max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST), + "max_connections_total": str(MAX_CONNECTIONS_TOTAL), + "file_cache_size": str(FILE_CACHE_SIZE), + "file_cache_ttl": str(FILE_CACHE_TTL) } } diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index cf78233..29ea6d1 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -10,6 +10,7 @@ from fastapi import HTTPException import logging from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain.chat_models import init_chat_model +from utils.settings import MASTERKEY, BACKEND_HOST USER = "user" ASSISTANT = "assistant" @@ -389,16 +390,14 @@ def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]: def generate_v2_auth_token(bot_id: str) -> str: """生成v2接口的认证token""" - masterkey = os.getenv("MASTERKEY", "master") - token_input = f"{masterkey}:{bot_id}" + token_input = f"{MASTERKEY}:{bot_id}" return hashlib.md5(token_input.encode()).hexdigest() async def fetch_bot_config(bot_id: str) -> Dict[str, Any]: """获取机器人配置从后端API""" try: - backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") - url = f"{backend_host}/v1/agent_bot_config/{bot_id}" + url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}" auth_token = generate_v2_auth_token(bot_id) headers = { diff --git a/utils/settings.py b/utils/settings.py new file mode 100644 index 0000000..0714fe0 --- /dev/null +++ b/utils/settings.py @@ -0,0 +1,35 @@ +import os + +# LLM Token Settings +MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536)) +MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) +SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 + +# Agent and Shard Settings +MAX_CACHED_AGENTS = int(os.getenv("MAX_CACHED_AGENTS", 50)) +SHARD_COUNT = int(os.getenv("SHARD_COUNT", 16)) + +# Connection Settings +MAX_CONNECTIONS_PER_HOST = int(os.getenv("MAX_CONNECTIONS_PER_HOST", 100)) +MAX_CONNECTIONS_TOTAL = int(os.getenv("MAX_CONNECTIONS_TOTAL", 500)) +KEEPALIVE_TIMEOUT = int(os.getenv("KEEPALIVE_TIMEOUT", 30)) +CONNECT_TIMEOUT = int(os.getenv("CONNECT_TIMEOUT", 10)) +TOTAL_TIMEOUT = int(os.getenv("TOTAL_TIMEOUT", 60)) + +# File Cache Settings +FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000)) +FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300)) + +# API Settings +BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") +MASTERKEY = os.getenv("MASTERKEY", "master") +FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') + +# Project Settings +PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data") + +# Tokenizer Settings +TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true") + +# Embedding Model Settings +SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny") \ No newline at end of file