feat: Support image generate model
This commit is contained in:
parent
9e859be5ff
commit
add99fabc6
@ -18,6 +18,7 @@ from .reranker_node import *
|
|||||||
|
|
||||||
from .document_extract_node import *
|
from .document_extract_node import *
|
||||||
from .image_understand_step_node import *
|
from .image_understand_step_node import *
|
||||||
|
from .image_generate_step_node import *
|
||||||
|
|
||||||
from .search_dataset_node import *
|
from .search_dataset_node import *
|
||||||
from .start_node import *
|
from .start_node import *
|
||||||
@ -25,7 +26,7 @@ from .start_node import *
|
|||||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
||||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
||||||
BaseDocumentExtractNode,
|
BaseDocumentExtractNode,
|
||||||
BaseImageUnderstandNode, BaseFormNode]
|
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
|
||||||
|
|
||||||
|
|
||||||
def get_node(node_type):
|
def get_node(node_type):
|
||||||
|
|||||||
@ -0,0 +1,3 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from .impl import *
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerateNodeSerializer(serializers.Serializer):
|
||||||
|
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||||
|
|
||||||
|
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)"))
|
||||||
|
|
||||||
|
negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)"))
|
||||||
|
# 多轮对话数量
|
||||||
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||||
|
|
||||||
|
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
|
||||||
|
|
||||||
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||||
|
|
||||||
|
|
||||||
|
class IImageGenerateNode(INode):
|
||||||
|
type = 'image-generate-node'
|
||||||
|
|
||||||
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
|
return ImageGenerateNodeSerializer
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
|
|
||||||
|
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||||
|
chat_record_id,
|
||||||
|
**kwargs) -> NodeResult:
|
||||||
|
pass
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from .base_image_generate_node import BaseImageGenerateNode
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from functools import reduce
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
||||||
|
|
||||||
|
from application.flow.i_step_node import NodeResult
|
||||||
|
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
|
||||||
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImageGenerateNode(IImageGenerateNode):
|
||||||
|
def save_context(self, details, workflow_manage):
|
||||||
|
self.context['answer'] = details.get('answer')
|
||||||
|
self.context['question'] = details.get('question')
|
||||||
|
self.answer_text = details.get('answer')
|
||||||
|
|
||||||
|
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||||
|
chat_record_id,
|
||||||
|
**kwargs) -> NodeResult:
|
||||||
|
|
||||||
|
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||||
|
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||||
|
self.context['history_message'] = history_message
|
||||||
|
question = self.generate_prompt_question(prompt)
|
||||||
|
self.context['question'] = question
|
||||||
|
message_list = self.generate_message_list(question, history_message)
|
||||||
|
self.context['message_list'] = message_list
|
||||||
|
self.context['dialogue_type'] = dialogue_type
|
||||||
|
print(message_list)
|
||||||
|
print(negative_prompt)
|
||||||
|
image_urls = tti_model.generate_image(question, negative_prompt)
|
||||||
|
self.context['image_list'] = image_urls
|
||||||
|
answer = '\n'.join([f"" for path in image_urls])
|
||||||
|
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
|
||||||
|
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls],
|
||||||
|
'history_message': history_message, 'question': question}, {})
|
||||||
|
|
||||||
|
def generate_history_ai_message(self, chat_record):
|
||||||
|
for val in chat_record.details.values():
|
||||||
|
if self.node.id == val['node_id'] and 'image_list' in val:
|
||||||
|
if val['dialogue_type'] == 'WORKFLOW':
|
||||||
|
return chat_record.get_ai_message()
|
||||||
|
return AIMessage(content=val['answer'])
|
||||||
|
return chat_record.get_ai_message()
|
||||||
|
|
||||||
|
def get_history_message(self, history_chat_record, dialogue_number):
|
||||||
|
start_index = len(history_chat_record) - dialogue_number
|
||||||
|
history_message = reduce(lambda x, y: [*x, *y], [
|
||||||
|
[self.generate_history_human_message(history_chat_record[index]),
|
||||||
|
self.generate_history_ai_message(history_chat_record[index])]
|
||||||
|
for index in
|
||||||
|
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||||
|
return history_message
|
||||||
|
|
||||||
|
def generate_history_human_message(self, chat_record):
|
||||||
|
|
||||||
|
for data in chat_record.details.values():
|
||||||
|
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||||
|
image_list = data['image_list']
|
||||||
|
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||||
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
|
return HumanMessage(content=data['question'])
|
||||||
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
|
|
||||||
|
def generate_prompt_question(self, prompt):
|
||||||
|
return self.workflow_manage.generate_prompt(prompt)
|
||||||
|
|
||||||
|
def generate_message_list(self, question: str, history_message):
|
||||||
|
return [
|
||||||
|
*history_message,
|
||||||
|
question
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||||
|
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||||
|
message
|
||||||
|
in
|
||||||
|
message_list]
|
||||||
|
result.append({'role': 'ai', 'content': answer_text})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_details(self, index: int, **kwargs):
|
||||||
|
return {
|
||||||
|
'name': self.node.properties.get('stepName'),
|
||||||
|
"index": index,
|
||||||
|
'run_time': self.context.get('run_time'),
|
||||||
|
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||||
|
(self.context.get('history_message') if self.context.get(
|
||||||
|
'history_message') is not None else [])],
|
||||||
|
'question': self.context.get('question'),
|
||||||
|
'answer': self.context.get('answer'),
|
||||||
|
'type': self.node.type,
|
||||||
|
'message_tokens': self.context.get('message_tokens'),
|
||||||
|
'answer_tokens': self.context.get('answer_tokens'),
|
||||||
|
'status': self.status,
|
||||||
|
'err_message': self.err_message,
|
||||||
|
'image_list': self.context.get('image_list'),
|
||||||
|
'dialogue_type': self.context.get('dialogue_type')
|
||||||
|
}
|
||||||
@ -54,7 +54,7 @@ class Node:
|
|||||||
|
|
||||||
|
|
||||||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
||||||
'image-understand-node']
|
'image-understand-node', 'image-generate-node']
|
||||||
|
|
||||||
|
|
||||||
class Flow:
|
class Flow:
|
||||||
|
|||||||
@ -8,9 +8,12 @@
|
|||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
|
import mimetypes
|
||||||
|
import io
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from ..exception.app_exception import AppApiException
|
from ..exception.app_exception import AppApiException
|
||||||
@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000):
|
|||||||
batch = data[i:i + batch_size]
|
batch = data[i:i + batch_size]
|
||||||
model.objects.bulk_create(batch)
|
model.objects.bulk_create(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||||
|
content_type, _ = mimetypes.guess_type(file_name)
|
||||||
|
if content_type is None:
|
||||||
|
# 如果未能识别,设置为默认的二进制文件类型
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
# 创建一个内存中的字节流对象
|
||||||
|
file_stream = io.BytesIO(file_bytes)
|
||||||
|
|
||||||
|
# 获取文件大小
|
||||||
|
file_size = len(file_bytes)
|
||||||
|
|
||||||
|
# 创建 InMemoryUploadedFile 对象
|
||||||
|
uploaded_file = InMemoryUploadedFile(
|
||||||
|
file=file_stream,
|
||||||
|
field_name=None,
|
||||||
|
name=file_name,
|
||||||
|
content_type=content_type,
|
||||||
|
size=file_size,
|
||||||
|
charset=None,
|
||||||
|
)
|
||||||
|
return uploaded_file
|
||||||
|
|||||||
@ -150,6 +150,7 @@ class ModelTypeConst(Enum):
|
|||||||
STT = {'code': 'STT', 'message': '语音识别'}
|
STT = {'code': 'STT', 'message': '语音识别'}
|
||||||
TTS = {'code': 'TTS', 'message': '语音合成'}
|
TTS = {'code': 'TTS', 'message': '语音合成'}
|
||||||
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
|
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
|
||||||
|
TTI = {'code': 'TTI', 'message': '图片生成'}
|
||||||
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
|
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
14
apps/setting/models_provider/impl/base_tti.py
Normal file
14
apps/setting/models_provider/impl/base_tti.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTextToImage(BaseModel):
|
||||||
|
@abstractmethod
|
||||||
|
def check_auth(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
pass
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
res = model.check_auth()
|
||||||
|
print(res)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
@ -0,0 +1,67 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from common.util.common import bytes_to_uploaded_file
|
||||||
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||||
|
api_base: str
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = kwargs.get('api_key')
|
||||||
|
self.api_base = kwargs.get('api_base')
|
||||||
|
self.model = kwargs.get('model')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {'params': {}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
return OpenAITextToImage(
|
||||||
|
model=model_name,
|
||||||
|
api_base=model_credential.get('api_base'),
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||||
|
response_list = chat.models.with_raw_response.list()
|
||||||
|
|
||||||
|
# self.generate_image('生成一个小猫图片')
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
|
||||||
|
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||||
|
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
|
||||||
|
|
||||||
|
file_urls = []
|
||||||
|
for content in res.data:
|
||||||
|
url = content.url
|
||||||
|
print(url)
|
||||||
|
file_name = 'generated_image.png'
|
||||||
|
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
|
||||||
|
meta = {'debug': True}
|
||||||
|
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
|
file_urls.append(file_url)
|
||||||
|
|
||||||
|
return file_urls
|
||||||
@ -15,11 +15,13 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp
|
|||||||
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
|
||||||
|
from setting.models_provider.impl.openai_model_provider.credential.tti import OpenAITextToImageModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
|
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
|
||||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||||
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
|
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
|
||||||
|
from setting.models_provider.impl.openai_model_provider.model.tti import OpenAITextToImage
|
||||||
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
|
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ openai_llm_model_credential = OpenAILLMModelCredential()
|
|||||||
openai_stt_model_credential = OpenAISTTModelCredential()
|
openai_stt_model_credential = OpenAISTTModelCredential()
|
||||||
openai_tts_model_credential = OpenAITTSModelCredential()
|
openai_tts_model_credential = OpenAITTSModelCredential()
|
||||||
openai_image_model_credential = OpenAIImageModelCredential()
|
openai_image_model_credential = OpenAIImageModelCredential()
|
||||||
|
openai_tti_model_credential = OpenAITextToImageModelCredential()
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
openai_llm_model_credential, OpenAIChatModel
|
openai_llm_model_credential, OpenAIChatModel
|
||||||
@ -37,8 +40,8 @@ model_info_list = [
|
|||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
OpenAIChatModel),
|
OpenAIChatModel),
|
||||||
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
|
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
OpenAIChatModel),
|
OpenAIChatModel),
|
||||||
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
openai_llm_model_credential,
|
openai_llm_model_credential,
|
||||||
OpenAIChatModel),
|
OpenAIChatModel),
|
||||||
@ -100,11 +103,27 @@ model_info_image_list = [
|
|||||||
OpenAIImage),
|
OpenAIImage),
|
||||||
]
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_tti_list = [
|
||||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
ModelInfo('dall-e-2', '',
|
||||||
openai_llm_model_credential, OpenAIChatModel
|
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||||
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
|
OpenAITextToImage),
|
||||||
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()
|
ModelInfo('dall-e-3', '',
|
||||||
|
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||||
|
OpenAITextToImage),
|
||||||
|
]
|
||||||
|
|
||||||
|
model_info_manage = (
|
||||||
|
ModelInfoManage.builder()
|
||||||
|
.append_model_info_list(model_info_list)
|
||||||
|
.append_default_model_info(ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential, OpenAIChatModel
|
||||||
|
))
|
||||||
|
.append_model_info_list(model_info_embedding_list)
|
||||||
|
.append_default_model_info(model_info_embedding_list[0])
|
||||||
|
.append_model_info_list(model_info_image_list)
|
||||||
|
.append_model_info_list(model_info_tti_list)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(IModelProvider):
|
class OpenAIModelProvider(IModelProvider):
|
||||||
|
|||||||
@ -0,0 +1,70 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:41
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||||
|
required=True, default_value=1.0,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.9,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
res = model.check_auth()
|
||||||
|
print(res)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return QwenModelParams()
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from http import HTTPStatus
|
||||||
|
from pathlib import PurePosixPath
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from dashscope import ImageSynthesis
|
||||||
|
from langchain_community.chat_models import ChatTongyi
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common.util.common import bytes_to_uploaded_file
|
||||||
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||||
|
|
||||||
|
|
||||||
|
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
|
||||||
|
api_key: str
|
||||||
|
model_name: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = kwargs.get('api_key')
|
||||||
|
self.model_name = kwargs.get('model_name')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {'params': {}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
chat_tong_yi = QwenTextToImageModel(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
return chat_tong_yi
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
|
||||||
|
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||||
|
rsp = ImageSynthesis.call(api_key=self.api_key,
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
**self.params)
|
||||||
|
file_urls = []
|
||||||
|
if rsp.status_code == HTTPStatus.OK:
|
||||||
|
for result in rsp.output.results:
|
||||||
|
file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1]
|
||||||
|
file = bytes_to_uploaded_file(requests.get(result.url).content, file_name)
|
||||||
|
meta = {'debug': True}
|
||||||
|
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
|
file_urls.append(file_url)
|
||||||
|
else:
|
||||||
|
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
|
||||||
|
(rsp.status_code, rsp.code, rsp.message))
|
||||||
|
return file_urls
|
||||||
@ -13,13 +13,16 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
|
|||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
|
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
|
||||||
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
|
from setting.models_provider.impl.qwen_model_provider.credential.tti import QwenTextToImageModelCredential
|
||||||
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
|
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
|
||||||
|
|
||||||
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||||
|
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
qwen_model_credential = OpenAILLMModelCredential()
|
qwen_model_credential = OpenAILLMModelCredential()
|
||||||
qwenvl_model_credential = QwenVLModelCredential()
|
qwenvl_model_credential = QwenVLModelCredential()
|
||||||
|
qwentti_model_credential = QwenTextToImageModelCredential()
|
||||||
|
|
||||||
module_info_list = [
|
module_info_list = [
|
||||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||||
@ -31,13 +34,21 @@ module_info_vl_list = [
|
|||||||
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||||
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||||
]
|
]
|
||||||
|
module_info_tti_list = [
|
||||||
|
ModelInfo('wanx-v1',
|
||||||
|
'通义万相-文本生成图像大模型,支持中英文双语输入,支持输入参考图片进行参考内容或者参考风格迁移,重点风格包括但不限于水彩、油画、中国画、素描、扁平插画、二次元、3D卡通。',
|
||||||
|
ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel),
|
||||||
|
]
|
||||||
|
|
||||||
model_info_manage = (ModelInfoManage.builder()
|
model_info_manage = (
|
||||||
.append_model_info_list(module_info_list)
|
ModelInfoManage.builder()
|
||||||
.append_default_model_info(
|
.append_model_info_list(module_info_list)
|
||||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
|
.append_default_model_info(
|
||||||
.append_model_info_list(module_info_vl_list)
|
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
|
||||||
.build())
|
.append_model_info_list(module_info_vl_list)
|
||||||
|
.append_model_info_list(module_info_tti_list)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QwenModelProvider(IModelProvider):
|
class QwenModelProvider(IModelProvider):
|
||||||
|
|||||||
@ -0,0 +1,108 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class TencentTTIModelParams(BaseForm):
|
||||||
|
Style = forms.SingleSelect(
|
||||||
|
TooltipLabel('绘画风格', '不传默认使用201(日系动漫风格)'),
|
||||||
|
required=True,
|
||||||
|
default_value='201',
|
||||||
|
option_list=[
|
||||||
|
{'value': '000', 'label': '不限定风格'},
|
||||||
|
{'value': '101', 'label': '水墨画'},
|
||||||
|
{'value': '102', 'label': '概念艺术'},
|
||||||
|
{'value': '103', 'label': '油画1'},
|
||||||
|
{'value': '118', 'label': '油画2(梵高)'},
|
||||||
|
{'value': '104', 'label': '水彩画'},
|
||||||
|
{'value': '105', 'label': '像素画'},
|
||||||
|
{'value': '106', 'label': '厚涂风格'},
|
||||||
|
{'value': '107', 'label': '插图'},
|
||||||
|
{'value': '108', 'label': '剪纸风格'},
|
||||||
|
{'value': '109', 'label': '印象派1(莫奈)'},
|
||||||
|
{'value': '119', 'label': '印象派2'},
|
||||||
|
{'value': '110', 'label': '2.5D'},
|
||||||
|
{'value': '111', 'label': '古典肖像画'},
|
||||||
|
{'value': '112', 'label': '黑白素描画'},
|
||||||
|
{'value': '113', 'label': '赛博朋克'},
|
||||||
|
{'value': '114', 'label': '科幻风格'},
|
||||||
|
{'value': '115', 'label': '暗黑风格'},
|
||||||
|
{'value': '116', 'label': '3D'},
|
||||||
|
{'value': '117', 'label': '蒸汽波'},
|
||||||
|
{'value': '201', 'label': '日系动漫'},
|
||||||
|
{'value': '202', 'label': '怪兽风格'},
|
||||||
|
{'value': '203', 'label': '唯美古风'},
|
||||||
|
{'value': '204', 'label': '复古动漫'},
|
||||||
|
{'value': '301', 'label': '游戏卡通手绘'},
|
||||||
|
{'value': '401', 'label': '通用写实风格'},
|
||||||
|
],
|
||||||
|
value_field='value',
|
||||||
|
text_field='label'
|
||||||
|
)
|
||||||
|
|
||||||
|
Resolution = forms.SingleSelect(
|
||||||
|
TooltipLabel('生成图分辨率', '不传默认使用768:768。'),
|
||||||
|
required=True,
|
||||||
|
default_value='768:768',
|
||||||
|
option_list=[
|
||||||
|
{'value': '768:768', 'label': '768:768(1:1)'},
|
||||||
|
{'value': '768:1024', 'label': '768:1024(3:4)'},
|
||||||
|
{'value': '1024:768', 'label': '1024:768(4:3)'},
|
||||||
|
{'value': '1024:1024', 'label': '1024:1024(1:1)'},
|
||||||
|
{'value': '720:1280', 'label': '720:1280(9:16)'},
|
||||||
|
{'value': '1280:720', 'label': '1280:720(16:9)'},
|
||||||
|
{'value': '768:1280', 'label': '768:1280(3:5)'},
|
||||||
|
{'value': '1280:768', 'label': '1280:768(5:3)'},
|
||||||
|
{'value': '1080:1920', 'label': '1080:1920(9:16)'},
|
||||||
|
{'value': '1920:1080', 'label': '1920:1080(16:9)'},
|
||||||
|
],
|
||||||
|
value_field='value',
|
||||||
|
text_field='label'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentTTIModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key']
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_model_type(cls, model_type, provider, raise_exception=False):
|
||||||
|
if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_credential_fields(cls, model_credential, raise_exception=False):
|
||||||
|
missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
|
||||||
|
if missing_keys:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
|
||||||
|
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
||||||
|
self._validate_credential_fields(model_credential, raise_exception)):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model):
|
||||||
|
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
|
||||||
|
|
||||||
|
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
|
||||||
|
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return TencentTTIModelParams()
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from tencentcloud.common import credential
|
||||||
|
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||||||
|
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||||
|
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||||
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||||
|
|
||||||
|
from common.util.common import bytes_to_uploaded_file
|
||||||
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
|
||||||
|
|
||||||
|
|
||||||
|
class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage):
|
||||||
|
hunyuan_secret_id: str
|
||||||
|
hunyuan_secret_key: str
|
||||||
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.hunyuan_secret_id = kwargs.get('hunyuan_secret_id')
|
||||||
|
self.hunyuan_secret_key = kwargs.get('hunyuan_secret_key')
|
||||||
|
self.model = kwargs.get('model_name')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
|
||||||
|
**model_kwargs) -> 'TencentTextToImageModel':
|
||||||
|
optional_params = {'params': {'Style': '201', 'Resolution': '768:768'}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
return TencentTextToImageModel(
|
||||||
|
model=model_name,
|
||||||
|
hunyuan_secret_id=model_credential.get('hunyuan_secret_id'),
|
||||||
|
hunyuan_secret_key=model_credential.get('hunyuan_secret_key'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
chat = ChatHunyuan(hunyuan_app_id='111111',
|
||||||
|
hunyuan_secret_id=self.hunyuan_secret_id,
|
||||||
|
hunyuan_secret_key=self.hunyuan_secret_key,
|
||||||
|
model="hunyuan-standard")
|
||||||
|
res = chat.invoke('你好')
|
||||||
|
# print(res)
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
try:
|
||||||
|
# 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
|
||||||
|
# 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
|
||||||
|
# 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
|
||||||
|
cred = credential.Credential(self.hunyuan_secret_id, self.hunyuan_secret_key)
|
||||||
|
# 实例化一个http选项,可选的,没有特殊需求可以跳过
|
||||||
|
httpProfile = HttpProfile()
|
||||||
|
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||||
|
|
||||||
|
# 实例化一个client选项,可选的,没有特殊需求可以跳过
|
||||||
|
clientProfile = ClientProfile()
|
||||||
|
clientProfile.httpProfile = httpProfile
|
||||||
|
# 实例化要请求产品的client对象,clientProfile是可选的
|
||||||
|
client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile)
|
||||||
|
|
||||||
|
# 实例化一个请求对象,每个接口都会对应一个request对象
|
||||||
|
req = models.TextToImageLiteRequest()
|
||||||
|
params = {
|
||||||
|
"Prompt": prompt,
|
||||||
|
"NegativePrompt": negative_prompt,
|
||||||
|
"RspImgType": "url",
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
req.from_json_string(json.dumps(params))
|
||||||
|
|
||||||
|
# 返回的resp是一个TextToImageLiteResponse的实例,与请求对象对应
|
||||||
|
resp = client.TextToImageLite(req)
|
||||||
|
# 输出json格式的字符串回包
|
||||||
|
print(resp.to_json_string())
|
||||||
|
file_urls = []
|
||||||
|
file_name = 'generated_image.png'
|
||||||
|
file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name)
|
||||||
|
meta = {'debug': True}
|
||||||
|
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
|
file_urls.append(file_url)
|
||||||
|
return file_urls
|
||||||
|
except TencentCloudSDKException as err:
|
||||||
|
print(err)
|
||||||
|
|
||||||
@ -9,9 +9,11 @@ from setting.models_provider.base_model_provider import (
|
|||||||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||||
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
|
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
|
||||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.credential.tti import TencentTTIModelCredential
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
|
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -87,11 +89,19 @@ def _initialize_model_info():
|
|||||||
TencentVisionModelCredential,
|
TencentVisionModelCredential,
|
||||||
TencentVision)]
|
TencentVision)]
|
||||||
|
|
||||||
|
model_info_tti_list = [_create_model_info(
|
||||||
|
'hunyuan-dit',
|
||||||
|
'混元生图模型',
|
||||||
|
ModelTypeConst.TTI,
|
||||||
|
TencentTTIModelCredential,
|
||||||
|
TencentTextToImageModel)]
|
||||||
|
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder() \
|
model_info_manage = ModelInfoManage.builder() \
|
||||||
.append_model_info_list(model_info_list) \
|
.append_model_info_list(model_info_list) \
|
||||||
.append_model_info_list(model_info_embedding_list) \
|
.append_model_info_list(model_info_embedding_list) \
|
||||||
.append_model_info_list(model_info_vision_list) \
|
.append_model_info_list(model_info_vision_list) \
|
||||||
|
.append_model_info_list(model_info_tti_list) \
|
||||||
.append_default_model_info(model_info_list[0]) \
|
.append_default_model_info(model_info_list[0]) \
|
||||||
.build()
|
.build()
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,44 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
res = model.check_auth()
|
||||||
|
print(res)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
@ -0,0 +1,73 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_community.chat_models import ChatZhipuAI
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from common.util.common import bytes_to_uploaded_file
|
||||||
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = kwargs.get('api_key')
|
||||||
|
self.model = kwargs.get('model')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {'params': {}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
return ZhiPuTextToImage(
|
||||||
|
model=model_name,
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
chat = ChatZhipuAI(
|
||||||
|
zhipuai_api_key=self.api_key,
|
||||||
|
model_name=self.model,
|
||||||
|
)
|
||||||
|
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
|
||||||
|
|
||||||
|
# self.generate_image('生成一个小猫图片')
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
# chat = ChatZhipuAI(
|
||||||
|
# zhipuai_api_key=self.api_key,
|
||||||
|
# model_name=self.model,
|
||||||
|
# )
|
||||||
|
chat = ZhipuAI(api_key=self.api_key)
|
||||||
|
response = chat.images.generations(
|
||||||
|
model=self.model, # 填写需要调用的模型编码
|
||||||
|
prompt=prompt, # 填写需要生成图片的文本
|
||||||
|
**self.params # 填写额外参数
|
||||||
|
)
|
||||||
|
file_urls = []
|
||||||
|
for content in response.data:
|
||||||
|
url = content['url']
|
||||||
|
print(url)
|
||||||
|
file_name = url.split('/')[-1]
|
||||||
|
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
|
||||||
|
meta = {'debug': True}
|
||||||
|
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
|
file_urls.append(file_url)
|
||||||
|
|
||||||
|
return file_urls
|
||||||
@ -13,12 +13,15 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
|
|||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.zhipu_model_provider.credential.image import ZhiPuImageModelCredential
|
from setting.models_provider.impl.zhipu_model_provider.credential.image import ZhiPuImageModelCredential
|
||||||
from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
|
from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
|
||||||
|
from setting.models_provider.impl.zhipu_model_provider.credential.tti import ZhiPuTextToImageModelCredential
|
||||||
from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuImage
|
from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuImage
|
||||||
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
|
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
|
||||||
|
from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
qwen_model_credential = ZhiPuLLMModelCredential()
|
qwen_model_credential = ZhiPuLLMModelCredential()
|
||||||
zhipu_image_model_credential = ZhiPuImageModelCredential()
|
zhipu_image_model_credential = ZhiPuImageModelCredential()
|
||||||
|
zhipu_tti_model_credential = ZhiPuTextToImageModelCredential()
|
||||||
|
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
||||||
@ -38,11 +41,21 @@ model_info_image_list = [
|
|||||||
ZhiPuImage),
|
ZhiPuImage),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
model_info_tti_list = [
|
||||||
|
ModelInfo('cogview-3', '根据用户文字描述快速、精准生成图像。分辨率支持1024x1024',
|
||||||
|
ModelTypeConst.TTI, zhipu_tti_model_credential,
|
||||||
|
ZhiPuTextToImage),
|
||||||
|
ModelInfo('cogview-3-plus', '根据用户文字描述生成高质量图像,支持多图片尺寸',
|
||||||
|
ModelTypeConst.TTI, zhipu_tti_model_credential,
|
||||||
|
ZhiPuTextToImage),
|
||||||
|
]
|
||||||
|
|
||||||
model_info_manage = (
|
model_info_manage = (
|
||||||
ModelInfoManage.builder()
|
ModelInfoManage.builder()
|
||||||
.append_model_info_list(model_info_list)
|
.append_model_info_list(model_info_list)
|
||||||
.append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel))
|
.append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel))
|
||||||
.append_model_info_list(model_info_image_list)
|
.append_model_info_list(model_info_image_list)
|
||||||
|
.append_model_info_list(model_info_tti_list)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -293,6 +293,13 @@ const getApplicationImageModel: (
|
|||||||
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
|
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getApplicationTTIModel: (
|
||||||
|
application_id: string,
|
||||||
|
loading?: Ref<boolean>
|
||||||
|
) => Promise<Result<Array<any>>> = (application_id, loading) => {
|
||||||
|
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 发布应用
|
* 发布应用
|
||||||
@ -523,6 +530,7 @@ export default {
|
|||||||
getApplicationSTTModel,
|
getApplicationSTTModel,
|
||||||
getApplicationTTSModel,
|
getApplicationTTSModel,
|
||||||
getApplicationImageModel,
|
getApplicationImageModel,
|
||||||
|
getApplicationTTIModel,
|
||||||
postSpeechToText,
|
postSpeechToText,
|
||||||
postTextToSpeech,
|
postTextToSpeech,
|
||||||
getPlatformStatus,
|
getPlatformStatus,
|
||||||
|
|||||||
@ -32,6 +32,7 @@
|
|||||||
item.type === WorkflowType.Question ||
|
item.type === WorkflowType.Question ||
|
||||||
item.type === WorkflowType.AiChat ||
|
item.type === WorkflowType.AiChat ||
|
||||||
item.type === WorkflowType.ImageUnderstandNode ||
|
item.type === WorkflowType.ImageUnderstandNode ||
|
||||||
|
item.type === WorkflowType.ImageGenerateNode ||
|
||||||
item.type === WorkflowType.Application
|
item.type === WorkflowType.Application
|
||||||
"
|
"
|
||||||
>{{ item?.message_tokens + item?.answer_tokens }} tokens</span
|
>{{ item?.message_tokens + item?.answer_tokens }} tokens</span
|
||||||
@ -444,6 +445,65 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
<!-- 图片生成 -->
|
||||||
|
<template v-if="item.type == WorkflowType.ImageGenerateNode">
|
||||||
|
<div
|
||||||
|
class="card-never border-r-4 mt-8"
|
||||||
|
v-if="item.type !== WorkflowType.Application"
|
||||||
|
>
|
||||||
|
<h5 class="p-8-12">历史记录</h5>
|
||||||
|
<div class="p-8-12 border-t-dashed lighter">
|
||||||
|
<template v-if="item.history_message?.length > 0">
|
||||||
|
<p
|
||||||
|
class="mt-4 mb-4"
|
||||||
|
v-for="(history, historyIndex) in item.history_message"
|
||||||
|
:key="historyIndex"
|
||||||
|
>
|
||||||
|
<span class="color-secondary mr-4">{{ history.role }}:</span>
|
||||||
|
|
||||||
|
<span v-if="Array.isArray(history.content)">
|
||||||
|
<template v-for="(h, i) in history.content" :key="i">
|
||||||
|
<el-image
|
||||||
|
v-if="h.type === 'image_url'"
|
||||||
|
:src="h.image_url.url"
|
||||||
|
alt=""
|
||||||
|
fit="cover"
|
||||||
|
style="width: 40px; height: 40px; display: inline-block"
|
||||||
|
class="border-r-4 mr-8"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<span v-else>{{ h.text }}<br /></span>
|
||||||
|
</template>
|
||||||
|
</span>
|
||||||
|
|
||||||
|
<span v-else>{{ history.content }}</span>
|
||||||
|
</p>
|
||||||
|
</template>
|
||||||
|
<template v-else> - </template>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="card-never border-r-4 mt-8">
|
||||||
|
<h5 class="p-8-12">本次对话</h5>
|
||||||
|
<div class="p-8-12 border-t-dashed lighter pre-wrap">
|
||||||
|
{{ item.question || '-' }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="card-never border-r-4 mt-8">
|
||||||
|
<h5 class="p-8-12">
|
||||||
|
{{ item.type == WorkflowType.Application ? '参数输出' : 'AI 回答' }}
|
||||||
|
</h5>
|
||||||
|
<div class="p-8-12 border-t-dashed lighter">
|
||||||
|
<MdPreview
|
||||||
|
v-if="item.answer"
|
||||||
|
ref="editorRef"
|
||||||
|
editorId="preview-only"
|
||||||
|
:modelValue="item.answer"
|
||||||
|
style="background: none"
|
||||||
|
/>
|
||||||
|
<template v-else> - </template>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
</template>
|
</template>
|
||||||
<template v-else>
|
<template v-else>
|
||||||
<div class="card-never border-r-4">
|
<div class="card-never border-r-4">
|
||||||
|
|||||||
@ -13,5 +13,6 @@ export enum modelType {
|
|||||||
STT = '语音识别',
|
STT = '语音识别',
|
||||||
TTS = '语音合成',
|
TTS = '语音合成',
|
||||||
IMAGE = '图片理解',
|
IMAGE = '图片理解',
|
||||||
|
TTI = '图片生成',
|
||||||
RERANKER = '重排模型'
|
RERANKER = '重排模型'
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,5 +12,6 @@ export enum WorkflowType {
|
|||||||
Application = 'application-node',
|
Application = 'application-node',
|
||||||
DocumentExtractNode = 'document-extract-node',
|
DocumentExtractNode = 'document-extract-node',
|
||||||
ImageUnderstandNode = 'image-understand-node',
|
ImageUnderstandNode = 'image-understand-node',
|
||||||
|
ImageGenerateNode = 'image-generate-node',
|
||||||
FormNode = 'form-node'
|
FormNode = 'form-node'
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,6 +63,7 @@ const modelTypeOptions = ref([
|
|||||||
{ text: '语音识别', value: 'STT' },
|
{ text: '语音识别', value: 'STT' },
|
||||||
{ text: '语音合成', value: 'TTS' },
|
{ text: '语音合成', value: 'TTS' },
|
||||||
{ text: '图片理解', value: 'IMAGE' },
|
{ text: '图片理解', value: 'IMAGE' },
|
||||||
|
{ text: '图片生成', value: 'TTI' },
|
||||||
])
|
])
|
||||||
|
|
||||||
const open = () => {
|
const open = () => {
|
||||||
|
|||||||
@ -133,6 +133,7 @@
|
|||||||
<el-option label="语音识别" value="STT" />
|
<el-option label="语音识别" value="STT" />
|
||||||
<el-option label="语音合成" value="TTS" />
|
<el-option label="语音合成" value="TTS" />
|
||||||
<el-option label="图片理解" value="IMAGE" />
|
<el-option label="图片理解" value="IMAGE" />
|
||||||
|
<el-option label="图片生成" value="TTI" />
|
||||||
</el-select>
|
</el-select>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -227,6 +227,28 @@ export const imageUnderstandNode = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const imageGenerateNode = {
|
||||||
|
type: WorkflowType.ImageGenerateNode,
|
||||||
|
text: '根据提供的文本内容生成图片',
|
||||||
|
label: '图片生成',
|
||||||
|
height: 252,
|
||||||
|
properties: {
|
||||||
|
stepName: '图片生成',
|
||||||
|
config: {
|
||||||
|
fields: [
|
||||||
|
{
|
||||||
|
label: 'AI 回答内容',
|
||||||
|
value: 'answer'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '图片',
|
||||||
|
value: 'image'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
export const menuNodes = [
|
export const menuNodes = [
|
||||||
aiChatNode,
|
aiChatNode,
|
||||||
searchDatasetNode,
|
searchDatasetNode,
|
||||||
@ -236,6 +258,7 @@ export const menuNodes = [
|
|||||||
rerankerNode,
|
rerankerNode,
|
||||||
documentExtractNode,
|
documentExtractNode,
|
||||||
imageUnderstandNode,
|
imageUnderstandNode,
|
||||||
|
imageGenerateNode,
|
||||||
formNode
|
formNode
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -326,7 +349,8 @@ export const nodeDict: any = {
|
|||||||
[WorkflowType.FormNode]: formNode,
|
[WorkflowType.FormNode]: formNode,
|
||||||
[WorkflowType.Application]: applicationNode,
|
[WorkflowType.Application]: applicationNode,
|
||||||
[WorkflowType.DocumentExtractNode]: documentExtractNode,
|
[WorkflowType.DocumentExtractNode]: documentExtractNode,
|
||||||
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode
|
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode,
|
||||||
|
[WorkflowType.ImageGenerateNode]: imageGenerateNode
|
||||||
}
|
}
|
||||||
export function isWorkFlow(type: string | undefined) {
|
export function isWorkFlow(type: string | undefined) {
|
||||||
return type === 'WORK_FLOW'
|
return type === 'WORK_FLOW'
|
||||||
|
|||||||
@ -6,6 +6,7 @@ const end_nodes: Array<string> = [
|
|||||||
WorkflowType.FunctionLib,
|
WorkflowType.FunctionLib,
|
||||||
WorkflowType.FunctionLibCustom,
|
WorkflowType.FunctionLibCustom,
|
||||||
WorkflowType.ImageUnderstandNode,
|
WorkflowType.ImageUnderstandNode,
|
||||||
|
WorkflowType.ImageGenerateNode,
|
||||||
WorkflowType.Application
|
WorkflowType.Application
|
||||||
]
|
]
|
||||||
export class WorkFlowInstance {
|
export class WorkFlowInstance {
|
||||||
|
|||||||
6
ui/src/workflow/icons/image-generate-node-icon.vue
Normal file
6
ui/src/workflow/icons/image-generate-node-icon.vue
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<template>
|
||||||
|
<AppAvatar shape="square" style="background: #14C0FF;">
|
||||||
|
<img src="@/assets/icon_image.svg" style="width: 65%" alt="" />
|
||||||
|
</AppAvatar>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts"></script>
|
||||||
14
ui/src/workflow/nodes/image-generate/index.ts
Normal file
14
ui/src/workflow/nodes/image-generate/index.ts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import ImageGenerateNodeVue from './index.vue'
|
||||||
|
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||||
|
|
||||||
|
class RerankerNode extends AppNode {
|
||||||
|
constructor(props: any) {
|
||||||
|
super(props, ImageGenerateNodeVue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default {
|
||||||
|
type: 'image-generate-node',
|
||||||
|
model: AppNodeModel,
|
||||||
|
view: RerankerNode
|
||||||
|
}
|
||||||
323
ui/src/workflow/nodes/image-generate/index.vue
Normal file
323
ui/src/workflow/nodes/image-generate/index.vue
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
<template>
|
||||||
|
<NodeContainer :node-model="nodeModel">
|
||||||
|
<h5 class="title-decoration-1 mb-8">节点设置</h5>
|
||||||
|
<el-card shadow="never" class="card-never">
|
||||||
|
<el-form
|
||||||
|
@submit.prevent
|
||||||
|
:model="form_data"
|
||||||
|
label-position="top"
|
||||||
|
require-asterisk-position="right"
|
||||||
|
label-width="auto"
|
||||||
|
ref="aiChatNodeFormRef"
|
||||||
|
hide-required-asterisk
|
||||||
|
>
|
||||||
|
<el-form-item
|
||||||
|
label="图片生成模型"
|
||||||
|
prop="model_id"
|
||||||
|
:rules="{
|
||||||
|
required: true,
|
||||||
|
message: '请选择图片生成模型',
|
||||||
|
trigger: 'change'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex-between w-full">
|
||||||
|
<div>
|
||||||
|
<span>图片生成模型<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-button
|
||||||
|
:disabled="!form_data.model_id"
|
||||||
|
type="primary"
|
||||||
|
link
|
||||||
|
@click="openAIParamSettingDialog(form_data.model_id)"
|
||||||
|
@refreshForm="refreshParam"
|
||||||
|
>
|
||||||
|
{{ $t('views.application.applicationForm.form.paramSetting') }}
|
||||||
|
</el-button>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-select
|
||||||
|
@change="model_change"
|
||||||
|
@wheel="wheel"
|
||||||
|
:teleported="false"
|
||||||
|
v-model="form_data.model_id"
|
||||||
|
placeholder="请选择图片生成模型"
|
||||||
|
class="w-full"
|
||||||
|
popper-class="select-model"
|
||||||
|
:clearable="true"
|
||||||
|
>
|
||||||
|
<el-option-group
|
||||||
|
v-for="(value, label) in modelOptions"
|
||||||
|
:key="value"
|
||||||
|
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||||
|
>
|
||||||
|
<el-option
|
||||||
|
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
class="flex-between"
|
||||||
|
>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<span
|
||||||
|
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||||
|
class="model-icon mr-8"
|
||||||
|
></span>
|
||||||
|
<span>{{ item.name }}</span>
|
||||||
|
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
|
||||||
|
>公用
|
||||||
|
</el-tag>
|
||||||
|
</div>
|
||||||
|
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||||
|
<Check />
|
||||||
|
</el-icon>
|
||||||
|
</el-option>
|
||||||
|
<!-- 不可用 -->
|
||||||
|
<el-option
|
||||||
|
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
class="flex-between"
|
||||||
|
disabled
|
||||||
|
>
|
||||||
|
<div class="flex">
|
||||||
|
<span
|
||||||
|
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||||
|
class="model-icon mr-8"
|
||||||
|
></span>
|
||||||
|
<span>{{ item.name }}</span>
|
||||||
|
<span class="danger">(不可用)</span>
|
||||||
|
</div>
|
||||||
|
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||||
|
<Check />
|
||||||
|
</el-icon>
|
||||||
|
</el-option>
|
||||||
|
</el-option-group>
|
||||||
|
</el-select>
|
||||||
|
</el-form-item>
|
||||||
|
|
||||||
|
|
||||||
|
<el-form-item
|
||||||
|
label="提示词(正向)"
|
||||||
|
prop="prompt"
|
||||||
|
:rules="{
|
||||||
|
required: true,
|
||||||
|
message: '请输入提示词',
|
||||||
|
trigger: 'blur'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>提示词(正向)<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content
|
||||||
|
>正向提示词,用来描述生成图像中期望包含的元素和视觉特点
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<MdEditorMagnify
|
||||||
|
@wheel="wheel"
|
||||||
|
title="提示词(正向)"
|
||||||
|
v-model="form_data.prompt"
|
||||||
|
style="height: 150px"
|
||||||
|
@submitDialog="submitDialog"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item
|
||||||
|
label="提示词(负向)"
|
||||||
|
prop="prompt"
|
||||||
|
:rules="{
|
||||||
|
required: false,
|
||||||
|
message: '请输入提示词',
|
||||||
|
trigger: 'blur'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>提示词(负向)</span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content
|
||||||
|
>反向提示词,用来描述不希望在画面中看到的内容,可以对画面进行限制。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<MdEditorMagnify
|
||||||
|
@wheel="wheel"
|
||||||
|
title="提示词(负向)"
|
||||||
|
v-model="form_data.negative_prompt"
|
||||||
|
style="height: 150px"
|
||||||
|
@submitDialog="submitDialog"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex-between">
|
||||||
|
<div>历史聊天记录</div>
|
||||||
|
<el-select v-model="form_data.dialogue_type" type="small" style="width: 100px">
|
||||||
|
<el-option label="节点" value="NODE" />
|
||||||
|
<el-option label="工作流" value="WORKFLOW" />
|
||||||
|
</el-select>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-input-number
|
||||||
|
v-model="form_data.dialogue_number"
|
||||||
|
:min="0"
|
||||||
|
:value-on-clear="0"
|
||||||
|
controls-position="right"
|
||||||
|
class="w-full"
|
||||||
|
:step="1"
|
||||||
|
:step-strictly="true"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="返回内容" @click.prevent>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>返回内容<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content>
|
||||||
|
关闭后该节点的内容则不输出给用户。
|
||||||
|
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-switch size="small" v-model="form_data.is_result" />
|
||||||
|
</el-form-item>
|
||||||
|
</el-form>
|
||||||
|
</el-card>
|
||||||
|
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
|
||||||
|
</NodeContainer>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||||
|
import { computed, onMounted, ref } from 'vue'
|
||||||
|
import { groupBy, set } from 'lodash'
|
||||||
|
import { relatedObject } from '@/utils/utils'
|
||||||
|
import type { Provider } from '@/api/type/model'
|
||||||
|
import applicationApi from '@/api/application'
|
||||||
|
import { app } from '@/main'
|
||||||
|
import useStore from '@/stores'
|
||||||
|
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||||
|
import type { FormInstance } from 'element-plus'
|
||||||
|
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
|
||||||
|
|
||||||
|
const { model } = useStore()
|
||||||
|
|
||||||
|
const {
|
||||||
|
params: { id }
|
||||||
|
} = app.config.globalProperties.$route as any
|
||||||
|
|
||||||
|
const props = defineProps<{ nodeModel: any }>()
|
||||||
|
const modelOptions = ref<any>(null)
|
||||||
|
const providerOptions = ref<Array<Provider>>([])
|
||||||
|
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
|
||||||
|
|
||||||
|
const aiChatNodeFormRef = ref<FormInstance>()
|
||||||
|
const validate = () => {
|
||||||
|
return aiChatNodeFormRef.value?.validate().catch((err) => {
|
||||||
|
return Promise.reject({ node: props.nodeModel, errMessage: err })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const wheel = (e: any) => {
|
||||||
|
if (e.ctrlKey === true) {
|
||||||
|
e.preventDefault()
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
e.stopPropagation()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultPrompt = `{{开始.question}}`
|
||||||
|
|
||||||
|
const form = {
|
||||||
|
model_id: '',
|
||||||
|
system: '',
|
||||||
|
prompt: defaultPrompt,
|
||||||
|
negative_prompt: '',
|
||||||
|
dialogue_number: 0,
|
||||||
|
dialogue_type: 'NODE',
|
||||||
|
is_result: true,
|
||||||
|
temperature: null,
|
||||||
|
max_tokens: null,
|
||||||
|
image_list: ['start-node', 'image']
|
||||||
|
}
|
||||||
|
|
||||||
|
const form_data = computed({
|
||||||
|
get: () => {
|
||||||
|
if (props.nodeModel.properties.node_data) {
|
||||||
|
return props.nodeModel.properties.node_data
|
||||||
|
} else {
|
||||||
|
set(props.nodeModel.properties, 'node_data', form)
|
||||||
|
}
|
||||||
|
return props.nodeModel.properties.node_data
|
||||||
|
},
|
||||||
|
set: (value) => {
|
||||||
|
set(props.nodeModel.properties, 'node_data', value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
function getModel() {
|
||||||
|
if (id) {
|
||||||
|
applicationApi.getApplicationTTIModel(id).then((res: any) => {
|
||||||
|
modelOptions.value = groupBy(res?.data, 'provider')
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
model.asyncGetModel().then((res: any) => {
|
||||||
|
modelOptions.value = groupBy(res?.data, 'provider')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getProvider() {
|
||||||
|
model.asyncGetProvider().then((res: any) => {
|
||||||
|
providerOptions.value = res?.data
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const model_change = () => {
|
||||||
|
if (form_data.value.model_id) {
|
||||||
|
AIModeParamSettingDialogRef.value?.reset_default(form_data.value.model_id, id)
|
||||||
|
} else {
|
||||||
|
refreshParam({})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const openAIParamSettingDialog = (modelId: string) => {
|
||||||
|
if (modelId) {
|
||||||
|
AIModeParamSettingDialogRef.value?.open(modelId, id, form_data.value.model_params_setting)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshParam(data: any) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
|
||||||
|
}
|
||||||
|
|
||||||
|
function submitDialog(val: string) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'prompt', val)
|
||||||
|
}
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
getModel()
|
||||||
|
getProvider()
|
||||||
|
|
||||||
|
set(props.nodeModel, 'validate', validate)
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped lang="scss"></style>
|
||||||
Loading…
Reference in New Issue
Block a user