feat: aws向量模型
This commit is contained in:
parent
cb3c064463
commit
6618c6baf3
@ -6,12 +6,10 @@ from common.util.file_util import get_file_content
|
|||||||
from setting.models_provider.base_model_provider import (
|
from setting.models_provider.base_model_provider import (
|
||||||
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||||
)
|
)
|
||||||
|
from setting.models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential
|
||||||
from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
|
from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
|
||||||
|
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
|
||||||
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
|
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
|
||||||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
|
||||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -118,10 +116,21 @@ def _initialize_model_info():
|
|||||||
BedrockLLMModelCredential,
|
BedrockLLMModelCredential,
|
||||||
BedrockModel),
|
BedrockModel),
|
||||||
]
|
]
|
||||||
|
embedded_model_info_list = [
|
||||||
|
_create_model_info(
|
||||||
|
'amazon.titan-embed-text-v1',
|
||||||
|
'Titan Embed Text 是 Amazon Titan Embed 系列中最大的嵌入模型,可以处理各种文本嵌入任务,如文本分类、文本相似度计算等。',
|
||||||
|
ModelTypeConst.EMBEDDING,
|
||||||
|
BedrockEmbeddingCredential,
|
||||||
|
BedrockEmbeddingModel
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder() \
|
model_info_manage = ModelInfoManage.builder() \
|
||||||
.append_model_info_list(model_info_list) \
|
.append_model_info_list(model_info_list) \
|
||||||
.append_default_model_info(model_info_list[0]) \
|
.append_default_model_info(model_info_list[0]) \
|
||||||
|
.append_model_info_list(embedded_model_info_list) \
|
||||||
|
.append_default_model_info(embedded_model_info_list[0]) \
|
||||||
.build()
|
.build()
|
||||||
|
|
||||||
return model_info_manage
|
return model_info_manage
|
||||||
|
|||||||
@ -1,64 +1,64 @@
|
|||||||
import json
|
import os
|
||||||
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from tencentcloud.common import credential
|
|
||||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
|
||||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
|
||||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm
|
||||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
@classmethod
|
|
||||||
def _validate_model_type(cls, model_type: str, provider) -> bool:
|
|
||||||
model_type_list = provider.get_model_type_list()
|
|
||||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
|
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||||
for key in ['SecretId', 'SecretKey']:
|
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||||
if key not in model_credential:
|
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
|
||||||
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
|
|
||||||
|
|
||||||
@classmethod
|
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||||
def _test_credentials(cls, client, model_name: str):
|
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||||
req = models.GetEmbeddingRequest()
|
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||||
params = {
|
|
||||||
"Model": model_name,
|
if not re.search(rf'\[{profile_name}\]', content):
|
||||||
"Input": "测试"
|
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||||
}
|
|
||||||
req.from_json_string(json.dumps(params))
|
with open(credentials_path, 'w') as file:
|
||||||
try:
|
file.write(content)
|
||||||
res = client.GetEmbedding(req)
|
|
||||||
print(res.to_json_string())
|
|
||||||
except Exception as e:
|
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
raise_exception=True) -> bool:
|
raise_exception=False):
|
||||||
try:
|
model_type_list = provider.get_model_type_list()
|
||||||
self._validate_model_type(model_type, provider)
|
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||||
cred = self._validate_credential(model_credential)
|
|
||||||
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
|
|
||||||
clientProfile = ClientProfile(httpProfile=httpProfile)
|
|
||||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
|
||||||
self._test_credentials(client, model_name)
|
|
||||||
return True
|
|
||||||
except AppApiException as e:
|
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
raise e
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
|
||||||
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
|
if not all(key in model_credential for key in required_keys):
|
||||||
return {**model, 'SecretKey': encrypted_secret_key}
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}')
|
||||||
|
return False
|
||||||
|
|
||||||
SecretId = forms.PasswordInputField('SecretId', required=True)
|
try:
|
||||||
SecretKey = forms.PasswordInputField('SecretKey', required=True)
|
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
||||||
|
model_credential['secret_access_key'])
|
||||||
|
model_credential['credentials_profile_name'] = 'aws-profile'
|
||||||
|
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
aa = model.embed_query('你好')
|
||||||
|
print(aa)
|
||||||
|
except AppApiException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
|
||||||
|
|
||||||
|
region_name = forms.TextInputField('Region Name', required=True)
|
||||||
|
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||||
|
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||||
|
|||||||
@ -1,25 +1,56 @@
|
|||||||
|
from langchain_community.embeddings import BedrockEmbeddings
|
||||||
|
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class TencentEmbeddingModel(MaxKBBaseModel):
|
class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
|
||||||
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
|
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||||
self.secret_id = secret_id
|
**kwargs):
|
||||||
self.secret_key = secret_key
|
super().__init__(model_id=model_id, region_name=region_name,
|
||||||
self.api_base = api_base
|
credentials_profile_name=credentials_profile_name, **kwargs)
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
|
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||||
return TencentEmbeddingModel(
|
**model_kwargs) -> 'BedrockModel':
|
||||||
secret_id=model_credential.get('SecretId'),
|
return cls(
|
||||||
secret_key=model_credential.get('SecretKey'),
|
model_id=model_name,
|
||||||
api_base=model_credential.get('api_base'),
|
region_name=model_credential['region_name'],
|
||||||
model_name=model_name,
|
credentials_profile_name=model_credential['credentials_profile_name'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Compute doc embeddings using a Bedrock model.
|
||||||
|
|
||||||
def _generate_auth_token(self):
|
Args:
|
||||||
# Example method to generate an authentication token for the model API
|
texts: The list of texts to embed
|
||||||
return f"{self.secret_id}:{self.secret_key}"
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for text in texts:
|
||||||
|
response = self._embedding_func(text)
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
response = self._normalize_vector(response)
|
||||||
|
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Compute query embeddings using a Bedrock model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding_func(text)
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
return self._normalize_vector(embedding)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user