feat: 支持讯飞向量模型
This commit is contained in:
parent
97cfd60346
commit
f85ce4a745
@ -0,0 +1,43 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/10/17 15:40
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XFEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
self.valid_form(model_credential)
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.embed_query('你好')
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return model
|
||||||
|
|
||||||
|
base_url = forms.TextInputField('API 域名', required=True, default_value="https://emb-cn-huabei-1.xf-yun.com/")
|
||||||
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/10/17 15:29
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return XFEmbedding(
|
||||||
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
|
spark_api_secret=model_credential.get('spark_api_secret')
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parser_message(
|
||||||
|
message: str,
|
||||||
|
) -> Optional[ndarray]:
|
||||||
|
data = json.loads(message)
|
||||||
|
code = data["header"]["code"]
|
||||||
|
if code != 0:
|
||||||
|
# 这里是讯飞的QPS限制会报错,所以不建议用讯飞的向量模型
|
||||||
|
raise Exception(f"Request error: {code}, {data}")
|
||||||
|
else:
|
||||||
|
text_base = data["payload"]["feature"]["text"]
|
||||||
|
text_data = base64.b64decode(text_base)
|
||||||
|
dt = np.dtype(np.float32)
|
||||||
|
dt = dt.newbyteorder("<")
|
||||||
|
text = np.frombuffer(text_data, dtype=dt)
|
||||||
|
if len(text) > 2560:
|
||||||
|
array = text[:2560]
|
||||||
|
else:
|
||||||
|
array = text
|
||||||
|
return array
|
||||||
@ -12,9 +12,11 @@ import ssl
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
|
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
||||||
|
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
||||||
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||||
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||||
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||||
@ -25,12 +27,14 @@ ssl._create_default_https_context = ssl.create_default_context()
|
|||||||
qwen_model_credential = XunFeiLLMModelCredential()
|
qwen_model_credential = XunFeiLLMModelCredential()
|
||||||
stt_model_credential = XunFeiSTTModelCredential()
|
stt_model_credential = XunFeiSTTModelCredential()
|
||||||
tts_model_credential = XunFeiTTSModelCredential()
|
tts_model_credential = XunFeiTTSModelCredential()
|
||||||
|
embedding_model_credential = XFEmbeddingCredential()
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||||
|
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||||
]
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user