Skip to content

feat: systemic multimodal assistant-only loss masking + #3617#4

Closed
thad0ctor wants to merge 24 commits into
mainfrom
feat/multimodal-assistant-mask-all
Closed

feat: systemic multimodal assistant-only loss masking + #3617#4
thad0ctor wants to merge 24 commits into
mainfrom
feat/multimodal-assistant-mask-all

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Apr 23, 2026

Copy link
Copy Markdown
Owner

Summary

  • Fixes silent ignoring of cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos in the multimodal training path. Before this branch, only Gemma 3n honored these knobs; every other VLM trained on the full sequence regardless of config.
  • Adds cfg.role_boundaries YAML override so users can declare per-role markers without subclassing.
  • Forwards cfg.processor_kwargs to processor_cls.from_pretrained (closes axolotl-ai-cloud/axolotl#3617).

What changed

  • ProcessingStrategy gains a declarative boundary scanner. Each strategy declares per-role start/end markers via _build_role_boundaries; the shared scanner honors train_on_inputs / roles_to_train / train_on_eos (including "last").
  • New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4, Pixtral, Mistral V7 Tekken.
  • Refactored: Gemma 3 (previously no role masking), Gemma 3n (previously ad-hoc scanner, now shared).
  • Strategies whose boundary tokens couldn't be verified offline (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning. Users can enable masking on them via cfg.role_boundaries.
  • Pixtral / Mistral V7 Tekken correctly handle the shared [/INST] token between user-end and assistant-start via include_end=False + scanner rewind.
  • cfg.processor_kwargs (new) merged into processor_cls.from_pretrained kwargs; revision and trust_remote_code remain axolotl-managed.

Design

See docs/multimodal_assistant_mask.md for the full audit table, root-cause analysis, design rationale (why boundary-scanner over return_assistant_tokens_mask or preserving tokenize_prompt labels), and cfg.role_boundaries usage.

Test plan

  • 64 offline unit tests in tests/test_processing_strategies.py pass (no HF Hub access needed)
  • End-to-end verified against real google/gemma-4-E2B-it tokenizer: 13/40 tokens kept on 2-turn chat, correct spans
  • End-to-end verified against real Llama-3.x tokenizer with bundled llama3_2_vision.jinja: 11/64 tokens kept, correct spans
  • End-to-end verified against the real Rashi-OCR Gemma-4-31B tokenizer + dataset, including cfg.role_boundaries override path
  • Text-only training path unaffected (MM path only taken when cfg.processor_type and self.processor)
  • Migration visibility: INFO-log at strategy init + collator assembly surfaces resolved masking config

Commits

20 commits in 4 logical groups: core refactor + tests, builder/loader plumbing, processor_kwargs (axolotl-ai-cloud#3617), cfg.role_boundaries override, review-driven fixes (blocking Pixtral/Mistral fix, train_on_eos="last" support, pydantic dataset handling, validation hardening). See git log main..HEAD.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Loss masking now correctly respects training flags (train_on_inputs, roles_to_train, train_on_eos) in multimodal scenarios.
  • New Features

    • Declarative, configurable role-boundary system for precise loss-masking control, with YAML override support.
    • Broader multimodal strategy coverage (Qwen, Gemma, Llama, Pixtral, Mistral).
  • Tests

    • Added extensive end-to-end tests covering masking modes, overrides, and strategy routing.
  • Documentation

    • New design doc describing the role-boundary approach and configuration changes.

thad0ctor and others added 20 commits April 22, 2026 17:32
Adds a parametrized role-boundary scanner to ProcessingStrategy so that
cfg.train_on_inputs / cfg.roles_to_train / cfg.train_on_eos (which were
previously silently ignored for every multimodal path except Gemma 3n)
restrict loss to trainable role spans.

Introduces ``RoleBoundary`` and ``_apply_role_boundaries``. Each strategy
declares its per-role start/end token sequences via ``_build_role_boundaries``;
the base ``_mask_non_assistant`` delegates to the shared scanner and
short-circuits with a one-shot warning when a strategy has no declared
boundaries (preserving legacy behavior for unverified paths).

Per-strategy boundary declarations added for Qwen2-VL, Qwen3.5, Gemma 3
(previously no role masking), Gemma 3n (previously ad-hoc per-strategy
scanner, now shared), Gemma 4 (new), Llama 3.2 Vision (new), Llama 4 (new),
Pixtral (new), and Mistral V7 Tekken (new). Dispatcher routes
chat_template_type to the new subclasses.

Strategies whose boundary tokens we can't verify without loading a real
checkpoint (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, plus the
llava/lfm2vl fallback) retain legacy behavior and emit a one-shot warning
naming the class so the miss is visible in training logs.

Also makes the dispatcher's Mistral3Processor import lazy so that
non-mistral-common installs can still route to non-mistral strategies.

Includes 30 offline unit tests covering scanner semantics
(train_on_inputs/roles_to_train/train_on_eos, longest-prefix start match,
truncated spans), per-strategy masking with fake tokenizers matching real
model id layouts, media-token masking inside assistant spans, and dispatcher
routing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… MM collator

HFCausalTrainerBuilder.build_collator now reads cfg.train_on_inputs and the
first dataset entry's roles_to_train / train_on_eos (mirroring how
ChatTemplateStrategy resolves them for text-only training) and forwards
them to get_processing_strategy so the shared scanner actually runs with
the user's config. Without this, the scanner from the prior commit would
always use its defaults (train_on_inputs=False, roles_to_train=["assistant"],
train_on_eos="turn"), which happens to match most users' intent but
silently ignores explicit config overrides.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…otl-ai-cloud#3617)

Adds a ``processor_kwargs: dict | None`` field to ModelInputConfig and
merges it into the kwargs passed to ``processor_cls.from_pretrained`` in
load_processor. Lets users set min_pixels, max_pixels, num_crops,
do_rescale, patch_size, and similar VLM processor knobs from YAML.

Axolotl-managed keys (``revision``, ``trust_remote_code``) are filtered
out of the user's processor_kwargs with a warning, so passing
``processor_kwargs: {revision: HIJACKED}`` cannot override the
cfg-level revision_of_model or the axolotl-managed trust_remote_code.

Includes 2 offline unit tests covering the forward-and-filter behavior
and the no-op case where cfg.processor_kwargs is absent.

Closes axolotl-ai-cloud#3617.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Captures the audit of every multimodal processing strategy, the root-cause
analysis, the design rationale for the boundary-scanner approach (vs
preserving tokenize_prompt labels or relying on HF's return_assistant_tokens_mask),
the per-commit summary for this branch, and a draft upstream PR description.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lets users declare per-role start/end markers directly in YAML without
subclassing a ProcessingStrategy. Intended uses:

1. Enable role masking on strategies we can't verify against a real
   checkpoint (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, unknown
   fallback) — instead of a one-shot warning + full-sequence loss, the
   user supplies markers for their specific checkpoint's chat template
   and gets proper assistant-only masking.
2. Fine-tune markers for custom / patched chat templates whose tokenized
   output diverges from axolotl's built-ins.

YAML shape:
  role_boundaries:
    - role: assistant
      start: "<|turn>model"
      end: "<turn|>"
    # include_start: false   (default)
    # include_end: true      (default, respects cfg.train_on_eos)
    # end: eos_token         (sentinel → tokenizer.eos_token_id)
    # end: null              (span runs to end of sequence)

Implementation:
- RoleBoundarySpec pydantic model added to MultiModalConfig (so Axolotl's
  standard validation catches typos at config-load time).
- _resolve_role_boundary_override converts strings → token ids at
  strategy init; logs the resolved ids at INFO so mismatches are visible.
- Validation errors (missing role/start, unencodable marker) surface at
  init time rather than as silent mis-masking during training.
- ProcessingStrategy.__init__ now takes role_boundaries_override; when
  set, it REPLACES the subclass's _build_role_boundaries() result
  wholesale (partial overlays would be ambiguous to review).
- Threaded through build_collator → get_processing_strategy → every
  strategy constructor.

Verified end-to-end against the real Rashi-OCR Gemma-4-31B tokenizer +
dataset: pass with built-in boundaries, pass with an assistant-only
override, and pass via the dispatcher path all produce identical correct
masks on real dataset rows.

Includes 7 additional unit tests (override replaces built-in, override
enables unverified strategy, eos_token sentinel, null end, spec
validation errors, Pydantic model input). Total: 39 offline tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests for: batch_size>1 (scanner + Qwen2VL strategy), include_start/include_end
round-trip through RoleBoundarySpec override, pad masking inside trainable spans,
all-pad sequence, multiple consecutive assistant turns, train_on_eos variants on
multi-turn conversations, MultiModalConfig dict->RoleBoundarySpec parsing, Qwen3.5
video_pad masking under train_on_inputs=True, and empty processor_kwargs no-op.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Trims verbose block comments and docstrings across the branch to 1-2
lines each, keeping the non-obvious WHY and dropping the WHAT.
Behavior unchanged; all 39 tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous `ds_cfg.get if hasattr else __getitem__` ladder fell through
to an AttributeError (swallowed) for pydantic dataset entries, silently
ignoring `roles_to_train` / `train_on_eos`. Switch to dict-style `.get`
first, then `getattr(obj, key, None)` — works for DictDefault, dict, and
pydantic SFTDataset alike.

Also log resolved collator knobs at INFO so the MM masking config is
visible in training logs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…n assistant spans

[/INST] is shared between user-end and assistant-start on these templates,
so the scanner consumed it for user and never matched it as assistant-start,
producing all-masked labels. Teach the scanner to not consume the end marker
when include_end=False (back up by len(end_tokens)), and declare
user.include_end=False on both Pixtral and Mistral V7 Tekken so assistant
spans are actually detected. Tests now assert the full expected mask.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bump "32 unit tests" references to the current 55, and replace the stale
5-entry commit breakdown with the actual 8-unit sequence (with a pointer
to `git log main..HEAD` as the source of truth).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… at init

'last' now trains on the final trainable turn's end marker only, matching the
text-only ChatTemplateStrategy's documented semantics. Invalid values raise
ValueError at strategy construction instead of silently falling through the
scanner's 'turn' branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Switch falsy check to 'is None' so roles_to_train=[] correctly masks every
role instead of silently defaulting to ['assistant'].

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
``labels == None`` emits a UserWarning and returns all-False. Sweep
process_labels overrides to skip pad/image/audio/video masking when the
corresponding id is None. Added a regression test that promotes warnings
to errors.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…boundaries

Mirrors the existing start-marker guard. Previously a typo'd end would
silently produce an empty end_tokens list and the span would run to end of
sequence instead of reporting the config error.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… + debug log

Catch only ImportError/ModuleNotFoundError around the optional
Mistral3Processor / Glm46VProcessor imports and log the caught exception at
DEBUG so users debugging why their strategy wasn't picked up can see the
actual failure without seeing a warning in normal operation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ranch

qwen3_5_moe is a model_config_type, not a ChatTemplate enum value, so the
branch was dead code. MoE variants share the qwen3_5 chat template anyway.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… markers

Matches the trailing-newline convention of Qwen / Gemma3 / Llama markers.
Previously the newline fell inside the span's first content token, which
worked but is fragile if the BPE tokenizer re-merges. Updated the
_FakeGemma4Tokenizer to reflect the expected three-token-sequence shape of
'<|turn>model\n'.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Gemma 3 and Gemma 3n jinja templates fold the system message into the first
user's content prefix and never emit <start_of_turn>system, so the system
RoleBoundary was harmless dead code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Emit a single grep-friendly line on every strategy construction covering
class name, train_on_inputs, roles_to_train, train_on_eos, boundary source,
and — for overrides — the resolved (role, start_ids, end_ids) tuples.
Users debugging why masking isn't firing no longer have to add prints.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pixtral and Mistral V7 Tekken entries no longer carry the "known
limitation" footnote — commit acfe4fe fixed the shared-[/INST] scanner
bug via include_end=False + rewind. Test-count mentions updated to 64.
Drops the qwen3_5_moe alias from the audit table to match the dispatcher
(removed in 77a9d17).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Apr 23, 2026

Copy link
Copy Markdown
📝 Walkthrough

Walkthrough

Introduces a declarative role-boundary masking system for multimodal training that honors train_on_inputs, roles_to_train, and train_on_eos. Threads masking config into processing strategy selection, replaces the legacy no-op assistant-mask with a token-span scanner, adds schema hooks for overrides, and includes comprehensive tests and documentation.

Changes

Cohort / File(s) Summary
Documentation Design
docs/multimodal_assistant_mask.md
New design doc describing the role-boundary system, RoleBoundary/RoleBoundarySpec, longest-prefix scanner, train_on_eos semantics, processor kwargs threading, and YAML cfg.role_boundaries override.
Configuration Schema
src/axolotl/utils/schemas/multimodal.py
Adds RoleBoundarySpec and MultiModalConfig.role_boundaries to allow declarative overrides of built-in role-boundary markers.
Core Processing Logic
src/axolotl/processing_strategies.py
Replaces _mask_non_assistant with _apply_role_boundaries; ProcessingStrategy now accepts and validates train_on_inputs, roles_to_train, train_on_eos, and role_boundaries_override; resolves built-in vs override boundaries; applies role masking before pad/media masking; adds per-strategy boundary declarations and lazy/robust processor imports.
Builder Integration
src/axolotl/core/builders/causal.py
Extracts roles_to_train and train_on_eos from the first dataset entry and threads role_boundaries_override from cfg.role_boundaries into get_processing_strategy; logs resolved masking knobs.
Tests
tests/test_processing_strategies.py
Adds extensive offline tests covering _apply_role_boundaries behavior across roles, train_on_eos modes, prefix selection, missing end markers, batch consistency, strategy-specific masking rules, dispatcher routing, and override/schema validation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: implementing systemic multimodal assistant-only loss masking, with a reference to issue #3617.
Linked Issues check ✅ Passed All objectives from issue #3617 are met: role-masking knobs now work in multimodal path, generic turn-marker scanner added, new per-template strategies implemented, cfg.role_boundaries override provided, processor_kwargs forwarding added, and comprehensive tests included.
Out of Scope Changes check ✅ Passed All changes directly support the multimodal loss masking objectives. Documentation, config schema, core processing logic, builder integration, and comprehensive tests are all within scope of addressing the issue.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/multimodal-assistant-mask-all

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
docs/multimodal_assistant_mask.md (1)

103-104: Consider using a permalink or removing the commit hash reference.

The commit hash acfe4fe4 may become difficult to trace after rebases or squash merges. Consider linking to a stable reference (e.g., the test file line) or removing the hash.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/multimodal_assistant_mask.md` around lines 103 - 104, The doc references
a short commit hash `acfe4fe4` which can become stale; update the line in
multimodal_assistant_mask.md that mentions the commit to instead use a stable
reference or remove the hash—either replace `acfe4fe4` with a permalink to the
specific location in tests/test_processing_strategies.py (or a link to the
repository/tag/PR) or simply remove the commit hash and point readers to the
test file `tests/test_processing_strategies.py` and the per-position assertions
mentioned.
tests/test_processing_strategies.py (1)

353-390: Good test fixture, consider avoiding mutable class attributes.

The VOCAB dict and other mutable attributes like size = {} are class-level. While safe in test code that doesn't subclass these fixtures, using instance initialization would be cleaner.

Note: The S105 "hardcoded password" warnings are false positives — these are token names (e.g., image_token, boi_token), not credentials.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_processing_strategies.py` around lines 353 - 390, The VOCAB
dictionary and other mutable attributes are currently defined at class scope on
_FakeGemma4Tokenizer (VOCAB) and set as class-like attributes on
_FakeGemma4Processor; move VOCAB into the tokenizer's __init__ as a self.vocab =
{ ... } and initialize any mutable state (e.g., token ids or maps) inside
_FakeGemma4Processor.__init__ as instance attributes (self.<name>) rather than
relying on class-level mutation; update references in _FakeGemma4Tokenizer and
_FakeGemma4Processor to use self.vocab and self.<token> so the fixtures no
longer use shared mutable class attributes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 116-132: The mypy error comes from assigning mixed types to
boundaries_repr (list vs string); make boundaries_repr a consistent type
(preferably str) by converting the list branch to a string before logging—e.g.,
stringify [(b.role, b.start_tokens, b.end_tokens) for b in self.role_boundaries]
(via repr(), ", ".join(...), or json.dumps) so boundaries_repr is always a str;
update the code around the ProcessingStrategy initializer where
self.role_boundaries and boundaries_repr are computed and used in the LOG.info
call to avoid the Union type.
- Around line 1115-1124: The mypy suppression on the Mistral3Processor import
uses the wrong error code; change the ignore to match the actual mypy error or
restructure the import: replace "# type: ignore[assignment]" with "# type:
ignore[misc]" (or remove the inline ignore and wrap the import in a try/except
typed with Optional[Type]" using typing.cast/Optional) so Mistral3Processor
remains None on ImportError while satisfying mypy; update the import block
referencing Mistral3Processor and keep the LOG.debug message intact.
- Around line 415-417: The mypy complaint is from unpacking
last_trainable_end_span[i] directly after an `is not None` check; introduce a
local variable (e.g., span = last_trainable_end_span[i]) and check `if
train_on_eos == "last" and span is not None:` then unpack `s, e = span` (or use
the walrus operator `if train_on_eos == "last" and (span :=
last_trainable_end_span[i]) is not None:`) so mypy can narrow the type before
assigning to mask[i][s:e] in the processing_strategies logic.

---

Nitpick comments:
In `@docs/multimodal_assistant_mask.md`:
- Around line 103-104: The doc references a short commit hash `acfe4fe4` which
can become stale; update the line in multimodal_assistant_mask.md that mentions
the commit to instead use a stable reference or remove the hash—either replace
`acfe4fe4` with a permalink to the specific location in
tests/test_processing_strategies.py (or a link to the repository/tag/PR) or
simply remove the commit hash and point readers to the test file
`tests/test_processing_strategies.py` and the per-position assertions mentioned.

In `@tests/test_processing_strategies.py`:
- Around line 353-390: The VOCAB dictionary and other mutable attributes are
currently defined at class scope on _FakeGemma4Tokenizer (VOCAB) and set as
class-like attributes on _FakeGemma4Processor; move VOCAB into the tokenizer's
__init__ as a self.vocab = { ... } and initialize any mutable state (e.g., token
ids or maps) inside _FakeGemma4Processor.__init__ as instance attributes
(self.<name>) rather than relying on class-level mutation; update references in
_FakeGemma4Tokenizer and _FakeGemma4Processor to use self.vocab and self.<token>
so the fixtures no longer use shared mutable class attributes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ca052a96-754c-468f-951d-9c95d87e45be

📥 Commits

Reviewing files that changed from the base of the PR and between 7420fd4 and 08f8e4e.

📒 Files selected for processing (7)
  • docs/multimodal_assistant_mask.md
  • src/axolotl/core/builders/causal.py
  • src/axolotl/loaders/processor.py
  • src/axolotl/processing_strategies.py
  • src/axolotl/utils/schemas/model.py
  • src/axolotl/utils/schemas/multimodal.py
  • tests/test_processing_strategies.py

Comment thread src/axolotl/processing_strategies.py
Comment thread src/axolotl/processing_strategies.py Outdated
Comment thread src/axolotl/processing_strategies.py Outdated
@github-actions

github-actions Bot commented Apr 23, 2026

Copy link
Copy Markdown

📖 Documentation Preview:

Deployed on Netlify from commit d444156

thad0ctor and others added 2 commits April 22, 2026 19:12
…n feat/processor-kwargs

This branch scope is multimodal assistant-only loss masking. The
processor_kwargs passthrough for axolotl-ai-cloud#3617 is being carried on a separate
branch (feat/processor-kwargs), so unbundling it here keeps this PR
focused and avoids duplicate history if the other branch lands first.

Removes:
- cfg.processor_kwargs field from ModelInputConfig
- processor_kwargs merge/filter logic in load_processor
- 3 tests (forward, absent, empty-dict) + the _load_processor_module helper
- Design-doc paragraphs and draft-PR bullets that referenced axolotl-ai-cloud#3617

No behavior change relative to main for the processor load path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…bbit

- Restructure Mistral3Processor lazy-import to match the Glm46V inline
  pattern 10 lines below; drops the type: ignore comment whose code
  was wrong anyway (mypy reports [misc] not [assignment]).
- Annotate boundaries_repr as str | list[tuple[...]] so the override
  vs built-in branches no longer trip the [assignment] error.
- Use walrus in the train_on_eos=="last" span unmask so mypy can
  narrow Optional after the check ([misc] "None is not iterable").

No behavior change; 61/61 tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Apr 24, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Two issues surfaced by CI:

- ruff-format (v0.15.8) wanted to reformat
  src/axolotl/processing_strategies.py and tests/test_processing_strategies.py
  into its preferred layout. Applied the formatter; also fixes a missing
  trailing newline and an unsorted import in the test file (caught by
  end-of-file-fixer and the ruff legacy-alias hook).

- test_strategy_init_logs_resolved_masking_config_{builtin,override} and
  test_base_strategy_warns_when_no_boundaries were flaky across the
  PyTest matrix. Root cause: axolotl's logging config sets
  propagate=False on the "axolotl" logger once configure_logging() runs
  on a worker, which prevents pytest's root-attached caplog handler from
  ever seeing records. These tests passed locally when configure_logging
  hadn't fired yet, but failed when another test on the same xdist
  worker triggered it.

Added an `axolotl_caplog` fixture that attaches caplog.handler directly
to the axolotl.processing_strategies logger, making the three tests
robust to either logging state. No production behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Apr 24, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/test_processing_strategies.py (1)

406-422: Avoid shared mutable class vocab in _FakeGemma4Tokenizer.

VOCAB is a mutable class attribute; if mutated in one test path it can leak into others. Copy it when constructing the tokenizer instance.

Proposed fix
 class _FakeGemma4Tokenizer(_Tokenizer):
@@
     def __init__(self):
-        super().__init__(self.VOCAB, pad_id=0, unk_id=3)
+        super().__init__(dict(self.VOCAB), pad_id=0, unk_id=3)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_processing_strategies.py` around lines 406 - 422, The class
_FakeGemma4Tokenizer uses a mutable class attribute VOCAB and passes it directly
to super().__init__, which can leak mutations across tests; update __init__ to
pass a fresh copy (e.g., a shallow copy or deepcopy of VOCAB) to
super().__init__ so each tokenizer instance gets its own independent vocab
(referencing _FakeGemma4Tokenizer, VOCAB, and __init__ in the diff), keeping
pad_id=0 and unk_id=3 unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 194-197: The code currently falls back to
convert_legacy_format(example) when messages is present but None (and
conversations is absent), causing a KeyError; instead explicitly detect when
example.get("messages") is None and "conversations" not in example and raise a
clear validation error (or return a normalized error response) before calling
convert_legacy_format. Update the branch around processed_example assignment
(the if using "messages" and example["messages"]) to add a separate check for
example.get("messages") is None && "conversations" not in example and
raise/return a descriptive error mentioning missing messages/conversations, then
only call convert_legacy_format(example) in the true legacy case. Ensure
references to convert_legacy_format and processed_example remain consistent.

---

Nitpick comments:
In `@tests/test_processing_strategies.py`:
- Around line 406-422: The class _FakeGemma4Tokenizer uses a mutable class
attribute VOCAB and passes it directly to super().__init__, which can leak
mutations across tests; update __init__ to pass a fresh copy (e.g., a shallow
copy or deepcopy of VOCAB) to super().__init__ so each tokenizer instance gets
its own independent vocab (referencing _FakeGemma4Tokenizer, VOCAB, and __init__
in the diff), keeping pad_id=0 and unk_id=3 unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ff621d22-a660-4821-a295-828f4b8ebd9b

📥 Commits

Reviewing files that changed from the base of the PR and between 08f8e4e and 42a47d7.

📒 Files selected for processing (3)
  • docs/multimodal_assistant_mask.md
  • src/axolotl/processing_strategies.py
  • tests/test_processing_strategies.py
✅ Files skipped from review due to trivial changes (1)
  • docs/multimodal_assistant_mask.md

Comment thread src/axolotl/processing_strategies.py
…ocab handling

Two review findings from CodeRabbit:

- When `messages` is present but None and `conversations` is absent, the
  code fell through to `convert_legacy_format(example)`, which then
  raised `KeyError: 'conversations'` from `example["conversations"]`
  instead of a clear validation error. Split the branch explicitly:
  messages-non-None → deepcopy and use as-is (symmetric with the legacy
  branch which already deepcopies internally); conversations present →
  legacy conversion; otherwise raise a descriptive ValueError that names
  both required keys.

- `_FakeGemma4Tokenizer.VOCAB` is a class-level mutable dict passed by
  reference to `super().__init__`, so every tokenizer instance aliases
  the same dict. No current code path mutates it, but the one-line
  `dict(self.VOCAB)` copy hardens the fake against any future mutation
  leaking across tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Apr 24, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/processing_strategies.py`:
- Around line 94-109: The current conditional uses "if
role_boundaries_override:" which treats an explicit empty list as falsy and
ignores it; change the check to test for None (e.g., "if
role_boundaries_override is not None") so that an explicit [] is treated as a
real override; keep the logic that calls
_resolve_role_boundary_override(self.processor.tokenizer,
role_boundaries_override) (or the existing argument order) and assigns the
result to self.role_boundaries, and preserve the LOG.info call and the "source"
assignment so built-in defaults are only used when no override (None) is
provided.

In `@tests/test_processing_strategies.py`:
- Around line 420-424: The current __init__ uses dict(self.VOCAB) which only
shallow-copies the mapping so the inner ID lists remain shared; change the
constructor to deep-copy the VOCAB before passing it to super (e.g., use
copy.deepcopy(self.VOCAB) or recreate each value list) so per-instance mutations
of self.vocab[...] cannot leak across tests; keep the same pad_id=0 and unk_id=3
parameters when calling super.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5b40dc62-74b4-49f9-93d9-4de91db20a8a

📥 Commits

Reviewing files that changed from the base of the PR and between 42a47d7 and d444156.

📒 Files selected for processing (2)
  • src/axolotl/processing_strategies.py
  • tests/test_processing_strategies.py

Comment on lines +94 to +109
if role_boundaries_override:
overridden = _resolve_role_boundary_override(
role_boundaries_override, self.processor.tokenizer
)
LOG.info(
"%s: overriding built-in role boundaries (%d decls) "
"with cfg.role_boundaries (%d decls).",
type(self).__name__,
len(built_in),
len(overridden),
)
self.role_boundaries: list[RoleBoundary] = overridden
source = "override"
else:
self.role_boundaries = built_in
source = "built-in"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Treat an explicit empty role_boundaries_override as a real override.

if role_boundaries_override: falls back to the built-ins for [], so cfg.role_boundaries: [] cannot clear a strategy’s default boundaries even though overrides are otherwise wholesale replacements.

Suggested fix
-        if role_boundaries_override:
+        if role_boundaries_override is not None:
             overridden = _resolve_role_boundary_override(
                 role_boundaries_override, self.processor.tokenizer
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/processing_strategies.py` around lines 94 - 109, The current
conditional uses "if role_boundaries_override:" which treats an explicit empty
list as falsy and ignores it; change the check to test for None (e.g., "if
role_boundaries_override is not None") so that an explicit [] is treated as a
real override; keep the logic that calls
_resolve_role_boundary_override(self.processor.tokenizer,
role_boundaries_override) (or the existing argument order) and assigns the
result to self.role_boundaries, and preserve the LOG.info call and the "source"
assignment so built-in defaults are only used when no override (None) is
provided.

Comment thread tests/test_processing_strategies.py
@thad0ctor

Copy link
Copy Markdown
Owner Author

Closing to re-open with squashed history. Will be replaced by a fresh PR with a single squashed commit so CodeRabbit can re-review cleanly.

@thad0ctor thad0ctor closed this Apr 24, 2026
thad0ctor added a commit that referenced this pull request May 12, 2026
Five test-quality refinements from CodeRabbit's third-round review.

**R3-#2 — deterministic teardown in test_dora.**

Wrap the DoRA smoke's wrap → train → assert sequence in
``try/finally`` so ``wrapped.close()`` runs even when the
loss-descent assertion fails mid-test. Without this, an early
assertion failure leaves hooks, pinned-host borrows, and CPU
adapter threads alive into subsequent GPU tests on the same
pytest session.

**R3-#3 — distinguish hook edges in test_lora_offload_mode
recording stub.**

The pre-fix ``_RecordingScheduler.ensure_chunks_resident``
recorded every container callback under the same
``"ensure_chunks_resident"`` label. The per-hook tests
(pre_forward / post_forward / post_backward fires
``ensure_chunks_resident``) then asserted only call COUNT — so a
regression that deleted the pre-forward hook factory while
post-forward still fired would still pass the count gates.

Tag each call with its originating hook edge via frame
inspection on the caller's ``co_qualname`` (Python 3.11+
guarantees the qualname captures the enclosing
``_make_lora_container_<edge>_hook`` factory). The four LoRA
container hooks all funnel through the same
``ensure_chunks_resident`` entry point but their closures live
in distinct factory functions, so the qualname uniquely
identifies the edge.

Update each per-hook test to filter on the edge-tagged label so
a regression in any single edge fails the corresponding test:

* pre_forward test: asserts ``ensure_chunks_resident:pre_forward``
  fires ≥ n_blocks times.
* post_forward test: asserts BOTH ``:pre_forward`` AND
  ``:post_forward`` fire ≥ n_containers times each (the previous
  bare ≥ 2*n_containers count was satisfied by either edge alone).
* post_backward test: asserts all four edges (pre/post fwd, pre/
  post bwd) fire ≥ n_containers times each.

The production hook factory layout is unchanged — the stub
recovers the edge from the existing closure's frame, no new
arguments thread through ``install_hooks``.

**R3-#4 — narrow protrain_model_wrapper exception scope in
test_lora_offload_mode:1117.**

The bare ``except (ValueError, RuntimeError)`` was treating any
wrapper failure as "offload setup unavailable" and skipping. A
broken ``protrain_model_wrapper`` runtime path could leave this
smoke green. Restrict the suppression to known env-failure
substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb
load, ``No module named``, and capacity/searcher gates) — same
canonical tuple D8 used at the optimizer-step site below — and
re-raise anything else. Real wrapper regressions now surface.

**R3-#5 — fail-safe CUDA teardown in
test_param_data_shape_preservation.**

Eight test functions in this module construct ``mgr / layout /
pool / host`` via ``_build_chunk_manager`` and tear them down at
the happy-path tail (``mgr.uninstall()`` / ``host.close()`` /
``del pool``). Any earlier assertion failure skipped the
teardown, leaking pinned-host borrows + CUDA buffer-pool state
into subsequent GPU tests.

Add a top-level ``_teardown_chunk_manager(mgr, host, pool)``
helper that does the best-effort 3-call teardown (each call
wrapped in its own try/except so a failure in ``uninstall``
doesn't block the ``host.close``), and wrap each test body in
``try: ... finally: _teardown_chunk_manager(...)``. Done
programmatically across all 8 tests via a one-shot Python
rewrite to keep the diff mechanical and the new structure
consistent.

**R3-#8 — replace hard-coded n_chunk_estimate=1 in
test_trace_skip_on_override.**

The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based
on the assumption that the tiny GPT-2 fixture produces a single
chunk. If the layout heuristics (``pick_S_chunk`` default,
block-discovery rules) shift such that ``N_chunk > 1``,
``min_n_buffer_for(layout, n_persist=1)`` rejects
``n_buffer_override=0`` BEFORE the wrapper reaches the
trace-skip gate the test is supposed to validate — converting
this into a flaky non-target failure.

Compute ``n_chunk_estimate`` dynamically by running the same
``discover_blocks`` → ``flatten_block_trees`` → ``build_layout``
pipeline the wrapper itself uses (with the wrapper's default
S_chunk), and pass the resulting ``layout.N_chunk`` through.
``n_persist_override = n_chunk_estimate`` then keeps the
all-persistent invariant the test relies on regardless of any
future layout-heuristic shift.

``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped
/ 0 failed. GPU-marker sweep on touched files: 40 passed /
2 skipped (single-process Mode-C downgrade for shape-preserving
placeholder paths) / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 12, 2026
…inor)

Four CodeRabbit findings on commits ``b61f04e0`` (init-transient peak
prediction) + ``aa0c6ba9`` (Mode-C steady CKPT-chain accounting).

**R4-#1 (Critical) — post-block-wrap DDP ignore-set re-registration.**

The M6C-fix-8 ignore-set registration ran BEFORE block-wrap. The
block wrappers (``block/checkpoint.py``, ``block/swap.py``,
``block/offload.py``) all do ``self.block = block``, which means
PyTorch's ``named_parameters()`` traversal inserts a ``.block.``
infix into the parameter namespace
(``layers.0.attn.q_proj.weight`` ⇒
``layers.0.block.attn.q_proj.weight``). The pre-wrap names captured
in ``model._ddp_params_and_buffers_to_ignore`` no longer match the
namespace DDP's ``__init__`` walks at construction time.

The init-time broadcast is irrelevant (M6C-fix-8's
``init_sync=False`` monkey-patch bypasses it wholesale on
chunk-managed models), but DDP's BACKWARD-pass allreduce still
consults the ignore list. A stale ignore set means DDP's backward
allreduce would attempt to all-reduce chunk-managed LoRA factor
gradients, conflicting with ProTrain's per-chunk ``reduce_scatter``
drain.

Add a post-wrap re-registration step after ``install_hooks`` in
``_construct_runtime``: walk the WRAPPED ``model.named_parameters()``
and identify chunk-managed params by OBJECT identity against
``chunk_manager._params_by_id.values()``. Build the post-wrap name
set, merge with the pre-protrain snapshot
(``_protrain_ddp_original_ignore``), overwrite the attribute.
Gated on ``_shape_preserving`` so the single-GPU / replicated path
remains a no-op.

**R4-#2 (Major) — reuse bootstrap init-transient peak instead of
recomputing post-offload.**

``predict_init_transient_peak_bytes(layout, hw, chunk_manager)``
walks ``chunk_manager.model.named_parameters()`` to sum chunk
bytes. By the time the phase-2 post-measurement calibration runs,
``materialize_offload`` has already executed and ``param.data``
points at the zero-size placeholders (replicated path) or
``scratch.expand(slot.shape)`` views (sharded path), so the
byte accounting drifts away from the bootstrap-time full-residence
prediction.

Replace the recompute call at the phase-2 post-measurement
calibration site with ``boot_result.predicted_init_transient_peak_bytes``
— the bootstrap-time value captured at line 1614 before
materialize_offload ran. The downstream consumers (SearchResult
publish, LOG.info diagnostic) get the same authoritative value
without re-walking a now-stale chunk_manager.

**R4-#3 (Major) — meta tensors in ``_stub_chunk_manager`` to avoid
CI OOM.**

``tests/protrain/test_init_transient_peak.py::_stub_chunk_manager``
allocated full CPU tensors sized to model 15–60 GiB chunk
totals. ``predict_init_transient_peak_bytes`` only reads
``param.numel() * param.element_size()``, so meta-device tensors
preserve the byte-accounting metadata without consuming RAM.
Switch the ``nn.Parameter(torch.zeros(numel, dtype, device='cpu'))``
construction to ``nn.Parameter(torch.empty(numel, dtype,
device='meta'), requires_grad=False)``.

**R4-#4 (Minor) — align docstring tolerance with ``TOLERANCE_FRAC = 0.35``.**

``tests/protrain/test_modec_steady_peak_accuracy.py`` docstring
said "±25%" but ``TOLERANCE_FRAC = 0.35`` and the assertion uses
0.35. Update the two docstring mentions to "±35%" so text matches
intent.

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker sweep: 313 passed / 4
  skipped / 162 deselected / 0 failed.
- GPU sanity on touched test files (GPU 5): 24 passed / 2 skipped
  (single-process Mode-C downgrade — expected) / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 12, 2026
…est fixes

Seven Minor items from the CodeRabbit full-diff re-scan on
commit ``55377e5d``.

**F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper``
8-bit warning (``api/optim_wrapper.py:802-815``).**

The warning told users to set ``protrain_force_all_persistent: true``
to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't
mention that ``protrain_force_all_persistent`` is ignored while
``protrain_auto_mode`` is on (the auto-mode selector picks the mode
itself based on capacity). Expanded the warning to instruct users
to set ``protrain_auto_mode: false`` AND
``protrain_force_all_persistent: true`` together.

**F-#4 — Unify fragmentation-alpha docs in DESIGN.md.**

Module summaries at lines 49 (``cost/memory.py``) and 118
(``memory.py`` module spec) still described a fixed ``alpha=1.10``
while Design Decision 1 documents the per-dtype lookup
(``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both
summaries to reference the per-dtype helper
(``alpha_fragmentation_for_dtype``) and the design decision section.

**F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.**

Line 109 (``block/checkpoint.py`` module spec) said
``use_reentrant=False``, which matches the actual implementation
(verified via ``grep`` against ``block/checkpoint.py:99``). Line 290
(audit Block G analysis) claimed ``use_reentrant=True, the
production wrap`` — stale and incorrect. Updated the analysis text
to acknowledge ``use_reentrant=False`` is the production wrap and
re-stated the per-block-input residual mechanism in a form
compatible with the non-reentrant variant (each CKPT block's
saved-tensors-hooks recompute frame holds the block input, which
is what produces the linear-in-N_block activation footprint the
audit data exposes).

**F-#8 — Centralized CUDA-availability guard in
``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.**

The helper unconditionally returned ``torch.device("cuda:0")``,
so a custom marker filter or conftest override that lands the
module in a CPU-only context would surface as a torch error
before any test body. Added a
``pytest.skip("CUDA not available; ...")`` early-return so every
gpu-marked test in the module gets a clean skip.

**F-#9 — Replace silent ``try/except: pass`` with
``contextlib.suppress(Exception)`` in
``tests/protrain/test_lora_offload_mode.py``.**

Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044
— each had the same ``for h in handles: try: h.remove() except
Exception: pass`` pattern that Ruff S110 flags. Replaced with
``contextlib.suppress(Exception)`` over the loop. Semantics
unchanged (best-effort cleanup, tolerate already-removed handles
or torch shutting down mid-test); intent now documented by the
context manager.

**F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.**

Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``.

**F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of
``test_trace_skip_on_override.py``.**

``test_run_trace_skipped_on_override_full_path`` (L255-282),
``test_run_trace_invoked_without_override`` (L319-337), and
``test_partial_overrides_do_not_skip_trace`` (L381-400) each
called ``wrapped.close()`` only on the success path — assertion
failures earlier in the test body would skip the close and leak
CUDA + chunk resources into subsequent GPU tests. Wrapped each
test body in ``try/finally`` so ``wrapped.close()`` always
runs. Done programmatically via a one-shot Python rewrite
(8 lines of new indent + 2 lines of try/finally per site) to
keep the diff mechanical.

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker: 313 passed / 4 skipped /
  162 deselected / 0 failed.
- GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped /
  0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 23, 2026
…g time

Adds a Pydantic validator that emits a warning when ProTrain is active and
the effective batch size (micro_batch_size * gradient_accumulation_steps) is
less than 4. ProTrain's per-iter scheduler walk and chunk-management hook
fan-out carry fixed overhead; below effective bs=4 the per-step overhead
dominates and users see ~5-7x throughput regressions vs vanilla bs=1
(proposal §6.d, hot-path fix tracked in §16 follow-up PR #4).

Lets users self-diagnose 'ProTrain is mysteriously slow' before opening a
support thread. The mitigation is just to raise gradient_accumulation_steps
(same effective batch, much better throughput).
thad0ctor added a commit that referenced this pull request May 24, 2026
…for bs=1 hot path

cProfile on a 4-block synthetic LoRA model identified three per-step
Python overheads that don't amortize at bs=1: (a) cuda.is_available()
fired in ensure_chunks_resident on every LoRA-container hook (~600
us/step at n_blocks=8); (b) per-step set() construction in
post_block_forward; (c) PyTorch's setup_*_hook autograd machinery
firing per forward (~1.4 ms/step) — fixed-cost given the LoRA-container
hook quartet pinned by offload-correctness tests.

This change addresses (a) + (b): caches torch.cuda.is_available() as
self._has_cuda at Scheduler init, precomputes next-block / prev-block
ids and next-chunks frozensets keyed by BlockId, and removes the
per-call torch reimport from ensure_chunks_resident. (c) is GPU-profile-
driven and rescoped under §16 PR #4.

Microbench tests/protrain/test_bs1_hot_path_microbench.py captures the
shape: pre-fix overhead 4.7 ms/step at n_blocks=8 → post-fix ~2.4
ms/step (~49%% reduction in Python-attributable per-step overhead).
bs=4 wall-time unchanged within noise; no regressions across
tests/protrain (369 passing).
thad0ctor added a commit that referenced this pull request May 28, 2026
Five test-quality refinements from CodeRabbit's third-round review.

**R3-#2 — deterministic teardown in test_dora.**

Wrap the DoRA smoke's wrap → train → assert sequence in
``try/finally`` so ``wrapped.close()`` runs even when the
loss-descent assertion fails mid-test. Without this, an early
assertion failure leaves hooks, pinned-host borrows, and CPU
adapter threads alive into subsequent GPU tests on the same
pytest session.

**R3-#3 — distinguish hook edges in test_lora_offload_mode
recording stub.**

The pre-fix ``_RecordingScheduler.ensure_chunks_resident``
recorded every container callback under the same
``"ensure_chunks_resident"`` label. The per-hook tests
(pre_forward / post_forward / post_backward fires
``ensure_chunks_resident``) then asserted only call COUNT — so a
regression that deleted the pre-forward hook factory while
post-forward still fired would still pass the count gates.

Tag each call with its originating hook edge via frame
inspection on the caller's ``co_qualname`` (Python 3.11+
guarantees the qualname captures the enclosing
``_make_lora_container_<edge>_hook`` factory). The four LoRA
container hooks all funnel through the same
``ensure_chunks_resident`` entry point but their closures live
in distinct factory functions, so the qualname uniquely
identifies the edge.

Update each per-hook test to filter on the edge-tagged label so
a regression in any single edge fails the corresponding test:

* pre_forward test: asserts ``ensure_chunks_resident:pre_forward``
  fires ≥ n_blocks times.
* post_forward test: asserts BOTH ``:pre_forward`` AND
  ``:post_forward`` fire ≥ n_containers times each (the previous
  bare ≥ 2*n_containers count was satisfied by either edge alone).
* post_backward test: asserts all four edges (pre/post fwd, pre/
  post bwd) fire ≥ n_containers times each.

The production hook factory layout is unchanged — the stub
recovers the edge from the existing closure's frame, no new
arguments thread through ``install_hooks``.

**R3-#4 — narrow protrain_model_wrapper exception scope in
test_lora_offload_mode:1117.**

The bare ``except (ValueError, RuntimeError)`` was treating any
wrapper failure as "offload setup unavailable" and skipping. A
broken ``protrain_model_wrapper`` runtime path could leave this
smoke green. Restrict the suppression to known env-failure
substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb
load, ``No module named``, and capacity/searcher gates) — same
canonical tuple D8 used at the optimizer-step site below — and
re-raise anything else. Real wrapper regressions now surface.

**R3-#5 — fail-safe CUDA teardown in
test_param_data_shape_preservation.**

Eight test functions in this module construct ``mgr / layout /
pool / host`` via ``_build_chunk_manager`` and tear them down at
the happy-path tail (``mgr.uninstall()`` / ``host.close()`` /
``del pool``). Any earlier assertion failure skipped the
teardown, leaking pinned-host borrows + CUDA buffer-pool state
into subsequent GPU tests.

Add a top-level ``_teardown_chunk_manager(mgr, host, pool)``
helper that does the best-effort 3-call teardown (each call
wrapped in its own try/except so a failure in ``uninstall``
doesn't block the ``host.close``), and wrap each test body in
``try: ... finally: _teardown_chunk_manager(...)``. Done
programmatically across all 8 tests via a one-shot Python
rewrite to keep the diff mechanical and the new structure
consistent.

**R3-#8 — replace hard-coded n_chunk_estimate=1 in
test_trace_skip_on_override.**

The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based
on the assumption that the tiny GPT-2 fixture produces a single
chunk. If the layout heuristics (``pick_S_chunk`` default,
block-discovery rules) shift such that ``N_chunk > 1``,
``min_n_buffer_for(layout, n_persist=1)`` rejects
``n_buffer_override=0`` BEFORE the wrapper reaches the
trace-skip gate the test is supposed to validate — converting
this into a flaky non-target failure.

Compute ``n_chunk_estimate`` dynamically by running the same
``discover_blocks`` → ``flatten_block_trees`` → ``build_layout``
pipeline the wrapper itself uses (with the wrapper's default
S_chunk), and pass the resulting ``layout.N_chunk`` through.
``n_persist_override = n_chunk_estimate`` then keeps the
all-persistent invariant the test relies on regardless of any
future layout-heuristic shift.

``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped
/ 0 failed. GPU-marker sweep on touched files: 40 passed /
2 skipped (single-process Mode-C downgrade for shape-preserving
placeholder paths) / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
…inor)

Four CodeRabbit findings on commits ``b61f04e0`` (init-transient peak
prediction) + ``aa0c6ba9`` (Mode-C steady CKPT-chain accounting).

**R4-#1 (Critical) — post-block-wrap DDP ignore-set re-registration.**

The M6C-fix-8 ignore-set registration ran BEFORE block-wrap. The
block wrappers (``block/checkpoint.py``, ``block/swap.py``,
``block/offload.py``) all do ``self.block = block``, which means
PyTorch's ``named_parameters()`` traversal inserts a ``.block.``
infix into the parameter namespace
(``layers.0.attn.q_proj.weight`` ⇒
``layers.0.block.attn.q_proj.weight``). The pre-wrap names captured
in ``model._ddp_params_and_buffers_to_ignore`` no longer match the
namespace DDP's ``__init__`` walks at construction time.

The init-time broadcast is irrelevant (M6C-fix-8's
``init_sync=False`` monkey-patch bypasses it wholesale on
chunk-managed models), but DDP's BACKWARD-pass allreduce still
consults the ignore list. A stale ignore set means DDP's backward
allreduce would attempt to all-reduce chunk-managed LoRA factor
gradients, conflicting with ProTrain's per-chunk ``reduce_scatter``
drain.

Add a post-wrap re-registration step after ``install_hooks`` in
``_construct_runtime``: walk the WRAPPED ``model.named_parameters()``
and identify chunk-managed params by OBJECT identity against
``chunk_manager._params_by_id.values()``. Build the post-wrap name
set, merge with the pre-protrain snapshot
(``_protrain_ddp_original_ignore``), overwrite the attribute.
Gated on ``_shape_preserving`` so the single-GPU / replicated path
remains a no-op.

**R4-#2 (Major) — reuse bootstrap init-transient peak instead of
recomputing post-offload.**

``predict_init_transient_peak_bytes(layout, hw, chunk_manager)``
walks ``chunk_manager.model.named_parameters()`` to sum chunk
bytes. By the time the phase-2 post-measurement calibration runs,
``materialize_offload`` has already executed and ``param.data``
points at the zero-size placeholders (replicated path) or
``scratch.expand(slot.shape)`` views (sharded path), so the
byte accounting drifts away from the bootstrap-time full-residence
prediction.

Replace the recompute call at the phase-2 post-measurement
calibration site with ``boot_result.predicted_init_transient_peak_bytes``
— the bootstrap-time value captured at line 1614 before
materialize_offload ran. The downstream consumers (SearchResult
publish, LOG.info diagnostic) get the same authoritative value
without re-walking a now-stale chunk_manager.

**R4-#3 (Major) — meta tensors in ``_stub_chunk_manager`` to avoid
CI OOM.**

``tests/protrain/test_init_transient_peak.py::_stub_chunk_manager``
allocated full CPU tensors sized to model 15–60 GiB chunk
totals. ``predict_init_transient_peak_bytes`` only reads
``param.numel() * param.element_size()``, so meta-device tensors
preserve the byte-accounting metadata without consuming RAM.
Switch the ``nn.Parameter(torch.zeros(numel, dtype, device='cpu'))``
construction to ``nn.Parameter(torch.empty(numel, dtype,
device='meta'), requires_grad=False)``.

**R4-#4 (Minor) — align docstring tolerance with ``TOLERANCE_FRAC = 0.35``.**

``tests/protrain/test_modec_steady_peak_accuracy.py`` docstring
said "±25%" but ``TOLERANCE_FRAC = 0.35`` and the assertion uses
0.35. Update the two docstring mentions to "±35%" so text matches
intent.

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker sweep: 313 passed / 4
  skipped / 162 deselected / 0 failed.
- GPU sanity on touched test files (GPU 5): 24 passed / 2 skipped
  (single-process Mode-C downgrade — expected) / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
…est fixes

Seven Minor items from the CodeRabbit full-diff re-scan on
commit ``55377e5d``.

**F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper``
8-bit warning (``api/optim_wrapper.py:802-815``).**

The warning told users to set ``protrain_force_all_persistent: true``
to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't
mention that ``protrain_force_all_persistent`` is ignored while
``protrain_auto_mode`` is on (the auto-mode selector picks the mode
itself based on capacity). Expanded the warning to instruct users
to set ``protrain_auto_mode: false`` AND
``protrain_force_all_persistent: true`` together.

**F-#4 — Unify fragmentation-alpha docs in DESIGN.md.**

Module summaries at lines 49 (``cost/memory.py``) and 118
(``memory.py`` module spec) still described a fixed ``alpha=1.10``
while Design Decision 1 documents the per-dtype lookup
(``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both
summaries to reference the per-dtype helper
(``alpha_fragmentation_for_dtype``) and the design decision section.

**F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.**

Line 109 (``block/checkpoint.py`` module spec) said
``use_reentrant=False``, which matches the actual implementation
(verified via ``grep`` against ``block/checkpoint.py:99``). Line 290
(audit Block G analysis) claimed ``use_reentrant=True, the
production wrap`` — stale and incorrect. Updated the analysis text
to acknowledge ``use_reentrant=False`` is the production wrap and
re-stated the per-block-input residual mechanism in a form
compatible with the non-reentrant variant (each CKPT block's
saved-tensors-hooks recompute frame holds the block input, which
is what produces the linear-in-N_block activation footprint the
audit data exposes).

**F-#8 — Centralized CUDA-availability guard in
``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.**

The helper unconditionally returned ``torch.device("cuda:0")``,
so a custom marker filter or conftest override that lands the
module in a CPU-only context would surface as a torch error
before any test body. Added a
``pytest.skip("CUDA not available; ...")`` early-return so every
gpu-marked test in the module gets a clean skip.

**F-#9 — Replace silent ``try/except: pass`` with
``contextlib.suppress(Exception)`` in
``tests/protrain/test_lora_offload_mode.py``.**

Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044
— each had the same ``for h in handles: try: h.remove() except
Exception: pass`` pattern that Ruff S110 flags. Replaced with
``contextlib.suppress(Exception)`` over the loop. Semantics
unchanged (best-effort cleanup, tolerate already-removed handles
or torch shutting down mid-test); intent now documented by the
context manager.

**F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.**

Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``.

**F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of
``test_trace_skip_on_override.py``.**

``test_run_trace_skipped_on_override_full_path`` (L255-282),
``test_run_trace_invoked_without_override`` (L319-337), and
``test_partial_overrides_do_not_skip_trace`` (L381-400) each
called ``wrapped.close()`` only on the success path — assertion
failures earlier in the test body would skip the close and leak
CUDA + chunk resources into subsequent GPU tests. Wrapped each
test body in ``try/finally`` so ``wrapped.close()`` always
runs. Done programmatically via a one-shot Python rewrite
(8 lines of new indent + 2 lines of try/finally per site) to
keep the diff mechanical.

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker: 313 passed / 4 skipped /
  162 deselected / 0 failed.
- GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped /
  0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
…g time

Adds a Pydantic validator that emits a warning when ProTrain is active and
the effective batch size (micro_batch_size * gradient_accumulation_steps) is
less than 4. ProTrain's per-iter scheduler walk and chunk-management hook
fan-out carry fixed overhead; below effective bs=4 the per-step overhead
dominates and users see ~5-7x throughput regressions vs vanilla bs=1
(proposal §6.d, hot-path fix tracked in §16 follow-up PR #4).

Lets users self-diagnose 'ProTrain is mysteriously slow' before opening a
support thread. The mitigation is just to raise gradient_accumulation_steps
(same effective batch, much better throughput).
thad0ctor added a commit that referenced this pull request May 28, 2026
…for bs=1 hot path

cProfile on a 4-block synthetic LoRA model identified three per-step
Python overheads that don't amortize at bs=1: (a) cuda.is_available()
fired in ensure_chunks_resident on every LoRA-container hook (~600
us/step at n_blocks=8); (b) per-step set() construction in
post_block_forward; (c) PyTorch's setup_*_hook autograd machinery
firing per forward (~1.4 ms/step) — fixed-cost given the LoRA-container
hook quartet pinned by offload-correctness tests.

This change addresses (a) + (b): caches torch.cuda.is_available() as
self._has_cuda at Scheduler init, precomputes next-block / prev-block
ids and next-chunks frozensets keyed by BlockId, and removes the
per-call torch reimport from ensure_chunks_resident. (c) is GPU-profile-
driven and rescoped under §16 PR #4.

Microbench tests/protrain/test_bs1_hot_path_microbench.py captures the
shape: pre-fix overhead 4.7 ms/step at n_blocks=8 → post-fix ~2.4
ms/step (~49%% reduction in Python-attributable per-step overhead).
bs=4 wall-time unchanged within noise; no regressions across
tests/protrain (369 passing).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[feature or bug?] Multimodal collator ignores train_on_inputs / roles_to_train / train_on_eos; every MM model except Gemma3n trains on full conversation

1 participant