# coding=utf-8 """ @project: maxkb @Author:虎 @file: llm.py @date:2023/11/10 17:45 @desc: """ from typing import List, Dict, Optional, Any, Iterator from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import ( AIMessageChunk, BaseMessage, ) from langchain_core.outputs import ChatGenerationChunk from models_provider.base_model_provider import MaxKBBaseModel from models_provider.impl.base_chat_open_ai import BaseChatOpenAI class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint): @staticmethod def is_cache_model(): return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return QianfanChatModelQianfan(model=model_name, qianfan_ak=model_credential.get('api_key'), qianfan_sk=model_credential.get('secret_key'), streaming=model_kwargs.get('streaming', False), init_kwargs=optional_params) usage_metadata: dict = {} def get_last_generation_info(self) -> Optional[Dict[str, Any]]: return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: return self.usage_metadata.get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: return self.usage_metadata.get('completion_tokens', 0) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs = {**self.init_kwargs, **kwargs} params = self._convert_prompt_msg_params(messages, **kwargs) params["stop"] = stop params["stream"] = True for res in self.client.do(**params): if res: msg = _convert_dict_to_message(res) additional_kwargs = msg.additional_kwargs.get("function_call", {}) if msg.content == "" or res.get("body").get("is_end"): token_usage = res.get("body").get("usage") self.usage_metadata = token_usage chunk = ChatGenerationChunk( text=res["result"], message=AIMessageChunk( # type: ignore[call-arg] content=msg.content, role="assistant", additional_kwargs=additional_kwargs, ), generation_info=msg.additional_kwargs, ) if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def is_cache_model(): return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return QianfanChatModelOpenai( model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), extra_body=optional_params ) class QianfanChatModel(MaxKBBaseModel): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): api_version = model_credential.get('api_version', 'v1') if api_version == "v1": return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs) elif api_version == "v2": return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)