feat: support multimodal image (base64) input in chat API
Normalize OpenAI-style and LangChain standard image blocks into LangChain standard content blocks so provider block_translators auto-convert for either OpenAI or Anthropic. Flatten multimodal content to plain text when persisting history and computing term embeddings. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
79b2c35d49
commit
13bdd9d40a
@ -90,6 +90,8 @@ async def prepare_checkpoint_message(config, checkpointer):
|
|||||||
last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None)
|
last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None)
|
||||||
if last_user_msg:
|
if last_user_msg:
|
||||||
config.messages = [last_user_msg]
|
config.messages = [last_user_msg]
|
||||||
logger.info(f"Has history, sending last user message: {last_user_msg.get('content', '')[:50]}...")
|
from utils.fastapi_utils import extract_text_from_content
|
||||||
|
preview = extract_text_from_content(last_user_msg.get('content', ''))
|
||||||
|
logger.info(f"Has history, sending last user message: {preview[:50]}...")
|
||||||
else:
|
else:
|
||||||
logger.info(f"No history, sending all {len(config.messages)} messages")
|
logger.info(f"No history, sending all {len(config.messages)} messages")
|
||||||
|
|||||||
@ -18,7 +18,8 @@ from utils.fastapi_utils import (
|
|||||||
process_messages,
|
process_messages,
|
||||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||||
call_preamble_llm,
|
call_preamble_llm,
|
||||||
create_stream_chunk
|
create_stream_chunk,
|
||||||
|
extract_text_from_content
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
|
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
|
||||||
from utils.settings import MAX_OUTPUT_TOKENS
|
from utils.settings import MAX_OUTPUT_TOKENS
|
||||||
@ -355,9 +356,9 @@ async def create_agent_and_generate_response(
|
|||||||
"finish_reason": "stop"
|
"finish_reason": "stop"
|
||||||
}],
|
}],
|
||||||
usage={
|
usage={
|
||||||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
|
"prompt_tokens": sum(len(extract_text_from_content(msg.get("content", ""))) for msg in config.messages),
|
||||||
"completion_tokens": len(response_text),
|
"completion_tokens": len(response_text),
|
||||||
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
|
"total_tokens": sum(len(extract_text_from_content(msg.get("content", ""))) for msg in config.messages) + len(response_text)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -391,6 +392,9 @@ async def _save_user_messages(config: AgentConfig) -> None:
|
|||||||
if isinstance(msg, dict):
|
if isinstance(msg, dict):
|
||||||
role = msg.get("role", "")
|
role = msg.get("role", "")
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
|
# Flatten multimodal list content to plain text before persisting,
|
||||||
|
# so base64 image data is not stored in chat history.
|
||||||
|
content = extract_text_from_content(content)
|
||||||
if role == "user" and content:
|
if role == "user" and content:
|
||||||
# ============ Execute PreSave hooks ============
|
# ============ Execute PreSave hooks ============
|
||||||
processed_content = await execute_hooks('PreSave', config, content=content, role=role)
|
processed_content = await execute_hooks('PreSave', config, content=content, role=role)
|
||||||
|
|||||||
@ -3,12 +3,16 @@
|
|||||||
API data models and response schemas.
|
API data models and response schemas.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
from typing import Dict, List, Optional, Any, AsyncGenerator, Union
|
||||||
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
# content can be a plain string, or a list of content blocks for multimodal
|
||||||
|
# input (e.g. text + image). Both OpenAI-style ({"type": "image_url", ...})
|
||||||
|
# and LangChain standard blocks ({"type": "image", ...}) are accepted; they
|
||||||
|
# are normalized later in process_messages.
|
||||||
|
content: Union[str, List[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
class DatasetRequest(BaseModel):
|
class DatasetRequest(BaseModel):
|
||||||
|
|||||||
@ -232,6 +232,55 @@ def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, fin
|
|||||||
# return full_text
|
# return full_text
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_content_blocks(content: Union[str, List[Dict[str, Any]]]) -> Union[str, List[Dict[str, Any]]]:
|
||||||
|
"""Normalize multimodal content blocks into LangChain standard content blocks.
|
||||||
|
|
||||||
|
Accepts both OpenAI-style blocks ({"type": "image_url", "image_url": {"url": ...}})
|
||||||
|
and LangChain standard blocks ({"type": "image", "base64"/"url": ...}), and emits
|
||||||
|
LangChain standard blocks so the provider's block_translator can auto-convert for
|
||||||
|
either OpenAI or Anthropic. Plain string content is returned unchanged.
|
||||||
|
"""
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return content
|
||||||
|
|
||||||
|
normalized: List[Dict[str, Any]] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
# Treat a bare string inside the list as a text block.
|
||||||
|
if isinstance(block, str):
|
||||||
|
normalized.append({"type": "text", "text": block})
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_type = block.get("type")
|
||||||
|
|
||||||
|
if block_type == "text":
|
||||||
|
normalized.append({"type": "text", "text": block.get("text", "")})
|
||||||
|
elif block_type == "image_url":
|
||||||
|
# OpenAI-style image block: {"type": "image_url", "image_url": {"url": ...}}
|
||||||
|
image_url = block.get("image_url")
|
||||||
|
url = image_url.get("url") if isinstance(image_url, dict) else image_url
|
||||||
|
if not url:
|
||||||
|
continue
|
||||||
|
if isinstance(url, str) and url.startswith("data:"):
|
||||||
|
# data:<mime_type>;base64,<data>
|
||||||
|
try:
|
||||||
|
header, data = url.split(",", 1)
|
||||||
|
mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/jpeg"
|
||||||
|
normalized.append({"type": "image", "base64": data, "mime_type": mime_type})
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Skipping malformed data URL in image_url block")
|
||||||
|
else:
|
||||||
|
normalized.append({"type": "image", "url": url})
|
||||||
|
elif block_type == "image":
|
||||||
|
# Already a LangChain standard image block; pass through.
|
||||||
|
normalized.append(block)
|
||||||
|
else:
|
||||||
|
# Unknown block type; pass through untouched.
|
||||||
|
normalized.append(block)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
|
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||||||
"""Process message list, including [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] splitting and language directive addition.
|
"""Process message list, including [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] splitting and language directive addition.
|
||||||
|
|
||||||
@ -255,7 +304,7 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
|||||||
|
|
||||||
# Process each message
|
# Process each message
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
if msg.role == ASSISTANT:
|
if msg.role == ASSISTANT and isinstance(msg.content, str):
|
||||||
# Determine the position of this ASSISTANT message among all ASSISTANT messages (0-indexed)
|
# Determine the position of this ASSISTANT message among all ASSISTANT messages (0-indexed)
|
||||||
assistant_position = assistant_indices.index(i)
|
assistant_position = assistant_indices.index(i)
|
||||||
|
|
||||||
@ -315,14 +364,16 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
|||||||
# If processed content is empty, use original content
|
# If processed content is empty, use original content
|
||||||
processed_messages.append({"role": msg.role, "content": msg.content})
|
processed_messages.append({"role": msg.role, "content": msg.content})
|
||||||
else:
|
else:
|
||||||
processed_messages.append({"role": msg.role, "content": msg.content})
|
# User/other messages (or assistant messages carrying multimodal list
|
||||||
|
# content) pass through; normalize multimodal blocks to LangChain standard.
|
||||||
|
processed_messages.append({"role": msg.role, "content": normalize_content_blocks(msg.content)})
|
||||||
|
|
||||||
# Inverse operation: reassemble messages containing [THINK|TOOL_RESPONSE] back into
|
# Inverse operation: reassemble messages containing [THINK|TOOL_RESPONSE] back into
|
||||||
# msg['role'] == 'function' and msg.get('function_call') format.
|
# msg['role'] == 'function' and msg.get('function_call') format.
|
||||||
# This is the inverse of get_content_from_messages.
|
# This is the inverse of get_content_from_messages.
|
||||||
final_messages = []
|
final_messages = []
|
||||||
for msg in processed_messages:
|
for msg in processed_messages:
|
||||||
if msg["role"] == ASSISTANT:
|
if msg["role"] == ASSISTANT and isinstance(msg["content"], str):
|
||||||
# Split message content
|
# Split message content
|
||||||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||||||
|
|
||||||
@ -401,13 +452,32 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
|||||||
return final_messages
|
return final_messages
|
||||||
|
|
||||||
|
|
||||||
def get_user_last_message_content(messages: list) -> Optional[dict]:
|
def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
"""Get the last message content from a message list."""
|
"""Extract plain text from message content that may be a multimodal block list."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
texts = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
texts.append(block.get("text", ""))
|
||||||
|
elif isinstance(block, str):
|
||||||
|
texts.append(block)
|
||||||
|
return "\n".join(texts)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_last_message_content(messages: list) -> Optional[str]:
|
||||||
|
"""Get the last user message's plain text content from a message list.
|
||||||
|
|
||||||
|
Multimodal list content is flattened to text so downstream consumers
|
||||||
|
(e.g. terms embedding) always receive a string.
|
||||||
|
"""
|
||||||
if not messages or len(messages) == 0:
|
if not messages or len(messages) == 0:
|
||||||
return ""
|
return ""
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message and last_message.get('role') == 'user':
|
if last_message and last_message.get('role') == 'user':
|
||||||
return last_message["content"]
|
return extract_text_from_content(last_message.get("content", ""))
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user