feat: 讯飞图片模型
This commit is contained in:
parent
f318f2da40
commit
ddad340534
@ -149,6 +149,7 @@ class ModelTypeConst(Enum):
|
|||||||
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
||||||
STT = {'code': 'STT', 'message': '语音识别'}
|
STT = {'code': 'STT', 'message': '语音识别'}
|
||||||
TTS = {'code': 'TTS', 'message': '语音合成'}
|
TTS = {'code': 'TTS', 'message': '语音合成'}
|
||||||
|
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
|
||||||
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
|
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
14
apps/setting/models_provider/impl/base_image.py
Normal file
14
apps/setting/models_provider/impl/base_image.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImage(BaseModel):
|
||||||
|
@abstractmethod
|
||||||
|
def check_auth(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def image_understand(self, image_file, text):
|
||||||
|
pass
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image')
|
||||||
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
@ -0,0 +1,170 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_image import BaseImage
|
||||||
|
|
||||||
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
|
||||||
|
class XFSparkImage(MaxKBBaseModel, BaseImage):
|
||||||
|
spark_app_id: str
|
||||||
|
spark_api_key: str
|
||||||
|
spark_api_secret: str
|
||||||
|
spark_api_url: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.spark_api_url = kwargs.get('spark_api_url')
|
||||||
|
self.spark_app_id = kwargs.get('spark_app_id')
|
||||||
|
self.spark_api_key = kwargs.get('spark_api_key')
|
||||||
|
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {'params': {}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
return XFSparkImage(
|
||||||
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||||
|
spark_api_url=model_credential.get('spark_api_url'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_url(self):
|
||||||
|
url = self.spark_api_url
|
||||||
|
host = urlparse(url).hostname
|
||||||
|
# 生成RFC1123格式的时间戳
|
||||||
|
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||||
|
date = datetime.now(UTC).strftime(gmt_format)
|
||||||
|
|
||||||
|
# 拼接字符串
|
||||||
|
signature_origin = "host: " + host + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
|
||||||
|
# 进行hmac-sha256进行加密
|
||||||
|
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||||
|
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
# 将请求的鉴权参数组合为字典
|
||||||
|
v = {
|
||||||
|
"authorization": authorization,
|
||||||
|
"date": date,
|
||||||
|
"host": host
|
||||||
|
}
|
||||||
|
# 拼接鉴权参数,生成url
|
||||||
|
url = url + '?' + urlencode(v)
|
||||||
|
# print("date: ",date)
|
||||||
|
# print("v: ",v)
|
||||||
|
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||||
|
# print('websocket url :', url)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||||
|
self.image_understand(f,"一句话概述这个图片")
|
||||||
|
|
||||||
|
def image_understand(self, image_file, question):
|
||||||
|
async def handle():
|
||||||
|
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||||||
|
# 发送 full client request
|
||||||
|
await self.send(ws, image_file, question)
|
||||||
|
return await self.handle_message(ws)
|
||||||
|
|
||||||
|
return asyncio.run(handle())
|
||||||
|
|
||||||
|
# 收到websocket消息的处理
|
||||||
|
@staticmethod
|
||||||
|
async def handle_message(ws):
|
||||||
|
# print(message)
|
||||||
|
answer = ''
|
||||||
|
while True:
|
||||||
|
res = await ws.recv()
|
||||||
|
data = json.loads(res)
|
||||||
|
code = data['header']['code']
|
||||||
|
if code != 0:
|
||||||
|
return f'请求错误: {code}, {data}'
|
||||||
|
else:
|
||||||
|
choices = data["payload"]["choices"]
|
||||||
|
status = choices["status"]
|
||||||
|
content = choices["text"][0]["content"]
|
||||||
|
# print(content, end="")
|
||||||
|
answer += content
|
||||||
|
# print(1)
|
||||||
|
if status == 2:
|
||||||
|
break
|
||||||
|
return answer
|
||||||
|
|
||||||
|
async def send(self, ws, image_file, question):
|
||||||
|
text = [
|
||||||
|
{"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
|
||||||
|
{"role": "user", "content": question}
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"header": {
|
||||||
|
"app_id": self.spark_app_id
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"chat": {
|
||||||
|
"domain": "image",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_k": 4,
|
||||||
|
"max_tokens": 2028,
|
||||||
|
"auditing": "default"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"message": {
|
||||||
|
"text": text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d = json.dumps(data)
|
||||||
|
await ws.send(d)
|
||||||
|
|
||||||
|
def is_cache_model(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_len(text):
|
||||||
|
length = 0
|
||||||
|
for content in text:
|
||||||
|
temp = content["content"]
|
||||||
|
leng = len(temp)
|
||||||
|
length += leng
|
||||||
|
return length
|
||||||
|
|
||||||
|
def check_len(self, text):
|
||||||
|
print("text-content-tokens:", self.get_len(text[1:]))
|
||||||
|
while (self.get_len(text[1:]) > 8000):
|
||||||
|
del text[1]
|
||||||
|
return text
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 354 KiB |
@ -10,7 +10,7 @@ import hmac
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime, UTC
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from urllib.parse import urlencode, urlparse
|
from urllib.parse import urlencode, urlparse
|
||||||
import ssl
|
import ssl
|
||||||
@ -63,7 +63,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
|||||||
host = urlparse(url).hostname
|
host = urlparse(url).hostname
|
||||||
# 生成RFC1123格式的时间戳
|
# 生成RFC1123格式的时间戳
|
||||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||||
date = datetime.utcnow().strftime(gmt_format)
|
date = datetime.now(UTC).strftime(gmt_format)
|
||||||
|
|
||||||
# 拼接字符串
|
# 拼接字符串
|
||||||
signature_origin = "host: " + host + "\n"
|
signature_origin = "host: " + host + "\n"
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import hmac
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime, UTC
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from urllib.parse import urlencode, urlparse
|
from urllib.parse import urlencode, urlparse
|
||||||
import ssl
|
import ssl
|
||||||
@ -67,7 +67,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
|||||||
host = urlparse(url).hostname
|
host = urlparse(url).hostname
|
||||||
# 生成RFC1123格式的时间戳
|
# 生成RFC1123格式的时间戳
|
||||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||||
date = datetime.utcnow().strftime(gmt_format)
|
date = datetime.now(UTC).strftime(gmt_format)
|
||||||
|
|
||||||
# 拼接字符串
|
# 拼接字符串
|
||||||
signature_origin = "host: " + host + "\n"
|
signature_origin = "host: " + host + "\n"
|
||||||
|
|||||||
@ -13,10 +13,12 @@ from common.util.file_util import get_file_content
|
|||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
|
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
|
||||||
|
from setting.models_provider.impl.xf_model_provider.credential.image import XunFeiImageModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
||||||
|
from setting.models_provider.impl.xf_model_provider.model.image import XFSparkImage
|
||||||
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||||
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||||
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||||
@ -26,6 +28,7 @@ ssl._create_default_https_context = ssl.create_default_context()
|
|||||||
|
|
||||||
qwen_model_credential = XunFeiLLMModelCredential()
|
qwen_model_credential = XunFeiLLMModelCredential()
|
||||||
stt_model_credential = XunFeiSTTModelCredential()
|
stt_model_credential = XunFeiSTTModelCredential()
|
||||||
|
image_model_credential = XunFeiImageModelCredential()
|
||||||
tts_model_credential = XunFeiTTSModelCredential()
|
tts_model_credential = XunFeiTTSModelCredential()
|
||||||
embedding_model_credential = XFEmbeddingCredential()
|
embedding_model_credential = XFEmbeddingCredential()
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
@ -34,6 +37,7 @@ model_info_list = [
|
|||||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||||
|
ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
|
||||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -132,6 +132,7 @@
|
|||||||
<el-option label="重排模型" value="RERANKER" />
|
<el-option label="重排模型" value="RERANKER" />
|
||||||
<el-option label="语音识别" value="STT" />
|
<el-option label="语音识别" value="STT" />
|
||||||
<el-option label="语音合成" value="TTS" />
|
<el-option label="语音合成" value="TTS" />
|
||||||
|
<el-option label="图片理解" value="IMAGE" />
|
||||||
</el-select>
|
</el-select>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user