feat: 大语言模型支持自定义参数入参

This commit is contained in:
shaohuzhang1 2024-10-25 16:40:22 +08:00 committed by shaohuzhang1
parent cbdf4b7fd9
commit 36efb58a05
16 changed files with 40 additions and 86 deletions

View File

@ -93,6 +93,14 @@ class MaxKBBaseModel(ABC):
def is_cache_model(): def is_cache_model():
return True return True
@staticmethod
def filter_optional_params(model_kwargs):
optional_params = {}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params[key] = value
return optional_params
class BaseModelCredential(ABC): class BaseModelCredential(ABC):

View File

@ -40,12 +40,7 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
@classmethod @classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel': **model_kwargs) -> 'BedrockModel':
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
keyword = get_max_tokens_keyword(model_name)
optional_params[keyword] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return cls( return cls(
model_id=model_name, model_id=model_name,

View File

@ -26,11 +26,7 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return AzureChatModel( return AzureChatModel(
azure_endpoint=model_credential.get('api_base'), azure_endpoint=model_credential.get('api_base'),

View File

@ -20,11 +20,7 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
deepseek_chat_open_ai = DeepSeekChatModel( deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name, model=model_name,

View File

@ -30,11 +30,7 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
if 'max_tokens' in model_kwargs:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
gemini_chat = GeminiChatModel( gemini_chat = GeminiChatModel(
model=model_name, model=model_name,

View File

@ -20,11 +20,7 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
kimi_chat_open_ai = KimiChatModel( kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'], openai_api_base=model_credential['api_base'],

View File

@ -34,11 +34,7 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
api_base = model_credential.get('api_base', '') api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base) base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OllamaChatModel(model=model_name, openai_api_base=base_url, return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'), openai_api_key=model_credential.get('api_key'),

View File

@ -30,11 +30,7 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
azure_chat_open_ai = OpenAIChatModel( azure_chat_open_ai = OpenAIChatModel(
model=model_name, model=model_name,
openai_api_base=model_credential.get('api_base'), openai_api_base=model_credential.get('api_base'),

View File

@ -27,11 +27,7 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
chat_tong_yi = QwenChatModel( chat_tong_yi = QwenChatModel(
model_name=model_name, model_name=model_name,
dashscope_api_key=model_credential.get('api_key'), dashscope_api_key=model_credential.get('api_key'),

View File

@ -18,9 +18,7 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
hunyuan_secret_id = credentials.get('hunyuan_secret_id') hunyuan_secret_id = credentials.get('hunyuan_secret_id')
hunyuan_secret_key = credentials.get('hunyuan_secret_key') hunyuan_secret_key = credentials.get('hunyuan_secret_key')
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(kwargs)
if 'temperature' in kwargs:
optional_params['temperature'] = kwargs['temperature']
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]): if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
raise ValueError( raise ValueError(

View File

@ -22,11 +22,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
vllm_chat_open_ai = VllmChatModel( vllm_chat_open_ai = VllmChatModel(
model=model_name, model=model_name,
openai_api_base=model_credential.get('api_base'), openai_api_base=model_credential.get('api_base'),

View File

@ -12,11 +12,7 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return VolcanicEngineChatModel( return VolcanicEngineChatModel(
model=model_name, model=model_name,
openai_api_base=model_credential.get('api_base'), openai_api_base=model_credential.get('api_base'),

View File

@ -26,11 +26,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return QianfanChatModel(model=model_name, return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'), qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'), qianfan_sk=model_credential.get('secret_key'),

View File

@ -6,11 +6,10 @@
@date2024/04/19 15:55 @date2024/04/19 15:55
@desc: @desc:
""" """
import json
from typing import List, Optional, Any, Iterator, Dict from typing import List, Optional, Any, Iterator, Dict
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \ from langchain_community.chat_models.sparkllm import \
ChatSparkLLM ChatSparkLLM, _convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk from langchain_core.outputs import ChatGenerationChunk
@ -25,11 +24,7 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XFChatSparkLLM( return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'), spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'), spark_api_key=model_credential.get('spark_api_key'),

View File

@ -1,8 +1,12 @@
# coding=utf-8 # coding=utf-8
from typing import Dict from typing import Dict, Optional, List, Any, Iterator
from urllib.parse import urlparse, ParseResult from urllib.parse import urlparse, ParseResult
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessageChunk
from langchain_core.runnables import RunnableConfig
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
@ -26,11 +30,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
api_base = model_credential.get('api_base', '') api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base) base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XinferenceChatModel( return XinferenceChatModel(
model=model_name, model=model_name,
openai_api_base=base_url, openai_api_base=base_url,

View File

@ -7,43 +7,41 @@
@desc: @desc:
""" """
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
import json import json
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.messages import ( from langchain_core.messages import (
AIMessageChunk, AIMessageChunk,
BaseMessage BaseMessage
) )
from langchain_core.outputs import ChatGenerationChunk from langchain_core.outputs import ChatGenerationChunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
optional_params: dict
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
return False return False
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {} optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
zhipuai_chat = ZhipuChatModel( zhipuai_chat = ZhipuChatModel(
api_key=model_credential.get('api_key'), api_key=model_credential.get('api_key'),
model=model_name, model=model_name,
streaming=model_kwargs.get('streaming', False), streaming=model_kwargs.get('streaming', False),
**optional_params optional_params=optional_params,
**optional_params,
) )
return zhipuai_chat return zhipuai_chat
@ -71,7 +69,7 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
if self.zhipuai_api_base is None: if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.") raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True} payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
_truncate_params(payload) _truncate_params(payload)
headers = { headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key), "Authorization": _get_jwt_token(self.zhipuai_api_key),