settings
This commit is contained in:
parent
9ada70eb58
commit
77c8f5e501
@ -13,10 +13,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
|||||||
from utils.fastapi_utils import detect_provider
|
from utils.fastapi_utils import detect_provider
|
||||||
|
|
||||||
from .guideline_middleware import GuidelineMiddleware
|
from .guideline_middleware import GuidelineMiddleware
|
||||||
|
from utils.settings import SUMMARIZATION_MAX_TOKENS
|
||||||
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
|
|
||||||
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
|
||||||
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
|
|
||||||
|
|
||||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import asyncio
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from utils.settings import FASTAPI_URL
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
@ -19,8 +20,7 @@ def encode_texts_via_api(texts, batch_size=32):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# FastAPI 服务地址
|
# FastAPI 服务地址
|
||||||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
api_endpoint = f"{FASTAPI_URL}/api/v1/embedding/encode"
|
||||||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
|
||||||
|
|
||||||
# 调用编码接口
|
# 调用编码接口
|
||||||
request_data = {
|
request_data = {
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import logging
|
|||||||
from typing import Dict, List, Optional, Any, Tuple
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from utils.settings import SENTENCE_TRANSFORMER_MODEL
|
||||||
import threading
|
import threading
|
||||||
import psutil
|
import psutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -126,6 +127,5 @@ def get_model_manager() -> GlobalModelManager:
|
|||||||
"""获取模型管理器实例"""
|
"""获取模型管理器实例"""
|
||||||
global _model_manager
|
global _model_manager
|
||||||
if _model_manager is None:
|
if _model_manager is None:
|
||||||
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
|
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
|
||||||
_model_manager = GlobalModelManager(model_name)
|
|
||||||
return _model_manager
|
return _model_manager
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import asyncio
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_directory():
|
def get_allowed_directory():
|
||||||
"""获取允许访问的目录"""
|
"""获取允许访问的目录"""
|
||||||
# 优先使用命令行参数传入的dataset_dir
|
# 优先使用命令行参数传入的dataset_dir
|
||||||
|
|||||||
@ -26,8 +26,8 @@ from mcp_common import (
|
|||||||
load_tools_from_json,
|
load_tools_from_json,
|
||||||
handle_mcp_streaming
|
handle_mcp_streaming
|
||||||
)
|
)
|
||||||
|
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
MASTERKEY = os.getenv("MASTERKEY", "master")
|
||||||
|
|
||||||
def rag_retrieve(query: str) -> Dict[str, Any]:
|
def rag_retrieve(query: str) -> Dict[str, Any]:
|
||||||
"""调用RAG检索API"""
|
"""调用RAG检索API"""
|
||||||
@ -36,7 +36,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
|
|||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
bot_id = sys.argv[1]
|
bot_id = sys.argv[1]
|
||||||
|
|
||||||
url = f"{backend_host}/v1/rag_retrieve/{bot_id}"
|
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
|
||||||
if not url:
|
if not url:
|
||||||
return {
|
return {
|
||||||
"content": [
|
"content": [
|
||||||
@ -48,7 +48,7 @@ def rag_retrieve(query: str) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 获取masterkey并生成认证token
|
# 获取masterkey并生成认证token
|
||||||
masterkey = os.getenv("MASTERKEY", "master")
|
masterkey = MASTERKEY
|
||||||
token_input = f"{masterkey}:{bot_id}"
|
token_input = f"{masterkey}:{bot_id}"
|
||||||
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from mcp_common import (
|
|||||||
create_tools_list_response,
|
create_tools_list_response,
|
||||||
handle_mcp_streaming
|
handle_mcp_streaming
|
||||||
)
|
)
|
||||||
|
from utils.settings import FASTAPI_URL
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ import requests
|
|||||||
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
|
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
|
||||||
"""通过API编码单个查询"""
|
"""通过API编码单个查询"""
|
||||||
if not fastapi_url:
|
if not fastapi_url:
|
||||||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
fastapi_url = FASTAPI_URL
|
||||||
|
|
||||||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
||||||
|
|
||||||
|
|||||||
@ -20,13 +20,14 @@ from utils.fastapi_utils import (
|
|||||||
create_stream_chunk
|
create_stream_chunk
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||||
|
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
# 初始化全局助手管理器
|
# 初始化全局助手管理器
|
||||||
agent_manager = init_global_sharded_agent_manager(
|
agent_manager = init_global_sharded_agent_manager(
|
||||||
max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")),
|
max_cached_agents=MAX_CACHED_AGENTS,
|
||||||
shard_count=int(os.getenv("SHARD_COUNT", "16"))
|
shard_count=SHARD_COUNT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -148,7 +149,7 @@ async def enhanced_generate_stream_response(
|
|||||||
config["configurable"] = {"thread_id": session_id}
|
config["configurable"] = {"thread_id": session_id}
|
||||||
if hasattr(agent, 'logging_handler'):
|
if hasattr(agent, 'logging_handler'):
|
||||||
config["callbacks"] = [agent.logging_handler]
|
config["callbacks"] = [agent.logging_handler]
|
||||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config):
|
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS):
|
||||||
new_content = ""
|
new_content = ""
|
||||||
|
|
||||||
if isinstance(msg, AIMessageChunk):
|
if isinstance(msg, AIMessageChunk):
|
||||||
@ -328,7 +329,7 @@ async def create_agent_and_generate_response(
|
|||||||
config["configurable"] = {"thread_id": session_id}
|
config["configurable"] = {"thread_id": session_id}
|
||||||
if hasattr(agent, 'logging_handler'):
|
if hasattr(agent, 'logging_handler'):
|
||||||
config["callbacks"] = [agent.logging_handler]
|
config["callbacks"] = [agent.logging_handler]
|
||||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config)
|
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS)
|
||||||
append_messages = agent_responses["messages"][len(final_messages):]
|
append_messages = agent_responses["messages"][len(final_messages):]
|
||||||
response_text = ""
|
response_text = ""
|
||||||
for msg in append_messages:
|
for msg in append_messages:
|
||||||
|
|||||||
@ -20,6 +20,11 @@ except ImportError:
|
|||||||
from embedding import get_model_manager
|
from embedding import get_model_manager
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
|
from utils.settings import (
|
||||||
|
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
|
||||||
|
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
|
||||||
|
TOKENIZERS_PARALLELISM
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
@ -45,8 +50,8 @@ logger.info("正在初始化系统优化...")
|
|||||||
system_optimizer = setup_system_optimizations()
|
system_optimizer = setup_system_optimizations()
|
||||||
|
|
||||||
# 全局助手管理器配置(使用优化后的配置)
|
# 全局助手管理器配置(使用优化后的配置)
|
||||||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小
|
max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
|
||||||
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量
|
shard_count = SHARD_COUNT # 分片数量
|
||||||
|
|
||||||
# 初始化优化的全局助手管理器
|
# 初始化优化的全局助手管理器
|
||||||
agent_manager = init_global_sharded_agent_manager(
|
agent_manager = init_global_sharded_agent_manager(
|
||||||
@ -56,17 +61,17 @@ agent_manager = init_global_sharded_agent_manager(
|
|||||||
|
|
||||||
# 初始化连接池
|
# 初始化连接池
|
||||||
connection_pool = init_global_connection_pool(
|
connection_pool = init_global_connection_pool(
|
||||||
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")),
|
max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
|
||||||
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")),
|
max_connections_total=MAX_CONNECTIONS_TOTAL,
|
||||||
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")),
|
keepalive_timeout=KEEPALIVE_TIMEOUT,
|
||||||
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")),
|
connect_timeout=CONNECT_TIMEOUT,
|
||||||
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60"))
|
total_timeout=TOTAL_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化文件缓存
|
# 初始化文件缓存
|
||||||
file_cache = init_global_file_cache(
|
file_cache = init_global_file_cache(
|
||||||
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")),
|
cache_size=FILE_CACHE_SIZE,
|
||||||
ttl=int(os.getenv("FILE_CACHE_TTL", "300"))
|
ttl=FILE_CACHE_TTL
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("系统优化初始化完成")
|
logger.info("系统优化初始化完成")
|
||||||
@ -191,11 +196,11 @@ async def get_system_config():
|
|||||||
"config": {
|
"config": {
|
||||||
"max_cached_agents": max_cached_agents,
|
"max_cached_agents": max_cached_agents,
|
||||||
"shard_count": shard_count,
|
"shard_count": shard_count,
|
||||||
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"),
|
"tokenizer_parallelism": TOKENIZERS_PARALLELISM,
|
||||||
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"),
|
"max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
|
||||||
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"),
|
"max_connections_total": str(MAX_CONNECTIONS_TOTAL),
|
||||||
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"),
|
"file_cache_size": str(FILE_CACHE_SIZE),
|
||||||
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300")
|
"file_cache_ttl": str(FILE_CACHE_TTL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from fastapi import HTTPException
|
|||||||
import logging
|
import logging
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
from utils.settings import MASTERKEY, BACKEND_HOST
|
||||||
|
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
@ -389,16 +390,14 @@ def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
|||||||
|
|
||||||
def generate_v2_auth_token(bot_id: str) -> str:
|
def generate_v2_auth_token(bot_id: str) -> str:
|
||||||
"""生成v2接口的认证token"""
|
"""生成v2接口的认证token"""
|
||||||
masterkey = os.getenv("MASTERKEY", "master")
|
token_input = f"{MASTERKEY}:{bot_id}"
|
||||||
token_input = f"{masterkey}:{bot_id}"
|
|
||||||
return hashlib.md5(token_input.encode()).hexdigest()
|
return hashlib.md5(token_input.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||||||
"""获取机器人配置从后端API"""
|
"""获取机器人配置从后端API"""
|
||||||
try:
|
try:
|
||||||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
url = f"{BACKEND_HOST}/v1/agent_bot_config/{bot_id}"
|
||||||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
|
||||||
|
|
||||||
auth_token = generate_v2_auth_token(bot_id)
|
auth_token = generate_v2_auth_token(bot_id)
|
||||||
headers = {
|
headers = {
|
||||||
|
|||||||
35
utils/settings.py
Normal file
35
utils/settings.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
# LLM Token Settings
|
||||||
|
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
|
||||||
|
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
||||||
|
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
|
||||||
|
|
||||||
|
# Agent and Shard Settings
|
||||||
|
MAX_CACHED_AGENTS = int(os.getenv("MAX_CACHED_AGENTS", 50))
|
||||||
|
SHARD_COUNT = int(os.getenv("SHARD_COUNT", 16))
|
||||||
|
|
||||||
|
# Connection Settings
|
||||||
|
MAX_CONNECTIONS_PER_HOST = int(os.getenv("MAX_CONNECTIONS_PER_HOST", 100))
|
||||||
|
MAX_CONNECTIONS_TOTAL = int(os.getenv("MAX_CONNECTIONS_TOTAL", 500))
|
||||||
|
KEEPALIVE_TIMEOUT = int(os.getenv("KEEPALIVE_TIMEOUT", 30))
|
||||||
|
CONNECT_TIMEOUT = int(os.getenv("CONNECT_TIMEOUT", 10))
|
||||||
|
TOTAL_TIMEOUT = int(os.getenv("TOTAL_TIMEOUT", 60))
|
||||||
|
|
||||||
|
# File Cache Settings
|
||||||
|
FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000))
|
||||||
|
FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300))
|
||||||
|
|
||||||
|
# API Settings
|
||||||
|
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||||||
|
MASTERKEY = os.getenv("MASTERKEY", "master")
|
||||||
|
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||||
|
|
||||||
|
# Project Settings
|
||||||
|
PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data")
|
||||||
|
|
||||||
|
# Tokenizer Settings
|
||||||
|
TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true")
|
||||||
|
|
||||||
|
# Embedding Model Settings
|
||||||
|
SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
|
||||||
Loading…
Reference in New Issue
Block a user