Skip to content

feat: add multimodal continued pre-training (raw image+text)#3629

Open
thad0ctor wants to merge 18 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:multimodal-cpt
Open

feat: add multimodal continued pre-training (raw image+text)#3629
thad0ctor wants to merge 18 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:multimodal-cpt

Conversation

@thad0ctor
Copy link
Copy Markdown
Contributor

@thad0ctor thad0ctor commented Apr 25, 2026

Description

Adds a streaming-first multimodal CPT path: raw (text, images[]) rows are tokenized once with a placeholder-count guardrail, batched through a hardened collator, and fed to a VLM with image-family tokens masked out of labels. Gated by type: multimodal_pretrain (or multimodal: true) on a pretraining_dataset entry; works end-to-end for train and eval, including multi-entry eval and mixed image/text batches.

Features

  • Streaming MM CPT encoder (encode_streaming_multimodal): counts placeholders by token id (not substring), enforces placeholders == len(images) per row, and rejects rows that exceed sequence_len instead of silently truncating mid-placeholder.
  • MM CPT collator (MultiModalPretrainDataCollator): security-hardened image loader (path traversal / NUL byte / remote URL / multi-frame bomb / pixel cap rejection), per-row image cap, processor-call retry that pinpoints the offending row, and label-side masking of every image-family token id.
  • Mixed image/text batches: text-only rows in a batch take a tokenizer-only fallback (no pixel_values); rows with images go through the processor as usual.
  • Eval support: test_datasets accepts MM entries via a dedicated MultiModalEvalDataset model so per-entry text_column / image_column / image_base_dir / image_token survive validation. Multi-entry MM eval streams are concatenated.
  • dispatch_batches: true support: non-main ranks get a placeholder dataset that mirrors the configured text + image columns.
  • Config validation gates: processor_type required, sample_packing: false enforced, chat_template rejected, single pretraining_dataset entry required, MM eval entries must share image_base_dir / image_token, mixed MM/non-MM eval rejected, incompatible processor classes (Mllama, Pixtral, InternVL) rejected at startup. remove_unused_columns is auto-set to false with an INFO log.
  • Docs: new section in docs/multimodal.qmd covering the YAML shape, placeholder-token table, eval contract, and supported/rejected model families.

YAML example

base_model: HuggingFaceTB/SmolVLM-500M-Instruct
processor_type: AutoProcessor

pretraining_dataset:
  - path: /path/to/shards/*.jsonl
    ds_type: json
    type: multimodal_pretrain
    text_column: text
    image_column: images
    image_base_dir: /path/to/images

streaming: true
sequence_len: 2048
sample_packing: false

Motivation and Context

Existing axolotl multimodal paths assume conversational SFT (chat templates with image placeholders inside a structured turn format). Raw image+text continued pre-training — feeding a VLM unstructured documents that interleave text and images — wasn't supported. This PR adds that path without disturbing the existing chat-template SFT flow: it's a separate type: multimodal_pretrain opt-in that reuses the streaming pretraining infrastructure and the existing processor/tokenizer wiring.

How has this been tested?

59 unit tests across four suites:

  • tests/prompt_strategies/test_multimodal_pretrain.py — encoder/strategy: placeholder count guardrail, sequence-length enforcement, falsy-image rejection, processor compatibility gate.
  • tests/test_multimodal_streaming.py — streaming encoder + collator: end-to-end batch construction, label masking, mixed-batch and all-text-batch handling, security gates (path traversal, NUL byte, remote URLs, image cap, error sanitization).
  • tests/utils/schemas/validation/test_multimodal_cpt.py — config validation: processor/packing/template gates, single pretraining-entry rule, MM eval schema preservation, eval homogeneity (image_base_dir/image_token), remove_unused_columns auto-set INFO log.
  • tests/utils/data/test_mm_cpt_eval.py — eval data path: placeholder schema for dispatch_batches, multi-entry eval merge, mixed MM/non-MM rejection, eval-aware collator config source.

Tested with HuggingFaceTB/SmolVLM-500M-Instruct, Gemma 4, Gemma 3, Qwen 2.5VL and Qwen 3.5, as the reference processor. Lint (ruff/ruff-format) and mypy clean against the patches added in this branch.

Real world testing:

In addition to the unit suite, the path was validated against two end-to-end LoRA-CPT runs on a OCR corpus (raw \ntext rows, streamed via type: multimodal_pretrain):

  • Gemma-4-31B-it, 2× GPU FSDP — LoRA r=256/α=256 (LM proj layers) with vision_tower, embed_vision, and embed_tokens in modules_to_save. 4,400 steps of a 49,000-step plan; eval_loss 9.42 → 1.53 (ppl 12,330 → 4.64), train loss0.41 at step 4,400, ~78 GB device-reserved at peak. Eight checkpoints written at 200-step cadence.
  • Gemma-4-E2B-it, single RTX 5090 — same LoRA shape, smaller base. 10,600 steps of a 98,000-step plan; eval_loss 9.82 → 0.99 (ppl 18,350 → 2.70), train loss 0.33 at step 10,600, ~18 GB peak.

Both runs trained to multi-thousand-step stability without divergence or NaN, with periodic eval against a separate test_datasets entry and image-family label masking applied throughout. Confirms the streaming encoder + collator + eval path work end-to-end in real training under FSDP and single-GPU configs.

AI Usage Disclaimer

Yes — Claude Code (Opus 4.7) was used for design discussion, code review, refactoring suggestions, and test generation. All code was reviewed and validated before commit. Codex was used for additional vetting, testing and review. Coderabbit reviews were addressed on personal fork for every commit.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation update

Summary by CodeRabbit

  • New Features

    • Multimodal continued pretraining: train/evaluate on (image, text) pairs with processor-aware tokenization, streaming encoding, and a collator that handles mixed image/text batches, processor compatibility checks, hardened image-loading, and masks image tokens in loss.
  • Documentation

    • Comprehensive multimodal guide: data format, placeholder/token rules, example configs, eval behavior, and loss‑masking semantics.
  • Configuration & Validation

    • Schema fields and prevalidation gates for multimodal datasets; rejects incompatible settings (sample packing, chat templates), enforces one-entry multimodal train, and validates eval/train modality consistency.
  • Tests

    • Extensive end-to-end and unit tests covering tokenization, streaming, collation, validation, image-loading security, fixtures, and error cases.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ec819955-4c3e-4e12-ab6d-37a76fb3fa69

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds end-to-end multimodal continued-pretraining (CPT): new ImageTokenSpec and multimodal tokenization strategy, streaming multimodal encoder, MM pretrain data collator with hardened image handling and label masking, schema/config validation gates, causal builder routing to MM collator, documentation, and comprehensive tests.

Changes

Multimodal CPT feature

Layer / File(s) Summary
Schema / Data Shape
src/axolotl/utils/schemas/datasets.py, src/axolotl/utils/schemas/config.py
Adds ds_type, multimodal, image_column, image_base_dir, image_token to PretrainingDataset; adds MultiModalEvalDataset; extends AxolotlInputConfig.test_datasets union to include multimodal eval entries.
Config Validation (pre-load)
src/axolotl/utils/schemas/validation.py
Adds check_multimodal_cpt (before-validator) enforcing multimodal gates: requires processor_type, forbids sample_packing and chat_template, forces remove_unused_columns=false, disallows multi-entry multimodal pretraining, and enforces consistency across multimodal eval entries.
Tokenization / Strategy
src/axolotl/prompt_strategies/multimodal_pretrain.py
New ImageTokenSpec, runtime processor compatibility checks, placeholder autodetection/override, image-family token id set, MultimodalPretrainTokenizationStrategy with placeholder-count validation and load() factory.
Streaming Encoder
src/axolotl/utils/data/streaming.py
Adds encode_streaming_multimodal, updates wrap_streaming_dataset signature to accept processor/pretraining_config/is_eval, derives image token spec for multimodal entries, and routes multimodal datasets to the new encoder.
SFT / Dataset Loading Wiring
src/axolotl/utils/data/sft.py
Threads processor into streaming loaders, multimodal-aware eval streaming, introduces _pretraining_config_from_entry, and normalizes pretraining-config extraction.
Collation / Batching
src/axolotl/utils/collators/mm_pretrain.py
Adds MultiModalPretrainDataCollator (dataclass) with hardened path resolution, secure open, max-pixels/frame checks, per-row image loading and skip/drop semantics, processor-based text+image batching, all-text fallback, and masking of pad + image-family token IDs in labels.
Causal Builder Routing
src/axolotl/core/builders/causal.py
Adds multimodal CPT detection helpers and routes pretraining/eval collator construction to the new MM pretrain collator (early-return in build_collator).
Documentation
docs/multimodal.qmd
New multimodal CPT documentation: JSONL schema, strict per-image placeholder rules, model-family placeholder mappings and autodetection/override behavior, YAML example forbidding sample packing, eval handling rules, and labeling/masking semantics.
Tests / Fixtures
tests/conftest.py, tests/prompt_strategies/test_multimodal_pretrain.py, tests/test_multimodal_streaming.py, tests/utils/..., tests/utils/schemas/validation/test_multimodal_cpt.py
Adds session-scoped HF offline fixture; comprehensive tests for image-token detection, processor compatibility, multimodal tokenization strategy, streaming encoder, collator behavior (security/robustness and masking), eval-stream handling, and config validation gates.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • winglian
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.06% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly summarizes the main change: adding multimodal continued pre-training support for raw image+text data.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@thad0ctor thad0ctor changed the title feat(mm-cpt): multimodal continued pre-training (raw image+text) feat: add multimodal continued pre-training (raw image+text) Apr 25, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (7)
tests/utils/data/test_mm_cpt_eval.py (1)

93-93: Nit: prefix unused unpacked variable.

train is never read; rename to _train to silence Ruff RUF059 and signal intent.

♻️ Suggested change
-    train, eval_ds, _, _ = _prepare_streaming_dataset(
+    _train, eval_ds, _, _ = _prepare_streaming_dataset(
         cfg, tokenizer=None, processor=None
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/data/test_mm_cpt_eval.py` at line 93, Unpack the return of
_prepare_streaming_dataset using a prefixed underscore for the unused variable
so the intent is clear and Ruff RUF059 is silenced: change the left-hand side of
the call that currently uses "train, eval_ds, _, _ =
_prepare_streaming_dataset(...)" to use "_train, eval_ds, _, _ =
_prepare_streaming_dataset(...)" so the unused first value is named _train while
keeping eval_ds and the other placeholders unchanged.
src/axolotl/utils/collators/mm_pretrain.py (3)

211-233: All-text fallback uses self.tokenizer, but image batches go through self.processor.

When all rows are text-only, this path bypasses the processor and tokenizes via self.tokenizer directly. That implicitly assumes self.tokenizer is self.processor.tokenizer; if a caller ever constructs the collator with a different tokenizer (e.g. a wrapped one that adds chat scaffolding), text-only batches and image batches will tokenize differently for the same input. Two low-cost mitigations:

  1. Assert the invariant in __post_init__:
    if getattr(self.processor, "tokenizer", None) is not self.tokenizer:
        LOG.warning(
            "MultiModalPretrainDataCollator.tokenizer is not processor.tokenizer; "
            "all-text and image batches may tokenize inconsistently."
        )
  2. Or route the all-text path through self.processor(text=texts, ...) if the processor supports text-only invocation.

Not a correctness bug today, but it's a footgun that won't surface until someone wires it differently.

I found an important concern with the mixed-batch path — let me add that to the review.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 211 - 233, The
all-text fallback uses self.tokenizer while image batches go through
self.processor, which can cause inconsistent tokenization; add a check in
MultiModalPretrainDataCollator.__post_init__ to warn if getattr(self.processor,
"tokenizer", None) is not self.tokenizer (log a clear warning mentioning
potential inconsistent tokenization between all-text and image batches), and/or
change the all-text branch in the collate code to call
self.processor(text=texts, return_tensors=self.return_tensors,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) when the
processor supports a text-only invocation so both paths use the same tokenizer
pipeline.

245-282: Retry path issues N additional processor calls per failed batch.

When a batch fails, the loop replays each row individually until the first reproducer is found. For a 32-row batch with the offender at row 31, that's 31 successful processor calls before the diagnostic — non-trivial on GPU/CPU when the processor does real image preprocessing. Suggest at least logging the retry kick-off at WARNING so users see why a single failure triggers a noticeable stall, and consider bisecting (binary search) instead of linear scan.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 245 - 282, The retry
loop inside MultiModalPretrainDataCollator that replays each row individually
(the for-loop that calls self.processor(...) to set offender_idx) can cause N
extra heavy processor calls; before starting the per-row retry, emit a
processLogger.warning noting the retry/bisect kickoff and batch size so users
see why processing stalls, and replace the linear scan with a binary-search
(bisect) strategy that tests halves of the remaining index range by calling
self.processor(text=[...], images=[...], **retry_kwargs) to locate the offending
index in O(log N) calls instead of O(N) — keep offender_idx, retry_ok, and
retry_kwargs semantics and preserve the existing RuntimeError message when
locating or failing to locate the offender.

61-68: Closed scheme list misses cloud/HF URIs (s3://, gs://, az://, hf://, …).

The current allowlist rejects http(s), ftp(s), file, data, and UNC, but s3://bucket/key.png, gs://…, azure://…, hf://… would not match. Without image_base_dir they'd fall through to os.open and fail with a confusing FileNotFoundError rather than the explicit “Non-local image path scheme” error users (and tests) expect. Prefer a generic URI-scheme detector — anything matching ^[a-z][a-z0-9+.-]*: (case-insensitive) other than a Windows drive letter (C:) is non-local.

🛡️ Suggested change
+import re
+
+# Generic URI scheme: <scheme>:<...> per RFC 3986 (excluding Windows drive letters).
+_URI_SCHEME_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9+.\-]*:")
+
 ...
-        p_lower = p.lower()
-        if p_lower.startswith(
-            ("http://", "https://", "ftp://", "ftps://", "file://", "data:")
-        ) or p.startswith(("\\\\", "//")):
+        m = _URI_SCHEME_RE.match(p)
+        # Exclude Windows drive letters like "C:\..." (single-letter scheme).
+        is_uri = m is not None and len(m.group(0).rstrip(":")) > 1
+        if is_uri or p.startswith(("\\\\", "//")):
             raise ValueError(
                 f"Non-local image path scheme is not supported in v1 "
                 f"multimodal CPT (got {p!r})."
             )

If you adopt this, also add s3://… / gs://… to test_collator_rejects_remote_urls so the contract is locked down.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 61 - 68, The current
scheme check in mm_pretrain (using p and p_lower) misses cloud/third-party URIs
like s3://, gs://, az://, hf://; replace the explicit startswith list with a
generic URI-scheme detector: if p matches the regex ^[a-z][a-z0-9+.-]*:
(case-insensitive) treat it as non-local, but exempt Windows drive letters
(e.g., patterns like C: followed by slash/backslash). Update the code path that
raises the ValueError (the block that currently raises "Non-local image path
scheme..." when p_lower.startswith(...)) to use this new check, and add s3://
and gs:// (and other cloud URIs as needed) to test_collator_rejects_remote_urls
to lock down the behavior.
src/axolotl/prompt_strategies/multimodal_pretrain.py (3)

244-249: Note: len(ids) > self.max_length is a text-side check; image-token expansion happens later.

The encoder rejects rows whose text tokenizes beyond sequence_len, but each placeholder expands at the processor into many patch tokens, so the actual model input can still exceed sequence_len. The collator warns post-hoc (mm_pretrain.py:286-292), which is fine — but consider clarifying the error here so users don't conclude the budget refers to post-expansion length. Optional wording tweak: append “(text-side; image patch expansion happens at the processor and will further inflate the sequence)”.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 244 - 249,
The ValueError message raised when checking if len(ids) > self.max_length is
misleading because it only refers to the text-side token count while image
placeholders expand into many patch tokens later; update the error text in the
multimodal_pretrain.py check (the block that examines ids and self.max_length)
to clearly state this is a text-side limit and that image patch expansion at the
processor can further inflate the sequence (e.g., append “(text-side; image
patch expansion happens at the processor and will further inflate the
sequence)”), so users don't mistake this for the post-expansion budget.

19-35: Hoist import importlib out of the per-class loop.

Cheap import-cache hit, but this is a module-level helper running at import time and it's clearer to put import importlib at the top of the module alongside the other stdlib imports.

♻️ Suggested change
-def _get_incompatible_processor_classes() -> tuple[type, ...]:
-    classes: list[type] = []
-    for mod_path, name in (
-        ("transformers.models.mllama", "MllamaProcessor"),
-        ("transformers.models.pixtral", "PixtralProcessor"),
-        ("transformers.models.internvl", "InternVLProcessor"),
-    ):
-        try:
-            import importlib
-
-            mod = importlib.import_module(mod_path)
+def _get_incompatible_processor_classes() -> tuple[type, ...]:
+    import importlib
+
+    classes: list[type] = []
+    for mod_path, name in (
+        ("transformers.models.mllama", "MllamaProcessor"),
+        ("transformers.models.pixtral", "PixtralProcessor"),
+        ("transformers.models.internvl", "InternVLProcessor"),
+    ):
+        try:
+            mod = importlib.import_module(mod_path)
             cls = getattr(mod, name, None)
             if cls is not None:
                 classes.append(cls)
         except ImportError:
             continue
     return tuple(classes)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 19 - 35,
The helper _get_incompatible_processor_classes currently imports importlib
inside the loop for each (mod_path, name) tuple; move the importlib import to
the module top (with other stdlib imports) and remove the per-iteration import
so the loop simply uses the already-imported importlib to import modules by
mod_path and getattr the class name, keeping logic around mod =
importlib.import_module(mod_path), cls = getattr(mod, name, None),
classes.append(cls) and the ImportError handling unchanged.

109-117: Don't silently swallow errors from get_added_vocab().

except Exception: pass masks tokenizer bugs (e.g. AttributeError on a custom wrapper) and silently degrades the special-token set used to validate overrides and build the family mask. At minimum, log at DEBUG so operators can correlate “image_token override rejected as not a registered special token” with the underlying cause.

♻️ Suggested change
-    known_special_tokens: set[str] = set()
-    try:
-        known_special_tokens |= set(tokenizer.get_added_vocab().keys())
-    except Exception:
-        pass
+    known_special_tokens: set[str] = set()
+    try:
+        known_special_tokens |= set(tokenizer.get_added_vocab().keys())
+    except Exception as exc:  # noqa: BLE001
+        LOG.debug("get_added_vocab() failed on %s: %s", type(tokenizer).__name__, exc)
     known_special_tokens |= set(getattr(tokenizer, "all_special_tokens", None) or [])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 109 - 117,
The try/except around tokenizer.get_added_vocab() is silently swallowing errors
which hides tokenizer bugs and causes silent degradation of
known_special_tokens; replace the bare except with catching Exception as e and
emit a DEBUG-level log including the exception details and context (e.g. "failed
to call tokenizer.get_added_vocab()") before continuing so known_special_tokens
is still populated from getattr(tokenizer, 'all_special_tokens'...) and
getattr(tokenizer, 'additional_special_tokens'...) as before; reference the
tokenizer.get_added_vocab() call and the known_special_tokens update in
multimodal_pretrain.py and use the module or existing logger (logger or
logging.getLogger(__name__)) to record the exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 234-282: The collator (MultiModalPretrainDataCollator) currently
passes mixed `images` (some empty lists) directly to `self.processor(texts,
images=images...)`, which may break or be undefined for many VLM processors;
modify the collator's processing path (where `batch =
self.processor(**proc_kwargs)` is called, e.g., inside the torch_call /
processor invocation) to detect mixed batches (any(len(im)==0 for im in images)
and not all(len(im)==0 for im in images)) and handle them by splitting into
homogeneous sub-batches (text-only rows vs image-containing rows) or falling
back to per-row `self.processor(text=[t], images=[imgs], **retry_kwargs)`
aggregation of the returned tensors into a single batched output, ensuring
shapes/padding/return_tensors remain consistent; additionally add doc notes in
docs/multimodal.qmd listing which processors support mixed batches or require
homogeneous input and expand tests to include other VLM processors to assert
behavior.

In `@src/axolotl/utils/data/sft.py`:
- Around line 211-226: The multimodal normalization helpers drop the
trust_remote_code flag so load_dataset never receives it; update
_pretraining_config_from_entry to include "trust_remote_code":
entry.get("trust_remote_code") in the returned DictDefault (and make the same
change in the analogous eval helper referenced around lines 278-284) so
MultiModalEvalDataset's trust_remote_code is preserved and propagated to
load_dataset.
- Around line 167-204: The multimodal eval branch reuses the training
pretraining config and therefore ignores cfg.eval_sequence_len; fix by, for each
multimodal eval entry (in the loop that builds eval_streams), after eval_config
= _pretraining_config_from_entry(entry) override the sequence length used for
evaluation by setting eval_config.sequence_len = cfg.eval_sequence_len (or the
appropriate eval length field if the pretraining config uses a different
attribute name) when cfg.eval_sequence_len is present, then pass that
eval_config into _load_streaming_dataset so the streaming eval loader uses the
intended eval length.

In `@src/axolotl/utils/schemas/datasets.py`:
- Around line 241-307: PretrainingDataset and MultiModalEvalDataset are missing
the documented ds_type loader hint so configs using
pretraining_dataset[].ds_type: json are dropped; add a ds_type: str | None =
Field(default=None, json_schema_extra={"description":"Loader type for the
dataset (e.g. 'json', 'csv') used to choose the dataset loader"}) to both
PretrainingDataset and MultiModalEvalDataset (and any other dataset model
variants that accept file-backed inputs) so the schema preserves the loader hint
during parsing; reference the PretrainingDataset and MultiModalEvalDataset class
definitions and add the Field in each model.

In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1413-1428: The code currently only ensures shared multimodal
fields across multimodal test entries but allows mixing modalities between train
and test or within test entries; detect modality per entry using _entry_is_mm
and reject mixed modalities: compute mm_train/text_train from
data.get("train_datasets") and mm_test/text_test from data.get("test_datasets")
(use the same isinstance(dict) + _entry_is_mm checks used for mm_test), then
raise a ValueError if (mm_train and text_test) or (mm_test and text_train) to
block train/eval modality mismatches, and also raise if mm_test and any
text-only entries exist in test_datasets to prevent mixed test entries; include
clear messages referencing train_datasets/test_datasets and multimodal vs
text-only to guide the user.

In `@tests/utils/schemas/validation/test_multimodal_cpt.py`:
- Around line 250-258: CI failed because ruff-format reformatted
tests/utils/schemas/validation/test_multimodal_cpt.py; re-run the formatter and
commit the changes. Run the formatter (e.g., pre-commit run --all-files or ruff
format .), review the updated formatting for the test that references caplog,
matches, logging.INFO and "remove_unused_columns", and commit the resulting file
so the lint hook is satisfied.

---

Nitpick comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 244-249: The ValueError message raised when checking if len(ids) >
self.max_length is misleading because it only refers to the text-side token
count while image placeholders expand into many patch tokens later; update the
error text in the multimodal_pretrain.py check (the block that examines ids and
self.max_length) to clearly state this is a text-side limit and that image patch
expansion at the processor can further inflate the sequence (e.g., append
“(text-side; image patch expansion happens at the processor and will further
inflate the sequence)”), so users don't mistake this for the post-expansion
budget.
- Around line 19-35: The helper _get_incompatible_processor_classes currently
imports importlib inside the loop for each (mod_path, name) tuple; move the
importlib import to the module top (with other stdlib imports) and remove the
per-iteration import so the loop simply uses the already-imported importlib to
import modules by mod_path and getattr the class name, keeping logic around mod
= importlib.import_module(mod_path), cls = getattr(mod, name, None),
classes.append(cls) and the ImportError handling unchanged.
- Around line 109-117: The try/except around tokenizer.get_added_vocab() is
silently swallowing errors which hides tokenizer bugs and causes silent
degradation of known_special_tokens; replace the bare except with catching
Exception as e and emit a DEBUG-level log including the exception details and
context (e.g. "failed to call tokenizer.get_added_vocab()") before continuing so
known_special_tokens is still populated from getattr(tokenizer,
'all_special_tokens'...) and getattr(tokenizer, 'additional_special_tokens'...)
as before; reference the tokenizer.get_added_vocab() call and the
known_special_tokens update in multimodal_pretrain.py and use the module or
existing logger (logger or logging.getLogger(__name__)) to record the exception.

In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 211-233: The all-text fallback uses self.tokenizer while image
batches go through self.processor, which can cause inconsistent tokenization;
add a check in MultiModalPretrainDataCollator.__post_init__ to warn if
getattr(self.processor, "tokenizer", None) is not self.tokenizer (log a clear
warning mentioning potential inconsistent tokenization between all-text and
image batches), and/or change the all-text branch in the collate code to call
self.processor(text=texts, return_tensors=self.return_tensors,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) when the
processor supports a text-only invocation so both paths use the same tokenizer
pipeline.
- Around line 245-282: The retry loop inside MultiModalPretrainDataCollator that
replays each row individually (the for-loop that calls self.processor(...) to
set offender_idx) can cause N extra heavy processor calls; before starting the
per-row retry, emit a processLogger.warning noting the retry/bisect kickoff and
batch size so users see why processing stalls, and replace the linear scan with
a binary-search (bisect) strategy that tests halves of the remaining index range
by calling self.processor(text=[...], images=[...], **retry_kwargs) to locate
the offending index in O(log N) calls instead of O(N) — keep offender_idx,
retry_ok, and retry_kwargs semantics and preserve the existing RuntimeError
message when locating or failing to locate the offender.
- Around line 61-68: The current scheme check in mm_pretrain (using p and
p_lower) misses cloud/third-party URIs like s3://, gs://, az://, hf://; replace
the explicit startswith list with a generic URI-scheme detector: if p matches
the regex ^[a-z][a-z0-9+.-]*: (case-insensitive) treat it as non-local, but
exempt Windows drive letters (e.g., patterns like C: followed by
slash/backslash). Update the code path that raises the ValueError (the block
that currently raises "Non-local image path scheme..." when
p_lower.startswith(...)) to use this new check, and add s3:// and gs:// (and
other cloud URIs as needed) to test_collator_rejects_remote_urls to lock down
the behavior.

In `@tests/utils/data/test_mm_cpt_eval.py`:
- Line 93: Unpack the return of _prepare_streaming_dataset using a prefixed
underscore for the unused variable so the intent is clear and Ruff RUF059 is
silenced: change the left-hand side of the call that currently uses "train,
eval_ds, _, _ = _prepare_streaming_dataset(...)" to use "_train, eval_ds, _, _ =
_prepare_streaming_dataset(...)" so the unused first value is named _train while
keeping eval_ds and the other placeholders unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4a5ecec3-baad-4185-89f4-7f39775023c8

📥 Commits

Reviewing files that changed from the base of the PR and between 798c8fb and 986a25a.

📒 Files selected for processing (14)
  • docs/multimodal.qmd
  • src/axolotl/core/builders/causal.py
  • src/axolotl/prompt_strategies/multimodal_pretrain.py
  • src/axolotl/utils/collators/mm_pretrain.py
  • src/axolotl/utils/data/sft.py
  • src/axolotl/utils/data/streaming.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/datasets.py
  • src/axolotl/utils/schemas/validation.py
  • tests/conftest.py
  • tests/prompt_strategies/test_multimodal_pretrain.py
  • tests/test_multimodal_streaming.py
  • tests/utils/data/test_mm_cpt_eval.py
  • tests/utils/schemas/validation/test_multimodal_cpt.py

Comment thread src/axolotl/utils/collators/mm_pretrain.py
Comment thread src/axolotl/utils/data/sft.py
Comment thread src/axolotl/utils/data/sft.py
Comment thread src/axolotl/utils/schemas/datasets.py Outdated
Comment thread src/axolotl/utils/schemas/validation.py
Comment thread tests/utils/schemas/validation/test_multimodal_cpt.py
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request Apr 25, 2026
  Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the
  happy path; expands schema, hardens fallbacks, tightens validation.

  Bug fixes
  ---------
  - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token`
    when they differ (Gemma-3's `image_token` is the post-expansion soft
    token, not the user-facing placeholder). Without this, MM CPT crashed
    on the first batch with "Prompt contained 0 image tokens".
  - `dispatch_batches: true` placeholder dataset now mirrors the
    configured `image_column` so worker ranks don't KeyError on empty
    rows.
  - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`,
    `False`) instead of coercing to `[]` — keeps malformed rows from
    silently turning into text-only samples.

  Schema completeness
  -------------------
  - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset`
    (the documented `ds_type: json` shape now actually reaches
    `load_dataset`; previously dropped at validation).
  - Preserve `trust_remote_code` through `_pretraining_config_from_entry`
    and pass it to `load_dataset` (was silently dropped).
  - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator)
    with documented fallback to `cfg.sequence_len` when unset.

  Validation tightening (config-load time)
  ----------------------------------------
  - Reject mixed multimodal/text entries in `test_datasets`.
  - Reject MM `test_datasets` paired with non-MM training.
  - Reject non-MM `test_datasets` paired with MM training.
  - The redundant runtime check in `sft.py` is removed; schema is the
    single source of truth.

  Hardening / observability
  -------------------------
  - Mixed/all-text batch handling: collator routes all-text batches to
    the tokenizer (no `pixel_values`); mixed batches go through the
    processor as-is. Documented per-VLM compatibility (verified on
    SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL).
  - Reject cloud/object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`,
    `azure://`, `hf://`) in image paths so users see "Non-local scheme"
    instead of a confusing FileNotFoundError.
  - Warn when `MultiModalPretrainDataCollator.tokenizer is not
    processor.tokenizer` (all-text vs image batches could otherwise
    tokenize the same text differently).
  - Warn at retry kickoff when a processor call fails on a batch, so
    users see why processing stalls during per-row diagnosis.
  - INFO log when `remove_unused_columns` is auto-set to `false` for
    MM CPT.
  - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass).
  - Clarify "exceeds sequence_len" error to note image-patch expansion
    may push the final length higher.

  Tests
  -----
  +8 regression tests across the four MM CPT suites covering: Gemma-3
  boi_token autodetection, eval_sequence_len (encoder + collator,
  including the fallback case), `trust_remote_code` and `ds_type`
  preservation through validation, three modality-mismatch validation
  cases, tokenizer-mismatch warning, cloud-URI rejection. 68 tests pass
  across `tests/test_multimodal_streaming.py`,
  `tests/prompt_strategies/test_multimodal_pretrain.py`,
  `tests/utils/schemas/validation/test_multimodal_cpt.py`,
  `tests/utils/data/test_mm_cpt_eval.py`. Lint clean against ruff
  v0.15.8 (upstream pre-commit pin).
@thad0ctor
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/utils/data/streaming.py (1)

268-296: ⚠️ Potential issue | 🟠 Major

Honor effective_seq_len in the packed eval path too.

Line 268 computes effective_seq_len, but Lines 281-292 still hardcode cfg.sequence_len. With sample_packing: true, eval therefore ignores eval_sequence_len, which changes truncation/padding behavior and can unexpectedly blow up eval memory.

♻️ Proposed fix
         collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
             tokenizer,
             return_tensors="pt",
             padding=True,
-            pad_to_multiple_of=cfg.sequence_len,
+            pad_to_multiple_of=effective_seq_len,
             multipack_attn=multipack_attn,
         )
         encode = functools.partial(
             encode_packed_streaming,
             collate_fn,
             ds_wrapper_fn,
-            max_seq_length=cfg.sequence_len,
+            max_seq_length=effective_seq_len,
             batch_size=cfg.micro_batch_size,
             multipack_attn=multipack_attn,
             bin_size=cfg.sample_packing_bin_size,
         )

Based on learnings: if cfg.eval_sequence_len is not set, eval must fall back to cfg.sequence_len; otherwise the explicit eval override should be honored.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/data/streaming.py` around lines 268 - 296, The
packed-sample eval path is still using cfg.sequence_len instead of the computed
effective_seq_len, so update the PretrainingBatchSamplerDataCollatorForSeq2Seq
instantiation and the functools.partial for encode_packed_streaming to use
effective_seq_len (not cfg.sequence_len) for pad_to_multiple_of and
max_seq_length; keep the existing multipack_attn logic (True if not
cfg.pretraining_dataset else cfg.pretrain_multipack_attn) and other args
unchanged so eval falls back to cfg.sequence_len only when cfg.eval_sequence_len
is not set.
🧹 Nitpick comments (5)
tests/prompt_strategies/test_multimodal_pretrain.py (1)

72-84: Assert the resolved token id here, not just the chosen token string.

This only proves that boi_token wins the name-selection step. If build_image_token_spec() regresses to returning the right token string but the wrong image_token_id/family ids, this test still passes. Please make the fixture register a distinct boi_token id and assert the resolved id mapping too.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/prompt_strategies/test_multimodal_pretrain.py` around lines 72 - 84,
The test currently only asserts the chosen token string; update it to also
assert the resolved token id so regressions in id mapping are caught: arrange
the smolvlm_processor.tokenizer fixture (used by _FakeGemma3Like) to register a
distinct id for the fake boi_token ("<fake_token_around_image>") different from
"<image>", call build_image_token_spec(_FakeGemma3Like()) as before, and add an
assertion that spec.image_token_id (and any related family ids such as
spec.image_padding_token_family_id if present) equals the tokenizer id you
registered for "<fake_token_around_image>" to ensure the returned spec maps to
the expected token id.
src/axolotl/prompt_strategies/multimodal_pretrain.py (4)

299-312: Optional: store image_token_spec in __init__ rather than monkey-patching.

Setting strat.image_token_spec = spec after construction with a type: ignore[attr-defined] works but bypasses the class contract — readers of MultimodalPretrainTokenizationStrategy won't see this attribute on the type. Since causal.py reads it back, lifting it into the constructor would clean up the type ignore and document the dependency.

Proposed change
 class MultimodalPretrainTokenizationStrategy(PretrainTokenizationStrategy):
     def __init__(
         self,
         *args: Any,
         image_token: str,
         image_token_id: int,
         image_column: str = "images",
         image_base_dir: str | None = None,
+        image_token_spec: ImageTokenSpec | None = None,
         **kwargs: Any,
     ) -> None:
         super().__init__(*args, **kwargs)
         self.image_token = image_token
         self.image_token_id = image_token_id
         self.image_column = image_column
         self.image_base_dir = image_base_dir
+        self.image_token_spec = image_token_spec
     strat = MultimodalPretrainTokenizationStrategy(
         PretrainTokenizer(),
         tokenizer,
         cfg.train_on_inputs,
         cfg.sequence_len,
         text_column=text_column,
         image_column=image_column,
         image_base_dir=image_base_dir,
         image_token=spec.image_token,
         image_token_id=spec.image_token_id,
         max_length=cfg.sequence_len,
+        image_token_spec=spec,
     )
-    strat.image_token_spec = spec  # type: ignore[attr-defined]
     return strat
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 299 - 312,
Currently the code monkey-patches image_token_spec onto the
MultimodalPretrainTokenizationStrategy instance after construction; instead add
an image_token_spec parameter (optional) to
MultimodalPretrainTokenizationStrategy.__init__ and assign it to
self.image_token_spec there, update the constructor signature where the class is
defined, and remove the post-construction assignment and its type: ignore in the
caller (the instantiation code shown) so the attribute is part of the class
contract; ensure any reads in causal.py continue to work with the attribute on
the instance.

48-48: Nit: replace ambiguous × (U+00D7) with ASCII x.

Ruff RUF003 flags the multiplication sign as confusable with the Latin letter x. Trivial fix:

Proposed fix
-# Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM.
+# Without masking these in labels, loss blows up ~10x on Qwen/SmolVLM.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` at line 48, Replace the
Unicode multiplication sign U+00D7 in the comment "Without masking these in
labels, loss blows up ~10× on Qwen/SmolVLM." with the ASCII letter 'x' so it
reads "Without masking these in labels, loss blows up ~10x on Qwen/SmolVLM.";
update that inline comment in multimodal_pretrain.py (the quoted comment) to
avoid the confusable character flagged by Ruff RUF003.

266-268: n_chunks is structurally always 1 — multiplication is dead code.

_tokenize wraps input_ids in a single-element list, so len(res["input_ids"]) is always 1. The * n_chunks replication is harmless but obscures intent. Either drop the variable or assert it as an invariant so a future change to _tokenize doesn't silently produce mis-aligned images/_mm_text lists.

Proposed simplification
-        n_chunks = len(res["input_ids"])
-        res["images"] = [list(images)] * n_chunks
-        res["_mm_text"] = [text] * n_chunks
+        # `_tokenize` produces exactly one chunk; replicate per-chunk for schema parity.
+        assert len(res["input_ids"]) == 1
+        res["images"] = [list(images)]
+        res["_mm_text"] = [text]
         return res
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 266 - 268,
The code currently computes n_chunks = len(res["input_ids"]) and multiplies
images/_mm_text by it, but _tokenize always wraps input_ids in a single-element
list so this replication is dead code; change the block in
multimodal_pretrain.py to assert that len(res["input_ids"]) == 1 (to catch
future changes to _tokenize) and then set res["images"] = [list(images)] and
res["_mm_text"] = [text] (no multiplication), referencing the variables
n_chunks, res["input_ids"], res["images"], res["_mm_text"], and the _tokenize
behavior so the intent and invariant are explicit.

221-232: _tokenize silently ignores its add_eos_token and strip_bos_token parameters.

The signature accepts add_eos_token and strip_bos_token, but the body unconditionally appends EOS and never inspects either flag. If anything in PretrainTokenizationStrategy (or a future caller) invokes _tokenize(prompt, add_eos_token=False), EOS will be appended anyway, masking the intent.

Either honor the flags, or drop them from the signature and add a brief docstring explaining the override is fixed-behavior on purpose.

Honor-the-flag variant
     def _tokenize(
         self,
         prompt: str,
         add_eos_token: bool = True,
         strip_bos_token: bool = False,
     ) -> BatchEncoding:
         # No truncation: collator re-tokenizes the full text without truncation;
         # truncating here decouples the stored ids from what the model receives.
         res = self.tokenizer(prompt, add_special_tokens=True)
-        res["input_ids"] = [res["input_ids"] + [self.tokenizer.eos_token_id]]
-        res["attention_mask"] = [res["attention_mask"] + [1]]
+        ids = res["input_ids"]
+        mask = res["attention_mask"]
+        if strip_bos_token and ids and ids[0] == self.tokenizer.bos_token_id:
+            ids = ids[1:]
+            mask = mask[1:]
+        if add_eos_token and self.tokenizer.eos_token_id is not None:
+            ids = ids + [self.tokenizer.eos_token_id]
+            mask = mask + [1]
+        res["input_ids"] = [ids]
+        res["attention_mask"] = [mask]
         return res
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 221 - 232,
The _tokenize method currently ignores add_eos_token and strip_bos_token; update
_tokenize to respect these flags by: call self.tokenizer(prompt,
add_special_tokens=True) as before, then if strip_bos_token is True and the
first id equals self.tokenizer.bos_token_id (or bos_token_id is not None),
remove that first id and its attention mask entry; then if add_eos_token is True
append self.tokenizer.eos_token_id and a 1 to attention_mask, otherwise leave
them out; preserve the existing no-truncation behavior and the outer-list
wrapping of input_ids and attention_mask, and keep returning the BatchEncoding.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/test_multimodal_streaming.py`:
- Around line 107-123: The test test_encode_counts_placeholders_on_full_text
doesn't trigger truncation; adjust the input so the last image placeholder is
only included when text is fully considered: either increase long_filler (e.g.,
"lorem ipsum " * 200) or reduce max_tokens (e.g., set max_tokens to a small
value like 20) in the call to encode_streaming_multimodal so that a truncating
implementation would drop the last placeholder; keep the assertion and
image_token/image_token_id usage the same.

---

Outside diff comments:
In `@src/axolotl/utils/data/streaming.py`:
- Around line 268-296: The packed-sample eval path is still using
cfg.sequence_len instead of the computed effective_seq_len, so update the
PretrainingBatchSamplerDataCollatorForSeq2Seq instantiation and the
functools.partial for encode_packed_streaming to use effective_seq_len (not
cfg.sequence_len) for pad_to_multiple_of and max_seq_length; keep the existing
multipack_attn logic (True if not cfg.pretraining_dataset else
cfg.pretrain_multipack_attn) and other args unchanged so eval falls back to
cfg.sequence_len only when cfg.eval_sequence_len is not set.

---

Nitpick comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 299-312: Currently the code monkey-patches image_token_spec onto
the MultimodalPretrainTokenizationStrategy instance after construction; instead
add an image_token_spec parameter (optional) to
MultimodalPretrainTokenizationStrategy.__init__ and assign it to
self.image_token_spec there, update the constructor signature where the class is
defined, and remove the post-construction assignment and its type: ignore in the
caller (the instantiation code shown) so the attribute is part of the class
contract; ensure any reads in causal.py continue to work with the attribute on
the instance.
- Line 48: Replace the Unicode multiplication sign U+00D7 in the comment
"Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM." with the
ASCII letter 'x' so it reads "Without masking these in labels, loss blows up
~10x on Qwen/SmolVLM."; update that inline comment in multimodal_pretrain.py
(the quoted comment) to avoid the confusable character flagged by Ruff RUF003.
- Around line 266-268: The code currently computes n_chunks =
len(res["input_ids"]) and multiplies images/_mm_text by it, but _tokenize always
wraps input_ids in a single-element list so this replication is dead code;
change the block in multimodal_pretrain.py to assert that len(res["input_ids"])
== 1 (to catch future changes to _tokenize) and then set res["images"] =
[list(images)] and res["_mm_text"] = [text] (no multiplication), referencing the
variables n_chunks, res["input_ids"], res["images"], res["_mm_text"], and the
_tokenize behavior so the intent and invariant are explicit.
- Around line 221-232: The _tokenize method currently ignores add_eos_token and
strip_bos_token; update _tokenize to respect these flags by: call
self.tokenizer(prompt, add_special_tokens=True) as before, then if
strip_bos_token is True and the first id equals self.tokenizer.bos_token_id (or
bos_token_id is not None), remove that first id and its attention mask entry;
then if add_eos_token is True append self.tokenizer.eos_token_id and a 1 to
attention_mask, otherwise leave them out; preserve the existing no-truncation
behavior and the outer-list wrapping of input_ids and attention_mask, and keep
returning the BatchEncoding.

In `@tests/prompt_strategies/test_multimodal_pretrain.py`:
- Around line 72-84: The test currently only asserts the chosen token string;
update it to also assert the resolved token id so regressions in id mapping are
caught: arrange the smolvlm_processor.tokenizer fixture (used by
_FakeGemma3Like) to register a distinct id for the fake boi_token
("<fake_token_around_image>") different from "<image>", call
build_image_token_spec(_FakeGemma3Like()) as before, and add an assertion that
spec.image_token_id (and any related family ids such as
spec.image_padding_token_family_id if present) equals the tokenizer id you
registered for "<fake_token_around_image>" to ensure the returned spec maps to
the expected token id.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6879bd64-1080-433c-a971-e8c6c0fa3a33

📥 Commits

Reviewing files that changed from the base of the PR and between 986a25a and a6def82.

📒 Files selected for processing (12)
  • docs/multimodal.qmd
  • src/axolotl/core/builders/causal.py
  • src/axolotl/prompt_strategies/multimodal_pretrain.py
  • src/axolotl/utils/collators/mm_pretrain.py
  • src/axolotl/utils/data/sft.py
  • src/axolotl/utils/data/streaming.py
  • src/axolotl/utils/schemas/datasets.py
  • src/axolotl/utils/schemas/validation.py
  • tests/prompt_strategies/test_multimodal_pretrain.py
  • tests/test_multimodal_streaming.py
  • tests/utils/data/test_mm_cpt_eval.py
  • tests/utils/schemas/validation/test_multimodal_cpt.py
✅ Files skipped from review due to trivial changes (1)
  • docs/multimodal.qmd
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/axolotl/core/builders/causal.py
  • src/axolotl/utils/data/sft.py

Comment thread tests/test_multimodal_streaming.py Outdated
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request Apr 25, 2026
  Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the
  happy path; expands schema, hardens fallbacks, tightens validation.

  Bug fixes
  ---------
  - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token`
    when they differ. Without this, MM CPT crashed on the first batch
    with "Prompt contained 0 image tokens".
  - `dispatch_batches: true` placeholder dataset mirrors the configured
    `image_column` so worker ranks don't KeyError on empty rows.
  - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`,
    `False`) instead of coercing to `[]`.
  - `_tokenize` now honors `add_eos_token` / `strip_bos_token` instead of
    silently ignoring them.

  Schema
  ------
  - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset`
    (the documented `ds_type: json` shape now reaches `load_dataset`).
  - Preserve `trust_remote_code` through `_pretraining_config_from_entry`
    and pass to `load_dataset`.
  - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder +
    collator) with documented fallback to `cfg.sequence_len`.

  Validation (config-load time)
  -----------------------------
  - Reject mixed multimodal/text entries in `test_datasets`.
  - Reject MM `test_datasets` paired with non-MM training.
  - Reject non-MM `test_datasets` paired with MM training.
  - Removed the redundant runtime check in `sft.py`; the schema is now
    the single source of truth.

  Hardening / observability
  -------------------------
  - Mixed/all-text batch handling: collator routes all-text batches to
    the tokenizer (no `pixel_values`); mixed batches go through the
    processor as-is. Documented per-VLM compatibility (verified on
    SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL).
  - Reject cloud / object-store URIs (`s3://`, `gs://`, `gcs://`,
    `az://`, `azure://`, `hf://`) in image paths so users see the
    explicit "Non-local scheme" error instead of a confusing
    FileNotFoundError.
  - Warn at construction when `MultiModalPretrainDataCollator.tokenizer`
    is not `processor.tokenizer` (all-text vs image batches could
    otherwise tokenize the same text differently).
  - Warn at retry kickoff when a processor call fails on a batch, so
    users see why processing stalls during per-row diagnosis.
  - INFO log when `remove_unused_columns` is auto-set to `false` for
    MM CPT.
  - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass).
  - Clarify "exceeds sequence_len" error in both encoder paths to note
    image-patch expansion may push the final length higher.

  Code quality
  ------------
  - Lift `image_token_spec` into `MultimodalPretrainTokenizationStrategy.
    __init__` instead of post-construction monkey-patch + `type: ignore`.
  - Hoist `import importlib` out of the per-class loop.
  - Drop dead `n_chunks` multiplication; replace with explicit invariant
    assertion.
  - Replace ambiguous `×` (U+00D7) with ASCII `x` in code/comments and
    the user-facing pixel-cap error.

  Tests
  -----
  +15 regression tests across the four MM CPT suites covering: Gemma-3
  boi_token autodetect (with id-mapping assertion), `eval_sequence_len`
  on encoder + collator (set + unset-fallback), `trust_remote_code` and
  `ds_type` preservation, three modality-mismatch validation cases,
  tokenizer-mismatch warning, `remove_unused_columns` auto-set log,
  cloud-URI rejection. 68 tests pass across `test_multimodal_streaming`,
  `test_multimodal_pretrain`, `test_multimodal_cpt`, and
  `test_mm_cpt_eval`. Lint clean against ruff v0.15.8 (upstream
  pre-commit pin).
@thad0ctor
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
src/axolotl/prompt_strategies/multimodal_pretrain.py (1)

278-284: Convert the alignment invariant from assert to an explicit raise.

This assert guards a non-trivial correctness property: if _tokenize ever returns more than one chunk (e.g., a future change re-introduces overflow striding), images=[list(images)] and _mm_text=[text] will be silently misaligned with input_ids[1:], sending unpaired text into the model. With python -O (or any optimized deployment), assert statements are stripped, turning this into silent data corruption rather than a clear failure. The neighboring _tokenize doesn't pass truncation=True/stride=... today, but this guard is precisely meant to catch a future regression — so it shouldn't depend on __debug__.

♻️ Proposed fix
-        # `_tokenize` produces exactly one chunk; the assert keeps that
-        # invariant explicit so a future change there can't silently
-        # mis-align `images` / `_mm_text` against `input_ids`.
-        assert len(res["input_ids"]) == 1
+        # `_tokenize` produces exactly one chunk; this guard keeps the
+        # invariant explicit so a future change there can't silently
+        # mis-align `images` / `_mm_text` against `input_ids`. Use a real
+        # raise (not assert) so it survives `python -O`.
+        if len(res["input_ids"]) != 1:
+            raise RuntimeError(
+                "MultimodalPretrainTokenizationStrategy._tokenize produced "
+                f"{len(res['input_ids'])} chunks; multimodal CPT requires "
+                "exactly one chunk per row to keep `images`/`_mm_text` "
+                "aligned with `input_ids`."
+            )
         res["images"] = [list(images)]
         res["_mm_text"] = [text]
         return res
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 278 - 284,
Replace the fragile assert in multimodal_pretrain.py that checks
len(res["input_ids"]) == 1 with an explicit runtime check that raises a clear
exception (e.g., ValueError or RuntimeError) if the invariant is violated;
perform this check before mutating res by setting res["images"] and
res["_mm_text"], and include the actual length in the error message so callers
of the function (the code path around the _tokenize call and the block that
assigns res["images"] = [list(images)] and res["_mm_text"] = [text]) will fail
loudly instead of silently misaligning inputs under optimized Python.
src/axolotl/utils/collators/mm_pretrain.py (2)

116-145: O_NOFOLLOW only protects the final path component.

os.open(resolved, O_RDONLY | O_NOFOLLOW) closes the final-link TOCTOU window after realpath(), but parent directories along resolved can still be replaced with symlinks between realpath() and os.open(). Combined with the commonpath containment check, this is a reasonable trade-off for v1, but worth flagging in docs/multimodal.qmd so operators know the residual risk on shared/multi-tenant filesystems. A future hardening pass could use os.openat() on a image_base_dir fd to walk components atomically.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 116 - 145, The
_open_image_hardened method uses os.open(..., O_NOFOLLOW) which only prevents
TOCTOU on the final path component; parent directories can still be swapped
after realpath() and before os.open(), so add a note to docs/multimodal.qmd
describing this residual risk on shared/multi-tenant filesystems and that the
current commonpath+O_NOFOLLOW approach is an acceptable v1 tradeoff; also add a
TODO to consider a future hardening that walks components from an image_base_dir
fd using os.openat() (or similar) to eliminate the parent-directory symlink
race.

233-256: Inconsistent return type vs the imaged path.

The all-text fallback returns dict(batch) while the imaged path at the bottom returns the BatchEncoding directly (line 337). Both are mapping-compatible, but BatchEncoding carries .to(device), .convert_to_tensors(), and other helpers that callers of the collator (HF Trainer._prepare_inputs, custom devices, accelerate handlers) sometimes rely on. Returning a plain dict from one branch and a BatchEncoding from the other is an avoidable surprise.

♻️ Proposed fix
             batch["labels"] = tok_labels
-            return dict(batch)
+            return batch
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 233 - 256, The
all-text branch converts the tokenizer's BatchEncoding to a plain dict before
returning, causing an inconsistent return type versus the imaged path; instead
of returning dict(batch) keep and return the original BatchEncoding object from
self.tokenizer so callers (e.g., Trainer._prepare_inputs, accelerate) retain
.to(), .convert_to_tensors(), etc. Locate the all-text branch where `batch =
self.tokenizer(**tok_kwargs)` and `batch["labels"] = tok_labels` are set and
replace `return dict(batch)` with `return batch`.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 278-284: Replace the fragile assert in multimodal_pretrain.py that
checks len(res["input_ids"]) == 1 with an explicit runtime check that raises a
clear exception (e.g., ValueError or RuntimeError) if the invariant is violated;
perform this check before mutating res by setting res["images"] and
res["_mm_text"], and include the actual length in the error message so callers
of the function (the code path around the _tokenize call and the block that
assigns res["images"] = [list(images)] and res["_mm_text"] = [text]) will fail
loudly instead of silently misaligning inputs under optimized Python.

In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 116-145: The _open_image_hardened method uses os.open(...,
O_NOFOLLOW) which only prevents TOCTOU on the final path component; parent
directories can still be swapped after realpath() and before os.open(), so add a
note to docs/multimodal.qmd describing this residual risk on shared/multi-tenant
filesystems and that the current commonpath+O_NOFOLLOW approach is an acceptable
v1 tradeoff; also add a TODO to consider a future hardening that walks
components from an image_base_dir fd using os.openat() (or similar) to eliminate
the parent-directory symlink race.
- Around line 233-256: The all-text branch converts the tokenizer's
BatchEncoding to a plain dict before returning, causing an inconsistent return
type versus the imaged path; instead of returning dict(batch) keep and return
the original BatchEncoding object from self.tokenizer so callers (e.g.,
Trainer._prepare_inputs, accelerate) retain .to(), .convert_to_tensors(), etc.
Locate the all-text branch where `batch = self.tokenizer(**tok_kwargs)` and
`batch["labels"] = tok_labels` are set and replace `return dict(batch)` with
`return batch`.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 31fbfd9a-34c9-4e3a-9046-c3ac9bf7a022

📥 Commits

Reviewing files that changed from the base of the PR and between a6def82 and c6cd14c.

📒 Files selected for processing (4)
  • src/axolotl/prompt_strategies/multimodal_pretrain.py
  • src/axolotl/utils/collators/mm_pretrain.py
  • tests/prompt_strategies/test_multimodal_pretrain.py
  • tests/test_multimodal_streaming.py

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 25, 2026

@thad0ctor
Copy link
Copy Markdown
Contributor Author

thad0ctor commented Apr 25, 2026

@coderabbitai review , your last suggestion introduced a regression in gemma4, resolved in last commit

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

@thad0ctor — apologies for the regression! I appreciate you flagging it and resolving it in the follow-up commit. Based on the commit messages, it looks like the Gemma-3/4 boi_token autodetect logic was affected — I'll make sure to keep that in mind for any future suggestions touching the image token spec detection path.

Triggering a review of the latest commit now.


✏️ Learnings added
Learnt from: thad0ctor
URL: https://github.com/axolotl-ai-cloud/axolotl/pull/3629

Timestamp: 2026-04-25T23:17:32.434Z
Learning: In `src/axolotl/prompt_strategies/multimodal_pretrain.py`, a previous CodeRabbit suggestion introduced a regression in Gemma-4 related to `boi_token` autodetection. For Gemma-3/4, `boi_token` must be preferred over `image_token` when they map to different token ids. The fix was applied in a follow-up commit. Treat any suggestions touching `build_image_token_spec` or Gemma-3/4 BOI token logic with extra caution.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: thad0ctor
Repo: axolotl-ai-cloud/axolotl PR: 3629
File: src/axolotl/utils/collators/mm_pretrain.py:234-282
Timestamp: 2026-04-25T21:40:48.932Z
Learning: In `src/axolotl/utils/collators/mm_pretrain.py`, the `MultiModalPretrainDataCollator` passes mixed batches (some rows with images, some text-only, i.e., `images=[[img1], [], [img2]]`) directly to `self.processor(text=texts, images=images, ...)`. This has been tested and confirmed to work with Gemma 3/4, Qwen 3.5, SmolVLM/SmolVLM2, and Qwen 2.5VL. A per-subset homogeneous fallback (splitting image rows vs text-only rows into sub-batches) is deferred to a future PR due to `pixel_values` shape variance across processors (Qwen2-VL uses `image_grid_thw`, others vary in ndim). Documentation of this limitation and tested processors is planned for `docs/multimodal.qmd`.
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

@thad0ctor — Apologies for the regression! I'll trigger a fresh review of the latest commit now so the fix is properly validated.


✏️ Learnings added
Learnt from: thad0ctor
URL: https://github.com/axolotl-ai-cloud/axolotl/pull/3629

Timestamp: 2026-04-25T23:17:45.124Z
Learning: In `src/axolotl/prompt_strategies/multimodal_pretrain.py`, the Gemma-3/Gemma-4 `boi_token` autodetect logic is sensitive. A prior CodeRabbit suggestion introduced a regression in Gemma-4 by incorrectly handling the `boi_token` vs `image_token` distinction during `build_image_token_spec`. Be conservative with suggestions that touch this autodetect path and always verify Gemma-3/Gemma-4 compatibility before recommending changes.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: thad0ctor
Repo: axolotl-ai-cloud/axolotl PR: 3629
File: src/axolotl/utils/collators/mm_pretrain.py:234-282
Timestamp: 2026-04-25T21:40:48.932Z
Learning: In `src/axolotl/utils/collators/mm_pretrain.py`, the `MultiModalPretrainDataCollator` passes mixed batches (some rows with images, some text-only, i.e., `images=[[img1], [], [img2]]`) directly to `self.processor(text=texts, images=images, ...)`. This has been tested and confirmed to work with Gemma 3/4, Qwen 3.5, SmolVLM/SmolVLM2, and Qwen 2.5VL. A per-subset homogeneous fallback (splitting image rows vs text-only rows into sub-batches) is deferred to a future PR due to `pixel_values` shape variance across processors (Qwen2-VL uses `image_grid_thw`, others vary in ndim). Documentation of this limitation and tested processors is planned for `docs/multimodal.qmd`.
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
src/axolotl/utils/collators/mm_pretrain.py (1)

268-313: Per-row retry can mislabel a text-only row as the offender in mixed batches.

The all-text bypass at line 234 already documents that some processors can't be called cleanly with images=[[]]. In the per-row retry on line 290, text-only rows are sent to self.processor(text=[t], images=[[]], ...) regardless. If the original batch failure was caused by an image row but a co-located text-only row also raises on standalone processor call, the loop's break on the first exception sets offender_idx to that text-only row and shadows the actual culprit in the diagnostic message. The original exception is still chained via from exc, so this is a misleading-error concern (not a correctness one), but it does undercut the value of the retry pinpointing.

A small mirror of the line-234 fast-path — routing text-only retry rows through self.tokenizer — keeps the retry diagnostic aligned with the production tokenization path without re-opening the deferred per-subset homogeneous-fallback discussion.

🛠️ Suggested narrow change to retry path
             for i, (t, imgs) in enumerate(zip(texts, images, strict=True)):
                 try:
-                    self.processor(text=[t], images=[imgs], **retry_kwargs)
+                    if len(imgs) == 0:
+                        # Mirror the all-text bypass; some processors raise on images=[[]].
+                        self.tokenizer(text=[t], **retry_kwargs)
+                    else:
+                        self.processor(text=[t], images=[imgs], **retry_kwargs)
                 except Exception as retry_exc:

Based on learnings: the collator is documented as passing mixed batches directly to the processor for the tested set (Gemma 3/4, Qwen 3.5, SmolVLM/SmolVLM2, Qwen 2.5VL), with per-subset fallback deferred to a future PR — this suggestion is intentionally scoped to the diagnostic-only retry path so it doesn't touch that deferred work.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 268 - 313, The
per-row retry loop in MultiModalPretrainDataCollator currently calls
self.processor(...) for every row, which can misattribute failures from image
rows to text-only rows; change the retry loop so that if a row is text-only
(e.g., imgs is empty or falsy / images element is []), call
self.tokenizer(text=[t], **retry_kwargs) instead of self.processor, otherwise
call self.processor(text=[t], images=[imgs], **retry_kwargs); keep retry_kwargs
(return_tensors, padding, pad_to_multiple_of) the same and do not alter the main
mixed-batch processing path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/test_multimodal_streaming.py`:
- Around line 188-193: The file was reformatted by ruff-format; update the
constructor call to match ruff's preferred style by collapsing the dict literal
onto one line for the _StubTokenizer(...) invocation and ensure the subsequent
_StubProcessor(...) call remains formatted correctly; specifically modify the
tok = _StubTokenizer({...}) line (and keep proc = _StubProcessor(...) as is) to
conform to ruff format rules, then re-run pre-commit or ruff format and commit
the change.

In `@tests/utils/schemas/validation/test_multimodal_cpt.py`:
- Around line 321-334: The test
test_remove_unused_columns_already_false_does_not_log incorrectly sets propagate
on the child logger; change the monkeypatch target to the parent logger
logging.getLogger("axolotl") so caplog can capture records emitted by
descendants; keep the rest of the test flow (create cfg with
_mm_cpt_cfg(remove_unused_columns=False), call validate_config(cfg), and assert
no "Auto-set" messages) so the negative assertion actually fails if the auto-set
log is emitted.

---

Nitpick comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 268-313: The per-row retry loop in MultiModalPretrainDataCollator
currently calls self.processor(...) for every row, which can misattribute
failures from image rows to text-only rows; change the retry loop so that if a
row is text-only (e.g., imgs is empty or falsy / images element is []), call
self.tokenizer(text=[t], **retry_kwargs) instead of self.processor, otherwise
call self.processor(text=[t], images=[imgs], **retry_kwargs); keep retry_kwargs
(return_tensors, padding, pad_to_multiple_of) the same and do not alter the main
mixed-batch processing path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b0c90ff4-69ff-446b-b76d-13459369affe

📥 Commits

Reviewing files that changed from the base of the PR and between c6cd14c and 5a59904.

📒 Files selected for processing (4)
  • src/axolotl/prompt_strategies/multimodal_pretrain.py
  • src/axolotl/utils/collators/mm_pretrain.py
  • tests/test_multimodal_streaming.py
  • tests/utils/schemas/validation/test_multimodal_cpt.py

Comment thread tests/test_multimodal_streaming.py Outdated
Comment thread tests/utils/schemas/validation/test_multimodal_cpt.py
@thad0ctor
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/test_multimodal_streaming.py (1)

224-375: 💤 Low value

LGTM — eval-vs-train sequence-len fallback coverage is solid.

The three-case split (eval w/ eval_sequence_len set → 1024; train → 4096; eval w/ eval_sequence_len unset → falls back to 4096) directly encodes the documented schema contract and would loudly catch any regression that drops the fallback.

One small fragility note: monkeypatch.setattr("axolotl.utils.data.streaming.functools.partial", ...) only intercepts calls if streaming.py references functools.partial via attribute access — if anyone refactors to from functools import partial; partial(...), the patch becomes a silent no-op and these assertions still pass via the lambda passthrough. Worth a comment in the test or rebinding to axolotl.utils.data.streaming.partial if that import style is ever introduced.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_multimodal_streaming.py` around lines 224 - 375, The test
monkeypatch targets "axolotl.utils.data.streaming.functools.partial", which only
works if streaming.py references functools via attribute access; change the
monkeypatch to bind to the actual symbol the module uses (e.g., patch
"axolotl.utils.data.streaming.partial" if streaming.py does "from functools
import partial", or keep the current target but add a comment warning about the
import style fragility), by updating the monkeypatch.setattr call in
test_wrap_streaming_dataset_uses_pretraining_config_arg and
test_wrap_streaming_dataset_eval_honors_eval_sequence_len to patch the exact
name exported/used by streaming.py (or use raising=False) so the fake_partial
reliably replaces the encoder partial regardless of import style.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 117-126: In _open_image_hardened, if os.open succeeds but
os.fdopen raises, the raw fd is leaked; wrap the os.fdopen call in a try/except
that closes the fd on failure (or use contextlib.closing) and re-raise a
ValueError/appropriate exception with the original error as cause so the
descriptor is always closed; adjust the error handling around os.open and
os.fdopen in _open_image_hardened to ensure fd is closed when fdopen fails.

---

Nitpick comments:
In `@tests/test_multimodal_streaming.py`:
- Around line 224-375: The test monkeypatch targets
"axolotl.utils.data.streaming.functools.partial", which only works if
streaming.py references functools via attribute access; change the monkeypatch
to bind to the actual symbol the module uses (e.g., patch
"axolotl.utils.data.streaming.partial" if streaming.py does "from functools
import partial", or keep the current target but add a comment warning about the
import style fragility), by updating the monkeypatch.setattr call in
test_wrap_streaming_dataset_uses_pretraining_config_arg and
test_wrap_streaming_dataset_eval_honors_eval_sequence_len to patch the exact
name exported/used by streaming.py (or use raising=False) so the fake_partial
reliably replaces the encoder partial regardless of import style.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a3c86584-ecd7-4bf6-890b-f60285e9be60

📥 Commits

Reviewing files that changed from the base of the PR and between 5a59904 and 5ceec56.

📒 Files selected for processing (6)
  • src/axolotl/utils/collators/mm_pretrain.py
  • src/axolotl/utils/schemas/validation.py
  • tests/conftest.py
  • tests/prompt_strategies/test_multimodal_pretrain.py
  • tests/test_multimodal_streaming.py
  • tests/utils/schemas/validation/test_multimodal_cpt.py
✅ Files skipped from review due to trivial changes (1)
  • tests/conftest.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/schemas/validation.py

Comment thread src/axolotl/utils/collators/mm_pretrain.py Outdated
thad0ctor added 9 commits May 14, 2026 15:14
Adds a streaming-first multimodal CPT path: raw `(text, images[])` rows
are tokenized once with a placeholder-count guardrail, batched through a
hardened collator, and fed to a VLM with image-family tokens masked out
of labels. Gated by `type: multimodal_pretrain` (or `multimodal: true`)
on a `pretraining_dataset` entry; works end-to-end for train and eval,
including multi-entry eval and mixed image/text batches.

Features
--------
- Streaming MM CPT encoder (`encode_streaming_multimodal`): counts
  placeholders by token id (not substring), enforces
  `placeholders == len(images)` per row, and rejects rows that exceed
  `sequence_len` instead of silently truncating mid-placeholder.
- MM CPT collator (`MultiModalPretrainDataCollator`): security-hardened
  image loader (path traversal / NUL byte / remote URL / multi-frame
  bomb / pixel cap rejection), per-row image cap, processor-call retry
  that pinpoints the offending row, and label-side masking of every
  image-family token id.
- Mixed image/text batches: text-only rows in a batch take a
  tokenizer-only fallback (no `pixel_values`); rows with images go
  through the processor as usual.
- Eval support: `test_datasets` accepts MM entries via a dedicated
  `MultiModalEvalDataset` model so per-entry `text_column` /
  `image_column` / `image_base_dir` / `image_token` survive validation.
  Multi-entry MM eval streams are concatenated.
- `dispatch_batches: true` support: non-main ranks get a placeholder
  dataset that mirrors the configured text + image columns.
- Config validation gates: `processor_type` required, `sample_packing:
  false` enforced, `chat_template` rejected, single
  `pretraining_dataset` entry required, MM eval entries must share
  `image_base_dir` / `image_token`, mixed MM/non-MM eval rejected,
  incompatible processor classes (Mllama, Pixtral, InternVL) rejected
  at startup. `remove_unused_columns` is auto-set to `false` with an
  INFO log.
- Docs: new section in `docs/multimodal.qmd` covering the YAML shape,
  placeholder-token table, eval contract, and supported/rejected model
  families.

YAML example
------------
    base_model: HuggingFaceTB/SmolVLM-500M-Instruct
    processor_type: AutoProcessor

    pretraining_dataset:
      - path: /path/to/shards/*.jsonl
        ds_type: json
        type: multimodal_pretrain
        text_column: text
        image_column: images
        image_base_dir: /path/to/images

    streaming: true
    sequence_len: 2048
    sample_packing: false

Tests
-----
59 tests across four suites covering the encoder, collator (including
mixed/all-text batches and security gates), prompt strategy, schema
preservation, multi-entry eval merge, eval homogeneity validation,
eval-aware collator, dispatch-batches placeholder shape, and the
auto-set log record.
  Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the
  happy path; expands schema, hardens fallbacks, tightens validation.

  Bug fixes
  ---------
  - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token`
    when they differ (Gemma-3's `image_token` is the post-expansion soft
    token, not the user-facing placeholder). Without this, MM CPT crashed
    on the first batch with "Prompt contained 0 image tokens".
  - `dispatch_batches: true` placeholder dataset now mirrors the
    configured `image_column` so worker ranks don't KeyError on empty
    rows.
  - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`,
    `False`) instead of coercing to `[]` — keeps malformed rows from
    silently turning into text-only samples.

  Schema completeness
  -------------------
  - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset`
    (the documented `ds_type: json` shape now actually reaches
    `load_dataset`; previously dropped at validation).
  - Preserve `trust_remote_code` through `_pretraining_config_from_entry`
    and pass it to `load_dataset` (was silently dropped).
  - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator)
    with documented fallback to `cfg.sequence_len` when unset.

  Validation tightening (config-load time)
  ----------------------------------------
  - Reject mixed multimodal/text entries in `test_datasets`.
  - Reject MM `test_datasets` paired with non-MM training.
  - Reject non-MM `test_datasets` paired with MM training.
  - The redundant runtime check in `sft.py` is removed; schema is the
    single source of truth.

  Hardening / observability
  -------------------------
  - Mixed/all-text batch handling: collator routes all-text batches to
    the tokenizer (no `pixel_values`); mixed batches go through the
    processor as-is. Documented per-VLM compatibility (verified on
    SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL).
  - Reject cloud/object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`,
    `azure://`, `hf://`) in image paths so users see "Non-local scheme"
    instead of a confusing FileNotFoundError.
  - Warn when `MultiModalPretrainDataCollator.tokenizer is not
    processor.tokenizer` (all-text vs image batches could otherwise
    tokenize the same text differently).
  - Warn at retry kickoff when a processor call fails on a batch, so
    users see why processing stalls during per-row diagnosis.
  - INFO log when `remove_unused_columns` is auto-set to `false` for
    MM CPT.
  - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass).
  - Clarify "exceeds sequence_len" error to note image-patch expansion
    may push the final length higher.

  Tests
  -----
  +8 regression tests across the four MM CPT suites covering: Gemma-3
  boi_token autodetection, eval_sequence_len (encoder + collator,
  including the fallback case), `trust_remote_code` and `ds_type`
  preservation through validation, three modality-mismatch validation
  cases, tokenizer-mismatch warning, cloud-URI rejection. 68 tests pass
  across `tests/test_multimodal_streaming.py`,
  `tests/prompt_strategies/test_multimodal_pretrain.py`,
  `tests/utils/schemas/validation/test_multimodal_cpt.py`,
  `tests/utils/data/test_mm_cpt_eval.py`. Lint clean against ruff
  v0.15.8 (upstream pre-commit pin).
  Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the
  happy path; expands schema, hardens fallbacks, tightens validation.

  Bug fixes
  ---------
  - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token`
    when they differ. Without this, MM CPT crashed on the first batch
    with "Prompt contained 0 image tokens".
  - `dispatch_batches: true` placeholder dataset mirrors the configured
    `image_column` so worker ranks don't KeyError on empty rows.
  - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`,
    `False`) instead of coercing to `[]`.
  - `_tokenize` now honors `add_eos_token` / `strip_bos_token` instead of
    silently ignoring them.

  Schema
  ------
  - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset`
    (the documented `ds_type: json` shape now reaches `load_dataset`).
  - Preserve `trust_remote_code` through `_pretraining_config_from_entry`
    and pass to `load_dataset`.
  - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder +
    collator) with documented fallback to `cfg.sequence_len`.

  Validation (config-load time)
  -----------------------------
  - Reject mixed multimodal/text entries in `test_datasets`.
  - Reject MM `test_datasets` paired with non-MM training.
  - Reject non-MM `test_datasets` paired with MM training.
  - Removed the redundant runtime check in `sft.py`; the schema is now
    the single source of truth.

  Hardening / observability
  -------------------------
  - Mixed/all-text batch handling: collator routes all-text batches to
    the tokenizer (no `pixel_values`); mixed batches go through the
    processor as-is. Documented per-VLM compatibility (verified on
    SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL).
  - Reject cloud / object-store URIs (`s3://`, `gs://`, `gcs://`,
    `az://`, `azure://`, `hf://`) in image paths so users see the
    explicit "Non-local scheme" error instead of a confusing
    FileNotFoundError.
  - Warn at construction when `MultiModalPretrainDataCollator.tokenizer`
    is not `processor.tokenizer` (all-text vs image batches could
    otherwise tokenize the same text differently).
  - Warn at retry kickoff when a processor call fails on a batch, so
    users see why processing stalls during per-row diagnosis.
  - INFO log when `remove_unused_columns` is auto-set to `false` for
    MM CPT.
  - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass).
  - Clarify "exceeds sequence_len" error in both encoder paths to note
    image-patch expansion may push the final length higher.

  Code quality
  ------------
  - Lift `image_token_spec` into `MultimodalPretrainTokenizationStrategy.
    __init__` instead of post-construction monkey-patch + `type: ignore`.
  - Hoist `import importlib` out of the per-class loop.
  - Drop dead `n_chunks` multiplication; replace with explicit invariant
    assertion.
  - Replace ambiguous `×` (U+00D7) with ASCII `x` in code/comments and
    the user-facing pixel-cap error.

  Tests
  -----
  +15 regression tests across the four MM CPT suites covering: Gemma-3
  boi_token autodetect (with id-mapping assertion), `eval_sequence_len`
  on encoder + collator (set + unset-fallback), `trust_remote_code` and
  `ds_type` preservation, three modality-mismatch validation cases,
  tokenizer-mismatch warning, `remove_unused_columns` auto-set log,
  cloud-URI rejection. 68 tests pass across `test_multimodal_streaming`,
  `test_multimodal_pretrain`, `test_multimodal_cpt`, and
  `test_mm_cpt_eval`. Lint clean against ruff v0.15.8 (upstream
  pre-commit pin).
  - mm_pretrain.py: return BatchEncoding (not dict) from all-text branch so
    it matches the imaged path.
  - test_multimodal_cpt.py, test_multimodal_streaming.py: monkeypatch
    axolotl logger propagate=True so caplog can capture records (axolotl's
    logging config sets propagate=False, blocking root capture in CI).
  multimodal_pretrain.py: scope the boi_token swap in build_image_token_spec
  to processors whose `image_token` name contains "soft_token" (the Gemma-3
  convention). Without this, Gemma-4 (`image_token=<|image|>`,
  `boi_token=<|image>`) gets the wrong placeholder autodetected and every
  row fails validation with a 0-vs-N placeholder/image mismatch.

  test_multimodal_streaming.py: 6 new tests
  - Two for the new autodetection behavior (Gemma-4 keeps image_token,
    Gemma-3 still swaps to boi_token), using stub processors.
  - Three branch-coverage tests for build_image_token_spec failure modes:
    override not registered as special token, override resolves to unk,
    nothing autodetectable.
  - Three collator-path tests: skip_bad_images drops a row and continues,
    all-rows-dropped surfaces a RuntimeError, multi-frame GIF triggers
    the animation-bomb guard via _open_image_hardened.
  fix(test): patch parent `axolotl` logger so negative caplog assertion has teeth

  The previous monkeypatch targeted `axolotl.utils.schemas.validation`, which
  is already propagate=True by inheritance — the actual block sits one level
  up at the `axolotl` logger (propagate=False from logging_config.py). The
  result: caplog never received any records, and `assert not any("Auto-set"
  ... in caplog.records)` would have passed even if the regression fired.

  Mirror the positive test by flipping propagate on `logging.getLogger("axolotl")`
  and add a comment explaining why the leaf isn't the right target.
…ry loop

  Some HF processors reject `images=[[]]`, which made the per-row retry
  flag innocent text-only rows as the offender. Mirror the all-text
  bypass — diagnostic-only path, mainline unchanged.
MultiModalPretrainDataCollator.torch_call calls processor(text=...) which
re-tokenizes _mm_text from scratch, discarding the EOS that
encode_streaming_multimodal appended to input_ids. Without this, labels
never contain EOS at end-of-document and the model never learns to emit a
stop token — symptoms: non-terminating / repetitive generation. Match the
text CPT contract (encode_streaming keeps EOS in both input_ids and
labels) by appending EOS to _mm_text, idempotently, gated on a new
add_eos_token field (default True).
Comment thread src/axolotl/utils/schemas/datasets.py Outdated
Comment thread src/axolotl/prompt_strategies/multimodal_pretrain.py Outdated
Comment thread src/axolotl/utils/collators/mm_pretrain.py Outdated
@ved1beta
Copy link
Copy Markdown
Member

No blockers, mostly bloat-trim.

@thad0ctor
Copy link
Copy Markdown
Contributor Author

No blockers, mostly bloat-trim.

I just committed based on your feedback with the cleanup and stripping out the associated tests

@thad0ctor thad0ctor requested a review from ved1beta May 15, 2026 06:28
thad0ctor added 2 commits May 14, 2026 23:34
Aligns mm_pretrain.py with mm_chat.py's image-loading posture. Drops
NUL/URL/path-traversal/pixel-cap/multi-frame/per-row-count guards that
defended against threat models that don't apply to a CLI trainer
loading its own dataset. Routes image loading through
transformers.image_utils.load_image, matching the chat path.

Keeps image_base_dir join, skip_bad_images, label masking, processor
compatibility check, and the tokenizer/processor mismatch warning.
@thad0ctor
Copy link
Copy Markdown
Contributor Author

@ved1beta sorry forgot to tag you on my last response

@ved1beta
Copy link
Copy Markdown
Member

lint please

@thad0ctor
Copy link
Copy Markdown
Contributor Author

lint please

@ved1beta just committed lint fixes

@thad0ctor
Copy link
Copy Markdown
Contributor Author

@ved1beta assuming test failure is unrelated to PR since it previously passed and only failed after syncing with main

Comment thread src/axolotl/utils/schemas/datasets.py Outdated
Comment thread docs/multimodal.qmd
The non-streaming `datasets:` MM CPT route was never wired through
`build_collator`, which only routes MM batches under the pretraining
branch — `datasets:` entries would emit `images`/`_mm_text` rows into a
text-only collator. Strip the strategy class + `load()` and their unit
tests; keep `ImageTokenSpec`, `build_image_token_spec`, and
`check_processor_compatibility` since the streaming collator imports
them. Add a docs callout that only the streaming `pretraining_dataset`
route is currently wired.

Fold `MultiModalEvalDataset` into `PretrainingDataset` via inheritance;
the only intentional divergence is the `type` default and the
`_require_mm_markers` validator. Drops ~60 lines of duplicated `Field`
declarations the reviewer flagged.

Tighten the collator `KeyError` message to mention only
`encode_streaming_multimodal` now that the strategy class is gone.
@thad0ctor
Copy link
Copy Markdown
Contributor Author

@ved1beta Pushed the carve-out (389e0f2) which strips the non-streaming strategy class + load() and associated unit tests.

Addressed MultiModalEvalDataset into PretrainingDataset via inheritance (only the type default and the _require_mm_markers validator differ now) and added a docs callout that only the streaming pretraining_dataset route is currently wired.

Impacted tests passed (60/60) and ran multi-step test on a local Qwen3-VL completes end-to-end with the streaming MM CPT path

@ved1beta
Copy link
Copy Markdown
Member

mostly good , config using type: multimodal_pretrain under datasets: will now fail at strategy load time with a raw AttributeError. make it a small ValueError("multimodal_pretrain is only supported via pretraining_dataset with streaming: true — see docs/multimodal.qmd")

@ved1beta ved1beta self-requested a review May 25, 2026 05:08
Copy link
Copy Markdown
Member

@ved1beta ved1beta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good after the val error , requesting nano for quick look if i missed anything LGTM

@ved1beta ved1beta requested a review from NanoCode012 May 25, 2026 05:10
…datasets:

Previously failed with a raw AttributeError at strategy load time. Now
raises a small ValueError pointing users to the supported entry point.
@thad0ctor
Copy link
Copy Markdown
Contributor Author

thad0ctor commented May 25, 2026

looks good after the val error , requesting nano for quick look if i missed anything LGTM

thank you, just added the valueerror per your request

I do have a follow up PR planned to address the non-streaming path and improve resumption speed on streaming path @NanoCode012

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants