feat: add multimodal continued pre-training (raw image+text)#3629
feat: add multimodal continued pre-training (raw image+text)#3629thad0ctor wants to merge 18 commits into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds end-to-end multimodal continued-pretraining (CPT): new ImageTokenSpec and multimodal tokenization strategy, streaming multimodal encoder, MM pretrain data collator with hardened image handling and label masking, schema/config validation gates, causal builder routing to MM collator, documentation, and comprehensive tests. ChangesMultimodal CPT feature
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (7)
tests/utils/data/test_mm_cpt_eval.py (1)
93-93: Nit: prefix unused unpacked variable.
trainis never read; rename to_trainto silence RuffRUF059and signal intent.♻️ Suggested change
- train, eval_ds, _, _ = _prepare_streaming_dataset( + _train, eval_ds, _, _ = _prepare_streaming_dataset( cfg, tokenizer=None, processor=None )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/data/test_mm_cpt_eval.py` at line 93, Unpack the return of _prepare_streaming_dataset using a prefixed underscore for the unused variable so the intent is clear and Ruff RUF059 is silenced: change the left-hand side of the call that currently uses "train, eval_ds, _, _ = _prepare_streaming_dataset(...)" to use "_train, eval_ds, _, _ = _prepare_streaming_dataset(...)" so the unused first value is named _train while keeping eval_ds and the other placeholders unchanged.src/axolotl/utils/collators/mm_pretrain.py (3)
211-233: All-text fallback usesself.tokenizer, but image batches go throughself.processor.When all rows are text-only, this path bypasses the processor and tokenizes via
self.tokenizerdirectly. That implicitly assumesself.tokenizer is self.processor.tokenizer; if a caller ever constructs the collator with a different tokenizer (e.g. a wrapped one that adds chat scaffolding), text-only batches and image batches will tokenize differently for the same input. Two low-cost mitigations:
- Assert the invariant in
__post_init__:if getattr(self.processor, "tokenizer", None) is not self.tokenizer: LOG.warning( "MultiModalPretrainDataCollator.tokenizer is not processor.tokenizer; " "all-text and image batches may tokenize inconsistently." )- Or route the all-text path through
self.processor(text=texts, ...)if the processor supports text-only invocation.Not a correctness bug today, but it's a footgun that won't surface until someone wires it differently.
I found an important concern with the mixed-batch path — let me add that to the review.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 211 - 233, The all-text fallback uses self.tokenizer while image batches go through self.processor, which can cause inconsistent tokenization; add a check in MultiModalPretrainDataCollator.__post_init__ to warn if getattr(self.processor, "tokenizer", None) is not self.tokenizer (log a clear warning mentioning potential inconsistent tokenization between all-text and image batches), and/or change the all-text branch in the collate code to call self.processor(text=texts, return_tensors=self.return_tensors, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) when the processor supports a text-only invocation so both paths use the same tokenizer pipeline.
245-282: Retry path issues N additional processor calls per failed batch.When a batch fails, the loop replays each row individually until the first reproducer is found. For a 32-row batch with the offender at row 31, that's 31 successful processor calls before the diagnostic — non-trivial on GPU/CPU when the processor does real image preprocessing. Suggest at least logging the retry kick-off at WARNING so users see why a single failure triggers a noticeable stall, and consider bisecting (binary search) instead of linear scan.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 245 - 282, The retry loop inside MultiModalPretrainDataCollator that replays each row individually (the for-loop that calls self.processor(...) to set offender_idx) can cause N extra heavy processor calls; before starting the per-row retry, emit a processLogger.warning noting the retry/bisect kickoff and batch size so users see why processing stalls, and replace the linear scan with a binary-search (bisect) strategy that tests halves of the remaining index range by calling self.processor(text=[...], images=[...], **retry_kwargs) to locate the offending index in O(log N) calls instead of O(N) — keep offender_idx, retry_ok, and retry_kwargs semantics and preserve the existing RuntimeError message when locating or failing to locate the offender.
61-68: Closed scheme list misses cloud/HF URIs (s3://, gs://, az://, hf://, …).The current allowlist rejects
http(s),ftp(s),file,data, and UNC, buts3://bucket/key.png,gs://…,azure://…,hf://…would not match. Withoutimage_base_dirthey'd fall through toos.openand fail with a confusingFileNotFoundErrorrather than the explicit “Non-local image path scheme” error users (and tests) expect. Prefer a generic URI-scheme detector — anything matching^[a-z][a-z0-9+.-]*:(case-insensitive) other than a Windows drive letter (C:) is non-local.🛡️ Suggested change
+import re + +# Generic URI scheme: <scheme>:<...> per RFC 3986 (excluding Windows drive letters). +_URI_SCHEME_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9+.\-]*:") + ... - p_lower = p.lower() - if p_lower.startswith( - ("http://", "https://", "ftp://", "ftps://", "file://", "data:") - ) or p.startswith(("\\\\", "//")): + m = _URI_SCHEME_RE.match(p) + # Exclude Windows drive letters like "C:\..." (single-letter scheme). + is_uri = m is not None and len(m.group(0).rstrip(":")) > 1 + if is_uri or p.startswith(("\\\\", "//")): raise ValueError( f"Non-local image path scheme is not supported in v1 " f"multimodal CPT (got {p!r})." )If you adopt this, also add
s3://…/gs://…totest_collator_rejects_remote_urlsso the contract is locked down.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 61 - 68, The current scheme check in mm_pretrain (using p and p_lower) misses cloud/third-party URIs like s3://, gs://, az://, hf://; replace the explicit startswith list with a generic URI-scheme detector: if p matches the regex ^[a-z][a-z0-9+.-]*: (case-insensitive) treat it as non-local, but exempt Windows drive letters (e.g., patterns like C: followed by slash/backslash). Update the code path that raises the ValueError (the block that currently raises "Non-local image path scheme..." when p_lower.startswith(...)) to use this new check, and add s3:// and gs:// (and other cloud URIs as needed) to test_collator_rejects_remote_urls to lock down the behavior.src/axolotl/prompt_strategies/multimodal_pretrain.py (3)
244-249: Note:len(ids) > self.max_lengthis a text-side check; image-token expansion happens later.The encoder rejects rows whose text tokenizes beyond
sequence_len, but each placeholder expands at the processor into many patch tokens, so the actual model input can still exceedsequence_len. The collator warns post-hoc (mm_pretrain.py:286-292), which is fine — but consider clarifying the error here so users don't conclude the budget refers to post-expansion length. Optional wording tweak: append “(text-side; image patch expansion happens at the processor and will further inflate the sequence)”.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 244 - 249, The ValueError message raised when checking if len(ids) > self.max_length is misleading because it only refers to the text-side token count while image placeholders expand into many patch tokens later; update the error text in the multimodal_pretrain.py check (the block that examines ids and self.max_length) to clearly state this is a text-side limit and that image patch expansion at the processor can further inflate the sequence (e.g., append “(text-side; image patch expansion happens at the processor and will further inflate the sequence)”), so users don't mistake this for the post-expansion budget.
19-35: Hoistimport importlibout of the per-class loop.Cheap import-cache hit, but this is a module-level helper running at import time and it's clearer to put
import importlibat the top of the module alongside the other stdlib imports.♻️ Suggested change
-def _get_incompatible_processor_classes() -> tuple[type, ...]: - classes: list[type] = [] - for mod_path, name in ( - ("transformers.models.mllama", "MllamaProcessor"), - ("transformers.models.pixtral", "PixtralProcessor"), - ("transformers.models.internvl", "InternVLProcessor"), - ): - try: - import importlib - - mod = importlib.import_module(mod_path) +def _get_incompatible_processor_classes() -> tuple[type, ...]: + import importlib + + classes: list[type] = [] + for mod_path, name in ( + ("transformers.models.mllama", "MllamaProcessor"), + ("transformers.models.pixtral", "PixtralProcessor"), + ("transformers.models.internvl", "InternVLProcessor"), + ): + try: + mod = importlib.import_module(mod_path) cls = getattr(mod, name, None) if cls is not None: classes.append(cls) except ImportError: continue return tuple(classes)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 19 - 35, The helper _get_incompatible_processor_classes currently imports importlib inside the loop for each (mod_path, name) tuple; move the importlib import to the module top (with other stdlib imports) and remove the per-iteration import so the loop simply uses the already-imported importlib to import modules by mod_path and getattr the class name, keeping logic around mod = importlib.import_module(mod_path), cls = getattr(mod, name, None), classes.append(cls) and the ImportError handling unchanged.
109-117: Don't silently swallow errors fromget_added_vocab().
except Exception: passmasks tokenizer bugs (e.g. AttributeError on a custom wrapper) and silently degrades the special-token set used to validate overrides and build the family mask. At minimum, log at DEBUG so operators can correlate “image_token override rejected as not a registered special token” with the underlying cause.♻️ Suggested change
- known_special_tokens: set[str] = set() - try: - known_special_tokens |= set(tokenizer.get_added_vocab().keys()) - except Exception: - pass + known_special_tokens: set[str] = set() + try: + known_special_tokens |= set(tokenizer.get_added_vocab().keys()) + except Exception as exc: # noqa: BLE001 + LOG.debug("get_added_vocab() failed on %s: %s", type(tokenizer).__name__, exc) known_special_tokens |= set(getattr(tokenizer, "all_special_tokens", None) or [])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 109 - 117, The try/except around tokenizer.get_added_vocab() is silently swallowing errors which hides tokenizer bugs and causes silent degradation of known_special_tokens; replace the bare except with catching Exception as e and emit a DEBUG-level log including the exception details and context (e.g. "failed to call tokenizer.get_added_vocab()") before continuing so known_special_tokens is still populated from getattr(tokenizer, 'all_special_tokens'...) and getattr(tokenizer, 'additional_special_tokens'...) as before; reference the tokenizer.get_added_vocab() call and the known_special_tokens update in multimodal_pretrain.py and use the module or existing logger (logger or logging.getLogger(__name__)) to record the exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 234-282: The collator (MultiModalPretrainDataCollator) currently
passes mixed `images` (some empty lists) directly to `self.processor(texts,
images=images...)`, which may break or be undefined for many VLM processors;
modify the collator's processing path (where `batch =
self.processor(**proc_kwargs)` is called, e.g., inside the torch_call /
processor invocation) to detect mixed batches (any(len(im)==0 for im in images)
and not all(len(im)==0 for im in images)) and handle them by splitting into
homogeneous sub-batches (text-only rows vs image-containing rows) or falling
back to per-row `self.processor(text=[t], images=[imgs], **retry_kwargs)`
aggregation of the returned tensors into a single batched output, ensuring
shapes/padding/return_tensors remain consistent; additionally add doc notes in
docs/multimodal.qmd listing which processors support mixed batches or require
homogeneous input and expand tests to include other VLM processors to assert
behavior.
In `@src/axolotl/utils/data/sft.py`:
- Around line 211-226: The multimodal normalization helpers drop the
trust_remote_code flag so load_dataset never receives it; update
_pretraining_config_from_entry to include "trust_remote_code":
entry.get("trust_remote_code") in the returned DictDefault (and make the same
change in the analogous eval helper referenced around lines 278-284) so
MultiModalEvalDataset's trust_remote_code is preserved and propagated to
load_dataset.
- Around line 167-204: The multimodal eval branch reuses the training
pretraining config and therefore ignores cfg.eval_sequence_len; fix by, for each
multimodal eval entry (in the loop that builds eval_streams), after eval_config
= _pretraining_config_from_entry(entry) override the sequence length used for
evaluation by setting eval_config.sequence_len = cfg.eval_sequence_len (or the
appropriate eval length field if the pretraining config uses a different
attribute name) when cfg.eval_sequence_len is present, then pass that
eval_config into _load_streaming_dataset so the streaming eval loader uses the
intended eval length.
In `@src/axolotl/utils/schemas/datasets.py`:
- Around line 241-307: PretrainingDataset and MultiModalEvalDataset are missing
the documented ds_type loader hint so configs using
pretraining_dataset[].ds_type: json are dropped; add a ds_type: str | None =
Field(default=None, json_schema_extra={"description":"Loader type for the
dataset (e.g. 'json', 'csv') used to choose the dataset loader"}) to both
PretrainingDataset and MultiModalEvalDataset (and any other dataset model
variants that accept file-backed inputs) so the schema preserves the loader hint
during parsing; reference the PretrainingDataset and MultiModalEvalDataset class
definitions and add the Field in each model.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1413-1428: The code currently only ensures shared multimodal
fields across multimodal test entries but allows mixing modalities between train
and test or within test entries; detect modality per entry using _entry_is_mm
and reject mixed modalities: compute mm_train/text_train from
data.get("train_datasets") and mm_test/text_test from data.get("test_datasets")
(use the same isinstance(dict) + _entry_is_mm checks used for mm_test), then
raise a ValueError if (mm_train and text_test) or (mm_test and text_train) to
block train/eval modality mismatches, and also raise if mm_test and any
text-only entries exist in test_datasets to prevent mixed test entries; include
clear messages referencing train_datasets/test_datasets and multimodal vs
text-only to guide the user.
In `@tests/utils/schemas/validation/test_multimodal_cpt.py`:
- Around line 250-258: CI failed because ruff-format reformatted
tests/utils/schemas/validation/test_multimodal_cpt.py; re-run the formatter and
commit the changes. Run the formatter (e.g., pre-commit run --all-files or ruff
format .), review the updated formatting for the test that references caplog,
matches, logging.INFO and "remove_unused_columns", and commit the resulting file
so the lint hook is satisfied.
---
Nitpick comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 244-249: The ValueError message raised when checking if len(ids) >
self.max_length is misleading because it only refers to the text-side token
count while image placeholders expand into many patch tokens later; update the
error text in the multimodal_pretrain.py check (the block that examines ids and
self.max_length) to clearly state this is a text-side limit and that image patch
expansion at the processor can further inflate the sequence (e.g., append
“(text-side; image patch expansion happens at the processor and will further
inflate the sequence)”), so users don't mistake this for the post-expansion
budget.
- Around line 19-35: The helper _get_incompatible_processor_classes currently
imports importlib inside the loop for each (mod_path, name) tuple; move the
importlib import to the module top (with other stdlib imports) and remove the
per-iteration import so the loop simply uses the already-imported importlib to
import modules by mod_path and getattr the class name, keeping logic around mod
= importlib.import_module(mod_path), cls = getattr(mod, name, None),
classes.append(cls) and the ImportError handling unchanged.
- Around line 109-117: The try/except around tokenizer.get_added_vocab() is
silently swallowing errors which hides tokenizer bugs and causes silent
degradation of known_special_tokens; replace the bare except with catching
Exception as e and emit a DEBUG-level log including the exception details and
context (e.g. "failed to call tokenizer.get_added_vocab()") before continuing so
known_special_tokens is still populated from getattr(tokenizer,
'all_special_tokens'...) and getattr(tokenizer, 'additional_special_tokens'...)
as before; reference the tokenizer.get_added_vocab() call and the
known_special_tokens update in multimodal_pretrain.py and use the module or
existing logger (logger or logging.getLogger(__name__)) to record the exception.
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 211-233: The all-text fallback uses self.tokenizer while image
batches go through self.processor, which can cause inconsistent tokenization;
add a check in MultiModalPretrainDataCollator.__post_init__ to warn if
getattr(self.processor, "tokenizer", None) is not self.tokenizer (log a clear
warning mentioning potential inconsistent tokenization between all-text and
image batches), and/or change the all-text branch in the collate code to call
self.processor(text=texts, return_tensors=self.return_tensors,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) when the
processor supports a text-only invocation so both paths use the same tokenizer
pipeline.
- Around line 245-282: The retry loop inside MultiModalPretrainDataCollator that
replays each row individually (the for-loop that calls self.processor(...) to
set offender_idx) can cause N extra heavy processor calls; before starting the
per-row retry, emit a processLogger.warning noting the retry/bisect kickoff and
batch size so users see why processing stalls, and replace the linear scan with
a binary-search (bisect) strategy that tests halves of the remaining index range
by calling self.processor(text=[...], images=[...], **retry_kwargs) to locate
the offending index in O(log N) calls instead of O(N) — keep offender_idx,
retry_ok, and retry_kwargs semantics and preserve the existing RuntimeError
message when locating or failing to locate the offender.
- Around line 61-68: The current scheme check in mm_pretrain (using p and
p_lower) misses cloud/third-party URIs like s3://, gs://, az://, hf://; replace
the explicit startswith list with a generic URI-scheme detector: if p matches
the regex ^[a-z][a-z0-9+.-]*: (case-insensitive) treat it as non-local, but
exempt Windows drive letters (e.g., patterns like C: followed by
slash/backslash). Update the code path that raises the ValueError (the block
that currently raises "Non-local image path scheme..." when
p_lower.startswith(...)) to use this new check, and add s3:// and gs:// (and
other cloud URIs as needed) to test_collator_rejects_remote_urls to lock down
the behavior.
In `@tests/utils/data/test_mm_cpt_eval.py`:
- Line 93: Unpack the return of _prepare_streaming_dataset using a prefixed
underscore for the unused variable so the intent is clear and Ruff RUF059 is
silenced: change the left-hand side of the call that currently uses "train,
eval_ds, _, _ = _prepare_streaming_dataset(...)" to use "_train, eval_ds, _, _ =
_prepare_streaming_dataset(...)" so the unused first value is named _train while
keeping eval_ds and the other placeholders unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4a5ecec3-baad-4185-89f4-7f39775023c8
📒 Files selected for processing (14)
docs/multimodal.qmdsrc/axolotl/core/builders/causal.pysrc/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/collators/mm_pretrain.pysrc/axolotl/utils/data/sft.pysrc/axolotl/utils/data/streaming.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/datasets.pysrc/axolotl/utils/schemas/validation.pytests/conftest.pytests/prompt_strategies/test_multimodal_pretrain.pytests/test_multimodal_streaming.pytests/utils/data/test_mm_cpt_eval.pytests/utils/schemas/validation/test_multimodal_cpt.py
Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ (Gemma-3's `image_token` is the post-expansion soft token, not the user-facing placeholder). Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset now mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]` — keeps malformed rows from silently turning into text-only samples. Schema completeness ------------------- - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now actually reaches `load_dataset`; previously dropped at validation). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass it to `load_dataset` (was silently dropped). - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len` when unset. Validation tightening (config-load time) ---------------------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - The redundant runtime check in `sft.py` is removed; schema is the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud/object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see "Non-local scheme" instead of a confusing FileNotFoundError. - Warn when `MultiModalPretrainDataCollator.tokenizer is not processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error to note image-patch expansion may push the final length higher. Tests ----- +8 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetection, eval_sequence_len (encoder + collator, including the fallback case), `trust_remote_code` and `ds_type` preservation through validation, three modality-mismatch validation cases, tokenizer-mismatch warning, cloud-URI rejection. 68 tests pass across `tests/test_multimodal_streaming.py`, `tests/prompt_strategies/test_multimodal_pretrain.py`, `tests/utils/schemas/validation/test_multimodal_cpt.py`, `tests/utils/data/test_mm_cpt_eval.py`. Lint clean against ruff v0.15.8 (upstream pre-commit pin).
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/utils/data/streaming.py (1)
268-296:⚠️ Potential issue | 🟠 MajorHonor
effective_seq_lenin the packed eval path too.Line 268 computes
effective_seq_len, but Lines 281-292 still hardcodecfg.sequence_len. Withsample_packing: true, eval therefore ignoreseval_sequence_len, which changes truncation/padding behavior and can unexpectedly blow up eval memory.♻️ Proposed fix
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( tokenizer, return_tensors="pt", padding=True, - pad_to_multiple_of=cfg.sequence_len, + pad_to_multiple_of=effective_seq_len, multipack_attn=multipack_attn, ) encode = functools.partial( encode_packed_streaming, collate_fn, ds_wrapper_fn, - max_seq_length=cfg.sequence_len, + max_seq_length=effective_seq_len, batch_size=cfg.micro_batch_size, multipack_attn=multipack_attn, bin_size=cfg.sample_packing_bin_size, )Based on learnings: if
cfg.eval_sequence_lenis not set, eval must fall back tocfg.sequence_len; otherwise the explicit eval override should be honored.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/data/streaming.py` around lines 268 - 296, The packed-sample eval path is still using cfg.sequence_len instead of the computed effective_seq_len, so update the PretrainingBatchSamplerDataCollatorForSeq2Seq instantiation and the functools.partial for encode_packed_streaming to use effective_seq_len (not cfg.sequence_len) for pad_to_multiple_of and max_seq_length; keep the existing multipack_attn logic (True if not cfg.pretraining_dataset else cfg.pretrain_multipack_attn) and other args unchanged so eval falls back to cfg.sequence_len only when cfg.eval_sequence_len is not set.
🧹 Nitpick comments (5)
tests/prompt_strategies/test_multimodal_pretrain.py (1)
72-84: Assert the resolved token id here, not just the chosen token string.This only proves that
boi_tokenwins the name-selection step. Ifbuild_image_token_spec()regresses to returning the right token string but the wrongimage_token_id/family ids, this test still passes. Please make the fixture register a distinctboi_tokenid and assert the resolved id mapping too.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/prompt_strategies/test_multimodal_pretrain.py` around lines 72 - 84, The test currently only asserts the chosen token string; update it to also assert the resolved token id so regressions in id mapping are caught: arrange the smolvlm_processor.tokenizer fixture (used by _FakeGemma3Like) to register a distinct id for the fake boi_token ("<fake_token_around_image>") different from "<image>", call build_image_token_spec(_FakeGemma3Like()) as before, and add an assertion that spec.image_token_id (and any related family ids such as spec.image_padding_token_family_id if present) equals the tokenizer id you registered for "<fake_token_around_image>" to ensure the returned spec maps to the expected token id.src/axolotl/prompt_strategies/multimodal_pretrain.py (4)
299-312: Optional: storeimage_token_specin__init__rather than monkey-patching.Setting
strat.image_token_spec = specafter construction with atype: ignore[attr-defined]works but bypasses the class contract — readers ofMultimodalPretrainTokenizationStrategywon't see this attribute on the type. Sincecausal.pyreads it back, lifting it into the constructor would clean up the type ignore and document the dependency.Proposed change
class MultimodalPretrainTokenizationStrategy(PretrainTokenizationStrategy): def __init__( self, *args: Any, image_token: str, image_token_id: int, image_column: str = "images", image_base_dir: str | None = None, + image_token_spec: ImageTokenSpec | None = None, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.image_token = image_token self.image_token_id = image_token_id self.image_column = image_column self.image_base_dir = image_base_dir + self.image_token_spec = image_token_specstrat = MultimodalPretrainTokenizationStrategy( PretrainTokenizer(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, text_column=text_column, image_column=image_column, image_base_dir=image_base_dir, image_token=spec.image_token, image_token_id=spec.image_token_id, max_length=cfg.sequence_len, + image_token_spec=spec, ) - strat.image_token_spec = spec # type: ignore[attr-defined] return strat🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 299 - 312, Currently the code monkey-patches image_token_spec onto the MultimodalPretrainTokenizationStrategy instance after construction; instead add an image_token_spec parameter (optional) to MultimodalPretrainTokenizationStrategy.__init__ and assign it to self.image_token_spec there, update the constructor signature where the class is defined, and remove the post-construction assignment and its type: ignore in the caller (the instantiation code shown) so the attribute is part of the class contract; ensure any reads in causal.py continue to work with the attribute on the instance.
48-48: Nit: replace ambiguous×(U+00D7) with ASCIIx.Ruff RUF003 flags the multiplication sign as confusable with the Latin letter
x. Trivial fix:Proposed fix
-# Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM. +# Without masking these in labels, loss blows up ~10x on Qwen/SmolVLM.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` at line 48, Replace the Unicode multiplication sign U+00D7 in the comment "Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM." with the ASCII letter 'x' so it reads "Without masking these in labels, loss blows up ~10x on Qwen/SmolVLM."; update that inline comment in multimodal_pretrain.py (the quoted comment) to avoid the confusable character flagged by Ruff RUF003.
266-268:n_chunksis structurally always 1 — multiplication is dead code.
_tokenizewrapsinput_idsin a single-element list, solen(res["input_ids"])is always 1. The* n_chunksreplication is harmless but obscures intent. Either drop the variable or assert it as an invariant so a future change to_tokenizedoesn't silently produce mis-alignedimages/_mm_textlists.Proposed simplification
- n_chunks = len(res["input_ids"]) - res["images"] = [list(images)] * n_chunks - res["_mm_text"] = [text] * n_chunks + # `_tokenize` produces exactly one chunk; replicate per-chunk for schema parity. + assert len(res["input_ids"]) == 1 + res["images"] = [list(images)] + res["_mm_text"] = [text] return res🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 266 - 268, The code currently computes n_chunks = len(res["input_ids"]) and multiplies images/_mm_text by it, but _tokenize always wraps input_ids in a single-element list so this replication is dead code; change the block in multimodal_pretrain.py to assert that len(res["input_ids"]) == 1 (to catch future changes to _tokenize) and then set res["images"] = [list(images)] and res["_mm_text"] = [text] (no multiplication), referencing the variables n_chunks, res["input_ids"], res["images"], res["_mm_text"], and the _tokenize behavior so the intent and invariant are explicit.
221-232:_tokenizesilently ignores itsadd_eos_tokenandstrip_bos_tokenparameters.The signature accepts
add_eos_tokenandstrip_bos_token, but the body unconditionally appends EOS and never inspects either flag. If anything inPretrainTokenizationStrategy(or a future caller) invokes_tokenize(prompt, add_eos_token=False), EOS will be appended anyway, masking the intent.Either honor the flags, or drop them from the signature and add a brief docstring explaining the override is fixed-behavior on purpose.
Honor-the-flag variant
def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False, ) -> BatchEncoding: # No truncation: collator re-tokenizes the full text without truncation; # truncating here decouples the stored ids from what the model receives. res = self.tokenizer(prompt, add_special_tokens=True) - res["input_ids"] = [res["input_ids"] + [self.tokenizer.eos_token_id]] - res["attention_mask"] = [res["attention_mask"] + [1]] + ids = res["input_ids"] + mask = res["attention_mask"] + if strip_bos_token and ids and ids[0] == self.tokenizer.bos_token_id: + ids = ids[1:] + mask = mask[1:] + if add_eos_token and self.tokenizer.eos_token_id is not None: + ids = ids + [self.tokenizer.eos_token_id] + mask = mask + [1] + res["input_ids"] = [ids] + res["attention_mask"] = [mask] return res🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 221 - 232, The _tokenize method currently ignores add_eos_token and strip_bos_token; update _tokenize to respect these flags by: call self.tokenizer(prompt, add_special_tokens=True) as before, then if strip_bos_token is True and the first id equals self.tokenizer.bos_token_id (or bos_token_id is not None), remove that first id and its attention mask entry; then if add_eos_token is True append self.tokenizer.eos_token_id and a 1 to attention_mask, otherwise leave them out; preserve the existing no-truncation behavior and the outer-list wrapping of input_ids and attention_mask, and keep returning the BatchEncoding.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/test_multimodal_streaming.py`:
- Around line 107-123: The test test_encode_counts_placeholders_on_full_text
doesn't trigger truncation; adjust the input so the last image placeholder is
only included when text is fully considered: either increase long_filler (e.g.,
"lorem ipsum " * 200) or reduce max_tokens (e.g., set max_tokens to a small
value like 20) in the call to encode_streaming_multimodal so that a truncating
implementation would drop the last placeholder; keep the assertion and
image_token/image_token_id usage the same.
---
Outside diff comments:
In `@src/axolotl/utils/data/streaming.py`:
- Around line 268-296: The packed-sample eval path is still using
cfg.sequence_len instead of the computed effective_seq_len, so update the
PretrainingBatchSamplerDataCollatorForSeq2Seq instantiation and the
functools.partial for encode_packed_streaming to use effective_seq_len (not
cfg.sequence_len) for pad_to_multiple_of and max_seq_length; keep the existing
multipack_attn logic (True if not cfg.pretraining_dataset else
cfg.pretrain_multipack_attn) and other args unchanged so eval falls back to
cfg.sequence_len only when cfg.eval_sequence_len is not set.
---
Nitpick comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 299-312: Currently the code monkey-patches image_token_spec onto
the MultimodalPretrainTokenizationStrategy instance after construction; instead
add an image_token_spec parameter (optional) to
MultimodalPretrainTokenizationStrategy.__init__ and assign it to
self.image_token_spec there, update the constructor signature where the class is
defined, and remove the post-construction assignment and its type: ignore in the
caller (the instantiation code shown) so the attribute is part of the class
contract; ensure any reads in causal.py continue to work with the attribute on
the instance.
- Line 48: Replace the Unicode multiplication sign U+00D7 in the comment
"Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM." with the
ASCII letter 'x' so it reads "Without masking these in labels, loss blows up
~10x on Qwen/SmolVLM."; update that inline comment in multimodal_pretrain.py
(the quoted comment) to avoid the confusable character flagged by Ruff RUF003.
- Around line 266-268: The code currently computes n_chunks =
len(res["input_ids"]) and multiplies images/_mm_text by it, but _tokenize always
wraps input_ids in a single-element list so this replication is dead code;
change the block in multimodal_pretrain.py to assert that len(res["input_ids"])
== 1 (to catch future changes to _tokenize) and then set res["images"] =
[list(images)] and res["_mm_text"] = [text] (no multiplication), referencing the
variables n_chunks, res["input_ids"], res["images"], res["_mm_text"], and the
_tokenize behavior so the intent and invariant are explicit.
- Around line 221-232: The _tokenize method currently ignores add_eos_token and
strip_bos_token; update _tokenize to respect these flags by: call
self.tokenizer(prompt, add_special_tokens=True) as before, then if
strip_bos_token is True and the first id equals self.tokenizer.bos_token_id (or
bos_token_id is not None), remove that first id and its attention mask entry;
then if add_eos_token is True append self.tokenizer.eos_token_id and a 1 to
attention_mask, otherwise leave them out; preserve the existing no-truncation
behavior and the outer-list wrapping of input_ids and attention_mask, and keep
returning the BatchEncoding.
In `@tests/prompt_strategies/test_multimodal_pretrain.py`:
- Around line 72-84: The test currently only asserts the chosen token string;
update it to also assert the resolved token id so regressions in id mapping are
caught: arrange the smolvlm_processor.tokenizer fixture (used by
_FakeGemma3Like) to register a distinct id for the fake boi_token
("<fake_token_around_image>") different from "<image>", call
build_image_token_spec(_FakeGemma3Like()) as before, and add an assertion that
spec.image_token_id (and any related family ids such as
spec.image_padding_token_family_id if present) equals the tokenizer id you
registered for "<fake_token_around_image>" to ensure the returned spec maps to
the expected token id.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6879bd64-1080-433c-a971-e8c6c0fa3a33
📒 Files selected for processing (12)
docs/multimodal.qmdsrc/axolotl/core/builders/causal.pysrc/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/collators/mm_pretrain.pysrc/axolotl/utils/data/sft.pysrc/axolotl/utils/data/streaming.pysrc/axolotl/utils/schemas/datasets.pysrc/axolotl/utils/schemas/validation.pytests/prompt_strategies/test_multimodal_pretrain.pytests/test_multimodal_streaming.pytests/utils/data/test_mm_cpt_eval.pytests/utils/schemas/validation/test_multimodal_cpt.py
✅ Files skipped from review due to trivial changes (1)
- docs/multimodal.qmd
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/core/builders/causal.py
- src/axolotl/utils/data/sft.py
Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ. Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]`. - `_tokenize` now honors `add_eos_token` / `strip_bos_token` instead of silently ignoring them. Schema ------ - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now reaches `load_dataset`). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass to `load_dataset`. - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len`. Validation (config-load time) ----------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - Removed the redundant runtime check in `sft.py`; the schema is now the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud / object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see the explicit "Non-local scheme" error instead of a confusing FileNotFoundError. - Warn at construction when `MultiModalPretrainDataCollator.tokenizer` is not `processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error in both encoder paths to note image-patch expansion may push the final length higher. Code quality ------------ - Lift `image_token_spec` into `MultimodalPretrainTokenizationStrategy. __init__` instead of post-construction monkey-patch + `type: ignore`. - Hoist `import importlib` out of the per-class loop. - Drop dead `n_chunks` multiplication; replace with explicit invariant assertion. - Replace ambiguous `×` (U+00D7) with ASCII `x` in code/comments and the user-facing pixel-cap error. Tests ----- +15 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetect (with id-mapping assertion), `eval_sequence_len` on encoder + collator (set + unset-fallback), `trust_remote_code` and `ds_type` preservation, three modality-mismatch validation cases, tokenizer-mismatch warning, `remove_unused_columns` auto-set log, cloud-URI rejection. 68 tests pass across `test_multimodal_streaming`, `test_multimodal_pretrain`, `test_multimodal_cpt`, and `test_mm_cpt_eval`. Lint clean against ruff v0.15.8 (upstream pre-commit pin).
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
🧹 Nitpick comments (3)
src/axolotl/prompt_strategies/multimodal_pretrain.py (1)
278-284: Convert the alignment invariant fromassertto an explicit raise.This
assertguards a non-trivial correctness property: if_tokenizeever returns more than one chunk (e.g., a future change re-introduces overflow striding),images=[list(images)]and_mm_text=[text]will be silently misaligned withinput_ids[1:], sending unpaired text into the model. Withpython -O(or any optimized deployment),assertstatements are stripped, turning this into silent data corruption rather than a clear failure. The neighboring_tokenizedoesn't passtruncation=True/stride=...today, but this guard is precisely meant to catch a future regression — so it shouldn't depend on__debug__.♻️ Proposed fix
- # `_tokenize` produces exactly one chunk; the assert keeps that - # invariant explicit so a future change there can't silently - # mis-align `images` / `_mm_text` against `input_ids`. - assert len(res["input_ids"]) == 1 + # `_tokenize` produces exactly one chunk; this guard keeps the + # invariant explicit so a future change there can't silently + # mis-align `images` / `_mm_text` against `input_ids`. Use a real + # raise (not assert) so it survives `python -O`. + if len(res["input_ids"]) != 1: + raise RuntimeError( + "MultimodalPretrainTokenizationStrategy._tokenize produced " + f"{len(res['input_ids'])} chunks; multimodal CPT requires " + "exactly one chunk per row to keep `images`/`_mm_text` " + "aligned with `input_ids`." + ) res["images"] = [list(images)] res["_mm_text"] = [text] return res🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 278 - 284, Replace the fragile assert in multimodal_pretrain.py that checks len(res["input_ids"]) == 1 with an explicit runtime check that raises a clear exception (e.g., ValueError or RuntimeError) if the invariant is violated; perform this check before mutating res by setting res["images"] and res["_mm_text"], and include the actual length in the error message so callers of the function (the code path around the _tokenize call and the block that assigns res["images"] = [list(images)] and res["_mm_text"] = [text]) will fail loudly instead of silently misaligning inputs under optimized Python.src/axolotl/utils/collators/mm_pretrain.py (2)
116-145: O_NOFOLLOW only protects the final path component.
os.open(resolved, O_RDONLY | O_NOFOLLOW)closes the final-link TOCTOU window afterrealpath(), but parent directories alongresolvedcan still be replaced with symlinks betweenrealpath()andos.open(). Combined with thecommonpathcontainment check, this is a reasonable trade-off for v1, but worth flagging indocs/multimodal.qmdso operators know the residual risk on shared/multi-tenant filesystems. A future hardening pass could useos.openat()on aimage_base_dirfd to walk components atomically.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 116 - 145, The _open_image_hardened method uses os.open(..., O_NOFOLLOW) which only prevents TOCTOU on the final path component; parent directories can still be swapped after realpath() and before os.open(), so add a note to docs/multimodal.qmd describing this residual risk on shared/multi-tenant filesystems and that the current commonpath+O_NOFOLLOW approach is an acceptable v1 tradeoff; also add a TODO to consider a future hardening that walks components from an image_base_dir fd using os.openat() (or similar) to eliminate the parent-directory symlink race.
233-256: Inconsistent return type vs the imaged path.The all-text fallback returns
dict(batch)while the imaged path at the bottom returns theBatchEncodingdirectly (line 337). Both are mapping-compatible, butBatchEncodingcarries.to(device),.convert_to_tensors(), and other helpers that callers of the collator (HFTrainer._prepare_inputs, custom devices, accelerate handlers) sometimes rely on. Returning a plaindictfrom one branch and aBatchEncodingfrom the other is an avoidable surprise.♻️ Proposed fix
batch["labels"] = tok_labels - return dict(batch) + return batch🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 233 - 256, The all-text branch converts the tokenizer's BatchEncoding to a plain dict before returning, causing an inconsistent return type versus the imaged path; instead of returning dict(batch) keep and return the original BatchEncoding object from self.tokenizer so callers (e.g., Trainer._prepare_inputs, accelerate) retain .to(), .convert_to_tensors(), etc. Locate the all-text branch where `batch = self.tokenizer(**tok_kwargs)` and `batch["labels"] = tok_labels` are set and replace `return dict(batch)` with `return batch`.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 278-284: Replace the fragile assert in multimodal_pretrain.py that
checks len(res["input_ids"]) == 1 with an explicit runtime check that raises a
clear exception (e.g., ValueError or RuntimeError) if the invariant is violated;
perform this check before mutating res by setting res["images"] and
res["_mm_text"], and include the actual length in the error message so callers
of the function (the code path around the _tokenize call and the block that
assigns res["images"] = [list(images)] and res["_mm_text"] = [text]) will fail
loudly instead of silently misaligning inputs under optimized Python.
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 116-145: The _open_image_hardened method uses os.open(...,
O_NOFOLLOW) which only prevents TOCTOU on the final path component; parent
directories can still be swapped after realpath() and before os.open(), so add a
note to docs/multimodal.qmd describing this residual risk on shared/multi-tenant
filesystems and that the current commonpath+O_NOFOLLOW approach is an acceptable
v1 tradeoff; also add a TODO to consider a future hardening that walks
components from an image_base_dir fd using os.openat() (or similar) to eliminate
the parent-directory symlink race.
- Around line 233-256: The all-text branch converts the tokenizer's
BatchEncoding to a plain dict before returning, causing an inconsistent return
type versus the imaged path; instead of returning dict(batch) keep and return
the original BatchEncoding object from self.tokenizer so callers (e.g.,
Trainer._prepare_inputs, accelerate) retain .to(), .convert_to_tensors(), etc.
Locate the all-text branch where `batch = self.tokenizer(**tok_kwargs)` and
`batch["labels"] = tok_labels` are set and replace `return dict(batch)` with
`return batch`.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 31fbfd9a-34c9-4e3a-9046-c3ac9bf7a022
📒 Files selected for processing (4)
src/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/collators/mm_pretrain.pytests/prompt_strategies/test_multimodal_pretrain.pytests/test_multimodal_streaming.py
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
@coderabbitai review , your last suggestion introduced a regression in gemma4, resolved in last commit |
|
Triggering a review of the latest commit now. ✏️ Learnings added
🧠 Learnings used✅ Actions performedReview triggered.
|
|
✏️ Learnings added
🧠 Learnings used✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
src/axolotl/utils/collators/mm_pretrain.py (1)
268-313: Per-row retry can mislabel a text-only row as the offender in mixed batches.The all-text bypass at line 234 already documents that some processors can't be called cleanly with
images=[[]]. In the per-row retry on line 290, text-only rows are sent toself.processor(text=[t], images=[[]], ...)regardless. If the original batch failure was caused by an image row but a co-located text-only row also raises on standalone processor call, the loop'sbreakon the first exception setsoffender_idxto that text-only row and shadows the actual culprit in the diagnostic message. The original exception is still chained viafrom exc, so this is a misleading-error concern (not a correctness one), but it does undercut the value of the retry pinpointing.A small mirror of the line-234 fast-path — routing text-only retry rows through
self.tokenizer— keeps the retry diagnostic aligned with the production tokenization path without re-opening the deferred per-subset homogeneous-fallback discussion.🛠️ Suggested narrow change to retry path
for i, (t, imgs) in enumerate(zip(texts, images, strict=True)): try: - self.processor(text=[t], images=[imgs], **retry_kwargs) + if len(imgs) == 0: + # Mirror the all-text bypass; some processors raise on images=[[]]. + self.tokenizer(text=[t], **retry_kwargs) + else: + self.processor(text=[t], images=[imgs], **retry_kwargs) except Exception as retry_exc:Based on learnings: the collator is documented as passing mixed batches directly to the processor for the tested set (Gemma 3/4, Qwen 3.5, SmolVLM/SmolVLM2, Qwen 2.5VL), with per-subset fallback deferred to a future PR — this suggestion is intentionally scoped to the diagnostic-only retry path so it doesn't touch that deferred work.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/collators/mm_pretrain.py` around lines 268 - 313, The per-row retry loop in MultiModalPretrainDataCollator currently calls self.processor(...) for every row, which can misattribute failures from image rows to text-only rows; change the retry loop so that if a row is text-only (e.g., imgs is empty or falsy / images element is []), call self.tokenizer(text=[t], **retry_kwargs) instead of self.processor, otherwise call self.processor(text=[t], images=[imgs], **retry_kwargs); keep retry_kwargs (return_tensors, padding, pad_to_multiple_of) the same and do not alter the main mixed-batch processing path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/test_multimodal_streaming.py`:
- Around line 188-193: The file was reformatted by ruff-format; update the
constructor call to match ruff's preferred style by collapsing the dict literal
onto one line for the _StubTokenizer(...) invocation and ensure the subsequent
_StubProcessor(...) call remains formatted correctly; specifically modify the
tok = _StubTokenizer({...}) line (and keep proc = _StubProcessor(...) as is) to
conform to ruff format rules, then re-run pre-commit or ruff format and commit
the change.
In `@tests/utils/schemas/validation/test_multimodal_cpt.py`:
- Around line 321-334: The test
test_remove_unused_columns_already_false_does_not_log incorrectly sets propagate
on the child logger; change the monkeypatch target to the parent logger
logging.getLogger("axolotl") so caplog can capture records emitted by
descendants; keep the rest of the test flow (create cfg with
_mm_cpt_cfg(remove_unused_columns=False), call validate_config(cfg), and assert
no "Auto-set" messages) so the negative assertion actually fails if the auto-set
log is emitted.
---
Nitpick comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 268-313: The per-row retry loop in MultiModalPretrainDataCollator
currently calls self.processor(...) for every row, which can misattribute
failures from image rows to text-only rows; change the retry loop so that if a
row is text-only (e.g., imgs is empty or falsy / images element is []), call
self.tokenizer(text=[t], **retry_kwargs) instead of self.processor, otherwise
call self.processor(text=[t], images=[imgs], **retry_kwargs); keep retry_kwargs
(return_tensors, padding, pad_to_multiple_of) the same and do not alter the main
mixed-batch processing path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b0c90ff4-69ff-446b-b76d-13459369affe
📒 Files selected for processing (4)
src/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/collators/mm_pretrain.pytests/test_multimodal_streaming.pytests/utils/schemas/validation/test_multimodal_cpt.py
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/test_multimodal_streaming.py (1)
224-375: 💤 Low valueLGTM — eval-vs-train sequence-len fallback coverage is solid.
The three-case split (eval w/
eval_sequence_lenset → 1024; train → 4096; eval w/eval_sequence_lenunset → falls back to 4096) directly encodes the documented schema contract and would loudly catch any regression that drops the fallback.One small fragility note:
monkeypatch.setattr("axolotl.utils.data.streaming.functools.partial", ...)only intercepts calls ifstreaming.pyreferencesfunctools.partialvia attribute access — if anyone refactors tofrom functools import partial; partial(...), the patch becomes a silent no-op and these assertions still pass via the lambda passthrough. Worth a comment in the test or rebinding toaxolotl.utils.data.streaming.partialif that import style is ever introduced.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_multimodal_streaming.py` around lines 224 - 375, The test monkeypatch targets "axolotl.utils.data.streaming.functools.partial", which only works if streaming.py references functools via attribute access; change the monkeypatch to bind to the actual symbol the module uses (e.g., patch "axolotl.utils.data.streaming.partial" if streaming.py does "from functools import partial", or keep the current target but add a comment warning about the import style fragility), by updating the monkeypatch.setattr call in test_wrap_streaming_dataset_uses_pretraining_config_arg and test_wrap_streaming_dataset_eval_honors_eval_sequence_len to patch the exact name exported/used by streaming.py (or use raising=False) so the fake_partial reliably replaces the encoder partial regardless of import style.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/utils/collators/mm_pretrain.py`:
- Around line 117-126: In _open_image_hardened, if os.open succeeds but
os.fdopen raises, the raw fd is leaked; wrap the os.fdopen call in a try/except
that closes the fd on failure (or use contextlib.closing) and re-raise a
ValueError/appropriate exception with the original error as cause so the
descriptor is always closed; adjust the error handling around os.open and
os.fdopen in _open_image_hardened to ensure fd is closed when fdopen fails.
---
Nitpick comments:
In `@tests/test_multimodal_streaming.py`:
- Around line 224-375: The test monkeypatch targets
"axolotl.utils.data.streaming.functools.partial", which only works if
streaming.py references functools via attribute access; change the monkeypatch
to bind to the actual symbol the module uses (e.g., patch
"axolotl.utils.data.streaming.partial" if streaming.py does "from functools
import partial", or keep the current target but add a comment warning about the
import style fragility), by updating the monkeypatch.setattr call in
test_wrap_streaming_dataset_uses_pretraining_config_arg and
test_wrap_streaming_dataset_eval_honors_eval_sequence_len to patch the exact
name exported/used by streaming.py (or use raising=False) so the fake_partial
reliably replaces the encoder partial regardless of import style.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a3c86584-ecd7-4bf6-890b-f60285e9be60
📒 Files selected for processing (6)
src/axolotl/utils/collators/mm_pretrain.pysrc/axolotl/utils/schemas/validation.pytests/conftest.pytests/prompt_strategies/test_multimodal_pretrain.pytests/test_multimodal_streaming.pytests/utils/schemas/validation/test_multimodal_cpt.py
✅ Files skipped from review due to trivial changes (1)
- tests/conftest.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/utils/schemas/validation.py
Adds a streaming-first multimodal CPT path: raw `(text, images[])` rows
are tokenized once with a placeholder-count guardrail, batched through a
hardened collator, and fed to a VLM with image-family tokens masked out
of labels. Gated by `type: multimodal_pretrain` (or `multimodal: true`)
on a `pretraining_dataset` entry; works end-to-end for train and eval,
including multi-entry eval and mixed image/text batches.
Features
--------
- Streaming MM CPT encoder (`encode_streaming_multimodal`): counts
placeholders by token id (not substring), enforces
`placeholders == len(images)` per row, and rejects rows that exceed
`sequence_len` instead of silently truncating mid-placeholder.
- MM CPT collator (`MultiModalPretrainDataCollator`): security-hardened
image loader (path traversal / NUL byte / remote URL / multi-frame
bomb / pixel cap rejection), per-row image cap, processor-call retry
that pinpoints the offending row, and label-side masking of every
image-family token id.
- Mixed image/text batches: text-only rows in a batch take a
tokenizer-only fallback (no `pixel_values`); rows with images go
through the processor as usual.
- Eval support: `test_datasets` accepts MM entries via a dedicated
`MultiModalEvalDataset` model so per-entry `text_column` /
`image_column` / `image_base_dir` / `image_token` survive validation.
Multi-entry MM eval streams are concatenated.
- `dispatch_batches: true` support: non-main ranks get a placeholder
dataset that mirrors the configured text + image columns.
- Config validation gates: `processor_type` required, `sample_packing:
false` enforced, `chat_template` rejected, single
`pretraining_dataset` entry required, MM eval entries must share
`image_base_dir` / `image_token`, mixed MM/non-MM eval rejected,
incompatible processor classes (Mllama, Pixtral, InternVL) rejected
at startup. `remove_unused_columns` is auto-set to `false` with an
INFO log.
- Docs: new section in `docs/multimodal.qmd` covering the YAML shape,
placeholder-token table, eval contract, and supported/rejected model
families.
YAML example
------------
base_model: HuggingFaceTB/SmolVLM-500M-Instruct
processor_type: AutoProcessor
pretraining_dataset:
- path: /path/to/shards/*.jsonl
ds_type: json
type: multimodal_pretrain
text_column: text
image_column: images
image_base_dir: /path/to/images
streaming: true
sequence_len: 2048
sample_packing: false
Tests
-----
59 tests across four suites covering the encoder, collator (including
mixed/all-text batches and security gates), prompt strategy, schema
preservation, multi-entry eval merge, eval homogeneity validation,
eval-aware collator, dispatch-batches placeholder shape, and the
auto-set log record.
Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ (Gemma-3's `image_token` is the post-expansion soft token, not the user-facing placeholder). Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset now mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]` — keeps malformed rows from silently turning into text-only samples. Schema completeness ------------------- - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now actually reaches `load_dataset`; previously dropped at validation). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass it to `load_dataset` (was silently dropped). - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len` when unset. Validation tightening (config-load time) ---------------------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - The redundant runtime check in `sft.py` is removed; schema is the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud/object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see "Non-local scheme" instead of a confusing FileNotFoundError. - Warn when `MultiModalPretrainDataCollator.tokenizer is not processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error to note image-patch expansion may push the final length higher. Tests ----- +8 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetection, eval_sequence_len (encoder + collator, including the fallback case), `trust_remote_code` and `ds_type` preservation through validation, three modality-mismatch validation cases, tokenizer-mismatch warning, cloud-URI rejection. 68 tests pass across `tests/test_multimodal_streaming.py`, `tests/prompt_strategies/test_multimodal_pretrain.py`, `tests/utils/schemas/validation/test_multimodal_cpt.py`, `tests/utils/data/test_mm_cpt_eval.py`. Lint clean against ruff v0.15.8 (upstream pre-commit pin).
Addresses CodeRabbit review on PR axolotl-ai-cloud#3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ. Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]`. - `_tokenize` now honors `add_eos_token` / `strip_bos_token` instead of silently ignoring them. Schema ------ - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now reaches `load_dataset`). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass to `load_dataset`. - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len`. Validation (config-load time) ----------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - Removed the redundant runtime check in `sft.py`; the schema is now the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud / object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see the explicit "Non-local scheme" error instead of a confusing FileNotFoundError. - Warn at construction when `MultiModalPretrainDataCollator.tokenizer` is not `processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error in both encoder paths to note image-patch expansion may push the final length higher. Code quality ------------ - Lift `image_token_spec` into `MultimodalPretrainTokenizationStrategy. __init__` instead of post-construction monkey-patch + `type: ignore`. - Hoist `import importlib` out of the per-class loop. - Drop dead `n_chunks` multiplication; replace with explicit invariant assertion. - Replace ambiguous `×` (U+00D7) with ASCII `x` in code/comments and the user-facing pixel-cap error. Tests ----- +15 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetect (with id-mapping assertion), `eval_sequence_len` on encoder + collator (set + unset-fallback), `trust_remote_code` and `ds_type` preservation, three modality-mismatch validation cases, tokenizer-mismatch warning, `remove_unused_columns` auto-set log, cloud-URI rejection. 68 tests pass across `test_multimodal_streaming`, `test_multimodal_pretrain`, `test_multimodal_cpt`, and `test_mm_cpt_eval`. Lint clean against ruff v0.15.8 (upstream pre-commit pin).
- mm_pretrain.py: return BatchEncoding (not dict) from all-text branch so
it matches the imaged path.
- test_multimodal_cpt.py, test_multimodal_streaming.py: monkeypatch
axolotl logger propagate=True so caplog can capture records (axolotl's
logging config sets propagate=False, blocking root capture in CI).
multimodal_pretrain.py: scope the boi_token swap in build_image_token_spec
to processors whose `image_token` name contains "soft_token" (the Gemma-3
convention). Without this, Gemma-4 (`image_token=<|image|>`,
`boi_token=<|image>`) gets the wrong placeholder autodetected and every
row fails validation with a 0-vs-N placeholder/image mismatch.
test_multimodal_streaming.py: 6 new tests
- Two for the new autodetection behavior (Gemma-4 keeps image_token,
Gemma-3 still swaps to boi_token), using stub processors.
- Three branch-coverage tests for build_image_token_spec failure modes:
override not registered as special token, override resolves to unk,
nothing autodetectable.
- Three collator-path tests: skip_bad_images drops a row and continues,
all-rows-dropped surfaces a RuntimeError, multi-frame GIF triggers
the animation-bomb guard via _open_image_hardened.
fix(test): patch parent `axolotl` logger so negative caplog assertion has teeth
The previous monkeypatch targeted `axolotl.utils.schemas.validation`, which
is already propagate=True by inheritance — the actual block sits one level
up at the `axolotl` logger (propagate=False from logging_config.py). The
result: caplog never received any records, and `assert not any("Auto-set"
... in caplog.records)` would have passed even if the regression fired.
Mirror the positive test by flipping propagate on `logging.getLogger("axolotl")`
and add a comment explaining why the leaf isn't the right target.
…ry loop Some HF processors reject `images=[[]]`, which made the per-row retry flag innocent text-only rows as the offender. Mirror the all-text bypass — diagnostic-only path, mainline unchanged.
Fix text following Gemma 4 regression fix
MultiModalPretrainDataCollator.torch_call calls processor(text=...) which re-tokenizes _mm_text from scratch, discarding the EOS that encode_streaming_multimodal appended to input_ids. Without this, labels never contain EOS at end-of-document and the model never learns to emit a stop token — symptoms: non-terminating / repetitive generation. Match the text CPT contract (encode_streaming keeps EOS in both input_ids and labels) by appending EOS to _mm_text, idempotently, gated on a new add_eos_token field (default True).
|
No blockers, mostly bloat-trim. |
I just committed based on your feedback with the cleanup and stripping out the associated tests |
Aligns mm_pretrain.py with mm_chat.py's image-loading posture. Drops NUL/URL/path-traversal/pixel-cap/multi-frame/per-row-count guards that defended against threat models that don't apply to a CLI trainer loading its own dataset. Routes image loading through transformers.image_utils.load_image, matching the chat path. Keeps image_base_dir join, skip_bad_images, label masking, processor compatibility check, and the tokenizer/processor mismatch warning.
|
@ved1beta sorry forgot to tag you on my last response |
|
lint please |
@ved1beta just committed lint fixes |
|
@ved1beta assuming test failure is unrelated to PR since it previously passed and only failed after syncing with main |
The non-streaming `datasets:` MM CPT route was never wired through `build_collator`, which only routes MM batches under the pretraining branch — `datasets:` entries would emit `images`/`_mm_text` rows into a text-only collator. Strip the strategy class + `load()` and their unit tests; keep `ImageTokenSpec`, `build_image_token_spec`, and `check_processor_compatibility` since the streaming collator imports them. Add a docs callout that only the streaming `pretraining_dataset` route is currently wired. Fold `MultiModalEvalDataset` into `PretrainingDataset` via inheritance; the only intentional divergence is the `type` default and the `_require_mm_markers` validator. Drops ~60 lines of duplicated `Field` declarations the reviewer flagged. Tighten the collator `KeyError` message to mention only `encode_streaming_multimodal` now that the strategy class is gone.
|
@ved1beta Pushed the carve-out (389e0f2) which strips the non-streaming strategy class + load() and associated unit tests. Addressed MultiModalEvalDataset into PretrainingDataset via inheritance (only the type default and the _require_mm_markers validator differ now) and added a docs callout that only the streaming pretraining_dataset route is currently wired. Impacted tests passed (60/60) and ran multi-step test on a local Qwen3-VL completes end-to-end with the streaming MM CPT path |
|
mostly good , config using type: multimodal_pretrain under datasets: will now fail at strategy load time with a raw AttributeError. make it a small ValueError("multimodal_pretrain is only supported via pretraining_dataset with streaming: true — see docs/multimodal.qmd") |
ved1beta
left a comment
There was a problem hiding this comment.
looks good after the val error , requesting nano for quick look if i missed anything LGTM
…datasets: Previously failed with a raw AttributeError at strategy load time. Now raises a small ValueError pointing users to the supported entry point.
thank you, just added the valueerror per your request I do have a follow up PR planned to address the non-streaming path and improve resumption speed on streaming path @NanoCode012 |
Description
Adds a streaming-first multimodal CPT path: raw
(text, images[])rows are tokenized once with a placeholder-count guardrail, batched through a hardened collator, and fed to a VLM with image-family tokens masked out of labels. Gated bytype: multimodal_pretrain(ormultimodal: true) on apretraining_datasetentry; works end-to-end for train and eval, including multi-entry eval and mixed image/text batches.Features
encode_streaming_multimodal): counts placeholders by token id (not substring), enforcesplaceholders == len(images)per row, and rejects rows that exceedsequence_leninstead of silently truncating mid-placeholder.MultiModalPretrainDataCollator): security-hardened image loader (path traversal / NUL byte / remote URL / multi-frame bomb / pixel cap rejection), per-row image cap, processor-call retry that pinpoints the offending row, and label-side masking of every image-family token id.pixel_values); rows with images go through the processor as usual.test_datasetsaccepts MM entries via a dedicatedMultiModalEvalDatasetmodel so per-entrytext_column/image_column/image_base_dir/image_tokensurvive validation. Multi-entry MM eval streams are concatenated.dispatch_batches: truesupport: non-main ranks get a placeholder dataset that mirrors the configured text + image columns.processor_typerequired,sample_packing: falseenforced,chat_templaterejected, singlepretraining_datasetentry required, MM eval entries must shareimage_base_dir/image_token, mixed MM/non-MM eval rejected, incompatible processor classes (Mllama, Pixtral, InternVL) rejected at startup.remove_unused_columnsis auto-set tofalsewith an INFO log.docs/multimodal.qmdcovering the YAML shape, placeholder-token table, eval contract, and supported/rejected model families.YAML example
Motivation and Context
Existing axolotl multimodal paths assume conversational SFT (chat templates with image placeholders inside a structured turn format). Raw image+text continued pre-training — feeding a VLM unstructured documents that interleave text and images — wasn't supported. This PR adds that path without disturbing the existing chat-template SFT flow: it's a separate
type: multimodal_pretrainopt-in that reuses the streaming pretraining infrastructure and the existing processor/tokenizer wiring.How has this been tested?
59 unit tests across four suites:
tests/prompt_strategies/test_multimodal_pretrain.py— encoder/strategy: placeholder count guardrail, sequence-length enforcement, falsy-image rejection, processor compatibility gate.tests/test_multimodal_streaming.py— streaming encoder + collator: end-to-end batch construction, label masking, mixed-batch and all-text-batch handling, security gates (path traversal, NUL byte, remote URLs, image cap, error sanitization).tests/utils/schemas/validation/test_multimodal_cpt.py— config validation: processor/packing/template gates, single pretraining-entry rule, MM eval schema preservation, eval homogeneity (image_base_dir/image_token),remove_unused_columnsauto-set INFO log.tests/utils/data/test_mm_cpt_eval.py— eval data path: placeholder schema fordispatch_batches, multi-entry eval merge, mixed MM/non-MM rejection, eval-aware collator config source.Tested with
HuggingFaceTB/SmolVLM-500M-Instruct, Gemma 4, Gemma 3, Qwen 2.5VL and Qwen 3.5, as the reference processor. Lint (ruff/ruff-format) and mypy clean against the patches added in this branch.Real world testing:
In addition to the unit suite, the path was validated against two end-to-end LoRA-CPT runs on a OCR corpus (raw
\ntext rows, streamed via type: multimodal_pretrain):
Both runs trained to multi-thousand-step stability without divergence or NaN, with periodic eval against a separate test_datasets entry and image-family label masking applied throughout. Confirms the streaming encoder + collator + eval path work end-to-end in real training under FSDP and single-GPU configs.
AI Usage Disclaimer
Yes — Claude Code (Opus 4.7) was used for design discussion, code review, refactoring suggestions, and test generation. All code was reviewed and validated before commit. Codex was used for additional vetting, testing and review. Coderabbit reviews were addressed on personal fork for every commit.
Types of changes
Summary by CodeRabbit
New Features
Documentation
Configuration & Validation
Tests