Skip to content

fix(protrain): auto-offload memory-bound VL models; compose lora_mlp_kernel with offload#39

Merged
thad0ctor merged 2 commits into
fix/protrain-costmodel-followupsfrom
fix/protrain-vl-automode-offload
Jun 1, 2026
Merged

fix(protrain): auto-offload memory-bound VL models; compose lora_mlp_kernel with offload#39
thad0ctor merged 2 commits into
fix/protrain-costmodel-followupsfrom
fix/protrain-vl-automode-offload

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Jun 1, 2026

Copy link
Copy Markdown
Owner

Summary

Three fixes so ProTrain auto-mode trains a memory-bound VL model (e.g. Qwen3.5-27B QLoRA on 2×24 GB) with no manual tuning. Before this, auto-mode OOM'd, and the only working route was a hand-tuned 4-knob override + lora_mlp_kernel: false + ddp_find_unused_parameters: true.

Fixes

A. VL vision tower invisible to the cost model → auto-mode OOMs (api/model_wrapper.py)

The profiling batch is text-only (no pixel_values) and discover_blocks only finds the text-decoder tree, so estimate_peak never counts the model.visual.* footprint (~3.7 GiB here). The searcher concludes "fits resident", picks n_offload=0, goes runtime-inert, then OOMs in the forward (predicted 19.76 vs real ~23.5 GiB on a 24 GB card).
Fix: _has_unprofiled_vision_tower() reserves _VL_VISION_HEADROOM_BYTES (4 GiB) of extra capacity headroom when a vision_config / visual module is present, so the searcher offloads instead. (Follow-up: actually profile the vision tower by injecting a representative image, removing the under-count at source.)

B. lora_mlp_kernel then no-offload pin is stale (plugin.py, args.py)

The v61 LoRA_MLPBackward shape-mismatch this pin guards against is already fixed by the unconditional shape-preserving placeholders (commit df3dad1ab), so the fused MLP kernel now composes with offload. The pin was never lifted.
Fix: default the forbid off; re-enable via new opt-out knob protrain_lora_mlp_forbid_offload. Validated: lora_mlp_kernel: true + offload trains cleanly, ~3% faster per step than kernel-off.

C. Partial residency override silently no-ops (plugin.py, args.py)

The override path bypasses the searcher only when all four of n_persist/n_buffer/n_swap/n_checkpoint are set; a partial set silently falls back to the auto search (which OOMs). Undocumented.
Fix: complete a partial override with safe defaults (n_persist=0, n_buffer=2, n_swap=0, n_checkpoint=0) + a clear warning; document the all-four requirement on the override field descriptions.

Validation (Qwen3.5-27B QLoRA, seq 1024, 2×3090 Ti NVLink)

Config s/it runtime/15 notes
manual recipe (kernel off, n_persist=4) 6.28 96.3 s prior hand-tuned route
manual + kernel on (fix B) 6.05 93.4 s fused kernel composes with offload
auto-mode (fixes A and B, no tuning) 5.05 77.4 s searcher offloads n_persist=34 on its own

Auto-mode now beats hand-tuning by ~20% because it offloads only what's needed. 0.6B text smoke unaffected (happy path unchanged). ruff check clean.

Deliberately out of scope

The global safety_fraction (prefer offload when predicted peak is near capacity for any model) is not included — it changes searcher selection for every near-capacity config. The VL-headroom fix covers the failure we hit; the general margin is a good follow-up.

No linked issues — this PR is self-contained ProTrain memory-management work.

…kernel with offload

Three fixes so ProTrain auto-mode trains a memory-bound VL model (e.g.
Qwen3.5-27B QLoRA on 2x24GB) with no manual tuning:

- The VL vision tower is invisible to the text-only profiling batch, so
  estimate_peak under-counts ~3.7 GiB and the searcher picks an
  all-resident config that OOMs. Reserve extra capacity headroom when a
  vision_config / visual module is present so it offloads instead.
- Lift the lora_mlp_kernel => no-offload pin (now opt-in via
  protrain_lora_mlp_forbid_offload). The v61 LoRA_MLPBackward shape
  mismatch it guarded is already fixed by the unconditional
  shape-preserving placeholders, so the fused MLP kernel composes with
  offload (~3% faster per step, validated).
- Complete a partial residency override with safe defaults + warn
  instead of silently falling back to the auto search; document the
  all-four-knob requirement on the override fields.

Validated on Qwen3.5-27B QLoRA seq-1024, 2x3090 Ti NVLink: auto-mode now
offloads (n_persist=34) and trains at 5.05 s/it, vs OOM before.
@coderabbitai

coderabbitai Bot commented Jun 1, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

ProTrain integration receives three coordinated enhancements: vision tower detection to reserve additional GPU capacity headroom during auto-search, a new boolean configuration to forbid offloading when using LoRA MLP kernels, and partial override completion logic that fills missing residency knob values with defaults while warning the user.

Changes

ProTrain Vision Tower Detection and Override Handling

Layer / File(s) Summary
Vision tower detection and capacity headroom
src/axolotl/integrations/protrain/api/model_wrapper.py
Adds _has_unprofiled_vision_tower helper to detect vision towers via HF config.vision_config or module-name heuristics. When capacity is unspecified, the wrapper adds 4GiB extra headroom if a vision tower is detected to prevent OOM during vision forward paths.
ProTrain configuration options
src/axolotl/integrations/protrain/args.py
Clarifies 4-knob override behavior in schema docs: searcher is bypassed only when all four n_persist/n_buffer/n_swap/n_checkpoint are set; partial sets trigger auto-completion with defaults and a warning. Adds new protrain_lora_mlp_forbid_offload boolean to block n_offload>0 candidates when LoRA MLP kernel is enabled.
Partial override completion and gating
src/axolotl/integrations/protrain/plugin.py
In post_model_load, handles incomplete override sets by filling missing knobs with safe defaults and logging a warning. Updates forbid_activation_offload to read from protrain_lora_mlp_forbid_offload instead of the older lora_mlp_kernel flag, with updated logging.

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Linked Issues check ⚠️ Warning The PR addresses vision-language (VL) model auto-offload and lora_mlp_kernel offload composition, but the linked issues #1 and #2 concern Liger kernel support for Qwen3-VL and JSON parsing refactoring, which are not present in the actual code changes. Verify that the correct linked issues are associated with this PR. The current changes address ProTrain memory management and vision tower detection, not the stated issues about Liger kernels or JSON parsing.
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title directly summarizes the main changes: auto-offload for memory-bound VL models and composing lora_mlp_kernel with offload, which are core objectives of the PR.
Out of Scope Changes check ✅ Passed All changes are scoped to ProTrain integration (model_wrapper.py, args.py, plugin.py) and directly address the PR's stated objectives of auto-offload for VL models and lora_mlp_kernel offload composition.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/protrain-vl-automode-offload

Comment @coderabbitai help to get the list of available commands and usage tips.

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jun 1, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
src/axolotl/integrations/protrain/plugin.py (1)

1645-1648: ⚡ Quick win

Condense these new inline comments to one short line each.

Both added rationale comments are multi-line; compress to single-line WHY comments.

As per coding guidelines src/**/*.py: "Keep comments to one short line maximum".

Also applies to: 1697-1699

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/plugin.py` around lines 1645 - 1648, Two
multi-line rationale comments near the wrapper logic (referencing the variables
n_persist, n_buffer, n_swap, n_checkpoint and the wrapper/searcher behavior)
should be condensed into single-line comments each; replace the two multi-line
WHY blocks around the wrapper completion logic (also present near the same
pattern at the later location) with concise one-line comments that state: (1)
the wrapper bypasses the searcher only when all four flags are set, and (2) we
fill missing flags with safe defaults so partial overrides behave as intended.
src/axolotl/integrations/protrain/api/model_wrapper.py (1)

72-74: ⚡ Quick win

Shorten the new inline rationale to a single short line.

This new comment is multi-line; keep it to one concise line per repo convention.

Proposed edit
-# A vision tower the text-only profiling batch never exercises is invisible to
-# estimate_peak, so reserve extra capacity to force offload instead of OOM.
+# Reserve extra GPU headroom when a vision tower is unprofiled by text-only tracing.

As per coding guidelines src/**/*.py: "Keep comments to one short line maximum".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/api/model_wrapper.py` around lines 72 - 74,
Replace the multi-line rationale above _VL_VISION_HEADROOM_BYTES with a single
short line comment (one sentence) explaining purpose, e.g., "Reserve extra
headroom to force vision offload and avoid OOM during text-only profiling." Keep
it concise and retain the constant name _VL_VISION_HEADROOM_BYTES and the
numeric value.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/integrations/protrain/api/model_wrapper.py`:
- Around line 77-83: The file fails CI due to formatting drift; run the ruff
formatter on the file and commit the changes. Specifically, run `ruff format`
(or your project's ruff command) against
src/axolotl/integrations/protrain/api/model_wrapper.py and ensure the function
_has_unprofiled_vision_tower and surrounding code are reformatted, then add and
commit the reformatted file so CI passes.

In `@src/axolotl/integrations/protrain/plugin.py`:
- Around line 1665-1673: The file fails the linter/formatter; run the ruff
formatter on src/axolotl/integrations/protrain/plugin.py (or apply ruff format
to the working tree) to fix formatting issues around the LOG.warning call and
surrounding code (e.g., the LOG.warning invocation that references _set,
n_persist_override, n_buffer_override, n_swap_override, n_checkpoint_override).
Ensure the formatted output conforms to project ruff settings and then re-run
CI.
- Around line 1700-1707: Update the protrain_model_wrapper docstring to reflect
that the forbid_activation_offload flag is read from
cfg.protrain_lora_mlp_forbid_offload (mapped to local variable
forbid_activation_offload) instead of cfg.lora_mlp_kernel; edit the docstring
text and any parameter descriptions to mention protrain_lora_mlp_forbid_offload
and the behavior ("refuse n_offload>0 candidates") so the docs match the code
path that sets forbid_activation_offload.

---

Nitpick comments:
In `@src/axolotl/integrations/protrain/api/model_wrapper.py`:
- Around line 72-74: Replace the multi-line rationale above
_VL_VISION_HEADROOM_BYTES with a single short line comment (one sentence)
explaining purpose, e.g., "Reserve extra headroom to force vision offload and
avoid OOM during text-only profiling." Keep it concise and retain the constant
name _VL_VISION_HEADROOM_BYTES and the numeric value.

In `@src/axolotl/integrations/protrain/plugin.py`:
- Around line 1645-1648: Two multi-line rationale comments near the wrapper
logic (referencing the variables n_persist, n_buffer, n_swap, n_checkpoint and
the wrapper/searcher behavior) should be condensed into single-line comments
each; replace the two multi-line WHY blocks around the wrapper completion logic
(also present near the same pattern at the later location) with concise one-line
comments that state: (1) the wrapper bypasses the searcher only when all four
flags are set, and (2) we fill missing flags with safe defaults so partial
overrides behave as intended.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4e554c56-2725-46f6-92c5-92dd88b2078e

📥 Commits

Reviewing files that changed from the base of the PR and between 3402746 and 461f3d2.

📒 Files selected for processing (3)
  • src/axolotl/integrations/protrain/api/model_wrapper.py
  • src/axolotl/integrations/protrain/args.py
  • src/axolotl/integrations/protrain/plugin.py

Comment on lines +77 to +83
def _has_unprofiled_vision_tower(model: "nn.Module") -> bool:
cfg = getattr(model, "config", None)
if cfg is not None and getattr(cfg, "vision_config", None) is not None:
return True
for name, _ in model.named_modules():
if name.rsplit(".", 1)[-1] in {"visual", "vision_tower", "vision_model"}:
return True

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Apply ruff format for this file before merge.

CI is currently failing on this file due formatting drift; please commit formatter output.

🧰 Tools
🪛 GitHub Actions: lint / 0_pre-commit.txt

[error] 83-83: ruff-format reformatting changed this file (2 files reformatted). Apply formatting changes committed by the hook.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/api/model_wrapper.py` around lines 77 - 83,
The file fails CI due to formatting drift; run the ruff formatter on the file
and commit the changes. Specifically, run `ruff format` (or your project's ruff
command) against src/axolotl/integrations/protrain/api/model_wrapper.py and
ensure the function _has_unprofiled_vision_tower and surrounding code are
reformatted, then add and commit the reformatted file so CI passes.

Comment on lines +1665 to +1673
LOG.warning(
"ProTrain: partial residency override (set: %s). The override "
"path needs all four of n_persist/n_buffer/n_swap/n_checkpoint "
"or it silently falls back to the auto search. Filled unset "
"knobs: n_persist=%d n_buffer=%d n_swap=%d n_checkpoint=%d. Set "
"all four explicitly to suppress this.",
", ".join(_set), n_persist_override, n_buffer_override,
n_swap_override, n_checkpoint_override,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Apply ruff format for this file before merge.

CI indicates formatting changes are still required in this file.

🧰 Tools
🪛 GitHub Actions: lint / 0_pre-commit.txt

[error] 1668-1668: ruff-format reformatting changed this file (2 files reformatted). Apply formatting changes committed by the hook.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/plugin.py` around lines 1665 - 1673, The
file fails the linter/formatter; run the ruff formatter on
src/axolotl/integrations/protrain/plugin.py (or apply ruff format to the working
tree) to fix formatting issues around the LOG.warning call and surrounding code
(e.g., the LOG.warning invocation that references _set, n_persist_override,
n_buffer_override, n_swap_override, n_checkpoint_override). Ensure the formatted
output conforms to project ruff settings and then re-run CI.

Comment on lines +1700 to 1707
forbid_activation_offload = bool(
getattr(cfg, "protrain_lora_mlp_forbid_offload", False)
)
if forbid_activation_offload:
LOG.info(
"ProTrain: cfg.lora_mlp_kernel=True; searcher will refuse "
"n_offload>0 candidates."
"ProTrain: protrain_lora_mlp_forbid_offload=True; searcher will "
"refuse n_offload>0 candidates."
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update the wrapper docstring to match this new knob source.

This hunk now uses protrain_lora_mlp_forbid_offload, but protrain_model_wrapper docs still describe forbid_activation_offload as coming from cfg.lora_mlp_kernel.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/plugin.py` around lines 1700 - 1707, Update
the protrain_model_wrapper docstring to reflect that the
forbid_activation_offload flag is read from cfg.protrain_lora_mlp_forbid_offload
(mapped to local variable forbid_activation_offload) instead of
cfg.lora_mlp_kernel; edit the docstring text and any parameter descriptions to
mention protrain_lora_mlp_forbid_offload and the behavior ("refuse n_offload>0
candidates") so the docs match the code path that sets
forbid_activation_offload.

- ruff format model_wrapper.py and plugin.py (CI formatting drift)
- update protrain_model_wrapper docstring: forbid_activation_offload now
  reads from protrain_lora_mlp_forbid_offload, not cfg.lora_mlp_kernel
- add docstring to _has_unprofiled_vision_tower
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai these three findings were already addressed in 5cc43e545 (pushed just before this review ran):

  • ruff format was applied to both model_wrapper.py and plugin.pyruff format --check now reports both files already formatted, and the pre-commit CI job passes on this commit.
  • The protrain_model_wrapper docstring was updated to reference cfg.protrain_lora_mlp_forbid_offload instead of cfg.lora_mlp_kernel.

The re-flag cites the lint run from the prior commit 461f3d2f9; the line numbers shifted by the docstring additions. Please re-check the latest commit.

@coderabbitai

coderabbitai Bot commented Jun 1, 2026

Copy link
Copy Markdown

Tip

For best results, initiate chat on the files or code changes.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the latest commit on the PR branch
git log --oneline -5

# Check formatting of the relevant files
echo "=== model_wrapper.py (relevant sections) ==="
cat src/axolotl/integrations/protrain/api/model_wrapper.py

echo ""
echo "=== plugin.py (relevant sections) ==="
cat src/axolotl/integrations/protrain/plugin.py

Length of output: 218050



@thad0ctor: Confirmed — all three items are verified in commit 5cc43e545:

  1. model_wrapper.py formatting — file is ruff-clean; indentation and blank lines are consistent throughout.
  2. plugin.py formatting — same, no formatting issues present.
  3. protrain_model_wrapper docstring — correctly reads cfg.protrain_lora_mlp_forbid_offload (default False) with the explanation that the fused MLP backward kernel now composes with offload via shape-preserving placeholders. The old cfg.lora_mlp_kernel reference is gone.

The earlier flags were from the lint snapshot of 461f3d2f9; the current HEAD is clean. No action needed on your end.

@thad0ctor thad0ctor merged commit 9c329e5 into fix/protrain-costmodel-followups Jun 1, 2026
12 checks passed
@thad0ctor thad0ctor deleted the fix/protrain-vl-automode-offload branch June 1, 2026 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant