feat: 增加对接模型
This commit is contained in:
parent
35f0c18dd3
commit
72423a7c3e
@ -8,6 +8,7 @@
|
|||||||
"""
|
"""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
|
||||||
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||||
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
||||||
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
||||||
@ -15,6 +16,9 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import
|
|||||||
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||||
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||||
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
|
||||||
|
VolcanicEngineModelProvider
|
||||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||||
@ -32,4 +36,7 @@ class ModelProvideConstants(Enum):
|
|||||||
model_xf_provider = XunFeiModelProvider()
|
model_xf_provider = XunFeiModelProvider()
|
||||||
model_deepseek_provider = DeepSeekModelProvider()
|
model_deepseek_provider = DeepSeekModelProvider()
|
||||||
model_gemini_provider = GeminiModelProvider()
|
model_gemini_provider = GeminiModelProvider()
|
||||||
|
model_volcanic_engine_provider = VolcanicEngineModelProvider()
|
||||||
|
model_tencent_provider = TencentModelProvider()
|
||||||
|
model_aws_bedrock_provider = BedrockModelProvider()
|
||||||
model_local_provider = LocalModelProvider()
|
model_local_provider = LocalModelProvider()
|
||||||
|
|||||||
@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
from common.util.file_util import get_file_content
|
||||||
|
from setting.models_provider.base_model_provider import (
|
||||||
|
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||||
|
)
|
||||||
|
from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
|
||||||
|
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||||
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||||
|
return ModelInfo(
|
||||||
|
name=model_name,
|
||||||
|
desc=description,
|
||||||
|
model_type=model_type,
|
||||||
|
model_credential=credential_class(),
|
||||||
|
model_class=model_class
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_aws_bedrock_icon_path():
|
||||||
|
return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aws_bedrock_model_provider',
|
||||||
|
'icon', 'bedrock_icon_svg')
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_model_info():
|
||||||
|
model_info_list = [_create_model_info(
|
||||||
|
'amazon.titan-text-premier-v1:0',
|
||||||
|
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel
|
||||||
|
),
|
||||||
|
_create_model_info(
|
||||||
|
'amazon.titan-text-lite-v1',
|
||||||
|
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
_create_model_info(
|
||||||
|
'amazon.titan-text-express-v1',
|
||||||
|
'Amazon Titan Text Express 的上下文长度长达 8000 个令牌,因而非常适合各种高级常规语言任务,例如开放式文本生成和对话式聊天,以及检索增强生成(RAG)中的支持。在发布时,该模型针对英语进行了优化,但也支持其他语言。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
_create_model_info(
|
||||||
|
'amazon.titan-embed-text-v2:0',
|
||||||
|
'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型,非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小(1024、512 和 256),并优先考虑在较小维度上保持准确性,从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
_create_model_info(
|
||||||
|
'mistral.mistral-7b-instruct-v0:2',
|
||||||
|
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
|
||||||
|
ModelTypeConst.EMBEDDING,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
_create_model_info(
|
||||||
|
'mistral.mistral-large-2402-v1:0',
|
||||||
|
'先进的 Mistral AI 大型语言模型,能够处理任何语言任务,包括复杂的多语言推理、文本理解、转换和代码生成。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
_create_model_info(
|
||||||
|
'meta.llama3-70b-instruct-v1:0',
|
||||||
|
'非常适合内容创作、会话式人工智能、语言理解、研发和企业应用',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
_create_model_info(
|
||||||
|
'meta.llama3-8b-instruct-v1:0',
|
||||||
|
'非常适合有限的计算能力和资源、边缘设备和更快的训练时间。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
BedrockLLMModelCredential,
|
||||||
|
BedrockModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
model_info_manage = ModelInfoManage.builder() \
|
||||||
|
.append_model_info_list(model_info_list) \
|
||||||
|
.append_default_model_info(model_info_list[0]) \
|
||||||
|
.build()
|
||||||
|
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockModelProvider(IModelProvider):
|
||||||
|
def __init__(self):
|
||||||
|
self._model_info_manage = _initialize_model_info()
|
||||||
|
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return self._model_info_manage
|
||||||
|
|
||||||
|
def get_model_provide_info(self):
|
||||||
|
icon_path = _get_aws_bedrock_icon_path()
|
||||||
|
icon_data = get_file_content(icon_path)
|
||||||
|
return ModelProvideInfo(
|
||||||
|
provider='model_aws_bedrock_provider',
|
||||||
|
name='Amazon Bedrock',
|
||||||
|
icon=icon_data
|
||||||
|
)
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from tencentcloud.common import credential
|
||||||
|
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||||
|
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||||
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
@classmethod
|
||||||
|
def _validate_model_type(cls, model_type: str, provider) -> bool:
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
|
||||||
|
for key in ['SecretId', 'SecretKey']:
|
||||||
|
if key not in model_credential:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _test_credentials(cls, client, model_name: str):
|
||||||
|
req = models.GetEmbeddingRequest()
|
||||||
|
params = {
|
||||||
|
"Model": model_name,
|
||||||
|
"Input": "测试"
|
||||||
|
}
|
||||||
|
req.from_json_string(json.dumps(params))
|
||||||
|
try:
|
||||||
|
res = client.GetEmbedding(req)
|
||||||
|
print(res.to_json_string())
|
||||||
|
except Exception as e:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=True) -> bool:
|
||||||
|
try:
|
||||||
|
self._validate_model_type(model_type, provider)
|
||||||
|
cred = self._validate_credential(model_credential)
|
||||||
|
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
|
||||||
|
clientProfile = ClientProfile(httpProfile=httpProfile)
|
||||||
|
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||||
|
self._test_credentials(client, model_name)
|
||||||
|
return True
|
||||||
|
except AppApiException as e:
|
||||||
|
if raise_exception:
|
||||||
|
raise e
|
||||||
|
return False
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||||
|
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
|
||||||
|
return {**model, 'SecretKey': encrypted_secret_key}
|
||||||
|
|
||||||
|
SecretId = forms.PasswordInputField('SecretId', required=True)
|
||||||
|
SecretKey = forms.PasswordInputField('SecretKey', required=True)
|
||||||
@ -0,0 +1,62 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from common import forms
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||||
|
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||||
|
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||||
|
|
||||||
|
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||||
|
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||||
|
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
if not re.search(rf'\[{profile_name}\]', content):
|
||||||
|
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||||
|
|
||||||
|
with open(credentials_path, 'w') as file:
|
||||||
|
file.write(content)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
return False
|
||||||
|
|
||||||
|
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
|
||||||
|
if not all(key in model_credential for key in required_keys):
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
||||||
|
model_credential['secret_access_key'])
|
||||||
|
model_credential['credentials_profile_name'] = 'aws-profile'
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except AppApiException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
|
||||||
|
|
||||||
|
region_name = forms.TextInputField('Region Name', required=True)
|
||||||
|
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||||
|
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||||
@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" version="2.0" focusable="false" aria-hidden="true" class="globalNav-1216 globalNav-1213" data-testid="awsc-logo" viewBox="0 0 29 17"><path class="globalNav-1214" d="M8.38 6.17a2.6 2.6 0 00.11.83c.08.232.18.456.3.67a.4.4 0 01.07.21.36.36 0 01-.18.28l-.59.39a.43.43 0 01-.24.08.38.38 0 01-.28-.13 2.38 2.38 0 01-.34-.43c-.09-.16-.18-.34-.28-.55a3.44 3.44 0 01-2.74 1.29 2.54 2.54 0 01-1.86-.67 2.36 2.36 0 01-.68-1.79 2.43 2.43 0 01.84-1.92 3.43 3.43 0 012.29-.72 6.75 6.75 0 011 .07c.35.05.7.12 1.07.2V3.3a2.06 2.06 0 00-.44-1.49 2.12 2.12 0 00-1.52-.43 4.4 4.4 0 00-1 .12 6.85 6.85 0 00-1 .32l-.33.12h-.14c-.14 0-.2-.1-.2-.29v-.46A.62.62 0 012.3.87a.78.78 0 01.27-.2A6 6 0 013.74.25 5.7 5.7 0 015.19.07a3.37 3.37 0 012.44.76 3 3 0 01.77 2.29l-.02 3.05zM4.6 7.59a3 3 0 001-.17 2 2 0 00.88-.6 1.36 1.36 0 00.32-.59 3.18 3.18 0 00.09-.81V5A7.52 7.52 0 006 4.87h-.88a2.13 2.13 0 00-1.38.37 1.3 1.3 0 00-.46 1.08 1.3 1.3 0 00.34 1c.278.216.63.313.98.27zm7.49 1a.56.56 0 01-.36-.09.73.73 0 01-.2-.37L9.35.93a1.39 1.39 0 01-.08-.38c0-.15.07-.23.22-.23h.92a.56.56 0 01.36.09.74.74 0 01.19.37L12.53 7 14 .79a.61.61 0 01.18-.37.59.59 0 01.37-.09h.75a.62.62 0 01.38.09.74.74 0 01.18.37L17.31 7 18.92.76a.74.74 0 01.19-.37.56.56 0 01.36-.09h.87a.21.21 0 01.23.23 1 1 0 010 .15s0 .13-.06.23l-2.26 7.2a.74.74 0 01-.19.37.6.6 0 01-.36.09h-.8a.53.53 0 01-.37-.1.64.64 0 01-.18-.37l-1.45-6-1.44 6a.64.64 0 01-.18.37.55.55 0 01-.37.1l-.82.02zm12 .24a6.29 6.29 0 01-1.44-.16 4.21 4.21 0 01-1.07-.37.69.69 0 01-.29-.26.66.66 0 01-.06-.27V7.3c0-.19.07-.29.21-.29a.57.57 0 01.18 0l.23.1c.32.143.656.25 1 .32.365.08.737.12 1.11.12a2.47 2.47 0 001.36-.31 1 1 0 00.48-.88.88.88 0 00-.25-.65 2.29 2.29 0 00-.94-.49l-1.35-.43a2.83 2.83 0 01-1.49-.94 2.24 2.24 0 01-.47-1.36 2 2 0 01.25-1c.167-.3.395-.563.67-.77a3 3 0 011-.48A4.1 4.1 0 0124.4.08a4.4 4.4 0 01.62 0l.61.1.53.15.39.16c.105.062.2.14.28.23a.57.57 0 01.08.31v.44c0 .2-.07.3-.21.3a.92.92 0 01-.36-.12 4.35 4.35 0 00-1.8-.36 2.51 2.51 0 00-1.24.26.92.92 0 00-.44.84c0 .249.1.488.28.66.295.236.635.41 1 .51l1.32.42a2.88 2.88 0 011.44.9 2.1 2.1 0 01.43 1.31 2.38 2.38 0 01-.24 1.08 2.34 2.34 0 01-.68.82 3 3 0 01-1 .53 4.59 4.59 0 01-1.35.22l.03-.01z"></path><path class="globalNav-1215" d="M25.82 13.43a20.07 20.07 0 01-11.35 3.47A20.54 20.54 0 01.61 11.62c-.29-.26 0-.62.32-.42a27.81 27.81 0 0013.86 3.68 27.54 27.54 0 0010.58-2.16c.52-.22.96.34.45.71z"></path><path class="globalNav-1215" d="M27.1 12c-.4-.51-2.6-.24-3.59-.12-.3 0-.34-.23-.07-.42 1.75-1.23 4.63-.88 5-.46.37.42-.09 3.3-1.74 4.68-.25.21-.49.09-.38-.18.34-.95 1.17-3.02.78-3.5z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 2.6 KiB |
@ -0,0 +1,25 @@
|
|||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from typing import Dict
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class TencentEmbeddingModel(MaxKBBaseModel):
|
||||||
|
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
|
||||||
|
self.secret_id = secret_id
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.api_base = api_base
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
|
||||||
|
return TencentEmbeddingModel(
|
||||||
|
secret_id=model_credential.get('SecretId'),
|
||||||
|
secret_key=model_credential.get('SecretKey'),
|
||||||
|
api_base=model_credential.get('api_base'),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_auth_token(self):
|
||||||
|
# Example method to generate an authentication token for the model API
|
||||||
|
return f"{self.secret_id}:{self.secret_key}"
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
from langchain_community.chat_models import BedrockChat
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockModel(MaxKBBaseModel, BedrockChat):
|
||||||
|
|
||||||
|
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||||
|
streaming: bool = False, **kwargs):
|
||||||
|
super().__init__(model_id=model_id, region_name=region_name,
|
||||||
|
credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||||
|
**model_kwargs) -> 'BedrockModel':
|
||||||
|
return cls(
|
||||||
|
model_id=model_name,
|
||||||
|
region_name=model_credential['region_name'],
|
||||||
|
credentials_profile_name=model_credential['credentials_profile_name'],
|
||||||
|
streaming=model_kwargs.pop('streaming', False),
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_num_tokens(self, content: str) -> int:
|
||||||
|
"""Helper method to count tokens in a string."""
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(content))
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
return self._get_num_tokens(text)
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from tencentcloud.common import credential
|
||||||
|
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||||
|
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||||
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
@classmethod
|
||||||
|
def _validate_model_type(cls, model_type: str, provider) -> bool:
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
|
||||||
|
for key in ['SecretId', 'SecretKey']:
|
||||||
|
if key not in model_credential:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _test_credentials(cls, client, model_name: str):
|
||||||
|
req = models.GetEmbeddingRequest()
|
||||||
|
params = {
|
||||||
|
"Model": model_name,
|
||||||
|
"Input": "测试"
|
||||||
|
}
|
||||||
|
req.from_json_string(json.dumps(params))
|
||||||
|
try:
|
||||||
|
res = client.GetEmbedding(req)
|
||||||
|
print(res.to_json_string())
|
||||||
|
except Exception as e:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=True) -> bool:
|
||||||
|
try:
|
||||||
|
self._validate_model_type(model_type, provider)
|
||||||
|
cred = self._validate_credential(model_credential)
|
||||||
|
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
|
||||||
|
clientProfile = ClientProfile(httpProfile=httpProfile)
|
||||||
|
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||||
|
self._test_credentials(client, model_name)
|
||||||
|
return True
|
||||||
|
except AppApiException as e:
|
||||||
|
if raise_exception:
|
||||||
|
raise e
|
||||||
|
return False
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||||
|
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
|
||||||
|
return {**model, 'SecretKey': encrypted_secret_key}
|
||||||
|
|
||||||
|
SecretId = forms.PasswordInputField('SecretId', required=True)
|
||||||
|
SecretKey = forms.PasswordInputField('SecretKey', required=True)
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class TencentLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_model_type(cls, model_type, provider, raise_exception=False):
|
||||||
|
if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_credential_fields(cls, model_credential, raise_exception=False):
|
||||||
|
missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
|
||||||
|
if missing_keys:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
|
||||||
|
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
||||||
|
self._validate_credential_fields(model_credential, raise_exception)):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.invoke([HumanMessage(content='你好')])
|
||||||
|
except Exception as e:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model):
|
||||||
|
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
|
||||||
|
|
||||||
|
hunyuan_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
|
||||||
|
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 74 KiB |
@ -0,0 +1,25 @@
|
|||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from typing import Dict
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class TencentEmbeddingModel(MaxKBBaseModel):
|
||||||
|
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
|
||||||
|
self.secret_id = secret_id
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.api_base = api_base
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
|
||||||
|
return TencentEmbeddingModel(
|
||||||
|
secret_id=model_credential.get('SecretId'),
|
||||||
|
secret_key=model_credential.get('SecretKey'),
|
||||||
|
api_base=model_credential.get('api_base'),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_auth_token(self):
|
||||||
|
# Example method to generate an authentication token for the model API
|
||||||
|
return f"{self.secret_id}:{self.secret_key}"
|
||||||
@ -0,0 +1,273 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
|
from langchain_core.utils import (
|
||||||
|
convert_to_secret_str,
|
||||||
|
get_from_dict_or_env,
|
||||||
|
get_pydantic_field_names,
|
||||||
|
pre_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
message_dict: Dict[str, Any]
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"Role": message.role, "Content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"Role": "user", "Content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"Role": "assistant", "Content": message.content}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["Role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["Content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
return AIMessage(content=_dict.get("Content", "") or "")
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["Content"], role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("Role")
|
||||||
|
content = _dict.get("Content") or ""
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content)
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for choice in response["Choices"]:
|
||||||
|
message = _convert_dict_to_message(choice["Message"])
|
||||||
|
generations.append(ChatGeneration(message=message))
|
||||||
|
|
||||||
|
token_usage = response["Usage"]
|
||||||
|
llm_output = {"token_usage": token_usage}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHunyuan(BaseChatModel):
|
||||||
|
"""Tencent Hunyuan chat models API by Tencent.
|
||||||
|
|
||||||
|
For more information, see https://cloud.tencent.com/document/product/1729
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"hunyuan_app_id": "HUNYUAN_APP_ID",
|
||||||
|
"hunyuan_secret_id": "HUNYUAN_SECRET_ID",
|
||||||
|
"hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
hunyuan_app_id: Optional[int] = None
|
||||||
|
"""Hunyuan App ID"""
|
||||||
|
hunyuan_secret_id: Optional[str] = None
|
||||||
|
"""Hunyuan Secret ID"""
|
||||||
|
hunyuan_secret_key: Optional[SecretStr] = None
|
||||||
|
"""Hunyuan Secret Key"""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
request_timeout: int = 60
|
||||||
|
"""Timeout for requests to Hunyuan API. Default is 60 seconds."""
|
||||||
|
temperature: float = 1.0
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
top_p: float = 1.0
|
||||||
|
"""What probability mass to use."""
|
||||||
|
model: str = "hunyuan-lite"
|
||||||
|
"""What Model to use.
|
||||||
|
Optional model:
|
||||||
|
- hunyuan-lite、
|
||||||
|
- hunyuan-standard
|
||||||
|
- hunyuan-standard-256K
|
||||||
|
- hunyuan-pro
|
||||||
|
- hunyuan-code
|
||||||
|
- hunyuan-role
|
||||||
|
- hunyuan-functioncall
|
||||||
|
- hunyuan-vision
|
||||||
|
"""
|
||||||
|
stream_moderation: bool = False
|
||||||
|
"""Whether to review the results or not when streaming is true."""
|
||||||
|
enable_enhancement: bool = True
|
||||||
|
"""Whether to enhancement the results or not."""
|
||||||
|
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@pre_init
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
values["hunyuan_app_id"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"hunyuan_app_id",
|
||||||
|
"HUNYUAN_APP_ID",
|
||||||
|
)
|
||||||
|
values["hunyuan_secret_id"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"hunyuan_secret_id",
|
||||||
|
"HUNYUAN_SECRET_ID",
|
||||||
|
)
|
||||||
|
values["hunyuan_secret_key"] = convert_to_secret_str(
|
||||||
|
get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"hunyuan_secret_key",
|
||||||
|
"HUNYUAN_SECRET_KEY",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling Hunyuan API."""
|
||||||
|
normal_params = {
|
||||||
|
"Temperature": self.temperature,
|
||||||
|
"TopP": self.top_p,
|
||||||
|
"Model": self.model,
|
||||||
|
"Stream": self.streaming,
|
||||||
|
"StreamModeration": self.stream_moderation,
|
||||||
|
"EnableEnhancement": self.enable_enhancement,
|
||||||
|
}
|
||||||
|
return {**normal_params, **self.model_kwargs}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
res = self._chat(messages, **kwargs)
|
||||||
|
return _create_chat_result(json.loads(res.to_json_string()))
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
res = self._chat(messages, **kwargs)
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for chunk in res:
|
||||||
|
chunk = chunk.get("data", "")
|
||||||
|
if len(chunk) == 0:
|
||||||
|
continue
|
||||||
|
response = json.loads(chunk)
|
||||||
|
if "error" in response:
|
||||||
|
raise ValueError(f"Error from Hunyuan api response: {response}")
|
||||||
|
|
||||||
|
for choice in response["Choices"]:
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["Delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||||
|
yield cg_chunk
|
||||||
|
|
||||||
|
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
|
||||||
|
if self.hunyuan_secret_key is None:
|
||||||
|
raise ValueError("Hunyuan secret key is not set.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tencentcloud.common import credential
|
||||||
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import tencentcloud python package. "
|
||||||
|
"Please install it with `pip install tencentcloud-sdk-python`."
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters = {**self._default_params, **kwargs}
|
||||||
|
cred = credential.Credential(
|
||||||
|
self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
|
||||||
|
)
|
||||||
|
client = hunyuan_client.HunyuanClient(cred, "")
|
||||||
|
req = models.ChatCompletionsRequest()
|
||||||
|
params = {
|
||||||
|
"Messages": [_convert_message_to_dict(m) for m in messages],
|
||||||
|
**parameters,
|
||||||
|
}
|
||||||
|
req.from_json_string(json.dumps(params))
|
||||||
|
resp = client.ChatCompletions(req)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "hunyuan-chat"
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
|
||||||
|
|
||||||
|
|
||||||
|
class TencentModel(MaxKBBaseModel, ChatHunyuan):
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs):
|
||||||
|
hunyuan_app_id = credentials.get('hunyuan_app_id')
|
||||||
|
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
|
||||||
|
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
|
||||||
|
|
||||||
|
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
|
||||||
|
raise ValueError(
|
||||||
|
"All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
|
||||||
|
|
||||||
|
super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
|
||||||
|
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
|
||||||
|
**model_kwargs) -> 'TencentModel':
|
||||||
|
streaming = model_kwargs.pop('streaming', False)
|
||||||
|
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages)
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -0,0 +1,103 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
from common.util.file_util import get_file_content
|
||||||
|
from setting.models_provider.base_model_provider import (
|
||||||
|
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||||
|
)
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||||
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||||
|
return ModelInfo(
|
||||||
|
name=model_name,
|
||||||
|
desc=description,
|
||||||
|
model_type=model_type,
|
||||||
|
model_credential=credential_class(),
|
||||||
|
model_class=model_class
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tencent_icon_path():
|
||||||
|
return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'tencent_model_provider',
|
||||||
|
'icon', 'tencent_icon_svg')
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_model_info():
|
||||||
|
model_info_list = [_create_model_info(
|
||||||
|
'hunyuan-pro',
|
||||||
|
'当前混元模型中效果最优版本,万亿级参数规模 MOE-32K 长文模型。在各种 benchmark 上达到绝对领先的水平,复杂指令和推理,具备复杂数学能力,支持 functioncall,在多语言翻译、金融法律医疗等领域应用重点优化',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
TencentLLMModelCredential,
|
||||||
|
TencentModel
|
||||||
|
),
|
||||||
|
_create_model_info(
|
||||||
|
'hunyuan-standard',
|
||||||
|
'采用更优的路由策略,同时缓解了负载均衡和专家趋同的问题。长文方面,大海捞针指标达到99.9%',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
TencentLLMModelCredential,
|
||||||
|
TencentModel),
|
||||||
|
_create_model_info(
|
||||||
|
'hunyuan-lite',
|
||||||
|
'升级为 MOE 结构,上下文窗口为 256k ,在 NLP,代码,数学,行业等多项评测集上领先众多开源模型',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
TencentLLMModelCredential,
|
||||||
|
TencentModel),
|
||||||
|
_create_model_info(
|
||||||
|
'hunyuan-role',
|
||||||
|
'混元最新版角色扮演模型,混元官方精调训练推出的角色扮演模型,基于混元模型结合角色扮演场景数据集进行增训,在角色扮演场景具有更好的基础效果',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
TencentLLMModelCredential,
|
||||||
|
TencentModel),
|
||||||
|
_create_model_info(
|
||||||
|
'hunyuan-functioncall ',
|
||||||
|
'混元最新 MOE 架构 FunctionCall 模型,经过高质量的 FunctionCall 数据训练,上下文窗口达 32K,在多个维度的评测指标上处于领先。',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
TencentLLMModelCredential,
|
||||||
|
TencentModel),
|
||||||
|
_create_model_info(
|
||||||
|
'hunyuan-code',
|
||||||
|
'混元最新代码生成模型,经过 200B 高质量代码数据增训基座模型,迭代半年高质量 SFT 数据训练,上下文长窗口长度增大到 8K,五大语言代码生成自动评测指标上位居前列;五大语言10项考量各方面综合代码任务人工高质量评测上,性能处于第一梯队',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
TencentLLMModelCredential,
|
||||||
|
TencentModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
tencent_embedding_model_info = _create_model_info(
|
||||||
|
'hunyuan-embedding',
|
||||||
|
'',
|
||||||
|
ModelTypeConst.EMBEDDING,
|
||||||
|
TencentEmbeddingCredential,
|
||||||
|
TencentEmbeddingModel
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info_embedding_list = [tencent_embedding_model_info]
|
||||||
|
|
||||||
|
model_info_manage = ModelInfoManage.builder() \
|
||||||
|
.append_model_info_list(model_info_list) \
|
||||||
|
.append_default_model_info(model_info_list[0]) \
|
||||||
|
.build()
|
||||||
|
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
|
|
||||||
|
class TencentModelProvider(IModelProvider):
|
||||||
|
def __init__(self):
|
||||||
|
self._model_info_manage = _initialize_model_info()
|
||||||
|
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return self._model_info_manage
|
||||||
|
|
||||||
|
def get_model_provide_info(self):
|
||||||
|
icon_path = _get_tencent_icon_path()
|
||||||
|
icon_data = get_file_content(icon_path)
|
||||||
|
return ModelProvideInfo(
|
||||||
|
provider='model_tencent_provider',
|
||||||
|
name='腾讯混元',
|
||||||
|
icon=icon_data
|
||||||
|
)
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 16:45
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=True):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.embed_query('你好')
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 17:57
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['access_key_id', 'secret_access_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
res = model.invoke([HumanMessage(content='你好')])
|
||||||
|
print(res)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'access_key_id': super().encryption(model.get('access_key_id', ''))}
|
||||||
|
|
||||||
|
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
|
||||||
|
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 60 KiB |
@ -0,0 +1,15 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_community.embeddings import VolcanoEmbeddings
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return VolcanicEngineEmbeddingModel(
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base=model_credential.get('api_base'),
|
||||||
|
)
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain_community.chat_models import VolcEngineMaasChat
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
volcanic_engine_chat = VolcanicEngineChatModel(
|
||||||
|
model=model_name,
|
||||||
|
volc_engine_maas_ak=model_credential.get("access_key_id"),
|
||||||
|
volc_engine_maas_sk=model_credential.get("secret_access_key"),
|
||||||
|
)
|
||||||
|
return volcanic_engine_chat
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :MaxKB
|
||||||
|
@File :gemini_model_provider.py
|
||||||
|
@Author :Brian Yang
|
||||||
|
@Date :5/13/24 7:47 AM
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from common.util.file_util import get_file_content
|
||||||
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||||
|
ModelInfoManage
|
||||||
|
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||||
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
|
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||||
|
|
||||||
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||||
|
|
||||||
|
model_info_list = [
|
||||||
|
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||||
|
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||||
|
ModelTypeConst.LLM,
|
||||||
|
volcanic_engine_llm_model_credential, OpenAIChatModel
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||||
|
model_info_embedding_list = [
|
||||||
|
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||||
|
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||||
|
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||||
|
OpenAIEmbeddingModel)]
|
||||||
|
|
||||||
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
|
model_info_list[0]).build()
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineModelProvider(IModelProvider):
|
||||||
|
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
|
def get_model_provide_info(self):
|
||||||
|
return ModelProvideInfo(provider='model_volcanic_engine_provider', name='火山引擎', icon=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'volcanic_engine_model_provider',
|
||||||
|
'icon',
|
||||||
|
'volcanic_engine_icon_svg')))
|
||||||
@ -46,6 +46,8 @@ gunicorn = "^22.0.0"
|
|||||||
python-daemon = "3.0.1"
|
python-daemon = "3.0.1"
|
||||||
gevent = "^24.2.1"
|
gevent = "^24.2.1"
|
||||||
|
|
||||||
|
boto3 = "^1.34.151"
|
||||||
|
langchain-aws = "^0.1.13"
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user