refactor: 优化代码
This commit is contained in:
parent
780a44f368
commit
4c28ff12f5
@ -11,73 +11,26 @@ from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
|||||||
|
|
||||||
|
|
||||||
class BaseChatOpenAI(ChatOpenAI):
|
class BaseChatOpenAI(ChatOpenAI):
|
||||||
|
usage_metadata: dict = {}
|
||||||
|
|
||||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
return self.__dict__.get('_last_generation_info')
|
return self.usage_metadata
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
return self.usage_metadata.get('input_tokens', 0)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
return self.get_last_generation_info().get('output_tokens', 0)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
kwargs["stream_options"] = {"include_usage": True}
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
for chunk in super()._stream(*args, stream_usage=stream_usage, **kwargs):
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
if chunk.message.usage_metadata is not None:
|
||||||
if self.include_response_headers:
|
self.usage_metadata = chunk.message.usage_metadata
|
||||||
raw_response = self.client.with_raw_response.create(**payload)
|
yield chunk
|
||||||
response = raw_response.parse()
|
|
||||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
|
||||||
else:
|
|
||||||
response = self.client.create(**payload)
|
|
||||||
base_generation_info = {}
|
|
||||||
with response:
|
|
||||||
is_first_chunk = True
|
|
||||||
for chunk in response:
|
|
||||||
if not isinstance(chunk, dict):
|
|
||||||
chunk = chunk.model_dump()
|
|
||||||
if (len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or
|
|
||||||
chunk["choices"][0]["finish_reason"] == "stop") and chunk.get("usage") is not None:
|
|
||||||
if token_usage := chunk.get("usage"):
|
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
|
||||||
logprobs = None
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
choice = chunk["choices"][0]
|
|
||||||
if choice["delta"] is None:
|
|
||||||
continue
|
|
||||||
message_chunk = _convert_delta_to_message_chunk(
|
|
||||||
choice["delta"], default_chunk_class
|
|
||||||
)
|
|
||||||
generation_info = {**base_generation_info} if is_first_chunk else {}
|
|
||||||
if finish_reason := choice.get("finish_reason"):
|
|
||||||
generation_info["finish_reason"] = finish_reason
|
|
||||||
if model_name := chunk.get("model"):
|
|
||||||
generation_info["model_name"] = model_name
|
|
||||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
|
||||||
generation_info["system_fingerprint"] = system_fingerprint
|
|
||||||
|
|
||||||
logprobs = choice.get("logprobs")
|
|
||||||
if logprobs:
|
|
||||||
generation_info["logprobs"] = logprobs
|
|
||||||
default_chunk_class = message_chunk.__class__
|
|
||||||
generation_chunk = ChatGenerationChunk(
|
|
||||||
message=message_chunk, generation_info=generation_info or None
|
|
||||||
)
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
|
||||||
)
|
|
||||||
is_first_chunk = False
|
|
||||||
yield generation_chunk
|
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
@ -101,5 +54,5 @@ class BaseChatOpenAI(ChatOpenAI):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message
|
).message
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
|
self.usage_metadata = chat_result.response_metadata['token_usage']
|
||||||
return chat_result
|
return chat_result
|
||||||
|
|||||||
@ -39,14 +39,16 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
|||||||
)
|
)
|
||||||
return chat_tong_yi
|
return chat_tong_yi
|
||||||
|
|
||||||
|
usage_metadata: dict = {}
|
||||||
|
|
||||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
return self.__dict__.get('_last_generation_info')
|
return self.usage_metadata
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
return self.get_last_generation_info().get('input_tokens', 0)
|
return self.usage_metadata.get('input_tokens', 0)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
return self.get_last_generation_info().get('output_tokens', 0)
|
return self.usage_metadata.get('output_tokens', 0)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -69,7 +71,7 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
|||||||
and message["content"] == ""
|
and message["content"] == ""
|
||||||
) or (choice["finish_reason"] == "length"):
|
) or (choice["finish_reason"] == "length"):
|
||||||
token_usage = stream_resp["usage"]
|
token_usage = stream_resp["usage"]
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
self.usage_metadata = token_usage
|
||||||
if (
|
if (
|
||||||
choice["finish_reason"] == "null"
|
choice["finish_reason"] == "null"
|
||||||
and message["content"] == ""
|
and message["content"] == ""
|
||||||
@ -108,5 +110,5 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message
|
).message
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
|
self.usage_metadata = chat_result.response_metadata['token_usage']
|
||||||
return chat_result
|
return chat_result
|
||||||
|
|||||||
@ -54,7 +54,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
|
|
||||||
|
|
||||||
def _convert_delta_to_message_chunk(
|
def _convert_delta_to_message_chunk(
|
||||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
role = _dict.get("Role")
|
role = _dict.get("Role")
|
||||||
content = _dict.get("Content") or ""
|
content = _dict.get("Content") or ""
|
||||||
@ -198,11 +198,11 @@ class ChatHunyuan(BaseChatModel):
|
|||||||
return {**normal_params, **self.model_kwargs}
|
return {**normal_params, **self.model_kwargs}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
@ -213,12 +213,14 @@ class ChatHunyuan(BaseChatModel):
|
|||||||
res = self._chat(messages, **kwargs)
|
res = self._chat(messages, **kwargs)
|
||||||
return _create_chat_result(json.loads(res.to_json_string()))
|
return _create_chat_result(json.loads(res.to_json_string()))
|
||||||
|
|
||||||
|
usage_metadata: dict = {}
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
res = self._chat(messages, **kwargs)
|
res = self._chat(messages, **kwargs)
|
||||||
|
|
||||||
@ -238,9 +240,7 @@ class ChatHunyuan(BaseChatModel):
|
|||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
# FinishReason === stop
|
# FinishReason === stop
|
||||||
if choice.get("FinishReason") == "stop":
|
if choice.get("FinishReason") == "stop":
|
||||||
self.__dict__.setdefault("_last_generation_info", {}).update(
|
self.usage_metadata = response.get("Usage", {})
|
||||||
response.get("Usage", {})
|
|
||||||
)
|
|
||||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||||
@ -275,4 +275,4 @@ class ChatHunyuan(BaseChatModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "hunyuan-chat"
|
return "hunyuan-chat"
|
||||||
|
|||||||
@ -38,10 +38,10 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
|
|||||||
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
|
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
|
||||||
|
|
||||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
return self.__dict__.get('_last_generation_info')
|
return self.usage_metadata
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
return self.get_last_generation_info().get('PromptTokens', 0)
|
return self.usage_metadata.get('PromptTokens', 0)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
return self.get_last_generation_info().get('CompletionTokens', 0)
|
return self.usage_metadata.get('CompletionTokens', 0)
|
||||||
|
|||||||
@ -37,14 +37,16 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||||||
streaming=model_kwargs.get('streaming', False),
|
streaming=model_kwargs.get('streaming', False),
|
||||||
init_kwargs=optional_params)
|
init_kwargs=optional_params)
|
||||||
|
|
||||||
|
usage_metadata: dict = {}
|
||||||
|
|
||||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
return self.__dict__.get('_last_generation_info')
|
return self.usage_metadata
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
return self.usage_metadata.get('prompt_tokens', 0)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
return self.usage_metadata.get('completion_tokens', 0)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -63,7 +65,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||||
if msg.content == "" or res.get("body").get("is_end"):
|
if msg.content == "" or res.get("body").get("is_end"):
|
||||||
token_usage = res.get("body").get("usage")
|
token_usage = res.get("body").get("usage")
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
self.usage_metadata = token_usage
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
text=res["result"],
|
text=res["result"],
|
||||||
message=AIMessageChunk( # type: ignore[call-arg]
|
message=AIMessageChunk( # type: ignore[call-arg]
|
||||||
|
|||||||
@ -40,14 +40,16 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
|||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
usage_metadata: dict = {}
|
||||||
|
|
||||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
return self.__dict__.get('_last_generation_info')
|
return self.usage_metadata
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
return self.usage_metadata.get('prompt_tokens', 0)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
return self.usage_metadata.get('completion_tokens', 0)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -71,7 +73,7 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
|||||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||||
elif "usage" in content:
|
elif "usage" in content:
|
||||||
generation_info = content["usage"]
|
generation_info = content["usage"]
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
|
self.usage_metadata = generation_info
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -47,14 +47,16 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
|||||||
)
|
)
|
||||||
return zhipuai_chat
|
return zhipuai_chat
|
||||||
|
|
||||||
|
usage_metadata: dict = {}
|
||||||
|
|
||||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
return self.__dict__.get('_last_generation_info')
|
return self.usage_metadata
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
return self.get_last_generation_info().get('prompt_tokens', 0)
|
return self.usage_metadata.get('prompt_tokens', 0)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
return self.get_last_generation_info().get('completion_tokens', 0)
|
return self.usage_metadata.get('completion_tokens', 0)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -91,7 +93,7 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
|||||||
generation_info = {}
|
generation_info = {}
|
||||||
if "usage" in chunk:
|
if "usage" in chunk:
|
||||||
generation_info = chunk["usage"]
|
generation_info = chunk["usage"]
|
||||||
self.__dict__.setdefault('_last_generation_info', {}).update(generation_info)
|
self.usage_metadata = generation_info
|
||||||
chunk = _convert_delta_to_message_chunk(
|
chunk = _convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
|
|||||||
@ -52,7 +52,7 @@ const platforms = reactive([
|
|||||||
{
|
{
|
||||||
key: 'wecom',
|
key: 'wecom',
|
||||||
logoSrc: new URL(`../../assets/logo_wechat-work.svg`, import.meta.url).href,
|
logoSrc: new URL(`../../assets/logo_wechat-work.svg`, import.meta.url).href,
|
||||||
name: '企业微信',
|
name: '企业微信应用',
|
||||||
description: '打造企业微信智能应用',
|
description: '打造企业微信智能应用',
|
||||||
isActive: false,
|
isActive: false,
|
||||||
exists: false
|
exists: false
|
||||||
@ -60,7 +60,7 @@ const platforms = reactive([
|
|||||||
{
|
{
|
||||||
key: 'dingtalk',
|
key: 'dingtalk',
|
||||||
logoSrc: new URL(`../../assets/logo_dingtalk.svg`, import.meta.url).href,
|
logoSrc: new URL(`../../assets/logo_dingtalk.svg`, import.meta.url).href,
|
||||||
name: '钉钉',
|
name: '钉钉应用',
|
||||||
description: '打造钉钉智能应用',
|
description: '打造钉钉智能应用',
|
||||||
isActive: false,
|
isActive: false,
|
||||||
exists: false
|
exists: false
|
||||||
@ -76,7 +76,7 @@ const platforms = reactive([
|
|||||||
{
|
{
|
||||||
key: 'feishu',
|
key: 'feishu',
|
||||||
logoSrc: new URL(`../../assets/logo_lark.svg`, import.meta.url).href,
|
logoSrc: new URL(`../../assets/logo_lark.svg`, import.meta.url).href,
|
||||||
name: '飞书',
|
name: '飞书应用',
|
||||||
description: '打造飞书智能应用',
|
description: '打造飞书智能应用',
|
||||||
isActive: false,
|
isActive: false,
|
||||||
exists: false
|
exists: false
|
||||||
|
|||||||
@ -29,8 +29,8 @@
|
|||||||
</el-form-item>
|
</el-form-item>
|
||||||
</template>
|
</template>
|
||||||
<div v-if="configType === 'wechat'" class="flex align-center" style="margin-bottom: 8px">
|
<div v-if="configType === 'wechat'" class="flex align-center" style="margin-bottom: 8px">
|
||||||
<span class="el-form-item__label">是否是订阅号</span>
|
<span class="el-form-item__label">认证通过</span>
|
||||||
<el-switch v-if="configType === 'wechat'" v-model="form[configType].is_personal" />
|
<el-switch v-if="configType === 'wechat'" v-model="form[configType].is_certification" />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<h4 class="title-decoration-1 mb-16">回调地址</h4>
|
<h4 class="title-decoration-1 mb-16">回调地址</h4>
|
||||||
@ -111,7 +111,7 @@ const form = reactive<any>({
|
|||||||
app_secret: '',
|
app_secret: '',
|
||||||
token: '',
|
token: '',
|
||||||
encoding_aes_key: '',
|
encoding_aes_key: '',
|
||||||
is_personal: false,
|
is_certification: false,
|
||||||
callback_url: ''
|
callback_url: ''
|
||||||
},
|
},
|
||||||
dingtalk: { client_id: '', client_secret: '', callback_url: '' },
|
dingtalk: { client_id: '', client_secret: '', callback_url: '' },
|
||||||
@ -184,17 +184,17 @@ const drawerTitle = computed(
|
|||||||
wechat: '公众号配置',
|
wechat: '公众号配置',
|
||||||
dingtalk: '钉钉应用配置',
|
dingtalk: '钉钉应用配置',
|
||||||
wecom: '企业微信应用配置',
|
wecom: '企业微信应用配置',
|
||||||
feishu: '飞书配置'
|
feishu: '飞书应用配置'
|
||||||
})[configType.value]
|
})[configType.value]
|
||||||
)
|
)
|
||||||
|
|
||||||
const infoTitle = computed(
|
const infoTitle = computed(
|
||||||
() =>
|
() =>
|
||||||
({
|
({
|
||||||
wechat: '微信公众号应用信息',
|
wechat: '应用信息',
|
||||||
dingtalk: '钉钉应用信息',
|
dingtalk: '应用信息',
|
||||||
wecom: '企业微信应用信息',
|
wecom: '应用信息',
|
||||||
feishu: '飞书应用信息'
|
feishu: '应用信息'
|
||||||
})[configType.value]
|
})[configType.value]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user