settings
This commit is contained in:
parent
9ada70eb58
commit
77c8f5e501
@ -13,10 +13,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
||||
from utils.fastapi_utils import detect_provider
|
||||
|
||||
from .guideline_middleware import GuidelineMiddleware
|
||||
|
||||
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
|
||||
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
||||
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
|
||||
from utils.settings import SUMMARIZATION_MAX_TOKENS
|
||||
|
||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||||
@ -182,7 +179,7 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
||||
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
||||
)
|
||||
middleware.append(summarization_middleware)
|
||||
|
||||
|
||||
agent = create_agent(
|
||||
model=llm_instance,
|
||||
system_prompt=system,
|
||||
|
||||
@ -8,6 +8,7 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from utils.settings import FASTAPI_URL
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger('app')
|
||||
@ -19,8 +20,7 @@ def encode_texts_via_api(texts, batch_size=32):
|
||||
|
||||
try:
|
||||
# FastAPI 服务地址
|
||||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
||||
api_endpoint = f"{FASTAPI_URL}/api/v1/embedding/encode"
|
||||
|
||||
# 调用编码接口
|
||||
request_data = {
|
||||
|
||||
@ -13,6 +13,7 @@ import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
from utils.settings import SENTENCE_TRANSFORMER_MODEL
|
||||
import threading
|
||||
import psutil
|
||||
import numpy as np
|
||||
@ -126,6 +127,5 @@ def get_model_manager() -> GlobalModelManager:
|
||||
"""获取模型管理器实例"""
|
||||
global _model_manager
|
||||
if _model_manager is None:
|
||||
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
|
||||
_model_manager = GlobalModelManager(model_name)
|
||||
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
|
||||
return _model_manager
|
||||
|
||||
@ -11,7 +11,6 @@ import asyncio
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import re
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 优先使用命令行参数传入的dataset_dir
|
||||
|
||||
@ -26,8 +26,8 @@ from mcp_common import (
|
||||
load_tools_from_json,
|
||||
handle_mcp_streaming
|
||||
)
|
||||
|
||||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||
MASTERKEY = os.getenv("MASTERKEY", "master")
|
||||
|
||||
def rag_retrieve(query: str) -> Dict[str, Any]:
|
||||
"""调用RAG检索API"""
|
||||
@ -36,7 +36,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
|
||||
if len(sys.argv) > 1:
|
||||
bot_id = sys.argv[1]
|
||||
|
||||
url = f"{backend_host}/v1/rag_retrieve/{bot_id}"
|
||||
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
|
||||
if not url:
|
||||
return {
|
||||
"content": [
|
||||
@ -48,7 +48,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
# 获取masterkey并生成认证token
|
||||
masterkey = os.getenv("MASTERKEY", "master")
|
||||
masterkey = MASTERKEY
|
||||
token_input = f"{masterkey}:{bot_id}"
|
||||
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from mcp_common import (
|
||||
create_tools_list_response,
|
||||
handle_mcp_streaming
|
||||
)
|
||||
from utils.settings import FASTAPI_URL
|
||||
|
||||
import requests
|
||||
|
||||
@ -32,7 +33,7 @@ import requests
|
||||
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
|
||||
"""通过API编码单个查询"""
|
||||
if not fastapi_url:
|
||||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||
fastapi_url = FASTAPI_URL
|
||||
|
||||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
||||
|
||||
|
||||
@ -20,13 +20,14 @@ from utils.fastapi_utils import (
|
||||
create_stream_chunk
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 初始化全局助手管理器
|
||||
agent_manager = init_global_sharded_agent_manager(
|
||||
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")),
|
||||
shard_count=int(os.getenv("SHARD_COUNT", "16"))
|
||||
max_cached_agents=MAX_CACHED_AGENTS,
|
||||
shard_count=SHARD_COUNT
|
||||
)
|
||||
|
||||
|
||||
@ -148,7 +149,7 @@ async def enhanced_generate_stream_response(
|
||||
config["configurable"] = {"thread_id": session_id}
|
||||
if hasattr(agent, 'logging_handler'):
|
||||
config["callbacks"] = [agent.logging_handler]
|
||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config):
|
||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS):
|
||||
new_content = ""
|
||||
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
@ -328,7 +329,7 @@ async def create_agent_and_generate_response(
|
||||
config["configurable"] = {"thread_id": session_id}
|
||||
if hasattr(agent, 'logging_handler'):
|
||||
config["callbacks"] = [agent.logging_handler]
|
||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config)
|
||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS)
|
||||
append_messages = agent_responses["messages"][len(final_messages):]
|
||||
response_text = ""
|
||||
for msg in append_messages:
|
||||
|
||||
@ -20,6 +20,11 @@ except ImportError:
|
||||
from embedding import get_model_manager
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from utils.settings import (
|
||||
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
|
||||
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
|
||||
TOKENIZERS_PARALLELISM
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
@ -45,28 +50,28 @@ logger.info("正在初始化系统优化...")
|
||||
system_optimizer = setup_system_optimizations()
|
||||
|
||||
# 全局助手管理器配置(使用优化后的配置)
|
||||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小
|
||||
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量
|
||||
max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
|
||||
shard_count = SHARD_COUNT # 分片数量
|
||||
|
||||
# 初始化优化的全局助手管理器
|
||||
agent_manager = init_global_sharded_agent_manager(
|
||||
max_cached_agents=max_cached_agents,
|
||||
max_cached_agents=max_cached_agents,
|
||||
shard_count=shard_count
|
||||
)
|
||||
|
||||
# 初始化连接池
|
||||
connection_pool = init_global_connection_pool(
|
||||
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")),
|
||||
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")),
|
||||
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")),
|
||||
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")),
|
||||
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60"))
|
||||
max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
|
||||
max_connections_total=MAX_CONNECTIONS_TOTAL,
|
||||
keepalive_timeout=KEEPALIVE_TIMEOUT,
|
||||
connect_timeout=CONNECT_TIMEOUT,
|
||||
total_timeout=TOTAL_TIMEOUT
|
||||
)
|
||||
|
||||
# 初始化文件缓存
|
||||
file_cache = init_global_file_cache(
|
||||
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")),
|
||||
ttl=int(os.getenv("FILE_CACHE_TTL", "300"))
|
||||
cache_size=FILE_CACHE_SIZE,
|
||||
ttl=FILE_CACHE_TTL
|
||||
)
|
||||
|
||||
logger.info("系统优化初始化完成")
|
||||
@ -191,11 +196,11 @@ async def get_system_config():
|
||||
"config": {
|
||||
"max_cached_agents": max_cached_agents,
|
||||
"shard_count": shard_count,
|
||||
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"),
|
||||
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"),
|
||||
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"),
|
||||
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"),
|
||||
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300")
|
||||
"tokenizer_parallelism": TOKENIZERS_PARALLELISM,
|
||||
"max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
|
||||
"max_connections_total": str(MAX_CONNECTIONS_TOTAL),
|
||||
"file_cache_size": str(FILE_CACHE_SIZE),
|
||||
"file_cache_ttl": str(FILE_CACHE_TTL)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from fastapi import HTTPException
|
||||
import logging
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.chat_models import init_chat_model
|
||||
from utils.settings import MASTERKEY, BACKEND_HOST
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
@ -389,16 +390,14 @@ def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
||||
|
||||
def generate_v2_auth_token(bot_id: str) -> str:
|
||||
"""生成v2接口的认证token"""
|
||||
masterkey = os.getenv("MASTERKEY", "master")
|
||||
token_input = f"{masterkey}:{bot_id}"
|
||||
token_input = f"{MASTERKEY}:{bot_id}"
|
||||
return hashlib.md5(token_input.encode()).hexdigest()
|
||||
|
||||
|
||||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||||
"""获取机器人配置从后端API"""
|
||||
try:
|
||||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
||||
url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
|
||||
|
||||
auth_token = generate_v2_auth_token(bot_id)
|
||||
headers = {
|
||||
|
||||
35
utils/settings.py
Normal file
35
utils/settings.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
|
||||
# LLM Token Settings
|
||||
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
|
||||
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
||||
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
|
||||
|
||||
# Agent and Shard Settings
|
||||
MAX_CACHED_AGENTS = int(os.getenv("MAX_CACHED_AGENTS", 50))
|
||||
SHARD_COUNT = int(os.getenv("SHARD_COUNT", 16))
|
||||
|
||||
# Connection Settings
|
||||
MAX_CONNECTIONS_PER_HOST = int(os.getenv("MAX_CONNECTIONS_PER_HOST", 100))
|
||||
MAX_CONNECTIONS_TOTAL = int(os.getenv("MAX_CONNECTIONS_TOTAL", 500))
|
||||
KEEPALIVE_TIMEOUT = int(os.getenv("KEEPALIVE_TIMEOUT", 30))
|
||||
CONNECT_TIMEOUT = int(os.getenv("CONNECT_TIMEOUT", 10))
|
||||
TOTAL_TIMEOUT = int(os.getenv("TOTAL_TIMEOUT", 60))
|
||||
|
||||
# File Cache Settings
|
||||
FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000))
|
||||
FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300))
|
||||
|
||||
# API Settings
|
||||
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||
MASTERKEY = os.getenv("MASTERKEY", "master")
|
||||
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||
|
||||
# Project Settings
|
||||
PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data")
|
||||
|
||||
# Tokenizer Settings
|
||||
TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
# Embedding Model Settings
|
||||
SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
|
||||
Loading…
Reference in New Issue
Block a user