diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 3796b5bb..aa6c525e 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -9,9 +9,9 @@ from abc import ABC, abstractmethod from enum import Enum from functools import reduce -from typing import Dict, Iterator +from typing import Dict, Iterator, Type, List -from langchain.chat_models.base import BaseChatModel +from pydantic.v1 import BaseModel from common.exception.app_exception import AppApiException @@ -47,39 +47,53 @@ class DownModelChunk: class IModelProvider(ABC): + @abstractmethod + def get_model_info_manage(self): + pass @abstractmethod def get_model_provide_info(self): pass - @abstractmethod def get_model_type_list(self): - pass + return self.get_model_info_manage().get_model_type_list() - @abstractmethod def get_model_list(self, model_type): - pass + if model_type is None: + raise AppApiException(500, '模型类型不能为空') + return self.get_model_info_manage().get_model_list() - @abstractmethod def get_model_credential(self, model_type, model_name): - pass + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_credential - @abstractmethod - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel: - pass + def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_credential.is_valid(model_type, model_name, model_credential, self, + raise_exception=raise_exception) + + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel: + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs) - @abstractmethod def get_dialogue_number(self): - pass + return 3 def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: raise AppApiException(500, "当前平台不支持下载模型") +class MaxKBBaseModel(ABC): + @staticmethod + @abstractmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + pass + + class BaseModelCredential(ABC): @abstractmethod - def is_valid(self, model_type: str, model_name, model: Dict[str, object], raise_exception=False): + def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True): pass @abstractmethod @@ -113,15 +127,18 @@ class BaseModelCredential(ABC): class ModelTypeConst(Enum): LLM = {'code': 'LLM', 'message': '大语言模型'} + EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} class ModelInfo: def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential, + model_class: Type[MaxKBBaseModel], **keywords): self.name = name self.desc = desc self.model_type = model_type.name self.model_credential = model_credential + self.model_class = model_class if keywords is not None: for key in keywords.keys(): self.__setattr__(key, keywords.get(key)) @@ -143,10 +160,66 @@ class ModelInfo: def get_model_type(self): return self.model_type + def get_model_class(self): + return self.model_class + def to_dict(self): return reduce(lambda x, y: {**x, **y}, [{attr: self.__getattribute__(attr)} for attr in vars(self) if - not attr.startswith("__") and not attr == 'model_credential'], {}) + not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {}) + + +class ModelInfoManage: + def __init__(self): + self.model_dict = {} + self.model_list = [] + self.default_model_list = [] + self.default_model_dict = {} + + def append_model_info(self, model_info: ModelInfo): + self.model_list.append(model_info) + model_type_dict = self.model_dict.get(model_info.model_type) + if model_type_dict is None: + self.model_dict[model_info.model_type] = {model_info.name: model_info} + else: + model_type_dict[model_info.name] = model_info + + def append_default_model_info(self, model_info: ModelInfo): + self.default_model_list.append(model_info) + self.default_model_dict[model_info.model_type] = model_info + + def get_model_list(self): + return [model.to_dict() for model in self.model_list] + + def get_model_type_list(self): + return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if + len([model for model in self.model_list if model.model_type == _type.name]) > 0] + + def get_model_info(self, model_type, model_name) -> ModelInfo: + model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type)) + if model_info is None: + raise AppApiException(500, '模型不支持') + return model_info + + class builder: + def __init__(self): + self.modelInfoManage = ModelInfoManage() + + def append_model_info(self, model_info: ModelInfo): + self.modelInfoManage.append_model_info(model_info) + return self + + def append_model_info_list(self, model_info_list: List[ModelInfo]): + for model_info in model_info_list: + self.modelInfoManage.append_model_info(model_info) + return self + + def append_default_model_info(self, model_info: ModelInfo): + self.modelInfoManage.append_default_model_info(model_info) + return self + + def build(self): + return self.modelInfoManage class ModelProvideInfo: diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 0a7565f3..1db811f9 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -9,15 +9,15 @@ from enum import Enum from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider +from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider +from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider +from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider -from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider -from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider -from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider class ModelProvideConstants(Enum): diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py index 3164dd8e..8b95dfef 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -7,98 +7,30 @@ @desc: """ import os -from typing import Dict -from langchain.schema import HumanMessage - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ - ModelInfo, \ - ModelTypeConst, ValidCode +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel from smartdoc.conf import PROJECT_DIR +base_azure_llm_model_credential = AzureLLMModelCredential() -class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential): +default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM, + base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview' + ) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = AzureModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_base', 'api_key', 'deployment_name', 'api_version']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = AzureModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确') - else: - return False - - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_version = forms.TextInputField("API 版本 (api_version)", required=True) - - api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) - - api_key = forms.PasswordInputField("API Key (api_key)", required=True) - - deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True) - - -base_azure_llm_model_credential = DefaultAzureLLMModelCredential() - -model_dict = { - 'deployment_name': ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM, - base_azure_llm_model_credential, api_version='2024-02-15-preview' - ) -} +model_info_manage = ModelInfoManage.builder().append_default_model_info(default_model_info).append_model_info( + default_model_info).build() class AzureModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel: - azure_chat_open_ai = AzureChatModel( - azure_endpoint=model_credential.get('api_base'), - openai_api_version=model_credential.get('api_version', '2024-02-15-preview'), - deployment_name=model_credential.get('deployment_name'), - openai_api_key=model_credential.get('api_key'), - openai_api_type="azure" - ) - return azure_chat_open_ai - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return base_azure_llm_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon', 'azure_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py new file mode 100644 index 00000000..ddfb395a --- /dev/null +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:08 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class AzureLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key', 'deployment_name', 'api_version']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确') + else: + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_version = forms.TextInputField("API 版本 (api_version)", required=True) + + api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) + + api_key = forms.PasswordInputField("API Key (api_key)", required=True) + + deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True) diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 6388dbde..45b388af 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -6,15 +6,26 @@ @date:2024/4/28 11:45 @desc: """ -from typing import List +from typing import List, Dict from langchain_core.messages import BaseMessage, get_buffer_string from langchain_openai import AzureChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel -class AzureChatModel(AzureChatOpenAI): +class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return AzureChatModel( + azure_endpoint=model_credential.get('api_base'), + openai_api_version=model_credential.get('api_version', '2024-02-15-preview'), + deployment_name=model_credential.get('deployment_name'), + openai_api_key=model_credential.get('api_key'), + openai_api_type="azure" + ) + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: try: return super().get_num_tokens_from_messages(messages) diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py new file mode 100644 index 00000000..29d9e344 --- /dev/null +++ b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:51 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py b/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py index 3baa5f04..f60f26fa 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py @@ -7,91 +7,35 @@ @Date :5/12/24 7:40 AM """ import os -from typing import Dict -from langchain.schema import HumanMessage - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ - ModelInfo, ModelTypeConst, ValidCode -from setting.models_provider.impl.deepseek_model_provider.model.deepseek_chat_model import DeepSeekChatModel +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential +from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel from smartdoc.conf import PROJECT_DIR - -class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = DeepSeekModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = DeepSeekModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_key = forms.PasswordInputField('API Key', required=True) - - deepseek_llm_model_credential = DeepSeekLLMModelCredential() -model_dict = { - 'deepseek-chat': ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM, - deepseek_llm_model_credential, - ), - 'deepseek-coder': ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM, - deepseek_llm_model_credential, - ), -} +deepseek_chat = ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM, + deepseek_llm_model_credential, DeepSeekChatModel + ) + +deepseek_coder = ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM, + deepseek_llm_model_credential, + DeepSeekChatModel) + +model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_chat).append_model_info( + deepseek_coder).append_default_model_info( + deepseek_coder).build() class DeepSeekModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> DeepSeekChatModel: - deepseek_chat_open_ai = DeepSeekChatModel( - model=model_name, - openai_api_base='https://api.deepseek.com', - openai_api_key=model_credential.get('api_key') - ) - return deepseek_chat_open_ai - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return deepseek_llm_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon', 'deepseek_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/deepseek_chat_model.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/deepseek_chat_model.py deleted file mode 100644 index b7a54b30..00000000 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/deepseek_chat_model.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -@Project :MaxKB -@File :deepseek_chat_model.py -@Author :Brian Yang -@Date :5/12/24 7:44 AM -""" -from typing import List - -from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai import ChatOpenAI - -from common.config.tokenizer_manage_config import TokenizerManage - - -class DeepSeekChatModel(ChatOpenAI): - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - try: - return super().get_num_tokens(text) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py new file mode 100644 index 00000000..086b02e6 --- /dev/null +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :llm.py +@Author :Brian Yang +@Date :5/12/24 7:44 AM +""" +from typing import List, Dict + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + deepseek_chat_open_ai = DeepSeekChatModel( + model=model_name, + openai_api_base='https://api.deepseek.com', + openai_api_key=model_credential.get('api_key') + ) + return deepseek_chat_open_ai + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py new file mode 100644 index 00000000..85ea1b23 --- /dev/null +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:57 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class GeminiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py b/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py index 5ddddf78..b6dd442c 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py @@ -7,93 +7,36 @@ @Date :5/13/24 7:47 AM """ import os -from typing import Dict -from langchain.schema import HumanMessage - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ - ModelInfo, ModelTypeConst, ValidCode -from setting.models_provider.impl.gemini_model_provider.model.gemini_chat_model import GeminiChatModel +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential +from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel from smartdoc.conf import PROJECT_DIR - -class GeminiLLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = GeminiModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = GeminiModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_key = forms.PasswordInputField('API Key', required=True) - - gemini_llm_model_credential = GeminiLLMModelCredential() -model_dict = { - 'gemini-1.0-pro': ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新', +gemini_1_pro = ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新', + ModelTypeConst.LLM, + gemini_llm_model_credential, + GeminiChatModel) + +gemini_1_pro_vision = ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新', ModelTypeConst.LLM, gemini_llm_model_credential, - ), - 'gemini-1.0-pro-vision': ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新', - ModelTypeConst.LLM, - gemini_llm_model_credential, - ), -} + GeminiChatModel) + +model_info_manage = ModelInfoManage.builder().append_model_info(gemini_1_pro).append_model_info( + gemini_1_pro_vision).append_default_model_info(gemini_1_pro).build() class GeminiModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], - **model_kwargs) -> GeminiChatModel: - gemini_chat = GeminiChatModel( - model=model_name, - google_api_key=model_credential.get('api_key') - ) - return gemini_chat - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return gemini_llm_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon', 'gemini_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/gemini_chat_model.py b/apps/setting/models_provider/impl/gemini_model_provider/model/gemini_chat_model.py deleted file mode 100644 index 7a972d9d..00000000 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/gemini_chat_model.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -@Project :MaxKB -@File :gemini_chat_model.py -@Author :Brian Yang -@Date :5/13/24 7:40 AM -""" -from typing import List - -from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_google_genai import ChatGoogleGenerativeAI - -from common.config.tokenizer_manage_config import TokenizerManage - - -class GeminiChatModel(ChatGoogleGenerativeAI): - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - try: - return super().get_num_tokens(text) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py new file mode 100644 index 00000000..b3e0bbc9 --- /dev/null +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :llm.py +@Author :Brian Yang +@Date :5/13/24 7:40 AM +""" +from typing import List, Dict + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_google_genai import ChatGoogleGenerativeAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + gemini_chat = GeminiChatModel( + model=model_name, + google_api_key=model_credential.get('api_key') + ) + return gemini_chat + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py new file mode 100644 index 00000000..9e5835cc --- /dev/null +++ b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:06 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class KimiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py b/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py index 6394e590..1347df46 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py @@ -7,103 +7,36 @@ @desc: """ import os -from typing import Dict -from langchain.schema import HumanMessage -from langchain.chat_models.base import BaseChatModel - - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ - ModelInfo, \ - ModelTypeConst, ValidCode +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.kimi_model_provider.credential.llm import KimiLLMModelCredential +from setting.models_provider.impl.kimi_model_provider.model.llm import KimiChatModel from smartdoc.conf import PROJECT_DIR -from setting.models_provider.impl.kimi_model_provider.model.kimi_chat_model import KimiChatModel - - - - -class KimiLLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = KimiModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_base', 'api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - # llm_kimi = Moonshot( - # model_name=model_name, - # base_url=model_credential['api_base'], - # moonshot_api_key=model_credential['api_key'] - # ) - - model = KimiModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_base = forms.TextInputField('API 域名', required=True) - api_key = forms.PasswordInputField('API Key', required=True) - kimi_llm_model_credential = KimiLLMModelCredential() -model_dict = { - 'moonshot-v1-8k': ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential, - ), - 'moonshot-v1-32k': ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential, - ), - 'moonshot-v1-128k': ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential, - ) -} +moonshot_v1_8k = ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential, + KimiChatModel) +moonshot_v1_32k = ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential, + KimiChatModel) +moonshot_v1_128k = ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential, + KimiChatModel) + +model_info_manage = ModelInfoManage.builder().append_model_info(moonshot_v1_8k).append_model_info( + moonshot_v1_32k).append_default_model_info(moonshot_v1_128k).append_default_model_info(moonshot_v1_8k).build() class KimiModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + def get_dialogue_number(self): return 3 - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel: - kimi_chat_open_ai = KimiChatModel( - openai_api_base=model_credential['api_base'], - openai_api_key=model_credential['api_key'], - model_name=model_name, - ) - return kimi_chat_open_ai - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return kimi_llm_model_credential - def get_model_provide_info(self): return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon', 'kimi_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py similarity index 55% rename from apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py rename to apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index deee11a0..3e4d4f28 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -2,19 +2,29 @@ """ @project: maxkb @Author:虎 - @file: kimi_chat_model.py + @file: llm.py @date:2023/11/10 17:45 @desc: """ -from typing import List +from typing import List, Dict from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel -class KimiChatModel(ChatOpenAI): +class KimiChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + kimi_chat_open_ai = KimiChatModel( + openai_api_base=model_credential['api_base'], + openai_api_key=model_credential['api_key'], + model_name=model_name, + ) + return kimi_chat_open_ai + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py new file mode 100644 index 00000000..030924ca --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -0,0 +1,44 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:19 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OllamaLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py new file mode 100644 index 00000000..89cb3ed3 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/3/6 11:48 + @desc: +""" +from typing import List, Dict +from urllib.parse import urlparse, ParseResult + +from langchain_community.chat_models import ChatOpenAI +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def get_base_url(url: str): + parse = urlparse(url) + return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='', + query='', + fragment='').geturl() + + +class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + api_base = model_credential.get('api_base', '') + base_url = get_base_url(api_base) + return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'), + openai_api_key=model_credential.get('api_key')) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py deleted file mode 100644 index 86c5219d..00000000 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py +++ /dev/null @@ -1,24 +0,0 @@ -# coding=utf-8 -""" - @project: maxkb - @Author:虎 - @file: ollama_chat_model.py - @date:2024/3/6 11:48 - @desc: -""" -from typing import List - -from langchain_community.chat_models import ChatOpenAI -from langchain_core.messages import BaseMessage, get_buffer_string - -from common.config.tokenizer_manage_config import TokenizerManage - - -class OllamaChatModel(ChatOpenAI): - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index 73239921..3839d5fc 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -19,106 +19,83 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ - BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode -from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel + BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage +from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential +from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel from smartdoc.conf import PROJECT_DIR "" - -class OllamaLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = OllamaModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - try: - model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base')) - except Exception as e: - raise AppApiException(ValidCode.valid_error.value, "API 域名无效") - exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if - model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] - if len(exist) == 0: - raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") - return True - - def encryption_dict(self, model_info: Dict[str, object]): - return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} - - def build_model(self, model_info: Dict[str, object]): - for key in ['api_key', 'model']: - if key not in model_info: - raise AppApiException(500, f'{key} 字段为必填字段') - self.api_key = model_info.get('api_key') - return self - - api_base = forms.TextInputField('API 域名', required=True) - api_key = forms.PasswordInputField('API Key', required=True) - - ollama_llm_model_credential = OllamaLLMModelCredential() - -model_dict = { - 'llama2': ModelInfo( +model_info_list = [ + ModelInfo( 'llama2', 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'llama2:13b': ModelInfo( + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'llama2:13b', 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'llama2:70b': ModelInfo( + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'llama2:70b', 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'llama2-chinese:13b': ModelInfo( + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'llama2-chinese:13b', '由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'llama3:8b': ModelInfo( + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'llama3:8b', - 'Meta Llama 3:迄今为止最有能力的公开产品LLM。8亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'llama3:70b': ModelInfo( + 'Meta Llama 3:迄今为止最有能力的公开产品LLM。80亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'llama3:70b', - 'Meta Llama 3:迄今为止最有能力的公开产品LLM。70亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:0.5b': ModelInfo( + 'Meta Llama 3:迄今为止最有能力的公开产品LLM。700亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:0.5b', - 'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。0.5亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:1.8b': ModelInfo( + 'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。5亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:1.8b', - 'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1.8亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:4b': ModelInfo( + 'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。18亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:4b', - 'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。4亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:7b': ModelInfo( + 'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。40亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + + ModelInfo( 'qwen:7b', - 'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。7亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:14b': ModelInfo( + 'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。70亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:14b', - 'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。14亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:32b': ModelInfo( + 'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。140亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:32b', - 'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。32亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:72b': ModelInfo( + 'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。320亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:72b', - 'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。72亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'qwen:110b': ModelInfo( + 'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。720亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'qwen:110b', - 'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。110亿参数。', - ModelTypeConst.LLM, ollama_llm_model_credential), - 'phi3': ModelInfo( + 'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1100亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( 'phi3', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', - ModelTypeConst.LLM, ollama_llm_model_credential), -} + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel) +] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo( + 'phi3', + 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build() def get_base_url(url: str): @@ -169,32 +146,14 @@ def convert(response_stream) -> Iterator[DownModelChunk]: class OllamaModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + def get_model_provide_info(self): return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon', 'ollama_icon_svg'))) - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] - - def get_model_list(self, model_type): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - # 如果使用模型不在配置中,则使用默认认证 - return ollama_llm_model_credential - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel: - api_base = model_credential.get('api_base') - base_url = get_base_url(api_base) - return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'), - openai_api_key=model_credential.get('api_key')) - def get_dialogue_number(self): return 2 diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py new file mode 100644 index 00000000..60ff3b88 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:32 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAILLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py new file mode 100644 index 00000000..7ad5f49f --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -0,0 +1,34 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/4/18 15:28 + @desc: +""" +from typing import List, Dict + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + azure_chat_open_ai = OpenAIChatModel( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key') + ) + return azure_chat_open_ai + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py deleted file mode 100644 index 7271fe8a..00000000 --- a/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py +++ /dev/null @@ -1,30 +0,0 @@ -# coding=utf-8 -""" - @project: maxkb - @Author:虎 - @file: openai_chat_model.py - @date:2024/4/18 15:28 - @desc: -""" -from typing import List - -from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai import ChatOpenAI - -from common.config.tokenizer_manage_config import TokenizerManage - - -class OpenAIChatModel(ChatOpenAI): - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - try: - return super().get_num_tokens(text) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 324e851c..6d12869c 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -7,127 +7,70 @@ @desc: """ import os -from typing import Dict -from langchain.schema import HumanMessage - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ - ModelInfo, \ - ModelTypeConst, ValidCode -from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from smartdoc.conf import PROJECT_DIR - -class OpenAILLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = OpenAIModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_base', 'api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = OpenAIModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_base = forms.TextInputField('API 域名', required=True) - api_key = forms.PasswordInputField('API Key', required=True) - - openai_llm_model_credential = OpenAILLMModelCredential() +model_info_list = [ + ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, OpenAIChatModel + ), + ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-3.5-turbo-0125', + '2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM, + openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-3.5-turbo-1106', + '2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM, + openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-3.5-turbo-0613', + '[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4o-2024-05-13', + '2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-turbo-2024-04-09', + '2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel) +] -model_dict = { - 'gpt-3.5-turbo': ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, - openai_llm_model_credential, - ), - 'gpt-4': ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-4o': ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新', - ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-4-turbo': ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, - openai_llm_model_credential, - ), - 'gpt-4-turbo-preview': ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新', - ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-3.5-turbo-0125': ModelInfo('gpt-3.5-turbo-0125', - '2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM, - openai_llm_model_credential, - ), - 'gpt-3.5-turbo-1106': ModelInfo('gpt-3.5-turbo-1106', - '2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM, - openai_llm_model_credential, - ), - 'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', - '[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用', - ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-4o-2024-05-13': ModelInfo('gpt-4o-2024-05-13', - '2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens', - ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-4-turbo-2024-04-09': ModelInfo('gpt-4-turbo-2024-04-09', - '2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens', - ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-4-0125-preview': ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens', - ModelTypeConst.LLM, openai_llm_model_credential, - ), - 'gpt-4-1106-preview': ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', - ModelTypeConst.LLM, openai_llm_model_credential, - ), -} +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, OpenAIChatModel + )).build() class OpenAIModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel: - azure_chat_open_ai = OpenAIChatModel( - model=model_name, - openai_api_base=model_credential.get('api_base'), - openai_api_key=model_credential.get('api_key') - ) - return azure_chat_open_ai - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return openai_llm_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon', 'openai_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py new file mode 100644 index 00000000..e31624c9 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:41 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAILLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py similarity index 58% rename from apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py rename to apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index d3894d1d..fa4b1109 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -2,19 +2,28 @@ """ @project: maxkb @Author:虎 - @file: qwen_chat_model.py + @file: llm.py @date:2024/4/28 11:44 @desc: """ -from typing import List +from typing import List, Dict from langchain_community.chat_models import ChatTongyi from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel -class QwenChatModel(ChatTongyi): +class QwenChatModel(MaxKBBaseModel, ChatTongyi): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + chat_tong_yi = QwenChatModel( + model_name=model_name, + dashscope_api_key=model_credential.get('api_key') + ) + return chat_tong_yi + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) diff --git a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py index 179f9036..dd0a924b 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py @@ -7,87 +7,33 @@ @desc: """ import os -from typing import Dict -from langchain.schema import HumanMessage -from langchain_community.chat_models.tongyi import ChatTongyi - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ - ModelInfo, IModelProvider, ValidCode -from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential + +from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel from smartdoc.conf import PROJECT_DIR - -class OpenAILLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = QwenModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - for key in ['api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = QwenModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_key = forms.PasswordInputField('API Key', required=True) - - qwen_model_credential = OpenAILLMModelCredential() -model_dict = { - 'qwen-turbo': ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential), - 'qwen-plus': ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential), - 'qwen-max': ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential) -} +module_info_list = [ + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), + ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), + ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel) +] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info( + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build() class QwenModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi: - chat_tong_yi = QwenChatModel( - model_name=model_name, - dashscope_api_key=model_credential.get('api_key') - ) - return chat_tong_yi - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return qwen_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon', 'qwen_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py new file mode 100644 index 00000000..55166ade --- /dev/null +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/12 10:19 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class WenxinLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + model = provider.get_model(model_type, model_name, model_credential) + model_info = [model.lower() for model in model.client.models()] + if not model_info.__contains__(model_name.lower()): + raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持') + for key in ['api_key', 'secret_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model.invoke( + [HumanMessage(content='你好')]) + except Exception as e: + raise e + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'secret_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + self.secret_key = model_info.get('secret_key') + return self + + api_key = forms.PasswordInputField('API Key', required=True) + + secret_key = forms.PasswordInputField("Secret Key", required=True) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py new file mode 100644 index 00000000..bddd1f6d --- /dev/null +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2023/11/10 17:45 + @desc: +""" +from typing import List, Dict + +from langchain.schema.messages import BaseMessage, get_buffer_string +from langchain_community.chat_models import QianfanChatEndpoint + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return QianfanChatModel(model=model_name, + qianfan_ak=model_credential.get('api_key'), + qianfan_sk=model_credential.get('secret_key'), + streaming=model_kwargs.get('streaming', False)) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py deleted file mode 100644 index c8892e05..00000000 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py +++ /dev/null @@ -1,32 +0,0 @@ -# coding=utf-8 -""" - @project: maxkb - @Author:虎 - @file: qian_fan_chat_model.py - @date:2023/11/10 17:45 - @desc: -""" -from typing import Optional, List, Any, Iterator, cast - -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models.base import BaseChatModel -from langchain.load import dumpd -from langchain.schema import LLMResult -from langchain.schema.language_model import LanguageModelInput -from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string -from langchain.schema.output import ChatGenerationChunk -from langchain.schema.runnable import RunnableConfig -from langchain_community.chat_models import QianfanChatEndpoint - -from common.config.tokenizer_manage_config import TokenizerManage - - -class QianfanChatModel(QianfanChatEndpoint): - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py b/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py index b61e2f6b..92e955fa 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py @@ -7,121 +7,53 @@ @desc: """ import os -from typing import Dict -from langchain.schema import HumanMessage -from langchain_community.chat_models import QianfanChatEndpoint -from qianfan import ChatCompletion - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ - ModelInfo, IModelProvider, ValidCode -from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.wenxin_model_provider.credential.llm import WenxinLLMModelCredential +from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel from smartdoc.conf import PROJECT_DIR - -class WenxinLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = WenxinModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - model = WenxinModelProvider().get_model(model_type, model_name, model_credential) - model_info = [model.lower() for model in model.client.models()] - if not model_info.__contains__(model_name.lower()): - raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持') - for key in ['api_key', 'secret_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model.invoke( - [HumanMessage(content='你好')]) - except Exception as e: - raise e - return True - - def encryption_dict(self, model_info: Dict[str, object]): - return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))} - - def build_model(self, model_info: Dict[str, object]): - for key in ['api_key', 'secret_key', 'model']: - if key not in model_info: - raise AppApiException(500, f'{key} 字段为必填字段') - self.api_key = model_info.get('api_key') - self.secret_key = model_info.get('secret_key') - return self - - api_key = forms.PasswordInputField('API Key', required=True) - - secret_key = forms.PasswordInputField("Secret Key", required=True) - - win_xin_llm_model_credential = WenxinLLMModelCredential() -model_dict = { - 'ERNIE-Bot-4': ModelInfo('ERNIE-Bot-4', +model_info_list = [ModelInfo('ERNIE-Bot-4', 'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', - ModelTypeConst.LLM, win_xin_llm_model_credential), + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('ERNIE-Bot', + 'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('ERNIE-Bot-turbo', + 'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('BLOOMZ-7B', + 'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Llama-2-7b-chat', + 'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Llama-2-13b-chat', + 'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Llama-2-70b-chat', + 'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Qianfan-Chinese-Llama-2-7B', + '千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel) + ] - 'ERNIE-Bot': ModelInfo('ERNIE-Bot', - 'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', - ModelTypeConst.LLM, win_xin_llm_model_credential), - - 'ERNIE-Bot-turbo': ModelInfo('ERNIE-Bot-turbo', - 'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。', - ModelTypeConst.LLM, win_xin_llm_model_credential), - - 'BLOOMZ-7B': ModelInfo('BLOOMZ-7B', - 'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。', - ModelTypeConst.LLM, win_xin_llm_model_credential), - - 'Llama-2-7b-chat': ModelInfo('Llama-2-7b-chat', - 'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。', - ModelTypeConst.LLM, win_xin_llm_model_credential), - - 'Llama-2-13b-chat': ModelInfo('Llama-2-13b-chat', - 'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。', - ModelTypeConst.LLM, win_xin_llm_model_credential), - - 'Llama-2-70b-chat': ModelInfo('Llama-2-70b-chat', - 'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。', - ModelTypeConst.LLM, win_xin_llm_model_credential), - - 'Qianfan-Chinese-Llama-2-7B': ModelInfo('Qianfan-Chinese-Llama-2-7B', - '千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。', - ModelTypeConst.LLM, win_xin_llm_model_credential) -} +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('ERNIE-Bot-4', + 'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, + win_xin_llm_model_credential, + QianfanChatModel)).build() class WenxinModelProvider(IModelProvider): - def get_dialogue_number(self): - return 2 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], - **model_kwargs) -> QianfanChatEndpoint: - return QianfanChatModel(model=model_name, - qianfan_ak=model_credential.get('api_key'), - qianfan_sk=model_credential.get('secret_key'), - streaming=model_kwargs.get('streaming', False)) - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] - - def get_model_list(self, model_type): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return win_xin_llm_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content( diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py new file mode 100644 index 00000000..88d47266 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py @@ -0,0 +1,51 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/12 10:29 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + spark_api_url = forms.TextInputField('API 域名', required=True) + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py similarity index 74% rename from apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py rename to apps/setting/models_provider/impl/xf_model_provider/model/llm.py index 3b6a22c4..a9e41dbc 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -7,7 +7,7 @@ @desc: """ -from typing import List, Optional, Any, Iterator +from typing import List, Optional, Any, Iterator, Dict from langchain_community.chat_models import ChatSparkLLM from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk @@ -16,9 +16,21 @@ from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_stri from langchain_core.outputs import ChatGenerationChunk from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel -class XFChatSparkLLM(ChatSparkLLM): +class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XFChatSparkLLM( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + spark_llm_domain=model_name + ) + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 28059c5c..d33b944e 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -7,97 +7,33 @@ @desc: """ import os -from typing import Dict - -from langchain.schema import HumanMessage -from langchain_community.chat_models import ChatSparkLLM - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ - ModelInfo, IModelProvider, ValidCode -from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM -from smartdoc.conf import PROJECT_DIR import ssl +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential +from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM +from smartdoc.conf import PROJECT_DIR + ssl._create_default_https_context = ssl.create_default_context() - -class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = XunFeiModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = XunFeiModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} - - spark_api_url = forms.TextInputField('API 域名', required=True) - spark_app_id = forms.TextInputField('APP ID', required=True) - spark_api_key = forms.PasswordInputField("API Key", required=True) - spark_api_secret = forms.PasswordInputField('API Secret', required=True) - - qwen_model_credential = XunFeiLLMModelCredential() +model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM) + ] -model_dict = { - 'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential), - 'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential), - 'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential) -} +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build() class XunFeiModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM: - zhipuai_chat = XFChatSparkLLM( - spark_app_id=model_credential.get('spark_app_id'), - spark_api_key=model_credential.get('spark_api_key'), - spark_api_secret=model_credential.get('spark_api_secret'), - spark_api_url=model_credential.get('spark_api_url'), - spark_llm_domain=model_name - ) - return zhipuai_chat - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return qwen_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon', 'xf_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py new file mode 100644 index 00000000..a88f65ce --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/12 10:46 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py similarity index 57% rename from apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py rename to apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index ceab8988..86c5b1a4 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -2,19 +2,29 @@ """ @project: maxkb @Author:虎 - @file: zhipu_chat_model.py + @file: llm.py @date:2024/4/28 11:42 @desc: """ -from typing import List +from typing import List, Dict from langchain_community.chat_models import ChatZhipuAI from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel -class ZhipuChatModel(ChatZhipuAI): +class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + zhipuai_chat = ZhipuChatModel( + temperature=0.5, + api_key=model_credential.get('api_key'), + model=model_name + ) + return zhipuai_chat + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py index ebbb3b46..da872c54 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py @@ -7,88 +7,30 @@ @desc: """ import os -from typing import Dict -from langchain.schema import HumanMessage -from langchain_community.chat_models import ChatZhipuAI - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm from common.util.file_util import get_file_content -from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ - ModelInfo, IModelProvider, ValidCode -from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential +from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel from smartdoc.conf import PROJECT_DIR - -class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): - model_type_list = ZhiPuModelProvider().get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - for key in ['api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - api_key = forms.PasswordInputField('API Key', required=True) - - qwen_model_credential = ZhiPuLLMModelCredential() - -model_dict = { - 'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential), - 'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential), - 'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential) -} +model_info_list = [ + ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel), + ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel), + ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel) +] +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info( + ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)).build() class ZhiPuModelProvider(IModelProvider): - def get_dialogue_number(self): - return 3 - - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI: - zhipuai_chat = ZhipuChatModel( - temperature=0.5, - api_key=model_credential.get('api_key'), - model=model_name - ) - return zhipuai_chat - - def get_model_credential(self, model_type, model_name): - if model_name in model_dict: - return model_dict.get(model_name).model_credential - return qwen_model_credential + def get_model_info_manage(self): + return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon', 'zhipuai_icon_svg'))) - - def get_model_list(self, model_type: str): - if model_type is None: - raise AppApiException(500, '模型类型不能为空') - return [model_dict.get(key).to_dict() for key in - list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] - - def get_model_type_list(self): - return [{'key': "大语言模型", 'value': "LLM"}] diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 351a98c8..b5d85e8a 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -115,7 +115,7 @@ class ModelSerializer(serializers.Serializer): model_name = self.data.get( 'model_name') credential = self.data.get('credential') - + provider_handler = ModelProvideConstants[provider].value model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_name) source_model_credential = json.loads(rsa_long_decrypt(model.credential)) @@ -124,7 +124,7 @@ class ModelSerializer(serializers.Serializer): for k in source_encryption_model_credential.keys(): if credential[k] == source_encryption_model_credential[k]: credential[k] = source_model_credential[k] - return credential, model_credential + return credential, model_credential, provider_handler class Create(serializers.Serializer): user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -145,13 +145,10 @@ class ModelSerializer(serializers.Serializer): name=self.data.get('name')).exists(): raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') # 校验模型认证数据 - ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'), - self.data.get( - 'model_name')).is_valid( - self.data.get('model_type'), - self.data.get('model_name'), - self.data.get('credential'), - raise_exception=True) + ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'), + self.data.get('model_name'), + self.data.get('credential') + ) def insert(self, user_id, with_valid=False): status = Status.SUCCESS @@ -232,16 +229,17 @@ class ModelSerializer(serializers.Serializer): if model is None: raise AppApiException(500, '不存在的id') else: - credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid( + credential, model_credential, provider_handler = ModelSerializer.Edit( + data={**instance, 'user_id': user_id}).is_valid( model=model) try: model.status = Status.SUCCESS # 校验模型认证数据 - model_credential.is_valid( - model.model_type, - instance.get("model_name"), - credential, - raise_exception=True) + provider_handler.is_valid_credential(model.model_type, + instance.get("model_name"), + credential, + raise_exception=True) + except AppApiException as e: if e.code == ValidCode.model_not_fount: model.status = Status.DOWNLOAD