refactor: openai

This commit is contained in:
wxg0103 2025-04-25 17:53:22 +08:00
parent 57c6c9916e
commit 7f492b4d92
3 changed files with 142 additions and 107 deletions

View File

@ -140,20 +140,21 @@ class PermissionConstants(Enum):
TOOL_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, TOOL_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
TOOL_DEBUG = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN, TOOL_DEBUG = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
TOOL_IMPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN, TOOL_IMPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
TOOL_EXPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN, TOOL_EXPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
KNOWLEDGE_MODULE_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN, KNOWLEDGE_MODULE_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
KNOWLEDGE_MODULE_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN, KNOWLEDGE_MODULE_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
KNOWLEDGE_MODULE_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, KNOWLEDGE_MODULE_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
KNOWLEDGE_MODULE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, KNOWLEDGE_MODULE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER]) RoleConstants.USER])
def get_workspace_application_permission(self): def get_workspace_application_permission(self):
return lambda r, kwargs: Permission(group=self.value.group, operate=self.value.operate, return lambda r, kwargs: Permission(group=self.value.group, operate=self.value.operate,
resource_path= resource_path=

View File

@ -100,7 +100,10 @@ class MaxKBBaseModel(ABC):
optional_params = {} optional_params = {}
for key, value in model_kwargs.items(): for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']: if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']:
optional_params[key] = value if key == 'extra_body' and isinstance(value, dict):
optional_params = {**optional_params, **value}
else:
optional_params[key] = value
return optional_params return optional_params

View File

@ -1,15 +1,16 @@
# coding=utf-8 # coding=utf-8
import warnings from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping
from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union
import openai
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.runnables import RunnableConfig, ensure_config from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
@ -19,6 +20,65 @@ def custom_get_token_ids(text: str):
return tokenizer.encode(text) return tokenizer.encode(text)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict = {}
if 'reasoning_content' in _dict:
additional_kwargs['reasoning_content'] = _dict.get('reasoning_content')
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
elif role in ("system", "developer") or default_class == SystemMessageChunk:
if role == "developer":
additional_kwargs = {"__openai_role__": "developer"}
else:
additional_kwargs = {}
return SystemMessageChunk(
content=content, id=id_, additional_kwargs=additional_kwargs
)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content, id=id_) # type: ignore
class BaseChatOpenAI(ChatOpenAI): class BaseChatOpenAI(ChatOpenAI):
usage_metadata: dict = {} usage_metadata: dict = {}
custom_get_token_ids = custom_get_token_ids custom_get_token_ids = custom_get_token_ids
@ -26,7 +86,13 @@ class BaseChatOpenAI(ChatOpenAI):
def get_last_generation_info(self) -> Optional[Dict[str, Any]]: def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
) -> int:
if self.usage_metadata is None or self.usage_metadata == {}: if self.usage_metadata is None or self.usage_metadata == {}:
try: try:
return super().get_num_tokens_from_messages(messages) return super().get_num_tokens_from_messages(messages)
@ -44,114 +110,77 @@ class BaseChatOpenAI(ChatOpenAI):
return len(tokenizer.encode(text)) return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0) return self.get_last_generation_info().get('output_tokens', 0)
def _stream( def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
kwargs['stream_usage'] = True
for chunk in super()._stream(*args, **kwargs):
if chunk.message.usage_metadata is not None:
self.usage_metadata = chunk.message.usage_metadata
yield chunk
def _convert_chunk_to_generation_chunk(
self, self,
messages: List[BaseMessage], chunk: dict,
stop: Optional[List[str]] = None, default_chunk_class: type,
run_manager: Optional[CallbackManagerForLLMRun] = None, base_generation_info: Optional[dict],
**kwargs: Any, ) -> Optional[ChatGenerationChunk]:
) -> Iterator[ChatGenerationChunk]: if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
kwargs["stream"] = True return None
kwargs["stream_options"] = {"include_usage": True} token_usage = chunk.get("usage")
"""Set default stream_options.""" choices = (
stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs) chunk.get("choices", [])
# Note: stream_options is not a valid parameter for Azure OpenAI. # from beta.chat.completions.stream
# To support users proxying Azure through ChatOpenAI, here we only specify or chunk.get("chunk", {}).get("choices", [])
# stream_options if include_usage is set to True. )
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
# for release notes.
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs) usage_metadata: Optional[UsageMetadata] = (
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk _create_usage_metadata(token_usage) if token_usage else None
base_generation_info = {} )
if len(choices) == 0:
if "response_format" in payload and is_basemodel_subclass( # logprobs is implicitly None
payload["response_format"] generation_chunk = ChatGenerationChunk(
): message=default_chunk_class(content="", usage_metadata=usage_metadata)
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
) )
msg = chat_result.generations[0].message return generation_chunk
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
generation_chunk = super()._convert_chunk_to_generation_chunk( choice = choices[0]
chunk, if choice["delta"] is None:
default_chunk_class, return None
base_generation_info if is_first_chunk else {},
)
if generation_chunk is None:
continue
# custom code message_chunk = _convert_delta_to_message_chunk(
if len(chunk['choices']) > 0 and 'reasoning_content' in chunk['choices'][0]['delta']: choice["delta"], default_chunk_class
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][ )
'reasoning_content'] generation_info = {**base_generation_info} if base_generation_info else {}
default_chunk_class = generation_chunk.message.__class__ if finish_reason := choice.get("finish_reason"):
logprobs = (generation_chunk.generation_info or {}).get("logprobs") generation_info["finish_reason"] = finish_reason
if run_manager: if model_name := chunk.get("model"):
run_manager.on_llm_new_token( generation_info["model_name"] = model_name
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs if system_fingerprint := chunk.get("system_fingerprint"):
) generation_info["system_fingerprint"] = system_fingerprint
is_first_chunk = False
# custom code
if generation_chunk.message.usage_metadata is not None:
self.usage_metadata = generation_chunk.message.usage_metadata
yield generation_chunk
def _create_chat_result(self, logprobs = choice.get("logprobs")
response: Union[dict, openai.BaseModel], if logprobs:
generation_info: Optional[Dict] = None): generation_info["logprobs"] = logprobs
result = super()._create_chat_result(response, generation_info)
try: if usage_metadata and isinstance(message_chunk, AIMessageChunk):
reasoning_content = '' message_chunk.usage_metadata = usage_metadata
reasoning_content_enable = False
for res in response.choices: generation_chunk = ChatGenerationChunk(
if 'reasoning_content' in res.message.model_extra: message=message_chunk, generation_info=generation_info or None
reasoning_content_enable = True )
_reasoning_content = res.message.model_extra.get('reasoning_content') return generation_chunk
if _reasoning_content is not None:
reasoning_content += _reasoning_content
if reasoning_content_enable:
result.llm_output['reasoning_content'] = reasoning_content
except Exception as e:
pass
return result
def invoke( def invoke(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
config = ensure_config(config) config = ensure_config(config)
chat_result = cast( chat_result = cast(
ChatGeneration, "ChatGeneration",
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
stop=stop, stop=stop,
@ -162,7 +191,9 @@ class BaseChatOpenAI(ChatOpenAI):
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
**kwargs, **kwargs,
).generations[0][0], ).generations[0][0],
).message ).message
self.usage_metadata = chat_result.response_metadata[ self.usage_metadata = chat_result.response_metadata[
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
return chat_result return chat_result