diff --git a/agent/checkpoint_utils.py b/agent/checkpoint_utils.py index 704b040..b23b268 100644 --- a/agent/checkpoint_utils.py +++ b/agent/checkpoint_utils.py @@ -90,6 +90,8 @@ async def prepare_checkpoint_message(config, checkpointer): last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None) if last_user_msg: config.messages = [last_user_msg] - logger.info(f"Has history, sending last user message: {last_user_msg.get('content', '')[:50]}...") + from utils.fastapi_utils import extract_text_from_content + preview = extract_text_from_content(last_user_msg.get('content', '')) + logger.info(f"Has history, sending last user message: {preview[:50]}...") else: logger.info(f"No history, sending all {len(config.messages)} messages") diff --git a/routes/chat.py b/routes/chat.py index 0f03759..f06981b 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -18,7 +18,8 @@ from utils.fastapi_utils import ( process_messages, create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, call_preamble_llm, - create_stream_chunk + create_stream_chunk, + extract_text_from_content ) from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage from utils.settings import MAX_OUTPUT_TOKENS @@ -355,9 +356,9 @@ async def create_agent_and_generate_response( "finish_reason": "stop" }], usage={ - "prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages), + "prompt_tokens": sum(len(extract_text_from_content(msg.get("content", ""))) for msg in config.messages), "completion_tokens": len(response_text), - "total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text) + "total_tokens": sum(len(extract_text_from_content(msg.get("content", ""))) for msg in config.messages) + len(response_text) } ) @@ -391,6 +392,9 @@ async def _save_user_messages(config: AgentConfig) -> None: if isinstance(msg, dict): role = msg.get("role", "") content = msg.get("content", "") + # Flatten multimodal list content to plain text before persisting, + # so base64 image data is not stored in chat history. + content = extract_text_from_content(content) if role == "user" and content: # ============ Execute PreSave hooks ============ processed_content = await execute_hooks('PreSave', config, content=content, role=role) diff --git a/utils/api_models.py b/utils/api_models.py index 7ff4bce..4061bad 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -3,12 +3,16 @@ API data models and response schemas. """ -from typing import Dict, List, Optional, Any, AsyncGenerator +from typing import Dict, List, Optional, Any, AsyncGenerator, Union from pydantic import BaseModel, Field, field_validator, ConfigDict class Message(BaseModel): role: str - content: str + # content can be a plain string, or a list of content blocks for multimodal + # input (e.g. text + image). Both OpenAI-style ({"type": "image_url", ...}) + # and LangChain standard blocks ({"type": "image", ...}) are accepted; they + # are normalized later in process_messages. + content: Union[str, List[Dict[str, Any]]] class DatasetRequest(BaseModel): diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index 6ab6762..cd5955f 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -232,6 +232,55 @@ def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, fin # return full_text +def normalize_content_blocks(content: Union[str, List[Dict[str, Any]]]) -> Union[str, List[Dict[str, Any]]]: + """Normalize multimodal content blocks into LangChain standard content blocks. + + Accepts both OpenAI-style blocks ({"type": "image_url", "image_url": {"url": ...}}) + and LangChain standard blocks ({"type": "image", "base64"/"url": ...}), and emits + LangChain standard blocks so the provider's block_translator can auto-convert for + either OpenAI or Anthropic. Plain string content is returned unchanged. + """ + if not isinstance(content, list): + return content + + normalized: List[Dict[str, Any]] = [] + for block in content: + if not isinstance(block, dict): + # Treat a bare string inside the list as a text block. + if isinstance(block, str): + normalized.append({"type": "text", "text": block}) + continue + + block_type = block.get("type") + + if block_type == "text": + normalized.append({"type": "text", "text": block.get("text", "")}) + elif block_type == "image_url": + # OpenAI-style image block: {"type": "image_url", "image_url": {"url": ...}} + image_url = block.get("image_url") + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if not url: + continue + if isinstance(url, str) and url.startswith("data:"): + # data:;base64, + try: + header, data = url.split(",", 1) + mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/jpeg" + normalized.append({"type": "image", "base64": data, "mime_type": mime_type}) + except ValueError: + logger.warning("Skipping malformed data URL in image_url block") + else: + normalized.append({"type": "image", "url": url}) + elif block_type == "image": + # Already a LangChain standard image block; pass through. + normalized.append(block) + else: + # Unknown block type; pass through untouched. + normalized.append(block) + + return normalized + + def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]: """Process message list, including [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] splitting and language directive addition. @@ -255,7 +304,7 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li # Process each message for i, msg in enumerate(messages): - if msg.role == ASSISTANT: + if msg.role == ASSISTANT and isinstance(msg.content, str): # Determine the position of this ASSISTANT message among all ASSISTANT messages (0-indexed) assistant_position = assistant_indices.index(i) @@ -315,14 +364,16 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li # If processed content is empty, use original content processed_messages.append({"role": msg.role, "content": msg.content}) else: - processed_messages.append({"role": msg.role, "content": msg.content}) + # User/other messages (or assistant messages carrying multimodal list + # content) pass through; normalize multimodal blocks to LangChain standard. + processed_messages.append({"role": msg.role, "content": normalize_content_blocks(msg.content)}) # Inverse operation: reassemble messages containing [THINK|TOOL_RESPONSE] back into # msg['role'] == 'function' and msg.get('function_call') format. # This is the inverse of get_content_from_messages. final_messages = [] for msg in processed_messages: - if msg["role"] == ASSISTANT: + if msg["role"] == ASSISTANT and isinstance(msg["content"], str): # Split message content parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"]) @@ -401,13 +452,32 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li return final_messages -def get_user_last_message_content(messages: list) -> Optional[dict]: - """Get the last message content from a message list.""" +def extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str: + """Extract plain text from message content that may be a multimodal block list.""" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + texts.append(block.get("text", "")) + elif isinstance(block, str): + texts.append(block) + return "\n".join(texts) + return "" + + +def get_user_last_message_content(messages: list) -> Optional[str]: + """Get the last user message's plain text content from a message list. + + Multimodal list content is flattened to text so downstream consumers + (e.g. terms embedding) always receive a string. + """ if not messages or len(messages) == 0: return "" last_message = messages[-1] if last_message and last_message.get('role') == 'user': - return last_message["content"] + return extract_text_from_content(last_message.get("content", "")) return "" def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str: