Skip to content

feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries#7

Merged
thad0ctor merged 7 commits into
mainfrom
feat/multimodal-assistant-mask-all
Apr 24, 2026
Merged

feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries#7
thad0ctor merged 7 commits into
mainfrom
feat/multimodal-assistant-mask-all

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Apr 24, 2026

Copy link
Copy Markdown
Owner

Squashed re-submission of the work originally proposed in #4 (closed). Single commit on top of main so CodeRabbit can re-review against a clean history.

Summary

  • Fixes silent ignoring of cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config.
  • Adds cfg.role_boundaries YAML override so users can declare per-role markers without subclassing.

What changed

  • ProcessingStrategy gains a declarative boundary scanner. Each strategy declares per-role start/end markers via _build_role_boundaries; the shared scanner honors train_on_inputs / roles_to_train / train_on_eos (including "last").
  • New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken.
  • Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared).
  • Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via cfg.role_boundaries.
  • Pixtral / Mistral V7 Tekken correctly handle the shared [/INST] token between user-end and assistant-start via include_end=False + scanner rewind.

Design

See docs/multimodal_assistant_mask.md for the full audit table, root-cause analysis, design rationale (why boundary-scanner over return_assistant_tokens_mask or preserving tokenize_prompt labels), and cfg.role_boundaries usage.

Test plan

  • 64 offline unit tests in tests/test_processing_strategies.py pass (no HF Hub access needed)
  • End-to-end verified against real google/gemma-4-E2B-it tokenizer: 13/40 tokens kept on 2-turn chat, correct spans
  • End-to-end verified against real Llama-3.x tokenizer with bundled llama3_2_vision.jinja: 11/64 tokens kept, correct spans
  • End-to-end verified against the real Rashi-OCR Gemma-4-31B tokenizer + dataset, including cfg.role_boundaries override path
  • Text-only training path unaffected (MM path only taken when cfg.processor_type and self.processor)
  • Migration visibility: INFO-log at strategy init + collator assembly surfaces resolved masking config

History

Previous incremental history (25 commits including review-driven fixes for Pixtral/Mistral V7 Tekken, train_on_eos="last" support, pydantic dataset handling, and validation hardening) was squashed into a single commit for review clarity.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added configurable role-boundary masking for multimodal models with YAML override support.
    • Introduced granular training control parameters (train_on_inputs, roles_to_train, train_on_eos) for fine-grained loss masking.
    • Enhanced role masking support across multiple multimodal architectures with declarative configuration.
  • Documentation

    • Added comprehensive guide on multimodal loss masking improvements, configuration options, and behavior audit table.

…daries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Apr 24, 2026

Copy link
Copy Markdown

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ea74fe77-6382-4d38-9113-ce980fb74fdc

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces a declarative role-boundary masking system for multimodal chat models. Replaces per-model assistant-only masking with a shared _apply_role_boundaries() scanner driven by RoleBoundary declarations, configuration parameters (train_on_inputs, roles_to_train, train_on_eos), and optional YAML-based overrides (role_boundaries_override). Updates builder, processing strategies, schemas, and adds comprehensive test coverage.

Changes

Cohort / File(s) Summary
Documentation & Schema
docs/multimodal_assistant_mask.md, src/axolotl/utils/schemas/multimodal.py
Adds documentation of the role-boundary masking design, audit tables, and verification results. Introduces new RoleBoundarySpec model and extends MultiModalConfig with optional role_boundaries field to enable runtime marker override.
Processing Strategies Core
src/axolotl/processing_strategies.py
Adds role-boundary masking parameters to ProcessingStrategy.__init__, implements shared _apply_role_boundaries() token scanner with longest-prefix and turn-based trainable span logic, declares built-in role boundaries for multiple model architectures (Qwen, Gemma, Llama, Pixtral, Mistral), and threads new parameters through get_processing_strategy.
Builder Integration
src/axolotl/core/builders/causal.py
Extracts per-dataset masking parameters (roles_to_train, train_on_eos) from first configured dataset, reads optional top-level role_boundaries_override, and passes these to get_processing_strategy and logging calls.
Test Coverage
tests/test_processing_strategies.py
Comprehensive offline test suite validating role-boundary masking logic, train_on_eos modes, per-model strategy configuration, longest-prefix selection, batch behavior, and role_boundaries_override handling across multiple architectures.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.52% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: introducing systemic multimodal assistant-only loss masking and a new cfg.role_boundaries configuration override mechanism.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/multimodal-assistant-mask-all

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
src/axolotl/core/builders/causal.py (1)

524-541: Consider documenting the first-dataset-only limitation.

The code reads roles_to_train and train_on_eos from only the first dataset entry (ds_entries[0]). While this mirrors ChatTemplateStrategy, it may surprise users with heterogeneous multi-dataset configs. A brief inline comment or LOG.debug noting this behavior would improve clarity.

📝 Suggested documentation
                 # Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
                 ds_entries = self.cfg.datasets or []
                 ds_cfg = ds_entries[0] if ds_entries else None
+                # NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/core/builders/causal.py` around lines 524 - 541, The code
currently reads per-dataset masking knobs from only the first dataset
(ds_entries -> ds_cfg) via _ds_get and assigns roles_to_train and train_on_eos
based on that single entry, which can be surprising for multi-dataset configs;
add a concise inline comment above the ds_entries/ds_cfg block stating this
"first-dataset-only" behavior and add a LOG.debug (or existing logger) line
mentioning that roles_to_train and train_on_eos are being sourced from the first
dataset entry (ds_cfg) and that other dataset entries are ignored, referencing
ChatTemplateStrategy to indicate parity.
src/axolotl/processing_strategies.py (1)

347-361: Consider adding type hints to inner functions for clarity.

The _match_prefix and _find_end inner functions work correctly but lack type hints, making the code slightly harder to follow. This is a minor nitpick.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 347 - 361, Add type hints
to the inner functions: annotate _match_prefix(label, start_pos, tok_seq) as
(label: Sequence[int], start_pos: int, tok_seq: Sequence[int]) -> bool and
annotate _find_end(label, start_pos, end_tok) as (label: Sequence[int],
start_pos: int, end_tok: Sequence[int]) -> tuple[int, bool]; also ensure
Sequence is imported from typing (or use a suitable type like list[int] /
torch.Tensor if preferred) so the parameter and return types for _match_prefix
and _find_end are explicit and consistent with the rest of the module.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 142-168: The numbered commit list under "Commits on this branch"
skips 3; update the sequence to be consecutive (e.g., change the item currently
labeled "4." to "3." and renumber subsequent items to 4–8) so items read 1–8;
ensure any internal references or cross-links that relied on the old numbering
are updated to the new numbers and keep the commit titles (e.g., "feat: thread
cfg.train_on_inputs ...", "docs: multimodal assistant-mask design doc")
unchanged.

---

Nitpick comments:
In `@src/axolotl/core/builders/causal.py`:
- Around line 524-541: The code currently reads per-dataset masking knobs from
only the first dataset (ds_entries -> ds_cfg) via _ds_get and assigns
roles_to_train and train_on_eos based on that single entry, which can be
surprising for multi-dataset configs; add a concise inline comment above the
ds_entries/ds_cfg block stating this "first-dataset-only" behavior and add a
LOG.debug (or existing logger) line mentioning that roles_to_train and
train_on_eos are being sourced from the first dataset entry (ds_cfg) and that
other dataset entries are ignored, referencing ChatTemplateStrategy to indicate
parity.

In `@src/axolotl/processing_strategies.py`:
- Around line 347-361: Add type hints to the inner functions: annotate
_match_prefix(label, start_pos, tok_seq) as (label: Sequence[int], start_pos:
int, tok_seq: Sequence[int]) -> bool and annotate _find_end(label, start_pos,
end_tok) as (label: Sequence[int], start_pos: int, end_tok: Sequence[int]) ->
tuple[int, bool]; also ensure Sequence is imported from typing (or use a
suitable type like list[int] / torch.Tensor if preferred) so the parameter and
return types for _match_prefix and _find_end are explicit and consistent with
the rest of the module.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 31b19ebf-0c33-418e-b7d2-71260e3773f9

📥 Commits

Reviewing files that changed from the base of the PR and between 798c8fb and 9416d40.

📒 Files selected for processing (5)
  • docs/multimodal_assistant_mask.md
  • src/axolotl/core/builders/causal.py
  • src/axolotl/processing_strategies.py
  • src/axolotl/utils/schemas/multimodal.py
  • tests/test_processing_strategies.py

Comment thread docs/multimodal_assistant_mask.md Outdated
…daries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions

github-actions Bot commented Apr 24, 2026

Copy link
Copy Markdown

📖 Documentation Preview:

Deployed on Netlify from commit 954794c

thad0ctor and others added 5 commits April 24, 2026 11:14
- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>


Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@thad0ctor thad0ctor merged commit 218018c into main Apr 24, 2026
15 checks passed
thad0ctor added a commit that referenced this pull request Apr 24, 2026
Bring in d76d66e chore(mm-mask): hoist .tolist() out of scanner;
shorten comments/docstrings — the single commit landed after PR #7
was merged.
thad0ctor added a commit that referenced this pull request May 12, 2026
… tests exercise D2/D3 hot paths (R3-#6 + R3-#7)

Three lifecycle / correctness fixes from CodeRabbit's third-round
review on PR #21.

**R3-#1 — scheduler.ensure_chunks_resident SWAP-stream barrier.**

M6C-fix-4 routes the LoRA-container synchronous gather onto the
compute stream (so the all_gather completes before autograd's
``_to_copy`` op records its source shape against the rebound
``param.data``). That bypass also skipped the
``compute.wait_stream(_swap_stream)`` barrier that
``_gather_on_prefetch_stream`` performs to protect pool buffers
from being overwritten while a SWAP D2H is still reading them.
On the SWAP + LoRA path that reopens the same cross-stream
buffer race the prefetch-stream barrier closes, just shifted onto
the compute stream.

Add the compute-stream wait_stream on ``_swap_stream`` before the
synchronous gather loop in
``Scheduler.ensure_chunks_resident``. Cost is one event-record /
event-wait pair per LoRA container hook fire; on the steady-state
fast path the wait completes immediately (no SWAP in flight on
the pool buffers being gathered) and is dominated by the gather's
H2D / all_gather work.

**R3-#6 — test_cpu_optim_replaced_calls_shutdown_on_previous no
longer self-skips.**

The pre-fix test used ``force_all_persistent=True`` which produces
``n_persist == N_chunk`` on the tiny model — no chunks offloaded
→ no ``CpuFusedAdamAdapter`` constructed → the test's "no CPU
adapter to swap" skip fires 100 % of the time. The D3 invariant
was effectively never exercised by this test.

Switch to ``force_all_persistent=False`` + explicit overrides
(``n_persist_override=0``, ``n_offload_override=N_layers``,
``small_chunk=True``) so the tiny model actually produces
non-persistent offloaded chunks and the per-chunk CPU adapter is
built. Probe ``DeepSpeedCPUAdam`` JIT-load up front and skip
cleanly if the env can't even build a CPU adapter — that's a
real env-skip, not a self-skip.

**R3-#7 — test_resume_hook_inprocess_cycle_continues_training
actually offloads.**

Same root cause: with ``force_all_persistent=True``, the
``materialize_offload()`` call inside the simulated resume cycle
was a no-op (no non-persistent chunks to offload). The D2 hot
path the test claims to cover (second ``materialize_offload`` on
the same chunk manager → snapshot-and-rebuild lifecycle) was
never exercised.

Switch to the same offload-mode override pattern as R3-#6 so the
second materialize_offload moves actual bytes (~7 non-persistent
chunks per the layout this produces). Also restructure the save /
load step to capture the state_dict AFTER ``restore_to_gpu``
rather than while chunks are offloaded — saving while offloaded
captured ``Size([0])`` placeholder shapes that wouldn't match
the restored model's full-storage tensors. This matches the
production HF Trainer save path (checkpoints are taken after
the resume hook restores chunks to GPU).

``_wrap_protrain`` now accepts forwarded override knobs +
``small_chunk=True`` (monkey-patches ``pick_S_chunk`` to 1 MiB
matching the working pattern in ``test_lora_offload_mode``) so
the tiny test model actually produces N_chunk > 1 chunks.

Test results after the fixes:

* GPU-marker sweep on resume robustness suite: 3 passed
  (cpu-optim-shutdown invariant, D1 marker cleanup, end-to-end
  resume cycle) / 2 skipped (single-process Mode-C downgrade —
  shape-preserving placeholders not engaged, multi-GPU coverage
  in ``test_real_multigpu_cross_mode_resume_*``).
* ``tests/protrain/`` default-marker sweep: 303 passed / 4
  skipped / 162 deselected / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 12, 2026
…s in prior fixes

CodeRabbit's full-diff re-scan on commit 55377e5 surfaced four
Major correctness gaps in prior triage commits that the incremental
reviews missed.

**F-#1 — Filter post-wrap ignore set to non-persistent chunks only.**

My R4-#1 fix at ``api/model_wrapper.py:2227-2246`` built
``chunk_managed_param_ids`` from ALL
``chunk_manager._params_by_id.values()``, but persistent chunks
should NEVER be in ``_ddp_params_and_buffers_to_ignore`` — they
need normal DDP broadcast and backward allreduce (see
``ChunkManager.chunk_managed_param_names``'s docstring: "Persistent
chunks are excluded — their params stay GPU-resident, do not pass
through the released-state placeholder, and DO need the standard
DDP broadcast for correctness."). The over-broad filter silently
swept persistent params into the ignore set, breaking gradient
sync on the chunks DDP IS supposed to handle.

Restrict the OBJECT-identity set to params backed by
``_non_persistent_ids`` only — iterate ``_cpu_slots[cid]`` for
each non-persistent ``cid`` and pull the param ref from
``_params_by_id``. Renamed the local loop vars (``_cpu_slot``
instead of ``slot``) to avoid shadowing the earlier
``for slot, child in enumerate(parent)`` block-wrap site that
binds ``slot`` to ``int``.

**F-#3 — Abort optim swap on CPU-adapter teardown failure.**

My D3 fix at ``api/optim_wrapper.py:951`` wrapped
``_old_cpu_optim.shutdown()`` in a try/except that warned and
continued. The whole point of D3 is the deterministic-cleanup
invariant — masking a real teardown failure (``ThreadPoolExecutor``
hung, DeepSpeed C-state corrupted) puts the failed adapter back on
the GC path AND silently accepts an inconsistent state-machine on
the rebuild side. Removed the try/except so a shutdown failure
aborts the swap rather than papering over it.

**F-#6 — Also fence compute stream against ``_prefetch_stream``.**

My R3-#1 fix at ``runtime/scheduler.py::ensure_chunks_resident``
added ``compute.wait_stream(_swap_stream)`` before the
synchronous gather loop to close the SWAP D2H race. CodeRabbit
caught that the symmetric prefetch race is still open: if a chunk
is being prefetched and ``ChunkManager.gather()`` hits the
``_active_chunks`` resident fast path, ``param.data`` rebinds
while the prefetch's H2D / ``all_gather_into_tensor`` is still
running on ``_prefetch_stream`` — gather returns BEFORE the chunk
is compute-stream-safe, and a LoRA forward consuming
``param.data`` reads stale / not-yet-written bytes.

Add ``compute.wait_stream(_prefetch_stream)`` alongside the
existing ``compute.wait_stream(_swap_stream)`` so both
cross-stream barriers fire when present. Cost: one extra
event-record / event-wait per LoRA container hook fire; no-op when
``_prefetch_stream`` isn't running anything.

**F-#7 — Broaden exception scope in ``check_cuda_p2p_support``.**

My D9 fix at ``utils/environment.py:96`` caught only
``AssertionError`` from ``torch.cuda.can_device_access_peer``.
Per the PyTorch 2.6 docs path, the Python wrapper validates
device indices with ``AssertionError``, but the C++ binding
``_cuda_canDeviceAccessPeer`` it delegates to can surface
exceptions from the CUDA runtime (``RuntimeError`` wrapping
``cudaErrorInvalidDevice``, peer-access machinery errors) that
``AssertionError`` wouldn't catch. An unhandled exception would
propagate out of the helper and break the fail-closed contract —
ranks would disagree about ``NCCL_P2P_DISABLE`` which is exactly
the SIGSEGV class commit ``91e0912e`` set out to prevent.

Widened to ``except Exception`` (with ``noqa: BLE001`` annotation
explicitly documenting the fail-closed rationale).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ved1beta pushed a commit that referenced this pull request May 14, 2026
…daries` (axolotl-ai-cloud#3625)

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs+types: address CodeRabbit nitpicks on PR #7

- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mm-mask): address two CodeRabbit findings on PR #7

1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* doc cleanup

* fix(mm-mask): CodeRabbit findings + lint fix on PR axolotl-ai-cloud#3625

Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.
thad0ctor added a commit that referenced this pull request May 28, 2026
… tests exercise D2/D3 hot paths (R3-#6 + R3-#7)

Three lifecycle / correctness fixes from CodeRabbit's third-round
review on PR #21.

**R3-#1 — scheduler.ensure_chunks_resident SWAP-stream barrier.**

M6C-fix-4 routes the LoRA-container synchronous gather onto the
compute stream (so the all_gather completes before autograd's
``_to_copy`` op records its source shape against the rebound
``param.data``). That bypass also skipped the
``compute.wait_stream(_swap_stream)`` barrier that
``_gather_on_prefetch_stream`` performs to protect pool buffers
from being overwritten while a SWAP D2H is still reading them.
On the SWAP + LoRA path that reopens the same cross-stream
buffer race the prefetch-stream barrier closes, just shifted onto
the compute stream.

Add the compute-stream wait_stream on ``_swap_stream`` before the
synchronous gather loop in
``Scheduler.ensure_chunks_resident``. Cost is one event-record /
event-wait pair per LoRA container hook fire; on the steady-state
fast path the wait completes immediately (no SWAP in flight on
the pool buffers being gathered) and is dominated by the gather's
H2D / all_gather work.

**R3-#6 — test_cpu_optim_replaced_calls_shutdown_on_previous no
longer self-skips.**

The pre-fix test used ``force_all_persistent=True`` which produces
``n_persist == N_chunk`` on the tiny model — no chunks offloaded
→ no ``CpuFusedAdamAdapter`` constructed → the test's "no CPU
adapter to swap" skip fires 100 % of the time. The D3 invariant
was effectively never exercised by this test.

Switch to ``force_all_persistent=False`` + explicit overrides
(``n_persist_override=0``, ``n_offload_override=N_layers``,
``small_chunk=True``) so the tiny model actually produces
non-persistent offloaded chunks and the per-chunk CPU adapter is
built. Probe ``DeepSpeedCPUAdam`` JIT-load up front and skip
cleanly if the env can't even build a CPU adapter — that's a
real env-skip, not a self-skip.

**R3-#7 — test_resume_hook_inprocess_cycle_continues_training
actually offloads.**

Same root cause: with ``force_all_persistent=True``, the
``materialize_offload()`` call inside the simulated resume cycle
was a no-op (no non-persistent chunks to offload). The D2 hot
path the test claims to cover (second ``materialize_offload`` on
the same chunk manager → snapshot-and-rebuild lifecycle) was
never exercised.

Switch to the same offload-mode override pattern as R3-#6 so the
second materialize_offload moves actual bytes (~7 non-persistent
chunks per the layout this produces). Also restructure the save /
load step to capture the state_dict AFTER ``restore_to_gpu``
rather than while chunks are offloaded — saving while offloaded
captured ``Size([0])`` placeholder shapes that wouldn't match
the restored model's full-storage tensors. This matches the
production HF Trainer save path (checkpoints are taken after
the resume hook restores chunks to GPU).

``_wrap_protrain`` now accepts forwarded override knobs +
``small_chunk=True`` (monkey-patches ``pick_S_chunk`` to 1 MiB
matching the working pattern in ``test_lora_offload_mode``) so
the tiny test model actually produces N_chunk > 1 chunks.

Test results after the fixes:

* GPU-marker sweep on resume robustness suite: 3 passed
  (cpu-optim-shutdown invariant, D1 marker cleanup, end-to-end
  resume cycle) / 2 skipped (single-process Mode-C downgrade —
  shape-preserving placeholders not engaged, multi-GPU coverage
  in ``test_real_multigpu_cross_mode_resume_*``).
* ``tests/protrain/`` default-marker sweep: 303 passed / 4
  skipped / 162 deselected / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
…s in prior fixes

CodeRabbit's full-diff re-scan on commit 55377e5 surfaced four
Major correctness gaps in prior triage commits that the incremental
reviews missed.

**F-#1 — Filter post-wrap ignore set to non-persistent chunks only.**

My R4-#1 fix at ``api/model_wrapper.py:2227-2246`` built
``chunk_managed_param_ids`` from ALL
``chunk_manager._params_by_id.values()``, but persistent chunks
should NEVER be in ``_ddp_params_and_buffers_to_ignore`` — they
need normal DDP broadcast and backward allreduce (see
``ChunkManager.chunk_managed_param_names``'s docstring: "Persistent
chunks are excluded — their params stay GPU-resident, do not pass
through the released-state placeholder, and DO need the standard
DDP broadcast for correctness."). The over-broad filter silently
swept persistent params into the ignore set, breaking gradient
sync on the chunks DDP IS supposed to handle.

Restrict the OBJECT-identity set to params backed by
``_non_persistent_ids`` only — iterate ``_cpu_slots[cid]`` for
each non-persistent ``cid`` and pull the param ref from
``_params_by_id``. Renamed the local loop vars (``_cpu_slot``
instead of ``slot``) to avoid shadowing the earlier
``for slot, child in enumerate(parent)`` block-wrap site that
binds ``slot`` to ``int``.

**F-#3 — Abort optim swap on CPU-adapter teardown failure.**

My D3 fix at ``api/optim_wrapper.py:951`` wrapped
``_old_cpu_optim.shutdown()`` in a try/except that warned and
continued. The whole point of D3 is the deterministic-cleanup
invariant — masking a real teardown failure (``ThreadPoolExecutor``
hung, DeepSpeed C-state corrupted) puts the failed adapter back on
the GC path AND silently accepts an inconsistent state-machine on
the rebuild side. Removed the try/except so a shutdown failure
aborts the swap rather than papering over it.

**F-#6 — Also fence compute stream against ``_prefetch_stream``.**

My R3-#1 fix at ``runtime/scheduler.py::ensure_chunks_resident``
added ``compute.wait_stream(_swap_stream)`` before the
synchronous gather loop to close the SWAP D2H race. CodeRabbit
caught that the symmetric prefetch race is still open: if a chunk
is being prefetched and ``ChunkManager.gather()`` hits the
``_active_chunks`` resident fast path, ``param.data`` rebinds
while the prefetch's H2D / ``all_gather_into_tensor`` is still
running on ``_prefetch_stream`` — gather returns BEFORE the chunk
is compute-stream-safe, and a LoRA forward consuming
``param.data`` reads stale / not-yet-written bytes.

Add ``compute.wait_stream(_prefetch_stream)`` alongside the
existing ``compute.wait_stream(_swap_stream)`` so both
cross-stream barriers fire when present. Cost: one extra
event-record / event-wait per LoRA container hook fire; no-op when
``_prefetch_stream`` isn't running anything.

**F-#7 — Broaden exception scope in ``check_cuda_p2p_support``.**

My D9 fix at ``utils/environment.py:96`` caught only
``AssertionError`` from ``torch.cuda.can_device_access_peer``.
Per the PyTorch 2.6 docs path, the Python wrapper validates
device indices with ``AssertionError``, but the C++ binding
``_cuda_canDeviceAccessPeer`` it delegates to can surface
exceptions from the CUDA runtime (``RuntimeError`` wrapping
``cudaErrorInvalidDevice``, peer-access machinery errors) that
``AssertionError`` wouldn't catch. An unhandled exception would
propagate out of the helper and break the fail-closed contract —
ranks would disagree about ``NCCL_P2P_DISABLE`` which is exactly
the SIGSEGV class commit ``91e0912e`` set out to prevent.

Widened to ``except Exception`` (with ``noqa: BLE001`` annotation
explicitly documenting the fail-closed rationale).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant