Skip to content

feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries`#3625

Merged
winglian merged 10 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:feat/multimodal-assistant-mask-all
May 5, 2026
Merged

feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries`#3625
winglian merged 10 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:feat/multimodal-assistant-mask-all

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Apr 24, 2026

Copy link
Copy Markdown
Contributor

Description

Fixes silent ignoring of cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos in the multimodal (VLM) training path. Before this PR, only Gemma 3n honored these knobs — every other multimodal model trained on
the entire sequence (system + user + assistant) regardless of config, silently turning assistant-only SFT into full-sequence SFT.

What changed

  • ProcessingStrategy gains a declarative boundary scanner. Each strategy declares per-role start/end markers via _build_role_boundaries; a shared scanner walks the re-tokenized sequence and honors train_on_inputs / roles_to_train / train_on_eos (including "last").
  • New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken.
  • Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner → shared).
  • cfg.role_boundaries YAML override (opt-in) — declare per-role markers directly in config without subclassing. Intended as an escape hatch for strategies whose tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava / lfm2vl / unknown template) and for custom chat templates. Leaving the field unset (or setting it to []) uses the strategy's built-in markers — writing it literally only replaces the built-ins when you supply a non-empty list.
  • Visibility: fallback strategies emit a one-shot warning (no silent mis-masking) and every strategy logs its resolved masking config at INFO.

Usage

Masking knobs are per-dataset (under datasets: / test_datasets:),not top-level. Only train_on_inputs lives at the root:

train_on_inputs: false

datasets:
  - path: data/train.jsonl
    type: chat_template
    roles_to_train: [assistant]
    train_on_eos: turn         # "turn" | "all" | "none" | "last"

test_datasets:
  - path: data/val.jsonl
    type: chat_template
    split: train
    roles_to_train: [assistant]
    train_on_eos: turn

For a fallback strategy or custom template, add cfg.role_boundaries at the root (it replaces any built-in markers):

# Each entry declares one role's token-level span. `role` must match the
# names you use in `roles_to_train` (e.g. assistant, user, system, tool,
# ipython). Any role not declared here is treated as masked by default.
role_boundaries:
  - role: assistant
    start: "<|turn>model"
    end: "<turn|>"
    # Optional per-entry keys (shown with defaults):
    # include_start: false        # start marker never contributes to loss
    # include_end: true           # end marker contributes on trainable turns
    #                             # (further gated by cfg.train_on_eos)

  - role: user
    start: "<|turn>user"
    end: "<turn|>"

  - role: system
    start: "<|turn>system"
    end: "<turn|>"

  - role: tool
    start: "<|turn>tool"
    end: "<turn|>"

  - role: ipython               # e.g. Llama 3.2 Vision / Llama 4 tool-call replies
    start: "<|start_header_id|>ipython<|end_header_id|>\n\n"
    end: "<|eot_id|>"

  # Special `end:` values you can mix in on any entry:
  # - role: assistant            # Pixtral-style: assistant span ends at EOS
  #   start: "[/INST]"
  #   end: eos_token             # sentinel → resolves to tokenizer.eos_token_id
  # - role: assistant            # span runs to end of sequence (no explicit close)
  #   start: "<|turn>model"
  #   end: null

Notes:

  • role_boundaries is an opt-in override. A non-empty list replaces the strategy's built-in markers
  • start / end are literal strings; they're encoded at strategy init via tokenizer.encode(..., add_special_tokens=False) and the resolved token ids are logged at INFO so mismatches are visible in the training log.
  • cfg.roles_to_train still governs which declared roles contribute to loss. Declaring user / system / tool boundaries lets the scanner correctly identify their spans as masking boundaries
  • Invalid specs (missing role / start, unencodable markers) raise at strategy init — not silently at loss-compute time.

See docs/multimodal_assistant_mask.md for the full audit table, per-strategy boundary markers

Verifying it works

At startup, build_collator and each ProcessingStrategy.__init__ emit INFO lines:

MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
ProcessingStrategy init: class=Gemma4ProcessingStrategy ... boundaries_source=built-in boundaries=3

Motivation and Context

MultiModalChatDataCollator re-tokenizes raw messages at collation time via processor.apply_chat_template, discarding the per-role labels computed by ChatTemplateStrategy.tokenize_prompt in preprocessing. It then called processing_strategy.process_labels(input_ids) to rebuild role-aware labels— but the base _mask_non_assistant was a no-op return labels. Only Gemma3nProcessingStrategy overrode it. As ar result every multimodal model except Gemma 3n, retokenized labels were never masked by role.

How has this been tested?

Offline unit tests — 63 passing

pytest tests/test_processing_strategies.py — no HF Hub access required.

Coverage:

  • Scanner semantics: train_on_inputs, roles_to_train (incl. empty list), train_on_eos ("turn" / "all" / "none" / "last"), longest-prefix start match, truncated spans, scanner rewind on shared tokens.
  • Per-strategy masking with fake tokenizers mirroring real-model id layouts: Qwen2-VL, Qwen3.5, Gemma 3 / 3n / 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken.
  • Media-token masking inside assistant spans (image_pad, video_pad, boi/eoi/boa/eoa where applicable).
  • Dispatcher routing including Mistral3 / GLM4V lazy-import fallbacks.
  • cfg.role_boundaries override: non-empty replaces built-ins wholesale, empty list falls through to built-ins (opt-in semantics), enables unverified strategies, eos_token sentinel, null end, spec-validation errors, Pydantic model input.
  • SFTDataset schema accepts every value the scanner honors for train_on_eos"all" / "turn" / "last" / "none" — and rejects bogus values (regression against a prior Literal that dropped "none").
  • Edge cases: batch_size > 1, all-pad sequences, consecutive assistant turns, pad masking inside trainable spans.
  • Init-time INFO log is grep-visible (regression test with caplog).

End-to-end against real tokenizers

  • google/gemma-4-E2B-it: 13/40 tokens kept for a 2-turn chat; decoded preview shows only assistant responses + <turn|> markers retained.
  • axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer with bundled llama3_2_vision.jinja: 11/64 tokens kept; decoded content resolves to "The capital of France is Paris.<|eot_id|>" and "Berlin.<|eot_id|>".- Gemma 4 boundary ids verified against the real tokenizer: <|turn>model[105, 4368], <turn|>[106], <|image|>258880, <|audio|>258881, <|video|>258884.

Real training run

Live SFT on Gemma 4 E2B, single RTX 5090, FSDP2 + CPU offload,
lora_r=256 rsLoRA, sequence_len=5048:

  • Collator log confirms knobs threaded through:
    MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
    ProcessingStrategy init: class=Gemma4ProcessingStrategy ... boundaries_source=built-in boundaries=3
    
  • Per-step counters confirm masking is active: tokens/total: 59664, tokens/trainable: 13945 (~23.4% unmasked —
    consistent with assistant-only masking on an OCR conversation where user content dominates).
  • Training + eval progress cleanly; no "legacy behavior" warnings.

Back-compat

  • Text-only path unaffected — MM path only taken when cfg.processor_type and self.processor.
  • No schema break — defaults when the per-dataset keys are unset match the text-only ChatTemplateStrategy defaults (roles_to_train=["assistant"], train_on_eos="turn").
  • Loss values will change for any multimodal run that was previously seeing full-sequence loss — this is the intended fix. Sample-efficiency should improve; compare eval loss/perplexity on a held-out set to confirm.

AI Usage Disclaimer

Yes. Claude Opus 4.7 (via Claude Code) assisted throughout. All commits are co-authored and were reviewed against real training logs by the submitter before inclusion and changes and testing reviewed and tested by Codex as well.Design choices and approach human guided.

Types of changes

  • Bug fix (non-breaking change which fixes an issue) — restores honoring of train_on_inputs / roles_to_train / train_on_eos in the multimodal path.
  • New feature (non-breaking change which adds functionality) — cfg.role_boundaries YAML override, new per-template strategies, and INFO-log visibility of resolved masking config.
  • Documentation update — new docs/multimodal_assistant_mask.md
    (audit table, root cause, design, config guide).

Summary by CodeRabbit

  • Bug Fixes

    • Fixed multimodal training where token masking settings could be ignored during collation by ensuring dataset-level masking params are respected.
  • New Features

    • Role-boundary token masking for precise control of which tokens contribute to loss.
    • Optional config override to supply custom role boundary markers.
    • Added "none" option to disable training on EOS tokens.
  • Documentation

    • New guidance on configuring, verifying, and logging multimodal masking behavior.
  • Tests

    • Comprehensive tests covering masking logic, EOS modes, overrides, and strategy routing.

thad0ctor and others added 5 commits April 24, 2026 11:05
…daries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…daries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Apr 24, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c6be493e-f926-4dd9-83de-d5cfad61d15e

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds token-level, role-boundary label masking for multimodal assistants, threads dataset-level masking knobs (train_on_inputs, roles_to_train, train_on_eos, role_boundaries) from builder into collators, extends schemas for overrides, and adds comprehensive tests and docs describing configuration and verification.

Changes

Cohort / File(s) Summary
Documentation
docs/multimodal_assistant_mask.md
New doc describing multimodal assistant-only loss masking, YAML examples, logging expectations, fallback/override semantics, and verification guidance.
Builder integration
src/axolotl/core/builders/causal.py
Collator now derives dataset-level roles_to_train/train_on_eos from first dataset and forwards train_on_inputs/roles_to_train/train_on_eos/role_boundaries_override into get_processing_strategy; logs resolved settings during collator build.
Processing strategies & masking core
src/axolotl/processing_strategies.py
Introduces RoleBoundary dataclass and generalized _apply_role_boundaries(); adds train_on_inputs/roles_to_train/train_on_eos/role_boundaries_override parameters to ProcessingStrategy and get_processing_strategy; implements role-span masking, built-in boundary declarations for many multimodal strategies, validation and logging, and routing for optional processors.
Schema updates
src/axolotl/utils/schemas/datasets.py, src/axolotl/utils/schemas/multimodal.py
SFTDataset.train_on_eos accepts "none"; added RoleBoundarySpec and optional role_boundaries to MultiModalConfig for user overrides.
Tests
tests/test_processing_strategies.py
Large CI-safe test suite validating role-boundary masking, train_on_eos modes (turn,all,none,last), longest-prefix matching, batch behavior, override semantics, schema parsing, and strategy dispatch.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Suggested labels

scheduled_release

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 17.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title clearly summarizes the main changes: implementing systemic multimodal assistant-only loss masking and adding cfg.role_boundaries configuration support.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@thad0ctor

thad0ctor commented Apr 24, 2026

Copy link
Copy Markdown
Contributor Author

Apologies in advance that this is a significant Opus-written PR, I can architect but my coding lacks. All +30 local fork PRs were reviewed by coderabbit and codex as well and tested thoroughly with training runs and against the existing workflow tests in the repo in my fork.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
docs/multimodal_assistant_mask.md (1)

158-160: Nit: add language to fenced code block (MD040).

The collator log example block has no language identifier. A plain text or log identifier satisfies the linter and enables consistent rendering.

📝 Proposed fix
 `build_collator` logs the resolved knobs at INFO:

-```
+```text
 MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
 ```
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/multimodal_assistant_mask.md` around lines 158 - 160, Add a language
identifier to the fenced code block containing "MM collator:
train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn
role_boundaries_override=none" so it satisfies the linter (MD040); edit the
opening fence from ``` to ```text (or ```log) to mark the block as plain
text/log and keep the block content unchanged.
src/axolotl/processing_strategies.py (3)

655-659: Hardcoded 262144 for Gemma 3 soft image token is brittle.

The literal id will silently go stale if the tokenizer vocab shifts (checkpoint variant, upstream retokenization, custom fine-tune with added specials). Consider resolving via convert_tokens_to_ids("<image_soft_token>") and only falling back to the literal if the token isn't in vocab — mirrors what Gemma4ProcessingStrategy.process_labels already does for boi/eoi/boa/eoa.

♻️ Proposed fix
     def process_labels(self, input_ids):
         labels = super().process_labels(input_ids)
-        # Gemma3-specific <image_soft_token> id; not exposed as a tokenizer attribute.
-        labels[labels == 262144] = -100
+        # Gemma3-specific <image_soft_token>; resolve via tokenizer, fallback to known id.
+        tok = self.processor.tokenizer
+        soft_id = tok.convert_tokens_to_ids("<image_soft_token>")
+        unk_id = getattr(tok, "unk_token_id", None)
+        if soft_id is not None and soft_id != unk_id:
+            labels[labels == soft_id] = -100
+        else:
+            labels[labels == 262144] = -100
         return labels
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 655 - 659, The
process_labels method hardcodes the Gemma3 image soft token id (262144); change
it to resolve the id from the tokenizer first (e.g., call
tokenizer.convert_tokens_to_ids("<image_soft_token>") or the equivalent lookup
used in Gemma4ProcessingStrategy.process_labels) and only use the numeric
fallback when the token is not present in vocab; update process_labels to
compute image_soft_token_id dynamically, then replace labels[labels ==
image_soft_token_id] = -100, falling back to 262144 if the conversion returns an
unknown id.

338-437: Scanner core _apply_role_boundaries looks correct.

Verified traces for: basic masking, longest-prefix wins, multi-batch, include_start/include_end semantics, train_on_eos in {"turn","all","none","last"}, Pixtral shared-marker rewind (Line 426-427 correctly sets j = end_after - len(end_tokens) so the shared [/INST] gets re-matched as assistant-start), and truncated spans (missing end marker runs to EOS). Tests in tests/test_processing_strategies.py cover each case.

Minor optimization note for a future pass: _match_prefix calls .tolist() on a tensor slice for every position × every boundary, which is a lot of Python-side allocation for long sequences. Converting labels to a Python list once per row before the scan (or using torch.equal on a pre-built tensor per boundary) would amortize it. Deferrable — real-world MM sequences are usually ≤ few thousand tokens and markers are 1-5 ids.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 338 - 437, The scan is
correct but _match_prefix repeatedly calls
label[start_pos:start_pos+len(tok_seq)].tolist(), causing many Python list
allocations; to optimize, convert each row once to a Python list (e.g., create
label_list = label.tolist() inside the for i in range(...) loop) and change
_match_prefix to compare slices from that label_list against
boundary.start_tokens (or pre-store boundary.start_tokens as tuples) so
comparisons are pure-Python without tensor-to-list conversions; alternatively,
you can compare tensors with torch.equal against pre-built tensors for each
RoleBoundary.start_tokens, but the simplest fix is to use label_list and update
_match_prefix to use it (references: function _apply_role_boundaries, local name
label, helper _match_prefix, and RoleBoundary.start_tokens).

296-325: One-shot warning for boundary-less strategies is well-designed.

The _ROLE_MASK_WARNED module-level dedupe correctly suppresses per-batch spam. One concern for distributed training: on N ranks, each rank's worker process keeps its own set, so you'll see the warning N times in aggregated logs (one per rank). That's mild noise but may be misread as "something is repeatedly misfiring." If you want rank-0-only behavior, gating on int(os.environ.get("RANK", "0")) == 0 is the usual trick. Deferrable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 296 - 325, The warning for
boundary-less strategies in _mask_non_assistant currently deduplicates
per-process via _ROLE_MASK_WARNED but still emits once per rank; modify the
branch that checks "if key not in _ROLE_MASK_WARNED" to also check the process
rank and only add/log from rank 0 (e.g., gate on int(os.environ.get("RANK",
"0")) == 0) before calling LOG.warning and adding to _ROLE_MASK_WARNED, so other
ranks skip the warning while preserving the single-shot behavior; touch the
_mask_non_assistant function and LOG.warning usage and ensure role_boundaries
behavior remains unchanged.
src/axolotl/core/builders/causal.py (2)

524-548: Per-dataset knob extraction looks correct; single-dataset assumption is well-documented.

The _ds_get helper cleanly handles the DictDefault/dict/pydantic trio, and the NOTE comment on Lines 525-526 is exactly the caveat I was going to flag (heterogeneous per-dataset masking knobs silently use the first dataset's values). Docs docs/multimodal_assistant_mask.md also call this out.

One very minor note: _ds_get is defined inside build_collator, so it gets re-created on every call (including the double call above). Hoisting it to module scope would let you unit-test it directly and avoid the re-definition. Optional.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/builders/causal.py` around lines 524 - 548, The helper
_ds_get is being defined inside build_collator causing it to be re-created on
each call; hoist _ds_get to module scope (keep the same signature and behavior:
accept cfg_obj and key, handle None, dict-like .get with safe exception
handling, then getattr) and replace the inner definition in build_collator with
calls to the new module-level function (used by roles_to_train and
train_on_eos); this lets you unit-test _ds_get directly and avoids repeated
re-definition while preserving current behavior.

550-557: Resolved-config INFO log will fire twice per build() call.

HFCausalTrainerBuilder.build() invokes self.build_collator(...) twice — once with is_eval=True (Line 400) and once for the training collator (Line 436). Both calls reach this branch when the MM path is taken, so users will see two identical MM collator: ... INFO lines, plus two ProcessingStrategy init: ... lines from each strategy constructor. Not a correctness issue, but when someone is diagnosing "why isn't masking firing?" the duplicate lines add a "did I hit it twice?" moment.

Consider gating the log on is_eval=False, or collapsing to a single log line in the caller. Deferrable — flagging for awareness.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/builders/causal.py` around lines 550 - 557, The INFO log "MM
collator: ..." in build_collator is emitted during both eval and train collator
creation because HFCausalTrainerBuilder.build calls build_collator twice (once
with is_eval=True and once for training), causing duplicate log lines; modify
build_collator to only log that INFO when is_eval is False (or add a parameter
to suppress logging) so the message (and the subsequent ProcessingStrategy logs)
are only emitted once for the training collator, referencing the build_collator
method and the LOG.info call that prints "MM collator: train_on_inputs=%s
roles_to_train=%s ...".
tests/test_processing_strategies.py (1)

1-1081: LGTM — thorough, deterministic offline test coverage.

Notable strengths:

  • axolotl_caplog fixture correctly handles propagate=False which is exactly the trap I'd expect to bite someone writing caplog-based tests against axolotl loggers.
  • _ROLE_MASK_WARNED.discard(...) before the no-boundaries test (Line 844) prevents test-order flakes from the module-level dedupe set.
  • test_process_labels_no_warning_when_image_token_id_none (Line 289) is a nice defensive test against the classic tensor == None PyTorch UserWarning.
  • The _FakeGemma4Tokenizer.VOCAB class attribute is correctly protected via {token: list(ids) for ...} copy in __init__ (Line 466) — Ruff RUF012 is a false positive here.
  • Ruff S105 warnings on self.boi_token, self.video_token, etc. are all false positives (multimodal token strings, not credentials).

Two minor suggestions:

  1. The _mistral_common_stub fixture (Line 627-630) always returns None and is never introspected by the tests that receive it. Either remove it or document why it's needed (presumably for ordering to ensure the lazy-import error path has been primed). Deferrable.
  2. Consider adding a test for the train_on_eos="all" + include_end=False non-trainable role case I flagged in processing_strategies.py — to lock in whatever behavior you settle on.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_processing_strategies.py` around lines 1 - 1081, Remove or
document the unused _mistral_common_stub fixture: either delete the fixture
definition named _mistral_common_stub (the pytest fixture that simply returns
None) from the tests file, or add a one-line docstring above it explaining it's
intentionally present to prime the lazy-import error path for
get_processing_strategy; update any import comment to reference
_mistral_common_stub so future readers know it's purposeful. Also add a small
unit test that exercises the processing_strategies behavior for
train_on_eos="all" combined with a role boundary that sets include_end=False on
a non-trainable role (create a test function that constructs a
ProcessingStrategy with role_boundaries_override where role is non-trainable,
train_on_eos="all" and include_end=False, then assert the output labels follow
the decided behavior), referencing ProcessingStrategy, _apply_role_boundaries,
and RoleBoundary/include_end to pin the intended behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 415-419: The non-trainable-role branch unconditionally marks the
end marker into the loss when train_on_eos == "all", ignoring
best_match.include_end; change the condition in the else branch (the block using
found_end, train_on_eos, end_after and mask[i]) to also require
best_match.include_end so the end token is only masked when include_end is True
(same gating as the trainable branch), ensuring the mask write at
content_end:end_after is skipped when include_end is False.
- Around line 1046-1051: The class Glm4vProcessingStrategy is misnamed relative
to the processor it actually handles (Glm46VProcessor) causing GLM-4V processors
to fall back to ProcessingStrategy; either rename Glm4vProcessingStrategy to
Glm46VProcessingStrategy to match Glm46VProcessor, or (if you intended to
support GLM-4V) update the dispatcher to also register Glm4vProcessor and verify
the token markers for that checkpoint; update any other occurrences noted (the
similar block at lines referenced in the comment) so class names and dispatcher
registrations (Glm4vProcessingStrategy / Glm46VProcessingStrategy,
Glm46VProcessor, Glm4vProcessor, ProcessingStrategy) are consistent.
- Around line 647-653: The special_tokens_map lookup for "boi_token" is
ineffective for real Gemma3/Gemma checkpoints because these tokenizers expose
boi_token as a direct attribute, so replace the special_tokens_map branch with a
direct attribute check on processor.tokenizer (use getattr(processor.tokenizer,
"boi_token", None)); if found, set self.image_token to that value and
self.image_token_id by converting it with
processor.tokenizer.convert_tokens_to_ids (and keep the existing fallback that
checks boi_token_id/boi_token_id-like attributes); update references to
special_tokens_map/boi_token to use the tokenizer attribute to ensure production
models set image_token/image_token_id.

---

Nitpick comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 158-160: Add a language identifier to the fenced code block
containing "MM collator: train_on_inputs=False roles_to_train=['assistant']
train_on_eos=turn role_boundaries_override=none" so it satisfies the linter
(MD040); edit the opening fence from ``` to ```text (or ```log) to mark the
block as plain text/log and keep the block content unchanged.

In `@src/axolotl/core/builders/causal.py`:
- Around line 524-548: The helper _ds_get is being defined inside build_collator
causing it to be re-created on each call; hoist _ds_get to module scope (keep
the same signature and behavior: accept cfg_obj and key, handle None, dict-like
.get with safe exception handling, then getattr) and replace the inner
definition in build_collator with calls to the new module-level function (used
by roles_to_train and train_on_eos); this lets you unit-test _ds_get directly
and avoids repeated re-definition while preserving current behavior.
- Around line 550-557: The INFO log "MM collator: ..." in build_collator is
emitted during both eval and train collator creation because
HFCausalTrainerBuilder.build calls build_collator twice (once with is_eval=True
and once for training), causing duplicate log lines; modify build_collator to
only log that INFO when is_eval is False (or add a parameter to suppress
logging) so the message (and the subsequent ProcessingStrategy logs) are only
emitted once for the training collator, referencing the build_collator method
and the LOG.info call that prints "MM collator: train_on_inputs=%s
roles_to_train=%s ...".

In `@src/axolotl/processing_strategies.py`:
- Around line 655-659: The process_labels method hardcodes the Gemma3 image soft
token id (262144); change it to resolve the id from the tokenizer first (e.g.,
call tokenizer.convert_tokens_to_ids("<image_soft_token>") or the equivalent
lookup used in Gemma4ProcessingStrategy.process_labels) and only use the numeric
fallback when the token is not present in vocab; update process_labels to
compute image_soft_token_id dynamically, then replace labels[labels ==
image_soft_token_id] = -100, falling back to 262144 if the conversion returns an
unknown id.
- Around line 338-437: The scan is correct but _match_prefix repeatedly calls
label[start_pos:start_pos+len(tok_seq)].tolist(), causing many Python list
allocations; to optimize, convert each row once to a Python list (e.g., create
label_list = label.tolist() inside the for i in range(...) loop) and change
_match_prefix to compare slices from that label_list against
boundary.start_tokens (or pre-store boundary.start_tokens as tuples) so
comparisons are pure-Python without tensor-to-list conversions; alternatively,
you can compare tensors with torch.equal against pre-built tensors for each
RoleBoundary.start_tokens, but the simplest fix is to use label_list and update
_match_prefix to use it (references: function _apply_role_boundaries, local name
label, helper _match_prefix, and RoleBoundary.start_tokens).
- Around line 296-325: The warning for boundary-less strategies in
_mask_non_assistant currently deduplicates per-process via _ROLE_MASK_WARNED but
still emits once per rank; modify the branch that checks "if key not in
_ROLE_MASK_WARNED" to also check the process rank and only add/log from rank 0
(e.g., gate on int(os.environ.get("RANK", "0")) == 0) before calling LOG.warning
and adding to _ROLE_MASK_WARNED, so other ranks skip the warning while
preserving the single-shot behavior; touch the _mask_non_assistant function and
LOG.warning usage and ensure role_boundaries behavior remains unchanged.

In `@tests/test_processing_strategies.py`:
- Around line 1-1081: Remove or document the unused _mistral_common_stub
fixture: either delete the fixture definition named _mistral_common_stub (the
pytest fixture that simply returns None) from the tests file, or add a one-line
docstring above it explaining it's intentionally present to prime the
lazy-import error path for get_processing_strategy; update any import comment to
reference _mistral_common_stub so future readers know it's purposeful. Also add
a small unit test that exercises the processing_strategies behavior for
train_on_eos="all" combined with a role boundary that sets include_end=False on
a non-trainable role (create a test function that constructs a
ProcessingStrategy with role_boundaries_override where role is non-trainable,
train_on_eos="all" and include_end=False, then assert the output labels follow
the decided behavior), referencing ProcessingStrategy, _apply_role_boundaries,
and RoleBoundary/include_end to pin the intended behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bb1d73db-8316-4cf7-8f0a-ef5a71d9e6c2

📥 Commits

Reviewing files that changed from the base of the PR and between 798c8fb and fb53a08.

📒 Files selected for processing (6)
  • docs/multimodal_assistant_mask.md
  • src/axolotl/core/builders/causal.py
  • src/axolotl/processing_strategies.py
  • src/axolotl/utils/schemas/datasets.py
  • src/axolotl/utils/schemas/multimodal.py
  • tests/test_processing_strategies.py

Comment thread src/axolotl/processing_strategies.py
Comment thread src/axolotl/processing_strategies.py Outdated
Comment thread src/axolotl/processing_strategies.py
@codecov

codecov Bot commented Apr 24, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 79.83651% with 74 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/processing_strategies.py 83.33% 57 Missing ⚠️
src/axolotl/core/builders/causal.py 0.00% 17 Missing ⚠️

📢 Thoughts on this report? Let us know!



Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Apr 24, 2026

Copy link
Copy Markdown
Contributor

@thad0ctor, thanks for the context! I'll kick off a full review of this PR now.

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/processing_strategies.py (1)

1110-1144: ⚠️ Potential issue | 🟡 Minor

Minor: Glm4vProcessingStrategy should guard media ids against unk_token_id.

convert_tokens_to_ids returns unk_token_id for strings not in vocab. If any of <|image|>, <|begin_of_image|>, …, <|end_of_video|> is missing from a given checkpoint's vocab (e.g. a fine-tune that stripped a marker, or a future upstream rename), the if tok_id is not None check at line 1141 still passes and labels[labels == unk_id] = -100 would mask every genuine unk token in the batch. Gemma4ProcessingStrategy.process_labels (lines 746-753) already uses the token_id is None or token_id == unk_id → skip pattern — mirroring it here keeps the class resilient to checkpoint drift.

🛡️ Proposed fix
     def process_labels(self, input_ids):
         labels = input_ids.clone()
         labels = self._mask_non_assistant(labels)
 
         pad_id = getattr(self.tokenizer, "pad_token_id", None)
         if pad_id is not None:
             labels[labels == pad_id] = -100
 
+        unk_id = getattr(self.tokenizer, "unk_token_id", None)
         for tok_id in (
             self.image_token_id,
             self.begin_image_token_id,
             self.end_image_token_id,
             self.video_token_id,
             self.begin_video_token_id,
             self.end_video_token_id,
         ):
-            if tok_id is not None:
+            if tok_id is not None and tok_id != unk_id:
                 labels[labels == tok_id] = -100
 
         return labels
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 1110 - 1144, The
process_labels method in Glm4vProcessingStrategy currently masks token ids even
when convert_tokens_to_ids returned the tokenizer's unk token id; update the
loop that iterates over self.image_token_id, self.begin_image_token_id,
self.end_image_token_id, self.video_token_id, self.begin_video_token_id,
self.end_video_token_id to first fetch the tokenizer's unk id (e.g., unk_id =
getattr(self.tokenizer, "unk_token_id", None)) and skip masking for any tok_id
that is None or equal to unk_id (i.e., only apply labels[labels == tok_id] =
-100 when tok_id is not None and tok_id != unk_id) so genuine unk tokens are not
incorrectly masked; reference the Glm4vProcessingStrategy.process_labels method
and the image_* / video_* token id attributes when making the change.
🧹 Nitpick comments (1)
src/axolotl/processing_strategies.py (1)

357-373: Optional: avoid per-position .tolist() on tensor slices in the inner scanner.

_match_prefix is called O(n·b) times in the outer loop plus O(n) times inside _find_end per matched span, and each call materializes a Tensor slice via .tolist(). On typical MM sequences (n≈4k, b≈5, 8/batch) this becomes a measurable per-step overhead you pay on every training batch. Converting labels[i] to a Python list once per row eliminates the Python↔C boundary crossings in the hot loop without changing any semantics.

♻️ Proposed fix
-    def _match_prefix(label: Tensor, start_pos: int, tok_seq: list[int]) -> bool:
-        if not tok_seq or start_pos + len(tok_seq) > len(label):
+    def _match_prefix(label: list[int], start_pos: int, tok_seq: list[int]) -> bool:
+        if not tok_seq or start_pos + len(tok_seq) > len(label):
             return False
-        return label[start_pos : start_pos + len(tok_seq)].tolist() == tok_seq
+        return label[start_pos : start_pos + len(tok_seq)] == tok_seq
 
     def _find_end(
-        label: Tensor, start_pos: int, end_tok: list[int]
+        label: list[int], start_pos: int, end_tok: list[int]
     ) -> tuple[int, bool]:
         # Empty end_tok means run to end-of-sequence.
         if not end_tok:
             return len(label), False
@@
     for i in range(labels.shape[0]):
-        label = labels[i]
+        label = labels[i].tolist()
         j = 0
         n = len(label)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 357 - 373, The inner
scanner currently calls _match_prefix which uses label[start_pos : start_pos +
len(tok_seq)].tolist() on every invocation, causing costly Tensor->Python
transitions; change the scanning logic so you convert each label Tensor to a
Python list once before scanning (e.g., at the start of the outer per-row loop
or by accepting a pre-converted list) and then have _match_prefix and _find_end
operate on that Python list (compare slices or use tuple/list comparisons)
instead of calling .tolist() inside the hot loop; update function signatures or
call sites for _match_prefix(label: Tensor, ...) and _find_end(label: Tensor,
...) accordingly to use the pre-converted list representation for label.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 1-31: Change the second heading from h3 to h2 (replace "###
Correct placement" with "## Correct placement") and add a language tag to the
fenced block showing build_collator output (e.g., replace the triple-backtick
fence with ```text) so the docs pass markdownlint rules MD001 and MD040; locate
these edits near the "Correct placement" heading and the code block that begins
with "MM collator: train_on_inputs=False..." in multimodal_assistant_mask.md.

---

Outside diff comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 1110-1144: The process_labels method in Glm4vProcessingStrategy
currently masks token ids even when convert_tokens_to_ids returned the
tokenizer's unk token id; update the loop that iterates over
self.image_token_id, self.begin_image_token_id, self.end_image_token_id,
self.video_token_id, self.begin_video_token_id, self.end_video_token_id to first
fetch the tokenizer's unk id (e.g., unk_id = getattr(self.tokenizer,
"unk_token_id", None)) and skip masking for any tok_id that is None or equal to
unk_id (i.e., only apply labels[labels == tok_id] = -100 when tok_id is not None
and tok_id != unk_id) so genuine unk tokens are not incorrectly masked;
reference the Glm4vProcessingStrategy.process_labels method and the image_* /
video_* token id attributes when making the change.

---

Nitpick comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 357-373: The inner scanner currently calls _match_prefix which
uses label[start_pos : start_pos + len(tok_seq)].tolist() on every invocation,
causing costly Tensor->Python transitions; change the scanning logic so you
convert each label Tensor to a Python list once before scanning (e.g., at the
start of the outer per-row loop or by accepting a pre-converted list) and then
have _match_prefix and _find_end operate on that Python list (compare slices or
use tuple/list comparisons) instead of calling .tolist() inside the hot loop;
update function signatures or call sites for _match_prefix(label: Tensor, ...)
and _find_end(label: Tensor, ...) accordingly to use the pre-converted list
representation for label.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b772d9b1-e0ba-4149-a039-7c08b1424956

📥 Commits

Reviewing files that changed from the base of the PR and between fb53a08 and 954794c.

📒 Files selected for processing (4)
  • docs/multimodal_assistant_mask.md
  • src/axolotl/core/builders/causal.py
  • src/axolotl/processing_strategies.py
  • tests/test_processing_strategies.py

Comment thread docs/multimodal_assistant_mask.md
…trings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@winglian winglian merged commit 5352d41 into axolotl-ai-cloud:main May 5, 2026
15 of 16 checks passed
@thad0ctor

Copy link
Copy Markdown
Contributor Author

@winglian thank you!

@NanoCode012

Copy link
Copy Markdown
Collaborator

Thanks for the PR! I think the next step from here would be to review speed / performance cost for this change and whether the implementation can be improved upon.

@thad0ctor

Copy link
Copy Markdown
Contributor Author

Thanks for the PR! I think the next step from here would be to review speed / performance cost for this change and whether the implementation can be improved upon.

will do, I'll look into benchmarking on my end once my GPUs free up

@thad0ctor

thad0ctor commented May 8, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the PR! I think the next step from here would be to review speed / performance cost for this change and whether the implementation can be improved upon.

I benched the PR, the scanner cost is ~0.21 µs/token (linear, strategy-independent), which works out to ~7 ms at micro batch=8 × seq=4096, well under 0.1% of step time and fully overlapped with GPU compute via dataloader workers.

Digging deeper, there are two possible directions that seem viable if you'd prefer without tapping into preprocessing (although that could be a viable future decision) to reduce any latency all together:

(a) Vectorized scanner: I tested this on a local branch (_apply_role_boundaries_vectorized alongside the original): ~1.5–2× faster (Gemma4 1.64×, Llama3.2V 1.50×, Pixtral 1.93×), this approach tested well and seems valid.

(b) Native return_assistant_tokens_mask=True from apply_chat_template: HF's tokenizers/processors expose this flag for templates with {% generation %} blocks. Where supported, the tokenizer returns the assistant mask directly, eliminating the scanner, the per-strategy boundary declarations, and the cfg.role_boundaries override . This is closer in spirit to the text path (prompt_strategies/chat_template.py), which defers boundary detection to template machinery (find_turn does a template-diff at preprocess time) rather than running its own per-batch token scan. The would would probably be: detect support at strategy init, use native when available, current scanner as fallback. A bit more effort, but architecturally consistent.

My take is that the scanner affords maximum flexibility at a noise-level latency tax. Improving it with the vectorized approach makes it a bit more palatable; frankly multimodal training is so much lengthier than text that a <1% tax is an acceptable trade-off pending a more complex refactor. To make (b) work with the same flexibility (given the diversity of chat templates for models without {% generation %} blocks) would be a strong contender for a 3rd future commit: native mask as a fast path for the supported common case (roles_to_train=["assistant"], train_on_eos="turn", template has generation blocks), scanner as the fallback for multi-role training, non-default EOS modes, cfg.role_boundariesoverrides, and templates without generation blocks.

That said, in hindsight perhaps (b) would have been more appropriate; I know your strategy has been to enforce a standard chat template (and rightfully so if you look at the complexity of chat template handling in llama.cpp and the like) and this deviates from it, so part of this may have been my own preferences leaking in.

@NanoCode012

Copy link
Copy Markdown
Collaborator

You can try timing for option b, but like mentioned earlier, it was very slow when I first tested it.

thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 20, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 20, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 20, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 20, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 20, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 20, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request May 21, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
ved1beta pushed a commit to thad0ctor/axolotl that referenced this pull request May 25, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
winglian pushed a commit to thad0ctor/axolotl that referenced this pull request May 29, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
winglian pushed a commit to thad0ctor/axolotl that referenced this pull request Jun 1, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
winglian pushed a commit to thad0ctor/axolotl that referenced this pull request Jun 9, 2026
Follow-up to axolotl-ai-cloud#3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.
winglian added a commit that referenced this pull request Jun 9, 2026
…3672)

* perf(mm-mask): vectorized role-boundary scanner + fused process_labels

Follow-up to #3625 reducing the cost of multimodal assistant-only masking:

- Vectorized role-boundary scanner (_compute_role_keep_mask_vectorized):
  precomputes start/end match positions in a single batched unfold pass,
  collapsing the inner Python prefix-match into a vectorized
  all-over-last-dim check. Byte-identical to the reference scanner, gated
  on torch.get_num_threads() == 1 (the DataLoader-worker default) since
  the per-op dispatch overhead regresses it under multi-threaded torch.
  1.3-1.5x speedup on real Gemma 3 / Gemma 4 / Qwen 2 tokenizers, tie
  on Llama 3 (more boundaries, more tensor ops).

- Fused process_labels: composes role/pad/media masks into one boolean
  keep tensor and does a single labels[~keep] = -100 write instead of
  the previous 3-4 sequential masked writes. Structural cleanup; the
  measured speedup is ~0ms since the prior masked-write cost was already
  dwarfed by the scanner.

- One-shot warning when MultiModalChatDataCollator is built with
  dataloader_num_workers in (None, 0): MM training does heavier
  collator-side work than text and synchronous collation blocks the GPU.

Tests:
- tests/test_vectorized_scanner.py: 2,000-config differential fuzz +
  targeted edge cases (Pixtral [/INST] rewind, longest-prefix tie,
  train_on_eos modes, empty end_tokens, include_end leak gate).
- tests/test_process_labels_fusion.py: 50 random batches per strategy
  across all 9 process_labels-overriding strategies (450 fuzz cases)
  asserting torch.equal between pre-refactor and fused implementation.
- tests/test_mm_collator_warnings.py: num_workers=0 warning regression.

* fix: address coderabbit feedback and lint

* test(mm-mask): add Voxtral/Mistral3/InternVL/Glm4v strategy coverage

- test_process_labels_fusion.py: factories + parity-fuzz entries for the
  four strategies (200 new fused-vs-legacy assertions); no-boundaries
  branch in _build_alternating_turns now injects pad/extras so the
  pad/media masking path is exercised.
- test_processing_strategies.py: behavioral tests (process_labels,
  train_on_inputs, role_boundaries_override, batch handling, InternVL
  error paths) plus dispatch routing for VoxtralProcessor,
  Mistral3Processor, and InternVLProcessor.

572 -> 788 passing across the PR's four test files.

* perf(mm-mask): bisect end-finder + slice-fill in vectorized scanner

The vectorized scanner's per-row loop still did O(span) Python work: a
linear walk over end-match flags to find each turn's end, and
element-wise bytearray writes to fill spans. Replace both:

- Precompute each boundary's end-match start positions (sorted via
  nonzero) and bisect for the next end >= start_of_content.
- Fill keep-spans with bytearray slice assignment and materialize each
  row tensor once via torch.frombuffer.

Byte-identical to the reference scanner; ~2x faster than the prior
vectorized path (≈3x vs reference) on 8x3-6k batches, widening with turn
length. Adds long-span / multi-end-marker parity fuzz that the original
filler<=200 generator under-exercised.

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants