feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries`#3625
Conversation
…daries Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` / `cfg.train_on_eos` in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config. Also adds `cfg.role_boundaries` YAML override so users can declare per-role markers without subclassing. What changed ------------ - `ProcessingStrategy` gains a declarative boundary scanner. Each strategy declares per-role start/end markers via `_build_role_boundaries`; the shared scanner honors `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last"). - New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken. - Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared). - Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via `cfg.role_boundaries`. - Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]` token between user-end and assistant-start via `include_end=False` + scanner rewind. See `docs/multimodal_assistant_mask.md` for the full audit table, root-cause analysis, and design rationale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…daries Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` / `cfg.train_on_eos` in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config. Also adds `cfg.role_boundaries` YAML override so users can declare per-role markers without subclassing. What changed ------------ - `ProcessingStrategy` gains a declarative boundary scanner. Each strategy declares per-role start/end markers via `_build_role_boundaries`; the shared scanner honors `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last"). - New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken. - Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared). - Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via `cfg.role_boundaries`. - Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]` token between user-end and assistant-start via `include_end=False` + scanner rewind. See `docs/multimodal_assistant_mask.md` for the full audit table, root-cause analysis, and design rationale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…om/thad0ctor/axolotl into feat/multimodal-assistant-mask-all
- builders/causal.py: add inline NOTE that multi-dataset configs reuse the first dataset's masking knobs (roles_to_train / train_on_eos) for all datasets — heterogeneous per-dataset overrides are not supported in the MM path today. - processing_strategies.py: annotate inner scanner helpers _match_prefix and _find_end with explicit types (Tensor, int, list[int] → bool / tuple[int, bool]) for readability. - docs/multimodal_assistant_mask.md: renumber the "Commits on this branch" list to 1-7 consecutive (previously skipped 3). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
`_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
`SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
users hit a pydantic ValidationError at config load. Added "none" to
the Literal and updated the description.
2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
ctor treated it as "replace built-ins with empty" while the collator
plumbing treated it as "unset", and both the design doc and the
MultiModalConfig schema help text promised wholesale replacement for
any set value. Aligned on opt-in semantics across all four surfaces —
a non-empty list replaces built-ins wholesale; unset or `[]` falls back
to built-ins. Rationale: honoring `[]` literally yields all-masked
labels and zero gradient, which is almost always a typo or leftover
rather than a deliberate user action. Users who want to disable role
masking should unset the field or use `train_on_inputs: true`.
Also sharpened the fallback one-shot warning for strategies without
built-in boundaries: names the consequence ("only pad and media tokens
are masked, every other token contributes to loss") and points users
at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
of "see axolotl/processing_strategies.py for how to declare
boundaries."
Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
block; updated the fallback-path detection paragraph to quote the new
warning text
- tests/test_processing_strategies.py: +2 regressions
(test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds token-level, role-boundary label masking for multimodal assistants, threads dataset-level masking knobs (train_on_inputs, roles_to_train, train_on_eos, role_boundaries) from builder into collators, extends schemas for overrides, and adds comprehensive tests and docs describing configuration and verification. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Apologies in advance that this is a significant Opus-written PR, I can architect but my coding lacks. All +30 local fork PRs were reviewed by coderabbit and codex as well and tested thoroughly with training runs and against the existing workflow tests in the repo in my fork. |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (7)
docs/multimodal_assistant_mask.md (1)
158-160: Nit: add language to fenced code block (MD040).The collator log example block has no language identifier. A plain
textorlogidentifier satisfies the linter and enables consistent rendering.📝 Proposed fix
`build_collator` logs the resolved knobs at INFO: -``` +```text MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none ```🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/multimodal_assistant_mask.md` around lines 158 - 160, Add a language identifier to the fenced code block containing "MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none" so it satisfies the linter (MD040); edit the opening fence from ``` to ```text (or ```log) to mark the block as plain text/log and keep the block content unchanged.src/axolotl/processing_strategies.py (3)
655-659: Hardcoded262144for Gemma 3 soft image token is brittle.The literal id will silently go stale if the tokenizer vocab shifts (checkpoint variant, upstream retokenization, custom fine-tune with added specials). Consider resolving via
convert_tokens_to_ids("<image_soft_token>")and only falling back to the literal if the token isn't in vocab — mirrors whatGemma4ProcessingStrategy.process_labelsalready does forboi/eoi/boa/eoa.♻️ Proposed fix
def process_labels(self, input_ids): labels = super().process_labels(input_ids) - # Gemma3-specific <image_soft_token> id; not exposed as a tokenizer attribute. - labels[labels == 262144] = -100 + # Gemma3-specific <image_soft_token>; resolve via tokenizer, fallback to known id. + tok = self.processor.tokenizer + soft_id = tok.convert_tokens_to_ids("<image_soft_token>") + unk_id = getattr(tok, "unk_token_id", None) + if soft_id is not None and soft_id != unk_id: + labels[labels == soft_id] = -100 + else: + labels[labels == 262144] = -100 return labels🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/processing_strategies.py` around lines 655 - 659, The process_labels method hardcodes the Gemma3 image soft token id (262144); change it to resolve the id from the tokenizer first (e.g., call tokenizer.convert_tokens_to_ids("<image_soft_token>") or the equivalent lookup used in Gemma4ProcessingStrategy.process_labels) and only use the numeric fallback when the token is not present in vocab; update process_labels to compute image_soft_token_id dynamically, then replace labels[labels == image_soft_token_id] = -100, falling back to 262144 if the conversion returns an unknown id.
338-437: Scanner core_apply_role_boundarieslooks correct.Verified traces for: basic masking, longest-prefix wins, multi-batch, include_start/include_end semantics,
train_on_eos in {"turn","all","none","last"}, Pixtral shared-marker rewind (Line 426-427 correctly setsj = end_after - len(end_tokens)so the shared[/INST]gets re-matched as assistant-start), and truncated spans (missing end marker runs to EOS). Tests intests/test_processing_strategies.pycover each case.Minor optimization note for a future pass:
_match_prefixcalls.tolist()on a tensor slice for every position × every boundary, which is a lot of Python-side allocation for long sequences. Converting labels to a Python list once per row before the scan (or usingtorch.equalon a pre-built tensor per boundary) would amortize it. Deferrable — real-world MM sequences are usually ≤ few thousand tokens and markers are 1-5 ids.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/processing_strategies.py` around lines 338 - 437, The scan is correct but _match_prefix repeatedly calls label[start_pos:start_pos+len(tok_seq)].tolist(), causing many Python list allocations; to optimize, convert each row once to a Python list (e.g., create label_list = label.tolist() inside the for i in range(...) loop) and change _match_prefix to compare slices from that label_list against boundary.start_tokens (or pre-store boundary.start_tokens as tuples) so comparisons are pure-Python without tensor-to-list conversions; alternatively, you can compare tensors with torch.equal against pre-built tensors for each RoleBoundary.start_tokens, but the simplest fix is to use label_list and update _match_prefix to use it (references: function _apply_role_boundaries, local name label, helper _match_prefix, and RoleBoundary.start_tokens).
296-325: One-shot warning for boundary-less strategies is well-designed.The
_ROLE_MASK_WARNEDmodule-level dedupe correctly suppresses per-batch spam. One concern for distributed training: on N ranks, each rank's worker process keeps its own set, so you'll see the warning N times in aggregated logs (one per rank). That's mild noise but may be misread as "something is repeatedly misfiring." If you want rank-0-only behavior, gating onint(os.environ.get("RANK", "0")) == 0is the usual trick. Deferrable.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/processing_strategies.py` around lines 296 - 325, The warning for boundary-less strategies in _mask_non_assistant currently deduplicates per-process via _ROLE_MASK_WARNED but still emits once per rank; modify the branch that checks "if key not in _ROLE_MASK_WARNED" to also check the process rank and only add/log from rank 0 (e.g., gate on int(os.environ.get("RANK", "0")) == 0) before calling LOG.warning and adding to _ROLE_MASK_WARNED, so other ranks skip the warning while preserving the single-shot behavior; touch the _mask_non_assistant function and LOG.warning usage and ensure role_boundaries behavior remains unchanged.src/axolotl/core/builders/causal.py (2)
524-548: Per-dataset knob extraction looks correct; single-dataset assumption is well-documented.The
_ds_gethelper cleanly handles the DictDefault/dict/pydantic trio, and the NOTE comment on Lines 525-526 is exactly the caveat I was going to flag (heterogeneous per-dataset masking knobs silently use the first dataset's values). Docsdocs/multimodal_assistant_mask.mdalso call this out.One very minor note:
_ds_getis defined insidebuild_collator, so it gets re-created on every call (including the double call above). Hoisting it to module scope would let you unit-test it directly and avoid the re-definition. Optional.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/builders/causal.py` around lines 524 - 548, The helper _ds_get is being defined inside build_collator causing it to be re-created on each call; hoist _ds_get to module scope (keep the same signature and behavior: accept cfg_obj and key, handle None, dict-like .get with safe exception handling, then getattr) and replace the inner definition in build_collator with calls to the new module-level function (used by roles_to_train and train_on_eos); this lets you unit-test _ds_get directly and avoids repeated re-definition while preserving current behavior.
550-557: Resolved-config INFO log will fire twice perbuild()call.
HFCausalTrainerBuilder.build()invokesself.build_collator(...)twice — once withis_eval=True(Line 400) and once for the training collator (Line 436). Both calls reach this branch when the MM path is taken, so users will see two identicalMM collator: ...INFO lines, plus twoProcessingStrategy init: ...lines from each strategy constructor. Not a correctness issue, but when someone is diagnosing "why isn't masking firing?" the duplicate lines add a "did I hit it twice?" moment.Consider gating the log on
is_eval=False, or collapsing to a single log line in the caller. Deferrable — flagging for awareness.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/builders/causal.py` around lines 550 - 557, The INFO log "MM collator: ..." in build_collator is emitted during both eval and train collator creation because HFCausalTrainerBuilder.build calls build_collator twice (once with is_eval=True and once for training), causing duplicate log lines; modify build_collator to only log that INFO when is_eval is False (or add a parameter to suppress logging) so the message (and the subsequent ProcessingStrategy logs) are only emitted once for the training collator, referencing the build_collator method and the LOG.info call that prints "MM collator: train_on_inputs=%s roles_to_train=%s ...".tests/test_processing_strategies.py (1)
1-1081: LGTM — thorough, deterministic offline test coverage.Notable strengths:
axolotl_caplogfixture correctly handlespropagate=Falsewhich is exactly the trap I'd expect to bite someone writing caplog-based tests against axolotl loggers._ROLE_MASK_WARNED.discard(...)before the no-boundaries test (Line 844) prevents test-order flakes from the module-level dedupe set.test_process_labels_no_warning_when_image_token_id_none(Line 289) is a nice defensive test against the classictensor == NonePyTorch UserWarning.- The
_FakeGemma4Tokenizer.VOCABclass attribute is correctly protected via{token: list(ids) for ...}copy in__init__(Line 466) — Ruff RUF012 is a false positive here.- Ruff S105 warnings on
self.boi_token,self.video_token, etc. are all false positives (multimodal token strings, not credentials).Two minor suggestions:
- The
_mistral_common_stubfixture (Line 627-630) always returnsNoneand is never introspected by the tests that receive it. Either remove it or document why it's needed (presumably for ordering to ensure the lazy-import error path has been primed). Deferrable.- Consider adding a test for the
train_on_eos="all"+include_end=Falsenon-trainable role case I flagged inprocessing_strategies.py— to lock in whatever behavior you settle on.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_processing_strategies.py` around lines 1 - 1081, Remove or document the unused _mistral_common_stub fixture: either delete the fixture definition named _mistral_common_stub (the pytest fixture that simply returns None) from the tests file, or add a one-line docstring above it explaining it's intentionally present to prime the lazy-import error path for get_processing_strategy; update any import comment to reference _mistral_common_stub so future readers know it's purposeful. Also add a small unit test that exercises the processing_strategies behavior for train_on_eos="all" combined with a role boundary that sets include_end=False on a non-trainable role (create a test function that constructs a ProcessingStrategy with role_boundaries_override where role is non-trainable, train_on_eos="all" and include_end=False, then assert the output labels follow the decided behavior), referencing ProcessingStrategy, _apply_role_boundaries, and RoleBoundary/include_end to pin the intended behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 415-419: The non-trainable-role branch unconditionally marks the
end marker into the loss when train_on_eos == "all", ignoring
best_match.include_end; change the condition in the else branch (the block using
found_end, train_on_eos, end_after and mask[i]) to also require
best_match.include_end so the end token is only masked when include_end is True
(same gating as the trainable branch), ensuring the mask write at
content_end:end_after is skipped when include_end is False.
- Around line 1046-1051: The class Glm4vProcessingStrategy is misnamed relative
to the processor it actually handles (Glm46VProcessor) causing GLM-4V processors
to fall back to ProcessingStrategy; either rename Glm4vProcessingStrategy to
Glm46VProcessingStrategy to match Glm46VProcessor, or (if you intended to
support GLM-4V) update the dispatcher to also register Glm4vProcessor and verify
the token markers for that checkpoint; update any other occurrences noted (the
similar block at lines referenced in the comment) so class names and dispatcher
registrations (Glm4vProcessingStrategy / Glm46VProcessingStrategy,
Glm46VProcessor, Glm4vProcessor, ProcessingStrategy) are consistent.
- Around line 647-653: The special_tokens_map lookup for "boi_token" is
ineffective for real Gemma3/Gemma checkpoints because these tokenizers expose
boi_token as a direct attribute, so replace the special_tokens_map branch with a
direct attribute check on processor.tokenizer (use getattr(processor.tokenizer,
"boi_token", None)); if found, set self.image_token to that value and
self.image_token_id by converting it with
processor.tokenizer.convert_tokens_to_ids (and keep the existing fallback that
checks boi_token_id/boi_token_id-like attributes); update references to
special_tokens_map/boi_token to use the tokenizer attribute to ensure production
models set image_token/image_token_id.
---
Nitpick comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 158-160: Add a language identifier to the fenced code block
containing "MM collator: train_on_inputs=False roles_to_train=['assistant']
train_on_eos=turn role_boundaries_override=none" so it satisfies the linter
(MD040); edit the opening fence from ``` to ```text (or ```log) to mark the
block as plain text/log and keep the block content unchanged.
In `@src/axolotl/core/builders/causal.py`:
- Around line 524-548: The helper _ds_get is being defined inside build_collator
causing it to be re-created on each call; hoist _ds_get to module scope (keep
the same signature and behavior: accept cfg_obj and key, handle None, dict-like
.get with safe exception handling, then getattr) and replace the inner
definition in build_collator with calls to the new module-level function (used
by roles_to_train and train_on_eos); this lets you unit-test _ds_get directly
and avoids repeated re-definition while preserving current behavior.
- Around line 550-557: The INFO log "MM collator: ..." in build_collator is
emitted during both eval and train collator creation because
HFCausalTrainerBuilder.build calls build_collator twice (once with is_eval=True
and once for training), causing duplicate log lines; modify build_collator to
only log that INFO when is_eval is False (or add a parameter to suppress
logging) so the message (and the subsequent ProcessingStrategy logs) are only
emitted once for the training collator, referencing the build_collator method
and the LOG.info call that prints "MM collator: train_on_inputs=%s
roles_to_train=%s ...".
In `@src/axolotl/processing_strategies.py`:
- Around line 655-659: The process_labels method hardcodes the Gemma3 image soft
token id (262144); change it to resolve the id from the tokenizer first (e.g.,
call tokenizer.convert_tokens_to_ids("<image_soft_token>") or the equivalent
lookup used in Gemma4ProcessingStrategy.process_labels) and only use the numeric
fallback when the token is not present in vocab; update process_labels to
compute image_soft_token_id dynamically, then replace labels[labels ==
image_soft_token_id] = -100, falling back to 262144 if the conversion returns an
unknown id.
- Around line 338-437: The scan is correct but _match_prefix repeatedly calls
label[start_pos:start_pos+len(tok_seq)].tolist(), causing many Python list
allocations; to optimize, convert each row once to a Python list (e.g., create
label_list = label.tolist() inside the for i in range(...) loop) and change
_match_prefix to compare slices from that label_list against
boundary.start_tokens (or pre-store boundary.start_tokens as tuples) so
comparisons are pure-Python without tensor-to-list conversions; alternatively,
you can compare tensors with torch.equal against pre-built tensors for each
RoleBoundary.start_tokens, but the simplest fix is to use label_list and update
_match_prefix to use it (references: function _apply_role_boundaries, local name
label, helper _match_prefix, and RoleBoundary.start_tokens).
- Around line 296-325: The warning for boundary-less strategies in
_mask_non_assistant currently deduplicates per-process via _ROLE_MASK_WARNED but
still emits once per rank; modify the branch that checks "if key not in
_ROLE_MASK_WARNED" to also check the process rank and only add/log from rank 0
(e.g., gate on int(os.environ.get("RANK", "0")) == 0) before calling LOG.warning
and adding to _ROLE_MASK_WARNED, so other ranks skip the warning while
preserving the single-shot behavior; touch the _mask_non_assistant function and
LOG.warning usage and ensure role_boundaries behavior remains unchanged.
In `@tests/test_processing_strategies.py`:
- Around line 1-1081: Remove or document the unused _mistral_common_stub
fixture: either delete the fixture definition named _mistral_common_stub (the
pytest fixture that simply returns None) from the tests file, or add a one-line
docstring above it explaining it's intentionally present to prime the
lazy-import error path for get_processing_strategy; update any import comment to
reference _mistral_common_stub so future readers know it's purposeful. Also add
a small unit test that exercises the processing_strategies behavior for
train_on_eos="all" combined with a role boundary that sets include_end=False on
a non-trainable role (create a test function that constructs a
ProcessingStrategy with role_boundaries_override where role is non-trainable,
train_on_eos="all" and include_end=False, then assert the output labels follow
the decided behavior), referencing ProcessingStrategy, _apply_role_boundaries,
and RoleBoundary/include_end to pin the intended behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bb1d73db-8316-4cf7-8f0a-ef5a71d9e6c2
📒 Files selected for processing (6)
docs/multimodal_assistant_mask.mdsrc/axolotl/core/builders/causal.pysrc/axolotl/processing_strategies.pysrc/axolotl/utils/schemas/datasets.pysrc/axolotl/utils/schemas/multimodal.pytests/test_processing_strategies.py
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Pre-commit failure: trailing newline missing on docs/multimodal_assistant_mask.md (end-of-file-fixer hook). Six CodeRabbit findings addressed: 1. Scanner: non-trainable role's end marker ignored ``include_end``. Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end with ``include_end=False``, intentionally re-matched as assistant-start) leaked into loss via the user branch on Pixtral / Mistral V7 Tekken. Fix: gate the non-trainable branch on ``best_match.include_end`` to mirror the trainable branch. 2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``, which never fires on real checkpoints (``special_tokens_map`` only holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct attribute read ``getattr(tokenizer, "boi_token", None)``, matching what ``transformers.models.gemma3.processing_gemma3`` itself does. Updated the ``_gemma_tokenizer`` test fixture to mirror real-model shape so the test exercises the production code path. 3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V / GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell through to the base fallback. Both processors ship identical media-token markers, so register both under the shared ``Glm4vProcessingStrategy`` with independent try/except import blocks. Updated class docstring. +2 dispatcher regressions. 4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token. Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")`` with unk-id guard; fall back to 262144 only if the string isn't in vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern. 5. ``build_collator`` was called twice per ``build()`` (eval + train passes), producing two identical ``MM collator: ...`` INFO banners on startup. Gate the log on ``is_eval=False`` so only the training pass emits it. 6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0, always returned ``None``; the dispatcher already handles missing ``mistral_common`` via lazy import + ``try/except``). Added ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false`` — a focused scanner-level lock-in for finding #1, independent of any specific VLM strategy. Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/processing_strategies.py (1)
1110-1144:⚠️ Potential issue | 🟡 MinorMinor:
Glm4vProcessingStrategyshould guard media ids againstunk_token_id.
convert_tokens_to_idsreturnsunk_token_idfor strings not in vocab. If any of<|image|>,<|begin_of_image|>, …,<|end_of_video|>is missing from a given checkpoint's vocab (e.g. a fine-tune that stripped a marker, or a future upstream rename), theif tok_id is not Nonecheck at line 1141 still passes andlabels[labels == unk_id] = -100would mask every genuine unk token in the batch.Gemma4ProcessingStrategy.process_labels(lines 746-753) already uses thetoken_id is None or token_id == unk_id → skippattern — mirroring it here keeps the class resilient to checkpoint drift.🛡️ Proposed fix
def process_labels(self, input_ids): labels = input_ids.clone() labels = self._mask_non_assistant(labels) pad_id = getattr(self.tokenizer, "pad_token_id", None) if pad_id is not None: labels[labels == pad_id] = -100 + unk_id = getattr(self.tokenizer, "unk_token_id", None) for tok_id in ( self.image_token_id, self.begin_image_token_id, self.end_image_token_id, self.video_token_id, self.begin_video_token_id, self.end_video_token_id, ): - if tok_id is not None: + if tok_id is not None and tok_id != unk_id: labels[labels == tok_id] = -100 return labels🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/processing_strategies.py` around lines 1110 - 1144, The process_labels method in Glm4vProcessingStrategy currently masks token ids even when convert_tokens_to_ids returned the tokenizer's unk token id; update the loop that iterates over self.image_token_id, self.begin_image_token_id, self.end_image_token_id, self.video_token_id, self.begin_video_token_id, self.end_video_token_id to first fetch the tokenizer's unk id (e.g., unk_id = getattr(self.tokenizer, "unk_token_id", None)) and skip masking for any tok_id that is None or equal to unk_id (i.e., only apply labels[labels == tok_id] = -100 when tok_id is not None and tok_id != unk_id) so genuine unk tokens are not incorrectly masked; reference the Glm4vProcessingStrategy.process_labels method and the image_* / video_* token id attributes when making the change.
🧹 Nitpick comments (1)
src/axolotl/processing_strategies.py (1)
357-373: Optional: avoid per-position.tolist()on tensor slices in the inner scanner.
_match_prefixis called O(n·b) times in the outer loop plus O(n) times inside_find_endper matched span, and each call materializes a Tensor slice via.tolist(). On typical MM sequences (n≈4k, b≈5, 8/batch) this becomes a measurable per-step overhead you pay on every training batch. Convertinglabels[i]to a Python list once per row eliminates the Python↔C boundary crossings in the hot loop without changing any semantics.♻️ Proposed fix
- def _match_prefix(label: Tensor, start_pos: int, tok_seq: list[int]) -> bool: - if not tok_seq or start_pos + len(tok_seq) > len(label): + def _match_prefix(label: list[int], start_pos: int, tok_seq: list[int]) -> bool: + if not tok_seq or start_pos + len(tok_seq) > len(label): return False - return label[start_pos : start_pos + len(tok_seq)].tolist() == tok_seq + return label[start_pos : start_pos + len(tok_seq)] == tok_seq def _find_end( - label: Tensor, start_pos: int, end_tok: list[int] + label: list[int], start_pos: int, end_tok: list[int] ) -> tuple[int, bool]: # Empty end_tok means run to end-of-sequence. if not end_tok: return len(label), False @@ for i in range(labels.shape[0]): - label = labels[i] + label = labels[i].tolist() j = 0 n = len(label)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/processing_strategies.py` around lines 357 - 373, The inner scanner currently calls _match_prefix which uses label[start_pos : start_pos + len(tok_seq)].tolist() on every invocation, causing costly Tensor->Python transitions; change the scanning logic so you convert each label Tensor to a Python list once before scanning (e.g., at the start of the outer per-row loop or by accepting a pre-converted list) and then have _match_prefix and _find_end operate on that Python list (compare slices or use tuple/list comparisons) instead of calling .tolist() inside the hot loop; update function signatures or call sites for _match_prefix(label: Tensor, ...) and _find_end(label: Tensor, ...) accordingly to use the pre-converted list representation for label.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 1-31: Change the second heading from h3 to h2 (replace "###
Correct placement" with "## Correct placement") and add a language tag to the
fenced block showing build_collator output (e.g., replace the triple-backtick
fence with ```text) so the docs pass markdownlint rules MD001 and MD040; locate
these edits near the "Correct placement" heading and the code block that begins
with "MM collator: train_on_inputs=False..." in multimodal_assistant_mask.md.
---
Outside diff comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 1110-1144: The process_labels method in Glm4vProcessingStrategy
currently masks token ids even when convert_tokens_to_ids returned the
tokenizer's unk token id; update the loop that iterates over
self.image_token_id, self.begin_image_token_id, self.end_image_token_id,
self.video_token_id, self.begin_video_token_id, self.end_video_token_id to first
fetch the tokenizer's unk id (e.g., unk_id = getattr(self.tokenizer,
"unk_token_id", None)) and skip masking for any tok_id that is None or equal to
unk_id (i.e., only apply labels[labels == tok_id] = -100 when tok_id is not None
and tok_id != unk_id) so genuine unk tokens are not incorrectly masked;
reference the Glm4vProcessingStrategy.process_labels method and the image_* /
video_* token id attributes when making the change.
---
Nitpick comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 357-373: The inner scanner currently calls _match_prefix which
uses label[start_pos : start_pos + len(tok_seq)].tolist() on every invocation,
causing costly Tensor->Python transitions; change the scanning logic so you
convert each label Tensor to a Python list once before scanning (e.g., at the
start of the outer per-row loop or by accepting a pre-converted list) and then
have _match_prefix and _find_end operate on that Python list (compare slices or
use tuple/list comparisons) instead of calling .tolist() inside the hot loop;
update function signatures or call sites for _match_prefix(label: Tensor, ...)
and _find_end(label: Tensor, ...) accordingly to use the pre-converted list
representation for label.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b772d9b1-e0ba-4149-a039-7c08b1424956
📒 Files selected for processing (4)
docs/multimodal_assistant_mask.mdsrc/axolotl/core/builders/causal.pysrc/axolotl/processing_strategies.pytests/test_processing_strategies.py
…trings - Scanner perf: convert labels[i] to a Python list once per row so _match_prefix / _find_end operate on list slices instead of re-materializing Tensor slices via .tolist() on every probe. Cuts O(n*boundaries) CPython↔C boundary crossings per batch. - Markdown lint (MD001, MD040): promote two h3 section headings to h2 under the h1; add `text` language to the verify-at-runtime fenced block. - Shorten verbose comments/docstrings added in recent commits to bare-minimum "why" notes matching the repo's existing style. 68/68 tests, 8/8 pre-commit hooks still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@winglian thank you! |
|
Thanks for the PR! I think the next step from here would be to review speed / performance cost for this change and whether the implementation can be improved upon. |
will do, I'll look into benchmarking on my end once my GPUs free up |
I benched the PR, the scanner cost is ~0.21 µs/token (linear, strategy-independent), which works out to ~7 ms at micro batch=8 × seq=4096, well under 0.1% of step time and fully overlapped with GPU compute via dataloader workers. Digging deeper, there are two possible directions that seem viable if you'd prefer without tapping into preprocessing (although that could be a viable future decision) to reduce any latency all together: (a) Vectorized scanner: I tested this on a local branch (_apply_role_boundaries_vectorized alongside the original): ~1.5–2× faster (Gemma4 1.64×, Llama3.2V 1.50×, Pixtral 1.93×), this approach tested well and seems valid. (b) Native return_assistant_tokens_mask=True from apply_chat_template: HF's tokenizers/processors expose this flag for templates with {% generation %} blocks. Where supported, the tokenizer returns the assistant mask directly, eliminating the scanner, the per-strategy boundary declarations, and the cfg.role_boundaries override . This is closer in spirit to the text path (prompt_strategies/chat_template.py), which defers boundary detection to template machinery (find_turn does a template-diff at preprocess time) rather than running its own per-batch token scan. The would would probably be: detect support at strategy init, use native when available, current scanner as fallback. A bit more effort, but architecturally consistent. My take is that the scanner affords maximum flexibility at a noise-level latency tax. Improving it with the vectorized approach makes it a bit more palatable; frankly multimodal training is so much lengthier than text that a <1% tax is an acceptable trade-off pending a more complex refactor. To make (b) work with the same flexibility (given the diversity of chat templates for models without {% generation %} blocks) would be a strong contender for a 3rd future commit: native mask as a fast path for the supported common case (roles_to_train=["assistant"], train_on_eos="turn", template has generation blocks), scanner as the fallback for multi-role training, non-default EOS modes, cfg.role_boundariesoverrides, and templates without generation blocks. That said, in hindsight perhaps (b) would have been more appropriate; I know your strategy has been to enforce a standard chat template (and rightfully so if you look at the complexity of chat template handling in llama.cpp and the like) and this deviates from it, so part of this may have been my own preferences leaking in. |
|
You can try timing for option b, but like mentioned earlier, it was very slow when I first tested it. |
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
…3672) * perf(mm-mask): vectorized role-boundary scanner + fused process_labels Follow-up to #3625 reducing the cost of multimodal assistant-only masking: - Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized): precomputes start/end match positions in a single batched unfold pass, collapsing the inner Python prefix-match into a vectorized all-over-last-dim check. Byte-identical to the reference scanner, gated on torch.get_num_threads() == 1 (the DataLoader-worker default) since the per-op dispatch overhead regresses it under multi-threaded torch. 1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie on Llama 3 (more boundaries, more tensor ops). - Fused process_labels: composes role/pad/media masks into one boolean keep tensor and does a single labels[~keep] = -100 write instead of the previous 3-4 sequential masked writes. Structural cleanup; the measured speedup is ~0ms since the prior masked-write cost was already dwarfed by the scanner. - One-shot warning when MultiModalChatDataCollator is built with dataloader_num_workers in (None, 0): MM training does heavier collator-side work than text and synchronous collation blocks the GPU. Tests: - tests/test_vectorized_scanner.py: 2,000-config differential fuzz + targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie, train_on_eos modes, empty end_tokens, include_end leak gate). - tests/test_process_labels_fusion.py: 50 random batches per strategy across all 9 process_labels-overriding strategies (450 fuzz cases) asserting torch.equal between pre-refactor and fused implementation. - tests/test_mm_collator_warnings.py: num_workers=0 warning regression. * fix: address coderabbit feedback and lint * test(mm-mask): add Voxtral/Mistral3/InternVL/Glm4v strategy coverage - test_process_labels_fusion.py: factories + parity-fuzz entries for the four strategies (200 new fused-vs-legacy assertions); no-boundaries branch in _build_alternating_turns now injects pad/extras so the pad/media masking path is exercised. - test_processing_strategies.py: behavioral tests (process_labels, train_on_inputs, role_boundaries_override, batch handling, InternVL error paths) plus dispatch routing for VoxtralProcessor, Mistral3Processor, and InternVLProcessor. 572 -> 788 passing across the PR's four test files. * perf(mm-mask): bisect end-finder + slice-fill in vectorized scanner The vectorized scanner's per-row loop still did O(span) Python work: a linear walk over end-match flags to find each turn's end, and element-wise bytearray writes to fill spans. Replace both: - Precompute each boundary's end-match start positions (sorted via nonzero) and bisect for the next end >= start_of_content. - Fill keep-spans with bytearray slice assignment and materialize each row tensor once via torch.frombuffer. Byte-identical to the reference scanner; ~2x faster than the prior vectorized path (≈3x vs reference) on 8x3-6k batches, widening with turn length. Adds long-span / multi-end-marker parity fuzz that the original filler<=200 generator under-exercised. --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
Description
Fixes silent ignoring of
cfg.train_on_inputs/cfg.roles_to_train/cfg.train_on_eosin the multimodal (VLM) training path. Before this PR, only Gemma 3n honored these knobs — every other multimodal model trained onthe entire sequence (system + user + assistant) regardless of config, silently turning assistant-only SFT into full-sequence SFT.
What changed
ProcessingStrategygains a declarative boundary scanner. Each strategy declares per-role start/end markers via_build_role_boundaries; a shared scanner walks the re-tokenized sequence and honorstrain_on_inputs/roles_to_train/train_on_eos(including"last").cfg.role_boundariesYAML override (opt-in) — declare per-role markers directly in config without subclassing. Intended as an escape hatch for strategies whose tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V,llava/lfm2vl/ unknown template) and for custom chat templates. Leaving the field unset (or setting it to[]) uses the strategy's built-in markers — writing it literally only replaces the built-ins when you supply a non-empty list.Usage
Masking knobs are per-dataset (under
datasets:/test_datasets:),not top-level. Onlytrain_on_inputslives at the root:For a fallback strategy or custom template, add
cfg.role_boundariesat the root (it replaces any built-in markers):Notes:
role_boundariesis an opt-in override. A non-empty list replaces the strategy's built-in markersstart/endare literal strings; they're encoded at strategy init viatokenizer.encode(..., add_special_tokens=False)and the resolved token ids are logged at INFO so mismatches are visible in the training log.cfg.roles_to_trainstill governs which declared roles contribute to loss. Declaringuser/system/toolboundaries lets the scanner correctly identify their spans as masking boundariesrole/start, unencodable markers) raise at strategy init — not silently at loss-compute time.See
docs/multimodal_assistant_mask.mdfor the full audit table, per-strategy boundary markersVerifying it works
At startup,
build_collatorand eachProcessingStrategy.__init__emit INFO lines:Motivation and Context
MultiModalChatDataCollatorre-tokenizes rawmessagesat collation time viaprocessor.apply_chat_template, discarding the per-role labels computed byChatTemplateStrategy.tokenize_promptin preprocessing. It then calledprocessing_strategy.process_labels(input_ids)to rebuild role-aware labels— but the base_mask_non_assistantwas a no-opreturn labels. OnlyGemma3nProcessingStrategyoverrode it. As ar result every multimodal model except Gemma 3n, retokenized labels were never masked by role.How has this been tested?
Offline unit tests — 63 passing
pytest tests/test_processing_strategies.py— no HF Hub access required.Coverage:
train_on_inputs,roles_to_train(incl. empty list),train_on_eos("turn"/"all"/"none"/"last"), longest-prefix start match, truncated spans, scanner rewind on shared tokens.image_pad,video_pad,boi/eoi/boa/eoawhere applicable).cfg.role_boundariesoverride: non-empty replaces built-ins wholesale, empty list falls through to built-ins (opt-in semantics), enables unverified strategies,eos_tokensentinel,nullend, spec-validation errors, Pydantic model input.SFTDatasetschema accepts every value the scanner honors fortrain_on_eos—"all"/"turn"/"last"/"none"— and rejects bogus values (regression against a priorLiteralthat dropped"none").caplog).End-to-end against real tokenizers
google/gemma-4-E2B-it: 13/40 tokens kept for a 2-turn chat; decoded preview shows only assistant responses +<turn|>markers retained.axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizerwith bundledllama3_2_vision.jinja: 11/64 tokens kept; decoded content resolves to"The capital of France is Paris.<|eot_id|>"and"Berlin.<|eot_id|>".- Gemma 4 boundary ids verified against the real tokenizer:<|turn>model→[105, 4368],<turn|>→[106],<|image|>→258880,<|audio|>→258881,<|video|>→258884.Real training run
Live SFT on Gemma 4 E2B, single RTX 5090, FSDP2 + CPU offload,
lora_r=256rsLoRA,sequence_len=5048:tokens/total: 59664, tokens/trainable: 13945(~23.4% unmasked —consistent with assistant-only masking on an OCR conversation where user content dominates).
Back-compat
cfg.processor_type and self.processor.ChatTemplateStrategydefaults (roles_to_train=["assistant"],train_on_eos="turn").AI Usage Disclaimer
Yes. Claude Opus 4.7 (via Claude Code) assisted throughout. All commits are co-authored and were reviewed against real training logs by the submitter before inclusion and changes and testing reviewed and tested by Codex as well.Design choices and approach human guided.
Types of changes
train_on_inputs/roles_to_train/train_on_eosin the multimodal path.cfg.role_boundariesYAML override, new per-template strategies, and INFO-log visibility of resolved masking config.docs/multimodal_assistant_mask.md(audit table, root cause, design, config guide).
Summary by CodeRabbit
Bug Fixes
New Features
Documentation
Tests