midyf model_id
This commit is contained in:
parent
51f436d7f7
commit
070b3e0057
@ -15,17 +15,23 @@ class MaxKBMinerUConfig(MinerUConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, llm_model_id: str = None, vision_model_id: str = None):
|
def create(cls, llm_model_id: str = None, vision_model_id: str = None):
|
||||||
"""Factory method to create config with specific model IDs"""
|
"""Factory method to create config with specific model IDs"""
|
||||||
instance = cls()
|
|
||||||
# Override model IDs after creation
|
|
||||||
if llm_model_id:
|
|
||||||
instance.llm_model_id = llm_model_id
|
|
||||||
if vision_model_id:
|
|
||||||
instance.vision_model_id = vision_model_id
|
|
||||||
|
|
||||||
# Log the configured model IDs
|
|
||||||
from .logger import get_module_logger
|
from .logger import get_module_logger
|
||||||
logger = get_module_logger('config_maxkb')
|
logger = get_module_logger('config_maxkb')
|
||||||
logger.info(f"MaxKBMinerUConfig.create() set LLM={instance.llm_model_id}, Vision={instance.vision_model_id}")
|
logger.info(f"MaxKBMinerUConfig.create() called with llm_model_id={llm_model_id}, vision_model_id={vision_model_id}")
|
||||||
|
|
||||||
|
instance = cls()
|
||||||
|
logger.info(f"After cls(), before override: LLM={instance.llm_model_id}, Vision={instance.vision_model_id}")
|
||||||
|
|
||||||
|
# Override model IDs after creation - MUST override both to prevent defaults
|
||||||
|
if llm_model_id:
|
||||||
|
instance.llm_model_id = llm_model_id
|
||||||
|
logger.info(f"Set llm_model_id to {llm_model_id}")
|
||||||
|
if vision_model_id:
|
||||||
|
instance.vision_model_id = vision_model_id
|
||||||
|
logger.info(f"Set vision_model_id to {vision_model_id}")
|
||||||
|
|
||||||
|
# Log the final configured model IDs
|
||||||
|
logger.info(f"MaxKBMinerUConfig.create() final: LLM={instance.llm_model_id}, Vision={instance.vision_model_id}")
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -34,8 +40,11 @@ class MaxKBMinerUConfig(MinerUConfig):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
# MaxKB specific settings from environment or defaults
|
# MaxKB specific settings from environment or defaults
|
||||||
# 如果环境变量中设置了具体的UUID,使用UUID;否则使用默认值或自动检测
|
# 只有在属性不存在时才设置默认值,避免覆盖已经设置的值
|
||||||
|
if not hasattr(self, 'llm_model_id'):
|
||||||
self.llm_model_id = os.getenv('MAXKB_LLM_MODEL_ID', self._get_default_llm_model_id())
|
self.llm_model_id = os.getenv('MAXKB_LLM_MODEL_ID', self._get_default_llm_model_id())
|
||||||
|
|
||||||
|
if not hasattr(self, 'vision_model_id'):
|
||||||
self.vision_model_id = os.getenv('MAXKB_VISION_MODEL_ID', self._get_default_vision_model_id())
|
self.vision_model_id = os.getenv('MAXKB_VISION_MODEL_ID', self._get_default_vision_model_id())
|
||||||
|
|
||||||
# Log the configured model IDs
|
# Log the configured model IDs
|
||||||
@ -287,21 +296,34 @@ class MaxKBMinerUConfig(MinerUConfig):
|
|||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from models_provider.models import Model
|
from models_provider.models import Model
|
||||||
|
|
||||||
# 首先尝试获取专门的视觉模型
|
# 首先尝试获取专门的视觉模型(IMAGE类型)
|
||||||
model = QuerySet(Model).filter(
|
model = QuerySet(Model).filter(
|
||||||
model_type__in=['VISION', 'MULTIMODAL']
|
model_type__in=['IMAGE', 'VISION', 'MULTIMODAL']
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
# 如果没有,获取任意LLM模型(许多LLM支持视觉)
|
# 如果没有IMAGE类型,尝试查找名称包含vision的模型
|
||||||
if not model:
|
if not model:
|
||||||
model = QuerySet(Model).filter(
|
model = QuerySet(Model).filter(
|
||||||
model_type__in=['LLM', 'CHAT']
|
model_name__icontains='vision'
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
|
# 最后的备选:获取不同于LLM的模型
|
||||||
|
if not model:
|
||||||
|
# 先获取已经用作LLM的模型ID
|
||||||
|
llm_id = self.llm_model_id if hasattr(self, 'llm_model_id') else None
|
||||||
|
if llm_id:
|
||||||
|
# 获取一个不同的模型
|
||||||
|
model = QuerySet(Model).exclude(id=llm_id).first()
|
||||||
|
else:
|
||||||
|
# 如果没有llm_id,获取任意模型
|
||||||
|
model = QuerySet(Model).first()
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
return str(model.id)
|
return str(model.id)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
from .logger import get_module_logger
|
||||||
|
logger = get_module_logger('config_maxkb')
|
||||||
|
logger.warning(f"Failed to get default vision model: {e}")
|
||||||
|
|
||||||
# 返回默认值
|
# 返回默认值
|
||||||
return 'default-vision'
|
return 'default-vision'
|
||||||
71
test_config_chain.py
Normal file
71
test_config_chain.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试配置对象的传递链
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 设置环境变量,避免从环境获取默认值
|
||||||
|
os.environ['MAXKB_LLM_MODEL_ID'] = ''
|
||||||
|
os.environ['MAXKB_VISION_MODEL_ID'] = ''
|
||||||
|
|
||||||
|
print("Testing config chain")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 模拟 dataclass
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseConfig:
|
||||||
|
"""Base configuration"""
|
||||||
|
api_url: str = "default_url"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
print(f" BaseConfig.__post_init__ called")
|
||||||
|
|
||||||
|
class TestConfig(BaseConfig):
|
||||||
|
"""Test configuration with model IDs"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, llm_id=None, vision_id=None):
|
||||||
|
print(f"TestConfig.create() called with llm_id={llm_id}, vision_id={vision_id}")
|
||||||
|
instance = cls()
|
||||||
|
print(f" After cls(): llm={getattr(instance, 'llm_id', 'NOT SET')}, vision={getattr(instance, 'vision_id', 'NOT SET')}")
|
||||||
|
|
||||||
|
if llm_id:
|
||||||
|
instance.llm_id = llm_id
|
||||||
|
print(f" Set llm_id to {llm_id}")
|
||||||
|
if vision_id:
|
||||||
|
instance.vision_id = vision_id
|
||||||
|
print(f" Set vision_id to {vision_id}")
|
||||||
|
|
||||||
|
print(f" Final: llm={instance.llm_id}, vision={instance.vision_id}")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
print(f" TestConfig.__post_init__ called")
|
||||||
|
super().__post_init__()
|
||||||
|
# Set defaults
|
||||||
|
self.llm_id = "default_llm"
|
||||||
|
self.vision_id = "default_vision"
|
||||||
|
print(f" Set defaults: llm={self.llm_id}, vision={self.vision_id}")
|
||||||
|
|
||||||
|
# Test 1: Direct creation
|
||||||
|
print("\nTest 1: Direct creation (should use defaults)")
|
||||||
|
config1 = TestConfig()
|
||||||
|
print(f"Result: llm={config1.llm_id}, vision={config1.vision_id}")
|
||||||
|
|
||||||
|
# Test 2: Factory method
|
||||||
|
print("\nTest 2: Factory method with IDs")
|
||||||
|
config2 = TestConfig.create(llm_id="llm_123", vision_id="vision_456")
|
||||||
|
print(f"Result: llm={config2.llm_id}, vision={config2.vision_id}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Analysis:")
|
||||||
|
if config2.llm_id == "llm_123" and config2.vision_id == "vision_456":
|
||||||
|
print("✅ Factory method correctly overrides defaults")
|
||||||
|
else:
|
||||||
|
print("❌ Problem: Factory method failed to override defaults")
|
||||||
|
print(f" Expected: llm=llm_123, vision=vision_456")
|
||||||
|
print(f" Got: llm={config2.llm_id}, vision={config2.vision_id}")
|
||||||
Loading…
Reference in New Issue
Block a user