feat: 增加xinference模型对接 (#959)

This commit is contained in:
wxg0103 2024-08-12 12:07:15 +08:00 committed by GitHub
parent 9cc24ac508
commit 35d9462689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 923 additions and 74 deletions

View File

@ -21,6 +21,7 @@ from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine
VolcanicEngineModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
@ -40,3 +41,4 @@ class ModelProvideConstants(Enum):
model_tencent_provider = TencentModelProvider()
model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()
model_xinference_provider = XinferenceModelProvider()

View File

@ -31,13 +31,56 @@ def _get_aws_bedrock_icon_path():
def _initialize_model_info():
model_info_list = [_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
model_info_list = [
_create_model_info(
'anthropic.claude-v2:1',
'Claude 2 的更新,采用双倍的上下文窗口,并在长文档和 RAG 上下文中提高可靠性、幻觉率和循证准确性。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-v2',
'Anthropic 功能强大的模型,可处理各种任务,从复杂的对话和创意内容生成到详细的指令服从。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-haiku-20240307-v1:0',
'Claude 3 Haiku 是 Anthropic 最快速、最紧凑的模型,具有近乎即时的响应能力。该模型可以快速回答简单的查询和请求。客户将能够构建模仿人类交互的无缝人工智能体验。 Claude 3 Haiku 可以处理图像和返回文本输出,并且提供 200K 上下文窗口。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-sonnet-20240229-v1:0',
'Anthropic 推出的 Claude 3 Sonnet 模型在智能和速度之间取得理想的平衡,尤其是在处理企业工作负载方面。该模型提供最大的效用,同时价格低于竞争产品,并且其经过精心设计,是大规模部署人工智能的可靠选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-5-sonnet-20240620-v1:0',
'Claude 3.5 Sonnet提高了智能的行业标准在广泛的评估中超越了竞争对手的型号和Claude 3 Opus具有我们中端型号的速度和成本效益。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-instant-v1',
'一种更快速、更实惠但仍然非常强大的模型,它可以处理一系列任务,包括随意对话、文本分析、摘要和文档问题回答。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-lite-v1',
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
@ -59,7 +102,7 @@ def _initialize_model_info():
_create_model_info(
'mistral.mistral-7b-instruct-v0:2',
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
ModelTypeConst.EMBEDDING,
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(

View File

@ -0,0 +1,78 @@
# coding=utf-8
from typing import List, Dict, Optional, Any, Iterator, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
class BaseChatOpenAI(ChatOpenAI):
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.get_last_generation_info().get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
logprobs = None
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk

View File

@ -8,14 +8,11 @@
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
deepseek_chat_open_ai = DeepSeekChatModel(
@ -25,10 +22,3 @@ class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI):
)
return deepseek_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -8,14 +8,14 @@
"""
from typing import List, Dict
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
kimi_chat_open_ai = KimiChatModel(
@ -25,10 +25,3 @@ class KimiChatModel(MaxKBBaseModel, ChatOpenAI):
)
return kimi_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -9,11 +9,11 @@
from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
def get_base_url(url: str):
@ -24,7 +24,7 @@ def get_base_url(url: str):
return result_url[:-1] if result_url.endswith("/") else result_url
class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
@ -32,11 +32,3 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -8,27 +8,19 @@
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key')
openai_api_key=model_credential.get('api_key'),
streaming=model_kwargs.get('streaming', False),
max_tokens=model_kwargs.get('max_tokens', 5),
temperature=model_kwargs.get('temperature', 0.5),
)
return azure_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -6,16 +6,15 @@
@date2024/04/19 15:55
@desc:
"""
import json
from typing import List, Optional, Any, Iterator, Dict
from langchain_community.chat_models import ChatSparkLLM
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \
ChatSparkLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -31,16 +30,19 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
spark_llm_domain=model_name
spark_llm_domain=model_name,
temperature=model_kwargs.get('temperature', 0.5),
max_tokens=model_kwargs.get('max_tokens', 5),
)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.get_last_generation_info().get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
@ -58,11 +60,17 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
True,
)
for content in self.client.subscribe(timeout=self.request_timeout):
if "data" not in content:
if "data" in content:
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
elif "usage" in content:
generation_info = content["usage"]
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
continue
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
else:
continue
if cg_chunk is not None:
if run_manager:
run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk)
yield cg_chunk

View File

@ -0,0 +1 @@
# coding=utf-8

View File

@ -0,0 +1,38 @@
# coding=utf-8
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding')
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
if len(exist) == 0:
model.start_down_model_thread()
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
model.embed_query('你好')
return True
def encryption_dict(self, model_info: Dict[str, object]):
return model_info
def build_model(self, model_info: Dict[str, object]):
for key in ['model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
return self
api_base = forms.TextInputField('API 域名', required=True)

View File

@ -0,0 +1,41 @@
# coding=utf-8
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'), model_type)
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
if len(exist) == 0:
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
return self
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" width="100%" height="100%" viewBox="0 0 48 48" enable-background="new 0 0 48 48" xml:space="preserve"> <image id="image0" width="48" height="48" x="0" y="0"
href="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAALRklEQVR4nNRZCVRTVxq+92UjIQlLEggJimGTQGVRBLVWQQGtVafitNVWavWUtjMdO1Qd67E9tlNtq9MRp1NrVUDFVo9Yi9O6YCluLAXcClYERJA1hhAQCJA9d87LAiF7tHM8859z4d333v3v99373f/+7wYD/ycWQ2Hx/uIbn/qkcTySrfCJfEYU+c4dPpHu9aSxuGVkSADZ/OSsodi/qbL58z940njcMjaRSrsgXHl0aNpG1Bv7rpRHojNtvUd01SETE5A9MG8fQxukGNaKHw4jse53RW20CIqX7+moFUUBBEYCQAB839+wT6QeGrT1rl0CXEJC8GTSyhd4pIQkDil2CsRQAIQQgxAAiAEAIdBACEQ92ppGqab+lkh1tUqkqC7tVtdKHgf8MnZY/C5BylE/Ai0cIAQQgiPZ4uo99t6HljcCCc9OmUPZ9rEvSbgUQoAZwYLR/+bXVveQTgH661pGis7fGTpV0DJ84YYWKF0CToYEmB02b9NqvynbAYJEfOShDoDC3sbsV5vObHBKAANkMI+St3UyOX0rhIBgBdIGgX5tc3mz4uypLvW1q72qOy0AgAEA9T69OBThBA8CA4kVtbcfKGqHHYHnUxiM41HPHY5l8NJx4PqiA/gEDCf8mh/SIO/tttd2VEKpHnnbQkjpH+i7t1NMBPq1nVWlss3r7o6cum7H75BU1dDlfNwBmO0dKDwy5blCDokWgSMGOmgaUXBKcvdrR+BHCQiJq5NCiOlboBlIWwUn0ago3PXTwJ83q3QyjSsAHdl6wfTlH4XMOAggkYmMkgEYAgBBgHRg6LP2yn8486EnMJ2yZade744IAACalWf2nH6YsfFxgTMwAmFPzIJPl3HDNxnkgvSjg3Bp6iWEQKHk7t4GeW+PUwI8bHYkHQtMMJeJLeloobL7Qv/6zY8LPtbLj7s/Lu1bIZ0136R1XC4m+RipyD5rrfzcFX9EHmHObEuZ2JJOs+LsMZmuy+FidGYr+KGJu6emnaZhJI4eOA4XgwYSYEw+Rd0thxpH+qQuEfAmBAeby8Tq2vi/VfnzxccBv1mY+MpmYeJ+iICnPr7jwPW6R8YZMBFByi/bb+521S+RiFEMW7SD6IMXqbqh8VGAs0lk0peJqf9cFBD8zmiItKV7nWE2zotbdpU97Gx1mQAAQO4s+uBFhQZH3AUf4eXDOjz72e8iGKxkYFycoyNtqXsMAKVOI3qv/son7vSBDehaOl0hQCOwfdxxHO3lM/F0Wnp5hLdvMg4WH2nDbBolY6pjhjr+/Iiobsd9+YBbA4X16eprxm1UwPbmxSEJI1x1muIfEFyyML2UQ6VGGIAivS9kAo6DHkcEAaVW2Zn9W8UBcz8JASs4Tgm0q8uqIDTIyBK0+b2J1JS5roCPYXMEh9IWXyTTPIIMQM1AY0ivIGR+z3id3Xzzsy6NalziJJHVo2eDNyxwSECOuuRSbc15Z5tYOC31eTLGcPgJOt3PX3B+SXo5nUIOQhYSGZ1VM8mY6oNaZcdXTTV5lv5ah2qlNCp74ryJmfPsEsD/1Cm/Pe5sLyASyIHhtMV2v0m9KRRi/oKFJzxIBN44ydjSvcXzvJbfvpBZjL7Jihp35i+O3PplAF3oZ5dAg+K7H7VA0W21G4PxHT7tsz7LHoHdc+d8FED3jDfX9TjdQ3M5jdUVSNN7oOlWjj2/I5p+VVP3pR/fSszPJ2Jkq/SfgP/RgBENncBFAZT4NOucH43WPUmcUA1CdR2K8jvmTjKihHPfjZ+aq0+lITQjbtwFgfl6Gv/80L26Dwvbmi7ZI4DboLpHtDA8azed4iP7TVxcaTUDuF0d2vWVCg222lvMpjLf//1cvkeCwNSOR6fTtj0z6xDCAGZXLvjixYyL11jHn8s0qqadNZX/dgQet7bea/UKzeD9eSFvbI/mLoyxSUCm61JUyD7JAuPWArJO7ADwejnwyDEagU3G22UlxK5jUkkC/F1kWpyWYRIbI2YuqS03yt6WqlUqZwSUQAnaB2quQwg83phx4IQ3jU+1IoDbjaE9P7QpLn5tIIDsRiVPSuCM5RNy9hEgGeTV3s7pkctrzcOiVZi0qOPPL4jacr9pqv/ZGXiTDSmlHXhbT7JP+JKIsbVoFRb/05vxTr+2pcxWKDWvhzFS12QITuy41yfv+8PxwuSGvr5ye2HSMoz2KEca15VfshsQbJrhEEHfd1Lo6g2+NB7dJgGlrl9zsvuFFRqgfGA3uTM4BCGMlPdWBRfsahpQ9i858WNKeXt7rmUEQth43QMMDLx1pWS5aGTYrdScRmJyTQRIBAprfvgbq20SwE2qbhCdk2x8BUKkc5YjhXulrH8t7OTeAQVQLf3+XGZ2RfVLAEMDo+vALJQiiNRZpZfTL3Z01Lk1+gAAAWfaVPN9ak7Iy6uAKYzaMomqppVGZA8GUuMXOjtaYXkI4oU+aXH3Bi8X/9RWf/1ic+ux6AC/IC7dUziWXAGQV1v75uc3bha6Cz6aMz9sbuiabeb9UskM3s3Oc/vtEsDt3nBxFY3IQnxafLJdAsZrJpk7eQZ3bQYBEnt+EV+uOHTrdkFDX19JBMuHz6ZRg3Nv3c7aVFqxz13wuGXOyv2K5cl/ymLg4IPBuzVWO5stWxSw6/1E1pvbHREw/z+g7rpe1PbhxhuS41fw9tO5ftxrYon4UcCvjst+PXlyZo6pLwwb26Oq2k994XAGTNY0VFzmRWB38z2n4XLCnBGgEpm8aM7S12L9liUBDPT+2l1Tq0Uq5A5wT4yBrZqWvSkpbO3u0T4t+tUiZY9LM2CyBFbmc4sCPz1CIlB9LQlgZs6BxT2FdrDjdu+5gqaekjP1/ZevDijFcnt9eHlwPWcJXl6aEpK5ieUZGGtvkPDSNVB/xS0CuPGpcbw/CvYf4dIi57ty/Gjj/FQ9pJY2SpX3G/vlnSIIwQiEgEglMfz4XsJIFo0XDTBIsmyLmfkDYwSKXSawZmZEaqAXY8K289cOEiAZe3HS3i2xfiu2QghIDgk4JYjce9/s+kb72QMurQHccjOScxdGTfzr7FD+xAuNbT9VdX9/UaYWXwn3TlpEJJDoj0YAuQzW1r2qppP7XPqRb2mMYHoQh5mM5zCzgv3XFq9bWpkczo+qlhws/VftM9Gtg1ePjv9+sLh2VICN9+20Nd/IAASqKtG5My7NQO6a5Bw23SPMlBQxPcjcF+NC1wZ40dGZupqSSvHhk1JFy4Ug5tSnqCQm3+7ojRtF5KLEbD+r7So5cObWFwVOZ2BJ3KTECJ7PorEPc0M2iTBAolD0p1Na/L2bkuPln1RHJx6tz1wgHqkvsZl6jCaE9jNdu8UsmZQpeu7uK31bf07rdAbyXk/KYzM8Qg1zZzClVifdeLJi+Y6imnwtGgvvOqAFoqHbzeWdOd+Ih+vPMigcGosWFA4xQHRJ38bQC2zcM10/lD+49fHZxYu6ZfclwNZPTONGf+qkmXmZyb/oz+2R4fyyVSor+9PhKxnX2yRtTrUHAAigC1kJ/JdWTuOlr/SnB8+AGLLelFwLwyOXmvJ3flv9/o5h1cDoR5BDAqUfPV8cwfVJ1R/CIqT9rqrl7+8eq9iu1Gjd2lVNNoEZ4x/OnpUcyp75dChr+kyWZ2AU/pXlgICmT955s6b1XOH5uzkHO/sbrH4vsEtgybSg2XlvJZfh4JUqXd/WE9dWHSqtL3oU4PaMQWYT/RiCSb60wAlMCpvFpLKpEABd93CLTK6Wdbb21jb0jXS5fSart5+3Lv5Bkvcaat+7qmOukCf8PYH/z21WhH+k5PBqTfWO9MopE1j8J43HbSvcklZQsCH1azaDQn7SWNy21Dh+1IZlMa8+aRyu2n8DAAD//2mMQDVCqcaMAAAAAElFTkSuQmCC" ></image>
</svg>

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@ -0,0 +1,24 @@
# coding=utf-8
import threading
from typing import Dict
from langchain_community.embeddings import XinferenceEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XinferenceEmbedding(
model_uid=model_name,
server_url=model_credential.get('api_base'),
)
def down_model(self):
self.client.launch_model(model_name=self.model_uid, model_type="embedding")
def start_down_model_thread(self):
thread = threading.Thread(target=self.down_model)
thread.daemon = True
thread.start()

View File

@ -0,0 +1,39 @@
# coding=utf-8
from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
return XinferenceChatModel(
model=model_name,
openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'),
streaming=model_kwargs.get('streaming', False),
**optional_params
)

View File

@ -0,0 +1,528 @@
# coding=utf-8
import os
from urllib.parse import urlparse, ParseResult
import requests
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfoManage
from setting.models_provider.impl.xinference_model_provider.credential.embedding import \
XinferenceEmbeddingModelCredential
from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential
from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding
from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel
from smartdoc.conf import PROJECT_DIR
xinference_llm_model_credential = XinferenceLLMModelCredential()
model_info_list = [
ModelInfo(
'aquila2',
'Aquila2 是一个具有 340 亿参数的大规模语言模型,支持中英文双语。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'aquila2-chat',
'Aquila2 Chat 是一个聊天模型版本的 Aquila2支持中英文双语。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'aquila2-chat-16k',
'Aquila2 Chat 16K 是一个聊天模型版本的 Aquila2支持长达 16K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'baichuan',
'Baichuan 是一个大规模语言模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'baichuan-2',
'Baichuan 2 是 Baichuan 的更新版本,具有更高的性能。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'baichuan-2-chat',
'Baichuan 2 Chat 是一个聊天模型版本的 Baichuan 2。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'baichuan-chat',
'Baichuan Chat 是一个聊天模型版本的 Baichuan。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'c4ai-command-r-v01',
'C4AI Command R V01 是一个用于执行命令的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'chatglm',
'ChatGLM 是一个聊天模型,特别擅长中文对话。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'chatglm2',
'ChatGLM2 是 ChatGLM 的更新版本,具有更好的性能。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'chatglm2-32k',
'ChatGLM2 32K 是一个聊天模型版本的 ChatGLM2支持长达 32K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'chatglm3',
'ChatGLM3 是 ChatGLM 的第三个版本,具有更高的性能。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'chatglm3-128k',
'ChatGLM3 128K 是一个聊天模型版本的 ChatGLM3支持长达 128K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'chatglm3-32k',
'ChatGLM3 32K 是一个聊天模型版本的 ChatGLM3支持长达 32K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'code-llama',
'Code Llama 是一个专门用于代码生成的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'code-llama-instruct',
'Code Llama Instruct 是 Code Llama 的指令微调版本,专为执行特定任务而设计。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'code-llama-python',
'Code Llama Python 是一个专门用于 Python 代码生成的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'codegeex4',
'CodeGeeX4 是一个用于代码生成的语言模型,具有较高的性能。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'codeqwen1.5',
'CodeQwen 1.5 是一个用于代码生成的语言模型,具有较高的性能。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'codeqwen1.5-chat',
'CodeQwen 1.5 Chat 是一个聊天模型版本的 CodeQwen 1.5。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'codeshell',
'CodeShell 是一个用于代码生成的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'codeshell-chat',
'CodeShell Chat 是一个聊天模型版本的 CodeShell。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'codestral-v0.1',
'CodeStral V0.1 是一个用于代码生成的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'cogvlm2',
'CogVLM2 是一个视觉语言模型,能够处理图像和文本输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'csg-wukong-chat-v0.1',
'CSG Wukong Chat V0.1 是一个聊天模型版本的 CSG Wukong。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'deepseek',
'Deepseek 是一个大规模语言模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'deepseek-chat',
'Deepseek Chat 是一个聊天模型版本的 Deepseek。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'deepseek-coder',
'Deepseek Coder 是一个专为代码生成设计的模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'deepseek-coder-instruct',
'Deepseek Coder Instruct 是 Deepseek Coder 的指令微调版本,专为执行特定任务而设计。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'deepseek-vl-chat',
'Deepseek VL Chat 是 Deepseek 的视觉语言聊天模型版本,能够处理图像和文本输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'falcon',
'Falcon 是一个开源的 Transformer 解码器模型,具有 400 亿参数,旨在生成高质量的文本。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'falcon-instruct',
'Falcon Instruct 是 Falcon 语言模型的指令微调版本,专为执行特定任务而设计。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'gemma-2-it',
'GEMMA-2-IT 是一个基于 GEMMA-2 的意大利语模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'gemma-it',
'GEMMA-IT 是一个基于 GEMMA 的意大利语模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'gpt-3.5-turbo',
'GPT-3.5 Turbo 是一个高效能的通用语言模型,适用于多种应用场景。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'gpt-4',
'GPT-4 是一个强大的多模态模型,不仅支持文本输入,还支持图像输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'gpt-4-vision-preview',
'GPT-4 Vision Preview 是 GPT-4 的视觉预览版本,支持图像输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'gpt4all',
'GPT4All 是一个开源的多模态模型,支持文本和图像输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'llama2',
'Llama2 是一个具有 700 亿参数的大规模语言模型,支持多种语言。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'llama2-chat',
'Llama2 Chat 是一个聊天模型版本的 Llama2支持多种语言。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'llama2-chat-32k',
'Llama2 Chat 32K 是一个聊天模型版本的 Llama2支持长达 32K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'moss',
'MOSS 是一个大规模语言模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'moss-chat',
'MOSS Chat 是一个聊天模型版本的 MOSS。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen',
'Qwen 是一个大规模语言模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen-chat',
'Qwen Chat 是一个聊天模型版本的 Qwen。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen-chat-32k',
'Qwen Chat 32K 是一个聊天模型版本的 Qwen支持长达 32K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen-code',
'Qwen Code 是一个专门用于代码生成的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen-code-chat',
'Qwen Code Chat 是一个聊天模型版本的 Qwen Code。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen-vl',
'Qwen VL 是 Qwen 的视觉语言模型版本,能够处理图像和文本输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'qwen-vl-chat',
'Qwen VL Chat 是 Qwen VL 的聊天模型版本,能够处理图像和文本输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2',
'Spark2 是一个大规模语言模型,具有 130 亿参数。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2-chat',
'Spark2 Chat 是一个聊天模型版本的 Spark2。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2-chat-32k',
'Spark2 Chat 32K 是一个聊天模型版本的 Spark2支持长达 32K 令牌的上下文。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2-code',
'Spark2 Code 是一个专门用于代码生成的语言模型。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2-code-chat',
'Spark2 Code Chat 是一个聊天模型版本的 Spark2 Code。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2-vl',
'Spark2 VL 是 Spark2 的视觉语言模型版本,能够处理图像和文本输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
ModelInfo(
'spark2-vl-chat',
'Spark2 VL Chat 是 Spark2 VL 的聊天模型版本,能够处理图像和文本输入。',
ModelTypeConst.LLM,
xinference_llm_model_credential,
XinferenceChatModel
),
]
xinference_embedding_model_credential = XinferenceEmbeddingModelCredential()
# 生成embedding_model_info列表
embedding_model_info = [
ModelInfo('bce-embedding-base_v1', 'BCE 嵌入模型的基础版本。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-base-en', 'BGE 英语基础版本的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-base-en-v1.5', 'BGE 英语基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-base-zh', 'BGE 中文基础版本的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-base-zh-v1.5', 'BGE 中文基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-large-en', 'BGE 英语大型版本的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-large-en-v1.5', 'BGE 英语大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-large-zh', 'BGE 中文大型版本的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-large-zh-noinstruct', 'BGE 中文大型版本的嵌入模型,无指令调整。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-large-zh-v1.5', 'BGE 中文大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-m3', 'BGE M3 版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
XinferenceEmbedding),
ModelInfo('bge-small-en-v1.5', 'BGE 英语小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-small-zh', 'BGE 中文小型版本的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('bge-small-zh-v1.5', 'BGE 中文小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('e5-large-v2', 'E5 大型版本 2 的嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('gte-base', 'GTE 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
XinferenceEmbedding),
ModelInfo('gte-large', 'GTE 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
XinferenceEmbedding),
ModelInfo('jina-embeddings-v2-base-en', 'Jina 嵌入模型的英语基础版本 2。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('jina-embeddings-v2-base-zh', 'Jina 嵌入模型的中文基础版本 2。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('jina-embeddings-v2-small-en', 'Jina 嵌入模型的英语小型版本 2。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('m3e-base', 'M3E 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
XinferenceEmbedding),
ModelInfo('m3e-large', 'M3E 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
XinferenceEmbedding),
ModelInfo('m3e-small', 'M3E 小型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential,
XinferenceEmbedding),
ModelInfo('multilingual-e5-large', '多语言大型版本的 E5 嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('text2vec-base-chinese', 'Text2Vec 的中文基础版本嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('text2vec-base-chinese-paraphrase', 'Text2Vec 的中文基础版本的同义句嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('text2vec-base-chinese-sentence', 'Text2Vec 的中文基础版本的句子嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('text2vec-base-multilingual', 'Text2Vec 的多语言基础版本嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
]
model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo(
'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, xinference_llm_model_credential, XinferenceChatModel))
.append_model_info_list(
embedding_model_info).append_default_model_info(
ModelInfo(
'',
'',
ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding))
.build())
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class XinferenceModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_xinference_provider', name='Xinference', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xinference_model_provider', 'icon',
'xinference_icon_svg')))
@staticmethod
def get_base_model_list(api_base, model_type):
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
r = requests.request(method="GET", url=f"{base_url}/models", timeout=5)
r.raise_for_status()
model_list = r.json().get('data')
return [model for model in model_list if model.get('model_type') == model_type]
@staticmethod
def get_model_info_by_name(model_list, model_name):
if model_list is None:
return []
return [model for model in model_list if model.get('model_name') == model_name]

View File

@ -9,26 +9,100 @@
from typing import List, Dict
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
import json
import logging
import time
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage
)
from langchain_core.outputs import ChatGenerationChunk
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
zhipuai_chat = ZhipuChatModel(
temperature=0.5,
api_key=model_credential.get('api_key'),
model=model_name
model=model_name,
max_tokens=model_kwargs.get('max_tokens', 5)
)
return zhipuai_chat
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.get_last_generation_info().get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks."""
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
default_chunk_class = AIMessageChunk
import httpx
with httpx.Client(headers=headers, timeout=60) as client:
with connect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
generation_info = {}
if "usage" in chunk:
generation_info = chunk["usage"]
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break

View File

@ -49,6 +49,7 @@ gevent = "^24.2.1"
boto3 = "^1.34.151"
langchain-aws = "^0.1.13"
tencentcloud-sdk-python = "^3.0.1205"
xinference-client = "^0.14.0.post1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"