Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docs/multimodal.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ chat_template: qwen2_vl # same as qwen2-vl
base_model: Qwen/Qwen3-VL-4B-Instruct

chat_template: qwen2_vl # same as qwen2-vl

# For datasets with mixed content types (string/array)
datasets:
- path: your_dataset.jsonl
type: chat_template
mixed_content_messages: true # Enable mixed content handling
```

### SmolVLM2 {#sec-smolvlm2}
Expand Down Expand Up @@ -285,6 +291,47 @@ Here is an example of a multi-modal dataset:
]
```

### Mixed Content Types

Some datasets, particularly for Qwen-VL models, may have mixed content types where system messages have string content while user messages have array content for multimodal data. This can cause JSON parsing errors.

To handle such datasets, enable the `mixed_content_messages` flag:

```yaml
datasets:
- path: your_dataset.jsonl
type: chat_template
mixed_content_messages: true # Enable mixed content handling
```

This feature automatically normalizes content types during loading, ensuring compatibility with Axolotl's training pipeline.

#### Example Mixed Content Dataset

```json
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant." // String content
},
{
"role": "user",
"content": [ // Array content for multimodal
{"type": "image", "image": "path/to/image.jpg"},
{"type": "text", "text": "What's in this image?"}
]
},
{
"role": "assistant",
"content": "I can see a beautiful landscape with mountains." // String content
}
]
}
```

The mixed content handler will normalize all content to a consistent format internally while preserving the original semantics.

## FAQ

1. `PIL.UnidentifiedImageError: cannot identify image file ...`
Expand Down
218 changes: 145 additions & 73 deletions src/axolotl/processing_strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""

import json
from copy import deepcopy
from typing import Optional

Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(
self.chat_template = chat_template
self.image_token = None
self.image_token_id = None
self.supports_multi_images = False # Override in subclasses that support multiple images

self.image_size = image_size
self.image_resize_algorithm = (
Expand Down Expand Up @@ -83,26 +85,41 @@ def convert_legacy_format(example: dict) -> dict:
result["messages"] = messages
return result

def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
def convert_messages_to_multimedia_messages(messages: list[dict], deserialize_json_content: bool = False) -> list[dict]:
"""Convert regular messages format to Messages format with content type"""

new_messages = []
for message in messages:
if isinstance(message["content"], str):
content = message["content"]

# Only try to deserialize JSON-encoded content when deserialize_json_content is True
# This is because we normalized mixed content to JSON strings during loading
if deserialize_json_content and isinstance(content, str):
# Improved JSON detection: check for valid JSON structure
content_stripped = content.strip()
if content_stripped and (
(content_stripped.startswith('[') and content_stripped.endswith(']')) or
(content_stripped.startswith('{') and content_stripped.endswith('}'))
):
try:
content = json.loads(content)
except json.JSONDecodeError:
# Not valid JSON, treat as regular string
pass

if isinstance(content, str):
new_messages.append(
{
"role": message["role"],
"content": [
{
"type": "text",
"text": message["content"],
"text": content,
}
],
}
)
elif isinstance(message["content"], list):
content = message["content"]

elif isinstance(content, list):
new_messages.append(
{
"role": message["role"],
Expand All @@ -129,8 +146,9 @@ def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:

# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
# Check if this model supports multi-images and needs special handling
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
processed_example["messages"], deserialize_json_content=self.supports_multi_images
)

# find the image key if it exists
Expand All @@ -141,80 +159,133 @@ def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
image_key = key
break

# if the image key exists, add the image to the first user message
# if the image key exists, add the images to the message
if image_key is not None and processed_example[image_key] is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
# Check if we should handle multiple images
# Debug logging
if len(processed_example[image_key]) > 1:
LOG.warning(
f"Found {len(processed_example[image_key])} images in a sample. Using the first one."
"If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages."
"See https://docs.axolotl.ai/docs/multimodal.html#dataset-format"
)

image_value = processed_example[image_key][0]

# Handle image loading (Image, url, path, base64)
image_value = load_image(image_value)
LOG.debug(f"Multiple images detected. Strategy type: {type(self).__name__}, supports_multi_images={self.supports_multi_images}")

if self.supports_multi_images and len(processed_example[image_key]) > 1:
# Qwen2-VL: Load all images
loaded_images = []
for img in processed_example[image_key]:
loaded_img = load_image(img)
loaded_images.append(loaded_img)

# Log multi-image usage for debugging
LOG.debug(f"Processing {len(loaded_images)} images in sample for Qwen2-VL")
else:
# Original behavior: take first image and warn if multiple
if len(processed_example[image_key]) > 1:
LOG.warning(
f"Found {len(processed_example[image_key])} images in a sample. Using the first one. "
"If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages. "
"See https://docs.axolotl.ai/docs/multimodal.html#dataset-format"
)

image_value = processed_example[image_key][0]
# Handle image loading (Image, url, path, base64)
image_value = load_image(image_value)
loaded_images = [image_value]

# Resize all loaded images if needed
if self.image_size is not None:
assert hasattr(image_value, "resize"), (
"Image does not have a resize method"
)

if isinstance(self.image_size, tuple):
image_value = image_value.resize(
self.image_size, self.image_resize_algorithm
)
else:
# Set the padding value; here we use black (0, 0, 0) for RGB images
padding_color = (0, 0, 0)

# When image_size is an int (square target), preserve aspect ratio then pad
# This is to prevent aspect ratio distortion when resizing to square
image_value = ImageOps.pad(
image_value,
(self.image_size, self.image_size),
method=self.image_resize_algorithm,
color=padding_color,
resized_images = []
for image_value in loaded_images:
assert hasattr(image_value, "resize"), (
"Image does not have a resize method"
)

# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
msg_ind_to_add = None
ind_to_add = None
first_user_idx = None

for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
if isinstance(self.image_size, tuple):
resized_img = image_value.resize(
self.image_size, self.image_resize_algorithm
)
else:
# Set the padding value; here we use black (0, 0, 0) for RGB images
padding_color = (0, 0, 0)

# When image_size is an int (square target), preserve aspect ratio then pad
# This is to prevent aspect ratio distortion when resizing to square
resized_img = ImageOps.pad(
image_value,
(self.image_size, self.image_size),
method=self.image_resize_algorithm,
color=padding_color,
)
resized_images.append(resized_img)
loaded_images = resized_images

# Look for image placeholders in messages
if self.supports_multi_images and len(loaded_images) > 1:
# Qwen2-VL: Map multiple images to their placeholders
image_placeholders = []
first_user_idx = None

for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break

# If an image type is found, add the image to that index
if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][msg_ind_to_add]["content"][
ind_to_add
]["image"] = image_value
# Find image placeholders
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
image_placeholders.append((msg_idx, i))

# Map loaded images to placeholders
if image_placeholders:
# If we have placeholders, map images to them in order
for idx, (msg_idx, content_idx) in enumerate(image_placeholders):
if idx < len(loaded_images):
processed_example["messages"][msg_idx]["content"][content_idx]["image"] = loaded_images[idx]
else:
# If no placeholders found, add all images to end of first user message
if first_user_idx is None:
first_user_idx = 0
for image_value in loaded_images:
processed_example["messages"][first_user_idx]["content"].append(
{
"type": "image",
"image": image_value,
}
)
else:
# if no image type is found, add it to end of the first user message
if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
{
"type": "image",
"image": image_value,
}
)
# Original single image behavior
msg_ind_to_add = None
ind_to_add = None
first_user_idx = None

for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break

# If an image type is found, add the image to that index
if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][msg_ind_to_add]["content"][
ind_to_add
]["image"] = loaded_images[0]
else:
# if no image type is found, add it to end of the first user message
if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
{
"type": "image",
"image": loaded_images[0],
}
)

processed_examples.append(remove_none_values(processed_example))

Expand Down Expand Up @@ -252,6 +323,7 @@ def __init__(
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.supports_multi_images = True # Qwen2-VL supports multiple images
self.image_token = "<|image_pad|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
Expand Down
Loading