Skip to content
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a904641
feat: systemic multimodal assistant-only loss masking
thad0ctor Apr 23, 2026
6b34a8b
feat: thread cfg.train_on_inputs / roles_to_train / train_on_eos into…
thad0ctor Apr 23, 2026
e916f42
feat: forward cfg.processor_kwargs to processor from_pretrained (#3617)
thad0ctor Apr 23, 2026
16f9007
docs: multimodal assistant-mask design doc
thad0ctor Apr 23, 2026
78aac49
feat: cfg.role_boundaries YAML override for MM role-mask scanner
thad0ctor Apr 23, 2026
60c81cb
test: additional coverage for MM role-mask scanner edge cases
thad0ctor Apr 23, 2026
5b4ccb8
chore: tighten docstrings and comments in multimodal mask refactor
thad0ctor Apr 23, 2026
bf0c0b4
fix: resolve MM per-dataset masking knobs for pydantic SFTDataset
thad0ctor Apr 23, 2026
acfe4fe
fix(processing_strategies): correctly mask Pixtral / Mistral V7 Tekke…
thad0ctor Apr 23, 2026
6ce4465
docs: refresh test count and commit list in MM mask design doc
thad0ctor Apr 23, 2026
7424ce5
feat(processing_strategies): support train_on_eos='last' and validate…
thad0ctor Apr 23, 2026
c70139e
fix(processing_strategies): honor explicit empty roles_to_train list
thad0ctor Apr 23, 2026
a72996c
fix(processing_strategies): guard None tokenizer id comparisons
thad0ctor Apr 23, 2026
f1d7c4c
fix(processing_strategies): raise on unencodable end markers in role_…
thad0ctor Apr 23, 2026
7e207f9
fix(processing_strategies): narrow optional-import exception handling…
thad0ctor Apr 23, 2026
77a9d17
fix(processing_strategies): drop unreachable qwen3_5_moe dispatcher b…
thad0ctor Apr 23, 2026
3b3aff2
fix(processing_strategies): add trailing newline to Gemma4 role-start…
thad0ctor Apr 23, 2026
f4407f9
fix(processing_strategies): drop unreachable Gemma system boundary
thad0ctor Apr 23, 2026
8c3515a
feat(processing_strategies): INFO-log resolved masking config at init
thad0ctor Apr 23, 2026
08f8e4e
docs: reflect Pixtral scanner fix + final test count in audit table
thad0ctor Apr 23, 2026
b8932f3
revert: drop cfg.processor_kwargs / #3617 — tracked on feat/processor…
thad0ctor Apr 23, 2026
245ee73
fix(processing_strategies): clear three mypy errors flagged by CodeRa…
thad0ctor Apr 23, 2026
42a47d7
fix(multimodal-cpt): resolve CI pre-commit + caplog flakiness on PR #4
thad0ctor Apr 24, 2026
d444156
fix(processing_strategies): address CodeRabbit findings on messages/v…
thad0ctor Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions docs/multimodal_assistant_mask.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Multimodal assistant-only loss masking

## What this fixes

For multimodal fine-tuning, `cfg.train_on_inputs`, `cfg.roles_to_train`, and
`cfg.train_on_eos` were silently ignored. Every non-pad, non-media token in
the sequence — including system prompts, user turns, and role markers —
contributed to the loss. Only Gemma3n had a working per-role mask; every
other multimodal path (LLaVA, Qwen2-VL, Qwen3.5, Gemma3, Llama-3.2 Vision,
Llama 4, Pixtral, Mistral V7 Tekken, Voxtral, SmolVLM2, Mistral3, InternVL,
GLM4V) trained on the full sequence.

## Root cause

`MultiModalChatDataCollator` re-tokenizes raw `messages` via
`processor.apply_chat_template(...)` at collation time, discarding the
per-role labels already computed by `ChatTemplateStrategy.tokenize_prompt` in
the preprocessing path. It then calls
`processing_strategy.process_labels(input_ids)`, which was supposed to rebuild
role-aware labels — but the base `_mask_non_assistant` was a no-op `return
labels`, and only `Gemma3nProcessingStrategy` overrode it. So for every other
multimodal model, the retokenized labels are never masked by role.

Parallel bug ([#3617](https://github.com/axolotl-ai-cloud/axolotl/issues/3617)):
`cfg.processor_kwargs` was never plumbed through to
`processor_cls.from_pretrained`, so users couldn't pass `min_pixels`,
`max_pixels`, `num_crops`, etc., to VLM processors.

## Design

We make role masking a first-class, declarative capability of the base
`ProcessingStrategy` and thread the masking knobs through from the trainer
builder.

### Why this over alternatives

- **Option (b): preserve the per-role labels from `tokenize_prompt`.**
Rejected. The preprocessing labels were computed against a text-only
tokenization; they don't align with the MM collator's re-tokenization after
image/audio/video placeholders expand into hundreds of placeholder tokens.
Preserving them would require either a second tokenization pass with image
stand-ins, or rewriting the collator to never re-tokenize. Either is
high-blast-radius for an incremental bugfix.
- **Option (c): `apply_chat_template(return_assistant_tokens_mask=True)`.**
Rejected. This requires `{% generation %}` / `{% endgeneration %}` jinja
markers. Only `llava.jinja` and `phi_4.jinja` have them in
`src/axolotl/utils/chat_templates/templates/`. Adding these markers to
upstream-mirrored templates (gemma3, qwen2_vl, llama3_2_vision, etc.)
diverges from the reference templates and is fragile when HF updates them.
- **Option (a): parametrized token-boundary scanner in the base class.**
Chosen. Each strategy declares its per-role boundary markers
(`<|im_start|>assistant\n` ... `<|im_end|>` for Qwen2-VL,
`<|turn>model` ... `<turn|>` for Gemma 4, etc.). The base scanner walks the
re-tokenized sequence, locates role spans, and masks everything outside
`cfg.roles_to_train`. Works with existing jinja templates, is testable
offline with fake tokenizers, and fails visible (unverified strategies emit
a one-shot warning rather than silently mis-masking).

### Components

1. **`RoleBoundary`** dataclass in `src/axolotl/processing_strategies.py`
describing one role's `(start_tokens, end_tokens, include_start, include_end)`.
2. **`_apply_role_boundaries`** function: a longest-prefix-match scanner that
implements `roles_to_train` / `train_on_inputs` / `train_on_eos` (`"turn"`
keeps role-end markers on trainable turns, `"all"` keeps them on every
turn, `"none"` excludes them).
3. **`ProcessingStrategy._build_role_boundaries`**: empty by default;
overridden by each subclass. `_mask_non_assistant` delegates to the
scanner; if no boundaries are declared it short-circuits and emits a
one-shot warning (legacy behavior preserved).
4. **Plumbing**: `cfg.train_on_inputs`, the first dataset's `roles_to_train`
and `train_on_eos` are threaded through `build_collator` →
`get_processing_strategy` → each strategy's constructor.
5. **`cfg.processor_kwargs`** added to `ModelInputConfig` and merged into
kwargs passed to `processor_cls.from_pretrained` in `load_processor`.
Axolotl-managed keys (`revision`, `trust_remote_code`) are filtered out
with a warning if the user tries to override them.

## Audit table

| Strategy / chat template | Honors `roles_to_train`? (before) | (after) | Role-boundary markers | Media tokens masked |
|---|---|---|---|---|
| `ProcessingStrategy` (fallback for `llava`, `lfm2vl`, `mistral_v3_tekken`, unknown) | ✗ | fallback + warn | *unverified* | `image_token_id` if processor exposes it |
| `Qwen2VLProcessingStrategy` (`qwen2_vl`) | ✗ | ✓ | `<\|im_start\|>{role}\n` ... `<\|im_end\|>` | `<\|image_pad\|>` |
| `Qwen3_5ProcessingStrategy` (`qwen3_5`) | ✗ | ✓ | same as Qwen2VL | `<\|image_pad\|>`, `<\|video_pad\|>` |
| `Gemma3ProcessingStrategy` (`gemma3`) | ✗ | ✓ | `<start_of_turn>{model/user/system}\n` ... `<end_of_turn>` | `boi_token`, `<image_soft_token>` (262144) |
| `Gemma3nProcessingStrategy` (`gemma3n`) | ✓ (ad-hoc) | ✓ (shared scanner) | same as Gemma 3 | `image_token_id`, `audio_token_id`, `boi_token_id`, `eoi_token_id` |
| `Gemma4ProcessingStrategy` (`gemma4`) | n/a (new) | ✓ | `<\|turn>{model/user/system}` ... `<turn\|>` | `image_token_id`, `audio_token_id`, `boi/eoi/boa/eoa` (resolved via `convert_tokens_to_ids`), `video_token_id` (on processor) |
| `Llama3_2VisionProcessingStrategy` (`llama3_2_vision`) — **new** | ✗ | ✓ | `<\|start_header_id\|>{role}<\|end_header_id\|>\n\n` ... `<\|eot_id\|>` | `image_token_id` via base |
| `Llama4ProcessingStrategy` (`llama4`) — **new** | ✗ | ✓ | `<\|header_start\|>{role}<\|header_end\|>\n\n` ... `<\|eot\|>` | `image_token_id` via base |
| `PixtralProcessingStrategy` (`pixtral`) — **new** | ✗ | ✓ | user: `[INST]` ... `[/INST]` (`include_end=False`), assistant: `[/INST]` ... `eos_token` | `image_token_id` via base |
| `MistralV7TekkenProcessingStrategy` (`mistral_v7_tekken`) — **new** | ✗ | ✓ | `[SYSTEM_PROMPT]` ... `[/SYSTEM_PROMPT]`, `[INST]` ... `[/INST]` (`include_end=False`), assistant: `[/INST]` ... `eos_token` | `image_token_id` via base |
| `VoxtralProcessingStrategy` | ✗ | fallback + warn | *unverified* (mistral-common tokenizer) | `audio_token`, `begin_audio_token` |
| `SmolVLM2ProcessingStrategy` | ✗ | fallback + warn | *unverified* (checkpoint-dependent default) | `<image>` |
| `Mistral3ProcessingStrategy` | ✗ | fallback + warn | *unverified* (mistral-common tokenizer) | `img`, `img_break`, `img_end` |
| `InternVLProcessingStrategy` | ✗ | fallback + warn | *unverified* (InternLM-family) | `processor.image_ids` |
| `Glm4vProcessingStrategy` | ✗ | fallback + warn | *unverified* | image/video + begin/end markers |

Pixtral and Mistral V7 Tekken share a token (`[/INST]`) between the user-end
and assistant-start markers. The scanner supports this via `include_end=False`
on the user boundary: when the scanner hits an end marker that is also another
boundary's start, it rewinds past it so the next iteration can match the
shared token as the next role's start. See commit `acfe4fe4` and the full
per-position assertions in `tests/test_processing_strategies.py`.

*unverified*: the right boundary markers cannot be confirmed without a real
checkpoint; the fallback preserves the legacy "mask pad + media tokens only"
behavior and emits a one-shot warning naming the strategy class so the miss
is visible in training logs. To enable role masking for one of these models,
subclass the strategy and implement `_build_role_boundaries` — see the Gemma
and Qwen implementations for the pattern.

## Config-based override: `cfg.role_boundaries`

For the "unverified" strategies above, or for custom chat templates that
don't match a built-in strategy's markers, users can declare role boundaries
directly in YAML without subclassing:

```yaml
role_boundaries:
- role: assistant
start: "<|turn>model"
end: "<turn|>"
- role: user
start: "<|turn>user"
end: "<turn|>"
# Optional keys:
# include_start: false # default False
# include_end: true # default True, respects cfg.train_on_eos
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
# end: null # span runs to end of sequence
```

Semantics:

- `start` and `end` are literal strings; axolotl encodes them at strategy
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
resolved token-id sequences at INFO level.
- The special value `end: eos_token` is the portable way to express
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
- When `role_boundaries` is set, it **replaces** the strategy's built-in
declarations wholesale. This is intentional: partial overlays are hard to
reason about at review time.
- `cfg.roles_to_train` still governs which declared roles contribute to
loss. You can declare `user` and `assistant` boundaries and set
`roles_to_train: ["assistant"]` to have the scanner correctly identify
user spans as masking boundaries without training on their content.
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
unencodable markers), not silently at loss-compute time.

## Commits on this branch

Run `git log main..HEAD --oneline` for the authoritative sequence. As of
this revision the logical units are:

1. **`feat: systemic multimodal assistant-only loss masking`** — core
refactor of `processing_strategies.py` (`RoleBoundary`,
`_apply_role_boundaries`, `_build_role_boundaries`), per-strategy boundary
declarations, dispatcher routing for new subclasses.
2. **`feat: thread cfg.train_on_inputs / roles_to_train / train_on_eos into
MM collator`** — `build_collator` reads the knobs from `cfg` and the
first dataset entry and passes them to `get_processing_strategy`.
3. **`feat: forward cfg.processor_kwargs to processor from_pretrained (#3617)`**
— schema field added; `load_processor` merges kwargs; axolotl-managed
keys (`revision`, `trust_remote_code`) protected.
4. **`docs: multimodal assistant-mask design doc`** — this file.
5. **`feat: cfg.role_boundaries YAML override for MM role-mask scanner`** —
schema field (`MultiModalConfig.role_boundaries`), resolver that converts
string markers to token ids at strategy init, ``eos_token`` sentinel, and
wiring through ``build_collator`` / ``get_processing_strategy`` /
every strategy constructor.
6. **`test: additional coverage for MM role-mask scanner edge cases`** —
expands the unit test suite to 64 tests covering scanner semantics,
per-strategy masking, media-token masking within assistant spans,
dispatcher routing, processor_kwargs passthrough, override semantics
(replace built-in, enable on unverified strategy, eos_token sentinel,
null end, validation errors, pydantic model input).
7. **`chore: tighten docstrings and comments in multimodal mask refactor`**
— no-behavior-change polish.
8. **`fix: resolve MM per-dataset masking knobs for pydantic SFTDataset`**
— `build_collator` resolver now uses `.get` → `getattr` fallback so
`roles_to_train` / `train_on_eos` are honored when datasets are supplied
as pydantic models (not just `DictDefault`). Adds an INFO log of the
resolved collator knobs.

## Verification

- All 64 unit tests pass offline (`pytest tests/test_processing_strategies.py`).
- End-to-end check against real tokenizers:
- `google/gemma-4-E2B-it`: 13/40 tokens kept for a 2-turn chat; decoded
preview shows only assistant responses + `<turn|>` markers remain.
- `axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer` (with bundled
`llama3_2_vision.jinja`): 11/64 tokens kept; content correctly resolves
to `"The capital of France is Paris.<|eot_id|>"` and `"Berlin.<|eot_id|>"`.
- Verified boundary token ids against the real Gemma 4 tokenizer:
`<|turn>model` → `[105, 4368]`, `<turn|>` → `[106]`, `<|image|>` → `258880`,
`<|audio|>` → `258881`, `<|video|>` → `258884`.

## Draft upstream PR description

> Fix silently-ignored `train_on_inputs` / `roles_to_train` / `train_on_eos`
> in the multimodal training path, and plumb `cfg.processor_kwargs`
> (#3617).
>
> **Why this matters**: for every multimodal model except Gemma 3n, loss was
> computed on the entire sequence (minus pad and media tokens) regardless of
> what `roles_to_train` / `train_on_inputs` the config specified. This
> silently turned assistant-only SFT into full-sequence SFT for thousands of
> users, degrading sample efficiency and introducing spurious gradient signal
> on system and user content.
>
> **What changed**:
> - `ProcessingStrategy._build_role_boundaries` declares per-role start/end
> token sequences. The base `_mask_non_assistant` now consumes those
> declarations via a shared scanner that honors `train_on_inputs`,
> `roles_to_train`, and `train_on_eos`.
> - Per-strategy boundary declarations added for Qwen2-VL, Qwen3.5, Gemma 3,
> Gemma 3n (refactored from ad-hoc scanner), Gemma 4 (new), Llama 3.2
> Vision (new), Llama 4 (new), Pixtral (new), Mistral V7 Tekken (new).
> - Strategies whose boundary tokens we couldn't verify against a real
> tokenizer (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, and the
> llava/lfm2vl/unknown-template fallback) retain legacy behavior but emit a
> one-shot warning so the miss is visible in training logs.
> - `cfg.train_on_inputs` / `cfg.datasets[0].roles_to_train` /
> `cfg.datasets[0].train_on_eos` are threaded through
> `HFCausalTrainerBuilder.build_collator` → `get_processing_strategy` →
> strategy constructor.
> - `cfg.processor_kwargs` (new) is merged into
> `processor_cls.from_pretrained` kwargs; `revision` and `trust_remote_code`
> remain axolotl-managed.
>
> **Testing**: 64 offline unit tests; end-to-end verified with the real
> Gemma 4 and Llama 3.x tokenizers.
37 changes: 37 additions & 0 deletions src/axolotl/core/builders/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,49 @@ def build_collator(
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
# Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
ds_entries = self.cfg.datasets or []
ds_cfg = ds_entries[0] if ds_entries else None

def _ds_get(cfg_obj, key):
# Handle DictDefault / dict / pydantic uniformly:
# dict-style .get first, then attribute access.
if cfg_obj is None:
return None
if hasattr(cfg_obj, "get"):
try:
return cfg_obj.get(key)
except (AttributeError, KeyError, TypeError):
pass
return getattr(cfg_obj, key, None)

roles_to_train = _ds_get(ds_cfg, "roles_to_train")
train_on_eos = _ds_get(ds_cfg, "train_on_eos")

# cfg.role_boundaries replaces the strategy's built-in markers.
role_boundaries_override = None
if self.cfg.role_boundaries:
role_boundaries_override = list(self.cfg.role_boundaries)

LOG.info(
"MM collator: train_on_inputs=%s roles_to_train=%s "
"train_on_eos=%s role_boundaries_override=%s",
bool(self.cfg.train_on_inputs),
roles_to_train,
train_on_eos,
"set" if role_boundaries_override else "none",
)

kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
train_on_inputs=bool(self.cfg.train_on_inputs),
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
role_boundaries_override=role_boundaries_override,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
Expand Down
17 changes: 15 additions & 2 deletions src/axolotl/loaders/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,21 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)

# Build common kwargs for processor loading
processor_kwargs = {}
# User kwargs go in first so axolotl-managed keys (revision, trust_remote_code) can't be overridden. See #3617.
processor_kwargs: dict = {}
user_processor_kwargs = dict(cfg.processor_kwargs) if cfg.processor_kwargs else {}
_overridden = {
k for k in ("revision", "trust_remote_code") if k in user_processor_kwargs
}
if _overridden:
LOG.warning(
"Ignoring cfg.processor_kwargs keys %s — these are managed by axolotl.",
sorted(_overridden),
)
for k in _overridden:
user_processor_kwargs.pop(k, None)
processor_kwargs.update(user_processor_kwargs)

if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model

Expand Down
Loading
Loading