This commit is contained in:
朱潮 2025-12-15 21:58:54 +08:00
parent 9ada70eb58
commit 77c8f5e501
10 changed files with 75 additions and 38 deletions

View File

@ -13,10 +13,7 @@ from langgraph.checkpoint.memory import MemorySaver
from utils.fastapi_utils import detect_provider from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware from .guideline_middleware import GuidelineMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
class LoggingCallbackHandler(BaseCallbackHandler): class LoggingCallbackHandler(BaseCallbackHandler):
"""自定义的 CallbackHandler使用项目的 logger 来记录日志""" """自定义的 CallbackHandler使用项目的 logger 来记录日志"""

View File

@ -8,6 +8,7 @@ import asyncio
import hashlib import hashlib
import json import json
import logging import logging
from utils.settings import FASTAPI_URL
# Configure logger # Configure logger
logger = logging.getLogger('app') logger = logging.getLogger('app')
@ -19,8 +20,7 @@ def encode_texts_via_api(texts, batch_size=32):
try: try:
# FastAPI 服务地址 # FastAPI 服务地址
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') api_endpoint = f"{FASTAPI_URL}/api/v1/embedding/encode"
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
# 调用编码接口 # 调用编码接口
request_data = { request_data = {

View File

@ -13,6 +13,7 @@ import logging
from typing import Dict, List, Optional, Any, Tuple from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from collections import OrderedDict from collections import OrderedDict
from utils.settings import SENTENCE_TRANSFORMER_MODEL
import threading import threading
import psutil import psutil
import numpy as np import numpy as np
@ -126,6 +127,5 @@ def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例""" """获取模型管理器实例"""
global _model_manager global _model_manager
if _model_manager is None: if _model_manager is None:
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny") _model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
_model_manager = GlobalModelManager(model_name)
return _model_manager return _model_manager

View File

@ -11,7 +11,6 @@ import asyncio
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import re import re
def get_allowed_directory(): def get_allowed_directory():
"""获取允许访问的目录""" """获取允许访问的目录"""
# 优先使用命令行参数传入的dataset_dir # 优先使用命令行参数传入的dataset_dir

View File

@ -26,8 +26,8 @@ from mcp_common import (
load_tools_from_json, load_tools_from_json,
handle_mcp_streaming handle_mcp_streaming
) )
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") MASTERKEY = os.getenv("MASTERKEY", "master")
def rag_retrieve(query: str) -> Dict[str, Any]: def rag_retrieve(query: str) -> Dict[str, Any]:
"""调用RAG检索API""" """调用RAG检索API"""
@ -36,7 +36,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
if len(sys.argv) > 1: if len(sys.argv) > 1:
bot_id = sys.argv[1] bot_id = sys.argv[1]
url = f"{backend_host}/v1/rag_retrieve/{bot_id}" url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
if not url: if not url:
return { return {
"content": [ "content": [
@ -48,7 +48,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
} }
# 获取masterkey并生成认证token # 获取masterkey并生成认证token
masterkey = os.getenv("MASTERKEY", "master") masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}" token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest() auth_token = hashlib.md5(token_input.encode()).hexdigest()

View File

@ -25,6 +25,7 @@ from mcp_common import (
create_tools_list_response, create_tools_list_response,
handle_mcp_streaming handle_mcp_streaming
) )
from utils.settings import FASTAPI_URL
import requests import requests
@ -32,7 +33,7 @@ import requests
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray: def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
"""通过API编码单个查询""" """通过API编码单个查询"""
if not fastapi_url: if not fastapi_url:
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') fastapi_url = FASTAPI_URL
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode" api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"

View File

@ -20,13 +20,14 @@ from utils.fastapi_utils import (
create_stream_chunk create_stream_chunk
) )
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
router = APIRouter() router = APIRouter()
# 初始化全局助手管理器 # 初始化全局助手管理器
agent_manager = init_global_sharded_agent_manager( agent_manager = init_global_sharded_agent_manager(
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")), max_cached_agents=MAX_CACHED_AGENTS,
shard_count=int(os.getenv("SHARD_COUNT", "16")) shard_count=SHARD_COUNT
) )
@ -148,7 +149,7 @@ async def enhanced_generate_stream_response(
config["configurable"] = {"thread_id": session_id} config["configurable"] = {"thread_id": session_id}
if hasattr(agent, 'logging_handler'): if hasattr(agent, 'logging_handler'):
config["callbacks"] = [agent.logging_handler] config["callbacks"] = [agent.logging_handler]
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config): async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS):
new_content = "" new_content = ""
if isinstance(msg, AIMessageChunk): if isinstance(msg, AIMessageChunk):
@ -328,7 +329,7 @@ async def create_agent_and_generate_response(
config["configurable"] = {"thread_id": session_id} config["configurable"] = {"thread_id": session_id}
if hasattr(agent, 'logging_handler'): if hasattr(agent, 'logging_handler'):
config["callbacks"] = [agent.logging_handler] config["callbacks"] = [agent.logging_handler]
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config) agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(final_messages):] append_messages = agent_responses["messages"][len(final_messages):]
response_text = "" response_text = ""
for msg in append_messages: for msg in append_messages:

View File

@ -20,6 +20,11 @@ except ImportError:
from embedding import get_model_manager from embedding import get_model_manager
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
from utils.settings import (
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
TOKENIZERS_PARALLELISM
)
logger = logging.getLogger('app') logger = logging.getLogger('app')
@ -45,8 +50,8 @@ logger.info("正在初始化系统优化...")
system_optimizer = setup_system_optimizations() system_optimizer = setup_system_optimizations()
# 全局助手管理器配置(使用优化后的配置) # 全局助手管理器配置(使用优化后的配置)
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小 max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量 shard_count = SHARD_COUNT # 分片数量
# 初始化优化的全局助手管理器 # 初始化优化的全局助手管理器
agent_manager = init_global_sharded_agent_manager( agent_manager = init_global_sharded_agent_manager(
@ -56,17 +61,17 @@ agent_manager = init_global_sharded_agent_manager(
# 初始化连接池 # 初始化连接池
connection_pool = init_global_connection_pool( connection_pool = init_global_connection_pool(
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")), max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")), max_connections_total=MAX_CONNECTIONS_TOTAL,
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")), keepalive_timeout=KEEPALIVE_TIMEOUT,
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")), connect_timeout=CONNECT_TIMEOUT,
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60")) total_timeout=TOTAL_TIMEOUT
) )
# 初始化文件缓存 # 初始化文件缓存
file_cache = init_global_file_cache( file_cache = init_global_file_cache(
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")), cache_size=FILE_CACHE_SIZE,
ttl=int(os.getenv("FILE_CACHE_TTL", "300")) ttl=FILE_CACHE_TTL
) )
logger.info("系统优化初始化完成") logger.info("系统优化初始化完成")
@ -191,11 +196,11 @@ async def get_system_config():
"config": { "config": {
"max_cached_agents": max_cached_agents, "max_cached_agents": max_cached_agents,
"shard_count": shard_count, "shard_count": shard_count,
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"), "tokenizer_parallelism": TOKENIZERS_PARALLELISM,
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"), "max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"), "max_connections_total": str(MAX_CONNECTIONS_TOTAL),
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"), "file_cache_size": str(FILE_CACHE_SIZE),
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300") "file_cache_ttl": str(FILE_CACHE_TTL)
} }
} }

View File

@ -10,6 +10,7 @@ from fastapi import HTTPException
import logging import logging
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from utils.settings import MASTERKEY, BACKEND_HOST
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
@ -389,16 +390,14 @@ def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
def generate_v2_auth_token(bot_id: str) -> str: def generate_v2_auth_token(bot_id: str) -> str:
"""生成v2接口的认证token""" """生成v2接口的认证token"""
masterkey = os.getenv("MASTERKEY", "master") token_input = f"{MASTERKEY}:{bot_id}"
token_input = f"{masterkey}:{bot_id}"
return hashlib.md5(token_input.encode()).hexdigest() return hashlib.md5(token_input.encode()).hexdigest()
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]: async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
"""获取机器人配置从后端API""" """获取机器人配置从后端API"""
try: try:
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
auth_token = generate_v2_auth_token(bot_id) auth_token = generate_v2_auth_token(bot_id)
headers = { headers = {

35
utils/settings.py Normal file
View File

@ -0,0 +1,35 @@
import os
# LLM Token Settings
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
# Agent and Shard Settings
MAX_CACHED_AGENTS = int(os.getenv("MAX_CACHED_AGENTS", 50))
SHARD_COUNT = int(os.getenv("SHARD_COUNT", 16))
# Connection Settings
MAX_CONNECTIONS_PER_HOST = int(os.getenv("MAX_CONNECTIONS_PER_HOST", 100))
MAX_CONNECTIONS_TOTAL = int(os.getenv("MAX_CONNECTIONS_TOTAL", 500))
KEEPALIVE_TIMEOUT = int(os.getenv("KEEPALIVE_TIMEOUT", 30))
CONNECT_TIMEOUT = int(os.getenv("CONNECT_TIMEOUT", 10))
TOTAL_TIMEOUT = int(os.getenv("TOTAL_TIMEOUT", 60))
# File Cache Settings
FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000))
FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300))
# API Settings
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
# Project Settings
PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data")
# Tokenizer Settings
TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true")
# Embedding Model Settings
SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")