Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
1cf8008f68
@ -10,7 +10,7 @@ import os
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
from langchain.schema import HumanMessage
|
||||||
from langchain_community.chat_models import AzureChatOpenAI
|
from langchain_community.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
|
||||||
from common import froms
|
from common import froms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -29,9 +29,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
if model_name not in model_dict:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
|
|
||||||
|
|
||||||
for key in ['api_base', 'api_key', 'deployment_name']:
|
for key in ['api_base', 'api_key', 'deployment_name']:
|
||||||
if key not in model_credential:
|
if key not in model_credential:
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
@ -40,7 +37,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
||||||
model.invoke([HumanMessage(content='valid')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
raise e
|
raise e
|
||||||
@ -61,8 +58,48 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
deployment_name = froms.TextInputField("部署名", required=True)
|
deployment_name = froms.TextInputField("部署名", required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||||
|
model_type_list = AzureModelProvider().get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_version = froms.TextInputField("api_version", required=True)
|
||||||
|
|
||||||
|
api_base = froms.TextInputField('API 域名', required=True)
|
||||||
|
|
||||||
|
api_key = froms.PasswordInputField("API Key", required=True)
|
||||||
|
|
||||||
|
deployment_name = froms.TextInputField("部署名", required=True)
|
||||||
|
|
||||||
|
|
||||||
azure_llm_model_credential = AzureLLMModelCredential()
|
azure_llm_model_credential = AzureLLMModelCredential()
|
||||||
|
|
||||||
|
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
|
||||||
|
|
||||||
model_dict = {
|
model_dict = {
|
||||||
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
|
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
|
||||||
api_version='2023-07-01-preview'),
|
api_version='2023-07-01-preview'),
|
||||||
@ -84,18 +121,18 @@ class AzureModelProvider(IModelProvider):
|
|||||||
model_info: ModelInfo = model_dict.get(model_name)
|
model_info: ModelInfo = model_dict.get(model_name)
|
||||||
azure_chat_open_ai = AzureChatOpenAI(
|
azure_chat_open_ai = AzureChatOpenAI(
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
openai_api_version=model_info.api_version,
|
openai_api_version=model_credential.get(
|
||||||
|
'api_version') if 'api_version' in model_credential else model_info.api_version,
|
||||||
deployment_name=model_credential.get('deployment_name'),
|
deployment_name=model_credential.get('deployment_name'),
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
openai_api_type="azure",
|
openai_api_type="azure"
|
||||||
tiktoken_model_name=model_name
|
|
||||||
)
|
)
|
||||||
return azure_chat_open_ai
|
return azure_chat_open_ai
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
def get_model_credential(self, model_type, model_name):
|
||||||
if model_name in model_dict:
|
if model_name in model_dict:
|
||||||
return model_dict.get(model_name).model_credential
|
return model_dict.get(model_name).model_credential
|
||||||
raise AppApiException(500, f'不支持的模型:{model_name}')
|
return base_azure_llm_model_credential
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
||||||
|
|||||||
@ -9,8 +9,9 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from langchain_community.chat_models import QianfanChatEndpoint
|
|
||||||
from langchain.schema import HumanMessage
|
from langchain.schema import HumanMessage
|
||||||
|
from langchain_community.chat_models import QianfanChatEndpoint
|
||||||
|
from qianfan import ChatCompletion
|
||||||
|
|
||||||
from common import froms
|
from common import froms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -27,10 +28,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
model_type_list = WenxinModelProvider().get_model_type_list()
|
model_type_list = WenxinModelProvider().get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
model_info = [model.lower() for model in ChatCompletion.models()]
|
||||||
if model_name not in model_dict:
|
if not model_info.__contains__(model_name.lower()):
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
||||||
|
|
||||||
for key in ['api_key', 'secret_key']:
|
for key in ['api_key', 'secret_key']:
|
||||||
if key not in model_credential:
|
if key not in model_credential:
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
@ -39,10 +39,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
||||||
[HumanMessage(content='valid')])
|
[HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if raise_exception:
|
raise e
|
||||||
raise AppApiException(ValidCode.valid_error.value, "校验失败,请检查 api_key secret_key 是否正确")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def encryption_dict(self, model_info: Dict[str, object]):
|
def encryption_dict(self, model_info: Dict[str, object]):
|
||||||
@ -121,7 +120,7 @@ class WenxinModelProvider(IModelProvider):
|
|||||||
def get_model_credential(self, model_type, model_name):
|
def get_model_credential(self, model_type, model_name):
|
||||||
if model_name in model_dict:
|
if model_name in model_dict:
|
||||||
return model_dict.get(model_name).model_credential
|
return model_dict.get(model_name).model_credential
|
||||||
raise AppApiException(500, f'不支持的模型:{model_name}')
|
return win_xin_llm_model_credential
|
||||||
|
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
|
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
|
||||||
|
|||||||
@ -23,7 +23,7 @@ sentence-transformers = "^2.2.2"
|
|||||||
blinker = "^1.6.3"
|
blinker = "^1.6.3"
|
||||||
openai = "^1.13.3"
|
openai = "^1.13.3"
|
||||||
tiktoken = "^0.5.1"
|
tiktoken = "^0.5.1"
|
||||||
qianfan = "^0.1.1"
|
qianfan = "^0.3.6.1"
|
||||||
pycryptodome = "^3.19.0"
|
pycryptodome = "^3.19.0"
|
||||||
beautifulsoup4 = "^4.12.2"
|
beautifulsoup4 = "^4.12.2"
|
||||||
html2text = "^2024.2.26"
|
html2text = "^2024.2.26"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user