This commit is contained in:
朱潮 2025-12-15 21:58:54 +08:00
parent 9ada70eb58
commit 77c8f5e501
10 changed files with 75 additions and 38 deletions

View File

@ -13,10 +13,7 @@ from langgraph.checkpoint.memory import MemorySaver
from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
from utils.settings import SUMMARIZATION_MAX_TOKENS
class LoggingCallbackHandler(BaseCallbackHandler):
"""自定义的 CallbackHandler使用项目的 logger 来记录日志"""
@ -182,7 +179,7 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
)
middleware.append(summarization_middleware)
agent = create_agent(
model=llm_instance,
system_prompt=system,

View File

@ -8,6 +8,7 @@ import asyncio
import hashlib
import json
import logging
from utils.settings import FASTAPI_URL
# Configure logger
logger = logging.getLogger('app')
@ -19,8 +20,7 @@ def encode_texts_via_api(texts, batch_size=32):
try:
# FastAPI 服务地址
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
api_endpoint = f"{FASTAPI_URL}/api/v1/embedding/encode"
# 调用编码接口
request_data = {

View File

@ -13,6 +13,7 @@ import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from collections import OrderedDict
from utils.settings import SENTENCE_TRANSFORMER_MODEL
import threading
import psutil
import numpy as np
@ -126,6 +127,5 @@ def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
global _model_manager
if _model_manager is None:
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
_model_manager = GlobalModelManager(model_name)
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
return _model_manager

View File

@ -11,7 +11,6 @@ import asyncio
from typing import Any, Dict, List, Optional, Union
import re
def get_allowed_directory():
"""获取允许访问的目录"""
# 优先使用命令行参数传入的dataset_dir

View File

@ -26,8 +26,8 @@ from mcp_common import (
load_tools_from_json,
handle_mcp_streaming
)
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
def rag_retrieve(query: str) -> Dict[str, Any]:
"""调用RAG检索API"""
@ -36,7 +36,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
if len(sys.argv) > 1:
bot_id = sys.argv[1]
url = f"{backend_host}/v1/rag_retrieve/{bot_id}"
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
if not url:
return {
"content": [
@ -48,7 +48,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
}
# 获取masterkey并生成认证token
masterkey = os.getenv("MASTERKEY", "master")
masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest()

View File

@ -25,6 +25,7 @@ from mcp_common import (
create_tools_list_response,
handle_mcp_streaming
)
from utils.settings import FASTAPI_URL
import requests
@ -32,7 +33,7 @@ import requests
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
"""通过API编码单个查询"""
if not fastapi_url:
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
fastapi_url = FASTAPI_URL
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"

View File

@ -20,13 +20,14 @@ from utils.fastapi_utils import (
create_stream_chunk
)
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
router = APIRouter()
# 初始化全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")),
shard_count=int(os.getenv("SHARD_COUNT", "16"))
max_cached_agents=MAX_CACHED_AGENTS,
shard_count=SHARD_COUNT
)
@ -148,7 +149,7 @@ async def enhanced_generate_stream_response(
config["configurable"] = {"thread_id": session_id}
if hasattr(agent, 'logging_handler'):
config["callbacks"] = [agent.logging_handler]
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config):
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS):
new_content = ""
if isinstance(msg, AIMessageChunk):
@ -328,7 +329,7 @@ async def create_agent_and_generate_response(
config["configurable"] = {"thread_id": session_id}
if hasattr(agent, 'logging_handler'):
config["callbacks"] = [agent.logging_handler]
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config)
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(final_messages):]
response_text = ""
for msg in append_messages:

View File

@ -20,6 +20,11 @@ except ImportError:
from embedding import get_model_manager
from pydantic import BaseModel
import logging
from utils.settings import (
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
TOKENIZERS_PARALLELISM
)
logger = logging.getLogger('app')
@ -45,28 +50,28 @@ logger.info("正在初始化系统优化...")
system_optimizer = setup_system_optimizations()
# 全局助手管理器配置(使用优化后的配置)
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量
max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
shard_count = SHARD_COUNT # 分片数量
# 初始化优化的全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=max_cached_agents,
max_cached_agents=max_cached_agents,
shard_count=shard_count
)
# 初始化连接池
connection_pool = init_global_connection_pool(
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")),
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")),
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")),
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")),
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60"))
max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
max_connections_total=MAX_CONNECTIONS_TOTAL,
keepalive_timeout=KEEPALIVE_TIMEOUT,
connect_timeout=CONNECT_TIMEOUT,
total_timeout=TOTAL_TIMEOUT
)
# 初始化文件缓存
file_cache = init_global_file_cache(
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")),
ttl=int(os.getenv("FILE_CACHE_TTL", "300"))
cache_size=FILE_CACHE_SIZE,
ttl=FILE_CACHE_TTL
)
logger.info("系统优化初始化完成")
@ -191,11 +196,11 @@ async def get_system_config():
"config": {
"max_cached_agents": max_cached_agents,
"shard_count": shard_count,
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"),
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"),
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"),
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"),
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300")
"tokenizer_parallelism": TOKENIZERS_PARALLELISM,
"max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
"max_connections_total": str(MAX_CONNECTIONS_TOTAL),
"file_cache_size": str(FILE_CACHE_SIZE),
"file_cache_ttl": str(FILE_CACHE_TTL)
}
}

View File

@ -10,6 +10,7 @@ from fastapi import HTTPException
import logging
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain.chat_models import init_chat_model
from utils.settings import MASTERKEY, BACKEND_HOST
USER = "user"
ASSISTANT = "assistant"
@ -389,16 +390,14 @@ def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
def generate_v2_auth_token(bot_id: str) -> str:
"""生成v2接口的认证token"""
masterkey = os.getenv("MASTERKEY", "master")
token_input = f"{masterkey}:{bot_id}"
token_input = f"{MASTERKEY}:{bot_id}"
return hashlib.md5(token_input.encode()).hexdigest()
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
"""获取机器人配置从后端API"""
try:
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
auth_token = generate_v2_auth_token(bot_id)
headers = {

35
utils/settings.py Normal file
View File

@ -0,0 +1,35 @@
import os
# LLM Token Settings
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
# Agent and Shard Settings
MAX_CACHED_AGENTS = int(os.getenv("MAX_CACHED_AGENTS", 50))
SHARD_COUNT = int(os.getenv("SHARD_COUNT", 16))
# Connection Settings
MAX_CONNECTIONS_PER_HOST = int(os.getenv("MAX_CONNECTIONS_PER_HOST", 100))
MAX_CONNECTIONS_TOTAL = int(os.getenv("MAX_CONNECTIONS_TOTAL", 500))
KEEPALIVE_TIMEOUT = int(os.getenv("KEEPALIVE_TIMEOUT", 30))
CONNECT_TIMEOUT = int(os.getenv("CONNECT_TIMEOUT", 10))
TOTAL_TIMEOUT = int(os.getenv("TOTAL_TIMEOUT", 60))
# File Cache Settings
FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000))
FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300))
# API Settings
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
# Project Settings
PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data")
# Tokenizer Settings
TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true")
# Embedding Model Settings
SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")