refactor: 优化模型管理 (#749)
This commit is contained in:
parent
1452df7f1c
commit
410e065b52
@ -9,9 +9,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, Iterator
|
||||
from typing import Dict, Iterator, Type, List
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
||||
@ -47,39 +47,53 @@ class DownModelChunk:
|
||||
|
||||
|
||||
class IModelProvider(ABC):
|
||||
@abstractmethod
|
||||
def get_model_info_manage(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_provide_info(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_type_list(self):
|
||||
pass
|
||||
return self.get_model_info_manage().get_model_type_list()
|
||||
|
||||
@abstractmethod
|
||||
def get_model_list(self, model_type):
|
||||
pass
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return self.get_model_info_manage().get_model_list()
|
||||
|
||||
@abstractmethod
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
pass
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_credential
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
||||
pass
|
||||
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
|
||||
raise_exception=raise_exception)
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_dialogue_number(self):
|
||||
pass
|
||||
return 3
|
||||
|
||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||
raise AppApiException(500, "当前平台不支持下载模型")
|
||||
|
||||
|
||||
class MaxKBBaseModel(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelCredential(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False):
|
||||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -113,15 +127,18 @@ class BaseModelCredential(ABC):
|
||||
|
||||
class ModelTypeConst(Enum):
|
||||
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
||||
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
|
||||
model_class: Type[MaxKBBaseModel],
|
||||
**keywords):
|
||||
self.name = name
|
||||
self.desc = desc
|
||||
self.model_type = model_type.name
|
||||
self.model_credential = model_credential
|
||||
self.model_class = model_class
|
||||
if keywords is not None:
|
||||
for key in keywords.keys():
|
||||
self.__setattr__(key, keywords.get(key))
|
||||
@ -143,10 +160,66 @@ class ModelInfo:
|
||||
def get_model_type(self):
|
||||
return self.model_type
|
||||
|
||||
def get_model_class(self):
|
||||
return self.model_class
|
||||
|
||||
def to_dict(self):
|
||||
return reduce(lambda x, y: {**x, **y},
|
||||
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
||||
not attr.startswith("__") and not attr == 'model_credential'], {})
|
||||
not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
|
||||
|
||||
|
||||
class ModelInfoManage:
|
||||
def __init__(self):
|
||||
self.model_dict = {}
|
||||
self.model_list = []
|
||||
self.default_model_list = []
|
||||
self.default_model_dict = {}
|
||||
|
||||
def append_model_info(self, model_info: ModelInfo):
|
||||
self.model_list.append(model_info)
|
||||
model_type_dict = self.model_dict.get(model_info.model_type)
|
||||
if model_type_dict is None:
|
||||
self.model_dict[model_info.model_type] = {model_info.name: model_info}
|
||||
else:
|
||||
model_type_dict[model_info.name] = model_info
|
||||
|
||||
def append_default_model_info(self, model_info: ModelInfo):
|
||||
self.default_model_list.append(model_info)
|
||||
self.default_model_dict[model_info.model_type] = model_info
|
||||
|
||||
def get_model_list(self):
|
||||
return [model.to_dict() for model in self.model_list]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
||||
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
||||
|
||||
def get_model_info(self, model_type, model_name) -> ModelInfo:
|
||||
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
|
||||
if model_info is None:
|
||||
raise AppApiException(500, '模型不支持')
|
||||
return model_info
|
||||
|
||||
class builder:
|
||||
def __init__(self):
|
||||
self.modelInfoManage = ModelInfoManage()
|
||||
|
||||
def append_model_info(self, model_info: ModelInfo):
|
||||
self.modelInfoManage.append_model_info(model_info)
|
||||
return self
|
||||
|
||||
def append_model_info_list(self, model_info_list: List[ModelInfo]):
|
||||
for model_info in model_info_list:
|
||||
self.modelInfoManage.append_model_info(model_info)
|
||||
return self
|
||||
|
||||
def append_default_model_info(self, model_info: ModelInfo):
|
||||
self.modelInfoManage.append_default_model_info(model_info)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.modelInfoManage
|
||||
|
||||
|
||||
class ModelProvideInfo:
|
||||
|
||||
@ -9,15 +9,15 @@
|
||||
from enum import Enum
|
||||
|
||||
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
||||
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
||||
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
|
||||
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
|
||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
||||
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
||||
|
||||
|
||||
class ModelProvideConstants(Enum):
|
||||
|
||||
@ -7,98 +7,30 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst, ValidCode
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
|
||||
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
base_azure_llm_model_credential = AzureLLMModelCredential()
|
||||
|
||||
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
|
||||
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
|
||||
)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = AzureModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
|
||||
|
||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||
|
||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||
|
||||
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
|
||||
|
||||
|
||||
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'deployment_name': ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
|
||||
base_azure_llm_model_credential, api_version='2024-02-15-preview'
|
||||
)
|
||||
}
|
||||
model_info_manage = ModelInfoManage.builder().append_default_model_info(default_model_info).append_model_info(
|
||||
default_model_info).build()
|
||||
|
||||
|
||||
class AzureModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
|
||||
azure_chat_open_ai = AzureChatModel(
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_type="azure"
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return base_azure_llm_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon',
|
||||
'azure_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -0,0 +1,55 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:08
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
|
||||
|
||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||
|
||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||
|
||||
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
|
||||
@ -6,15 +6,26 @@
|
||||
@date:2024/4/28 11:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class AzureChatModel(AzureChatOpenAI):
|
||||
class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return AzureChatModel(
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_type="azure"
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
|
||||
@ -0,0 +1,48 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:51
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -7,91 +7,35 @@
|
||||
@Date :5/12/24 7:40 AM
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, ModelTypeConst, ValidCode
|
||||
from setting.models_provider.impl.deepseek_model_provider.model.deepseek_chat_model import DeepSeekChatModel
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
|
||||
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = DeepSeekModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = DeepSeekModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'deepseek-chat': ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential,
|
||||
),
|
||||
'deepseek-coder': ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential,
|
||||
),
|
||||
}
|
||||
deepseek_chat = ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential, DeepSeekChatModel
|
||||
)
|
||||
|
||||
deepseek_coder = ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential,
|
||||
DeepSeekChatModel)
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_chat).append_model_info(
|
||||
deepseek_coder).append_default_model_info(
|
||||
deepseek_coder).build()
|
||||
|
||||
|
||||
class DeepSeekModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> DeepSeekChatModel:
|
||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||
model=model_name,
|
||||
openai_api_base='https://api.deepseek.com',
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return deepseek_chat_open_ai
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return deepseek_llm_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
|
||||
'deepseek_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :deepseek_chat_model.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/12/24 7:44 AM
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class DeepSeekChatModel(ChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :llm.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/12/24 7:44 AM
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||
model=model_name,
|
||||
openai_api_base='https://api.deepseek.com',
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return deepseek_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -0,0 +1,48 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:57
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -7,93 +7,36 @@
|
||||
@Date :5/13/24 7:47 AM
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, ModelTypeConst, ValidCode
|
||||
from setting.models_provider.impl.gemini_model_provider.model.gemini_chat_model import GeminiChatModel
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
|
||||
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = GeminiModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = GeminiModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
gemini_llm_model_credential = GeminiLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'gemini-1.0-pro': ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
|
||||
gemini_1_pro = ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
|
||||
ModelTypeConst.LLM,
|
||||
gemini_llm_model_credential,
|
||||
GeminiChatModel)
|
||||
|
||||
gemini_1_pro_vision = ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
|
||||
ModelTypeConst.LLM,
|
||||
gemini_llm_model_credential,
|
||||
),
|
||||
'gemini-1.0-pro-vision': ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
|
||||
ModelTypeConst.LLM,
|
||||
gemini_llm_model_credential,
|
||||
),
|
||||
}
|
||||
GeminiChatModel)
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info(gemini_1_pro).append_model_info(
|
||||
gemini_1_pro_vision).append_default_model_info(gemini_1_pro).build()
|
||||
|
||||
|
||||
class GeminiModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
|
||||
**model_kwargs) -> GeminiChatModel:
|
||||
gemini_chat = GeminiChatModel(
|
||||
model=model_name,
|
||||
google_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return gemini_chat
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return gemini_llm_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
|
||||
'gemini_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :gemini_chat_model.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/13/24 7:40 AM
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class GeminiChatModel(ChatGoogleGenerativeAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :llm.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/13/24 7:40 AM
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
gemini_chat = GeminiChatModel(
|
||||
model=model_name,
|
||||
google_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return gemini_chat
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -7,103 +7,36 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst, ValidCode
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.kimi_model_provider.credential.llm import KimiLLMModelCredential
|
||||
from setting.models_provider.impl.kimi_model_provider.model.llm import KimiChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from setting.models_provider.impl.kimi_model_provider.model.kimi_chat_model import KimiChatModel
|
||||
|
||||
|
||||
|
||||
|
||||
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = KimiModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
# llm_kimi = Moonshot(
|
||||
# model_name=model_name,
|
||||
# base_url=model_credential['api_base'],
|
||||
# moonshot_api_key=model_credential['api_key']
|
||||
# )
|
||||
|
||||
model = KimiModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
kimi_llm_model_credential = KimiLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'moonshot-v1-8k': ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||
),
|
||||
'moonshot-v1-32k': ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||
),
|
||||
'moonshot-v1-128k': ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||
)
|
||||
}
|
||||
moonshot_v1_8k = ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||
KimiChatModel)
|
||||
moonshot_v1_32k = ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||
KimiChatModel)
|
||||
moonshot_v1_128k = ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||
KimiChatModel)
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info(moonshot_v1_8k).append_model_info(
|
||||
moonshot_v1_32k).append_default_model_info(moonshot_v1_128k).append_default_model_info(moonshot_v1_8k).build()
|
||||
|
||||
|
||||
class KimiModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
||||
kimi_chat_open_ai = KimiChatModel(
|
||||
openai_api_base=model_credential['api_base'],
|
||||
openai_api_key=model_credential['api_key'],
|
||||
model_name=model_name,
|
||||
)
|
||||
return kimi_chat_open_ai
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return kimi_llm_model_credential
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon',
|
||||
'kimi_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -2,19 +2,29 @@
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: kimi_chat_model.py
|
||||
@file: llm.py
|
||||
@date:2023/11/10 17:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class KimiChatModel(ChatOpenAI):
|
||||
class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
kimi_chat_open_ai = KimiChatModel(
|
||||
openai_api_base=model_credential['api_base'],
|
||||
openai_api_key=model_credential['api_key'],
|
||||
model_name=model_name,
|
||||
)
|
||||
return kimi_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:19
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
model_list = provider.get_base_model_list(model_credential.get('api_base'))
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/3/6 11:48
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
parse = urlparse(url)
|
||||
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
|
||||
query='',
|
||||
fragment='').geturl()
|
||||
|
||||
|
||||
class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'),
|
||||
openai_api_key=model_credential.get('api_key'))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -1,24 +0,0 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: ollama_chat_model.py
|
||||
@date:2024/3/6 11:48
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class OllamaChatModel(ChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -19,106 +19,83 @@ from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode
|
||||
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
|
||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
""
|
||||
|
||||
|
||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = OllamaModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
ollama_llm_model_credential = OllamaLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'llama2': ModelInfo(
|
||||
model_info_list = [
|
||||
ModelInfo(
|
||||
'llama2',
|
||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2:13b': ModelInfo(
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'llama2:13b',
|
||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2:70b': ModelInfo(
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'llama2:70b',
|
||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2-chinese:13b': ModelInfo(
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'llama2-chinese:13b',
|
||||
'由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama3:8b': ModelInfo(
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'llama3:8b',
|
||||
'Meta Llama 3:迄今为止最有能力的公开产品LLM。8亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama3:70b': ModelInfo(
|
||||
'Meta Llama 3:迄今为止最有能力的公开产品LLM。80亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'llama3:70b',
|
||||
'Meta Llama 3:迄今为止最有能力的公开产品LLM。70亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:0.5b': ModelInfo(
|
||||
'Meta Llama 3:迄今为止最有能力的公开产品LLM。700亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:0.5b',
|
||||
'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。0.5亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:1.8b': ModelInfo(
|
||||
'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。5亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:1.8b',
|
||||
'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1.8亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:4b': ModelInfo(
|
||||
'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。18亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:4b',
|
||||
'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。4亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:7b': ModelInfo(
|
||||
'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。40亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
|
||||
ModelInfo(
|
||||
'qwen:7b',
|
||||
'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。7亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:14b': ModelInfo(
|
||||
'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。70亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:14b',
|
||||
'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。14亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:32b': ModelInfo(
|
||||
'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。140亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:32b',
|
||||
'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。32亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:72b': ModelInfo(
|
||||
'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。320亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:72b',
|
||||
'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。72亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'qwen:110b': ModelInfo(
|
||||
'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。720亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'qwen:110b',
|
||||
'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。110亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'phi3': ModelInfo(
|
||||
'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1100亿参数。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
}
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
@ -169,32 +146,14 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
|
||||
|
||||
|
||||
class OllamaModelProvider(IModelProvider):
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
||||
'ollama_icon_svg')))
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
# 如果使用模型不在配置中,则使用默认认证
|
||||
return ollama_llm_model_credential
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
||||
api_base = model_credential.get('api_base')
|
||||
base_url = get_base_url(api_base)
|
||||
return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'),
|
||||
openai_api_key=model_credential.get('api_key'))
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
||||
|
||||
@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -0,0 +1,34 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -1,30 +0,0 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: openai_chat_model.py
|
||||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class OpenAIChatModel(ChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -7,127 +7,70 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst, ValidCode
|
||||
from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = OpenAIModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = OpenAIModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
openai_llm_model_credential = OpenAILLMModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
),
|
||||
ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-3.5-turbo-0125',
|
||||
'2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-3.5-turbo-1106',
|
||||
'2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-3.5-turbo-0613',
|
||||
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4o-2024-05-13',
|
||||
'2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-turbo-2024-04-09',
|
||||
'2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel)
|
||||
]
|
||||
|
||||
model_dict = {
|
||||
'gpt-3.5-turbo': ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4': ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4o': ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4-turbo': ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4-turbo-preview': ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-3.5-turbo-0125': ModelInfo('gpt-3.5-turbo-0125',
|
||||
'2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
),
|
||||
'gpt-3.5-turbo-1106': ModelInfo('gpt-3.5-turbo-1106',
|
||||
'2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
),
|
||||
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613',
|
||||
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4o-2024-05-13': ModelInfo('gpt-4o-2024-05-13',
|
||||
'2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4-turbo-2024-04-09': ModelInfo('gpt-4-turbo-2024-04-09',
|
||||
'2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4-0125-preview': ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
'gpt-4-1106-preview': ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
),
|
||||
}
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
)).build()
|
||||
|
||||
|
||||
class OpenAIModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel:
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return openai_llm_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon',
|
||||
'openai_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:41
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -2,19 +2,28 @@
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: qwen_chat_model.py
|
||||
@file: llm.py
|
||||
@date:2024/4/28 11:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class QwenChatModel(ChatTongyi):
|
||||
class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
chat_tong_yi = QwenChatModel(
|
||||
model_name=model_name,
|
||||
dashscope_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
@ -7,87 +7,33 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider, ValidCode
|
||||
from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
|
||||
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = QwenModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = QwenModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
qwen_model_credential = OpenAILLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'qwen-turbo': ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential),
|
||||
'qwen-plus': ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential),
|
||||
'qwen-max': ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential)
|
||||
}
|
||||
module_info_list = [
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||
ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info(
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build()
|
||||
|
||||
|
||||
class QwenModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
|
||||
chat_tong_yi = QwenChatModel(
|
||||
model_name=model_name,
|
||||
dashscope_api_key=model_credential.get('api_key')
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return qwen_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon',
|
||||
'qwen_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -0,0 +1,55 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/12 10:19
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model_info = [model.lower() for model in model.client.models()]
|
||||
if not model_info.__contains__(model_name.lower()):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
||||
for key in ['api_key', 'secret_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model.invoke(
|
||||
[HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
raise e
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'secret_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
self.secret_key = model_info.get('secret_key')
|
||||
return self
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
||||
@ -0,0 +1,33 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2023/11/10 17:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return QianfanChatModel(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
streaming=model_kwargs.get('streaming', False))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -1,32 +0,0 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: qian_fan_chat_model.py
|
||||
@date:2023/11/10 17:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import Optional, List, Any, Iterator, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.load import dumpd
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class QianfanChatModel(QianfanChatEndpoint):
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
@ -7,121 +7,53 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from qianfan import ChatCompletion
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider, ValidCode
|
||||
from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.wenxin_model_provider.credential.llm import WenxinLLMModelCredential
|
||||
from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = WenxinModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
model = WenxinModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model_info = [model.lower() for model in model.client.models()]
|
||||
if not model_info.__contains__(model_name.lower()):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
||||
for key in ['api_key', 'secret_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model.invoke(
|
||||
[HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
raise e
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'secret_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
self.secret_key = model_info.get('secret_key')
|
||||
return self
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
||||
|
||||
|
||||
win_xin_llm_model_credential = WenxinLLMModelCredential()
|
||||
model_dict = {
|
||||
'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4',
|
||||
model_info_list = [ModelInfo('ERNIE-Bot-4',
|
||||
'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('ERNIE-Bot',
|
||||
'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('ERNIE-Bot-turbo',
|
||||
'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('BLOOMZ-7B',
|
||||
'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('Llama-2-7b-chat',
|
||||
'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('Llama-2-13b-chat',
|
||||
'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('Llama-2-70b-chat',
|
||||
'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||
ModelInfo('Qianfan-Chinese-Llama-2-7B',
|
||||
'千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel)
|
||||
]
|
||||
|
||||
'ERNIE-Bot': ModelInfo('ERNIE-Bot',
|
||||
'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo',
|
||||
'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'BLOOMZ-7B': ModelInfo('BLOOMZ-7B',
|
||||
'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat',
|
||||
'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat',
|
||||
'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat',
|
||||
'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
||||
|
||||
'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B',
|
||||
'千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。',
|
||||
ModelTypeConst.LLM, win_xin_llm_model_credential)
|
||||
}
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('ERNIE-Bot-4',
|
||||
'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||
ModelTypeConst.LLM,
|
||||
win_xin_llm_model_credential,
|
||||
QianfanChatModel)).build()
|
||||
|
||||
|
||||
class WenxinModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
|
||||
**model_kwargs) -> QianfanChatEndpoint:
|
||||
return QianfanChatModel(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
streaming=model_kwargs.get('streaming', False))
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return win_xin_llm_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/12 10:29
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
spark_api_url = forms.TextInputField('API 域名', required=True)
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
@ -7,7 +7,7 @@
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Any, Iterator
|
||||
from typing import List, Optional, Any, Iterator, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatSparkLLM
|
||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
||||
@ -16,9 +16,21 @@ from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_stri
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class XFChatSparkLLM(ChatSparkLLM):
|
||||
class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return XFChatSparkLLM(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
spark_llm_domain=model_name
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
@ -7,97 +7,33 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models import ChatSparkLLM
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider, ValidCode
|
||||
from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
import ssl
|
||||
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
ssl._create_default_https_context = ssl.create_default_context()
|
||||
|
||||
|
||||
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = XunFeiModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = XunFeiModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
spark_api_url = forms.TextInputField('API 域名', required=True)
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
|
||||
qwen_model_credential = XunFeiLLMModelCredential()
|
||||
model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)
|
||||
]
|
||||
|
||||
model_dict = {
|
||||
'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential),
|
||||
'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential),
|
||||
'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential)
|
||||
}
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()
|
||||
|
||||
|
||||
class XunFeiModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM:
|
||||
zhipuai_chat = XFChatSparkLLM(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
spark_llm_domain=model_name
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return qwen_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon',
|
||||
'xf_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/12 10:46
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -2,19 +2,29 @@
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: zhipu_chat_model.py
|
||||
@file: llm.py
|
||||
@date:2024/4/28 11:42
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class ZhipuChatModel(ChatZhipuAI):
|
||||
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
zhipuai_chat = ZhipuChatModel(
|
||||
temperature=0.5,
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
@ -7,88 +7,30 @@
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider, ValidCode
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = ZhiPuModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
qwen_model_credential = ZhiPuLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential),
|
||||
'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential),
|
||||
'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential)
|
||||
}
|
||||
model_info_list = [
|
||||
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
||||
ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
||||
ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)
|
||||
]
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info(
|
||||
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)).build()
|
||||
|
||||
|
||||
class ZhiPuModelProvider(IModelProvider):
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
|
||||
zhipuai_chat = ZhipuChatModel(
|
||||
temperature=0.5,
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
return qwen_model_credential
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon',
|
||||
'zhipuai_icon_svg')))
|
||||
|
||||
def get_model_list(self, model_type: str):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
@ -115,7 +115,7 @@ class ModelSerializer(serializers.Serializer):
|
||||
model_name = self.data.get(
|
||||
'model_name')
|
||||
credential = self.data.get('credential')
|
||||
|
||||
provider_handler = ModelProvideConstants[provider].value
|
||||
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
||||
model_name)
|
||||
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
||||
@ -124,7 +124,7 @@ class ModelSerializer(serializers.Serializer):
|
||||
for k in source_encryption_model_credential.keys():
|
||||
if credential[k] == source_encryption_model_credential[k]:
|
||||
credential[k] = source_model_credential[k]
|
||||
return credential, model_credential
|
||||
return credential, model_credential, provider_handler
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
@ -145,13 +145,10 @@ class ModelSerializer(serializers.Serializer):
|
||||
name=self.data.get('name')).exists():
|
||||
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
||||
# 校验模型认证数据
|
||||
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'),
|
||||
self.data.get(
|
||||
'model_name')).is_valid(
|
||||
self.data.get('model_type'),
|
||||
self.data.get('model_name'),
|
||||
self.data.get('credential'),
|
||||
raise_exception=True)
|
||||
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'),
|
||||
self.data.get('model_name'),
|
||||
self.data.get('credential')
|
||||
)
|
||||
|
||||
def insert(self, user_id, with_valid=False):
|
||||
status = Status.SUCCESS
|
||||
@ -232,16 +229,17 @@ class ModelSerializer(serializers.Serializer):
|
||||
if model is None:
|
||||
raise AppApiException(500, '不存在的id')
|
||||
else:
|
||||
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
|
||||
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
||||
data={**instance, 'user_id': user_id}).is_valid(
|
||||
model=model)
|
||||
try:
|
||||
model.status = Status.SUCCESS
|
||||
# 校验模型认证数据
|
||||
model_credential.is_valid(
|
||||
model.model_type,
|
||||
instance.get("model_name"),
|
||||
credential,
|
||||
raise_exception=True)
|
||||
provider_handler.is_valid_credential(model.model_type,
|
||||
instance.get("model_name"),
|
||||
credential,
|
||||
raise_exception=True)
|
||||
|
||||
except AppApiException as e:
|
||||
if e.code == ValidCode.model_not_fount:
|
||||
model.status = Status.DOWNLOAD
|
||||
|
||||
Loading…
Reference in New Issue
Block a user