Skip to content

ProTrain integration: chunk-managed weight offload + per-dtype memory cost model#24

Closed
thad0ctor wants to merge 456 commits into
mainfrom
protrain-phase2-integration
Closed

ProTrain integration: chunk-managed weight offload + per-dtype memory cost model#24
thad0ctor wants to merge 456 commits into
mainfrom
protrain-phase2-integration

Conversation

@thad0ctor

@thad0ctor thad0ctor commented May 21, 2026

Copy link
Copy Markdown
Owner

Summary

Add ProTrain as an Axolotl integration: a chunk-managed weight-offload runtime
with a per-dtype memory cost model, designed to train larger models on smaller
GPUs without sacrificing optimizer-state quality.

Highlights

  • 4× weight-memory reduction on 8B + 4-bit configurations
  • 74.6% optimizer-state memory reduction via bnb.AdamW8bit
  • Llama-13B + 4-bit + LoRA trains on a single RTX 3090
  • Llama-30B + 4-bit + LoRA trains on a single RTX 3090 (seq 512 / 1024 / 2048)
  • Multi-GPU 4×3090 support across replicated and sharded modes, including
    cross-mode resume
  • Composes with FlashAttention, fused LoRA kernels, and Liger
  • bnb.AdamW8bit + paged-variant optimizer adapter
  • Per-dtype α fragmentation factor in the memory cost model
    (alpha_fragmentation_for_dtype in cost/memory.py): fp16 / bf16 / 8-bit
    hold α=1.10, bnb-4-bit drops to α=0.75 to match the empirical
    α_measured ≈ 0.70 across the 4-bit matrix.

What's included

  • src/axolotl/integrations/protrain/ — runtime, scheduler, chunk manager,
    per-PEFT-LoRA-container hooks, cost model, profiler, plugin args
  • tests/protrain/ — CPU + GPU regression coverage (303 PASSED / 4 SKIPPED at
    branch tip on default markers; GPU markers verified on a single-GPU rig and
    on a 4×3090 rig)
  • src/axolotl/utils/environment.py — fail-closed P2P probe used by the
    multi-GPU launch path
  • examples/protrain/3090-8b-lora.yml — reference config

Test plan

  • CodeRabbit review
  • tests/protrain/ default-marker regression
  • tests/protrain/ GPU-marker subset on a single-GPU rig (cost-model files,
    chunk manager, bnb sweep, 4-bit offload sweep)
  • Multi-GPU regression: test_real_multigpu_cross_mode_resume_{a_to_c,c_to_a},
    test_paged_adam_offload_mgpu_no_ddp_broadcast_crash
  • Memory-headline benchmarks reproducible from m0_artifacts/

Known limitations

  • Single-GPU bs=1 throughput. ProTrain adds per-iter overhead (chunked
    layout, per-iter scheduler walk, PEFT-LoRA container hooks). At bs=1 the
    memory headroom isn't load-bearing, so the cost isn't amortized. Measured
    ~31.75% slower than fused-no-ProTrain at bs=1 / seq=512 / single-GPU.
    The acceptance criterion is scoped to multi-GPU bs ≥ 4 or to
    configurations where bare DDP OOMs (the 13B / 30B single-3090 headlines).
  • bnb-4-bit Mode-C iter-1 transient (~6.9× predicted) during model load.
    Init-time chunk residency phenomenon, not a fragmentation one; documented
    as an "init window" in DESIGN.md.
  • bnb-4-bit Mode-C steady residual (~1.47×). Activation accounting in the
    offload-mode forward path is under-counted. Tracked as a follow-up cost-model
    accuracy item.

Summary by CodeRabbit

  • New Features
    • Full ProTrain release: plugin, model/optimizer wrappers, chunked-memory runtime with CKPT/SWAP/OFFLOAD, profiler (on‑demand & phase‑2), hardware microbenchmarks, multi‑GPU benchmark, NCCL/PCIe tools, CPU reshard utility, and an RTX‑3090 LoRA training config.
  • Tests
    • Pytest defaults now exclude slow and gpu tests by default.
  • Documentation
    • Expanded ProTrain design & checkpoint guides and Apex install how‑to; added website how‑to page.
  • Chores
    • CI sdist persistent cache disabled to avoid deterministic Py3.12 install issues; .gitignore updated to ignore benchmark outputs.

Review Change Stack

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2068dfba-4007-4168-8d6d-8ae27b4b1d09

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

<review_stack_artifact>

</review_stack_artifact>

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch protrain-phase2-integration

@github-actions

github-actions Bot commented May 21, 2026

Copy link
Copy Markdown

📖 Documentation Preview:

Deployed on Netlify from commit 7f15738

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/integrations/protrain/chunk/buffer_pool.py`:
- Around line 219-342: Add an explicit closed-guard to prevent post-close
operations: at the top of BufferPool.acquire(checking self._closed) raise a
clear RuntimeError like "BufferPool is closed" instead of proceeding (this
prevents recreating oversize buffers or returning misleading exhaustion errors);
also mirror this same guard in the other public entry points that touch pool
state (e.g., acquire_if_resident, release, and any other public API methods) and
ensure close() sets self._closed = True during teardown.
- Around line 532-548: BufferPool.close currently sets self._closed and clears
pool state before calling self.pinned_host.close(), then swallows any exception,
which prevents retries and leaks PinnedHostMemory; change the teardown so
pinned_host.close() is attempted before marking the pool irrevocably closed and
before dropping self.pinned_host: call self.pinned_host.close() inside
BufferPool.close (or a dedicated _teardown_pinned_host helper), on success
proceed to clear buffers/state and set self._closed=True and
self.pinned_host=None, but on exception do not clear the pool state (leave
self._closed False and pinned_host intact) and either re-raise the exception or
return the error after logging—ensure you reference BufferPool.close,
self.pinned_host (type PinnedHostMemory), and preserve ability to retry
teardown.

In `@src/axolotl/integrations/protrain/plugin.py`:
- Around line 795-802: The try/except around existing.close() in the re-wrap
path must not swallow failures; if existing.close() raises, abort re-wrap and
preserve cfg._protrain_wrapped so teardown remains deterministic. Replace the
current behavior that logs and clears cfg._protrain_wrapped with one that logs
the error (using LOG.debug/LOG.error) and then re-raises or returns an
error/exception to stop the wrap operation, ensuring cfg._protrain_wrapped
remains set when close() fails; apply this change to the block referencing
existing.close() and cfg._protrain_wrapped so the close-chain
(WrappedModel.close → ChunkManager.close → backend/threads) is enforced
deterministically.

In `@src/axolotl/integrations/protrain/profiler/batch_factory.py`:
- Around line 102-113: The current logic returns TASK_SEQ2SEQ_LM as soon as
cfg.is_encoder_decoder is true, which prevents the module-class fallback from
detecting concrete heads like T5ForSequenceClassification; change the order so
the concrete class-name check runs before the generic is_encoder_decoder
fallback: use type(model).__name__ (cls_name) and iterate
_ARCHITECTURE_SUFFIX_TASKS first, returning the matched task if any, and only
then fall back to checking getattr(cfg, "is_encoder_decoder", False) to return
TASK_SEQ2SEQ_LM; ensure the same symbols (cls_name, _ARCHITECTURE_SUFFIX_TASKS,
cfg, is_encoder_decoder, TASK_SEQ2SEQ_LM) are used so models with missing
config.architectures are classified by concrete class when possible.

In `@src/axolotl/integrations/protrain/profiler/hw_bench.py`:
- Around line 625-637: The recorded NCCL timings are being stored under the
original requested payload_bytes even though you round down to
elements_per_shard (element_size=4) and bench the smaller combined size; compute
the actual benchmarked size (actual_payload_bytes = elements_per_shard *
world_size * element_size) after calculating elements_per_shard and use that
actual_payload_bytes when labeling/storing results (e.g., where timings/tables
are written later in the loop) so the lookup tables and cost model are keyed by
the true measured payload size instead of the requested payload_bytes; update
any places that read/write using payload_bytes to use actual_payload_bytes
(refer to element_size, elements_per_shard, shard, payload_bytes, world_size).

In `@src/axolotl/integrations/protrain/profiler/on_demand.py`:
- Around line 254-264: The OnDemandTensorMgr currently allows re-entering the
same instance which causes duplicate hook registration and unsafe shared state;
modify __enter__ in class OnDemandTensorMgr to detect re-entry (e.g., check an
_entered flag) and raise an error if already entered, and ensure __exit__ clears
or resets that flag (set _entered = True on successful enter and set _entered =
False on exit) so hooks and _spills bookkeeping cannot be registered twice or
left in an inconsistent state; update both __enter__ and __exit__ methods to
enforce and maintain this guard around hook registration and
saved_tensors_hooks.

In `@src/axolotl/integrations/protrain/profiler/phase2.py`:
- Around line 511-515: This measurement helper currently calls
optimizer.zero_grad(set_to_none=True) and later relies on
model_state/optim_state but does not save/restore param.grad, which silently
drops caller-held accumulated gradients; update the helper (near the
optimizer.zero_grad call and where model_state/optim_state are used) to either
1) fail fast by detecting any parameter with non-None .grad and raising an error
unless grads are already clear, or 2) preserve and restore grads by snapshotting
all param.grad tensors before zeroing and restoring them after the
profiling/restore path; reference the optimizer.zero_grad(set_to_none=True) call
and the model_state/optim_state restore code when making the change so grads are
either validated or round-tripped.
- Around line 75-92: The final return forcing at least 1 skews results when all
chunks are persistent; change the end of this logic so the function returns the
computed need (allowing 0) instead of forcing max(1, need). Specifically, keep
the early check using n_persist and the computation over
layout.effective_persistent_ids and layout.block_to_chunks, but replace the
final return max(1, need) with simply return need (or return max(0, need)) so
fully-persistent layouts can bootstrap with a zero buffer.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: af6b3549-b924-481a-95e0-cc6fdf3ce3e1

📥 Commits

Reviewing files that changed from the base of the PR and between dc8f7c7 and 3ea9489.

📒 Files selected for processing (117)
  • .github/workflows/tests.yml
  • .gitignore
  • examples/protrain/3090-8b-lora.yml
  • pyproject.toml
  • scripts/benchmark_multi_gpu.py
  • scripts/protrain/measure_nccl.py
  • scripts/protrain/reshard_optim.py
  • src/axolotl/integrations/protrain/BLOCK_MODE_OFFLOAD_DESIGN.md
  • src/axolotl/integrations/protrain/CHECKPOINT_DESIGN.md
  • src/axolotl/integrations/protrain/CHECKPOINT_DESIGN_PHASE2.md
  • src/axolotl/integrations/protrain/DESIGN.md
  • src/axolotl/integrations/protrain/__init__.py
  • src/axolotl/integrations/protrain/api/__init__.py
  • src/axolotl/integrations/protrain/api/checkpoint.py
  • src/axolotl/integrations/protrain/api/hardware.py
  • src/axolotl/integrations/protrain/api/model_wrapper.py
  • src/axolotl/integrations/protrain/api/optim_wrapper.py
  • src/axolotl/integrations/protrain/api/reshard.py
  • src/axolotl/integrations/protrain/args.py
  • src/axolotl/integrations/protrain/block/__init__.py
  • src/axolotl/integrations/protrain/block/checkpoint.py
  • src/axolotl/integrations/protrain/block/dispatcher.py
  • src/axolotl/integrations/protrain/block/layout_rules.py
  • src/axolotl/integrations/protrain/block/offload.py
  • src/axolotl/integrations/protrain/block/strategy.py
  • src/axolotl/integrations/protrain/block/swap.py
  • src/axolotl/integrations/protrain/block/swap_pool.py
  • src/axolotl/integrations/protrain/chunk/__init__.py
  • src/axolotl/integrations/protrain/chunk/buffer_pool.py
  • src/axolotl/integrations/protrain/chunk/layout.py
  • src/axolotl/integrations/protrain/chunk/manager.py
  • src/axolotl/integrations/protrain/chunk/optim.py
  • src/axolotl/integrations/protrain/chunk/pinned_alloc.py
  • src/axolotl/integrations/protrain/chunk/sizing.py
  • src/axolotl/integrations/protrain/cost/__init__.py
  • src/axolotl/integrations/protrain/cost/bandwidth.py
  • src/axolotl/integrations/protrain/cost/memory.py
  • src/axolotl/integrations/protrain/cost/runtime.py
  • src/axolotl/integrations/protrain/plugin.py
  • src/axolotl/integrations/protrain/profiler/__init__.py
  • src/axolotl/integrations/protrain/profiler/batch_factory.py
  • src/axolotl/integrations/protrain/profiler/cache.py
  • src/axolotl/integrations/protrain/profiler/hw_bench.py
  • src/axolotl/integrations/protrain/profiler/memory_deltas.py
  • src/axolotl/integrations/protrain/profiler/on_demand.py
  • src/axolotl/integrations/protrain/profiler/phase2.py
  • src/axolotl/integrations/protrain/profiler/trace.py
  • src/axolotl/integrations/protrain/runtime/__init__.py
  • src/axolotl/integrations/protrain/runtime/hooks.py
  • src/axolotl/integrations/protrain/runtime/scheduler.py
  • src/axolotl/integrations/protrain/runtime/streams.py
  • src/axolotl/integrations/protrain/search/__init__.py
  • src/axolotl/integrations/protrain/search/exhaustive.py
  • src/axolotl/integrations/protrain/search/knobs.py
  • src/axolotl/integrations/protrain/types.py
  • src/axolotl/utils/config/__init__.py
  • src/axolotl/utils/environment.py
  • tests/protrain/__init__.py
  • tests/protrain/conftest.py
  • tests/protrain/peft_edge_cases/__init__.py
  • tests/protrain/peft_edge_cases/test_dora.py
  • tests/protrain/peft_edge_cases/test_multi_adapter.py
  • tests/protrain/peft_edge_cases/test_vision_lm_hybrid.py
  • tests/protrain/test_adamw8bit_adapter.py
  • tests/protrain/test_alpha_per_dtype.py
  • tests/protrain/test_api.py
  • tests/protrain/test_auto_wrap.py
  • tests/protrain/test_batch_factory.py
  • tests/protrain/test_block_manager.py
  • tests/protrain/test_bnb_offload.py
  • tests/protrain/test_chunk_manager.py
  • tests/protrain/test_chunk_manager_distributed.py
  • tests/protrain/test_chunk_manager_offload.py
  • tests/protrain/test_chunk_optim_shutdown.py
  • tests/protrain/test_cost_search.py
  • tests/protrain/test_cross_mode_resume.py
  • tests/protrain/test_enc_dec_smoke.py
  • tests/protrain/test_full_ft_smoke.py
  • tests/protrain/test_fused_lora_kernels.py
  • tests/protrain/test_hw_bench.py
  • tests/protrain/test_init_transient_peak.py
  • tests/protrain/test_integration_2b.py
  • tests/protrain/test_integration_7b.py
  • tests/protrain/test_late_nccl_search_skip.py
  • tests/protrain/test_lora_offload_mode.py
  • tests/protrain/test_m5_cli_smoke.py
  • tests/protrain/test_mandatory_persistent.py
  • tests/protrain/test_math_equivalence.py
  • tests/protrain/test_modec_external_baseline.py
  • tests/protrain/test_modec_steady_peak_accuracy.py
  • tests/protrain/test_multi_gpu_7b.py
  • tests/protrain/test_multi_gpu_benchmark.py
  • tests/protrain/test_offload_mode_m1.py
  • tests/protrain/test_offload_mode_m2.py
  • tests/protrain/test_offload_mode_m3.py
  • tests/protrain/test_offload_mode_m4.py
  • tests/protrain/test_optimizer_checkpoint.py
  • tests/protrain/test_paged_adam_offload_mgpu.py
  • tests/protrain/test_param_data_shape_preservation.py
  • tests/protrain/test_peak_calibration.py
  • tests/protrain/test_plugin_args_validators.py
  • tests/protrain/test_plugin_auto_mode.py
  • tests/protrain/test_plugin_e2e.py
  • tests/protrain/test_plugin_early_dist_init.py
  • tests/protrain/test_plugin_nccl_remeasure.py
  • tests/protrain/test_profiler.py
  • tests/protrain/test_quantization.py
  • tests/protrain/test_resume_robustness.py
  • tests/protrain/test_scheduler.py
  • tests/protrain/test_seq_cls_smoke.py
  • tests/protrain/test_sharded_lora_offload.py
  • tests/protrain/test_single_stream_allocator.py
  • tests/protrain/test_steady_state_calibration.py
  • tests/protrain/test_swap.py
  • tests/protrain/test_trace_skip_on_override.py
  • tests/protrain/test_world_size_reshard.py
  • tests/protrain/test_wrapped_model_close.py

Comment thread src/axolotl/integrations/protrain/chunk/buffer_pool.py
Comment thread src/axolotl/integrations/protrain/chunk/buffer_pool.py Outdated
Comment thread src/axolotl/integrations/protrain/plugin.py
Comment thread src/axolotl/integrations/protrain/profiler/batch_factory.py Outdated
Comment thread src/axolotl/integrations/protrain/profiler/hw_bench.py
Comment thread src/axolotl/integrations/protrain/profiler/on_demand.py
Comment thread src/axolotl/integrations/protrain/profiler/phase2.py Outdated
Comment thread src/axolotl/integrations/protrain/profiler/phase2.py Outdated
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/integrations/protrain/profiler/hw_bench.py (1)

667-718: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use actual_payload_bytes consistently for the rank-0 debug path.

After this change, the timing tables are keyed by actual_payload_bytes, but Lines 717-718 still read them with payload_bytes. On any custom payload list that does not divide evenly into world_size * 4, rank 0 will hit a KeyError after the measurement loop.

Suggested fix
-        if rank == 0:
-            LOG.debug(
-                "measure_nccl payload=%dMiB gather=%.3fms reduce=%.3fms "
-                "(world=%d, %d iters)",
-                payload_bytes >> 20,
-                gather_table[payload_bytes] * 1000,
-                reduce_table[payload_bytes] * 1000,
-                world_size,
-                n_iters,
-            )
+        if rank == 0:
+            LOG.debug(
+                "measure_nccl payload=%d bytes gather=%.3fms reduce=%.3fms "
+                "(requested=%d bytes, world=%d, %d iters)",
+                actual_payload_bytes,
+                gather_table[actual_payload_bytes] * 1000,
+                reduce_table[actual_payload_bytes] * 1000,
+                payload_bytes,
+                world_size,
+                n_iters,
+            )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/profiler/hw_bench.py` around lines 667 -
718, The debug print in the rank==0 block mistakenly indexes gather_table and
reduce_table with payload_bytes (which can be absent) instead of the measurement
key actual_payload_bytes; update the LOG.debug call to use
gather_table[actual_payload_bytes] and reduce_table[actual_payload_bytes] (leave
the displayed payload size expression as-is if you want to report the requested
payload_bytes) so the rank-0 path reads the same keys used during measurement
and avoids KeyError when payloads don't evenly shard.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@src/axolotl/integrations/protrain/profiler/hw_bench.py`:
- Around line 667-718: The debug print in the rank==0 block mistakenly indexes
gather_table and reduce_table with payload_bytes (which can be absent) instead
of the measurement key actual_payload_bytes; update the LOG.debug call to use
gather_table[actual_payload_bytes] and reduce_table[actual_payload_bytes] (leave
the displayed payload size expression as-is if you want to report the requested
payload_bytes) so the rank-0 path reads the same keys used during measurement
and avoids KeyError when payloads don't evenly shard.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e3419305-4861-4767-8ae3-b8375d8b1ad4

📥 Commits

Reviewing files that changed from the base of the PR and between 3ea9489 and 9982739.

📒 Files selected for processing (8)
  • src/axolotl/integrations/protrain/api/model_wrapper.py
  • src/axolotl/integrations/protrain/chunk/buffer_pool.py
  • src/axolotl/integrations/protrain/plugin.py
  • src/axolotl/integrations/protrain/profiler/batch_factory.py
  • src/axolotl/integrations/protrain/profiler/hw_bench.py
  • src/axolotl/integrations/protrain/profiler/on_demand.py
  • src/axolotl/integrations/protrain/profiler/phase2.py
  • tests/protrain/test_modec_steady_peak_accuracy.py
💤 Files with no reviewable changes (1)
  • src/axolotl/integrations/protrain/api/model_wrapper.py

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/axolotl/integrations/protrain/api/checkpoint.py (1)

246-248: 💤 Low value

Redundant import of torch.

torch is already imported at module level (line 13); the local import here is unnecessary.

Suggested fix
 def _estimate_optim_state_bytes(optim: Any) -> int:
     """Estimated bytes for the optimizer's persisted Adam state (cluster-wide under Mode-C)."""
-    import torch
-
     replicated = 0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/api/checkpoint.py` around lines 246 - 248,
The local import of torch inside _estimate_optim_state_bytes is redundant
because torch is imported at module scope; remove the inner "import torch"
statement from the _estimate_optim_state_bytes function so the function uses the
module-level torch import instead (search for the function
_estimate_optim_state_bytes to locate the change).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@src/axolotl/integrations/protrain/api/checkpoint.py`:
- Around line 246-248: The local import of torch inside
_estimate_optim_state_bytes is redundant because torch is imported at module
scope; remove the inner "import torch" statement from the
_estimate_optim_state_bytes function so the function uses the module-level torch
import instead (search for the function _estimate_optim_state_bytes to locate
the change).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8bd49e99-85f1-4891-bfc6-659b38043e66

📥 Commits

Reviewing files that changed from the base of the PR and between d57e264 and 4ec11e8.

📒 Files selected for processing (5)
  • src/axolotl/integrations/protrain/api/checkpoint.py
  • src/axolotl/integrations/protrain/api/optim_wrapper.py
  • src/axolotl/integrations/protrain/plugin.py
  • tests/protrain/test_chunk_optim_shutdown.py
  • tests/protrain/test_trace_skip_on_override.py

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/integrations/protrain/cost/runtime.py (1)

275-275: ⚡ Quick win

Replace Greek α in comments/docstrings to clear Ruff.

These lines trip RUF002/RUF003. Swapping α for alpha avoids lint noise without changing meaning.

Also applies to: 282-282, 291-291, 319-319, 333-333, 354-354, 385-385, 408-408, 417-417, 421-421, 448-448

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/cost/runtime.py` at line 275, Replace the
Greek letter α used in comments/docstrings with the ASCII word "alpha" to
satisfy Ruff (RUF002/RUF003); search for occurrences of "α" (e.g., the comment
starting "Per-component α clamp bounds; [0.5, 2.0] window..." and the other
similar inline comments) and update them to "alpha" without changing surrounding
wording or code.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/integrations/protrain/cost/runtime.py`:
- Around line 227-237: The code currently treats eff_h2d or eff_d2h == 0 as a
free transfer by using 0.0 for h2d/d2h; instead, fail closed when either
effective PCIe bandwidth is missing by detecting eff_h2d <= 0 or eff_d2h <= 0
and raising a clear exception (e.g., RuntimeError or ValueError with a message
like "missing effective PCIe bandwidth for H2D/D2H") so the estimator rejects
the candidate; update the block that computes h2d/d2h (using S_chunk, eff_h2d,
eff_d2h, nccl_reduce_s, is_backward, buffer_cached, collective) to raise rather
than return 0 when eff_* is non-positive.

---

Nitpick comments:
In `@src/axolotl/integrations/protrain/cost/runtime.py`:
- Line 275: Replace the Greek letter α used in comments/docstrings with the
ASCII word "alpha" to satisfy Ruff (RUF002/RUF003); search for occurrences of
"α" (e.g., the comment starting "Per-component α clamp bounds; [0.5, 2.0]
window..." and the other similar inline comments) and update them to "alpha"
without changing surrounding wording or code.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 373d317e-8857-4268-a197-082a819f3794

📥 Commits

Reviewing files that changed from the base of the PR and between 4ec11e8 and 2fe7629.

📒 Files selected for processing (1)
  • src/axolotl/integrations/protrain/cost/runtime.py

Comment thread src/axolotl/integrations/protrain/cost/runtime.py Outdated
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 21, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@thad0ctor

Copy link
Copy Markdown
Owner Author

Overnight CR triage complete.

Cycles: 9 (8 cron fires + 1 manual setup)
Commits during overnight loop: 42
Net line change: 43 files changed, 1201 insertions(+), 13517 deletions(-)

CR threads addressed:

  • Cycle 1: 8 inline findings applied (buffer_pool post-close acquires, plugin stale-wrapper fail-closed, batch_factory task detection, hw_bench payload keying, on_demand re-entry, phase2 zero-buffer + caller-grad)
  • Cycle 5: cost/runtime._comm_time_chunk fail-closed on zero PCIe bandwidth
  • Cleanup pass: collapsed CR cycle 1's multi-line # blocks to one-line WHY (commit ca48336), mopped 7 more (commit f8da137)

Concurrent docstring condensation pass: 40 files, ~16,440 lines removed (model_wrapper -1534, chunk/manager -1546, cost/runtime -1125, api/checkpoint -758, cost/memory -731 leading the way).

CI: pre-commit pass on every commit. PyTest passing on 3/6 matrices so far (3.14 x 2.10, source-dist 3.12 x 2.9, source-dist 3.14 x 2.10 — 11-19 min each). Remaining 3 still running.

CodeRabbit re-review triggers posted at each push + every 30 min idle.

Final branch tip: 2617da6

@thad0ctor

Copy link
Copy Markdown
Owner Author

Benchmark suite complete. Headline result: Meta-Llama-3-8B BF16 LoRA on a single 3090 — vanilla 15.75 GiB → ProTrain Mode A 3.30 GiB (80% memory reduction) with 7x throughput cost at bs=1 seq=256 (matches the PR's documented bs=1 limitation). Phase 1/2/3 tests all passed (313/4-skip plus GPU + multi-GPU regression). 13B/27B BF16 not measured — those need load_in_4bit: true per the PR's headline configs (ProTrain requires the model to fit briefly on GPU before materialize_offload). Full report at /home/rgilbreth/Desktop/protrain-benchmark-report.md.

@thad0ctor

Copy link
Copy Markdown
Owner Author

Bug found and fixed in this branch as commit c838ac1.

Symptom: RuntimeError: expected scalar type BFloat16 but found Float in profiler/trace.py and profiler/phase2.py when wrapping a 4-bit + LoRA Qwen3.5-9B model.

Reproduction: load_in_4bit: true, adapter: qlora, bf16: true on any model with a non-Linear non-quantized parametric layer in its forward body (Qwen3.5 has a Mamba-style linear_attn.conv1d; same bug class would hit Mamba / Falcon-Mamba / Zamba etc.). Stack trace bottoms out at:

transformers/models/qwen3_5/modeling_qwen3_5.py:474 linear_attn.forward
  mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])

Root cause: bnb 4-bit quantization only touches nn.Linear. axolotl's 'Converting modules to torch.bfloat16' does not catch every parametric layer either, so a Conv1d weight remains fp32 while upstream activations have been cast to BF16. Vanilla training survives because Trainer.training_step runs forward under accelerator.autocast(...) which casts inputs back. The profiler's run_trace and measure_chunked_steady did naive model forwards without autocast, so the dtype mismatch surfaced immediately.

Fix: wrap each model(**batch) in profiler/trace.py and profiler/phase2.py with torch.autocast(device_type='cuda', dtype=<dominant param dtype>), where the dtype is detected from the first floating-point parameter of the model. No effect on fp32-everywhere or fully-quantized paths (autocast is a no-op when dtype is neither bfloat16 nor float16).

Tested vanilla-baseline data so far: Qwen3.5-9B + 4-bit qlora vanilla 8.39 GiB, Llama-2-13b + 4-bit qlora vanilla 7.91 GiB peak / 50 steps in 35.13s / loss 1.374→1.074 (M3 single-3090 vanilla baseline confirmed fits). Re-running the ProTrain side with this fix in v9.

@thad0ctor

Copy link
Copy Markdown
Owner Author

Follow-up on the Qwen3.5 + ProTrain incompatibility. After fixing the dtype mismatch in commit c838ac1ab, a re-run surfaced two more issues, both now fixed:

1. Autocast detection (commit cf00f5d9f) — the dtype-detection loop in profiler/trace.py:run_trace and profiler/phase2.py:measure_chunked_steady used first param where is_floating_point() which can return a stray fp32 RMSNorm or Conv1d weight; autocast_dtype ended up torch.float32 and the autocast context collapsed to a nullcontext. Narrowed the search to params whose dtype is explicitly torch.bfloat16 or torch.float16.

2. Mamba-style block discovery (commit e57aa6917)block/layout_rules.py:_looks_like_block only recognized .attention / .self_attn. Qwen3.5's DecoderLayer puts its linear-attention module under .linear_attn, so the whole model.layers ModuleList was rejected as 'not transformer blocks' even though it sits at the standard base_model.model.model.layers path. Added hasattr(m, 'linear_attn') to the direct-block check and to the CheckpointedBlock-inner check. Same class of fix would also cover Falcon-Mamba, Zamba, and any other Mamba/hybrid where the attention attribute is non-standard.

Verified data so far (single 3090, BF16 LoRA / 4-bit qlora, seq=256, 50 steps):

  • Meta-Llama-3-8B vanilla 15.75 GiB vs ProTrain Mode A 3.30 GiB (~80% reduction) — v6
  • Qwen3.5-9B + 4-bit qlora vanilla 8.39 GiB rc=0 — v8
  • Llama-2-13b + 4-bit qlora vanilla 7.91 GiB / 50 steps / 35.13s / loss 1.374→1.074 — M3 single-3090 vanilla baseline confirmed, v8

Currently re-running Qwen3.5-9B + 4-bit qlora ProTrain with both fixes; Llama-2-13b + 4-bit ProTrain (which doesn't hit either Qwen3.5-specific bug, so loaded with the pre-fix code) is mid-exhaustive-search (N_chunk=162 N_block=40).

@thad0ctor

Copy link
Copy Markdown
Owner Author

Fix chain verified: Qwen3.5-9B + 4-bit qlora + ProTrain trains end-to-end on a single 3090 with commits c838ac1 + cf00f5d + e57aa69 applied.

materialize_offload: freed 2.24 GB (alloc 13.53 -> 8.48 GB)
Training: 50/50 steps in 22.49s
Loss: 1.758 -> 0.839 (50 steps cosine)
Peak GPU memory (active): 12.60 GiB

For comparison the 9B+4bit vanilla baseline was 8.39 GiB at 1.42 samples/s; ProTrain runs at 2.22 samples/s but uses more GPU memory because the picked Mode A keeps most chunks persistent and adds activation/chunk-gather buffers — a small model with no offload pressure, where ProTrain's overhead is paid without amortization (matches the PR's documented single-GPU bs=1 limitation).

@thad0ctor

Copy link
Copy Markdown
Owner Author

Benchmark v8 (true qlora 4-bit) complete. Key results:

Vanilla baselines (4-bit qlora, single 3090, BF16 LoRA, bs=1 seq=256, 50 steps):

  • Qwen3.5-9B vanilla: 8.39 GiB peak rc=0
  • Llama-2-13b vanilla: 7.91 GiB peak / 50 steps in 35.13s / loss 1.374→1.074 (PR M3 single-3090 vanilla baseline confirmed)
  • Qwen3.5-27B vanilla: OOM (the 4-bit base is ~13.5 GB but activations + Adam + LoRA + buffers blow past 24 GiB)

ProTrain runs: v8's were affected by the Qwen3.5 incompatibilities (dtype + linear_attn block discovery) — fixed in commits c838ac1, cf00f5d, e57aa69. Llama-2-13b + 4-bit + ProTrain hit the per-config timeout (75 min) mid exhaustive search (N_chunk=162 N_block=40 is a much bigger search space than 8B's 130×32). Qwen3.5-9B + 4-bit ProTrain re-runs after the fix chain landed succeeded at 12.6 GiB peak in 22.49s; documented in issuecomment-4513338847.

@thad0ctor

Copy link
Copy Markdown
Owner Author

Save / merge / resume matrix complete (v10) — all 8 scenarios passed rc=0 on a single 3090. Total wall time: 11 minutes.

Model Vanilla resume Vanilla merge ProTrain resume ProTrain merge
Qwen3.5-0.8B (linear_attn) ✅ pre 1.78 / post 0.93 ✅ rc=0 ✅ pre 2.05 / post 0.80 ✅ train 1.37 + merge 0
Qwen3-0.6B (standard attn) ✅ pre 2.23 / post 1.50 ✅ rc=0 ✅ pre 1.41 / post 1.91 ✅ train 0.90 + merge 0

M5/M6 protrain_optim/ checkpoint validated end-to-end: checkpoint-{15,20,30,45,50}/protrain_optim/ directories present, each containing gpu_optim.pt (~18 MB) + metadata.json (format_version 2, persistent_ids 0..7, saved_at_step). Load-side restoration runs on resume (_load_protrain_optim_dir fires).

Loss bump noted in Qwen3-0.6B ProTrain resume (1.41 → 1.91) investigated and dismissed — see issuecomment-next. Same bump appears in vanilla resume; root cause is the cosine LR schedule re-fit when max_steps is extended from 20 to 50 on the resume call. LR at step 25 jumps from 1.23e-6 (pre-save) to 1.13e-4 (post-resume) — 92× higher, applied to a near-converged optimizer state → temporary loss bounce. A no-resume ProTrain control with the same hyperparams trained 50 steps to loss 0.983 cleanly, confirming the issue is hyperparameter scheduling, not a ProTrain code bug.

@thad0ctor

Copy link
Copy Markdown
Owner Author

v11 gap-fill complete (single-GPU + multi-GPU) — see /home/rgilbreth/Desktop/protrain-benchmark-report.md sections 'Gap-fill v11'.

Headlines:

  • Multi-GPU 4× 3090: Llama-13B+4bit Mode A 24.68 sps @ 9.98 GiB/rank; Mode C 23.67 sps @ 9.87 GiB/rank; Qwen3.5-9B+4bit Mode A 32.53 sps @ 9.89 GiB/rank
  • Single 3090: Llama-13B+4bit+ProTrain Mode A (force_all_persistent) 7.91 GiB / 16.72s / loss 0.895 (PR M3 headline reproduced); Qwen3.5-9B+4bit bs=4 ProTrain 55% faster than vanilla; 13B+4bit adamw_bnb_8bit 7.53 vs adamw_torch 7.91 GiB
  • Qwen3.5-27B+4bit+ProTrain at seq=128 on a single 3090: peak 20.08 GiB, 16.95s / 25 steps, loss 1.228 (PR M5 30B-class single-3090 stretch reproduced; auto-mode picked Mode B replicated CPU-offload)

@thad0ctor

Copy link
Copy Markdown
Owner Author

v12 gap-fill complete (single-GPU + multi-GPU with flash_attention: true) — see /home/rgilbreth/Desktop/protrain-benchmark-report.md.

Highlights (all with FA on):

  • bs sweep Qwen3.5-9B+4bit single 3090: vanilla bs=4 10.81 GiB / 1.77 sps; ProTrain bs=4 13.49 GiB / 3.11 sps; bs=8 vanilla 13.92 GiB / 2.53 sps; bs=8 ProTrain OOM'd (Mode-C spill)
  • Long-horizon convergence Llama-13B+4bit 500 steps: vanilla loss 0.906 in 439s @ 7.91 GiB; ProTrain Mode A loss 0.865 in 303s @ 9.34 GiB — ProTrain converges further AND faster at 500 steps
  • Multi-GPU 4×3090 Mode A/C re-runs with FA: 13B Mode A 9.56 GiB / 5.44 sps/rank; 13B Mode C 9.57 GiB / 5.52 sps/rank; Qwen3.5-9B Mode A 9.86 GiB / 7.63 sps/rank
  • Headline 8B BF16 re-run with FA: vanilla 15.88 GiB (matches FA-off 15.83 within noise — activations weren't the dominant cost); ProTrain run hit per-config 20-min timeout in exhaustive search

Additionally: a ProTrain Phase-2 false-positive on bnb 4-bit companion-buffer keys (.absmax/.quant_map/.nested_/.quant_state.) was diagnosed and patched (commit 321408f58); re-validation of the 4 affected configs runs next.

@thad0ctor

Copy link
Copy Markdown
Owner Author

Phase-2 chunked-runtime gate false-positives on bnb 4-bit models.

profiler/phase2.py::measure_chunked_steady, lines 356-364:

if _result.unexpected_keys:
raise RuntimeError(
f"Phase-2 state_dict restore saw {len(_result.unexpected_keys)} "
f"unexpected snapshot keys ..."
"The live model dropped or renamed state during the timed "
"measurement, so rollback is incomplete."
)

bitsandbytes Linear4bit._save_to_state_dict emits <weight>.absmax,
<weight>.quant_map, <weight>.nested_absmax,
<weight>.nested_quant_map, <weight>.quant_state.<algo> per module,
but Linear4bit has no _load_from_state_dict override — QuantState
lives as a Python attribute on the Params4bit parameter, not as a
registered buffer. Result: every bnb-4bit Linear contributes ~5 keys
to _result.unexpected_keys even on a noop round-trip. Repro on a
single 64x64 Linear4bit (nf4) gives 5 unexpected_keys; the v11 27B run
shows 3030 = 606 modules x 5, the v9/v11/v12 9B runs show 1790 = 358 x 5.

This trips the gate, aborts Phase-2, and forces a v8 cost-model
fallback under the warning:

Phase-2 chunked measurement raised RuntimeError: ... 3030 unexpected
snapshot keys ...; falling back to the v8 cost-model path

The fallback is correct but it surrenders the ~10% accuracy improvement
Phase-2 is designed to give. Affects every bnb-4bit + ProTrain
configuration that hits Phase-2 (4 known logs to date — see
re-run list in the linked report).

Suggested fix: build an expected_unexpected_keys set from the
snapshot keys ending in .absmax, .quant_map, .nested_absmax,
.nested_quant_map, or containing .quant_state., and subtract it
from _result.unexpected_keys before the gate raises. Mirrors the
existing expected_missing_keys filter for offloaded placeholders.

Companion correctness: tests/protrain/test_bnb_offload.py lines
212-271 already verify qs.absmax device + bytes are preserved across
chunk gather/offload, so the round-trip these keys would have done
through load_state_dict was redundant — chunk/manager.py rebinds
param.data without touching python attrs (see args.py:73-78).

Not a blocker for shipping the 27B headline (training converged, peak
20.08 GiB, loss 1.228, rc=0), but a wanting-better-not-blocking
follow-up.


Fix applied as 321408f58 on protrain-phase2-integration.
Pre-commit (ruff/mypy/bandit) clean; pytest tests/protrain -k "phase2 or steady or chunked_steady" 24/24 passed.

Post-fix v13 re-validation on GPU 2 confirmed phase2_fallback_fired=0
and phase2_runtime_error_count=0 on the 3 9B-4bit runs that
previously fell back. ProTrain's CostConfig pick changed on 2 of those
3 runs — the searcher now lands on configurations the v8 fallback
couldn't reach:

  • qwen35-9b-4bit bs=1 (50 steps): same pick, peak 12.6 → 11.23 GiB.
  • qwen35-9b-4bit bs=4 (25 steps): n_persist=1 → 16, n_offload=30 → 0
    (fully resident; +3.4 GiB peak but lower loss).
  • qwen35-9b-4bit bs=1 (25 steps): n_persist=0 → 6, n_buffer=15 → 10,
    n_swap=0 → 1, n_offload=30 → 29 (-4.95 GiB peak, +8% throughput).

The 27B-seq=128 re-run on the same 24 GiB 3090 OOMed at model load
(pre-Phase-2; _convert_embedding_modules_dtype requests 4.74 GiB,
allocator only had 4.05 GiB free). v11 fit by ≤1 GiB; today's allocator
state shaved that margin. Unrelated to this patch — recommend
re-attempting the 27B on the 4×3090 multi-GPU set or 5090. Full
head-to-head table in protrain-benchmark-report.md (v13 section).

@thad0ctor thad0ctor force-pushed the protrain-phase2-integration branch from 90f60a5 to fe3f38f Compare May 28, 2026 17:18
@thad0ctor thad0ctor force-pushed the protrain-phase2-integration branch from e553305 to f58966b Compare May 28, 2026 21:25
@thad0ctor

Copy link
Copy Markdown
Owner Author

Closing to reopen a fresh fork PR and force a full CodeRabbit review on the current branch state.

@thad0ctor thad0ctor closed this May 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant