feat: support multimodal image (base64) input in chat API
Normalize OpenAI-style and LangChain standard image blocks into LangChain standard content blocks so provider block_translators auto-convert for either OpenAI or Anthropic. Flatten multimodal content to plain text when persisting history and computing term embeddings. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
79b2c35d49
commit
13bdd9d40a
@ -90,6 +90,8 @@ async def prepare_checkpoint_message(config, checkpointer):
|
||||
last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None)
|
||||
if last_user_msg:
|
||||
config.messages = [last_user_msg]
|
||||
logger.info(f"Has history, sending last user message: {last_user_msg.get('content', '')[:50]}...")
|
||||
from utils.fastapi_utils import extract_text_from_content
|
||||
preview = extract_text_from_content(last_user_msg.get('content', ''))
|
||||
logger.info(f"Has history, sending last user message: {preview[:50]}...")
|
||||
else:
|
||||
logger.info(f"No history, sending all {len(config.messages)} messages")
|
||||
|
||||
@ -18,7 +18,8 @@ from utils.fastapi_utils import (
|
||||
process_messages,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||
call_preamble_llm,
|
||||
create_stream_chunk
|
||||
create_stream_chunk,
|
||||
extract_text_from_content
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
|
||||
from utils.settings import MAX_OUTPUT_TOKENS
|
||||
@ -355,9 +356,9 @@ async def create_agent_and_generate_response(
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
|
||||
"prompt_tokens": sum(len(extract_text_from_content(msg.get("content", ""))) for msg in config.messages),
|
||||
"completion_tokens": len(response_text),
|
||||
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
|
||||
"total_tokens": sum(len(extract_text_from_content(msg.get("content", ""))) for msg in config.messages) + len(response_text)
|
||||
}
|
||||
)
|
||||
|
||||
@ -391,6 +392,9 @@ async def _save_user_messages(config: AgentConfig) -> None:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
# Flatten multimodal list content to plain text before persisting,
|
||||
# so base64 image data is not stored in chat history.
|
||||
content = extract_text_from_content(content)
|
||||
if role == "user" and content:
|
||||
# ============ Execute PreSave hooks ============
|
||||
processed_content = await execute_hooks('PreSave', config, content=content, role=role)
|
||||
|
||||
@ -3,12 +3,16 @@
|
||||
API data models and response schemas.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator, Union
|
||||
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
# content can be a plain string, or a list of content blocks for multimodal
|
||||
# input (e.g. text + image). Both OpenAI-style ({"type": "image_url", ...})
|
||||
# and LangChain standard blocks ({"type": "image", ...}) are accepted; they
|
||||
# are normalized later in process_messages.
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class DatasetRequest(BaseModel):
|
||||
|
||||
@ -232,6 +232,55 @@ def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, fin
|
||||
# return full_text
|
||||
|
||||
|
||||
def normalize_content_blocks(content: Union[str, List[Dict[str, Any]]]) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""Normalize multimodal content blocks into LangChain standard content blocks.
|
||||
|
||||
Accepts both OpenAI-style blocks ({"type": "image_url", "image_url": {"url": ...}})
|
||||
and LangChain standard blocks ({"type": "image", "base64"/"url": ...}), and emits
|
||||
LangChain standard blocks so the provider's block_translator can auto-convert for
|
||||
either OpenAI or Anthropic. Plain string content is returned unchanged.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
# Treat a bare string inside the list as a text block.
|
||||
if isinstance(block, str):
|
||||
normalized.append({"type": "text", "text": block})
|
||||
continue
|
||||
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == "text":
|
||||
normalized.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image_url":
|
||||
# OpenAI-style image block: {"type": "image_url", "image_url": {"url": ...}}
|
||||
image_url = block.get("image_url")
|
||||
url = image_url.get("url") if isinstance(image_url, dict) else image_url
|
||||
if not url:
|
||||
continue
|
||||
if isinstance(url, str) and url.startswith("data:"):
|
||||
# data:<mime_type>;base64,<data>
|
||||
try:
|
||||
header, data = url.split(",", 1)
|
||||
mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/jpeg"
|
||||
normalized.append({"type": "image", "base64": data, "mime_type": mime_type})
|
||||
except ValueError:
|
||||
logger.warning("Skipping malformed data URL in image_url block")
|
||||
else:
|
||||
normalized.append({"type": "image", "url": url})
|
||||
elif block_type == "image":
|
||||
# Already a LangChain standard image block; pass through.
|
||||
normalized.append(block)
|
||||
else:
|
||||
# Unknown block type; pass through untouched.
|
||||
normalized.append(block)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||||
"""Process message list, including [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] splitting and language directive addition.
|
||||
|
||||
@ -255,7 +304,7 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
|
||||
# Process each message
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role == ASSISTANT:
|
||||
if msg.role == ASSISTANT and isinstance(msg.content, str):
|
||||
# Determine the position of this ASSISTANT message among all ASSISTANT messages (0-indexed)
|
||||
assistant_position = assistant_indices.index(i)
|
||||
|
||||
@ -315,14 +364,16 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
# If processed content is empty, use original content
|
||||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||||
else:
|
||||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||||
# User/other messages (or assistant messages carrying multimodal list
|
||||
# content) pass through; normalize multimodal blocks to LangChain standard.
|
||||
processed_messages.append({"role": msg.role, "content": normalize_content_blocks(msg.content)})
|
||||
|
||||
# Inverse operation: reassemble messages containing [THINK|TOOL_RESPONSE] back into
|
||||
# msg['role'] == 'function' and msg.get('function_call') format.
|
||||
# This is the inverse of get_content_from_messages.
|
||||
final_messages = []
|
||||
for msg in processed_messages:
|
||||
if msg["role"] == ASSISTANT:
|
||||
if msg["role"] == ASSISTANT and isinstance(msg["content"], str):
|
||||
# Split message content
|
||||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||||
|
||||
@ -401,13 +452,32 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
return final_messages
|
||||
|
||||
|
||||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
||||
"""Get the last message content from a message list."""
|
||||
def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract plain text from message content that may be a multimodal block list."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
texts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
texts.append(block)
|
||||
return "\n".join(texts)
|
||||
return ""
|
||||
|
||||
|
||||
def get_user_last_message_content(messages: list) -> Optional[str]:
|
||||
"""Get the last user message's plain text content from a message list.
|
||||
|
||||
Multimodal list content is flattened to text so downstream consumers
|
||||
(e.g. terms embedding) always receive a string.
|
||||
"""
|
||||
if not messages or len(messages) == 0:
|
||||
return ""
|
||||
last_message = messages[-1]
|
||||
if last_message and last_message.get('role') == 'user':
|
||||
return last_message["content"]
|
||||
return extract_text_from_content(last_message.get("content", ""))
|
||||
return ""
|
||||
|
||||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user