maxkb/apps/models_provider/impl/zhipu_model_provider/model/llm.py
2025-04-17 18:01:33 +08:00

108 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2024/4/28 11:42
@desc:
"""
import json
from collections.abc import Iterator
from typing import Any, Dict, List, Optional
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage
)
from langchain_core.outputs import ChatGenerationChunk
from models_provider.base_model_provider import MaxKBBaseModel
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
optional_params: dict
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
zhipuai_chat = ZhipuChatModel(
api_key=model_credential.get('api_key'),
model=model_name,
streaming=model_kwargs.get('streaming', False),
optional_params=optional_params,
**optional_params,
)
return zhipuai_chat
usage_metadata: dict = {}
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.usage_metadata.get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
return self.usage_metadata.get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks."""
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
default_chunk_class = AIMessageChunk
import httpx
with httpx.Client(headers=headers, timeout=60) as client:
with connect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
generation_info = {}
if "usage" in chunk:
generation_info = chunk["usage"]
self.usage_metadata = generation_info
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break