Skip to content

fix: allow kwargs-only forward on PEFT ModulesToSaveWrapper#5

Closed
thad0ctor wants to merge 1 commit into
mainfrom
fix/peft-modules-to-save-kwargs
Closed

fix: allow kwargs-only forward on PEFT ModulesToSaveWrapper#5
thad0ctor wants to merge 1 commit into
mainfrom
fix/peft-modules-to-save-kwargs

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Apr 23, 2026

Copy link
Copy Markdown
Owner

Summary

PEFT's AuxiliaryTrainingWrapper.forward requires a positional x, so placing a kwargs-only-called module (e.g. Gemma 4's vision_tower / embed_vision) in lora_modules_to_save crashes on the first forward pass with TypeError: ... missing 1 required positional argument: 'x'.

This adds an idempotent monkeypatch that rewrites forward, _forward_wrapped, and _forward_wrapped_passthrough on AuxiliaryTrainingWrapper / ModulesToSaveWrapper to accept *args, **kwargs — backward-compatible with existing positional callers (e.g. embed_tokens).

  • Patch installed from PatchManager._apply_adapter_patches() (gated on cfg.adapter) so it runs before get_peft_model.
  • _mixed_batch_forward intentionally left positional — it requires sub-batch indexing and only fires under multi-adapter adapter_names=... calls, which don't happen in single-adapter training.
  • _check_forward_args is short-circuited when no positional input is provided (its only real work validates len(x) == len(adapter_names), which itself requires adapter_names to be set).

Reproduces

adapter: lora
lora_modules_to_save:
  - vision_tower   # or any module called with keyword args

Test plan

  • pytest tests/monkeypatch/test_peft_modules_to_save.py — 7/7 passing
  • ruff check + ruff format --check clean on touched files
  • kwargs-only forward (Gemma 4 vision_tower / embed_vision shape)
  • Positional forward still works (embed_tokens shape)
  • Mixed args + kwargs forward
  • Passthrough branch (enable_adapters(False)) also accepts kwargs-only
  • TrainableTokensWrapper (sibling subclass of AuxiliaryTrainingWrapper) end-to-end still works
  • adapter_names + kwargs-only raises TypeError (locks the deliberately-unsupported shape)
  • Patch is idempotent

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes
    • Improved adapter support to correctly process model calls with keyword arguments.
    • Enhanced compatibility when using adapters with advanced training configurations.

PEFT's ``AuxiliaryTrainingWrapper.forward`` requires a positional ``x``:

    def forward(self, x: torch.Tensor, *args, **kwargs):
        self._check_forward_args(x, *args, **kwargs)
        ...
        return self._forward_wrapped(x, *args, **kwargs)

This works for modules called positionally (e.g. ``embed_tokens``) but
crashes when a ``lora_modules_to_save`` entry is invoked with keyword
arguments only. Gemma 4's VLM forward path does this for both
``vision_tower`` and ``embed_vision``:

    vision_outputs = self.vision_tower(pixel_values=..., ...)
    ... = self.embed_vision(inputs_embeds=...)

Producing:

    TypeError: AuxiliaryTrainingWrapper.forward() missing 1 required
    positional argument: 'x'

at the first forward pass (e.g. eval step 0).

Add an idempotent monkeypatch that rewrites ``forward``,
``_forward_wrapped``, and ``_forward_wrapped_passthrough`` to accept
``*args, **kwargs``. ``_mixed_batch_forward`` is untouched — it needs
positional input for sub-batch indexing and only runs when
``adapter_names`` is passed (multi-adapter inference), which doesn't
happen in single-adapter training. ``_check_forward_args`` is
short-circuited when no positional args are provided; its only work is
validating ``len(x) == len(adapter_names)`` which is itself gated on
``adapter_names`` being set.

Applied from ``PatchManager._apply_adapter_patches`` (gated on
``cfg.adapter``) before ``get_peft_model`` runs, so every wrapper
constructed this training run gets the fixed forward. Backward-compatible:
positional calls still work (verified in tests against real PEFT classes).

Reproduces with:

    adapter: lora
    lora_modules_to_save:
      - vision_tower   # or any module called with keyword args

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Apr 23, 2026

Copy link
Copy Markdown
📝 Walkthrough

Walkthrough

This pull request introduces a new monkeypatch for PEFT's module wrapper to support keyword-argument-based forward calls alongside positional arguments. The patch is installed when adapter support is enabled and includes comprehensive test coverage for various forward signature patterns and adapter states.

Changes

Cohort / File(s) Summary
Adapter Patching Integration
src/axolotl/loaders/patch_manager.py
Integrates the new PEFT monkeypatch by calling patch_peft_modules_to_save_kwargs() within the adapter initialization flow when cfg.adapter is enabled.
PEFT Modules-to-Save Monkeypatch
src/axolotl/monkeypatch/peft_modules_to_save.py
New monkeypatch module that modifies PEFT wrappers to handle forward calls with keyword arguments. Provides patched forward functions that validate args conditionally, extract adapter_names from kwargs, and route to appropriate forwarding paths. Includes idempotency tracking via a private attribute.
Monkeypatch Tests
tests/monkeypatch/test_peft_modules_to_save.py
Comprehensive test suite validating kwargs-only, positional-only, and mixed forward signatures; adapter disabled states; compatibility with TrainableTokensWrapper; unsupported edge cases; and patch idempotency.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.31% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and specifically describes the main change: adding support for kwargs-only forward calls on PEFT's ModulesToSaveWrapper, which is the core fix implemented across the three modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/peft-modules-to-save-kwargs

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/axolotl/monkeypatch/peft_modules_to_save.py (1)

46-50: Consider patching AuxiliaryTrainingWrapper._forward_wrapped as well for completeness.

The patch sets forward on AuxiliaryTrainingWrapper but only patches _forward_wrapped and _forward_wrapped_passthrough on ModulesToSaveWrapper. If AuxiliaryTrainingWrapper is ever instantiated directly (rather than through its subclass ModulesToSaveWrapper), it would use the patched forward but call its unpatched _forward_wrapped, potentially breaking kwargs-only forwarding.

This may be intentional if PEFT only instantiates ModulesToSaveWrapper for modules_to_save entries. However, for robustness, consider also patching AuxiliaryTrainingWrapper._forward_wrapped and _forward_wrapped_passthrough.

Proposed patch for completeness
     AuxiliaryTrainingWrapper.forward = _patched_forward
+    AuxiliaryTrainingWrapper._forward_wrapped = _patched_forward_wrapped
+    AuxiliaryTrainingWrapper._forward_wrapped_passthrough = (
+        _patched_forward_wrapped_passthrough
+    )
     ModulesToSaveWrapper._forward_wrapped = _patched_forward_wrapped
     ModulesToSaveWrapper._forward_wrapped_passthrough = (
         _patched_forward_wrapped_passthrough
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/peft_modules_to_save.py` around lines 46 - 50, The
patch currently replaces AuxiliaryTrainingWrapper.forward and sets
_forward_wrapped and _forward_wrapped_passthrough only on ModulesToSaveWrapper;
update AuxiliaryTrainingWrapper too so direct instantiation won't call the old
unpatched helpers: assign AuxiliaryTrainingWrapper._forward_wrapped =
_patched_forward_wrapped and
AuxiliaryTrainingWrapper._forward_wrapped_passthrough =
_patched_forward_wrapped_passthrough (similar to how
AuxiliaryTrainingWrapper.forward = _patched_forward is set) so both helper
methods are consistently patched alongside forward.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/axolotl/monkeypatch/peft_modules_to_save.py`:
- Around line 46-50: The patch currently replaces
AuxiliaryTrainingWrapper.forward and sets _forward_wrapped and
_forward_wrapped_passthrough only on ModulesToSaveWrapper; update
AuxiliaryTrainingWrapper too so direct instantiation won't call the old
unpatched helpers: assign AuxiliaryTrainingWrapper._forward_wrapped =
_patched_forward_wrapped and
AuxiliaryTrainingWrapper._forward_wrapped_passthrough =
_patched_forward_wrapped_passthrough (similar to how
AuxiliaryTrainingWrapper.forward = _patched_forward is set) so both helper
methods are consistently patched alongside forward.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bc44d88e-287b-464f-bd33-32d671fd4761

📥 Commits

Reviewing files that changed from the base of the PR and between 7420fd4 and a22a61d.

📒 Files selected for processing (3)
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/peft_modules_to_save.py
  • tests/monkeypatch/test_peft_modules_to_save.py

@thad0ctor thad0ctor closed this Apr 24, 2026
thad0ctor added a commit that referenced this pull request May 6, 2026
Two related fixes landing together — both Codex-flagged after the
previous round closed gaps #1 (persistent peak), #2 (T_reduce), and #5
(per-chunk contention). Bundled because both touch the cost-model
search path and add regression tests to the same file.

## Fix 1: searcher fast-path applies cap layering (closes #1 fully)

909fc9e fixed cost/memory.py::estimate_peak to layer the
hot_iter_peak_cap (cap only the activation portion, preserve
model_state_present). The same bug existed in search/exhaustive.py:649
inline fast path — it still applied the cap as a flat clamp, erasing
the Adam state contribution.

Codex synthetic repro: search() returned predicted_peak_bytes=78MB
while estimate_peak() for the same config returned 7.09GB — a 90×
divergence on full-FT shapes.

Extracted shared helper apply_hot_iter_cap(raw_peak, model_state_present,
measured_cap, layout) -> int into cost/memory.py. Used at three sites
to eliminate drift risk:
  - cost/memory.py::estimate_peak (replaces inline 909fc9e code)
  - search/exhaustive.py inline fast path (the Codex bug site)
  - search/exhaustive.py _cap_dominates probe (also was naive)

The _cap_dominates shortcut at line 522 was widening max_sum based on
the false claim "predicted_peak collapses to alpha * hot_cap
independent of (n_persist + n_buffer)." Tightened to use the layered
cap with n_persist=N_chunk worst-case probe — shortcut now activates
only when the layered worst-case cap fits in capacity. LoRA-shape
efficiency win preserved; full-FT shapes correctly exit the shortcut.

New regression test test_search_fast_path_cap_preserves_full_ft_model_state
asserts search()'s predicted_peak_bytes agrees with estimate_peak()'s
output and clears the alpha * model_state_present floor. Falsification
verified: temporarily reverting the inline fix reproduces the exact
74MB-vs-7GB Codex pattern.

## Fix 2: phase-2 measured-wall override gates on n_swap == 0

e8f45fd added the per-chunk timeline-overlap bandwidth model on the
analytical path. But cost/runtime.py:703 (forward) and :820 (backward)
have a phase-2 measured-chunked-wall override that bypasses the
chunk_bw_fwd / chunk_bw_bwd vectors entirely. The phase-2 measurement
was captured at n_swap=0 (verified at profiler/phase2.py:117), so when
an n_swap > 0 candidate is evaluated using the measured wall, the
per-chunk SWAP contention is missed — only the swap-transfer term is
charged.

Fix: AND both override conditions with cfg.n_swap == 0. Phase-2 stays
applicable in the dominant case (no SWAP, hot-iter measured walls);
n_swap > 0 candidates fall through to the analytical path that
already handles per-chunk contention correctly.

Numerics for n_swap == 0 candidates are bit-identical (the analytical
path's per-chunk derate collapses to full bandwidth when n_swap == 0).

New regression test test_phase2_override_routes_n_swap_through_per_chunk_contention
constructs n_swap=0 vs n_swap=N_block configs under a fully-populated
phase-2 trace, asserts the gap exceeds the swap-transfer-only lower
bound (the exact pre-fix value). Falsification verified: temporarily
reverting both gates reproduces gap == swap-transfer-only.

## Test results

41/41 pass in tests/protrain/test_cost_search.py (39 prior + 2 new).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 12, 2026
Five test-quality refinements from CodeRabbit's third-round review.

**R3-#2 — deterministic teardown in test_dora.**

Wrap the DoRA smoke's wrap → train → assert sequence in
``try/finally`` so ``wrapped.close()`` runs even when the
loss-descent assertion fails mid-test. Without this, an early
assertion failure leaves hooks, pinned-host borrows, and CPU
adapter threads alive into subsequent GPU tests on the same
pytest session.

**R3-#3 — distinguish hook edges in test_lora_offload_mode
recording stub.**

The pre-fix ``_RecordingScheduler.ensure_chunks_resident``
recorded every container callback under the same
``"ensure_chunks_resident"`` label. The per-hook tests
(pre_forward / post_forward / post_backward fires
``ensure_chunks_resident``) then asserted only call COUNT — so a
regression that deleted the pre-forward hook factory while
post-forward still fired would still pass the count gates.

Tag each call with its originating hook edge via frame
inspection on the caller's ``co_qualname`` (Python 3.11+
guarantees the qualname captures the enclosing
``_make_lora_container_<edge>_hook`` factory). The four LoRA
container hooks all funnel through the same
``ensure_chunks_resident`` entry point but their closures live
in distinct factory functions, so the qualname uniquely
identifies the edge.

Update each per-hook test to filter on the edge-tagged label so
a regression in any single edge fails the corresponding test:

* pre_forward test: asserts ``ensure_chunks_resident:pre_forward``
  fires ≥ n_blocks times.
* post_forward test: asserts BOTH ``:pre_forward`` AND
  ``:post_forward`` fire ≥ n_containers times each (the previous
  bare ≥ 2*n_containers count was satisfied by either edge alone).
* post_backward test: asserts all four edges (pre/post fwd, pre/
  post bwd) fire ≥ n_containers times each.

The production hook factory layout is unchanged — the stub
recovers the edge from the existing closure's frame, no new
arguments thread through ``install_hooks``.

**R3-#4 — narrow protrain_model_wrapper exception scope in
test_lora_offload_mode:1117.**

The bare ``except (ValueError, RuntimeError)`` was treating any
wrapper failure as "offload setup unavailable" and skipping. A
broken ``protrain_model_wrapper`` runtime path could leave this
smoke green. Restrict the suppression to known env-failure
substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb
load, ``No module named``, and capacity/searcher gates) — same
canonical tuple D8 used at the optimizer-step site below — and
re-raise anything else. Real wrapper regressions now surface.

**R3-#5 — fail-safe CUDA teardown in
test_param_data_shape_preservation.**

Eight test functions in this module construct ``mgr / layout /
pool / host`` via ``_build_chunk_manager`` and tear them down at
the happy-path tail (``mgr.uninstall()`` / ``host.close()`` /
``del pool``). Any earlier assertion failure skipped the
teardown, leaking pinned-host borrows + CUDA buffer-pool state
into subsequent GPU tests.

Add a top-level ``_teardown_chunk_manager(mgr, host, pool)``
helper that does the best-effort 3-call teardown (each call
wrapped in its own try/except so a failure in ``uninstall``
doesn't block the ``host.close``), and wrap each test body in
``try: ... finally: _teardown_chunk_manager(...)``. Done
programmatically across all 8 tests via a one-shot Python
rewrite to keep the diff mechanical and the new structure
consistent.

**R3-#8 — replace hard-coded n_chunk_estimate=1 in
test_trace_skip_on_override.**

The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based
on the assumption that the tiny GPT-2 fixture produces a single
chunk. If the layout heuristics (``pick_S_chunk`` default,
block-discovery rules) shift such that ``N_chunk > 1``,
``min_n_buffer_for(layout, n_persist=1)`` rejects
``n_buffer_override=0`` BEFORE the wrapper reaches the
trace-skip gate the test is supposed to validate — converting
this into a flaky non-target failure.

Compute ``n_chunk_estimate`` dynamically by running the same
``discover_blocks`` → ``flatten_block_trees`` → ``build_layout``
pipeline the wrapper itself uses (with the wrapper's default
S_chunk), and pass the resulting ``layout.N_chunk`` through.
``n_persist_override = n_chunk_estimate`` then keeps the
all-persistent invariant the test relies on regardless of any
future layout-heuristic shift.

``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped
/ 0 failed. GPU-marker sweep on touched files: 40 passed /
2 skipped (single-process Mode-C downgrade for shape-preserving
placeholder paths) / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 12, 2026
…s (R5)

CodeRabbit R5 review (final pass on c996ce9 + f09be09) flagged Ruff
RUF002/RUF003 warnings for confusable unicode glyphs across the new
audit-Block-G commentary added by 2fcc1fc / b61f04e / aa0c6ba /
the per-dtype alpha lookup work. Same lint family R1-#5 and R2-#3
addressed in narrow scope before; this is the broader pass that
sweeps the rest of the protrain subtree.

Replacements (234 substitutions across 7 files):

| File                                              | alpha | x   | ∪ | Total |
|---------------------------------------------------|------:|----:|--:|------:|
| src/axolotl/integrations/protrain/cost/memory.py  |   23  |   1 | 0 |    24 |
| src/axolotl/integrations/protrain/api/model_wrapper.py | 39 | 14 | 4 |    57 |
| src/axolotl/integrations/protrain/types.py        |   23  |   6 | 2 |    31 |
| src/axolotl/integrations/protrain/DESIGN.md       |   19  |  17 | 0 |    36 |
| tests/protrain/test_modec_steady_peak_accuracy.py |    8  |   5 | 1 |    14 |
| tests/protrain/test_init_transient_peak.py        |    6  |   7 | 0 |    13 |
| tests/protrain/test_alpha_per_dtype.py            |   38  |   0 | 0 |    38 |

Substitution rules:
- Greek small letter alpha (U+03B1) → ``alpha``
- Multiplication sign (U+00D7) → ``x``
- Union operator (U+222A) → ``|`` (also the Python set-union operator,
  so doubly appropriate)

All replacements are in docstrings, comments, and pytest-parametrize
ID strings — zero changes to function names, type names, control
flow, or assertion text. ``param_to_chunk`` typed dict keys, set
literals, and any Python operator usage of ``|`` are unaffected.

Test parametrize IDs change cosmetically (e.g.
``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → α=1.10]`` ⇒
``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → alpha=1.10]``)
— the ``→`` arrow remains unchanged (Ruff doesn't flag it; CodeRabbit
flagged only ``α``/``×``/``∪`` explicitly).

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker sweep: 313 passed / 4 skipped /
  162 deselected / 0 failed.
- Ruff RUF002/RUF003 warnings across the seven touched protrain
  files: 234 → 0.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 12, 2026
…est fixes

Seven Minor items from the CodeRabbit full-diff re-scan on
commit ``55377e5d``.

**F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper``
8-bit warning (``api/optim_wrapper.py:802-815``).**

The warning told users to set ``protrain_force_all_persistent: true``
to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't
mention that ``protrain_force_all_persistent`` is ignored while
``protrain_auto_mode`` is on (the auto-mode selector picks the mode
itself based on capacity). Expanded the warning to instruct users
to set ``protrain_auto_mode: false`` AND
``protrain_force_all_persistent: true`` together.

**F-#4 — Unify fragmentation-alpha docs in DESIGN.md.**

Module summaries at lines 49 (``cost/memory.py``) and 118
(``memory.py`` module spec) still described a fixed ``alpha=1.10``
while Design Decision 1 documents the per-dtype lookup
(``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both
summaries to reference the per-dtype helper
(``alpha_fragmentation_for_dtype``) and the design decision section.

**F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.**

Line 109 (``block/checkpoint.py`` module spec) said
``use_reentrant=False``, which matches the actual implementation
(verified via ``grep`` against ``block/checkpoint.py:99``). Line 290
(audit Block G analysis) claimed ``use_reentrant=True, the
production wrap`` — stale and incorrect. Updated the analysis text
to acknowledge ``use_reentrant=False`` is the production wrap and
re-stated the per-block-input residual mechanism in a form
compatible with the non-reentrant variant (each CKPT block's
saved-tensors-hooks recompute frame holds the block input, which
is what produces the linear-in-N_block activation footprint the
audit data exposes).

**F-#8 — Centralized CUDA-availability guard in
``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.**

The helper unconditionally returned ``torch.device("cuda:0")``,
so a custom marker filter or conftest override that lands the
module in a CPU-only context would surface as a torch error
before any test body. Added a
``pytest.skip("CUDA not available; ...")`` early-return so every
gpu-marked test in the module gets a clean skip.

**F-#9 — Replace silent ``try/except: pass`` with
``contextlib.suppress(Exception)`` in
``tests/protrain/test_lora_offload_mode.py``.**

Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044
— each had the same ``for h in handles: try: h.remove() except
Exception: pass`` pattern that Ruff S110 flags. Replaced with
``contextlib.suppress(Exception)`` over the loop. Semantics
unchanged (best-effort cleanup, tolerate already-removed handles
or torch shutting down mid-test); intent now documented by the
context manager.

**F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.**

Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``.

**F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of
``test_trace_skip_on_override.py``.**

``test_run_trace_skipped_on_override_full_path`` (L255-282),
``test_run_trace_invoked_without_override`` (L319-337), and
``test_partial_overrides_do_not_skip_trace`` (L381-400) each
called ``wrapped.close()`` only on the success path — assertion
failures earlier in the test body would skip the close and leak
CUDA + chunk resources into subsequent GPU tests. Wrapped each
test body in ``try/finally`` so ``wrapped.close()`` always
runs. Done programmatically via a one-shot Python rewrite
(8 lines of new indent + 2 lines of try/finally per site) to
keep the diff mechanical.

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker: 313 passed / 4 skipped /
  162 deselected / 0 failed.
- GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped /
  0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
Two related fixes landing together — both Codex-flagged after the
previous round closed gaps #1 (persistent peak), #2 (T_reduce), and #5
(per-chunk contention). Bundled because both touch the cost-model
search path and add regression tests to the same file.

## Fix 1: searcher fast-path applies cap layering (closes #1 fully)

909fc9e fixed cost/memory.py::estimate_peak to layer the
hot_iter_peak_cap (cap only the activation portion, preserve
model_state_present). The same bug existed in search/exhaustive.py:649
inline fast path — it still applied the cap as a flat clamp, erasing
the Adam state contribution.

Codex synthetic repro: search() returned predicted_peak_bytes=78MB
while estimate_peak() for the same config returned 7.09GB — a 90×
divergence on full-FT shapes.

Extracted shared helper apply_hot_iter_cap(raw_peak, model_state_present,
measured_cap, layout) -> int into cost/memory.py. Used at three sites
to eliminate drift risk:
  - cost/memory.py::estimate_peak (replaces inline 909fc9e code)
  - search/exhaustive.py inline fast path (the Codex bug site)
  - search/exhaustive.py _cap_dominates probe (also was naive)

The _cap_dominates shortcut at line 522 was widening max_sum based on
the false claim "predicted_peak collapses to alpha * hot_cap
independent of (n_persist + n_buffer)." Tightened to use the layered
cap with n_persist=N_chunk worst-case probe — shortcut now activates
only when the layered worst-case cap fits in capacity. LoRA-shape
efficiency win preserved; full-FT shapes correctly exit the shortcut.

New regression test test_search_fast_path_cap_preserves_full_ft_model_state
asserts search()'s predicted_peak_bytes agrees with estimate_peak()'s
output and clears the alpha * model_state_present floor. Falsification
verified: temporarily reverting the inline fix reproduces the exact
74MB-vs-7GB Codex pattern.

## Fix 2: phase-2 measured-wall override gates on n_swap == 0

e8f45fd added the per-chunk timeline-overlap bandwidth model on the
analytical path. But cost/runtime.py:703 (forward) and :820 (backward)
have a phase-2 measured-chunked-wall override that bypasses the
chunk_bw_fwd / chunk_bw_bwd vectors entirely. The phase-2 measurement
was captured at n_swap=0 (verified at profiler/phase2.py:117), so when
an n_swap > 0 candidate is evaluated using the measured wall, the
per-chunk SWAP contention is missed — only the swap-transfer term is
charged.

Fix: AND both override conditions with cfg.n_swap == 0. Phase-2 stays
applicable in the dominant case (no SWAP, hot-iter measured walls);
n_swap > 0 candidates fall through to the analytical path that
already handles per-chunk contention correctly.

Numerics for n_swap == 0 candidates are bit-identical (the analytical
path's per-chunk derate collapses to full bandwidth when n_swap == 0).

New regression test test_phase2_override_routes_n_swap_through_per_chunk_contention
constructs n_swap=0 vs n_swap=N_block configs under a fully-populated
phase-2 trace, asserts the gap exceeds the swap-transfer-only lower
bound (the exact pre-fix value). Falsification verified: temporarily
reverting both gates reproduces gap == swap-transfer-only.

## Test results

41/41 pass in tests/protrain/test_cost_search.py (39 prior + 2 new).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
Five test-quality refinements from CodeRabbit's third-round review.

**R3-#2 — deterministic teardown in test_dora.**

Wrap the DoRA smoke's wrap → train → assert sequence in
``try/finally`` so ``wrapped.close()`` runs even when the
loss-descent assertion fails mid-test. Without this, an early
assertion failure leaves hooks, pinned-host borrows, and CPU
adapter threads alive into subsequent GPU tests on the same
pytest session.

**R3-#3 — distinguish hook edges in test_lora_offload_mode
recording stub.**

The pre-fix ``_RecordingScheduler.ensure_chunks_resident``
recorded every container callback under the same
``"ensure_chunks_resident"`` label. The per-hook tests
(pre_forward / post_forward / post_backward fires
``ensure_chunks_resident``) then asserted only call COUNT — so a
regression that deleted the pre-forward hook factory while
post-forward still fired would still pass the count gates.

Tag each call with its originating hook edge via frame
inspection on the caller's ``co_qualname`` (Python 3.11+
guarantees the qualname captures the enclosing
``_make_lora_container_<edge>_hook`` factory). The four LoRA
container hooks all funnel through the same
``ensure_chunks_resident`` entry point but their closures live
in distinct factory functions, so the qualname uniquely
identifies the edge.

Update each per-hook test to filter on the edge-tagged label so
a regression in any single edge fails the corresponding test:

* pre_forward test: asserts ``ensure_chunks_resident:pre_forward``
  fires ≥ n_blocks times.
* post_forward test: asserts BOTH ``:pre_forward`` AND
  ``:post_forward`` fire ≥ n_containers times each (the previous
  bare ≥ 2*n_containers count was satisfied by either edge alone).
* post_backward test: asserts all four edges (pre/post fwd, pre/
  post bwd) fire ≥ n_containers times each.

The production hook factory layout is unchanged — the stub
recovers the edge from the existing closure's frame, no new
arguments thread through ``install_hooks``.

**R3-#4 — narrow protrain_model_wrapper exception scope in
test_lora_offload_mode:1117.**

The bare ``except (ValueError, RuntimeError)`` was treating any
wrapper failure as "offload setup unavailable" and skipping. A
broken ``protrain_model_wrapper`` runtime path could leave this
smoke green. Restrict the suppression to known env-failure
substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb
load, ``No module named``, and capacity/searcher gates) — same
canonical tuple D8 used at the optimizer-step site below — and
re-raise anything else. Real wrapper regressions now surface.

**R3-#5 — fail-safe CUDA teardown in
test_param_data_shape_preservation.**

Eight test functions in this module construct ``mgr / layout /
pool / host`` via ``_build_chunk_manager`` and tear them down at
the happy-path tail (``mgr.uninstall()`` / ``host.close()`` /
``del pool``). Any earlier assertion failure skipped the
teardown, leaking pinned-host borrows + CUDA buffer-pool state
into subsequent GPU tests.

Add a top-level ``_teardown_chunk_manager(mgr, host, pool)``
helper that does the best-effort 3-call teardown (each call
wrapped in its own try/except so a failure in ``uninstall``
doesn't block the ``host.close``), and wrap each test body in
``try: ... finally: _teardown_chunk_manager(...)``. Done
programmatically across all 8 tests via a one-shot Python
rewrite to keep the diff mechanical and the new structure
consistent.

**R3-#8 — replace hard-coded n_chunk_estimate=1 in
test_trace_skip_on_override.**

The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based
on the assumption that the tiny GPT-2 fixture produces a single
chunk. If the layout heuristics (``pick_S_chunk`` default,
block-discovery rules) shift such that ``N_chunk > 1``,
``min_n_buffer_for(layout, n_persist=1)`` rejects
``n_buffer_override=0`` BEFORE the wrapper reaches the
trace-skip gate the test is supposed to validate — converting
this into a flaky non-target failure.

Compute ``n_chunk_estimate`` dynamically by running the same
``discover_blocks`` → ``flatten_block_trees`` → ``build_layout``
pipeline the wrapper itself uses (with the wrapper's default
S_chunk), and pass the resulting ``layout.N_chunk`` through.
``n_persist_override = n_chunk_estimate`` then keeps the
all-persistent invariant the test relies on regardless of any
future layout-heuristic shift.

``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped
/ 0 failed. GPU-marker sweep on touched files: 40 passed /
2 skipped (single-process Mode-C downgrade for shape-preserving
placeholder paths) / 0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
…s (R5)

CodeRabbit R5 review (final pass on c996ce9 + f09be09) flagged Ruff
RUF002/RUF003 warnings for confusable unicode glyphs across the new
audit-Block-G commentary added by 2fcc1fc / b61f04e / aa0c6ba /
the per-dtype alpha lookup work. Same lint family R1-#5 and R2-#3
addressed in narrow scope before; this is the broader pass that
sweeps the rest of the protrain subtree.

Replacements (234 substitutions across 7 files):

| File                                              | alpha | x   | ∪ | Total |
|---------------------------------------------------|------:|----:|--:|------:|
| src/axolotl/integrations/protrain/cost/memory.py  |   23  |   1 | 0 |    24 |
| src/axolotl/integrations/protrain/api/model_wrapper.py | 39 | 14 | 4 |    57 |
| src/axolotl/integrations/protrain/types.py        |   23  |   6 | 2 |    31 |
| src/axolotl/integrations/protrain/DESIGN.md       |   19  |  17 | 0 |    36 |
| tests/protrain/test_modec_steady_peak_accuracy.py |    8  |   5 | 1 |    14 |
| tests/protrain/test_init_transient_peak.py        |    6  |   7 | 0 |    13 |
| tests/protrain/test_alpha_per_dtype.py            |   38  |   0 | 0 |    38 |

Substitution rules:
- Greek small letter alpha (U+03B1) → ``alpha``
- Multiplication sign (U+00D7) → ``x``
- Union operator (U+222A) → ``|`` (also the Python set-union operator,
  so doubly appropriate)

All replacements are in docstrings, comments, and pytest-parametrize
ID strings — zero changes to function names, type names, control
flow, or assertion text. ``param_to_chunk`` typed dict keys, set
literals, and any Python operator usage of ``|`` are unaffected.

Test parametrize IDs change cosmetically (e.g.
``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → α=1.10]`` ⇒
``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → alpha=1.10]``)
— the ``→`` arrow remains unchanged (Ruff doesn't flag it; CodeRabbit
flagged only ``α``/``×``/``∪`` explicitly).

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker sweep: 313 passed / 4 skipped /
  162 deselected / 0 failed.
- Ruff RUF002/RUF003 warnings across the seven touched protrain
  files: 234 → 0.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit that referenced this pull request May 28, 2026
…est fixes

Seven Minor items from the CodeRabbit full-diff re-scan on
commit ``55377e5d``.

**F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper``
8-bit warning (``api/optim_wrapper.py:802-815``).**

The warning told users to set ``protrain_force_all_persistent: true``
to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't
mention that ``protrain_force_all_persistent`` is ignored while
``protrain_auto_mode`` is on (the auto-mode selector picks the mode
itself based on capacity). Expanded the warning to instruct users
to set ``protrain_auto_mode: false`` AND
``protrain_force_all_persistent: true`` together.

**F-#4 — Unify fragmentation-alpha docs in DESIGN.md.**

Module summaries at lines 49 (``cost/memory.py``) and 118
(``memory.py`` module spec) still described a fixed ``alpha=1.10``
while Design Decision 1 documents the per-dtype lookup
(``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both
summaries to reference the per-dtype helper
(``alpha_fragmentation_for_dtype``) and the design decision section.

**F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.**

Line 109 (``block/checkpoint.py`` module spec) said
``use_reentrant=False``, which matches the actual implementation
(verified via ``grep`` against ``block/checkpoint.py:99``). Line 290
(audit Block G analysis) claimed ``use_reentrant=True, the
production wrap`` — stale and incorrect. Updated the analysis text
to acknowledge ``use_reentrant=False`` is the production wrap and
re-stated the per-block-input residual mechanism in a form
compatible with the non-reentrant variant (each CKPT block's
saved-tensors-hooks recompute frame holds the block input, which
is what produces the linear-in-N_block activation footprint the
audit data exposes).

**F-#8 — Centralized CUDA-availability guard in
``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.**

The helper unconditionally returned ``torch.device("cuda:0")``,
so a custom marker filter or conftest override that lands the
module in a CPU-only context would surface as a torch error
before any test body. Added a
``pytest.skip("CUDA not available; ...")`` early-return so every
gpu-marked test in the module gets a clean skip.

**F-#9 — Replace silent ``try/except: pass`` with
``contextlib.suppress(Exception)`` in
``tests/protrain/test_lora_offload_mode.py``.**

Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044
— each had the same ``for h in handles: try: h.remove() except
Exception: pass`` pattern that Ruff S110 flags. Replaced with
``contextlib.suppress(Exception)`` over the loop. Semantics
unchanged (best-effort cleanup, tolerate already-removed handles
or torch shutting down mid-test); intent now documented by the
context manager.

**F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.**

Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``.

**F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of
``test_trace_skip_on_override.py``.**

``test_run_trace_skipped_on_override_full_path`` (L255-282),
``test_run_trace_invoked_without_override`` (L319-337), and
``test_partial_overrides_do_not_skip_trace`` (L381-400) each
called ``wrapped.close()`` only on the success path — assertion
failures earlier in the test body would skip the close and leak
CUDA + chunk resources into subsequent GPU tests. Wrapped each
test body in ``try/finally`` so ``wrapped.close()`` always
runs. Done programmatically via a one-shot Python rewrite
(8 lines of new indent + 2 lines of try/finally per site) to
keep the diff mechanical.

### Test gates

- ``pre-commit run --all-files`` ALL green (ruff / ruff-format /
  mypy / bandit / yaml / eol / whitespace).
- ``tests/protrain/`` default-marker: 313 passed / 4 skipped /
  162 deselected / 0 failed.
- GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped /
  0 failed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant