fix: 解决非流式返回报错的缺陷
This commit is contained in:
parent
90a7a9d085
commit
88f6e336e7
@ -1,9 +1,11 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Iterator, Type
|
from typing import List, Dict, Optional, Any, Iterator, Type, cast
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
|
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
|
||||||
from langchain_core.outputs import ChatGenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
|
||||||
|
from langchain_core.runnables import RunnableConfig, ensure_config
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
||||||
|
|
||||||
@ -76,3 +78,28 @@ class BaseChatOpenAI(ChatOpenAI):
|
|||||||
)
|
)
|
||||||
is_first_chunk = False
|
is_first_chunk = False
|
||||||
yield generation_chunk
|
yield generation_chunk
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseMessage:
|
||||||
|
config = ensure_config(config)
|
||||||
|
chat_result = cast(
|
||||||
|
ChatGeneration,
|
||||||
|
self.generate_prompt(
|
||||||
|
[self._convert_input(input)],
|
||||||
|
stop=stop,
|
||||||
|
callbacks=config.get("callbacks"),
|
||||||
|
tags=config.get("tags"),
|
||||||
|
metadata=config.get("metadata"),
|
||||||
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
**kwargs,
|
||||||
|
).generations[0][0],
|
||||||
|
).message
|
||||||
|
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
|
||||||
|
return chat_result
|
||||||
|
|||||||
@ -6,13 +6,15 @@
|
|||||||
@date:2024/4/28 11:44
|
@date:2024/4/28 11:44
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict, Optional, Iterator, Any
|
from typing import List, Dict, Optional, Iterator, Any, cast
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatTongyi
|
from langchain_community.chat_models import ChatTongyi
|
||||||
from langchain_community.llms.tongyi import generate_with_last_element_mark
|
from langchain_community.llms.tongyi import generate_with_last_element_mark
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.outputs import ChatGenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
|
||||||
|
from langchain_core.runnables import RunnableConfig, ensure_config
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
@ -83,3 +85,28 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseMessage:
|
||||||
|
config = ensure_config(config)
|
||||||
|
chat_result = cast(
|
||||||
|
ChatGeneration,
|
||||||
|
self.generate_prompt(
|
||||||
|
[self._convert_input(input)],
|
||||||
|
stop=stop,
|
||||||
|
callbacks=config.get("callbacks"),
|
||||||
|
tags=config.get("tags"),
|
||||||
|
metadata=config.get("metadata"),
|
||||||
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
**kwargs,
|
||||||
|
).generations[0][0],
|
||||||
|
).message
|
||||||
|
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
|
||||||
|
return chat_result
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user