fix: allow kwargs-only forward on PEFT ModulesToSaveWrapper#5
Conversation
PEFT's ``AuxiliaryTrainingWrapper.forward`` requires a positional ``x``:
def forward(self, x: torch.Tensor, *args, **kwargs):
self._check_forward_args(x, *args, **kwargs)
...
return self._forward_wrapped(x, *args, **kwargs)
This works for modules called positionally (e.g. ``embed_tokens``) but
crashes when a ``lora_modules_to_save`` entry is invoked with keyword
arguments only. Gemma 4's VLM forward path does this for both
``vision_tower`` and ``embed_vision``:
vision_outputs = self.vision_tower(pixel_values=..., ...)
... = self.embed_vision(inputs_embeds=...)
Producing:
TypeError: AuxiliaryTrainingWrapper.forward() missing 1 required
positional argument: 'x'
at the first forward pass (e.g. eval step 0).
Add an idempotent monkeypatch that rewrites ``forward``,
``_forward_wrapped``, and ``_forward_wrapped_passthrough`` to accept
``*args, **kwargs``. ``_mixed_batch_forward`` is untouched — it needs
positional input for sub-batch indexing and only runs when
``adapter_names`` is passed (multi-adapter inference), which doesn't
happen in single-adapter training. ``_check_forward_args`` is
short-circuited when no positional args are provided; its only work is
validating ``len(x) == len(adapter_names)`` which is itself gated on
``adapter_names`` being set.
Applied from ``PatchManager._apply_adapter_patches`` (gated on
``cfg.adapter``) before ``get_peft_model`` runs, so every wrapper
constructed this training run gets the fixed forward. Backward-compatible:
positional calls still work (verified in tests against real PEFT classes).
Reproduces with:
adapter: lora
lora_modules_to_save:
- vision_tower # or any module called with keyword args
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis pull request introduces a new monkeypatch for PEFT's module wrapper to support keyword-argument-based forward calls alongside positional arguments. The patch is installed when adapter support is enabled and includes comprehensive test coverage for various forward signature patterns and adapter states. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/peft_modules_to_save.py (1)
46-50: Consider patchingAuxiliaryTrainingWrapper._forward_wrappedas well for completeness.The patch sets
forwardonAuxiliaryTrainingWrapperbut only patches_forward_wrappedand_forward_wrapped_passthroughonModulesToSaveWrapper. IfAuxiliaryTrainingWrapperis ever instantiated directly (rather than through its subclassModulesToSaveWrapper), it would use the patchedforwardbut call its unpatched_forward_wrapped, potentially breaking kwargs-only forwarding.This may be intentional if PEFT only instantiates
ModulesToSaveWrapperformodules_to_saveentries. However, for robustness, consider also patchingAuxiliaryTrainingWrapper._forward_wrappedand_forward_wrapped_passthrough.Proposed patch for completeness
AuxiliaryTrainingWrapper.forward = _patched_forward + AuxiliaryTrainingWrapper._forward_wrapped = _patched_forward_wrapped + AuxiliaryTrainingWrapper._forward_wrapped_passthrough = ( + _patched_forward_wrapped_passthrough + ) ModulesToSaveWrapper._forward_wrapped = _patched_forward_wrapped ModulesToSaveWrapper._forward_wrapped_passthrough = ( _patched_forward_wrapped_passthrough )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/monkeypatch/peft_modules_to_save.py` around lines 46 - 50, The patch currently replaces AuxiliaryTrainingWrapper.forward and sets _forward_wrapped and _forward_wrapped_passthrough only on ModulesToSaveWrapper; update AuxiliaryTrainingWrapper too so direct instantiation won't call the old unpatched helpers: assign AuxiliaryTrainingWrapper._forward_wrapped = _patched_forward_wrapped and AuxiliaryTrainingWrapper._forward_wrapped_passthrough = _patched_forward_wrapped_passthrough (similar to how AuxiliaryTrainingWrapper.forward = _patched_forward is set) so both helper methods are consistently patched alongside forward.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/axolotl/monkeypatch/peft_modules_to_save.py`:
- Around line 46-50: The patch currently replaces
AuxiliaryTrainingWrapper.forward and sets _forward_wrapped and
_forward_wrapped_passthrough only on ModulesToSaveWrapper; update
AuxiliaryTrainingWrapper too so direct instantiation won't call the old
unpatched helpers: assign AuxiliaryTrainingWrapper._forward_wrapped =
_patched_forward_wrapped and
AuxiliaryTrainingWrapper._forward_wrapped_passthrough =
_patched_forward_wrapped_passthrough (similar to how
AuxiliaryTrainingWrapper.forward = _patched_forward is set) so both helper
methods are consistently patched alongside forward.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bc44d88e-287b-464f-bd33-32d671fd4761
📒 Files selected for processing (3)
src/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/peft_modules_to_save.pytests/monkeypatch/test_peft_modules_to_save.py
Two related fixes landing together — both Codex-flagged after the previous round closed gaps #1 (persistent peak), #2 (T_reduce), and #5 (per-chunk contention). Bundled because both touch the cost-model search path and add regression tests to the same file. ## Fix 1: searcher fast-path applies cap layering (closes #1 fully) 909fc9e fixed cost/memory.py::estimate_peak to layer the hot_iter_peak_cap (cap only the activation portion, preserve model_state_present). The same bug existed in search/exhaustive.py:649 inline fast path — it still applied the cap as a flat clamp, erasing the Adam state contribution. Codex synthetic repro: search() returned predicted_peak_bytes=78MB while estimate_peak() for the same config returned 7.09GB — a 90× divergence on full-FT shapes. Extracted shared helper apply_hot_iter_cap(raw_peak, model_state_present, measured_cap, layout) -> int into cost/memory.py. Used at three sites to eliminate drift risk: - cost/memory.py::estimate_peak (replaces inline 909fc9e code) - search/exhaustive.py inline fast path (the Codex bug site) - search/exhaustive.py _cap_dominates probe (also was naive) The _cap_dominates shortcut at line 522 was widening max_sum based on the false claim "predicted_peak collapses to alpha * hot_cap independent of (n_persist + n_buffer)." Tightened to use the layered cap with n_persist=N_chunk worst-case probe — shortcut now activates only when the layered worst-case cap fits in capacity. LoRA-shape efficiency win preserved; full-FT shapes correctly exit the shortcut. New regression test test_search_fast_path_cap_preserves_full_ft_model_state asserts search()'s predicted_peak_bytes agrees with estimate_peak()'s output and clears the alpha * model_state_present floor. Falsification verified: temporarily reverting the inline fix reproduces the exact 74MB-vs-7GB Codex pattern. ## Fix 2: phase-2 measured-wall override gates on n_swap == 0 e8f45fd added the per-chunk timeline-overlap bandwidth model on the analytical path. But cost/runtime.py:703 (forward) and :820 (backward) have a phase-2 measured-chunked-wall override that bypasses the chunk_bw_fwd / chunk_bw_bwd vectors entirely. The phase-2 measurement was captured at n_swap=0 (verified at profiler/phase2.py:117), so when an n_swap > 0 candidate is evaluated using the measured wall, the per-chunk SWAP contention is missed — only the swap-transfer term is charged. Fix: AND both override conditions with cfg.n_swap == 0. Phase-2 stays applicable in the dominant case (no SWAP, hot-iter measured walls); n_swap > 0 candidates fall through to the analytical path that already handles per-chunk contention correctly. Numerics for n_swap == 0 candidates are bit-identical (the analytical path's per-chunk derate collapses to full bandwidth when n_swap == 0). New regression test test_phase2_override_routes_n_swap_through_per_chunk_contention constructs n_swap=0 vs n_swap=N_block configs under a fully-populated phase-2 trace, asserts the gap exceeds the swap-transfer-only lower bound (the exact pre-fix value). Falsification verified: temporarily reverting both gates reproduces gap == swap-transfer-only. ## Test results 41/41 pass in tests/protrain/test_cost_search.py (39 prior + 2 new). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Five test-quality refinements from CodeRabbit's third-round review. **R3-#2 — deterministic teardown in test_dora.** Wrap the DoRA smoke's wrap → train → assert sequence in ``try/finally`` so ``wrapped.close()`` runs even when the loss-descent assertion fails mid-test. Without this, an early assertion failure leaves hooks, pinned-host borrows, and CPU adapter threads alive into subsequent GPU tests on the same pytest session. **R3-#3 — distinguish hook edges in test_lora_offload_mode recording stub.** The pre-fix ``_RecordingScheduler.ensure_chunks_resident`` recorded every container callback under the same ``"ensure_chunks_resident"`` label. The per-hook tests (pre_forward / post_forward / post_backward fires ``ensure_chunks_resident``) then asserted only call COUNT — so a regression that deleted the pre-forward hook factory while post-forward still fired would still pass the count gates. Tag each call with its originating hook edge via frame inspection on the caller's ``co_qualname`` (Python 3.11+ guarantees the qualname captures the enclosing ``_make_lora_container_<edge>_hook`` factory). The four LoRA container hooks all funnel through the same ``ensure_chunks_resident`` entry point but their closures live in distinct factory functions, so the qualname uniquely identifies the edge. Update each per-hook test to filter on the edge-tagged label so a regression in any single edge fails the corresponding test: * pre_forward test: asserts ``ensure_chunks_resident:pre_forward`` fires ≥ n_blocks times. * post_forward test: asserts BOTH ``:pre_forward`` AND ``:post_forward`` fire ≥ n_containers times each (the previous bare ≥ 2*n_containers count was satisfied by either edge alone). * post_backward test: asserts all four edges (pre/post fwd, pre/ post bwd) fire ≥ n_containers times each. The production hook factory layout is unchanged — the stub recovers the edge from the existing closure's frame, no new arguments thread through ``install_hooks``. **R3-#4 — narrow protrain_model_wrapper exception scope in test_lora_offload_mode:1117.** The bare ``except (ValueError, RuntimeError)`` was treating any wrapper failure as "offload setup unavailable" and skipping. A broken ``protrain_model_wrapper`` runtime path could leave this smoke green. Restrict the suppression to known env-failure substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb load, ``No module named``, and capacity/searcher gates) — same canonical tuple D8 used at the optimizer-step site below — and re-raise anything else. Real wrapper regressions now surface. **R3-#5 — fail-safe CUDA teardown in test_param_data_shape_preservation.** Eight test functions in this module construct ``mgr / layout / pool / host`` via ``_build_chunk_manager`` and tear them down at the happy-path tail (``mgr.uninstall()`` / ``host.close()`` / ``del pool``). Any earlier assertion failure skipped the teardown, leaking pinned-host borrows + CUDA buffer-pool state into subsequent GPU tests. Add a top-level ``_teardown_chunk_manager(mgr, host, pool)`` helper that does the best-effort 3-call teardown (each call wrapped in its own try/except so a failure in ``uninstall`` doesn't block the ``host.close``), and wrap each test body in ``try: ... finally: _teardown_chunk_manager(...)``. Done programmatically across all 8 tests via a one-shot Python rewrite to keep the diff mechanical and the new structure consistent. **R3-#8 — replace hard-coded n_chunk_estimate=1 in test_trace_skip_on_override.** The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based on the assumption that the tiny GPT-2 fixture produces a single chunk. If the layout heuristics (``pick_S_chunk`` default, block-discovery rules) shift such that ``N_chunk > 1``, ``min_n_buffer_for(layout, n_persist=1)`` rejects ``n_buffer_override=0`` BEFORE the wrapper reaches the trace-skip gate the test is supposed to validate — converting this into a flaky non-target failure. Compute ``n_chunk_estimate`` dynamically by running the same ``discover_blocks`` → ``flatten_block_trees`` → ``build_layout`` pipeline the wrapper itself uses (with the wrapper's default S_chunk), and pass the resulting ``layout.N_chunk`` through. ``n_persist_override = n_chunk_estimate`` then keeps the all-persistent invariant the test relies on regardless of any future layout-heuristic shift. ``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped / 0 failed. GPU-marker sweep on touched files: 40 passed / 2 skipped (single-process Mode-C downgrade for shape-preserving placeholder paths) / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…s (R5) CodeRabbit R5 review (final pass on c996ce9 + f09be09) flagged Ruff RUF002/RUF003 warnings for confusable unicode glyphs across the new audit-Block-G commentary added by 2fcc1fc / b61f04e / aa0c6ba / the per-dtype alpha lookup work. Same lint family R1-#5 and R2-#3 addressed in narrow scope before; this is the broader pass that sweeps the rest of the protrain subtree. Replacements (234 substitutions across 7 files): | File | alpha | x | ∪ | Total | |---------------------------------------------------|------:|----:|--:|------:| | src/axolotl/integrations/protrain/cost/memory.py | 23 | 1 | 0 | 24 | | src/axolotl/integrations/protrain/api/model_wrapper.py | 39 | 14 | 4 | 57 | | src/axolotl/integrations/protrain/types.py | 23 | 6 | 2 | 31 | | src/axolotl/integrations/protrain/DESIGN.md | 19 | 17 | 0 | 36 | | tests/protrain/test_modec_steady_peak_accuracy.py | 8 | 5 | 1 | 14 | | tests/protrain/test_init_transient_peak.py | 6 | 7 | 0 | 13 | | tests/protrain/test_alpha_per_dtype.py | 38 | 0 | 0 | 38 | Substitution rules: - Greek small letter alpha (U+03B1) → ``alpha`` - Multiplication sign (U+00D7) → ``x`` - Union operator (U+222A) → ``|`` (also the Python set-union operator, so doubly appropriate) All replacements are in docstrings, comments, and pytest-parametrize ID strings — zero changes to function names, type names, control flow, or assertion text. ``param_to_chunk`` typed dict keys, set literals, and any Python operator usage of ``|`` are unaffected. Test parametrize IDs change cosmetically (e.g. ``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → α=1.10]`` ⇒ ``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → alpha=1.10]``) — the ``→`` arrow remains unchanged (Ruff doesn't flag it; CodeRabbit flagged only ``α``/``×``/``∪`` explicitly). ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker sweep: 313 passed / 4 skipped / 162 deselected / 0 failed. - Ruff RUF002/RUF003 warnings across the seven touched protrain files: 234 → 0. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…est fixes Seven Minor items from the CodeRabbit full-diff re-scan on commit ``55377e5d``. **F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper`` 8-bit warning (``api/optim_wrapper.py:802-815``).** The warning told users to set ``protrain_force_all_persistent: true`` to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't mention that ``protrain_force_all_persistent`` is ignored while ``protrain_auto_mode`` is on (the auto-mode selector picks the mode itself based on capacity). Expanded the warning to instruct users to set ``protrain_auto_mode: false`` AND ``protrain_force_all_persistent: true`` together. **F-#4 — Unify fragmentation-alpha docs in DESIGN.md.** Module summaries at lines 49 (``cost/memory.py``) and 118 (``memory.py`` module spec) still described a fixed ``alpha=1.10`` while Design Decision 1 documents the per-dtype lookup (``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both summaries to reference the per-dtype helper (``alpha_fragmentation_for_dtype``) and the design decision section. **F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.** Line 109 (``block/checkpoint.py`` module spec) said ``use_reentrant=False``, which matches the actual implementation (verified via ``grep`` against ``block/checkpoint.py:99``). Line 290 (audit Block G analysis) claimed ``use_reentrant=True, the production wrap`` — stale and incorrect. Updated the analysis text to acknowledge ``use_reentrant=False`` is the production wrap and re-stated the per-block-input residual mechanism in a form compatible with the non-reentrant variant (each CKPT block's saved-tensors-hooks recompute frame holds the block input, which is what produces the linear-in-N_block activation footprint the audit data exposes). **F-#8 — Centralized CUDA-availability guard in ``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.** The helper unconditionally returned ``torch.device("cuda:0")``, so a custom marker filter or conftest override that lands the module in a CPU-only context would surface as a torch error before any test body. Added a ``pytest.skip("CUDA not available; ...")`` early-return so every gpu-marked test in the module gets a clean skip. **F-#9 — Replace silent ``try/except: pass`` with ``contextlib.suppress(Exception)`` in ``tests/protrain/test_lora_offload_mode.py``.** Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044 — each had the same ``for h in handles: try: h.remove() except Exception: pass`` pattern that Ruff S110 flags. Replaced with ``contextlib.suppress(Exception)`` over the loop. Semantics unchanged (best-effort cleanup, tolerate already-removed handles or torch shutting down mid-test); intent now documented by the context manager. **F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.** Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``. **F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of ``test_trace_skip_on_override.py``.** ``test_run_trace_skipped_on_override_full_path`` (L255-282), ``test_run_trace_invoked_without_override`` (L319-337), and ``test_partial_overrides_do_not_skip_trace`` (L381-400) each called ``wrapped.close()`` only on the success path — assertion failures earlier in the test body would skip the close and leak CUDA + chunk resources into subsequent GPU tests. Wrapped each test body in ``try/finally`` so ``wrapped.close()`` always runs. Done programmatically via a one-shot Python rewrite (8 lines of new indent + 2 lines of try/finally per site) to keep the diff mechanical. ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker: 313 passed / 4 skipped / 162 deselected / 0 failed. - GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two related fixes landing together — both Codex-flagged after the previous round closed gaps #1 (persistent peak), #2 (T_reduce), and #5 (per-chunk contention). Bundled because both touch the cost-model search path and add regression tests to the same file. ## Fix 1: searcher fast-path applies cap layering (closes #1 fully) 909fc9e fixed cost/memory.py::estimate_peak to layer the hot_iter_peak_cap (cap only the activation portion, preserve model_state_present). The same bug existed in search/exhaustive.py:649 inline fast path — it still applied the cap as a flat clamp, erasing the Adam state contribution. Codex synthetic repro: search() returned predicted_peak_bytes=78MB while estimate_peak() for the same config returned 7.09GB — a 90× divergence on full-FT shapes. Extracted shared helper apply_hot_iter_cap(raw_peak, model_state_present, measured_cap, layout) -> int into cost/memory.py. Used at three sites to eliminate drift risk: - cost/memory.py::estimate_peak (replaces inline 909fc9e code) - search/exhaustive.py inline fast path (the Codex bug site) - search/exhaustive.py _cap_dominates probe (also was naive) The _cap_dominates shortcut at line 522 was widening max_sum based on the false claim "predicted_peak collapses to alpha * hot_cap independent of (n_persist + n_buffer)." Tightened to use the layered cap with n_persist=N_chunk worst-case probe — shortcut now activates only when the layered worst-case cap fits in capacity. LoRA-shape efficiency win preserved; full-FT shapes correctly exit the shortcut. New regression test test_search_fast_path_cap_preserves_full_ft_model_state asserts search()'s predicted_peak_bytes agrees with estimate_peak()'s output and clears the alpha * model_state_present floor. Falsification verified: temporarily reverting the inline fix reproduces the exact 74MB-vs-7GB Codex pattern. ## Fix 2: phase-2 measured-wall override gates on n_swap == 0 e8f45fd added the per-chunk timeline-overlap bandwidth model on the analytical path. But cost/runtime.py:703 (forward) and :820 (backward) have a phase-2 measured-chunked-wall override that bypasses the chunk_bw_fwd / chunk_bw_bwd vectors entirely. The phase-2 measurement was captured at n_swap=0 (verified at profiler/phase2.py:117), so when an n_swap > 0 candidate is evaluated using the measured wall, the per-chunk SWAP contention is missed — only the swap-transfer term is charged. Fix: AND both override conditions with cfg.n_swap == 0. Phase-2 stays applicable in the dominant case (no SWAP, hot-iter measured walls); n_swap > 0 candidates fall through to the analytical path that already handles per-chunk contention correctly. Numerics for n_swap == 0 candidates are bit-identical (the analytical path's per-chunk derate collapses to full bandwidth when n_swap == 0). New regression test test_phase2_override_routes_n_swap_through_per_chunk_contention constructs n_swap=0 vs n_swap=N_block configs under a fully-populated phase-2 trace, asserts the gap exceeds the swap-transfer-only lower bound (the exact pre-fix value). Falsification verified: temporarily reverting both gates reproduces gap == swap-transfer-only. ## Test results 41/41 pass in tests/protrain/test_cost_search.py (39 prior + 2 new). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Five test-quality refinements from CodeRabbit's third-round review. **R3-#2 — deterministic teardown in test_dora.** Wrap the DoRA smoke's wrap → train → assert sequence in ``try/finally`` so ``wrapped.close()`` runs even when the loss-descent assertion fails mid-test. Without this, an early assertion failure leaves hooks, pinned-host borrows, and CPU adapter threads alive into subsequent GPU tests on the same pytest session. **R3-#3 — distinguish hook edges in test_lora_offload_mode recording stub.** The pre-fix ``_RecordingScheduler.ensure_chunks_resident`` recorded every container callback under the same ``"ensure_chunks_resident"`` label. The per-hook tests (pre_forward / post_forward / post_backward fires ``ensure_chunks_resident``) then asserted only call COUNT — so a regression that deleted the pre-forward hook factory while post-forward still fired would still pass the count gates. Tag each call with its originating hook edge via frame inspection on the caller's ``co_qualname`` (Python 3.11+ guarantees the qualname captures the enclosing ``_make_lora_container_<edge>_hook`` factory). The four LoRA container hooks all funnel through the same ``ensure_chunks_resident`` entry point but their closures live in distinct factory functions, so the qualname uniquely identifies the edge. Update each per-hook test to filter on the edge-tagged label so a regression in any single edge fails the corresponding test: * pre_forward test: asserts ``ensure_chunks_resident:pre_forward`` fires ≥ n_blocks times. * post_forward test: asserts BOTH ``:pre_forward`` AND ``:post_forward`` fire ≥ n_containers times each (the previous bare ≥ 2*n_containers count was satisfied by either edge alone). * post_backward test: asserts all four edges (pre/post fwd, pre/ post bwd) fire ≥ n_containers times each. The production hook factory layout is unchanged — the stub recovers the edge from the existing closure's frame, no new arguments thread through ``install_hooks``. **R3-#4 — narrow protrain_model_wrapper exception scope in test_lora_offload_mode:1117.** The bare ``except (ValueError, RuntimeError)`` was treating any wrapper failure as "offload setup unavailable" and skipping. A broken ``protrain_model_wrapper`` runtime path could leave this smoke green. Restrict the suppression to known env-failure substrings (DeepSpeedCPUAdam JIT, CUDA version mismatch, bnb load, ``No module named``, and capacity/searcher gates) — same canonical tuple D8 used at the optimizer-step site below — and re-raise anything else. Real wrapper regressions now surface. **R3-#5 — fail-safe CUDA teardown in test_param_data_shape_preservation.** Eight test functions in this module construct ``mgr / layout / pool / host`` via ``_build_chunk_manager`` and tear them down at the happy-path tail (``mgr.uninstall()`` / ``host.close()`` / ``del pool``). Any earlier assertion failure skipped the teardown, leaking pinned-host borrows + CUDA buffer-pool state into subsequent GPU tests. Add a top-level ``_teardown_chunk_manager(mgr, host, pool)`` helper that does the best-effort 3-call teardown (each call wrapped in its own try/except so a failure in ``uninstall`` doesn't block the ``host.close``), and wrap each test body in ``try: ... finally: _teardown_chunk_manager(...)``. Done programmatically across all 8 tests via a one-shot Python rewrite to keep the diff mechanical and the new structure consistent. **R3-#8 — replace hard-coded n_chunk_estimate=1 in test_trace_skip_on_override.** The trace-skip e2e test hard-coded ``n_chunk_estimate = 1`` based on the assumption that the tiny GPT-2 fixture produces a single chunk. If the layout heuristics (``pick_S_chunk`` default, block-discovery rules) shift such that ``N_chunk > 1``, ``min_n_buffer_for(layout, n_persist=1)`` rejects ``n_buffer_override=0`` BEFORE the wrapper reaches the trace-skip gate the test is supposed to validate — converting this into a flaky non-target failure. Compute ``n_chunk_estimate`` dynamically by running the same ``discover_blocks`` → ``flatten_block_trees`` → ``build_layout`` pipeline the wrapper itself uses (with the wrapper's default S_chunk), and pass the resulting ``layout.N_chunk`` through. ``n_persist_override = n_chunk_estimate`` then keeps the all-persistent invariant the test relies on regardless of any future layout-heuristic shift. ``tests/protrain/`` default-marker sweep: 303 passed / 4 skipped / 0 failed. GPU-marker sweep on touched files: 40 passed / 2 skipped (single-process Mode-C downgrade for shape-preserving placeholder paths) / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…s (R5) CodeRabbit R5 review (final pass on c996ce9 + f09be09) flagged Ruff RUF002/RUF003 warnings for confusable unicode glyphs across the new audit-Block-G commentary added by 2fcc1fc / b61f04e / aa0c6ba / the per-dtype alpha lookup work. Same lint family R1-#5 and R2-#3 addressed in narrow scope before; this is the broader pass that sweeps the rest of the protrain subtree. Replacements (234 substitutions across 7 files): | File | alpha | x | ∪ | Total | |---------------------------------------------------|------:|----:|--:|------:| | src/axolotl/integrations/protrain/cost/memory.py | 23 | 1 | 0 | 24 | | src/axolotl/integrations/protrain/api/model_wrapper.py | 39 | 14 | 4 | 57 | | src/axolotl/integrations/protrain/types.py | 23 | 6 | 2 | 31 | | src/axolotl/integrations/protrain/DESIGN.md | 19 | 17 | 0 | 36 | | tests/protrain/test_modec_steady_peak_accuracy.py | 8 | 5 | 1 | 14 | | tests/protrain/test_init_transient_peak.py | 6 | 7 | 0 | 13 | | tests/protrain/test_alpha_per_dtype.py | 38 | 0 | 0 | 38 | Substitution rules: - Greek small letter alpha (U+03B1) → ``alpha`` - Multiplication sign (U+00D7) → ``x`` - Union operator (U+222A) → ``|`` (also the Python set-union operator, so doubly appropriate) All replacements are in docstrings, comments, and pytest-parametrize ID strings — zero changes to function names, type names, control flow, or assertion text. ``param_to_chunk`` typed dict keys, set literals, and any Python operator usage of ``|`` are unaffected. Test parametrize IDs change cosmetically (e.g. ``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → α=1.10]`` ⇒ ``test_alpha_lookup_by_dtype[2.0-1.1-fp16/bf16 weights → alpha=1.10]``) — the ``→`` arrow remains unchanged (Ruff doesn't flag it; CodeRabbit flagged only ``α``/``×``/``∪`` explicitly). ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker sweep: 313 passed / 4 skipped / 162 deselected / 0 failed. - Ruff RUF002/RUF003 warnings across the seven touched protrain files: 234 → 0. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…est fixes Seven Minor items from the CodeRabbit full-diff re-scan on commit ``55377e5d``. **F-#2 — Clarify Mode-A guidance in ``protrain_optimizer_wrapper`` 8-bit warning (``api/optim_wrapper.py:802-815``).** The warning told users to set ``protrain_force_all_persistent: true`` to get end-to-end 8-bit AdamW on CPU-resident chunks, but didn't mention that ``protrain_force_all_persistent`` is ignored while ``protrain_auto_mode`` is on (the auto-mode selector picks the mode itself based on capacity). Expanded the warning to instruct users to set ``protrain_auto_mode: false`` AND ``protrain_force_all_persistent: true`` together. **F-#4 — Unify fragmentation-alpha docs in DESIGN.md.** Module summaries at lines 49 (``cost/memory.py``) and 118 (``memory.py`` module spec) still described a fixed ``alpha=1.10`` while Design Decision 1 documents the per-dtype lookup (``ALPHA_FRAGMENTATION_4BIT = 0.75`` for bnb-4-bit). Aligned both summaries to reference the per-dtype helper (``alpha_fragmentation_for_dtype``) and the design decision section. **F-#5 — Resolve ``use_reentrant`` contradiction in DESIGN.md.** Line 109 (``block/checkpoint.py`` module spec) said ``use_reentrant=False``, which matches the actual implementation (verified via ``grep`` against ``block/checkpoint.py:99``). Line 290 (audit Block G analysis) claimed ``use_reentrant=True, the production wrap`` — stale and incorrect. Updated the analysis text to acknowledge ``use_reentrant=False`` is the production wrap and re-stated the per-block-input residual mechanism in a form compatible with the non-reentrant variant (each CKPT block's saved-tensors-hooks recompute frame holds the block input, which is what produces the linear-in-N_block activation footprint the audit data exposes). **F-#8 — Centralized CUDA-availability guard in ``tests/protrain/test_adamw8bit_adapter.py::_gpu_device``.** The helper unconditionally returned ``torch.device("cuda:0")``, so a custom marker filter or conftest override that lands the module in a CPU-only context would surface as a torch error before any test body. Added a ``pytest.skip("CUDA not available; ...")`` early-return so every gpu-marked test in the module gets a clean skip. **F-#9 — Replace silent ``try/except: pass`` with ``contextlib.suppress(Exception)`` in ``tests/protrain/test_lora_offload_mode.py``.** Five sites — lines 742-746, 839-843, 906-910, 981-985, 1040-1044 — each had the same ``for h in handles: try: h.remove() except Exception: pass`` pattern that Ruff S110 flags. Replaced with ``contextlib.suppress(Exception)`` over the loop. Semantics unchanged (best-effort cleanup, tolerate already-removed handles or torch shutting down mid-test); intent now documented by the context manager. **F-#10 — ASCII ``x`` in ``test_lora_offload_mode.py:1062`` docstring.** Missed in the R5 unicode sweep — ``4×3090`` ⇒ ``4x3090``. **F-#11 — ``try/finally`` for ``wrapped.close()`` in 3 sites of ``test_trace_skip_on_override.py``.** ``test_run_trace_skipped_on_override_full_path`` (L255-282), ``test_run_trace_invoked_without_override`` (L319-337), and ``test_partial_overrides_do_not_skip_trace`` (L381-400) each called ``wrapped.close()`` only on the success path — assertion failures earlier in the test body would skip the close and leak CUDA + chunk resources into subsequent GPU tests. Wrapped each test body in ``try/finally`` so ``wrapped.close()`` always runs. Done programmatically via a one-shot Python rewrite (8 lines of new indent + 2 lines of try/finally per site) to keep the diff mechanical. ### Test gates - ``pre-commit run --all-files`` ALL green (ruff / ruff-format / mypy / bandit / yaml / eol / whitespace). - ``tests/protrain/`` default-marker: 313 passed / 4 skipped / 162 deselected / 0 failed. - GPU sanity on F-touched files (GPU 5): 43 passed / 2 skipped / 0 failed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
PEFT's
AuxiliaryTrainingWrapper.forwardrequires a positionalx, so placing a kwargs-only-called module (e.g. Gemma 4'svision_tower/embed_vision) inlora_modules_to_savecrashes on the first forward pass withTypeError: ... missing 1 required positional argument: 'x'.This adds an idempotent monkeypatch that rewrites
forward,_forward_wrapped, and_forward_wrapped_passthroughonAuxiliaryTrainingWrapper/ModulesToSaveWrapperto accept*args, **kwargs— backward-compatible with existing positional callers (e.g.embed_tokens).PatchManager._apply_adapter_patches()(gated oncfg.adapter) so it runs beforeget_peft_model._mixed_batch_forwardintentionally left positional — it requires sub-batch indexing and only fires under multi-adapteradapter_names=...calls, which don't happen in single-adapter training._check_forward_argsis short-circuited when no positional input is provided (its only real work validateslen(x) == len(adapter_names), which itself requiresadapter_namesto be set).Reproduces
Test plan
pytest tests/monkeypatch/test_peft_modules_to_save.py— 7/7 passingruff check+ruff format --checkclean on touched filesvision_tower/embed_visionshape)embed_tokensshape)enable_adapters(False)) also accepts kwargs-onlyTrainableTokensWrapper(sibling subclass ofAuxiliaryTrainingWrapper) end-to-end still worksadapter_names+ kwargs-only raisesTypeError(locks the deliberately-unsupported shape)🤖 Generated with Claude Code
Summary by CodeRabbit