Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8756d8f
feat(mm-cpt): multimodal continued pre-training (raw image+text)
thad0ctor Apr 25, 2026
c8f0380
fix: address coderabbit feedback
thad0ctor Apr 25, 2026
98d7e52
fix: address coderabbit comments/nits
thad0ctor Apr 25, 2026
6cd74f0
fix: valid coderabbit nit and caplog test capture
thad0ctor Apr 25, 2026
40ea3b9
fix: gemma 4 regression, expand test coverage per Codecov
thad0ctor Apr 25, 2026
c5b3a81
fix: tests + ruff/lint
thad0ctor Apr 25, 2026
10ca6a2
fix coderabbit nit - use tokenizer for text-only rows in offender ret…
thad0ctor Apr 25, 2026
0dcf37d
fix: test (invert prefers_boi assertion to match new heuristic)
thad0ctor Apr 25, 2026
950785a
fix(mm-cpt): re-append EOS in collator after processor re-tokenization
thad0ctor May 7, 2026
85e0c6d
Align PretrainingDataset text_column/data_files with MultiModalEvalDa…
thad0ctor May 15, 2026
f413bee
Strip speculative image-loader hardening from MM CPT collator
thad0ctor May 15, 2026
c962b72
chore(lint): ruff-format mm_pretrain.py
thad0ctor May 21, 2026
b6711cf
Merge branch 'main' into multimodal-cpt
thad0ctor May 21, 2026
389e0f2
refactor(mm-cpt): scope PR to streaming path; dedupe eval schema
thad0ctor May 22, 2026
9867132
fix(mm-cpt): raise ValueError when multimodal_pretrain is used under …
thad0ctor May 25, 2026
2d85db5
Merge branch 'main' into multimodal-cpt
thad0ctor May 29, 2026
e19db6b
Merge branch 'main' into multimodal-cpt
thad0ctor May 29, 2026
67c432b
Merge branch 'main' into multimodal-cpt
thad0ctor Jun 4, 2026
39713c1
Merge branch 'main' into multimodal-cpt
thad0ctor Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions docs/multimodal.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,146 @@ Here is an example of a multi-modal dataset:
]
```

## Continued Pre-training (CPT) with images {#sec-multimodal-cpt}

Raw image+text continued pretraining — no chat template, no conversational
scaffolding. The model learns to emit raw text conditioned on visual patches.
Intended for use cases like OCR/transcription corpora where every row is a
tight `(image, target_text)` pair and any user/assistant framing would pollute
the learned signal.

::: {.callout-note}
**Currently supported via the streaming `pretraining_dataset` route only.**
A non-streaming `datasets:` route (`type: multimodal_pretrain` under
`datasets:`) is intentionally not wired — the collator selection inside
`build_collator` only routes MM CPT batches under the pretraining branch.
Configure MM CPT under `pretraining_dataset:` with `streaming: true` as
shown below.
:::

### Dataset format (JSONL)

Two keys per row: `text` (the raw string) and `images` (list of local paths).
The `text` must contain the model's placeholder token **once per image**,
placed immediately before the text it describes, followed by a newline:

```json
{"text": "<image>\nInvoice number: 10427. Total due: 148.32 USD.", "images": ["/dataset/crops/doc_14_p2.png"]}
{"text": "<image>\nChapter 3 begins with a description of the storm over the harbor.", "images": ["/dataset/crops/doc_14_p3.png"]}
```

Notes:

- Never wrap the row in `User:` / `Assistant:` / `Transcribe this:` scaffolding. That is the whole point of the CPT path.
- Do not manually append an EOS token. Axolotl appends one during tokenization.
- The newline between the placeholder and the real text preserves the BPE
boundary. Without it, some tokenizers merge the visual-token boundary with
the first real character.

### The placeholder token varies by model

| Model family | Placeholder | Notes |
|---|---|---|
| LLaVA-1.5 / 1.6 | `<image>` | |
| SmolVLM / SmolVLM2 / Idefics3 | `<image>` | Processor expands to 1088 tokens (17 tiles × 64) |
| Qwen2-VL / Qwen2.5-VL / Qwen3-VL | `<\|image_pad\|>` | Processor autowraps with `<\|vision_start\|>` / `<\|vision_end\|>` |
| Gemma-3 | `<start_of_image>` | Processor expands to 256 `<image_soft_token>` |
| Gemma-4 | `<\|image\|>` | Processor expands to 256 `<\|image\|>` |

Axolotl autodetects the placeholder from the loaded processor. If autodetection
fails, supply `image_token: <your token>` on the dataset entry.

### YAML example

```yaml
base_model: HuggingFaceTB/SmolVLM-500M-Instruct
processor_type: AutoProcessor

pretraining_dataset:
- path: /path/to/shards/*.jsonl
ds_type: json
type: multimodal_pretrain
text_column: text
image_column: images
image_base_dir: /path/to/images # optional, for relative paths
# image_token: "<image>" # optional override; autodetect by default

streaming: true
sequence_len: 2048
sample_packing: false # REQUIRED — see below
remove_unused_columns: false # auto-set by validator

max_steps: 10000
micro_batch_size: 1
gradient_accumulation_steps: 8
```

### Eval datasets

`test_datasets` accepts multimodal entries (`type: multimodal_pretrain` or
`multimodal: true`). Per-entry `text_column` and `image_column` are honored
independently. When more than one multimodal entry is provided,
`image_base_dir` and `image_token` must be either unset on every entry or
identical across them, because the eval collator resolves both once for the
merged eval stream.

### Mixed image / text-only batches

Rows with no images are allowed. All-text batches bypass the processor and
tokenize via the base tokenizer (no `pixel_values`). Mixed batches are
handed to the processor as-is.

Verified VLM families:

- HuggingFace `SmolVLM` / `SmolVLM2`
- Google `Gemma-3` (4B), `Gemma-4` (E2B, 31B)
- Alibaba `Qwen2.5-VL`, `Qwen3-VL`

For other families (LLaVA, Qwen2-VL, Idefics3), the collator's per-row
retry isolates and names the offending row on processor failure — failures
are loud, not silent. Smoke-test before long runs, or pre-filter rows
without images.

### Gates and rejections

The following combinations are rejected at config-load time with a clear error:

- `sample_packing: true` — cross-row packing would break the 1-to-1 alignment
between text placeholders and `pixel_values`.
- `chat_template` set to anything — defeats the purpose of the CPT path.
- `processor_type` unset — no processor means no image tensors.
- Multiple MM `test_datasets` entries with mismatched `image_base_dir` or `image_token`.

In addition, the following model families are **not supported** in v1 and will
be rejected when their processor is loaded:

- **Llama-3.2-Vision (Mllama)** — uses cross-attention image injection, not
in-stream placeholders. Use chat-template SFT.
- **Pixtral** — requires `mistral_common` and a different API.
- **InternVL** — ships a custom processor that doesn't produce `pixel_values`.

Per-row validation: at encode time the row's text is tokenized once and the
number of `image_token_id` occurrences in the resulting token-id list must
equal `len(images)`. Counting by token id (not by substring) avoids false
matches — e.g., `<image>` would substring-match inside `<image_soft_token>`.
This is a critical guardrail — LLaVA and Qwen-VL processors silently
accept rows without placeholders and drop the image, which looks like
successful training but teaches nothing. If a row fails this check,
inspect the tokenized ids rather than the raw string.

### Why masking image tokens in labels is automatic

For this multimodal CPT path, the collator masks every image-family token id
(`<image>`, `<\|image_pad\|>`, `<\|vision_start\|>`, `<\|vision_end\|>`,
`<start_of_image>`, `<end_of_image>`, `<image_soft_token>`, `<\|image\|>`,
etc.) to `-100` in the labels tensor.

Supported processors expand image placeholders into vision-specific token ids.
If those ids contribute to loss, the model is trained to predict patch or
marker tokens rather than only the target text. On architectures like
Qwen-VL and SmolVLM, leaving them unmasked substantially increases loss and can
destabilize training.

Comment thread
thad0ctor marked this conversation as resolved.
## FAQ

1. `PIL.UnidentifiedImageError: cannot identify image file ...`
Expand Down
71 changes: 71 additions & 0 deletions src/axolotl/core/builders/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.collators.mm_pretrain import MultiModalPretrainDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger

Expand All @@ -67,6 +68,27 @@ def _warn_if_num_workers_zero_for_mm(cfg, log) -> None:
)


def _is_multimodal_cpt(cfg) -> bool:
if not getattr(cfg, "pretraining_dataset", None):
return False
ds_first = cfg.pretraining_dataset[0]
ds_type = None
mm_flag = None
if hasattr(ds_first, "type"):
ds_type = getattr(ds_first, "type", None)
mm_flag = getattr(ds_first, "multimodal", None)
elif isinstance(ds_first, dict):
ds_type = ds_first.get("type")
mm_flag = ds_first.get("multimodal")
return (ds_type == "multimodal_pretrain") or bool(mm_flag)


def _mm_cpt_get(pt_cfg, key, default=None):
if isinstance(pt_cfg, dict):
return pt_cfg.get(key, default)
return getattr(pt_cfg, key, default)


class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
Expand Down Expand Up @@ -464,13 +486,53 @@ def build(self, total_num_steps):

return trainer

def _build_mm_pretrain_collator(self, pad_to_multiple_of=None, is_eval=False):
from axolotl.prompt_strategies.multimodal_pretrain import (
build_image_token_spec,
)

if is_eval and self.cfg.test_datasets:
pt_cfg = self.cfg.test_datasets[0]
elif self.cfg.pretraining_dataset:
pt_cfg = self.cfg.pretraining_dataset[0]
else:
pt_cfg = {}
spec = build_image_token_spec(
self.processor, override=_mm_cpt_get(pt_cfg, "image_token")
)
max_length = (
self.cfg.eval_sequence_len
if is_eval and self.cfg.eval_sequence_len
else self.cfg.sequence_len
)
collator_kwargs = {
"tokenizer": self.tokenizer,
"processor": self.processor,
"image_token_spec": spec,
"image_base_dir": _mm_cpt_get(pt_cfg, "image_base_dir"),
"max_length": max_length,
}
if pad_to_multiple_of is not None:
collator_kwargs["pad_to_multiple_of"] = pad_to_multiple_of
return MultiModalPretrainDataCollator(**collator_kwargs)

def build_collator(
self,
training_args, # type: "AxolotlTrainingArguments" # type: ignore
is_eval=False,
**kwargs,
):
if training_args.pretraining:
# Intercept MM CPT before the text-only pretraining branches.
if (
self.cfg.processor_type
and self.processor
and _is_multimodal_cpt(self.cfg)
):
return self._build_mm_pretrain_collator(
pad_to_multiple_of=kwargs.get("pad_to_multiple_of"),
is_eval=is_eval,
)
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
Expand Down Expand Up @@ -532,6 +594,15 @@ def build_collator(
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if (
self.cfg.processor_type
and self.processor
and _is_multimodal_cpt(self.cfg)
):
return self._build_mm_pretrain_collator(
pad_to_multiple_of=kwargs.get("pad_to_multiple_of"),
is_eval=is_eval,
)
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
# Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
Expand Down
Loading
Loading