feat: 增加xinference模型对接 (#959)
This commit is contained in:
parent
9cc24ac508
commit
35d9462689
@ -21,6 +21,7 @@ from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine
|
||||
VolcanicEngineModelProvider
|
||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||
from setting.models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider
|
||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||
|
||||
@ -40,3 +41,4 @@ class ModelProvideConstants(Enum):
|
||||
model_tencent_provider = TencentModelProvider()
|
||||
model_aws_bedrock_provider = BedrockModelProvider()
|
||||
model_local_provider = LocalModelProvider()
|
||||
model_xinference_provider = XinferenceModelProvider()
|
||||
|
||||
@ -31,13 +31,56 @@ def _get_aws_bedrock_icon_path():
|
||||
|
||||
|
||||
def _initialize_model_info():
|
||||
model_info_list = [_create_model_info(
|
||||
'amazon.titan-text-premier-v1:0',
|
||||
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
model_info_list = [
|
||||
_create_model_info(
|
||||
'anthropic.claude-v2:1',
|
||||
'Claude 2 的更新,采用双倍的上下文窗口,并在长文档和 RAG 上下文中提高可靠性、幻觉率和循证准确性。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-v2',
|
||||
'Anthropic 功能强大的模型,可处理各种任务,从复杂的对话和创意内容生成到详细的指令服从。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
'Claude 3 Haiku 是 Anthropic 最快速、最紧凑的模型,具有近乎即时的响应能力。该模型可以快速回答简单的查询和请求。客户将能够构建模仿人类交互的无缝人工智能体验。 Claude 3 Haiku 可以处理图像和返回文本输出,并且提供 200K 上下文窗口。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-3-sonnet-20240229-v1:0',
|
||||
'Anthropic 推出的 Claude 3 Sonnet 模型在智能和速度之间取得理想的平衡,尤其是在处理企业工作负载方面。该模型提供最大的效用,同时价格低于竞争产品,并且其经过精心设计,是大规模部署人工智能的可靠选择。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
'Claude 3.5 Sonnet提高了智能的行业标准,在广泛的评估中超越了竞争对手的型号和Claude 3 Opus,具有我们中端型号的速度和成本效益。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-instant-v1',
|
||||
'一种更快速、更实惠但仍然非常强大的模型,它可以处理一系列任务,包括随意对话、文本分析、摘要和文档问题回答。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-premier-v1:0',
|
||||
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-lite-v1',
|
||||
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
|
||||
@ -59,7 +102,7 @@ def _initialize_model_info():
|
||||
_create_model_info(
|
||||
'mistral.mistral-7b-instruct-v0:2',
|
||||
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
|
||||
ModelTypeConst.EMBEDDING,
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
|
||||
78
apps/setting/models_provider/impl/base_chat_open_ai.py
Normal file
78
apps/setting/models_provider/impl/base_chat_open_ai.py
Normal file
@ -0,0 +1,78 @@
|
||||
# coding=utf-8
|
||||
|
||||
from typing import List, Dict, Optional, Any, Iterator, Type
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
||||
|
||||
|
||||
class BaseChatOpenAI(ChatOpenAI):
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
base_generation_info = {}
|
||||
with response:
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
if token_usage := chunk.get("usage"):
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
||||
logprobs = None
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if is_first_chunk else {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
is_first_chunk = False
|
||||
yield generation_chunk
|
||||
@ -8,14 +8,11 @@
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||
@ -25,10 +22,3 @@ class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
)
|
||||
return deepseek_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
@ -8,14 +8,14 @@
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
kimi_chat_open_ai = KimiChatModel(
|
||||
@ -25,10 +25,3 @@ class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
)
|
||||
return kimi_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
@ -9,11 +9,11 @@
|
||||
from typing import List, Dict
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
@ -24,7 +24,7 @@ def get_base_url(url: str):
|
||||
return result_url[:-1] if result_url.endswith("/") else result_url
|
||||
|
||||
|
||||
class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
api_base = model_credential.get('api_base', '')
|
||||
@ -32,11 +32,3 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
return OllamaChatModel(model=model_name, openai_api_base=base_url,
|
||||
openai_api_key=model_credential.get('api_key'))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
@ -8,27 +8,19 @@
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
max_tokens=model_kwargs.get('max_tokens', 5),
|
||||
temperature=model_kwargs.get('temperature', 0.5),
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
@ -6,16 +6,15 @@
|
||||
@date:2024/04/19 15:55
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Any, Iterator, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatSparkLLM
|
||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \
|
||||
ChatSparkLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
@ -31,16 +30,19 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
spark_llm_domain=model_name
|
||||
spark_llm_domain=model_name,
|
||||
temperature=model_kwargs.get('temperature', 0.5),
|
||||
max_tokens=model_kwargs.get('max_tokens', 5),
|
||||
)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -58,11 +60,17 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
||||
True,
|
||||
)
|
||||
for content in self.client.subscribe(timeout=self.request_timeout):
|
||||
if "data" not in content:
|
||||
if "data" in content:
|
||||
delta = content["data"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
elif "usage" in content:
|
||||
generation_info = content["usage"]
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
|
||||
continue
|
||||
delta = content["data"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
else:
|
||||
continue
|
||||
if cg_chunk is not None:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
@ -0,0 +1 @@
|
||||
# coding=utf-8
|
||||
@ -0,0 +1,38 @@
|
||||
# coding=utf-8
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
|
||||
|
||||
class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding')
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||
exist = provider.get_model_info_by_name(model_list, model_name)
|
||||
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
if len(exist) == 0:
|
||||
model.start_down_model_thread()
|
||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||
model.embed_query('你好')
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return model_info
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
@ -0,0 +1,41 @@
|
||||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
model_list = provider.get_base_model_list(model_credential.get('api_base'), model_type)
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||
exist = provider.get_model_info_by_name(model_list, model_name)
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" width="100%" height="100%" viewBox="0 0 48 48" enable-background="new 0 0 48 48" xml:space="preserve"> <image id="image0" width="48" height="48" x="0" y="0"
|
||||
href="" ></image>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.3 KiB |
@ -0,0 +1,24 @@
|
||||
# coding=utf-8
|
||||
import threading
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.embeddings import XinferenceEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return XinferenceEmbedding(
|
||||
model_uid=model_name,
|
||||
server_url=model_credential.get('api_base'),
|
||||
)
|
||||
|
||||
def down_model(self):
|
||||
self.client.launch_model(model_name=self.model_uid, model_type="embedding")
|
||||
|
||||
def start_down_model_thread(self):
|
||||
thread = threading.Thread(target=self.down_model)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
@ -0,0 +1,39 @@
|
||||
# coding=utf-8
|
||||
|
||||
from typing import List, Dict
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
parse = urlparse(url)
|
||||
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
|
||||
query='',
|
||||
fragment='').geturl()
|
||||
return result_url[:-1] if result_url.endswith("/") else result_url
|
||||
|
||||
|
||||
class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return XinferenceChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=base_url,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
**optional_params
|
||||
)
|
||||
@ -0,0 +1,528 @@
|
||||
# coding=utf-8
|
||||
import os
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
import requests
|
||||
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.xinference_model_provider.credential.embedding import \
|
||||
XinferenceEmbeddingModelCredential
|
||||
from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential
|
||||
from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding
|
||||
from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
xinference_llm_model_credential = XinferenceLLMModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo(
|
||||
'aquila2',
|
||||
'Aquila2 是一个具有 340 亿参数的大规模语言模型,支持中英文双语。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'aquila2-chat',
|
||||
'Aquila2 Chat 是一个聊天模型版本的 Aquila2,支持中英文双语。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'aquila2-chat-16k',
|
||||
'Aquila2 Chat 16K 是一个聊天模型版本的 Aquila2,支持长达 16K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'baichuan',
|
||||
'Baichuan 是一个大规模语言模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'baichuan-2',
|
||||
'Baichuan 2 是 Baichuan 的更新版本,具有更高的性能。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'baichuan-2-chat',
|
||||
'Baichuan 2 Chat 是一个聊天模型版本的 Baichuan 2。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'baichuan-chat',
|
||||
'Baichuan Chat 是一个聊天模型版本的 Baichuan。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'c4ai-command-r-v01',
|
||||
'C4AI Command R V01 是一个用于执行命令的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'chatglm',
|
||||
'ChatGLM 是一个聊天模型,特别擅长中文对话。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'chatglm2',
|
||||
'ChatGLM2 是 ChatGLM 的更新版本,具有更好的性能。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'chatglm2-32k',
|
||||
'ChatGLM2 32K 是一个聊天模型版本的 ChatGLM2,支持长达 32K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'chatglm3',
|
||||
'ChatGLM3 是 ChatGLM 的第三个版本,具有更高的性能。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'chatglm3-128k',
|
||||
'ChatGLM3 128K 是一个聊天模型版本的 ChatGLM3,支持长达 128K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'chatglm3-32k',
|
||||
'ChatGLM3 32K 是一个聊天模型版本的 ChatGLM3,支持长达 32K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'code-llama',
|
||||
'Code Llama 是一个专门用于代码生成的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'code-llama-instruct',
|
||||
'Code Llama Instruct 是 Code Llama 的指令微调版本,专为执行特定任务而设计。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'code-llama-python',
|
||||
'Code Llama Python 是一个专门用于 Python 代码生成的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'codegeex4',
|
||||
'CodeGeeX4 是一个用于代码生成的语言模型,具有较高的性能。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'codeqwen1.5',
|
||||
'CodeQwen 1.5 是一个用于代码生成的语言模型,具有较高的性能。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'codeqwen1.5-chat',
|
||||
'CodeQwen 1.5 Chat 是一个聊天模型版本的 CodeQwen 1.5。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'codeshell',
|
||||
'CodeShell 是一个用于代码生成的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'codeshell-chat',
|
||||
'CodeShell Chat 是一个聊天模型版本的 CodeShell。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'codestral-v0.1',
|
||||
'CodeStral V0.1 是一个用于代码生成的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'cogvlm2',
|
||||
'CogVLM2 是一个视觉语言模型,能够处理图像和文本输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'csg-wukong-chat-v0.1',
|
||||
'CSG Wukong Chat V0.1 是一个聊天模型版本的 CSG Wukong。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'deepseek',
|
||||
'Deepseek 是一个大规模语言模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'deepseek-chat',
|
||||
'Deepseek Chat 是一个聊天模型版本的 Deepseek。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'deepseek-coder',
|
||||
'Deepseek Coder 是一个专为代码生成设计的模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'deepseek-coder-instruct',
|
||||
'Deepseek Coder Instruct 是 Deepseek Coder 的指令微调版本,专为执行特定任务而设计。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'deepseek-vl-chat',
|
||||
'Deepseek VL Chat 是 Deepseek 的视觉语言聊天模型版本,能够处理图像和文本输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'falcon',
|
||||
'Falcon 是一个开源的 Transformer 解码器模型,具有 400 亿参数,旨在生成高质量的文本。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'falcon-instruct',
|
||||
'Falcon Instruct 是 Falcon 语言模型的指令微调版本,专为执行特定任务而设计。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'gemma-2-it',
|
||||
'GEMMA-2-IT 是一个基于 GEMMA-2 的意大利语模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'gemma-it',
|
||||
'GEMMA-IT 是一个基于 GEMMA 的意大利语模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'gpt-3.5-turbo',
|
||||
'GPT-3.5 Turbo 是一个高效能的通用语言模型,适用于多种应用场景。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'gpt-4',
|
||||
'GPT-4 是一个强大的多模态模型,不仅支持文本输入,还支持图像输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'gpt-4-vision-preview',
|
||||
'GPT-4 Vision Preview 是 GPT-4 的视觉预览版本,支持图像输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'gpt4all',
|
||||
'GPT4All 是一个开源的多模态模型,支持文本和图像输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'llama2',
|
||||
'Llama2 是一个具有 700 亿参数的大规模语言模型,支持多种语言。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'llama2-chat',
|
||||
'Llama2 Chat 是一个聊天模型版本的 Llama2,支持多种语言。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'llama2-chat-32k',
|
||||
'Llama2 Chat 32K 是一个聊天模型版本的 Llama2,支持长达 32K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'moss',
|
||||
'MOSS 是一个大规模语言模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'moss-chat',
|
||||
'MOSS Chat 是一个聊天模型版本的 MOSS。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen',
|
||||
'Qwen 是一个大规模语言模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen-chat',
|
||||
'Qwen Chat 是一个聊天模型版本的 Qwen。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen-chat-32k',
|
||||
'Qwen Chat 32K 是一个聊天模型版本的 Qwen,支持长达 32K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen-code',
|
||||
'Qwen Code 是一个专门用于代码生成的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen-code-chat',
|
||||
'Qwen Code Chat 是一个聊天模型版本的 Qwen Code。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen-vl',
|
||||
'Qwen VL 是 Qwen 的视觉语言模型版本,能够处理图像和文本输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'qwen-vl-chat',
|
||||
'Qwen VL Chat 是 Qwen VL 的聊天模型版本,能够处理图像和文本输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2',
|
||||
'Spark2 是一个大规模语言模型,具有 130 亿参数。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2-chat',
|
||||
'Spark2 Chat 是一个聊天模型版本的 Spark2。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2-chat-32k',
|
||||
'Spark2 Chat 32K 是一个聊天模型版本的 Spark2,支持长达 32K 令牌的上下文。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2-code',
|
||||
'Spark2 Code 是一个专门用于代码生成的语言模型。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2-code-chat',
|
||||
'Spark2 Code Chat 是一个聊天模型版本的 Spark2 Code。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2-vl',
|
||||
'Spark2 VL 是 Spark2 的视觉语言模型版本,能够处理图像和文本输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
ModelInfo(
|
||||
'spark2-vl-chat',
|
||||
'Spark2 VL Chat 是 Spark2 VL 的聊天模型版本,能够处理图像和文本输入。',
|
||||
ModelTypeConst.LLM,
|
||||
xinference_llm_model_credential,
|
||||
XinferenceChatModel
|
||||
),
|
||||
]
|
||||
|
||||
xinference_embedding_model_credential = XinferenceEmbeddingModelCredential()
|
||||
|
||||
# 生成embedding_model_info列表
|
||||
embedding_model_info = [
|
||||
ModelInfo('bce-embedding-base_v1', 'BCE 嵌入模型的基础版本。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-base-en', 'BGE 英语基础版本的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-base-en-v1.5', 'BGE 英语基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-base-zh', 'BGE 中文基础版本的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-base-zh-v1.5', 'BGE 中文基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-large-en', 'BGE 英语大型版本的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-large-en-v1.5', 'BGE 英语大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-large-zh', 'BGE 中文大型版本的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-large-zh-noinstruct', 'BGE 中文大型版本的嵌入模型,无指令调整。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-large-zh-v1.5', 'BGE 中文大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-m3', 'BGE M3 版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
|
||||
XinferenceEmbedding),
|
||||
ModelInfo('bge-small-en-v1.5', 'BGE 英语小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-small-zh', 'BGE 中文小型版本的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('bge-small-zh-v1.5', 'BGE 中文小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('e5-large-v2', 'E5 大型版本 2 的嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('gte-base', 'GTE 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
|
||||
XinferenceEmbedding),
|
||||
ModelInfo('gte-large', 'GTE 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
|
||||
XinferenceEmbedding),
|
||||
ModelInfo('jina-embeddings-v2-base-en', 'Jina 嵌入模型的英语基础版本 2。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('jina-embeddings-v2-base-zh', 'Jina 嵌入模型的中文基础版本 2。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('jina-embeddings-v2-small-en', 'Jina 嵌入模型的英语小型版本 2。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('m3e-base', 'M3E 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
|
||||
XinferenceEmbedding),
|
||||
ModelInfo('m3e-large', 'M3E 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
|
||||
XinferenceEmbedding),
|
||||
ModelInfo('m3e-small', 'M3E 小型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
|
||||
XinferenceEmbedding),
|
||||
ModelInfo('multilingual-e5-large', '多语言大型版本的 E5 嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('text2vec-base-chinese', 'Text2Vec 的中文基础版本嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('text2vec-base-chinese-paraphrase', 'Text2Vec 的中文基础版本的同义句嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('text2vec-base-chinese-sentence', 'Text2Vec 的中文基础版本的句子嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('text2vec-base-multilingual', 'Text2Vec 的多语言基础版本嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING,
|
||||
xinference_embedding_model_credential, XinferenceEmbedding),
|
||||
]
|
||||
|
||||
model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, xinference_llm_model_credential, XinferenceChatModel))
|
||||
.append_model_info_list(
|
||||
embedding_model_info).append_default_model_info(
|
||||
ModelInfo(
|
||||
'',
|
||||
'',
|
||||
ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding))
|
||||
.build())
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
parse = urlparse(url)
|
||||
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
|
||||
query='',
|
||||
fragment='').geturl()
|
||||
return result_url[:-1] if result_url.endswith("/") else result_url
|
||||
|
||||
|
||||
class XinferenceModelProvider(IModelProvider):
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_xinference_provider', name='Xinference', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xinference_model_provider', 'icon',
|
||||
'xinference_icon_svg')))
|
||||
|
||||
@staticmethod
|
||||
def get_base_model_list(api_base, model_type):
|
||||
base_url = get_base_url(api_base)
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
r = requests.request(method="GET", url=f"{base_url}/models", timeout=5)
|
||||
r.raise_for_status()
|
||||
model_list = r.json().get('data')
|
||||
return [model for model in model_list if model.get('model_type') == model_type]
|
||||
|
||||
@staticmethod
|
||||
def get_model_info_by_name(model_list, model_name):
|
||||
if model_list is None:
|
||||
return []
|
||||
return [model for model in model_list if model.get('model_name') == model_name]
|
||||
@ -9,26 +9,100 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
|
||||
_convert_delta_to_message_chunk
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage
|
||||
)
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
|
||||
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
zhipuai_chat = ZhipuChatModel(
|
||||
temperature=0.5,
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name
|
||||
model=model_name,
|
||||
max_tokens=model_kwargs.get('max_tokens', 5)
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Stream the chat response in chunks."""
|
||||
if self.zhipuai_api_key is None:
|
||||
raise ValueError("Did not find zhipuai_api_key.")
|
||||
if self.zhipuai_api_base is None:
|
||||
raise ValueError("Did not find zhipu_api_base.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||
_truncate_params(payload)
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
import httpx
|
||||
|
||||
with httpx.Client(headers=headers, timeout=60) as client:
|
||||
with connect_sse(
|
||||
client, "POST", self.zhipuai_api_base, json=payload
|
||||
) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
generation_info = {}
|
||||
if "usage" in chunk:
|
||||
generation_info = chunk["usage"]
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
if finish_reason is not None:
|
||||
break
|
||||
|
||||
@ -49,6 +49,7 @@ gevent = "^24.2.1"
|
||||
boto3 = "^1.34.151"
|
||||
langchain-aws = "^0.1.13"
|
||||
tencentcloud-sdk-python = "^3.0.1205"
|
||||
xinference-client = "^0.14.0.post1"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user