refactor: 优化模型管理 (#749)
This commit is contained in:
parent
1452df7f1c
commit
410e065b52
@ -9,9 +9,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict, Iterator
|
from typing import Dict, Iterator, Type, List
|
||||||
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from pydantic.v1 import BaseModel
|
||||||
|
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
|
|
||||||
@ -47,39 +47,53 @@ class DownModelChunk:
|
|||||||
|
|
||||||
|
|
||||||
class IModelProvider(ABC):
|
class IModelProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model_type_list(self):
|
def get_model_type_list(self):
|
||||||
pass
|
return self.get_model_info_manage().get_model_type_list()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model_list(self, model_type):
|
def get_model_list(self, model_type):
|
||||||
pass
|
if model_type is None:
|
||||||
|
raise AppApiException(500, '模型类型不能为空')
|
||||||
|
return self.get_model_info_manage().get_model_list()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
def get_model_credential(self, model_type, model_name):
|
||||||
pass
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
|
return model_info.model_credential
|
||||||
|
|
||||||
@abstractmethod
|
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
pass
|
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
|
||||||
|
raise_exception=raise_exception)
|
||||||
|
|
||||||
|
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
|
||||||
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
|
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dialogue_number(self):
|
def get_dialogue_number(self):
|
||||||
pass
|
return 3
|
||||||
|
|
||||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||||
raise AppApiException(500, "当前平台不支持下载模型")
|
raise AppApiException(500, "当前平台不支持下载模型")
|
||||||
|
|
||||||
|
|
||||||
|
class MaxKBBaseModel(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseModelCredential(ABC):
|
class BaseModelCredential(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False):
|
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -113,15 +127,18 @@ class BaseModelCredential(ABC):
|
|||||||
|
|
||||||
class ModelTypeConst(Enum):
|
class ModelTypeConst(Enum):
|
||||||
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
||||||
|
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
|
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
|
||||||
|
model_class: Type[MaxKBBaseModel],
|
||||||
**keywords):
|
**keywords):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.desc = desc
|
self.desc = desc
|
||||||
self.model_type = model_type.name
|
self.model_type = model_type.name
|
||||||
self.model_credential = model_credential
|
self.model_credential = model_credential
|
||||||
|
self.model_class = model_class
|
||||||
if keywords is not None:
|
if keywords is not None:
|
||||||
for key in keywords.keys():
|
for key in keywords.keys():
|
||||||
self.__setattr__(key, keywords.get(key))
|
self.__setattr__(key, keywords.get(key))
|
||||||
@ -143,10 +160,66 @@ class ModelInfo:
|
|||||||
def get_model_type(self):
|
def get_model_type(self):
|
||||||
return self.model_type
|
return self.model_type
|
||||||
|
|
||||||
|
def get_model_class(self):
|
||||||
|
return self.model_class
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return reduce(lambda x, y: {**x, **y},
|
return reduce(lambda x, y: {**x, **y},
|
||||||
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
||||||
not attr.startswith("__") and not attr == 'model_credential'], {})
|
not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfoManage:
|
||||||
|
def __init__(self):
|
||||||
|
self.model_dict = {}
|
||||||
|
self.model_list = []
|
||||||
|
self.default_model_list = []
|
||||||
|
self.default_model_dict = {}
|
||||||
|
|
||||||
|
def append_model_info(self, model_info: ModelInfo):
|
||||||
|
self.model_list.append(model_info)
|
||||||
|
model_type_dict = self.model_dict.get(model_info.model_type)
|
||||||
|
if model_type_dict is None:
|
||||||
|
self.model_dict[model_info.model_type] = {model_info.name: model_info}
|
||||||
|
else:
|
||||||
|
model_type_dict[model_info.name] = model_info
|
||||||
|
|
||||||
|
def append_default_model_info(self, model_info: ModelInfo):
|
||||||
|
self.default_model_list.append(model_info)
|
||||||
|
self.default_model_dict[model_info.model_type] = model_info
|
||||||
|
|
||||||
|
def get_model_list(self):
|
||||||
|
return [model.to_dict() for model in self.model_list]
|
||||||
|
|
||||||
|
def get_model_type_list(self):
|
||||||
|
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
||||||
|
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
||||||
|
|
||||||
|
def get_model_info(self, model_type, model_name) -> ModelInfo:
|
||||||
|
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
|
||||||
|
if model_info is None:
|
||||||
|
raise AppApiException(500, '模型不支持')
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
class builder:
|
||||||
|
def __init__(self):
|
||||||
|
self.modelInfoManage = ModelInfoManage()
|
||||||
|
|
||||||
|
def append_model_info(self, model_info: ModelInfo):
|
||||||
|
self.modelInfoManage.append_model_info(model_info)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def append_model_info_list(self, model_info_list: List[ModelInfo]):
|
||||||
|
for model_info in model_info_list:
|
||||||
|
self.modelInfoManage.append_model_info(model_info)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def append_default_model_info(self, model_info: ModelInfo):
|
||||||
|
self.modelInfoManage.append_default_model_info(model_info)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
return self.modelInfoManage
|
||||||
|
|
||||||
|
|
||||||
class ModelProvideInfo:
|
class ModelProvideInfo:
|
||||||
|
|||||||
@ -9,15 +9,15 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||||
|
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
||||||
|
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
||||||
|
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
|
||||||
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||||
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||||
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
||||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||||
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
|
|
||||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||||
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
|
||||||
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProvideConstants(Enum):
|
class ModelProvideConstants(Enum):
|
||||||
|
|||||||
@ -7,98 +7,30 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||||
ModelInfo, \
|
ModelTypeConst, ModelInfoManage
|
||||||
ModelTypeConst, ValidCode
|
from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
|
||||||
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
|
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
base_azure_llm_model_credential = AzureLLMModelCredential()
|
||||||
|
|
||||||
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
|
default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
|
||||||
|
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
|
||||||
|
)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
model_info_manage = ModelInfoManage.builder().append_default_model_info(default_model_info).append_model_info(
|
||||||
model_type_list = AzureModelProvider().get_model_type_list()
|
default_model_info).build()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
|
|
||||||
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
|
|
||||||
|
|
||||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
|
||||||
|
|
||||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
|
||||||
|
|
||||||
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
|
|
||||||
|
|
||||||
|
|
||||||
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
|
|
||||||
|
|
||||||
model_dict = {
|
|
||||||
'deployment_name': ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
|
|
||||||
base_azure_llm_model_credential, api_version='2024-02-15-preview'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AzureModelProvider(IModelProvider):
|
class AzureModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
|
|
||||||
azure_chat_open_ai = AzureChatModel(
|
|
||||||
azure_endpoint=model_credential.get('api_base'),
|
|
||||||
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
|
|
||||||
deployment_name=model_credential.get('deployment_name'),
|
|
||||||
openai_api_key=model_credential.get('api_key'),
|
|
||||||
openai_api_type="azure"
|
|
||||||
)
|
|
||||||
return azure_chat_open_ai
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return base_azure_llm_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon',
|
||||||
'azure_icon_svg')))
|
'azure_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -0,0 +1,55 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 17:08
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||||
|
|
||||||
|
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
|
||||||
@ -6,15 +6,26 @@
|
|||||||
@date:2024/4/28 11:45
|
@date:2024/4/28 11:45
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_openai import AzureChatOpenAI
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
class AzureChatModel(AzureChatOpenAI):
|
class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return AzureChatModel(
|
||||||
|
azure_endpoint=model_credential.get('api_base'),
|
||||||
|
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
|
||||||
|
deployment_name=model_credential.get('deployment_name'),
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
openai_api_type="azure"
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
try:
|
try:
|
||||||
return super().get_num_tokens_from_messages(messages)
|
return super().get_num_tokens_from_messages(messages)
|
||||||
|
|||||||
@ -0,0 +1,48 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 17:51
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -7,91 +7,35 @@
|
|||||||
@Date :5/12/24 7:40 AM
|
@Date :5/12/24 7:40 AM
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||||
ModelInfo, ModelTypeConst, ValidCode
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.deepseek_model_provider.model.deepseek_chat_model import DeepSeekChatModel
|
from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
|
||||||
|
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = DeepSeekModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
|
|
||||||
for key in ['api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = DeepSeekModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
|
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
|
||||||
|
|
||||||
model_dict = {
|
deepseek_chat = ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
|
||||||
'deepseek-chat': ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
|
deepseek_llm_model_credential, DeepSeekChatModel
|
||||||
deepseek_llm_model_credential,
|
)
|
||||||
),
|
|
||||||
'deepseek-coder': ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
|
deepseek_coder = ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
|
||||||
deepseek_llm_model_credential,
|
deepseek_llm_model_credential,
|
||||||
),
|
DeepSeekChatModel)
|
||||||
}
|
|
||||||
|
model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_chat).append_model_info(
|
||||||
|
deepseek_coder).append_default_model_info(
|
||||||
|
deepseek_coder).build()
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekModelProvider(IModelProvider):
|
class DeepSeekModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> DeepSeekChatModel:
|
|
||||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
|
||||||
model=model_name,
|
|
||||||
openai_api_base='https://api.deepseek.com',
|
|
||||||
openai_api_key=model_credential.get('api_key')
|
|
||||||
)
|
|
||||||
return deepseek_chat_open_ai
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return deepseek_llm_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
|
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
|
||||||
'deepseek_icon_svg')))
|
'deepseek_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: UTF-8 -*-
|
|
||||||
"""
|
|
||||||
@Project :MaxKB
|
|
||||||
@File :deepseek_chat_model.py
|
|
||||||
@Author :Brian Yang
|
|
||||||
@Date :5/12/24 7:44 AM
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekChatModel(ChatOpenAI):
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
||||||
try:
|
|
||||||
return super().get_num_tokens_from_messages(messages)
|
|
||||||
except Exception as e:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
try:
|
|
||||||
return super().get_num_tokens(text)
|
|
||||||
except Exception as e:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return len(tokenizer.encode(text))
|
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :MaxKB
|
||||||
|
@File :llm.py
|
||||||
|
@Author :Brian Yang
|
||||||
|
@Date :5/12/24 7:44 AM
|
||||||
|
"""
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base='https://api.deepseek.com',
|
||||||
|
openai_api_key=model_credential.get('api_key')
|
||||||
|
)
|
||||||
|
return deepseek_chat_open_ai
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 17:57
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -7,93 +7,36 @@
|
|||||||
@Date :5/13/24 7:47 AM
|
@Date :5/13/24 7:47 AM
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||||
ModelInfo, ModelTypeConst, ValidCode
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.gemini_model_provider.model.gemini_chat_model import GeminiChatModel
|
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
|
||||||
|
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = GeminiModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
|
|
||||||
for key in ['api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = GeminiModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
gemini_llm_model_credential = GeminiLLMModelCredential()
|
gemini_llm_model_credential = GeminiLLMModelCredential()
|
||||||
|
|
||||||
model_dict = {
|
gemini_1_pro = ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
|
||||||
'gemini-1.0-pro': ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
|
ModelTypeConst.LLM,
|
||||||
|
gemini_llm_model_credential,
|
||||||
|
GeminiChatModel)
|
||||||
|
|
||||||
|
gemini_1_pro_vision = ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
|
||||||
ModelTypeConst.LLM,
|
ModelTypeConst.LLM,
|
||||||
gemini_llm_model_credential,
|
gemini_llm_model_credential,
|
||||||
),
|
GeminiChatModel)
|
||||||
'gemini-1.0-pro-vision': ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
|
|
||||||
ModelTypeConst.LLM,
|
model_info_manage = ModelInfoManage.builder().append_model_info(gemini_1_pro).append_model_info(
|
||||||
gemini_llm_model_credential,
|
gemini_1_pro_vision).append_default_model_info(gemini_1_pro).build()
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiModelProvider(IModelProvider):
|
class GeminiModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
|
|
||||||
**model_kwargs) -> GeminiChatModel:
|
|
||||||
gemini_chat = GeminiChatModel(
|
|
||||||
model=model_name,
|
|
||||||
google_api_key=model_credential.get('api_key')
|
|
||||||
)
|
|
||||||
return gemini_chat
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return gemini_llm_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
|
return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
|
||||||
'gemini_icon_svg')))
|
'gemini_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: UTF-8 -*-
|
|
||||||
"""
|
|
||||||
@Project :MaxKB
|
|
||||||
@File :gemini_chat_model.py
|
|
||||||
@Author :Brian Yang
|
|
||||||
@Date :5/13/24 7:40 AM
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiChatModel(ChatGoogleGenerativeAI):
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
||||||
try:
|
|
||||||
return super().get_num_tokens_from_messages(messages)
|
|
||||||
except Exception as e:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
try:
|
|
||||||
return super().get_num_tokens(text)
|
|
||||||
except Exception as e:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return len(tokenizer.encode(text))
|
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :MaxKB
|
||||||
|
@File :llm.py
|
||||||
|
@Author :Brian Yang
|
||||||
|
@Date :5/13/24 7:40 AM
|
||||||
|
"""
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
gemini_chat = GeminiChatModel(
|
||||||
|
model=model_name,
|
||||||
|
google_api_key=model_credential.get('api_key')
|
||||||
|
)
|
||||||
|
return gemini_chat
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:06
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -7,103 +7,36 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
|
||||||
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||||
ModelInfo, \
|
ModelTypeConst, ModelInfoManage
|
||||||
ModelTypeConst, ValidCode
|
from setting.models_provider.impl.kimi_model_provider.credential.llm import KimiLLMModelCredential
|
||||||
|
from setting.models_provider.impl.kimi_model_provider.model.llm import KimiChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
from setting.models_provider.impl.kimi_model_provider.model.kimi_chat_model import KimiChatModel
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = KimiModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
|
|
||||||
for key in ['api_base', 'api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
# llm_kimi = Moonshot(
|
|
||||||
# model_name=model_name,
|
|
||||||
# base_url=model_credential['api_base'],
|
|
||||||
# moonshot_api_key=model_credential['api_key']
|
|
||||||
# )
|
|
||||||
|
|
||||||
model = KimiModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_base = forms.TextInputField('API 域名', required=True)
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
kimi_llm_model_credential = KimiLLMModelCredential()
|
kimi_llm_model_credential = KimiLLMModelCredential()
|
||||||
|
|
||||||
model_dict = {
|
moonshot_v1_8k = ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||||
'moonshot-v1-8k': ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
KimiChatModel)
|
||||||
),
|
moonshot_v1_32k = ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||||
'moonshot-v1-32k': ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
KimiChatModel)
|
||||||
),
|
moonshot_v1_128k = ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
||||||
'moonshot-v1-128k': ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
|
KimiChatModel)
|
||||||
)
|
|
||||||
}
|
model_info_manage = ModelInfoManage.builder().append_model_info(moonshot_v1_8k).append_model_info(
|
||||||
|
moonshot_v1_32k).append_default_model_info(moonshot_v1_128k).append_default_model_info(moonshot_v1_8k).build()
|
||||||
|
|
||||||
|
|
||||||
class KimiModelProvider(IModelProvider):
|
class KimiModelProvider(IModelProvider):
|
||||||
|
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_dialogue_number(self):
|
||||||
return 3
|
return 3
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
|
||||||
kimi_chat_open_ai = KimiChatModel(
|
|
||||||
openai_api_base=model_credential['api_base'],
|
|
||||||
openai_api_key=model_credential['api_key'],
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
return kimi_chat_open_ai
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return kimi_llm_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content(
|
return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon',
|
||||||
'kimi_icon_svg')))
|
'kimi_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -2,19 +2,29 @@
|
|||||||
"""
|
"""
|
||||||
@project: maxkb
|
@project: maxkb
|
||||||
@Author:虎
|
@Author:虎
|
||||||
@file: kimi_chat_model.py
|
@file: llm.py
|
||||||
@date:2023/11/10 17:45
|
@date:2023/11/10 17:45
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
class KimiChatModel(ChatOpenAI):
|
class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
kimi_chat_open_ai = KimiChatModel(
|
||||||
|
openai_api_base=model_credential['api_base'],
|
||||||
|
openai_api_key=model_credential['api_key'],
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
return kimi_chat_open_ai
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:19
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
try:
|
||||||
|
model_list = provider.get_base_model_list(model_credential.get('api_base'))
|
||||||
|
except Exception as e:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||||
|
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||||
|
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||||
|
if len(exist) == 0:
|
||||||
|
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model_info: Dict[str, object]):
|
||||||
|
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||||
|
|
||||||
|
def build_model(self, model_info: Dict[str, object]):
|
||||||
|
for key in ['api_key', 'model']:
|
||||||
|
if key not in model_info:
|
||||||
|
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||||
|
self.api_key = model_info.get('api_key')
|
||||||
|
return self
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/3/6 11:48
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List, Dict
|
||||||
|
from urllib.parse import urlparse, ParseResult
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_url(url: str):
|
||||||
|
parse = urlparse(url)
|
||||||
|
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
|
||||||
|
query='',
|
||||||
|
fragment='').geturl()
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
api_base = model_credential.get('api_base', '')
|
||||||
|
base_url = get_base_url(api_base)
|
||||||
|
return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'),
|
||||||
|
openai_api_key=model_credential.get('api_key'))
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -1,24 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: ollama_chat_model.py
|
|
||||||
@date:2024/3/6 11:48
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatModel(ChatOpenAI):
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return len(tokenizer.encode(text))
|
|
||||||
@ -19,106 +19,83 @@ from common.exception.app_exception import AppApiException
|
|||||||
from common.forms import BaseForm
|
from common.forms import BaseForm
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode
|
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
|
||||||
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
|
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
||||||
|
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
""
|
""
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = OllamaModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
try:
|
|
||||||
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
|
|
||||||
except Exception as e:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
|
||||||
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
|
||||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
|
||||||
if len(exist) == 0:
|
|
||||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model_info: Dict[str, object]):
|
|
||||||
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
|
||||||
|
|
||||||
def build_model(self, model_info: Dict[str, object]):
|
|
||||||
for key in ['api_key', 'model']:
|
|
||||||
if key not in model_info:
|
|
||||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
|
||||||
self.api_key = model_info.get('api_key')
|
|
||||||
return self
|
|
||||||
|
|
||||||
api_base = forms.TextInputField('API 域名', required=True)
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
ollama_llm_model_credential = OllamaLLMModelCredential()
|
ollama_llm_model_credential = OllamaLLMModelCredential()
|
||||||
|
model_info_list = [
|
||||||
model_dict = {
|
ModelInfo(
|
||||||
'llama2': ModelInfo(
|
|
||||||
'llama2',
|
'llama2',
|
||||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'llama2:13b': ModelInfo(
|
ModelInfo(
|
||||||
'llama2:13b',
|
'llama2:13b',
|
||||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'llama2:70b': ModelInfo(
|
ModelInfo(
|
||||||
'llama2:70b',
|
'llama2:70b',
|
||||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'llama2-chinese:13b': ModelInfo(
|
ModelInfo(
|
||||||
'llama2-chinese:13b',
|
'llama2-chinese:13b',
|
||||||
'由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。',
|
'由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'llama3:8b': ModelInfo(
|
ModelInfo(
|
||||||
'llama3:8b',
|
'llama3:8b',
|
||||||
'Meta Llama 3:迄今为止最有能力的公开产品LLM。8亿参数。',
|
'Meta Llama 3:迄今为止最有能力的公开产品LLM。80亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'llama3:70b': ModelInfo(
|
ModelInfo(
|
||||||
'llama3:70b',
|
'llama3:70b',
|
||||||
'Meta Llama 3:迄今为止最有能力的公开产品LLM。70亿参数。',
|
'Meta Llama 3:迄今为止最有能力的公开产品LLM。700亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:0.5b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:0.5b',
|
'qwen:0.5b',
|
||||||
'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。0.5亿参数。',
|
'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。5亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:1.8b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:1.8b',
|
'qwen:1.8b',
|
||||||
'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1.8亿参数。',
|
'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。18亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:4b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:4b',
|
'qwen:4b',
|
||||||
'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。4亿参数。',
|
'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。40亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:7b': ModelInfo(
|
|
||||||
|
ModelInfo(
|
||||||
'qwen:7b',
|
'qwen:7b',
|
||||||
'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。7亿参数。',
|
'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。70亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:14b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:14b',
|
'qwen:14b',
|
||||||
'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。14亿参数。',
|
'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。140亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:32b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:32b',
|
'qwen:32b',
|
||||||
'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。32亿参数。',
|
'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。320亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:72b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:72b',
|
'qwen:72b',
|
||||||
'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。72亿参数。',
|
'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。720亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'qwen:110b': ModelInfo(
|
ModelInfo(
|
||||||
'qwen:110b',
|
'qwen:110b',
|
||||||
'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。110亿参数。',
|
'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1100亿参数。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
'phi3': ModelInfo(
|
ModelInfo(
|
||||||
'phi3',
|
'phi3',
|
||||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
|
||||||
}
|
]
|
||||||
|
|
||||||
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
|
ModelInfo(
|
||||||
|
'phi3',
|
||||||
|
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||||
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
|
||||||
|
|
||||||
|
|
||||||
def get_base_url(url: str):
|
def get_base_url(url: str):
|
||||||
@ -169,32 +146,14 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
|
|||||||
|
|
||||||
|
|
||||||
class OllamaModelProvider(IModelProvider):
|
class OllamaModelProvider(IModelProvider):
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
||||||
'ollama_icon_svg')))
|
'ollama_icon_svg')))
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|
||||||
def get_model_list(self, model_type):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
# 如果使用模型不在配置中,则使用默认认证
|
|
||||||
return ollama_llm_model_credential
|
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
|
||||||
api_base = model_credential.get('api_base')
|
|
||||||
base_url = get_base_url(api_base)
|
|
||||||
return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'),
|
|
||||||
openai_api_key=model_credential.get('api_key'))
|
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_dialogue_number(self):
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,49 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:32
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/4/18 15:28
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
azure_chat_open_ai = OpenAIChatModel(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base=model_credential.get('api_base'),
|
||||||
|
openai_api_key=model_credential.get('api_key')
|
||||||
|
)
|
||||||
|
return azure_chat_open_ai
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -1,30 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: openai_chat_model.py
|
|
||||||
@date:2024/4/18 15:28
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatModel(ChatOpenAI):
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
||||||
try:
|
|
||||||
return super().get_num_tokens_from_messages(messages)
|
|
||||||
except Exception as e:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
try:
|
|
||||||
return super().get_num_tokens(text)
|
|
||||||
except Exception as e:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return len(tokenizer.encode(text))
|
|
||||||
@ -7,127 +7,70 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||||
ModelInfo, \
|
ModelTypeConst, ModelInfoManage
|
||||||
ModelTypeConst, ValidCode
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel
|
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = OpenAIModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
|
|
||||||
for key in ['api_base', 'api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = OpenAIModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_base = forms.TextInputField('API 域名', required=True)
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
openai_llm_model_credential = OpenAILLMModelCredential()
|
openai_llm_model_credential = OpenAILLMModelCredential()
|
||||||
|
model_info_list = [
|
||||||
|
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential, OpenAIChatModel
|
||||||
|
),
|
||||||
|
ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-3.5-turbo-0125',
|
||||||
|
'2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-3.5-turbo-1106',
|
||||||
|
'2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-3.5-turbo-0613',
|
||||||
|
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4o-2024-05-13',
|
||||||
|
'2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4-turbo-2024-04-09',
|
||||||
|
'2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel),
|
||||||
|
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
OpenAIChatModel)
|
||||||
|
]
|
||||||
|
|
||||||
model_dict = {
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
'gpt-3.5-turbo': ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
openai_llm_model_credential,
|
openai_llm_model_credential, OpenAIChatModel
|
||||||
),
|
)).build()
|
||||||
'gpt-4': ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4o': ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4-turbo': ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
|
||||||
openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4-turbo-preview': ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-3.5-turbo-0125': ModelInfo('gpt-3.5-turbo-0125',
|
|
||||||
'2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
|
||||||
openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-3.5-turbo-1106': ModelInfo('gpt-3.5-turbo-1106',
|
|
||||||
'2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
|
|
||||||
openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613',
|
|
||||||
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4o-2024-05-13': ModelInfo('gpt-4o-2024-05-13',
|
|
||||||
'2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4-turbo-2024-04-09': ModelInfo('gpt-4-turbo-2024-04-09',
|
|
||||||
'2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4-0125-preview': ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
'gpt-4-1106-preview': ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(IModelProvider):
|
class OpenAIModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel:
|
|
||||||
azure_chat_open_ai = OpenAIChatModel(
|
|
||||||
model=model_name,
|
|
||||||
openai_api_base=model_credential.get('api_base'),
|
|
||||||
openai_api_key=model_credential.get('api_key')
|
|
||||||
)
|
|
||||||
return azure_chat_open_ai
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return openai_llm_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content(
|
return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon',
|
||||||
'openai_icon_svg')))
|
'openai_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -0,0 +1,47 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:41
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -2,19 +2,28 @@
|
|||||||
"""
|
"""
|
||||||
@project: maxkb
|
@project: maxkb
|
||||||
@Author:虎
|
@Author:虎
|
||||||
@file: qwen_chat_model.py
|
@file: llm.py
|
||||||
@date:2024/4/28 11:44
|
@date:2024/4/28 11:44
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatTongyi
|
from langchain_community.chat_models import ChatTongyi
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
class QwenChatModel(ChatTongyi):
|
class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
chat_tong_yi = QwenChatModel(
|
||||||
|
model_name=model_name,
|
||||||
|
dashscope_api_key=model_credential.get('api_key')
|
||||||
|
)
|
||||||
|
return chat_tong_yi
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
@ -7,87 +7,33 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfo, IModelProvider, ValidCode
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel
|
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
|
|
||||||
|
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = QwenModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
for key in ['api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = QwenModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
qwen_model_credential = OpenAILLMModelCredential()
|
qwen_model_credential = OpenAILLMModelCredential()
|
||||||
|
|
||||||
model_dict = {
|
module_info_list = [
|
||||||
'qwen-turbo': ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential),
|
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||||
'qwen-plus': ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential),
|
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||||
'qwen-max': ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential)
|
ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
|
||||||
}
|
]
|
||||||
|
|
||||||
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info(
|
||||||
|
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build()
|
||||||
|
|
||||||
|
|
||||||
class QwenModelProvider(IModelProvider):
|
class QwenModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
|
|
||||||
chat_tong_yi = QwenChatModel(
|
|
||||||
model_name=model_name,
|
|
||||||
dashscope_api_key=model_credential.get('api_key')
|
|
||||||
)
|
|
||||||
return chat_tong_yi
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return qwen_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content(
|
return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon',
|
||||||
'qwen_icon_svg')))
|
'qwen_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -0,0 +1,55 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/12 10:19
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model_info = [model.lower() for model in model.client.models()]
|
||||||
|
if not model_info.__contains__(model_name.lower()):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
||||||
|
for key in ['api_key', 'secret_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model.invoke(
|
||||||
|
[HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model_info: Dict[str, object]):
|
||||||
|
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||||
|
|
||||||
|
def build_model(self, model_info: Dict[str, object]):
|
||||||
|
for key in ['api_key', 'secret_key', 'model']:
|
||||||
|
if key not in model_info:
|
||||||
|
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||||
|
self.api_key = model_info.get('api_key')
|
||||||
|
self.secret_key = model_info.get('secret_key')
|
||||||
|
return self
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2023/11/10 17:45
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_community.chat_models import QianfanChatEndpoint
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return QianfanChatModel(model=model_name,
|
||||||
|
qianfan_ak=model_credential.get('api_key'),
|
||||||
|
qianfan_sk=model_credential.get('secret_key'),
|
||||||
|
streaming=model_kwargs.get('streaming', False))
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -1,32 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: qian_fan_chat_model.py
|
|
||||||
@date:2023/11/10 17:45
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import Optional, List, Any, Iterator, cast
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
|
||||||
from langchain.load import dumpd
|
|
||||||
from langchain.schema import LLMResult
|
|
||||||
from langchain.schema.language_model import LanguageModelInput
|
|
||||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
|
||||||
from langchain.schema.runnable import RunnableConfig
|
|
||||||
from langchain_community.chat_models import QianfanChatEndpoint
|
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
|
||||||
|
|
||||||
|
|
||||||
class QianfanChatModel(QianfanChatEndpoint):
|
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
|
||||||
return len(tokenizer.encode(text))
|
|
||||||
@ -7,121 +7,53 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
from langchain_community.chat_models import QianfanChatEndpoint
|
|
||||||
from qianfan import ChatCompletion
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfo, IModelProvider, ValidCode
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel
|
from setting.models_provider.impl.wenxin_model_provider.credential.llm import WenxinLLMModelCredential
|
||||||
|
from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = WenxinModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
model = WenxinModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model_info = [model.lower() for model in model.client.models()]
|
|
||||||
if not model_info.__contains__(model_name.lower()):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
|
||||||
for key in ['api_key', 'secret_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model.invoke(
|
|
||||||
[HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model_info: Dict[str, object]):
|
|
||||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
|
||||||
|
|
||||||
def build_model(self, model_info: Dict[str, object]):
|
|
||||||
for key in ['api_key', 'secret_key', 'model']:
|
|
||||||
if key not in model_info:
|
|
||||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
|
||||||
self.api_key = model_info.get('api_key')
|
|
||||||
self.secret_key = model_info.get('secret_key')
|
|
||||||
return self
|
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
|
||||||
|
|
||||||
|
|
||||||
win_xin_llm_model_credential = WenxinLLMModelCredential()
|
win_xin_llm_model_credential = WenxinLLMModelCredential()
|
||||||
model_dict = {
|
model_info_list = [ModelInfo('ERNIE-Bot-4',
|
||||||
'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4',
|
|
||||||
'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('ERNIE-Bot',
|
||||||
|
'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('ERNIE-Bot-turbo',
|
||||||
|
'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('BLOOMZ-7B',
|
||||||
|
'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('Llama-2-7b-chat',
|
||||||
|
'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('Llama-2-13b-chat',
|
||||||
|
'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('Llama-2-70b-chat',
|
||||||
|
'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
|
||||||
|
ModelInfo('Qianfan-Chinese-Llama-2-7B',
|
||||||
|
'千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。',
|
||||||
|
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel)
|
||||||
|
]
|
||||||
|
|
||||||
'ERNIE-Bot': ModelInfo('ERNIE-Bot',
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
ModelInfo('ERNIE-Bot-4',
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo',
|
win_xin_llm_model_credential,
|
||||||
'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。',
|
QianfanChatModel)).build()
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
|
||||||
|
|
||||||
'BLOOMZ-7B': ModelInfo('BLOOMZ-7B',
|
|
||||||
'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。',
|
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
|
||||||
|
|
||||||
'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat',
|
|
||||||
'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。',
|
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
|
||||||
|
|
||||||
'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat',
|
|
||||||
'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。',
|
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
|
||||||
|
|
||||||
'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat',
|
|
||||||
'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。',
|
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential),
|
|
||||||
|
|
||||||
'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B',
|
|
||||||
'千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。',
|
|
||||||
ModelTypeConst.LLM, win_xin_llm_model_credential)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WenxinModelProvider(IModelProvider):
|
class WenxinModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 2
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
|
|
||||||
**model_kwargs) -> QianfanChatEndpoint:
|
|
||||||
return QianfanChatModel(model=model_name,
|
|
||||||
qianfan_ak=model_credential.get('api_key'),
|
|
||||||
qianfan_sk=model_credential.get('secret_key'),
|
|
||||||
streaming=model_kwargs.get('streaming', False))
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|
||||||
def get_model_list(self, model_type):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return win_xin_llm_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
|
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
|
||||||
|
|||||||
@ -0,0 +1,51 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/12 10:29
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||||
|
|
||||||
|
spark_api_url = forms.TextInputField('API 域名', required=True)
|
||||||
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
@ -7,7 +7,7 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Any, Iterator
|
from typing import List, Optional, Any, Iterator, Dict
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatSparkLLM
|
from langchain_community.chat_models import ChatSparkLLM
|
||||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
||||||
@ -16,9 +16,21 @@ from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_stri
|
|||||||
from langchain_core.outputs import ChatGenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
class XFChatSparkLLM(ChatSparkLLM):
|
class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return XFChatSparkLLM(
|
||||||
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||||
|
spark_api_url=model_credential.get('spark_api_url'),
|
||||||
|
spark_llm_domain=model_name
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
@ -7,97 +7,33 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
from langchain_community.chat_models import ChatSparkLLM
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
|
||||||
ModelInfo, IModelProvider, ValidCode
|
|
||||||
from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM
|
|
||||||
from smartdoc.conf import PROJECT_DIR
|
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
|
from common.util.file_util import get_file_content
|
||||||
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
|
ModelInfoManage
|
||||||
|
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||||
|
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||||
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
ssl._create_default_https_context = ssl.create_default_context()
|
ssl._create_default_https_context = ssl.create_default_context()
|
||||||
|
|
||||||
|
|
||||||
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = XunFeiModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
|
|
||||||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = XunFeiModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
|
||||||
|
|
||||||
spark_api_url = forms.TextInputField('API 域名', required=True)
|
|
||||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
|
||||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
|
||||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
qwen_model_credential = XunFeiLLMModelCredential()
|
qwen_model_credential = XunFeiLLMModelCredential()
|
||||||
|
model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
|
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)
|
||||||
|
]
|
||||||
|
|
||||||
model_dict = {
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential),
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()
|
||||||
'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential),
|
|
||||||
'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class XunFeiModelProvider(IModelProvider):
|
class XunFeiModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM:
|
|
||||||
zhipuai_chat = XFChatSparkLLM(
|
|
||||||
spark_app_id=model_credential.get('spark_app_id'),
|
|
||||||
spark_api_key=model_credential.get('spark_api_key'),
|
|
||||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
|
||||||
spark_api_url=model_credential.get('spark_api_url'),
|
|
||||||
spark_llm_domain=model_name
|
|
||||||
)
|
|
||||||
return zhipuai_chat
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return qwen_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content(
|
return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon',
|
||||||
'xf_icon_svg')))
|
'xf_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -0,0 +1,47 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/12 10:46
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -2,19 +2,29 @@
|
|||||||
"""
|
"""
|
||||||
@project: maxkb
|
@project: maxkb
|
||||||
@Author:虎
|
@Author:虎
|
||||||
@file: zhipu_chat_model.py
|
@file: llm.py
|
||||||
@date:2024/4/28 11:42
|
@date:2024/4/28 11:42
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatZhipuAI
|
from langchain_community.chat_models import ChatZhipuAI
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
class ZhipuChatModel(ChatZhipuAI):
|
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
zhipuai_chat = ZhipuChatModel(
|
||||||
|
temperature=0.5,
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
return zhipuai_chat
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
@ -7,88 +7,30 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
from langchain_community.chat_models import ChatZhipuAI
|
|
||||||
|
|
||||||
from common import forms
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfo, IModelProvider, ValidCode
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel
|
from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
|
||||||
|
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
|
||||||
model_type_list = ZhiPuModelProvider().get_model_type_list()
|
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
for key in ['api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential)
|
|
||||||
model.invoke([HumanMessage(content='你好')])
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, AppApiException):
|
|
||||||
raise e
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]):
|
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
qwen_model_credential = ZhiPuLLMModelCredential()
|
qwen_model_credential = ZhiPuLLMModelCredential()
|
||||||
|
model_info_list = [
|
||||||
model_dict = {
|
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
||||||
'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential),
|
ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
||||||
'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential),
|
ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)
|
||||||
'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential)
|
]
|
||||||
}
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info(
|
||||||
|
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)).build()
|
||||||
|
|
||||||
|
|
||||||
class ZhiPuModelProvider(IModelProvider):
|
class ZhiPuModelProvider(IModelProvider):
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_model_info_manage(self):
|
||||||
return 3
|
return model_info_manage
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
|
|
||||||
zhipuai_chat = ZhipuChatModel(
|
|
||||||
temperature=0.5,
|
|
||||||
api_key=model_credential.get('api_key'),
|
|
||||||
model=model_name
|
|
||||||
)
|
|
||||||
return zhipuai_chat
|
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
|
||||||
if model_name in model_dict:
|
|
||||||
return model_dict.get(model_name).model_credential
|
|
||||||
return qwen_model_credential
|
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content(
|
return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon',
|
||||||
'zhipuai_icon_svg')))
|
'zhipuai_icon_svg')))
|
||||||
|
|
||||||
def get_model_list(self, model_type: str):
|
|
||||||
if model_type is None:
|
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
|
||||||
return [model_dict.get(key).to_dict() for key in
|
|
||||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
|
||||||
|
|
||||||
def get_model_type_list(self):
|
|
||||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
|
||||||
|
|||||||
@ -115,7 +115,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
model_name = self.data.get(
|
model_name = self.data.get(
|
||||||
'model_name')
|
'model_name')
|
||||||
credential = self.data.get('credential')
|
credential = self.data.get('credential')
|
||||||
|
provider_handler = ModelProvideConstants[provider].value
|
||||||
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
||||||
model_name)
|
model_name)
|
||||||
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
||||||
@ -124,7 +124,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
for k in source_encryption_model_credential.keys():
|
for k in source_encryption_model_credential.keys():
|
||||||
if credential[k] == source_encryption_model_credential[k]:
|
if credential[k] == source_encryption_model_credential[k]:
|
||||||
credential[k] = source_model_credential[k]
|
credential[k] = source_model_credential[k]
|
||||||
return credential, model_credential
|
return credential, model_credential, provider_handler
|
||||||
|
|
||||||
class Create(serializers.Serializer):
|
class Create(serializers.Serializer):
|
||||||
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
@ -145,13 +145,10 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
name=self.data.get('name')).exists():
|
name=self.data.get('name')).exists():
|
||||||
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
||||||
# 校验模型认证数据
|
# 校验模型认证数据
|
||||||
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'),
|
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'),
|
||||||
self.data.get(
|
self.data.get('model_name'),
|
||||||
'model_name')).is_valid(
|
self.data.get('credential')
|
||||||
self.data.get('model_type'),
|
)
|
||||||
self.data.get('model_name'),
|
|
||||||
self.data.get('credential'),
|
|
||||||
raise_exception=True)
|
|
||||||
|
|
||||||
def insert(self, user_id, with_valid=False):
|
def insert(self, user_id, with_valid=False):
|
||||||
status = Status.SUCCESS
|
status = Status.SUCCESS
|
||||||
@ -232,16 +229,17 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
if model is None:
|
if model is None:
|
||||||
raise AppApiException(500, '不存在的id')
|
raise AppApiException(500, '不存在的id')
|
||||||
else:
|
else:
|
||||||
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
|
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
||||||
|
data={**instance, 'user_id': user_id}).is_valid(
|
||||||
model=model)
|
model=model)
|
||||||
try:
|
try:
|
||||||
model.status = Status.SUCCESS
|
model.status = Status.SUCCESS
|
||||||
# 校验模型认证数据
|
# 校验模型认证数据
|
||||||
model_credential.is_valid(
|
provider_handler.is_valid_credential(model.model_type,
|
||||||
model.model_type,
|
instance.get("model_name"),
|
||||||
instance.get("model_name"),
|
credential,
|
||||||
credential,
|
raise_exception=True)
|
||||||
raise_exception=True)
|
|
||||||
except AppApiException as e:
|
except AppApiException as e:
|
||||||
if e.code == ValidCode.model_not_fount:
|
if e.code == ValidCode.model_not_fount:
|
||||||
model.status = Status.DOWNLOAD
|
model.status = Status.DOWNLOAD
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user