feat: systemic multimodal assistant-only loss masking + #3617#4
feat: systemic multimodal assistant-only loss masking + #3617#4thad0ctor wants to merge 24 commits into
Conversation
Adds a parametrized role-boundary scanner to ProcessingStrategy so that cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos (which were previously silently ignored for every multimodal path except Gemma 3n) restrict loss to trainable role spans. Introduces ``RoleBoundary`` and ``_apply_role_boundaries``. Each strategy declares its per-role start/end token sequences via ``_build_role_boundaries``; the base ``_mask_non_assistant`` delegates to the shared scanner and short-circuits with a one-shot warning when a strategy has no declared boundaries (preserving legacy behavior for unverified paths). Per-strategy boundary declarations added for Qwen2-VL, Qwen3.5, Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc per-strategy scanner, now shared), Gemma 4 (new), Llama 3.2 Vision (new), Llama 4 (new), Pixtral (new), and Mistral V7 Tekken (new). Dispatcher routes chat_template_type to the new subclasses. Strategies whose boundary tokens we can't verify without loading a real checkpoint (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, plus the llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning naming the class so the miss is visible in training logs. Also makes the dispatcher's Mistral3Processor import lazy so that non-mistral-common installs can still route to non-mistral strategies. Includes 30 offline unit tests covering scanner semantics (train_on_inputs/roles_to_train/train_on_eos, longest-prefix start match, truncated spans), per-strategy masking with fake tokenizers matching real model id layouts, media-token masking inside assistant spans, and dispatcher routing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… MM collator HFCausalTrainerBuilder.build_collator now reads cfg.train_on_inputs and the first dataset entry's roles_to_train / train_on_eos (mirroring how ChatTemplateStrategy resolves them for text-only training) and forwards them to get_processing_strategy so the shared scanner actually runs with the user's config. Without this, the scanner from the prior commit would always use its defaults (train_on_inputs=False, roles_to_train=["assistant"], train_on_eos="turn"), which happens to match most users' intent but silently ignores explicit config overrides. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…otl-ai-cloud#3617) Adds a ``processor_kwargs: dict | None`` field to ModelInputConfig and merges it into the kwargs passed to ``processor_cls.from_pretrained`` in load_processor. Lets users set min_pixels, max_pixels, num_crops, do_rescale, patch_size, and similar VLM processor knobs from YAML. Axolotl-managed keys (``revision``, ``trust_remote_code``) are filtered out of the user's processor_kwargs with a warning, so passing ``processor_kwargs: {revision: HIJACKED}`` cannot override the cfg-level revision_of_model or the axolotl-managed trust_remote_code. Includes 2 offline unit tests covering the forward-and-filter behavior and the no-op case where cfg.processor_kwargs is absent. Closes axolotl-ai-cloud#3617. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Captures the audit of every multimodal processing strategy, the root-cause analysis, the design rationale for the boundary-scanner approach (vs preserving tokenize_prompt labels or relying on HF's return_assistant_tokens_mask), the per-commit summary for this branch, and a draft upstream PR description. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lets users declare per-role start/end markers directly in YAML without
subclassing a ProcessingStrategy. Intended uses:
1. Enable role masking on strategies we can't verify against a real
checkpoint (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, unknown
fallback) — instead of a one-shot warning + full-sequence loss, the
user supplies markers for their specific checkpoint's chat template
and gets proper assistant-only masking.
2. Fine-tune markers for custom / patched chat templates whose tokenized
output diverges from axolotl's built-ins.
YAML shape:
role_boundaries:
- role: assistant
start: "<|turn>model"
end: "<turn|>"
# include_start: false (default)
# include_end: true (default, respects cfg.train_on_eos)
# end: eos_token (sentinel → tokenizer.eos_token_id)
# end: null (span runs to end of sequence)
Implementation:
- RoleBoundarySpec pydantic model added to MultiModalConfig (so Axolotl's
standard validation catches typos at config-load time).
- _resolve_role_boundary_override converts strings → token ids at
strategy init; logs the resolved ids at INFO so mismatches are visible.
- Validation errors (missing role/start, unencodable marker) surface at
init time rather than as silent mis-masking during training.
- ProcessingStrategy.__init__ now takes role_boundaries_override; when
set, it REPLACES the subclass's _build_role_boundaries() result
wholesale (partial overlays would be ambiguous to review).
- Threaded through build_collator → get_processing_strategy → every
strategy constructor.
Verified end-to-end against the real Rashi-OCR Gemma-4-31B tokenizer +
dataset: pass with built-in boundaries, pass with an assistant-only
override, and pass via the dispatcher path all produce identical correct
masks on real dataset rows.
Includes 7 additional unit tests (override replaces built-in, override
enables unverified strategy, eos_token sentinel, null end, spec
validation errors, Pydantic model input). Total: 39 offline tests.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests for: batch_size>1 (scanner + Qwen2VL strategy), include_start/include_end round-trip through RoleBoundarySpec override, pad masking inside trainable spans, all-pad sequence, multiple consecutive assistant turns, train_on_eos variants on multi-turn conversations, MultiModalConfig dict->RoleBoundarySpec parsing, Qwen3.5 video_pad masking under train_on_inputs=True, and empty processor_kwargs no-op. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Trims verbose block comments and docstrings across the branch to 1-2 lines each, keeping the non-obvious WHY and dropping the WHAT. Behavior unchanged; all 39 tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous `ds_cfg.get if hasattr else __getitem__` ladder fell through to an AttributeError (swallowed) for pydantic dataset entries, silently ignoring `roles_to_train` / `train_on_eos`. Switch to dict-style `.get` first, then `getattr(obj, key, None)` — works for DictDefault, dict, and pydantic SFTDataset alike. Also log resolved collator knobs at INFO so the MM masking config is visible in training logs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…n assistant spans [/INST] is shared between user-end and assistant-start on these templates, so the scanner consumed it for user and never matched it as assistant-start, producing all-masked labels. Teach the scanner to not consume the end marker when include_end=False (back up by len(end_tokens)), and declare user.include_end=False on both Pixtral and Mistral V7 Tekken so assistant spans are actually detected. Tests now assert the full expected mask. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bump "32 unit tests" references to the current 55, and replace the stale 5-entry commit breakdown with the actual 8-unit sequence (with a pointer to `git log main..HEAD` as the source of truth). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… at init 'last' now trains on the final trainable turn's end marker only, matching the text-only ChatTemplateStrategy's documented semantics. Invalid values raise ValueError at strategy construction instead of silently falling through the scanner's 'turn' branch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Switch falsy check to 'is None' so roles_to_train=[] correctly masks every role instead of silently defaulting to ['assistant']. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
``labels == None`` emits a UserWarning and returns all-False. Sweep process_labels overrides to skip pad/image/audio/video masking when the corresponding id is None. Added a regression test that promotes warnings to errors. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…boundaries Mirrors the existing start-marker guard. Previously a typo'd end would silently produce an empty end_tokens list and the span would run to end of sequence instead of reporting the config error. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… + debug log Catch only ImportError/ModuleNotFoundError around the optional Mistral3Processor / Glm46VProcessor imports and log the caught exception at DEBUG so users debugging why their strategy wasn't picked up can see the actual failure without seeing a warning in normal operation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ranch qwen3_5_moe is a model_config_type, not a ChatTemplate enum value, so the branch was dead code. MoE variants share the qwen3_5 chat template anyway. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… markers Matches the trailing-newline convention of Qwen / Gemma3 / Llama markers. Previously the newline fell inside the span's first content token, which worked but is fragile if the BPE tokenizer re-merges. Updated the _FakeGemma4Tokenizer to reflect the expected three-token-sequence shape of '<|turn>model\n'. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Gemma 3 and Gemma 3n jinja templates fold the system message into the first user's content prefix and never emit <start_of_turn>system, so the system RoleBoundary was harmless dead code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Emit a single grep-friendly line on every strategy construction covering class name, train_on_inputs, roles_to_train, train_on_eos, boundary source, and — for overrides — the resolved (role, start_ids, end_ids) tuples. Users debugging why masking isn't firing no longer have to add prints. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pixtral and Mistral V7 Tekken entries no longer carry the "known limitation" footnote — commit acfe4fe fixed the shared-[/INST] scanner bug via include_end=False + rewind. Test-count mentions updated to 64. Drops the qwen3_5_moe alias from the audit table to match the dispatcher (removed in 77a9d17). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughIntroduces a declarative role-boundary masking system for multimodal training that honors Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
docs/multimodal_assistant_mask.md (1)
103-104: Consider using a permalink or removing the commit hash reference.The commit hash
acfe4fe4may become difficult to trace after rebases or squash merges. Consider linking to a stable reference (e.g., the test file line) or removing the hash.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/multimodal_assistant_mask.md` around lines 103 - 104, The doc references a short commit hash `acfe4fe4` which can become stale; update the line in multimodal_assistant_mask.md that mentions the commit to instead use a stable reference or remove the hash—either replace `acfe4fe4` with a permalink to the specific location in tests/test_processing_strategies.py (or a link to the repository/tag/PR) or simply remove the commit hash and point readers to the test file `tests/test_processing_strategies.py` and the per-position assertions mentioned.tests/test_processing_strategies.py (1)
353-390: Good test fixture, consider avoiding mutable class attributes.The
VOCABdict and other mutable attributes likesize = {}are class-level. While safe in test code that doesn't subclass these fixtures, using instance initialization would be cleaner.Note: The S105 "hardcoded password" warnings are false positives — these are token names (e.g.,
image_token,boi_token), not credentials.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_processing_strategies.py` around lines 353 - 390, The VOCAB dictionary and other mutable attributes are currently defined at class scope on _FakeGemma4Tokenizer (VOCAB) and set as class-like attributes on _FakeGemma4Processor; move VOCAB into the tokenizer's __init__ as a self.vocab = { ... } and initialize any mutable state (e.g., token ids or maps) inside _FakeGemma4Processor.__init__ as instance attributes (self.<name>) rather than relying on class-level mutation; update references in _FakeGemma4Tokenizer and _FakeGemma4Processor to use self.vocab and self.<token> so the fixtures no longer use shared mutable class attributes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 116-132: The mypy error comes from assigning mixed types to
boundaries_repr (list vs string); make boundaries_repr a consistent type
(preferably str) by converting the list branch to a string before logging—e.g.,
stringify [(b.role, b.start_tokens, b.end_tokens) for b in self.role_boundaries]
(via repr(), ", ".join(...), or json.dumps) so boundaries_repr is always a str;
update the code around the ProcessingStrategy initializer where
self.role_boundaries and boundaries_repr are computed and used in the LOG.info
call to avoid the Union type.
- Around line 1115-1124: The mypy suppression on the Mistral3Processor import
uses the wrong error code; change the ignore to match the actual mypy error or
restructure the import: replace "# type: ignore[assignment]" with "# type:
ignore[misc]" (or remove the inline ignore and wrap the import in a try/except
typed with Optional[Type]" using typing.cast/Optional) so Mistral3Processor
remains None on ImportError while satisfying mypy; update the import block
referencing Mistral3Processor and keep the LOG.debug message intact.
- Around line 415-417: The mypy complaint is from unpacking
last_trainable_end_span[i] directly after an `is not None` check; introduce a
local variable (e.g., span = last_trainable_end_span[i]) and check `if
train_on_eos == "last" and span is not None:` then unpack `s, e = span` (or use
the walrus operator `if train_on_eos == "last" and (span :=
last_trainable_end_span[i]) is not None:`) so mypy can narrow the type before
assigning to mask[i][s:e] in the processing_strategies logic.
---
Nitpick comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 103-104: The doc references a short commit hash `acfe4fe4` which
can become stale; update the line in multimodal_assistant_mask.md that mentions
the commit to instead use a stable reference or remove the hash—either replace
`acfe4fe4` with a permalink to the specific location in
tests/test_processing_strategies.py (or a link to the repository/tag/PR) or
simply remove the commit hash and point readers to the test file
`tests/test_processing_strategies.py` and the per-position assertions mentioned.
In `@tests/test_processing_strategies.py`:
- Around line 353-390: The VOCAB dictionary and other mutable attributes are
currently defined at class scope on _FakeGemma4Tokenizer (VOCAB) and set as
class-like attributes on _FakeGemma4Processor; move VOCAB into the tokenizer's
__init__ as a self.vocab = { ... } and initialize any mutable state (e.g., token
ids or maps) inside _FakeGemma4Processor.__init__ as instance attributes
(self.<name>) rather than relying on class-level mutation; update references in
_FakeGemma4Tokenizer and _FakeGemma4Processor to use self.vocab and self.<token>
so the fixtures no longer use shared mutable class attributes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ca052a96-754c-468f-951d-9c95d87e45be
📒 Files selected for processing (7)
docs/multimodal_assistant_mask.mdsrc/axolotl/core/builders/causal.pysrc/axolotl/loaders/processor.pysrc/axolotl/processing_strategies.pysrc/axolotl/utils/schemas/model.pysrc/axolotl/utils/schemas/multimodal.pytests/test_processing_strategies.py
|
📖 Documentation Preview: Deployed on Netlify from commit d444156 |
…n feat/processor-kwargs This branch scope is multimodal assistant-only loss masking. The processor_kwargs passthrough for axolotl-ai-cloud#3617 is being carried on a separate branch (feat/processor-kwargs), so unbundling it here keeps this PR focused and avoids duplicate history if the other branch lands first. Removes: - cfg.processor_kwargs field from ModelInputConfig - processor_kwargs merge/filter logic in load_processor - 3 tests (forward, absent, empty-dict) + the _load_processor_module helper - Design-doc paragraphs and draft-PR bullets that referenced axolotl-ai-cloud#3617 No behavior change relative to main for the processor load path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…bbit - Restructure Mistral3Processor lazy-import to match the Glm46V inline pattern 10 lines below; drops the type: ignore comment whose code was wrong anyway (mypy reports [misc] not [assignment]). - Annotate boundaries_repr as str | list[tuple[...]] so the override vs built-in branches no longer trip the [assignment] error. - Use walrus in the train_on_eos=="last" span unmask so mypy can narrow Optional after the check ([misc] "None is not iterable"). No behavior change; 61/61 tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
Two issues surfaced by CI:
- ruff-format (v0.15.8) wanted to reformat
src/axolotl/processing_strategies.py and tests/test_processing_strategies.py
into its preferred layout. Applied the formatter; also fixes a missing
trailing newline and an unsorted import in the test file (caught by
end-of-file-fixer and the ruff legacy-alias hook).
- test_strategy_init_logs_resolved_masking_config_{builtin,override} and
test_base_strategy_warns_when_no_boundaries were flaky across the
PyTest matrix. Root cause: axolotl's logging config sets
propagate=False on the "axolotl" logger once configure_logging() runs
on a worker, which prevents pytest's root-attached caplog handler from
ever seeing records. These tests passed locally when configure_logging
hadn't fired yet, but failed when another test on the same xdist
worker triggered it.
Added an `axolotl_caplog` fixture that attaches caplog.handler directly
to the axolotl.processing_strategies logger, making the three tests
robust to either logging state. No production behavior change.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/test_processing_strategies.py (1)
406-422: Avoid shared mutable class vocab in_FakeGemma4Tokenizer.
VOCABis a mutable class attribute; if mutated in one test path it can leak into others. Copy it when constructing the tokenizer instance.Proposed fix
class _FakeGemma4Tokenizer(_Tokenizer): @@ def __init__(self): - super().__init__(self.VOCAB, pad_id=0, unk_id=3) + super().__init__(dict(self.VOCAB), pad_id=0, unk_id=3)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_processing_strategies.py` around lines 406 - 422, The class _FakeGemma4Tokenizer uses a mutable class attribute VOCAB and passes it directly to super().__init__, which can leak mutations across tests; update __init__ to pass a fresh copy (e.g., a shallow copy or deepcopy of VOCAB) to super().__init__ so each tokenizer instance gets its own independent vocab (referencing _FakeGemma4Tokenizer, VOCAB, and __init__ in the diff), keeping pad_id=0 and unk_id=3 unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 194-197: The code currently falls back to
convert_legacy_format(example) when messages is present but None (and
conversations is absent), causing a KeyError; instead explicitly detect when
example.get("messages") is None and "conversations" not in example and raise a
clear validation error (or return a normalized error response) before calling
convert_legacy_format. Update the branch around processed_example assignment
(the if using "messages" and example["messages"]) to add a separate check for
example.get("messages") is None && "conversations" not in example and
raise/return a descriptive error mentioning missing messages/conversations, then
only call convert_legacy_format(example) in the true legacy case. Ensure
references to convert_legacy_format and processed_example remain consistent.
---
Nitpick comments:
In `@tests/test_processing_strategies.py`:
- Around line 406-422: The class _FakeGemma4Tokenizer uses a mutable class
attribute VOCAB and passes it directly to super().__init__, which can leak
mutations across tests; update __init__ to pass a fresh copy (e.g., a shallow
copy or deepcopy of VOCAB) to super().__init__ so each tokenizer instance gets
its own independent vocab (referencing _FakeGemma4Tokenizer, VOCAB, and __init__
in the diff), keeping pad_id=0 and unk_id=3 unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ff621d22-a660-4821-a295-828f4b8ebd9b
📒 Files selected for processing (3)
docs/multimodal_assistant_mask.mdsrc/axolotl/processing_strategies.pytests/test_processing_strategies.py
✅ Files skipped from review due to trivial changes (1)
- docs/multimodal_assistant_mask.md
…ocab handling Two review findings from CodeRabbit: - When `messages` is present but None and `conversations` is absent, the code fell through to `convert_legacy_format(example)`, which then raised `KeyError: 'conversations'` from `example["conversations"]` instead of a clear validation error. Split the branch explicitly: messages-non-None → deepcopy and use as-is (symmetric with the legacy branch which already deepcopies internally); conversations present → legacy conversion; otherwise raise a descriptive ValueError that names both required keys. - `_FakeGemma4Tokenizer.VOCAB` is a class-level mutable dict passed by reference to `super().__init__`, so every tokenizer instance aliases the same dict. No current code path mutates it, but the one-line `dict(self.VOCAB)` copy hardens the fake against any future mutation leaking across tests. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 94-109: The current conditional uses "if
role_boundaries_override:" which treats an explicit empty list as falsy and
ignores it; change the check to test for None (e.g., "if
role_boundaries_override is not None") so that an explicit [] is treated as a
real override; keep the logic that calls
_resolve_role_boundary_override(self.processor.tokenizer,
role_boundaries_override) (or the existing argument order) and assigns the
result to self.role_boundaries, and preserve the LOG.info call and the "source"
assignment so built-in defaults are only used when no override (None) is
provided.
In `@tests/test_processing_strategies.py`:
- Around line 420-424: The current __init__ uses dict(self.VOCAB) which only
shallow-copies the mapping so the inner ID lists remain shared; change the
constructor to deep-copy the VOCAB before passing it to super (e.g., use
copy.deepcopy(self.VOCAB) or recreate each value list) so per-instance mutations
of self.vocab[...] cannot leak across tests; keep the same pad_id=0 and unk_id=3
parameters when calling super.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 5b40dc62-74b4-49f9-93d9-4de91db20a8a
📒 Files selected for processing (2)
src/axolotl/processing_strategies.pytests/test_processing_strategies.py
| if role_boundaries_override: | ||
| overridden = _resolve_role_boundary_override( | ||
| role_boundaries_override, self.processor.tokenizer | ||
| ) | ||
| LOG.info( | ||
| "%s: overriding built-in role boundaries (%d decls) " | ||
| "with cfg.role_boundaries (%d decls).", | ||
| type(self).__name__, | ||
| len(built_in), | ||
| len(overridden), | ||
| ) | ||
| self.role_boundaries: list[RoleBoundary] = overridden | ||
| source = "override" | ||
| else: | ||
| self.role_boundaries = built_in | ||
| source = "built-in" |
There was a problem hiding this comment.
Treat an explicit empty role_boundaries_override as a real override.
if role_boundaries_override: falls back to the built-ins for [], so cfg.role_boundaries: [] cannot clear a strategy’s default boundaries even though overrides are otherwise wholesale replacements.
Suggested fix
- if role_boundaries_override:
+ if role_boundaries_override is not None:
overridden = _resolve_role_boundary_override(
role_boundaries_override, self.processor.tokenizer
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/processing_strategies.py` around lines 94 - 109, The current
conditional uses "if role_boundaries_override:" which treats an explicit empty
list as falsy and ignores it; change the check to test for None (e.g., "if
role_boundaries_override is not None") so that an explicit [] is treated as a
real override; keep the logic that calls
_resolve_role_boundary_override(self.processor.tokenizer,
role_boundaries_override) (or the existing argument order) and assigns the
result to self.role_boundaries, and preserve the LOG.info call and the "source"
assignment so built-in defaults are only used when no override (None) is
provided.
|
Closing to re-open with squashed history. Will be replaced by a fresh PR with a single squashed commit so CodeRabbit can re-review cleanly. |
Five test-quality refinements from CodeRabbit's third-round review. **R3-#2 — deterministic teardown in test_dora.** Wrap the DoRA smoke's wrap → train → assert sequence in ``try/finally`` so ``wrapped.close()`` runs even when the loss-descent assertion fails mid-test. Without this, an early assertion failure leaves hooks, pinned-host borrows, and CPU adapter threads alive into subsequent GPU tests on the same pytest session. **R3-#3 — distinguish hook edges in test_lora_offload_mode recording stub.** The pre-fix ``_RecordingScheduler.ensure_chunks_resident`` recorded every container callback under the same ``"ensure_chunks_resident"`` label. The per-hook tests (pre_forward / post_forward / post_backward fires ``ensure_chunks_resident``) then asserted only call COUNT — so a regression that deleted the pre-forward hook factory while post-forward still fired would still pass the count gates. Tag each call with its originating hook edge via frame inspection on the caller's ``co_qualname`` (Python 3.11+ guarantees the qualname captures the enclosing ``_make_lora_container_<edge>_hook`` factory). The four LoRA container hooks all funnel through the same ``ensure_chunks_resident`` entry point but their closures live in distinct factory functions, so the qualname uniquely identifies the edge. Update each per-hook test to filter on the edge-tagged label so a regression in any single edge fails the corresponding test: * pre_forward test: asserts ``ensure_chunks_resident:pre_forward`` fires ≥ n_blocks times. * post_forward test: asserts BOTH ``:pre_forward`` AND ``:post_forward`` fire ≥ n_containers times each (the previous bare ≥ 2*n_containers count was satisfied by either edge alone). * post_backward test: asserts all four edges (pre/post fwd, pre/ post bwd) fire ≥ n_containers times each. The production hook factory layout is unchanged — the stub recovers the edge from the existing closure's frame, no new arguments thread through ``install_hooks``. **R3-#4 — narrow protrain_model_wrapper exception scope in test_lora_offload_mode:1117.** The bare ``except (ValueError, RuntimeError)`` was treating any wrapper failure as "offload setup unavailable" and skipping. A broken ``protrain_model_wrapper`` runtime path could leave this smoke green. Restrict the suppression to known env-failure substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb load, ``No module named``, and capacity/searcher gates) — same canonical tuple D8 used at the optimizer-step site below — and re-raise anything else. Real wrapper regressions now surface. **R3-#5 — fail-safe CUDA teardown in test_param_data_shape_preservation.** Eight test functions in this module construct ``mgr / layout / pool / host`` via ``_build_chunk_manager`` and tear them down at the happy-path tail (``mgr.uninstall()`` / ``host.close()`` / ``del pool``). Any earlier assertion failure skipped the teardown, leaking pinned-host borrows + CUDA buffer-pool state into subsequent GPU tests. Add a top-level ``_teardown_chunk_manager(mgr, host, pool)`` helper that does the best-effort 3-call teardown (each call wrapped in its own try/except so a failure in ``uninstall`` doesn't block the ``host.close``), and wrap each test body in ``try: ... finally: _teardown_chunk_manager(...)``. Done programmatically across all 8 tests via a one-shot Python rewrite to keep the diff mechanical and the new structure consistent. **R3-#8 — replace hard-coded n_chunk_estimate=1 in test_trace_skip_on_override.** The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based on the assumption that the tiny GPT-2 fixture produces a single chunk. If the layout heuristics (``pick_S_chunk`` default, block-discovery rules) shift such that ``N_chunk > 1``, ``min_n_buffer_for(layout, n_persist=1)`` rejects ``n_buffer_override=0`` BEFORE the wrapper reaches the trace-skip gate the test is supposed to validate — converting this into a flaky non-target failure. Compute ``n_chunk_estimate`` dynamically by running the same ``discover_blocks`` → ``flatten_block_trees`` → ``build_layout`` pipeline the wrapper itself uses (with the wrapper's default S_chunk), and pass the resulting ``layout.N_chunk`` through. ``n_persist_override = n_chunk_estimate`` then keeps the all-persistent invariant the test relies on regardless of any future layout-heuristic shift. ``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped / 0 failed. GPU-marker sweep on touched files: 40 passed / 2 skipped (single-process Mode-C downgrade for shape-preserving placeholder paths) / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…inor) Four CodeRabbit findings on commits ``b61f04e0`` (init-transient peak prediction) + ``aa0c6ba9`` (Mode-C steady CKPT-chain accounting). **R4-#1 (Critical) — post-block-wrap DDP ignore-set re-registration.** The M6C-fix-8 ignore-set registration ran BEFORE block-wrap. The block wrappers (``block/checkpoint.py``, ``block/swap.py``, ``block/offload.py``) all do ``self.block = block``, which means PyTorch's ``named_parameters()`` traversal inserts a ``.block.`` infix into the parameter namespace (``layers.0.attn.q_proj.weight`` ⇒ ``layers.0.block.attn.q_proj.weight``). The pre-wrap names captured in ``model._ddp_params_and_buffers_to_ignore`` no longer match the namespace DDP's ``__init__`` walks at construction time. The init-time broadcast is irrelevant (M6C-fix-8's ``init_sync=False`` monkey-patch bypasses it wholesale on chunk-managed models), but DDP's BACKWARD-pass allreduce still consults the ignore list. A stale ignore set means DDP's backward allreduce would attempt to all-reduce chunk-managed LoRA factor gradients, conflicting with ProTrain's per-chunk ``reduce_scatter`` drain. Add a post-wrap re-registration step after ``install_hooks`` in ``_construct_runtime``: walk the WRAPPED ``model.named_parameters()`` and identify chunk-managed params by OBJECT identity against ``chunk_manager._params_by_id.values()``. Build the post-wrap name set, merge with the pre-protrain snapshot (``_protrain_ddp_original_ignore``), overwrite the attribute. Gated on ``_shape_preserving`` so the single-GPU / replicated path remains a no-op. **R4-#2 (Major) — reuse bootstrap init-transient peak instead of recomputing post-offload.** ``predict_init_transient_peak_bytes(layout, hw, chunk_manager)`` walks ``chunk_manager.model.named_parameters()`` to sum chunk bytes. By the time the phase-2 post-measurement calibration runs, ``materialize_offload`` has already executed and ``param.data`` points at the zero-size placeholders (replicated path) or ``scratch.expand(slot.shape)`` views (sharded path), so the byte accounting drifts away from the bootstrap-time full-residence prediction. Replace the recompute call at the phase-2 post-measurement calibration site with ``boot_result.predicted_init_transient_peak_bytes`` — the bootstrap-time value captured at line 1614 before materialize_offload ran. The downstream consumers (SearchResult publish, LOG.info diagnostic) get the same authoritative value without re-walking a now-stale chunk_manager. **R4-#3 (Major) — meta tensors in ``_stub_chunk_manager`` to avoid CI OOM.** ``tests/protrain/test_init_transient_peak.py::_stub_chunk_manager`` allocated full CPU tensors sized to model 15–60 GiB chunk totals. ``predict_init_transient_peak_bytes`` only reads ``param.numel() * param.element_size()``, so meta-device tensors preserve the byte-accounting metadata without consuming RAM. Switch the ``nn.Parameter(torch.zeros(numel, dtype, device='cpu'))`` construction to ``nn.Parameter(torch.empty(numel, dtype, device='meta'), requires_grad=False)``. **R4-#4 (Minor) — align docstring tolerance with ``TOLERANCE_FRAC = 0.35``.** ``tests/protrain/test_modec_steady_peak_accuracy.py`` docstring said "±25%" but ``TOLERANCE_FRAC = 0.35`` and the assertion uses 0.35. Update the two docstring mentions to "±35%" so text matches intent. ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker sweep: 313 passed / 4 skipped / 162 deselected / 0 failed. - GPU sanity on touched test files (GPU 5): 24 passed / 2 skipped (single-process Mode-C downgrade — expected) / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…est fixes Seven Minor items from the CodeRabbit full-diff re-scan on commit ``55377e5d``. **F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper`` 8-bit warning (``api/optim_wrapper.py:802-815``).** The warning told users to set ``protrain_force_all_persistent: true`` to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't mention that ``protrain_force_all_persistent`` is ignored while ``protrain_auto_mode`` is on (the auto-mode selector picks the mode itself based on capacity). Expanded the warning to instruct users to set ``protrain_auto_mode: false`` AND ``protrain_force_all_persistent: true`` together. **F-#4 — Unify fragmentation-alpha docs in DESIGN.md.** Module summaries at lines 49 (``cost/memory.py``) and 118 (``memory.py`` module spec) still described a fixed ``alpha=1.10`` while Design Decision 1 documents the per-dtype lookup (``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both summaries to reference the per-dtype helper (``alpha_fragmentation_for_dtype``) and the design decision section. **F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.** Line 109 (``block/checkpoint.py`` module spec) said ``use_reentrant=False``, which matches the actual implementation (verified via ``grep`` against ``block/checkpoint.py:99``). Line 290 (audit Block G analysis) claimed ``use_reentrant=True, the production wrap`` — stale and incorrect. Updated the analysis text to acknowledge ``use_reentrant=False`` is the production wrap and re-stated the per-block-input residual mechanism in a form compatible with the non-reentrant variant (each CKPT block's saved-tensors-hooks recompute frame holds the block input, which is what produces the linear-in-N_block activation footprint the audit data exposes). **F-#8 — Centralized CUDA-availability guard in ``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.** The helper unconditionally returned ``torch.device("cuda:0")``, so a custom marker filter or conftest override that lands the module in a CPU-only context would surface as a torch error before any test body. Added a ``pytest.skip("CUDA not available; ...")`` early-return so every gpu-marked test in the module gets a clean skip. **F-#9 — Replace silent ``try/except: pass`` with ``contextlib.suppress(Exception)`` in ``tests/protrain/test_lora_offload_mode.py``.** Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044 — each had the same ``for h in handles: try: h.remove() except Exception: pass`` pattern that Ruff S110 flags. Replaced with ``contextlib.suppress(Exception)`` over the loop. Semantics unchanged (best-effort cleanup, tolerate already-removed handles or torch shutting down mid-test); intent now documented by the context manager. **F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.** Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``. **F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of ``test_trace_skip_on_override.py``.** ``test_run_trace_skipped_on_override_full_path`` (L255-282), ``test_run_trace_invoked_without_override`` (L319-337), and ``test_partial_overrides_do_not_skip_trace`` (L381-400) each called ``wrapped.close()`` only on the success path — assertion failures earlier in the test body would skip the close and leak CUDA + chunk resources into subsequent GPU tests. Wrapped each test body in ``try/finally`` so ``wrapped.close()`` always runs. Done programmatically via a one-shot Python rewrite (8 lines of new indent + 2 lines of try/finally per site) to keep the diff mechanical. ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker: 313 passed / 4 skipped / 162 deselected / 0 failed. - GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…g time Adds a Pydantic validator that emits a warning when ProTrain is active and the effective batch size (micro_batch_size * gradient_accumulation_steps) is less than 4. ProTrain's per-iter scheduler walk and chunk-management hook fan-out carry fixed overhead; below effective bs=4 the per-step overhead dominates and users see ~5-7x throughput regressions vs vanilla bs=1 (proposal §6.d, hot-path fix tracked in §16 follow-up PR #4). Lets users self-diagnose 'ProTrain is mysteriously slow' before opening a support thread. The mitigation is just to raise gradient_accumulation_steps (same effective batch, much better throughput).
…for bs=1 hot path cProfile on a 4-block synthetic LoRA model identified three per-step Python overheads that don't amortize at bs=1: (a) cuda.is_available() fired in ensure_chunks_resident on every LoRA-container hook (~600 us/step at n_blocks=8); (b) per-step set() construction in post_block_forward; (c) PyTorch's setup_*_hook autograd machinery firing per forward (~1.4 ms/step) — fixed-cost given the LoRA-container hook quartet pinned by offload-correctness tests. This change addresses (a) + (b): caches torch.cuda.is_available() as self._has_cuda at Scheduler init, precomputes next-block / prev-block ids and next-chunks frozensets keyed by BlockId, and removes the per-call torch reimport from ensure_chunks_resident. (c) is GPU-profile- driven and rescoped under §16 PR #4. Microbench tests/protrain/test_bs1_hot_path_microbench.py captures the shape: pre-fix overhead 4.7 ms/step at n_blocks=8 → post-fix ~2.4 ms/step (~49%% reduction in Python-attributable per-step overhead). bs=4 wall-time unchanged within noise; no regressions across tests/protrain (369 passing).
Five test-quality refinements from CodeRabbit's third-round review. **R3-#2 — deterministic teardown in test_dora.** Wrap the DoRA smoke's wrap → train → assert sequence in ``try/finally`` so ``wrapped.close()`` runs even when the loss-descent assertion fails mid-test. Without this, an early assertion failure leaves hooks, pinned-host borrows, and CPU adapter threads alive into subsequent GPU tests on the same pytest session. **R3-#3 — distinguish hook edges in test_lora_offload_mode recording stub.** The pre-fix ``_RecordingScheduler.ensure_chunks_resident`` recorded every container callback under the same ``"ensure_chunks_resident"`` label. The per-hook tests (pre_forward / post_forward / post_backward fires ``ensure_chunks_resident``) then asserted only call COUNT — so a regression that deleted the pre-forward hook factory while post-forward still fired would still pass the count gates. Tag each call with its originating hook edge via frame inspection on the caller's ``co_qualname`` (Python 3.11+ guarantees the qualname captures the enclosing ``_make_lora_container_<edge>_hook`` factory). The four LoRA container hooks all funnel through the same ``ensure_chunks_resident`` entry point but their closures live in distinct factory functions, so the qualname uniquely identifies the edge. Update each per-hook test to filter on the edge-tagged label so a regression in any single edge fails the corresponding test: * pre_forward test: asserts ``ensure_chunks_resident:pre_forward`` fires ≥ n_blocks times. * post_forward test: asserts BOTH ``:pre_forward`` AND ``:post_forward`` fire ≥ n_containers times each (the previous bare ≥ 2*n_containers count was satisfied by either edge alone). * post_backward test: asserts all four edges (pre/post fwd, pre/ post bwd) fire ≥ n_containers times each. The production hook factory layout is unchanged — the stub recovers the edge from the existing closure's frame, no new arguments thread through ``install_hooks``. **R3-#4 — narrow protrain_model_wrapper exception scope in test_lora_offload_mode:1117.** The bare ``except (ValueError, RuntimeError)`` was treating any wrapper failure as "offload setup unavailable" and skipping. A broken ``protrain_model_wrapper`` runtime path could leave this smoke green. Restrict the suppression to known env-failure substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb load, ``No module named``, and capacity/searcher gates) — same canonical tuple D8 used at the optimizer-step site below — and re-raise anything else. Real wrapper regressions now surface. **R3-#5 — fail-safe CUDA teardown in test_param_data_shape_preservation.** Eight test functions in this module construct ``mgr / layout / pool / host`` via ``_build_chunk_manager`` and tear them down at the happy-path tail (``mgr.uninstall()`` / ``host.close()`` / ``del pool``). Any earlier assertion failure skipped the teardown, leaking pinned-host borrows + CUDA buffer-pool state into subsequent GPU tests. Add a top-level ``_teardown_chunk_manager(mgr, host, pool)`` helper that does the best-effort 3-call teardown (each call wrapped in its own try/except so a failure in ``uninstall`` doesn't block the ``host.close``), and wrap each test body in ``try: ... finally: _teardown_chunk_manager(...)``. Done programmatically across all 8 tests via a one-shot Python rewrite to keep the diff mechanical and the new structure consistent. **R3-#8 — replace hard-coded n_chunk_estimate=1 in test_trace_skip_on_override.** The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based on the assumption that the tiny GPT-2 fixture produces a single chunk. If the layout heuristics (``pick_S_chunk`` default, block-discovery rules) shift such that ``N_chunk > 1``, ``min_n_buffer_for(layout, n_persist=1)`` rejects ``n_buffer_override=0`` BEFORE the wrapper reaches the trace-skip gate the test is supposed to validate — converting this into a flaky non-target failure. Compute ``n_chunk_estimate`` dynamically by running the same ``discover_blocks`` → ``flatten_block_trees`` → ``build_layout`` pipeline the wrapper itself uses (with the wrapper's default S_chunk), and pass the resulting ``layout.N_chunk`` through. ``n_persist_override = n_chunk_estimate`` then keeps the all-persistent invariant the test relies on regardless of any future layout-heuristic shift. ``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped / 0 failed. GPU-marker sweep on touched files: 40 passed / 2 skipped (single-process Mode-C downgrade for shape-preserving placeholder paths) / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…inor) Four CodeRabbit findings on commits ``b61f04e0`` (init-transient peak prediction) + ``aa0c6ba9`` (Mode-C steady CKPT-chain accounting). **R4-#1 (Critical) — post-block-wrap DDP ignore-set re-registration.** The M6C-fix-8 ignore-set registration ran BEFORE block-wrap. The block wrappers (``block/checkpoint.py``, ``block/swap.py``, ``block/offload.py``) all do ``self.block = block``, which means PyTorch's ``named_parameters()`` traversal inserts a ``.block.`` infix into the parameter namespace (``layers.0.attn.q_proj.weight`` ⇒ ``layers.0.block.attn.q_proj.weight``). The pre-wrap names captured in ``model._ddp_params_and_buffers_to_ignore`` no longer match the namespace DDP's ``__init__`` walks at construction time. The init-time broadcast is irrelevant (M6C-fix-8's ``init_sync=False`` monkey-patch bypasses it wholesale on chunk-managed models), but DDP's BACKWARD-pass allreduce still consults the ignore list. A stale ignore set means DDP's backward allreduce would attempt to all-reduce chunk-managed LoRA factor gradients, conflicting with ProTrain's per-chunk ``reduce_scatter`` drain. Add a post-wrap re-registration step after ``install_hooks`` in ``_construct_runtime``: walk the WRAPPED ``model.named_parameters()`` and identify chunk-managed params by OBJECT identity against ``chunk_manager._params_by_id.values()``. Build the post-wrap name set, merge with the pre-protrain snapshot (``_protrain_ddp_original_ignore``), overwrite the attribute. Gated on ``_shape_preserving`` so the single-GPU / replicated path remains a no-op. **R4-#2 (Major) — reuse bootstrap init-transient peak instead of recomputing post-offload.** ``predict_init_transient_peak_bytes(layout, hw, chunk_manager)`` walks ``chunk_manager.model.named_parameters()`` to sum chunk bytes. By the time the phase-2 post-measurement calibration runs, ``materialize_offload`` has already executed and ``param.data`` points at the zero-size placeholders (replicated path) or ``scratch.expand(slot.shape)`` views (sharded path), so the byte accounting drifts away from the bootstrap-time full-residence prediction. Replace the recompute call at the phase-2 post-measurement calibration site with ``boot_result.predicted_init_transient_peak_bytes`` — the bootstrap-time value captured at line 1614 before materialize_offload ran. The downstream consumers (SearchResult publish, LOG.info diagnostic) get the same authoritative value without re-walking a now-stale chunk_manager. **R4-#3 (Major) — meta tensors in ``_stub_chunk_manager`` to avoid CI OOM.** ``tests/protrain/test_init_transient_peak.py::_stub_chunk_manager`` allocated full CPU tensors sized to model 15–60 GiB chunk totals. ``predict_init_transient_peak_bytes`` only reads ``param.numel() * param.element_size()``, so meta-device tensors preserve the byte-accounting metadata without consuming RAM. Switch the ``nn.Parameter(torch.zeros(numel, dtype, device='cpu'))`` construction to ``nn.Parameter(torch.empty(numel, dtype, device='meta'), requires_grad=False)``. **R4-#4 (Minor) — align docstring tolerance with ``TOLERANCE_FRAC = 0.35``.** ``tests/protrain/test_modec_steady_peak_accuracy.py`` docstring said "±25%" but ``TOLERANCE_FRAC = 0.35`` and the assertion uses 0.35. Update the two docstring mentions to "±35%" so text matches intent. ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker sweep: 313 passed / 4 skipped / 162 deselected / 0 failed. - GPU sanity on touched test files (GPU 5): 24 passed / 2 skipped (single-process Mode-C downgrade — expected) / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…est fixes Seven Minor items from the CodeRabbit full-diff re-scan on commit ``55377e5d``. **F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper`` 8-bit warning (``api/optim_wrapper.py:802-815``).** The warning told users to set ``protrain_force_all_persistent: true`` to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't mention that ``protrain_force_all_persistent`` is ignored while ``protrain_auto_mode`` is on (the auto-mode selector picks the mode itself based on capacity). Expanded the warning to instruct users to set ``protrain_auto_mode: false`` AND ``protrain_force_all_persistent: true`` together. **F-#4 — Unify fragmentation-alpha docs in DESIGN.md.** Module summaries at lines 49 (``cost/memory.py``) and 118 (``memory.py`` module spec) still described a fixed ``alpha=1.10`` while Design Decision 1 documents the per-dtype lookup (``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both summaries to reference the per-dtype helper (``alpha_fragmentation_for_dtype``) and the design decision section. **F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.** Line 109 (``block/checkpoint.py`` module spec) said ``use_reentrant=False``, which matches the actual implementation (verified via ``grep`` against ``block/checkpoint.py:99``). Line 290 (audit Block G analysis) claimed ``use_reentrant=True, the production wrap`` — stale and incorrect. Updated the analysis text to acknowledge ``use_reentrant=False`` is the production wrap and re-stated the per-block-input residual mechanism in a form compatible with the non-reentrant variant (each CKPT block's saved-tensors-hooks recompute frame holds the block input, which is what produces the linear-in-N_block activation footprint the audit data exposes). **F-#8 — Centralized CUDA-availability guard in ``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.** The helper unconditionally returned ``torch.device("cuda:0")``, so a custom marker filter or conftest override that lands the module in a CPU-only context would surface as a torch error before any test body. Added a ``pytest.skip("CUDA not available; ...")`` early-return so every gpu-marked test in the module gets a clean skip. **F-#9 — Replace silent ``try/except: pass`` with ``contextlib.suppress(Exception)`` in ``tests/protrain/test_lora_offload_mode.py``.** Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044 — each had the same ``for h in handles: try: h.remove() except Exception: pass`` pattern that Ruff S110 flags. Replaced with ``contextlib.suppress(Exception)`` over the loop. Semantics unchanged (best-effort cleanup, tolerate already-removed handles or torch shutting down mid-test); intent now documented by the context manager. **F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.** Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``. **F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of ``test_trace_skip_on_override.py``.** ``test_run_trace_skipped_on_override_full_path`` (L255-282), ``test_run_trace_invoked_without_override`` (L319-337), and ``test_partial_overrides_do_not_skip_trace`` (L381-400) each called ``wrapped.close()`` only on the success path — assertion failures earlier in the test body would skip the close and leak CUDA + chunk resources into subsequent GPU tests. Wrapped each test body in ``try/finally`` so ``wrapped.close()`` always runs. Done programmatically via a one-shot Python rewrite (8 lines of new indent + 2 lines of try/finally per site) to keep the diff mechanical. ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker: 313 passed / 4 skipped / 162 deselected / 0 failed. - GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…g time Adds a Pydantic validator that emits a warning when ProTrain is active and the effective batch size (micro_batch_size * gradient_accumulation_steps) is less than 4. ProTrain's per-iter scheduler walk and chunk-management hook fan-out carry fixed overhead; below effective bs=4 the per-step overhead dominates and users see ~5-7x throughput regressions vs vanilla bs=1 (proposal §6.d, hot-path fix tracked in §16 follow-up PR #4). Lets users self-diagnose 'ProTrain is mysteriously slow' before opening a support thread. The mitigation is just to raise gradient_accumulation_steps (same effective batch, much better throughput).
…for bs=1 hot path cProfile on a 4-block synthetic LoRA model identified three per-step Python overheads that don't amortize at bs=1: (a) cuda.is_available() fired in ensure_chunks_resident on every LoRA-container hook (~600 us/step at n_blocks=8); (b) per-step set() construction in post_block_forward; (c) PyTorch's setup_*_hook autograd machinery firing per forward (~1.4 ms/step) — fixed-cost given the LoRA-container hook quartet pinned by offload-correctness tests. This change addresses (a) + (b): caches torch.cuda.is_available() as self._has_cuda at Scheduler init, precomputes next-block / prev-block ids and next-chunks frozensets keyed by BlockId, and removes the per-call torch reimport from ensure_chunks_resident. (c) is GPU-profile- driven and rescoped under §16 PR #4. Microbench tests/protrain/test_bs1_hot_path_microbench.py captures the shape: pre-fix overhead 4.7 ms/step at n_blocks=8 → post-fix ~2.4 ms/step (~49%% reduction in Python-attributable per-step overhead). bs=4 wall-time unchanged within noise; no regressions across tests/protrain (369 passing).
Summary
cfg.train_on_inputs/cfg.roles_to_train/cfg.train_on_eosin the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config.cfg.role_boundariesYAML override so users can declare per-role markers without subclassing.cfg.processor_kwargstoprocessor_cls.from_pretrained(closes axolotl-ai-cloud/axolotl#3617).What changed
ProcessingStrategygains a declarative boundary scanner. Each strategy declares per-role start/end markers via_build_role_boundaries; the shared scanner honorstrain_on_inputs/roles_to_train/train_on_eos(including"last").llava/lfm2vlfallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them viacfg.role_boundaries.[/INST]token between user-end and assistant-start viainclude_end=False+ scanner rewind.cfg.processor_kwargs(new) merged intoprocessor_cls.from_pretrainedkwargs;revisionandtrust_remote_coderemain axolotl-managed.Design
See
docs/multimodal_assistant_mask.mdfor the full audit table, root-cause analysis, design rationale (why boundary-scanner overreturn_assistant_tokens_maskor preservingtokenize_promptlabels), andcfg.role_boundariesusage.Test plan
tests/test_processing_strategies.pypass (no HF Hub access needed)google/gemma-4-E2B-ittokenizer: 13/40 tokens kept on 2-turn chat, correct spansllama3_2_vision.jinja: 11/64 tokens kept, correct spanscfg.role_boundariesoverride pathcfg.processor_type and self.processor)Commits
20 commits in 4 logical groups: core refactor + tests, builder/loader plumbing,
processor_kwargs(axolotl-ai-cloud#3617),cfg.role_boundariesoverride, review-driven fixes (blocking Pixtral/Mistral fix,train_on_eos="last"support, pydantic dataset handling, validation hardening). Seegit log main..HEAD.🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
New Features
Tests
Documentation