From 2301190c4d8af546ca92f08a5b1a14777a0382a3 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Mon, 20 Jan 2025 13:50:51 +0800 Subject: [PATCH] feat: Support vllm image model --- .../vllm_model_provider/credential/image.py | 68 +++++++++++++++++++ .../impl/vllm_model_provider/model/image.py | 20 ++++++ .../vllm_model_provider.py | 24 +++++-- 3 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/credential/image.py create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/model/image.py diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/image.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/image.py new file mode 100644 index 00000000..9cc8215e --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/image.py @@ -0,0 +1,68 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from django.utils.translation import gettext_lazy as _ + +class VllmImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + + +class VllmImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API Url', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, _('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + for chunk in res: + print(chunk) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return VllmImageModelParams() diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py new file mode 100644 index 00000000..f3b69a38 --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py @@ -0,0 +1,20 @@ +from typing import Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class VllmImage(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return VllmImage( + model_name=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + # stream_options={"include_usage": True}, + streaming=True, + stream_usage=True, + **optional_params, + ) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py index aaeec966..96169591 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -7,12 +7,16 @@ import requests from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ ModelInfoManage +from setting.models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential from setting.models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential +from setting.models_provider.impl.vllm_model_provider.model.image import VllmImage from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ v_llm_model_credential = VLLMModelCredential() +image_model_credential = VllmImageModelCredential() + model_info_list = [ ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), @@ -20,12 +24,20 @@ model_info_list = [ ] -model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( - ModelInfo( - 'facebook/opt-125m', - _('Facebook’s 125M parameter model'), - ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel)) - .build()) +image_model_info_list = [ + ModelInfo('Qwen/Qwen2-VL-2B-Instruct', '', ModelTypeConst.IMAGE, image_model_credential, VllmImage), +] + +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info(ModelInfo('facebook/opt-125m', + _('Facebook’s 125M parameter model'), + ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel)) + .append_model_info_list(image_model_info_list) + .append_default_model_info(image_model_info_list[0]) + .build() +) def get_base_url(url: str):