fix(protrain): auto-offload memory-bound VL models; compose lora_mlp_kernel with offload#39
Conversation
…kernel with offload Three fixes so ProTrain auto-mode trains a memory-bound VL model (e.g. Qwen3.5-27B QLoRA on 2x24GB) with no manual tuning: - The VL vision tower is invisible to the text-only profiling batch, so estimate_peak under-counts ~3.7 GiB and the searcher picks an all-resident config that OOMs. Reserve extra capacity headroom when a vision_config / visual module is present so it offloads instead. - Lift the lora_mlp_kernel => no-offload pin (now opt-in via protrain_lora_mlp_forbid_offload). The v61 LoRA_MLPBackward shape mismatch it guarded is already fixed by the unconditional shape-preserving placeholders, so the fused MLP kernel composes with offload (~3% faster per step, validated). - Complete a partial residency override with safe defaults + warn instead of silently falling back to the auto search; document the all-four-knob requirement on the override fields. Validated on Qwen3.5-27B QLoRA seq-1024, 2x3090 Ti NVLink: auto-mode now offloads (n_persist=34) and trains at 5.05 s/it, vs OOM before.
📝 WalkthroughWalkthroughProTrain integration receives three coordinated enhancements: vision tower detection to reserve additional GPU capacity headroom during auto-search, a new boolean configuration to forbid offloading when using LoRA MLP kernels, and partial override completion logic that fills missing residency knob values with defaults while warning the user. ChangesProTrain Vision Tower Detection and Override Handling
🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
src/axolotl/integrations/protrain/plugin.py (1)
1645-1648: ⚡ Quick winCondense these new inline comments to one short line each.
Both added rationale comments are multi-line; compress to single-line WHY comments.
As per coding guidelines
src/**/*.py: "Keep comments to one short line maximum".Also applies to: 1697-1699
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/integrations/protrain/plugin.py` around lines 1645 - 1648, Two multi-line rationale comments near the wrapper logic (referencing the variables n_persist, n_buffer, n_swap, n_checkpoint and the wrapper/searcher behavior) should be condensed into single-line comments each; replace the two multi-line WHY blocks around the wrapper completion logic (also present near the same pattern at the later location) with concise one-line comments that state: (1) the wrapper bypasses the searcher only when all four flags are set, and (2) we fill missing flags with safe defaults so partial overrides behave as intended.src/axolotl/integrations/protrain/api/model_wrapper.py (1)
72-74: ⚡ Quick winShorten the new inline rationale to a single short line.
This new comment is multi-line; keep it to one concise line per repo convention.
Proposed edit
-# A vision tower the text-only profiling batch never exercises is invisible to -# estimate_peak, so reserve extra capacity to force offload instead of OOM. +# Reserve extra GPU headroom when a vision tower is unprofiled by text-only tracing.As per coding guidelines
src/**/*.py: "Keep comments to one short line maximum".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/integrations/protrain/api/model_wrapper.py` around lines 72 - 74, Replace the multi-line rationale above _VL_VISION_HEADROOM_BYTES with a single short line comment (one sentence) explaining purpose, e.g., "Reserve extra headroom to force vision offload and avoid OOM during text-only profiling." Keep it concise and retain the constant name _VL_VISION_HEADROOM_BYTES and the numeric value.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/integrations/protrain/api/model_wrapper.py`:
- Around line 77-83: The file fails CI due to formatting drift; run the ruff
formatter on the file and commit the changes. Specifically, run `ruff format`
(or your project's ruff command) against
src/axolotl/integrations/protrain/api/model_wrapper.py and ensure the function
_has_unprofiled_vision_tower and surrounding code are reformatted, then add and
commit the reformatted file so CI passes.
In `@src/axolotl/integrations/protrain/plugin.py`:
- Around line 1665-1673: The file fails the linter/formatter; run the ruff
formatter on src/axolotl/integrations/protrain/plugin.py (or apply ruff format
to the working tree) to fix formatting issues around the LOG.warning call and
surrounding code (e.g., the LOG.warning invocation that references _set,
n_persist_override, n_buffer_override, n_swap_override, n_checkpoint_override).
Ensure the formatted output conforms to project ruff settings and then re-run
CI.
- Around line 1700-1707: Update the protrain_model_wrapper docstring to reflect
that the forbid_activation_offload flag is read from
cfg.protrain_lora_mlp_forbid_offload (mapped to local variable
forbid_activation_offload) instead of cfg.lora_mlp_kernel; edit the docstring
text and any parameter descriptions to mention protrain_lora_mlp_forbid_offload
and the behavior ("refuse n_offload>0 candidates") so the docs match the code
path that sets forbid_activation_offload.
---
Nitpick comments:
In `@src/axolotl/integrations/protrain/api/model_wrapper.py`:
- Around line 72-74: Replace the multi-line rationale above
_VL_VISION_HEADROOM_BYTES with a single short line comment (one sentence)
explaining purpose, e.g., "Reserve extra headroom to force vision offload and
avoid OOM during text-only profiling." Keep it concise and retain the constant
name _VL_VISION_HEADROOM_BYTES and the numeric value.
In `@src/axolotl/integrations/protrain/plugin.py`:
- Around line 1645-1648: Two multi-line rationale comments near the wrapper
logic (referencing the variables n_persist, n_buffer, n_swap, n_checkpoint and
the wrapper/searcher behavior) should be condensed into single-line comments
each; replace the two multi-line WHY blocks around the wrapper completion logic
(also present near the same pattern at the later location) with concise one-line
comments that state: (1) the wrapper bypasses the searcher only when all four
flags are set, and (2) we fill missing flags with safe defaults so partial
overrides behave as intended.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4e554c56-2725-46f6-92c5-92dd88b2078e
📒 Files selected for processing (3)
src/axolotl/integrations/protrain/api/model_wrapper.pysrc/axolotl/integrations/protrain/args.pysrc/axolotl/integrations/protrain/plugin.py
| def _has_unprofiled_vision_tower(model: "nn.Module") -> bool: | ||
| cfg = getattr(model, "config", None) | ||
| if cfg is not None and getattr(cfg, "vision_config", None) is not None: | ||
| return True | ||
| for name, _ in model.named_modules(): | ||
| if name.rsplit(".", 1)[-1] in {"visual", "vision_tower", "vision_model"}: | ||
| return True |
There was a problem hiding this comment.
Apply ruff format for this file before merge.
CI is currently failing on this file due formatting drift; please commit formatter output.
🧰 Tools
🪛 GitHub Actions: lint / 0_pre-commit.txt
[error] 83-83: ruff-format reformatting changed this file (2 files reformatted). Apply formatting changes committed by the hook.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/integrations/protrain/api/model_wrapper.py` around lines 77 - 83,
The file fails CI due to formatting drift; run the ruff formatter on the file
and commit the changes. Specifically, run `ruff format` (or your project's ruff
command) against src/axolotl/integrations/protrain/api/model_wrapper.py and
ensure the function _has_unprofiled_vision_tower and surrounding code are
reformatted, then add and commit the reformatted file so CI passes.
| LOG.warning( | ||
| "ProTrain: partial residency override (set: %s). The override " | ||
| "path needs all four of n_persist/n_buffer/n_swap/n_checkpoint " | ||
| "or it silently falls back to the auto search. Filled unset " | ||
| "knobs: n_persist=%d n_buffer=%d n_swap=%d n_checkpoint=%d. Set " | ||
| "all four explicitly to suppress this.", | ||
| ", ".join(_set), n_persist_override, n_buffer_override, | ||
| n_swap_override, n_checkpoint_override, | ||
| ) |
There was a problem hiding this comment.
Apply ruff format for this file before merge.
CI indicates formatting changes are still required in this file.
🧰 Tools
🪛 GitHub Actions: lint / 0_pre-commit.txt
[error] 1668-1668: ruff-format reformatting changed this file (2 files reformatted). Apply formatting changes committed by the hook.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/integrations/protrain/plugin.py` around lines 1665 - 1673, The
file fails the linter/formatter; run the ruff formatter on
src/axolotl/integrations/protrain/plugin.py (or apply ruff format to the working
tree) to fix formatting issues around the LOG.warning call and surrounding code
(e.g., the LOG.warning invocation that references _set, n_persist_override,
n_buffer_override, n_swap_override, n_checkpoint_override). Ensure the formatted
output conforms to project ruff settings and then re-run CI.
| forbid_activation_offload = bool( | ||
| getattr(cfg, "protrain_lora_mlp_forbid_offload", False) | ||
| ) | ||
| if forbid_activation_offload: | ||
| LOG.info( | ||
| "ProTrain: cfg.lora_mlp_kernel=True; searcher will refuse " | ||
| "n_offload>0 candidates." | ||
| "ProTrain: protrain_lora_mlp_forbid_offload=True; searcher will " | ||
| "refuse n_offload>0 candidates." | ||
| ) |
There was a problem hiding this comment.
Update the wrapper docstring to match this new knob source.
This hunk now uses protrain_lora_mlp_forbid_offload, but protrain_model_wrapper docs still describe forbid_activation_offload as coming from cfg.lora_mlp_kernel.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/integrations/protrain/plugin.py` around lines 1700 - 1707, Update
the protrain_model_wrapper docstring to reflect that the
forbid_activation_offload flag is read from cfg.protrain_lora_mlp_forbid_offload
(mapped to local variable forbid_activation_offload) instead of
cfg.lora_mlp_kernel; edit the docstring text and any parameter descriptions to
mention protrain_lora_mlp_forbid_offload and the behavior ("refuse n_offload>0
candidates") so the docs match the code path that sets
forbid_activation_offload.
- ruff format model_wrapper.py and plugin.py (CI formatting drift) - update protrain_model_wrapper docstring: forbid_activation_offload now reads from protrain_lora_mlp_forbid_offload, not cfg.lora_mlp_kernel - add docstring to _has_unprofiled_vision_tower
|
@coderabbitai these three findings were already addressed in
The re-flag cites the lint run from the prior commit |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check the latest commit on the PR branch
git log --oneline -5
# Check formatting of the relevant files
echo "=== model_wrapper.py (relevant sections) ==="
cat src/axolotl/integrations/protrain/api/model_wrapper.py
echo ""
echo "=== plugin.py (relevant sections) ==="
cat src/axolotl/integrations/protrain/plugin.pyLength of output: 218050
The earlier flags were from the lint snapshot of |
Summary
Three fixes so ProTrain auto-mode trains a memory-bound VL model (e.g. Qwen3.5-27B QLoRA on 2×24 GB) with no manual tuning. Before this, auto-mode OOM'd, and the only working route was a hand-tuned 4-knob override +
lora_mlp_kernel: false+ddp_find_unused_parameters: true.Fixes
A. VL vision tower invisible to the cost model → auto-mode OOMs (
api/model_wrapper.py)The profiling batch is text-only (no
pixel_values) anddiscover_blocksonly finds the text-decoder tree, soestimate_peaknever counts themodel.visual.*footprint (~3.7 GiB here). The searcher concludes "fits resident", picksn_offload=0, goes runtime-inert, then OOMs in the forward (predicted 19.76 vs real ~23.5 GiB on a 24 GB card).Fix:
_has_unprofiled_vision_tower()reserves_VL_VISION_HEADROOM_BYTES(4 GiB) of extra capacity headroom when avision_config/visualmodule is present, so the searcher offloads instead. (Follow-up: actually profile the vision tower by injecting a representative image, removing the under-count at source.)B.
lora_mlp_kernelthen no-offload pin is stale (plugin.py,args.py)The v61
LoRA_MLPBackwardshape-mismatch this pin guards against is already fixed by the unconditional shape-preserving placeholders (commitdf3dad1ab), so the fused MLP kernel now composes with offload. The pin was never lifted.Fix: default the forbid off; re-enable via new opt-out knob
protrain_lora_mlp_forbid_offload. Validated:lora_mlp_kernel: true+ offload trains cleanly, ~3% faster per step than kernel-off.C. Partial residency override silently no-ops (
plugin.py,args.py)The override path bypasses the searcher only when all four of
n_persist/n_buffer/n_swap/n_checkpointare set; a partial set silently falls back to the auto search (which OOMs). Undocumented.Fix: complete a partial override with safe defaults (
n_persist=0, n_buffer=2, n_swap=0, n_checkpoint=0) + a clear warning; document the all-four requirement on the override field descriptions.Validation (Qwen3.5-27B QLoRA, seq 1024, 2×3090 Ti NVLink)
Auto-mode now beats hand-tuning by ~20% because it offloads only what's needed. 0.6B text smoke unaffected (happy path unchanged).
ruff checkclean.Deliberately out of scope
The global
safety_fraction(prefer offload when predicted peak is near capacity for any model) is not included — it changes searcher selection for every near-capacity config. The VL-headroom fix covers the failure we hit; the general margin is a good follow-up.No linked issues — this PR is self-contained ProTrain memory-management work.