fix: 私有部署计算tokens报错 (#284)
This commit is contained in:
parent
9d808b4ccd
commit
29427a0ad6
24
apps/common/config/tokenizer_manage_config.py
Normal file
24
apps/common/config/tokenizer_manage_config.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: tokenizer_manage_config.py
|
||||||
|
@date:2024/4/28 10:17
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerManage:
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizer():
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
if TokenizerManage.tokenizer is None:
|
||||||
|
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||||
|
'gpt2',
|
||||||
|
cache_dir="/opt/maxkb/model/tokenizer",
|
||||||
|
local_files_only=True,
|
||||||
|
resume_download=False,
|
||||||
|
force_download=False)
|
||||||
|
return TokenizerManage.tokenizer
|
||||||
@ -19,6 +19,7 @@ from common.util.file_util import get_file_content
|
|||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||||
ModelInfo, \
|
ModelInfo, \
|
||||||
ModelTypeConst, ValidCode
|
ModelTypeConst, ValidCode
|
||||||
|
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -119,8 +120,8 @@ class AzureModelProvider(IModelProvider):
|
|||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
|
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
|
||||||
model_info: ModelInfo = model_dict.get(model_name)
|
model_info: ModelInfo = model_dict.get(model_name)
|
||||||
azure_chat_open_ai = AzureChatOpenAI(
|
azure_chat_open_ai = AzureChatModel(
|
||||||
openai_api_base=model_credential.get('api_base'),
|
azure_endpoint=model_credential.get('api_base'),
|
||||||
openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get(
|
openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get(
|
||||||
'api_version'),
|
'api_version'),
|
||||||
deployment_name=model_credential.get('deployment_name'),
|
deployment_name=model_credential.get('deployment_name'),
|
||||||
|
|||||||
@ -0,0 +1,24 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: azure_chat_model.py
|
||||||
|
@date:2024/4/28 11:45
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
|
||||||
|
|
||||||
|
class AzureChatModel(AzureChatOpenAI):
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -11,19 +11,7 @@ from typing import List
|
|||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
class TokenizerManage:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizer():
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
if TokenizerManage.tokenizer is None:
|
|
||||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
|
||||||
cache_dir="/opt/maxkb/model/tokenizer",
|
|
||||||
resume_download=False,
|
|
||||||
force_download=False)
|
|
||||||
return TokenizerManage.tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class KimiChatModel(ChatOpenAI):
|
class KimiChatModel(ChatOpenAI):
|
||||||
|
|||||||
@ -11,19 +11,7 @@ from typing import List
|
|||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
class TokenizerManage:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizer():
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
if TokenizerManage.tokenizer is None:
|
|
||||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
|
||||||
cache_dir="/opt/maxkb/model/tokenizer",
|
|
||||||
resume_download=False,
|
|
||||||
force_download=False)
|
|
||||||
return TokenizerManage.tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatModel(ChatOpenAI):
|
class OllamaChatModel(ChatOpenAI):
|
||||||
|
|||||||
@ -11,19 +11,7 @@ from typing import List
|
|||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
class TokenizerManage:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizer():
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
if TokenizerManage.tokenizer is None:
|
|
||||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
|
||||||
cache_dir="/opt/maxkb/model/tokenizer",
|
|
||||||
resume_download=False,
|
|
||||||
force_download=False)
|
|
||||||
return TokenizerManage.tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatModel(ChatOpenAI):
|
class OpenAIChatModel(ChatOpenAI):
|
||||||
|
|||||||
@ -0,0 +1,24 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: qwen_chat_model.py
|
||||||
|
@date:2024/4/28 11:44
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatTongyi
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
|
||||||
|
|
||||||
|
class QwenChatModel(ChatTongyi):
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -18,6 +18,7 @@ from common.forms import BaseForm
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||||
ModelInfo, IModelProvider, ValidCode
|
ModelInfo, IModelProvider, ValidCode
|
||||||
|
from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ class QwenModelProvider(IModelProvider):
|
|||||||
return 3
|
return 3
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
|
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
|
||||||
chat_tong_yi = ChatTongyi(
|
chat_tong_yi = QwenChatModel(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
dashscope_api_key=model_credential.get('api_key')
|
dashscope_api_key=model_credential.get('api_key')
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,19 +18,7 @@ from langchain.schema.output import ChatGenerationChunk
|
|||||||
from langchain.schema.runnable import RunnableConfig
|
from langchain.schema.runnable import RunnableConfig
|
||||||
from langchain_community.chat_models import QianfanChatEndpoint
|
from langchain_community.chat_models import QianfanChatEndpoint
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
class TokenizerManage:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizer():
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
if TokenizerManage.tokenizer is None:
|
|
||||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
|
||||||
cache_dir="/opt/maxkb/model/tokenizer",
|
|
||||||
resume_download=False,
|
|
||||||
force_download=False)
|
|
||||||
return TokenizerManage.tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class QianfanChatModel(QianfanChatEndpoint):
|
class QianfanChatModel(QianfanChatEndpoint):
|
||||||
|
|||||||
@ -12,11 +12,21 @@ from typing import List, Optional, Any, Iterator
|
|||||||
from langchain_community.chat_models import ChatSparkLLM
|
from langchain_community.chat_models import ChatSparkLLM
|
||||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.messages import BaseMessage, AIMessageChunk
|
from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string
|
||||||
from langchain_core.outputs import ChatGenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
|
||||||
|
|
||||||
class XFChatSparkLLM(ChatSparkLLM):
|
class XFChatSparkLLM(ChatSparkLLM):
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
|
|||||||
@ -0,0 +1,24 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: zhipu_chat_model.py
|
||||||
|
@date:2024/4/28 11:42
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatZhipuAI
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuChatModel(ChatZhipuAI):
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
@ -18,6 +18,7 @@ from common.forms import BaseForm
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||||
ModelInfo, IModelProvider, ValidCode
|
ModelInfo, IModelProvider, ValidCode
|
||||||
|
from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ class ZhiPuModelProvider(IModelProvider):
|
|||||||
return 3
|
return 3
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
|
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
|
||||||
zhipuai_chat = ChatZhipuAI(
|
zhipuai_chat = ZhipuChatModel(
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
api_key=model_credential.get('api_key'),
|
api_key=model_credential.get('api_key'),
|
||||||
model=model_name
|
model=model_name
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user