feat: Volcanic Engine Image Model
This commit is contained in:
parent
4d977fd765
commit
7de58de42a
@ -12,6 +12,7 @@ from application.flow.i_step_node import NodeResult, INode
|
|||||||
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
||||||
from dataset.models import File
|
from dataset.models import File
|
||||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
from imghdr import what
|
||||||
|
|
||||||
|
|
||||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||||
@ -59,8 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||||||
|
|
||||||
def file_id_to_base64(file_id: str):
|
def file_id_to_base64(file_id: str):
|
||||||
file = QuerySet(File).filter(id=file_id).first()
|
file = QuerySet(File).filter(id=file_id).first()
|
||||||
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
|
file_bytes = file.get_byte()
|
||||||
return base64_image
|
base64_image = base64.b64encode(file_bytes).decode("utf-8")
|
||||||
|
return [base64_image, what(None, file_bytes.tobytes())]
|
||||||
|
|
||||||
|
|
||||||
class BaseImageUnderstandNode(IImageUnderstandNode):
|
class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
@ -77,7 +79,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||||||
# 处理不正确的参数
|
# 处理不正确的参数
|
||||||
if image is None or not isinstance(image, list):
|
if image is None or not isinstance(image, list):
|
||||||
image = []
|
image = []
|
||||||
|
print(model_params_setting)
|
||||||
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
|
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
|
||||||
# 执行详情中的历史消息不需要图片内容
|
# 执行详情中的历史消息不需要图片内容
|
||||||
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
|
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
|
||||||
@ -152,7 +154,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||||||
return HumanMessage(
|
return HumanMessage(
|
||||||
content=[
|
content=[
|
||||||
{'type': 'text', 'text': data['question']},
|
{'type': 'text', 'text': data['question']},
|
||||||
*[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for
|
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||||
base64_image in image_base64_list]
|
base64_image in image_base64_list]
|
||||||
])
|
])
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
@ -167,8 +169,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||||||
for img in image:
|
for img in image:
|
||||||
file_id = img['file_id']
|
file_id = img['file_id']
|
||||||
file = QuerySet(File).filter(id=file_id).first()
|
file = QuerySet(File).filter(id=file_id).first()
|
||||||
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
|
image_bytes = file.get_byte()
|
||||||
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}})
|
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
image_format = what(None, image_bytes.tobytes())
|
||||||
|
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||||
messages = [HumanMessage(
|
messages = [HumanMessage(
|
||||||
content=[
|
content=[
|
||||||
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
||||||
|
|||||||
@ -0,0 +1,63 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
class VolcanicEngineImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||||
|
required=True, default_value=0.95,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||||
|
required=True, default_value=1024,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_key', 'api_base']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
|
for chunk in res:
|
||||||
|
print(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return VolcanicEngineImageModelParams()
|
||||||
@ -0,0 +1,62 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineTTIModelGeneralParams(BaseForm):
|
||||||
|
size = forms.SingleSelect(
|
||||||
|
TooltipLabel('图片尺寸',
|
||||||
|
'宽、高与512差距过大,则出图效果不佳、延迟过长概率显著增加。超分前建议比例及对应宽高:width*height'),
|
||||||
|
required=True,
|
||||||
|
default_value='512*512',
|
||||||
|
option_list=[
|
||||||
|
{'value': '512*512', 'label': '512*512'},
|
||||||
|
{'value': '512*384', 'label': '512*384'},
|
||||||
|
{'value': '384*512', 'label': '384*512'},
|
||||||
|
{'value': '512*341', 'label': '512*341'},
|
||||||
|
{'value': '341*512', 'label': '341*512'},
|
||||||
|
{'value': '512*288', 'label': '512*288'},
|
||||||
|
{'value': '288*512', 'label': '288*512'},
|
||||||
|
],
|
||||||
|
text_field='label',
|
||||||
|
value_field='value')
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
access_key = forms.PasswordInputField('Access Key', required=True)
|
||||||
|
secret_key = forms.PasswordInputField('Secret Key', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['access_key', 'secret_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'secret_key': super().encryption(model.get('secret_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return VolcanicEngineTTIModelGeneralParams()
|
||||||
@ -0,0 +1,26 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineImage(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
|
return VolcanicEngineImage(
|
||||||
|
model_name=model_name,
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
openai_api_base=model_credential.get('api_base'),
|
||||||
|
# stream_options={"include_usage": True},
|
||||||
|
streaming=True,
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
@ -0,0 +1,173 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
'''
|
||||||
|
requires Python 3.6 or later
|
||||||
|
|
||||||
|
pip install asyncio
|
||||||
|
pip install websockets
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||||
|
|
||||||
|
method = 'POST'
|
||||||
|
host = 'visual.volcengineapi.com'
|
||||||
|
region = 'cn-north-1'
|
||||||
|
endpoint = 'https://visual.volcengineapi.com'
|
||||||
|
service = 'cv'
|
||||||
|
|
||||||
|
req_key_dict = {
|
||||||
|
'general_v1.4': 'high_aes_general_v14',
|
||||||
|
'general_v2.0': 'high_aes_general_v20',
|
||||||
|
'general_v2.0_L': 'high_aes_general_v20_L',
|
||||||
|
'anime_v1.3': 'high_aes',
|
||||||
|
'anime_v1.3.1': 'high_aes',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sign(key, msg):
|
||||||
|
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def getSignatureKey(key, dateStamp, regionName, serviceName):
|
||||||
|
kDate = sign(key.encode('utf-8'), dateStamp)
|
||||||
|
kRegion = sign(kDate, regionName)
|
||||||
|
kService = sign(kRegion, serviceName)
|
||||||
|
kSigning = sign(kService, 'request')
|
||||||
|
return kSigning
|
||||||
|
|
||||||
|
|
||||||
|
def formatQuery(parameters):
|
||||||
|
request_parameters_init = ''
|
||||||
|
for key in sorted(parameters):
|
||||||
|
request_parameters_init += key + '=' + parameters[key] + '&'
|
||||||
|
request_parameters = request_parameters_init[:-1]
|
||||||
|
return request_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def signV4Request(access_key, secret_key, service, req_query, req_body):
|
||||||
|
if access_key is None or secret_key is None:
|
||||||
|
print('No access key is available.')
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
t = datetime.datetime.utcnow()
|
||||||
|
current_date = t.strftime('%Y%m%dT%H%M%SZ')
|
||||||
|
# current_date = '20210818T095729Z'
|
||||||
|
datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope
|
||||||
|
canonical_uri = '/'
|
||||||
|
canonical_querystring = req_query
|
||||||
|
signed_headers = 'content-type;host;x-content-sha256;x-date'
|
||||||
|
payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest()
|
||||||
|
content_type = 'application/json'
|
||||||
|
canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + host + \
|
||||||
|
'\n' + 'x-content-sha256:' + payload_hash + \
|
||||||
|
'\n' + 'x-date:' + current_date + '\n'
|
||||||
|
canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + \
|
||||||
|
'\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash
|
||||||
|
# print(canonical_request)
|
||||||
|
algorithm = 'HMAC-SHA256'
|
||||||
|
credential_scope = datestamp + '/' + region + '/' + service + '/' + 'request'
|
||||||
|
string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256(
|
||||||
|
canonical_request.encode('utf-8')).hexdigest()
|
||||||
|
# print(string_to_sign)
|
||||||
|
signing_key = getSignatureKey(secret_key, datestamp, region, service)
|
||||||
|
# print(signing_key)
|
||||||
|
signature = hmac.new(signing_key, (string_to_sign).encode(
|
||||||
|
'utf-8'), hashlib.sha256).hexdigest()
|
||||||
|
# print(signature)
|
||||||
|
|
||||||
|
authorization_header = algorithm + ' ' + 'Credential=' + access_key + '/' + \
|
||||||
|
credential_scope + ', ' + 'SignedHeaders=' + \
|
||||||
|
signed_headers + ', ' + 'Signature=' + signature
|
||||||
|
# print(authorization_header)
|
||||||
|
headers = {'X-Date': current_date,
|
||||||
|
'Authorization': authorization_header,
|
||||||
|
'X-Content-Sha256': payload_hash,
|
||||||
|
'Content-Type': content_type
|
||||||
|
}
|
||||||
|
# print(headers)
|
||||||
|
|
||||||
|
# ************* SEND THE REQUEST *************
|
||||||
|
request_url = endpoint + '?' + canonical_querystring
|
||||||
|
|
||||||
|
print('\nBEGIN REQUEST++++++++++++++++++++++++++++++++++++')
|
||||||
|
print('Request URL = ' + request_url)
|
||||||
|
try:
|
||||||
|
r = requests.post(request_url, headers=headers, data=req_body)
|
||||||
|
except Exception as err:
|
||||||
|
print(f'error occurred: {err}')
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
print('\nRESPONSE++++++++++++++++++++++++++++++++++++')
|
||||||
|
print(f'Response code: {r.status_code}\n')
|
||||||
|
# 使用 replace 方法将 \u0026 替换为 &
|
||||||
|
resp_str = r.text.replace("\\u0026", "&")
|
||||||
|
if r.status_code != 200:
|
||||||
|
raise Exception(f'Error: {resp_str}')
|
||||||
|
print(f'Response body: {resp_str}\n')
|
||||||
|
return json.loads(resp_str)['data']['image_urls']
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||||
|
access_key: str
|
||||||
|
secret_key: str
|
||||||
|
model_version: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.access_key = kwargs.get('access_key')
|
||||||
|
self.secret_key = kwargs.get('secret_key')
|
||||||
|
self.model_version = kwargs.get('model_version')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {'params': {}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
return VolcanicEngineTextToImage(
|
||||||
|
model_version=model_name,
|
||||||
|
access_key=model_credential.get('access_key'),
|
||||||
|
secret_key=model_credential.get('secret_key'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
res = self.generate_image('生成一张小猫图片')
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
# 请求Query,按照接口文档中填入即可
|
||||||
|
query_params = {
|
||||||
|
'Action': 'CVProcess',
|
||||||
|
'Version': '2022-08-31',
|
||||||
|
}
|
||||||
|
formatted_query = formatQuery(query_params)
|
||||||
|
size = self.params.pop('size', '512*512').split('*')
|
||||||
|
body_params = {
|
||||||
|
"req_key": req_key_dict[self.model_version],
|
||||||
|
"prompt": prompt,
|
||||||
|
"model_version": self.model_version,
|
||||||
|
"return_url": True,
|
||||||
|
"width": int(size[0]),
|
||||||
|
"height": int(size[1]),
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
formatted_body = json.dumps(body_params)
|
||||||
|
return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body)
|
||||||
|
|
||||||
|
def is_cache_model(self):
|
||||||
|
return False
|
||||||
@ -14,10 +14,15 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \
|
||||||
|
VolcanicEngineImageModelCredential
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.tti import VolcanicEngineTextToImage
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
|
||||||
|
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
@ -25,6 +30,8 @@ from smartdoc.conf import PROJECT_DIR
|
|||||||
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||||
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
||||||
volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential()
|
volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential()
|
||||||
|
volcanic_engine_image_model_credential = VolcanicEngineImageModelCredential()
|
||||||
|
volcanic_engine_tti_model_credential = VolcanicEngineTTIModelCredential()
|
||||||
|
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||||
@ -32,6 +39,11 @@ model_info_list = [
|
|||||||
ModelTypeConst.LLM,
|
ModelTypeConst.LLM,
|
||||||
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
|
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
|
||||||
),
|
),
|
||||||
|
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||||
|
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||||
|
ModelTypeConst.IMAGE,
|
||||||
|
volcanic_engine_image_model_credential, VolcanicEngineImage
|
||||||
|
),
|
||||||
ModelInfo('asr',
|
ModelInfo('asr',
|
||||||
'',
|
'',
|
||||||
ModelTypeConst.STT,
|
ModelTypeConst.STT,
|
||||||
@ -42,6 +54,31 @@ model_info_list = [
|
|||||||
ModelTypeConst.TTS,
|
ModelTypeConst.TTS,
|
||||||
volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech
|
volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech
|
||||||
),
|
),
|
||||||
|
ModelInfo('general_v2.0',
|
||||||
|
'通用2.0-文生图',
|
||||||
|
ModelTypeConst.TTI,
|
||||||
|
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
||||||
|
),
|
||||||
|
ModelInfo('general_v2.0_L',
|
||||||
|
'通用2.0Pro-文生图',
|
||||||
|
ModelTypeConst.TTI,
|
||||||
|
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
||||||
|
),
|
||||||
|
ModelInfo('general_v1.4',
|
||||||
|
'通用1.4-文生图',
|
||||||
|
ModelTypeConst.TTI,
|
||||||
|
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
||||||
|
),
|
||||||
|
ModelInfo('anime_v1.3',
|
||||||
|
'动漫1.3.0-文生图',
|
||||||
|
ModelTypeConst.TTI,
|
||||||
|
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
||||||
|
),
|
||||||
|
ModelInfo('anime_v1.3.1',
|
||||||
|
'动漫1.3.1-文生图',
|
||||||
|
ModelTypeConst.TTI,
|
||||||
|
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||||
@ -51,8 +88,13 @@ model_info_embedding_list = [
|
|||||||
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||||
OpenAIEmbeddingModel)]
|
OpenAIEmbeddingModel)]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_manage = (
|
||||||
model_info_list[0]).build()
|
ModelInfoManage.builder()
|
||||||
|
.append_model_info_list(model_info_list)
|
||||||
|
.append_default_model_info(model_info_list[0])
|
||||||
|
.append_default_model_info(model_info_list[1])
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VolcanicEngineModelProvider(IModelProvider):
|
class VolcanicEngineModelProvider(IModelProvider):
|
||||||
|
|||||||
@ -25,6 +25,15 @@
|
|||||||
<div>
|
<div>
|
||||||
<span>图片理解模型<span class="danger">*</span></span>
|
<span>图片理解模型<span class="danger">*</span></span>
|
||||||
</div>
|
</div>
|
||||||
|
<el-button
|
||||||
|
:disabled="!form_data.model_id"
|
||||||
|
type="primary"
|
||||||
|
link
|
||||||
|
@click="openAIParamSettingDialog(form_data.model_id)"
|
||||||
|
@refreshForm="refreshParam"
|
||||||
|
>
|
||||||
|
{{ $t('views.application.applicationForm.form.paramSetting') }}
|
||||||
|
</el-button>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<el-select
|
<el-select
|
||||||
@ -183,6 +192,7 @@
|
|||||||
</el-form-item>
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
</el-card>
|
</el-card>
|
||||||
|
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
|
||||||
</NodeContainer>
|
</NodeContainer>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@ -197,6 +207,7 @@ import { app } from '@/main'
|
|||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||||
import type { FormInstance } from 'element-plus'
|
import type { FormInstance } from 'element-plus'
|
||||||
|
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
|
||||||
|
|
||||||
const { model } = useStore()
|
const { model } = useStore()
|
||||||
|
|
||||||
@ -207,6 +218,7 @@ const {
|
|||||||
const props = defineProps<{ nodeModel: any }>()
|
const props = defineProps<{ nodeModel: any }>()
|
||||||
const modelOptions = ref<any>(null)
|
const modelOptions = ref<any>(null)
|
||||||
const providerOptions = ref<Array<Provider>>([])
|
const providerOptions = ref<Array<Provider>>([])
|
||||||
|
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
|
||||||
|
|
||||||
const aiChatNodeFormRef = ref<FormInstance>()
|
const aiChatNodeFormRef = ref<FormInstance>()
|
||||||
const validate = () => {
|
const validate = () => {
|
||||||
@ -281,6 +293,16 @@ function submitDialog(val: string) {
|
|||||||
set(props.nodeModel.properties.node_data, 'prompt', val)
|
set(props.nodeModel.properties.node_data, 'prompt', val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const openAIParamSettingDialog = (modelId: string) => {
|
||||||
|
if (modelId) {
|
||||||
|
AIModeParamSettingDialogRef.value?.open(modelId, id, form_data.value.model_params_setting)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshParam(data: any) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
getModel()
|
getModel()
|
||||||
getProvider()
|
getProvider()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user