refactor: 优化模型管理 (#749)

This commit is contained in:
shaohuzhang1 2024-07-12 14:15:42 +08:00 committed by GitHub
parent 1452df7f1c
commit 410e065b52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 1082 additions and 1028 deletions

View File

@ -9,9 +9,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import Dict, Iterator from typing import Dict, Iterator, Type, List
from langchain.chat_models.base import BaseChatModel from pydantic.v1 import BaseModel
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
@ -47,39 +47,53 @@ class DownModelChunk:
class IModelProvider(ABC): class IModelProvider(ABC):
@abstractmethod
def get_model_info_manage(self):
pass
@abstractmethod @abstractmethod
def get_model_provide_info(self): def get_model_provide_info(self):
pass pass
@abstractmethod
def get_model_type_list(self): def get_model_type_list(self):
pass return self.get_model_info_manage().get_model_type_list()
@abstractmethod
def get_model_list(self, model_type): def get_model_list(self, model_type):
pass if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return self.get_model_info_manage().get_model_list()
@abstractmethod
def get_model_credential(self, model_type, model_name): def get_model_credential(self, model_type, model_name):
pass model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential
@abstractmethod def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel: model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
pass return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
raise_exception=raise_exception)
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
@abstractmethod
def get_dialogue_number(self): def get_dialogue_number(self):
pass return 3
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
raise AppApiException(500, "当前平台不支持下载模型") raise AppApiException(500, "当前平台不支持下载模型")
class MaxKBBaseModel(ABC):
@staticmethod
@abstractmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
class BaseModelCredential(ABC): class BaseModelCredential(ABC):
@abstractmethod @abstractmethod
def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False): def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
pass pass
@abstractmethod @abstractmethod
@ -113,15 +127,18 @@ class BaseModelCredential(ABC):
class ModelTypeConst(Enum): class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': '大语言模型'} LLM = {'code': 'LLM', 'message': '大语言模型'}
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
class ModelInfo: class ModelInfo:
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential, def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
model_class: Type[MaxKBBaseModel],
**keywords): **keywords):
self.name = name self.name = name
self.desc = desc self.desc = desc
self.model_type = model_type.name self.model_type = model_type.name
self.model_credential = model_credential self.model_credential = model_credential
self.model_class = model_class
if keywords is not None: if keywords is not None:
for key in keywords.keys(): for key in keywords.keys():
self.__setattr__(key, keywords.get(key)) self.__setattr__(key, keywords.get(key))
@ -143,10 +160,66 @@ class ModelInfo:
def get_model_type(self): def get_model_type(self):
return self.model_type return self.model_type
def get_model_class(self):
return self.model_class
def to_dict(self): def to_dict(self):
return reduce(lambda x, y: {**x, **y}, return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if [{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__") and not attr == 'model_credential'], {}) not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
class ModelInfoManage:
def __init__(self):
self.model_dict = {}
self.model_list = []
self.default_model_list = []
self.default_model_dict = {}
def append_model_info(self, model_info: ModelInfo):
self.model_list.append(model_info)
model_type_dict = self.model_dict.get(model_info.model_type)
if model_type_dict is None:
self.model_dict[model_info.model_type] = {model_info.name: model_info}
else:
model_type_dict[model_info.name] = model_info
def append_default_model_info(self, model_info: ModelInfo):
self.default_model_list.append(model_info)
self.default_model_dict[model_info.model_type] = model_info
def get_model_list(self):
return [model.to_dict() for model in self.model_list]
def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
def get_model_info(self, model_type, model_name) -> ModelInfo:
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
if model_info is None:
raise AppApiException(500, '模型不支持')
return model_info
class builder:
def __init__(self):
self.modelInfoManage = ModelInfoManage()
def append_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_model_info(model_info)
return self
def append_model_info_list(self, model_info_list: List[ModelInfo]):
for model_info in model_info_list:
self.modelInfoManage.append_model_info(model_info)
return self
def append_default_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_default_model_info(model_info)
return self
def build(self):
return self.modelInfoManage
class ModelProvideInfo: class ModelProvideInfo:

View File

@ -9,15 +9,15 @@
from enum import Enum from enum import Enum
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
class ModelProvideConstants(Enum): class ModelProvideConstants(Enum):

View File

@ -7,98 +7,30 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelInfo, \ ModelTypeConst, ModelInfoManage
ModelTypeConst, ValidCode from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
base_azure_llm_model_credential = AzureLLMModelCredential()
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential): default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): model_info_manage = ModelInfoManage.builder().append_default_model_info(default_model_info).append_model_info(
model_type_list = AzureModelProvider().get_model_type_list() default_model_info).build()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
model_dict = {
'deployment_name': ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
base_azure_llm_model_credential, api_version='2024-02-15-preview'
)
}
class AzureModelProvider(IModelProvider): class AzureModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
azure_chat_open_ai = AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure"
)
return azure_chat_open_ai
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return base_azure_llm_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content( return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon',
'azure_icon_svg'))) 'azure_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:08
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)

View File

@ -6,15 +6,26 @@
@date2024/4/28 11:45 @date2024/4/28 11:45
@desc: @desc:
""" """
from typing import List from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class AzureChatModel(AzureChatOpenAI): class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure"
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try: try:
return super().get_num_tokens_from_messages(messages) return super().get_num_tokens_from_messages(messages)

View File

@ -0,0 +1,48 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:51
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -7,91 +7,35 @@
@Date 5/12/24 7:40 AM @Date 5/12/24 7:40 AM
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfo, ModelTypeConst, ValidCode ModelInfoManage
from setting.models_provider.impl.deepseek_model_provider.model.deepseek_chat_model import DeepSeekChatModel from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = DeepSeekModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = DeepSeekModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
deepseek_llm_model_credential = DeepSeekLLMModelCredential() deepseek_llm_model_credential = DeepSeekLLMModelCredential()
model_dict = { deepseek_chat = ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM,
'deepseek-chat': ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM, deepseek_llm_model_credential, DeepSeekChatModel
deepseek_llm_model_credential, )
),
'deepseek-coder': ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM, deepseek_coder = ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM,
deepseek_llm_model_credential, deepseek_llm_model_credential,
), DeepSeekChatModel)
}
model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_chat).append_model_info(
deepseek_coder).append_default_model_info(
deepseek_coder).build()
class DeepSeekModelProvider(IModelProvider): class DeepSeekModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> DeepSeekChatModel:
deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key')
)
return deepseek_chat_open_ai
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return deepseek_llm_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content( return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
'deepseek_icon_svg'))) 'deepseek_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -1,30 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File deepseek_chat_model.py
@Author Brian Yang
@Date 5/12/24 7:44 AM
"""
from typing import List
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
class DeepSeekChatModel(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File llm.py
@Author Brian Yang
@Date 5/12/24 7:44 AM
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key')
)
return deepseek_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,48 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:57
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -7,93 +7,36 @@
@Date 5/13/24 7:47 AM @Date 5/13/24 7:47 AM
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfo, ModelTypeConst, ValidCode ModelInfoManage
from setting.models_provider.impl.gemini_model_provider.model.gemini_chat_model import GeminiChatModel from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = GeminiModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = GeminiModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
gemini_llm_model_credential = GeminiLLMModelCredential() gemini_llm_model_credential = GeminiLLMModelCredential()
model_dict = { gemini_1_pro = ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型随Google更新而更新',
'gemini-1.0-pro': ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型随Google更新而更新', ModelTypeConst.LLM,
gemini_llm_model_credential,
GeminiChatModel)
gemini_1_pro_vision = ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型随Google更新而更新',
ModelTypeConst.LLM, ModelTypeConst.LLM,
gemini_llm_model_credential, gemini_llm_model_credential,
), GeminiChatModel)
'gemini-1.0-pro-vision': ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型随Google更新而更新',
ModelTypeConst.LLM, model_info_manage = ModelInfoManage.builder().append_model_info(gemini_1_pro).append_model_info(
gemini_llm_model_credential, gemini_1_pro_vision).append_default_model_info(gemini_1_pro).build()
),
}
class GeminiModelProvider(IModelProvider): class GeminiModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
**model_kwargs) -> GeminiChatModel:
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key')
)
return gemini_chat
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return gemini_llm_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content( return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
'gemini_icon_svg'))) 'gemini_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -1,30 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File gemini_chat_model.py
@Author Brian Yang
@Date 5/13/24 7:40 AM
"""
from typing import List
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_google_genai import ChatGoogleGenerativeAI
from common.config.tokenizer_manage_config import TokenizerManage
class GeminiChatModel(ChatGoogleGenerativeAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File llm.py
@Author Brian Yang
@Date 5/13/24 7:40 AM
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_google_genai import ChatGoogleGenerativeAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key')
)
return gemini_chat
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,49 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:06
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -7,103 +7,36 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from langchain.chat_models.base import BaseChatModel
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelInfo, \ ModelTypeConst, ModelInfoManage
ModelTypeConst, ValidCode from setting.models_provider.impl.kimi_model_provider.credential.llm import KimiLLMModelCredential
from setting.models_provider.impl.kimi_model_provider.model.llm import KimiChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from setting.models_provider.impl.kimi_model_provider.model.kimi_chat_model import KimiChatModel
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = KimiModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
# llm_kimi = Moonshot(
# model_name=model_name,
# base_url=model_credential['api_base'],
# moonshot_api_key=model_credential['api_key']
# )
model = KimiModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
kimi_llm_model_credential = KimiLLMModelCredential() kimi_llm_model_credential = KimiLLMModelCredential()
model_dict = { moonshot_v1_8k = ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
'moonshot-v1-8k': ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential, KimiChatModel)
), moonshot_v1_32k = ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
'moonshot-v1-32k': ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential, KimiChatModel)
), moonshot_v1_128k = ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential,
'moonshot-v1-128k': ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential, KimiChatModel)
)
} model_info_manage = ModelInfoManage.builder().append_model_info(moonshot_v1_8k).append_model_info(
moonshot_v1_32k).append_default_model_info(moonshot_v1_128k).append_default_model_info(moonshot_v1_8k).build()
class KimiModelProvider(IModelProvider): class KimiModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_dialogue_number(self): def get_dialogue_number(self):
return 3 return 3
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'],
openai_api_key=model_credential['api_key'],
model_name=model_name,
)
return kimi_chat_open_ai
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return kimi_llm_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content( return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon',
'kimi_icon_svg'))) 'kimi_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -2,19 +2,29 @@
""" """
@project: maxkb @project: maxkb
@Author @Author
@file kimi_chat_model.py @file llm.py
@date2023/11/10 17:45 @date2023/11/10 17:45
@desc: @desc:
""" """
from typing import List from typing import List, Dict
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class KimiChatModel(ChatOpenAI): class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'],
openai_api_key=model_credential['api_key'],
model_name=model_name,
)
return kimi_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer() tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

View File

@ -0,0 +1,44 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:19
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
return self
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,40 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2024/3/6 11:48
@desc:
"""
from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
def get_base_url(url: str):
parse = urlparse(url)
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
query='',
fragment='').geturl()
class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'),
openai_api_key=model_credential.get('api_key'))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -1,24 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file ollama_chat_model.py
@date2024/3/6 11:48
@desc:
"""
from typing import List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
class OllamaChatModel(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -19,106 +19,83 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
"" ""
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = OllamaModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
return self
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
ollama_llm_model_credential = OllamaLLMModelCredential() ollama_llm_model_credential = OllamaLLMModelCredential()
model_info_list = [
model_dict = { ModelInfo(
'llama2': ModelInfo(
'llama2', 'llama2',
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'llama2:13b': ModelInfo( ModelInfo(
'llama2:13b', 'llama2:13b',
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'llama2:70b': ModelInfo( ModelInfo(
'llama2:70b', 'llama2:70b',
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'llama2-chinese:13b': ModelInfo( ModelInfo(
'llama2-chinese:13b', 'llama2-chinese:13b',
'由于Llama2本身的中文对齐较弱我们采用中文指令集对meta-llama/Llama-2-13b-chat-hf进行LoRA微调使其具备较强的中文对话能力。', '由于Llama2本身的中文对齐较弱我们采用中文指令集对meta-llama/Llama-2-13b-chat-hf进行LoRA微调使其具备较强的中文对话能力。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'llama3:8b': ModelInfo( ModelInfo(
'llama3:8b', 'llama3:8b',
'Meta Llama 3迄今为止最有能力的公开产品LLM。8亿参数。', 'Meta Llama 3迄今为止最有能力的公开产品LLM。80亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'llama3:70b': ModelInfo( ModelInfo(
'llama3:70b', 'llama3:70b',
'Meta Llama 3迄今为止最有能力的公开产品LLM。70亿参数。', 'Meta Llama 3迄今为止最有能力的公开产品LLM。700亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:0.5b': ModelInfo( ModelInfo(
'qwen:0.5b', 'qwen:0.5b',
'qwen 1.5 0.5b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。0.5亿参数。', 'qwen 1.5 0.5b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。5亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:1.8b': ModelInfo( ModelInfo(
'qwen:1.8b', 'qwen:1.8b',
'qwen 1.5 1.8b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1.8亿参数。', 'qwen 1.5 1.8b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。18亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:4b': ModelInfo( ModelInfo(
'qwen:4b', 'qwen:4b',
'qwen 1.5 4b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。4亿参数。', 'qwen 1.5 4b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。40亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:7b': ModelInfo(
ModelInfo(
'qwen:7b', 'qwen:7b',
'qwen 1.5 7b 相较于以往版本模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。7亿参数。', 'qwen 1.5 7b 相较于以往版本模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。70亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:14b': ModelInfo( ModelInfo(
'qwen:14b', 'qwen:14b',
'qwen 1.5 14b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。14亿参数。', 'qwen 1.5 14b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。140亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:32b': ModelInfo( ModelInfo(
'qwen:32b', 'qwen:32b',
'qwen 1.5 32b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。32亿参数。', 'qwen 1.5 32b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。320亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:72b': ModelInfo( ModelInfo(
'qwen:72b', 'qwen:72b',
'qwen 1.5 72b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。72亿参数。', 'qwen 1.5 72b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。720亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'qwen:110b': ModelInfo( ModelInfo(
'qwen:110b', 'qwen:110b',
'qwen 1.5 110b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。110亿参数。', 'qwen 1.5 110b 相较于以往版本模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1100亿参数。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
'phi3': ModelInfo( ModelInfo(
'phi3', 'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential), ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
} ]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo(
'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
def get_base_url(url: str): def get_base_url(url: str):
@ -169,32 +146,14 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
class OllamaModelProvider(IModelProvider): class OllamaModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content( return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
'ollama_icon_svg'))) 'ollama_icon_svg')))
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
# 如果使用模型不在配置中,则使用默认认证
return ollama_llm_model_credential
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
api_base = model_credential.get('api_base')
base_url = get_base_url(api_base)
return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'),
openai_api_key=model_credential.get('api_key'))
def get_dialogue_number(self): def get_dialogue_number(self):
return 2 return 2

View File

@ -0,0 +1,49 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:32
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,34 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2024/4/18 15:28
@desc:
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key')
)
return azure_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -1,30 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file openai_chat_model.py
@date2024/4/18 15:28
@desc:
"""
from typing import List
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
class OpenAIChatModel(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -7,127 +7,70 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelInfo, \ ModelTypeConst, ModelInfoManage
ModelTypeConst, ValidCode from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = OpenAIModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = OpenAIModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
openai_llm_model_credential = OpenAILLMModelCredential() openai_llm_model_credential = OpenAILLMModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
),
ModelInfo('gpt-4', '最新的gpt-4随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4o', '最新的GPT-4o比gpt-4-turbo更便宜、更快随OpenAI调整而更新',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview随OpenAI调整而更新',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-3.5-turbo-0125',
'2024年1月25日的gpt-3.5-turbo快照支持上下文长度16,385 tokens', ModelTypeConst.LLM,
openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-3.5-turbo-1106',
'2023年11月6日的gpt-3.5-turbo快照支持上下文长度16,385 tokens', ModelTypeConst.LLM,
openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-3.5-turbo-0613',
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照将于2024年6月13日弃用',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4o-2024-05-13',
'2024年5月13日的gpt-4o快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4-turbo-2024-04-09',
'2024年4月9日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel)
]
model_dict = { model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
'gpt-3.5-turbo': ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM, ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, openai_llm_model_credential, OpenAIChatModel
), )).build()
'gpt-4': ModelInfo('gpt-4', '最新的gpt-4随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-4o': ModelInfo('gpt-4o', '最新的GPT-4o比gpt-4-turbo更便宜、更快随OpenAI调整而更新',
ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-4-turbo': ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential,
),
'gpt-4-turbo-preview': ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview随OpenAI调整而更新',
ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-3.5-turbo-0125': ModelInfo('gpt-3.5-turbo-0125',
'2024年1月25日的gpt-3.5-turbo快照支持上下文长度16,385 tokens', ModelTypeConst.LLM,
openai_llm_model_credential,
),
'gpt-3.5-turbo-1106': ModelInfo('gpt-3.5-turbo-1106',
'2023年11月6日的gpt-3.5-turbo快照支持上下文长度16,385 tokens', ModelTypeConst.LLM,
openai_llm_model_credential,
),
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613',
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照将于2024年6月13日弃用',
ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-4o-2024-05-13': ModelInfo('gpt-4o-2024-05-13',
'2024年5月13日的gpt-4o快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-4-turbo-2024-04-09': ModelInfo('gpt-4-turbo-2024-04-09',
'2024年4月9日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-4-0125-preview': ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
),
'gpt-4-1106-preview': ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
),
}
class OpenAIModelProvider(IModelProvider): class OpenAIModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel:
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key')
)
return azure_chat_open_ai
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return openai_llm_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content( return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon',
'openai_icon_svg'))) 'openai_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:41
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -2,19 +2,28 @@
""" """
@project: maxkb @project: maxkb
@Author @Author
@file qwen_chat_model.py @file llm.py
@date2024/4/28 11:44 @date2024/4/28 11:44
@desc: @desc:
""" """
from typing import List from typing import List, Dict
from langchain_community.chat_models import ChatTongyi from langchain_community.chat_models import ChatTongyi
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class QwenChatModel(ChatTongyi): class QwenChatModel(MaxKBBaseModel, ChatTongyi):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key')
)
return chat_tong_yi
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer() tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

View File

@ -7,87 +7,33 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from langchain_community.chat_models.tongyi import ChatTongyi
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfo, IModelProvider, ValidCode ModelInfoManage
from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = QwenModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = QwenModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
qwen_model_credential = OpenAILLMModelCredential() qwen_model_credential = OpenAILLMModelCredential()
model_dict = { module_info_list = [
'qwen-turbo': ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential), ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
'qwen-plus': ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential), ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
'qwen-max': ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential) ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
} ]
model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info(
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build()
class QwenModelProvider(IModelProvider): class QwenModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key')
)
return chat_tong_yi
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return qwen_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content( return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon',
'qwen_icon_svg'))) 'qwen_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/12 10:19
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
model = provider.get_model(model_type, model_name, model_credential)
model_info = [model.lower() for model in model.client.models()]
if not model_info.__contains__(model_name.lower()):
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
for key in ['api_key', 'secret_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model.invoke(
[HumanMessage(content='你好')])
except Exception as e:
raise e
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'secret_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
self.secret_key = model_info.get('secret_key')
return self
api_key = forms.PasswordInputField('API Key', required=True)
secret_key = forms.PasswordInputField("Secret Key", required=True)

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2023/11/10 17:45
@desc:
"""
from typing import List, Dict
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain_community.chat_models import QianfanChatEndpoint
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -1,32 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file qian_fan_chat_model.py
@date2023/11/10 17:45
@desc:
"""
from typing import Optional, List, Any, Iterator, cast
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd
from langchain.schema import LLMResult
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
from langchain_community.chat_models import QianfanChatEndpoint
from common.config.tokenizer_manage_config import TokenizerManage
class QianfanChatModel(QianfanChatEndpoint):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -7,121 +7,53 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from langchain_community.chat_models import QianfanChatEndpoint
from qianfan import ChatCompletion
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfo, IModelProvider, ValidCode ModelInfoManage
from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel from setting.models_provider.impl.wenxin_model_provider.credential.llm import WenxinLLMModelCredential
from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = WenxinModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
model = WenxinModelProvider().get_model(model_type, model_name, model_credential)
model_info = [model.lower() for model in model.client.models()]
if not model_info.__contains__(model_name.lower()):
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
for key in ['api_key', 'secret_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model.invoke(
[HumanMessage(content='你好')])
except Exception as e:
raise e
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'secret_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
self.secret_key = model_info.get('secret_key')
return self
api_key = forms.PasswordInputField('API Key', required=True)
secret_key = forms.PasswordInputField("Secret Key", required=True)
win_xin_llm_model_credential = WenxinLLMModelCredential() win_xin_llm_model_credential = WenxinLLMModelCredential()
model_dict = { model_info_list = [ModelInfo('ERNIE-Bot-4',
'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4',
'ERNIE-Bot-4是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。', 'ERNIE-Bot-4是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。',
ModelTypeConst.LLM, win_xin_llm_model_credential), ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('ERNIE-Bot',
'ERNIE-Bot是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('ERNIE-Bot-turbo',
'ERNIE-Bot-turbo是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力响应速度更快。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('BLOOMZ-7B',
'BLOOMZ-7B是业内知名的大语言模型由BigScience研发并开源能够以46种语言和13种编程语言输出文本。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('Llama-2-7b-chat',
'Llama-2-7b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-7b-chat是高性能原生开源版本适用于对话场景。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('Llama-2-13b-chat',
'Llama-2-13b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-13b-chat是性能与效果均衡的原生开源版本适用于对话场景。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('Llama-2-70b-chat',
'Llama-2-70b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-70b-chat是高精度效果的原生开源版本。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel),
ModelInfo('Qianfan-Chinese-Llama-2-7B',
'千帆团队在Llama-2-7b基础上的中文增强版本在CMMLU、C-EVAL等中文知识库上表现优异。',
ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel)
]
'ERNIE-Bot': ModelInfo('ERNIE-Bot', model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
'ERNIE-Bot是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。', ModelInfo('ERNIE-Bot-4',
ModelTypeConst.LLM, win_xin_llm_model_credential), 'ERNIE-Bot-4是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力。',
ModelTypeConst.LLM,
'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo', win_xin_llm_model_credential,
'ERNIE-Bot-turbo是百度自行研发的大语言模型覆盖海量中文数据具有更强的对话问答、内容创作生成等能力响应速度更快。', QianfanChatModel)).build()
ModelTypeConst.LLM, win_xin_llm_model_credential),
'BLOOMZ-7B': ModelInfo('BLOOMZ-7B',
'BLOOMZ-7B是业内知名的大语言模型由BigScience研发并开源能够以46种语言和13种编程语言输出文本。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat',
'Llama-2-7b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-7b-chat是高性能原生开源版本适用于对话场景。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat',
'Llama-2-13b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-13b-chat是性能与效果均衡的原生开源版本适用于对话场景。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat',
'Llama-2-70b-chat由Meta AI研发并开源在编码、推理及知识应用等场景表现优秀Llama-2-70b-chat是高精度效果的原生开源版本。',
ModelTypeConst.LLM, win_xin_llm_model_credential),
'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B',
'千帆团队在Llama-2-7b基础上的中文增强版本在CMMLU、C-EVAL等中文知识库上表现优异。',
ModelTypeConst.LLM, win_xin_llm_model_credential)
}
class WenxinModelProvider(IModelProvider): class WenxinModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 2 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object],
**model_kwargs) -> QianfanChatEndpoint:
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False))
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return win_xin_llm_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content( return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(

View File

@ -0,0 +1,51 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/12 10:29
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
spark_api_url = forms.TextInputField('API 域名', required=True)
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)

View File

@ -7,7 +7,7 @@
@desc: @desc:
""" """
from typing import List, Optional, Any, Iterator from typing import List, Optional, Any, Iterator, Dict
from langchain_community.chat_models import ChatSparkLLM from langchain_community.chat_models import ChatSparkLLM
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
@ -16,9 +16,21 @@ from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_stri
from langchain_core.outputs import ChatGenerationChunk from langchain_core.outputs import ChatGenerationChunk
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class XFChatSparkLLM(ChatSparkLLM): class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
spark_llm_domain=model_name
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer() tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

View File

@ -7,97 +7,33 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from langchain_community.chat_models import ChatSparkLLM
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
ModelInfo, IModelProvider, ValidCode
from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM
from smartdoc.conf import PROJECT_DIR
import ssl import ssl
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
from smartdoc.conf import PROJECT_DIR
ssl._create_default_https_context = ssl.create_default_context() ssl._create_default_https_context = ssl.create_default_context()
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = XunFeiModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = XunFeiModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
spark_api_url = forms.TextInputField('API 域名', required=True)
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
qwen_model_credential = XunFeiLLMModelCredential() qwen_model_credential = XunFeiLLMModelCredential()
model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)
]
model_dict = { model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential), ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()
'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential),
'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential)
}
class XunFeiModelProvider(IModelProvider): class XunFeiModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM:
zhipuai_chat = XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
spark_llm_domain=model_name
)
return zhipuai_chat
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return qwen_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content( return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon',
'xf_icon_svg'))) 'xf_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/12 10:46
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -2,19 +2,29 @@
""" """
@project: maxkb @project: maxkb
@Author @Author
@file zhipu_chat_model.py @file llm.py
@date2024/4/28 11:42 @date2024/4/28 11:42
@desc: @desc:
""" """
from typing import List from typing import List, Dict
from langchain_community.chat_models import ChatZhipuAI from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class ZhipuChatModel(ChatZhipuAI): class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
zhipuai_chat = ZhipuChatModel(
temperature=0.5,
api_key=model_credential.get('api_key'),
model=model_name
)
return zhipuai_chat
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer() tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

View File

@ -7,88 +7,30 @@
@desc: @desc:
""" """
import os import os
from typing import Dict
from langchain.schema import HumanMessage
from langchain_community.chat_models import ChatZhipuAI
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfo, IModelProvider, ValidCode ModelInfoManage
from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = ZhiPuModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
qwen_model_credential = ZhiPuLLMModelCredential() qwen_model_credential = ZhiPuLLMModelCredential()
model_info_list = [
model_dict = { ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential), ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential), ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)
'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential) ]
} model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info(
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)).build()
class ZhiPuModelProvider(IModelProvider): class ZhiPuModelProvider(IModelProvider):
def get_dialogue_number(self): def get_model_info_manage(self):
return 3 return model_info_manage
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
zhipuai_chat = ZhipuChatModel(
temperature=0.5,
api_key=model_credential.get('api_key'),
model=model_name
)
return zhipuai_chat
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return qwen_model_credential
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content( return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon',
'zhipuai_icon_svg'))) 'zhipuai_icon_svg')))
def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]

View File

@ -115,7 +115,7 @@ class ModelSerializer(serializers.Serializer):
model_name = self.data.get( model_name = self.data.get(
'model_name') 'model_name')
credential = self.data.get('credential') credential = self.data.get('credential')
provider_handler = ModelProvideConstants[provider].value
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name) model_name)
source_model_credential = json.loads(rsa_long_decrypt(model.credential)) source_model_credential = json.loads(rsa_long_decrypt(model.credential))
@ -124,7 +124,7 @@ class ModelSerializer(serializers.Serializer):
for k in source_encryption_model_credential.keys(): for k in source_encryption_model_credential.keys():
if credential[k] == source_encryption_model_credential[k]: if credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k] credential[k] = source_model_credential[k]
return credential, model_credential return credential, model_credential, provider_handler
class Create(serializers.Serializer): class Create(serializers.Serializer):
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -145,13 +145,10 @@ class ModelSerializer(serializers.Serializer):
name=self.data.get('name')).exists(): name=self.data.get('name')).exists():
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
# 校验模型认证数据 # 校验模型认证数据
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'), ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'),
self.data.get( self.data.get('model_name'),
'model_name')).is_valid( self.data.get('credential')
self.data.get('model_type'), )
self.data.get('model_name'),
self.data.get('credential'),
raise_exception=True)
def insert(self, user_id, with_valid=False): def insert(self, user_id, with_valid=False):
status = Status.SUCCESS status = Status.SUCCESS
@ -232,16 +229,17 @@ class ModelSerializer(serializers.Serializer):
if model is None: if model is None:
raise AppApiException(500, '不存在的id') raise AppApiException(500, '不存在的id')
else: else:
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid( credential, model_credential, provider_handler = ModelSerializer.Edit(
data={**instance, 'user_id': user_id}).is_valid(
model=model) model=model)
try: try:
model.status = Status.SUCCESS model.status = Status.SUCCESS
# 校验模型认证数据 # 校验模型认证数据
model_credential.is_valid( provider_handler.is_valid_credential(model.model_type,
model.model_type, instance.get("model_name"),
instance.get("model_name"), credential,
credential, raise_exception=True)
raise_exception=True)
except AppApiException as e: except AppApiException as e:
if e.code == ValidCode.model_not_fount: if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD model.status = Status.DOWNLOAD