fix: 修改Tokenizer加载顺序
This commit is contained in:
parent
7933b15a38
commit
fb7dfba567
@ -10,15 +10,27 @@ from typing import List
|
|||||||
|
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
|
||||||
force_download=False)
|
class TokenizerManage:
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizer():
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
if TokenizerManage.tokenizer is None:
|
||||||
|
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||||
|
cache_dir="/opt/maxkb/model/tokenizer",
|
||||||
|
resume_download=False,
|
||||||
|
force_download=False)
|
||||||
|
return TokenizerManage.tokenizer
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatModel(ChatOpenAI):
|
class OllamaChatModel(ChatOpenAI):
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return len(tokenizer.encode(text))
|
return len(tokenizer.encode(text))
|
||||||
|
|||||||
@ -9,7 +9,6 @@
|
|||||||
from typing import Optional, List, Any, Iterator, cast
|
from typing import Optional, List, Any, Iterator, cast
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
from langchain_community.chat_models import QianfanChatEndpoint
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.load import dumpd
|
from langchain.load import dumpd
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
@ -17,18 +16,31 @@ from langchain.schema.language_model import LanguageModelInput
|
|||||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
from langchain.schema.runnable import RunnableConfig
|
from langchain.schema.runnable import RunnableConfig
|
||||||
from transformers import GPT2TokenizerFast
|
from langchain_community.chat_models import QianfanChatEndpoint
|
||||||
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
|
||||||
force_download=False)
|
class TokenizerManage:
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizer():
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
if TokenizerManage.tokenizer is None:
|
||||||
|
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||||
|
cache_dir="/opt/maxkb/model/tokenizer",
|
||||||
|
resume_download=False,
|
||||||
|
force_download=False)
|
||||||
|
return TokenizerManage.tokenizer
|
||||||
|
|
||||||
|
|
||||||
class QianfanChatModel(QianfanChatEndpoint):
|
class QianfanChatModel(QianfanChatEndpoint):
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
return len(tokenizer.encode(text))
|
return len(tokenizer.encode(text))
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user