ProTrain integration: chunk-managed weight offload + per-dtype memory cost model#24
ProTrain integration: chunk-managed weight offload + per-dtype memory cost model#24thad0ctor wants to merge 456 commits into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 Walkthrough<review_stack_artifact> </review_stack_artifact> ✨ Finishing Touches🧪 Generate unit tests (beta)
|
|
📖 Documentation Preview: Deployed on Netlify from commit 7f15738 |
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/integrations/protrain/chunk/buffer_pool.py`:
- Around line 219-342: Add an explicit closed-guard to prevent post-close
operations: at the top of BufferPool.acquire(checking self._closed) raise a
clear RuntimeError like "BufferPool is closed" instead of proceeding (this
prevents recreating oversize buffers or returning misleading exhaustion errors);
also mirror this same guard in the other public entry points that touch pool
state (e.g., acquire_if_resident, release, and any other public API methods) and
ensure close() sets self._closed = True during teardown.
- Around line 532-548: BufferPool.close currently sets self._closed and clears
pool state before calling self.pinned_host.close(), then swallows any exception,
which prevents retries and leaks PinnedHostMemory; change the teardown so
pinned_host.close() is attempted before marking the pool irrevocably closed and
before dropping self.pinned_host: call self.pinned_host.close() inside
BufferPool.close (or a dedicated _teardown_pinned_host helper), on success
proceed to clear buffers/state and set self._closed=True and
self.pinned_host=None, but on exception do not clear the pool state (leave
self._closed False and pinned_host intact) and either re-raise the exception or
return the error after logging—ensure you reference BufferPool.close,
self.pinned_host (type PinnedHostMemory), and preserve ability to retry
teardown.
In `@src/axolotl/integrations/protrain/plugin.py`:
- Around line 795-802: The try/except around existing.close() in the re-wrap
path must not swallow failures; if existing.close() raises, abort re-wrap and
preserve cfg._protrain_wrapped so teardown remains deterministic. Replace the
current behavior that logs and clears cfg._protrain_wrapped with one that logs
the error (using LOG.debug/LOG.error) and then re-raises or returns an
error/exception to stop the wrap operation, ensuring cfg._protrain_wrapped
remains set when close() fails; apply this change to the block referencing
existing.close() and cfg._protrain_wrapped so the close-chain
(WrappedModel.close → ChunkManager.close → backend/threads) is enforced
deterministically.
In `@src/axolotl/integrations/protrain/profiler/batch_factory.py`:
- Around line 102-113: The current logic returns TASK_SEQ2SEQ_LM as soon as
cfg.is_encoder_decoder is true, which prevents the module-class fallback from
detecting concrete heads like T5ForSequenceClassification; change the order so
the concrete class-name check runs before the generic is_encoder_decoder
fallback: use type(model).__name__ (cls_name) and iterate
_ARCHITECTURE_SUFFIX_TASKS first, returning the matched task if any, and only
then fall back to checking getattr(cfg, "is_encoder_decoder", False) to return
TASK_SEQ2SEQ_LM; ensure the same symbols (cls_name, _ARCHITECTURE_SUFFIX_TASKS,
cfg, is_encoder_decoder, TASK_SEQ2SEQ_LM) are used so models with missing
config.architectures are classified by concrete class when possible.
In `@src/axolotl/integrations/protrain/profiler/hw_bench.py`:
- Around line 625-637: The recorded NCCL timings are being stored under the
original requested payload_bytes even though you round down to
elements_per_shard (element_size=4) and bench the smaller combined size; compute
the actual benchmarked size (actual_payload_bytes = elements_per_shard *
world_size * element_size) after calculating elements_per_shard and use that
actual_payload_bytes when labeling/storing results (e.g., where timings/tables
are written later in the loop) so the lookup tables and cost model are keyed by
the true measured payload size instead of the requested payload_bytes; update
any places that read/write using payload_bytes to use actual_payload_bytes
(refer to element_size, elements_per_shard, shard, payload_bytes, world_size).
In `@src/axolotl/integrations/protrain/profiler/on_demand.py`:
- Around line 254-264: The OnDemandTensorMgr currently allows re-entering the
same instance which causes duplicate hook registration and unsafe shared state;
modify __enter__ in class OnDemandTensorMgr to detect re-entry (e.g., check an
_entered flag) and raise an error if already entered, and ensure __exit__ clears
or resets that flag (set _entered = True on successful enter and set _entered =
False on exit) so hooks and _spills bookkeeping cannot be registered twice or
left in an inconsistent state; update both __enter__ and __exit__ methods to
enforce and maintain this guard around hook registration and
saved_tensors_hooks.
In `@src/axolotl/integrations/protrain/profiler/phase2.py`:
- Around line 511-515: This measurement helper currently calls
optimizer.zero_grad(set_to_none=True) and later relies on
model_state/optim_state but does not save/restore param.grad, which silently
drops caller-held accumulated gradients; update the helper (near the
optimizer.zero_grad call and where model_state/optim_state are used) to either
1) fail fast by detecting any parameter with non-None .grad and raising an error
unless grads are already clear, or 2) preserve and restore grads by snapshotting
all param.grad tensors before zeroing and restoring them after the
profiling/restore path; reference the optimizer.zero_grad(set_to_none=True) call
and the model_state/optim_state restore code when making the change so grads are
either validated or round-tripped.
- Around line 75-92: The final return forcing at least 1 skews results when all
chunks are persistent; change the end of this logic so the function returns the
computed need (allowing 0) instead of forcing max(1, need). Specifically, keep
the early check using n_persist and the computation over
layout.effective_persistent_ids and layout.block_to_chunks, but replace the
final return max(1, need) with simply return need (or return max(0, need)) so
fully-persistent layouts can bootstrap with a zero buffer.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: af6b3549-b924-481a-95e0-cc6fdf3ce3e1
📒 Files selected for processing (117)
.github/workflows/tests.yml.gitignoreexamples/protrain/3090-8b-lora.ymlpyproject.tomlscripts/benchmark_multi_gpu.pyscripts/protrain/measure_nccl.pyscripts/protrain/reshard_optim.pysrc/axolotl/integrations/protrain/BLOCK_MODE_OFFLOAD_DESIGN.mdsrc/axolotl/integrations/protrain/CHECKPOINT_DESIGN.mdsrc/axolotl/integrations/protrain/CHECKPOINT_DESIGN_PHASE2.mdsrc/axolotl/integrations/protrain/DESIGN.mdsrc/axolotl/integrations/protrain/__init__.pysrc/axolotl/integrations/protrain/api/__init__.pysrc/axolotl/integrations/protrain/api/checkpoint.pysrc/axolotl/integrations/protrain/api/hardware.pysrc/axolotl/integrations/protrain/api/model_wrapper.pysrc/axolotl/integrations/protrain/api/optim_wrapper.pysrc/axolotl/integrations/protrain/api/reshard.pysrc/axolotl/integrations/protrain/args.pysrc/axolotl/integrations/protrain/block/__init__.pysrc/axolotl/integrations/protrain/block/checkpoint.pysrc/axolotl/integrations/protrain/block/dispatcher.pysrc/axolotl/integrations/protrain/block/layout_rules.pysrc/axolotl/integrations/protrain/block/offload.pysrc/axolotl/integrations/protrain/block/strategy.pysrc/axolotl/integrations/protrain/block/swap.pysrc/axolotl/integrations/protrain/block/swap_pool.pysrc/axolotl/integrations/protrain/chunk/__init__.pysrc/axolotl/integrations/protrain/chunk/buffer_pool.pysrc/axolotl/integrations/protrain/chunk/layout.pysrc/axolotl/integrations/protrain/chunk/manager.pysrc/axolotl/integrations/protrain/chunk/optim.pysrc/axolotl/integrations/protrain/chunk/pinned_alloc.pysrc/axolotl/integrations/protrain/chunk/sizing.pysrc/axolotl/integrations/protrain/cost/__init__.pysrc/axolotl/integrations/protrain/cost/bandwidth.pysrc/axolotl/integrations/protrain/cost/memory.pysrc/axolotl/integrations/protrain/cost/runtime.pysrc/axolotl/integrations/protrain/plugin.pysrc/axolotl/integrations/protrain/profiler/__init__.pysrc/axolotl/integrations/protrain/profiler/batch_factory.pysrc/axolotl/integrations/protrain/profiler/cache.pysrc/axolotl/integrations/protrain/profiler/hw_bench.pysrc/axolotl/integrations/protrain/profiler/memory_deltas.pysrc/axolotl/integrations/protrain/profiler/on_demand.pysrc/axolotl/integrations/protrain/profiler/phase2.pysrc/axolotl/integrations/protrain/profiler/trace.pysrc/axolotl/integrations/protrain/runtime/__init__.pysrc/axolotl/integrations/protrain/runtime/hooks.pysrc/axolotl/integrations/protrain/runtime/scheduler.pysrc/axolotl/integrations/protrain/runtime/streams.pysrc/axolotl/integrations/protrain/search/__init__.pysrc/axolotl/integrations/protrain/search/exhaustive.pysrc/axolotl/integrations/protrain/search/knobs.pysrc/axolotl/integrations/protrain/types.pysrc/axolotl/utils/config/__init__.pysrc/axolotl/utils/environment.pytests/protrain/__init__.pytests/protrain/conftest.pytests/protrain/peft_edge_cases/__init__.pytests/protrain/peft_edge_cases/test_dora.pytests/protrain/peft_edge_cases/test_multi_adapter.pytests/protrain/peft_edge_cases/test_vision_lm_hybrid.pytests/protrain/test_adamw8bit_adapter.pytests/protrain/test_alpha_per_dtype.pytests/protrain/test_api.pytests/protrain/test_auto_wrap.pytests/protrain/test_batch_factory.pytests/protrain/test_block_manager.pytests/protrain/test_bnb_offload.pytests/protrain/test_chunk_manager.pytests/protrain/test_chunk_manager_distributed.pytests/protrain/test_chunk_manager_offload.pytests/protrain/test_chunk_optim_shutdown.pytests/protrain/test_cost_search.pytests/protrain/test_cross_mode_resume.pytests/protrain/test_enc_dec_smoke.pytests/protrain/test_full_ft_smoke.pytests/protrain/test_fused_lora_kernels.pytests/protrain/test_hw_bench.pytests/protrain/test_init_transient_peak.pytests/protrain/test_integration_2b.pytests/protrain/test_integration_7b.pytests/protrain/test_late_nccl_search_skip.pytests/protrain/test_lora_offload_mode.pytests/protrain/test_m5_cli_smoke.pytests/protrain/test_mandatory_persistent.pytests/protrain/test_math_equivalence.pytests/protrain/test_modec_external_baseline.pytests/protrain/test_modec_steady_peak_accuracy.pytests/protrain/test_multi_gpu_7b.pytests/protrain/test_multi_gpu_benchmark.pytests/protrain/test_offload_mode_m1.pytests/protrain/test_offload_mode_m2.pytests/protrain/test_offload_mode_m3.pytests/protrain/test_offload_mode_m4.pytests/protrain/test_optimizer_checkpoint.pytests/protrain/test_paged_adam_offload_mgpu.pytests/protrain/test_param_data_shape_preservation.pytests/protrain/test_peak_calibration.pytests/protrain/test_plugin_args_validators.pytests/protrain/test_plugin_auto_mode.pytests/protrain/test_plugin_e2e.pytests/protrain/test_plugin_early_dist_init.pytests/protrain/test_plugin_nccl_remeasure.pytests/protrain/test_profiler.pytests/protrain/test_quantization.pytests/protrain/test_resume_robustness.pytests/protrain/test_scheduler.pytests/protrain/test_seq_cls_smoke.pytests/protrain/test_sharded_lora_offload.pytests/protrain/test_single_stream_allocator.pytests/protrain/test_steady_state_calibration.pytests/protrain/test_swap.pytests/protrain/test_trace_skip_on_override.pytests/protrain/test_world_size_reshard.pytests/protrain/test_wrapped_model_close.py
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/integrations/protrain/profiler/hw_bench.py (1)
667-718:⚠️ Potential issue | 🟠 Major | ⚡ Quick winUse
actual_payload_bytesconsistently for the rank-0 debug path.After this change, the timing tables are keyed by
actual_payload_bytes, but Lines 717-718 still read them withpayload_bytes. On any custom payload list that does not divide evenly intoworld_size * 4, rank 0 will hit aKeyErrorafter the measurement loop.Suggested fix
- if rank == 0: - LOG.debug( - "measure_nccl payload=%dMiB gather=%.3fms reduce=%.3fms " - "(world=%d, %d iters)", - payload_bytes >> 20, - gather_table[payload_bytes] * 1000, - reduce_table[payload_bytes] * 1000, - world_size, - n_iters, - ) + if rank == 0: + LOG.debug( + "measure_nccl payload=%d bytes gather=%.3fms reduce=%.3fms " + "(requested=%d bytes, world=%d, %d iters)", + actual_payload_bytes, + gather_table[actual_payload_bytes] * 1000, + reduce_table[actual_payload_bytes] * 1000, + payload_bytes, + world_size, + n_iters, + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/integrations/protrain/profiler/hw_bench.py` around lines 667 - 718, The debug print in the rank==0 block mistakenly indexes gather_table and reduce_table with payload_bytes (which can be absent) instead of the measurement key actual_payload_bytes; update the LOG.debug call to use gather_table[actual_payload_bytes] and reduce_table[actual_payload_bytes] (leave the displayed payload size expression as-is if you want to report the requested payload_bytes) so the rank-0 path reads the same keys used during measurement and avoids KeyError when payloads don't evenly shard.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@src/axolotl/integrations/protrain/profiler/hw_bench.py`:
- Around line 667-718: The debug print in the rank==0 block mistakenly indexes
gather_table and reduce_table with payload_bytes (which can be absent) instead
of the measurement key actual_payload_bytes; update the LOG.debug call to use
gather_table[actual_payload_bytes] and reduce_table[actual_payload_bytes] (leave
the displayed payload size expression as-is if you want to report the requested
payload_bytes) so the rank-0 path reads the same keys used during measurement
and avoids KeyError when payloads don't evenly shard.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e3419305-4861-4767-8ae3-b8375d8b1ad4
📒 Files selected for processing (8)
src/axolotl/integrations/protrain/api/model_wrapper.pysrc/axolotl/integrations/protrain/chunk/buffer_pool.pysrc/axolotl/integrations/protrain/plugin.pysrc/axolotl/integrations/protrain/profiler/batch_factory.pysrc/axolotl/integrations/protrain/profiler/hw_bench.pysrc/axolotl/integrations/protrain/profiler/on_demand.pysrc/axolotl/integrations/protrain/profiler/phase2.pytests/protrain/test_modec_steady_peak_accuracy.py
💤 Files with no reviewable changes (1)
- src/axolotl/integrations/protrain/api/model_wrapper.py
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/axolotl/integrations/protrain/api/checkpoint.py (1)
246-248: 💤 Low valueRedundant import of
torch.
torchis already imported at module level (line 13); the local import here is unnecessary.Suggested fix
def _estimate_optim_state_bytes(optim: Any) -> int: """Estimated bytes for the optimizer's persisted Adam state (cluster-wide under Mode-C).""" - import torch - replicated = 0🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/integrations/protrain/api/checkpoint.py` around lines 246 - 248, The local import of torch inside _estimate_optim_state_bytes is redundant because torch is imported at module scope; remove the inner "import torch" statement from the _estimate_optim_state_bytes function so the function uses the module-level torch import instead (search for the function _estimate_optim_state_bytes to locate the change).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@src/axolotl/integrations/protrain/api/checkpoint.py`:
- Around line 246-248: The local import of torch inside
_estimate_optim_state_bytes is redundant because torch is imported at module
scope; remove the inner "import torch" statement from the
_estimate_optim_state_bytes function so the function uses the module-level torch
import instead (search for the function _estimate_optim_state_bytes to locate
the change).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8bd49e99-85f1-4891-bfc6-659b38043e66
📒 Files selected for processing (5)
src/axolotl/integrations/protrain/api/checkpoint.pysrc/axolotl/integrations/protrain/api/optim_wrapper.pysrc/axolotl/integrations/protrain/plugin.pytests/protrain/test_chunk_optim_shutdown.pytests/protrain/test_trace_skip_on_override.py
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/integrations/protrain/cost/runtime.py (1)
275-275: ⚡ Quick winReplace Greek
αin comments/docstrings to clear Ruff.These lines trip RUF002/RUF003. Swapping
αforalphaavoids lint noise without changing meaning.Also applies to: 282-282, 291-291, 319-319, 333-333, 354-354, 385-385, 408-408, 417-417, 421-421, 448-448
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/integrations/protrain/cost/runtime.py` at line 275, Replace the Greek letter α used in comments/docstrings with the ASCII word "alpha" to satisfy Ruff (RUF002/RUF003); search for occurrences of "α" (e.g., the comment starting "Per-component α clamp bounds; [0.5, 2.0] window..." and the other similar inline comments) and update them to "alpha" without changing surrounding wording or code.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/integrations/protrain/cost/runtime.py`:
- Around line 227-237: The code currently treats eff_h2d or eff_d2h == 0 as a
free transfer by using 0.0 for h2d/d2h; instead, fail closed when either
effective PCIe bandwidth is missing by detecting eff_h2d <= 0 or eff_d2h <= 0
and raising a clear exception (e.g., RuntimeError or ValueError with a message
like "missing effective PCIe bandwidth for H2D/D2H") so the estimator rejects
the candidate; update the block that computes h2d/d2h (using S_chunk, eff_h2d,
eff_d2h, nccl_reduce_s, is_backward, buffer_cached, collective) to raise rather
than return 0 when eff_* is non-positive.
---
Nitpick comments:
In `@src/axolotl/integrations/protrain/cost/runtime.py`:
- Line 275: Replace the Greek letter α used in comments/docstrings with the
ASCII word "alpha" to satisfy Ruff (RUF002/RUF003); search for occurrences of
"α" (e.g., the comment starting "Per-component α clamp bounds; [0.5, 2.0]
window..." and the other similar inline comments) and update them to "alpha"
without changing surrounding wording or code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 373d317e-8857-4268-a197-082a819f3794
📒 Files selected for processing (1)
src/axolotl/integrations/protrain/cost/runtime.py
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
Overnight CR triage complete. Cycles: 9 (8 cron fires + 1 manual setup) CR threads addressed:
Concurrent docstring condensation pass: 40 files, ~16,440 lines removed (model_wrapper -1534, chunk/manager -1546, cost/runtime -1125, api/checkpoint -758, cost/memory -731 leading the way). CI: pre-commit pass on every commit. PyTest passing on 3/6 matrices so far (3.14 x 2.10, source-dist 3.12 x 2.9, source-dist 3.14 x 2.10 — 11-19 min each). Remaining 3 still running. CodeRabbit re-review triggers posted at each push + every 30 min idle. Final branch tip: 2617da6 |
|
Benchmark suite complete. Headline result: Meta-Llama-3-8B BF16 LoRA on a single 3090 — vanilla 15.75 GiB → ProTrain Mode A 3.30 GiB (80% memory reduction) with 7x throughput cost at bs=1 seq=256 (matches the PR's documented bs=1 limitation). Phase 1/2/3 tests all passed (313/4-skip plus GPU + multi-GPU regression). 13B/27B BF16 not measured — those need |
|
Bug found and fixed in this branch as commit c838ac1. Symptom: Reproduction: Root cause: bnb 4-bit quantization only touches Fix: wrap each Tested vanilla-baseline data so far: Qwen3.5-9B + 4-bit qlora vanilla 8.39 GiB, Llama-2-13b + 4-bit qlora vanilla 7.91 GiB peak / 50 steps in 35.13s / loss 1.374→1.074 (M3 single-3090 vanilla baseline confirmed fits). Re-running the ProTrain side with this fix in v9. |
|
Follow-up on the Qwen3.5 + ProTrain incompatibility. After fixing the dtype mismatch in commit 1. Autocast detection (commit 2. Mamba-style block discovery (commit Verified data so far (single 3090, BF16 LoRA / 4-bit qlora, seq=256, 50 steps):
Currently re-running Qwen3.5-9B + 4-bit qlora ProTrain with both fixes; Llama-2-13b + 4-bit ProTrain (which doesn't hit either Qwen3.5-specific bug, so loaded with the pre-fix code) is mid-exhaustive-search ( |
|
Fix chain verified: Qwen3.5-9B + 4-bit qlora + ProTrain trains end-to-end on a single 3090 with commits c838ac1 + cf00f5d + e57aa69 applied. For comparison the 9B+4bit vanilla baseline was 8.39 GiB at 1.42 samples/s; ProTrain runs at 2.22 samples/s but uses more GPU memory because the picked Mode A keeps most chunks persistent and adds activation/chunk-gather buffers — a small model with no offload pressure, where ProTrain's overhead is paid without amortization (matches the PR's documented |
|
Benchmark v8 (true qlora 4-bit) complete. Key results: Vanilla baselines (4-bit qlora, single 3090, BF16 LoRA, bs=1 seq=256, 50 steps):
ProTrain runs: v8's were affected by the Qwen3.5 incompatibilities (dtype + linear_attn block discovery) — fixed in commits c838ac1, cf00f5d, e57aa69. Llama-2-13b + 4-bit + ProTrain hit the per-config timeout (75 min) mid exhaustive search (N_chunk=162 N_block=40 is a much bigger search space than 8B's 130×32). Qwen3.5-9B + 4-bit ProTrain re-runs after the fix chain landed succeeded at 12.6 GiB peak in 22.49s; documented in issuecomment-4513338847. |
|
Save / merge / resume matrix complete (v10) — all 8 scenarios passed rc=0 on a single 3090. Total wall time: 11 minutes.
M5/M6 protrain_optim/ checkpoint validated end-to-end: Loss bump noted in Qwen3-0.6B ProTrain resume (1.41 → 1.91) investigated and dismissed — see issuecomment-next. Same bump appears in vanilla resume; root cause is the cosine LR schedule re-fit when |
|
v11 gap-fill complete (single-GPU + multi-GPU) — see /home/rgilbreth/Desktop/protrain-benchmark-report.md sections 'Gap-fill v11'. Headlines:
|
|
v12 gap-fill complete (single-GPU + multi-GPU with flash_attention: true) — see /home/rgilbreth/Desktop/protrain-benchmark-report.md. Highlights (all with FA on):
Additionally: a ProTrain Phase-2 false-positive on bnb 4-bit companion-buffer keys (.absmax/.quant_map/.nested_/.quant_state.) was diagnosed and patched (commit |
|
Phase-2 chunked-runtime gate false-positives on bnb 4-bit models.
if _result.unexpected_keys: bitsandbytes This trips the gate, aborts Phase-2, and forces a v8 cost-model Phase-2 chunked measurement raised RuntimeError: ... 3030 unexpected The fallback is correct but it surrenders the ~10% accuracy improvement Suggested fix: build an Companion correctness: Not a blocker for shipping the 27B headline (training converged, peak Fix applied as Post-fix v13 re-validation on GPU 2 confirmed
The 27B-seq=128 re-run on the same 24 GiB 3090 OOMed at model load |
90f60a5 to
fe3f38f
Compare
e553305 to
f58966b
Compare
|
Closing to reopen a fresh fork PR and force a full CodeRabbit review on the current branch state. |
Summary
Add ProTrain as an Axolotl integration: a chunk-managed weight-offload runtime
with a per-dtype memory cost model, designed to train larger models on smaller
GPUs without sacrificing optimizer-state quality.
Highlights
bnb.AdamW8bitcross-mode resume
bnb.AdamW8bit+ paged-variant optimizer adapter(
alpha_fragmentation_for_dtypeincost/memory.py): fp16 / bf16 / 8-bithold α=1.10, bnb-4-bit drops to α=0.75 to match the empirical
α_measured ≈ 0.70 across the 4-bit matrix.
What's included
src/axolotl/integrations/protrain/— runtime, scheduler, chunk manager,per-PEFT-LoRA-container hooks, cost model, profiler, plugin args
tests/protrain/— CPU + GPU regression coverage (303 PASSED / 4 SKIPPED atbranch tip on default markers; GPU markers verified on a single-GPU rig and
on a 4×3090 rig)
src/axolotl/utils/environment.py— fail-closed P2P probe used by themulti-GPU launch path
examples/protrain/3090-8b-lora.yml— reference configTest plan
tests/protrain/default-marker regressiontests/protrain/GPU-marker subset on a single-GPU rig (cost-model files,chunk manager, bnb sweep, 4-bit offload sweep)
test_real_multigpu_cross_mode_resume_{a_to_c,c_to_a},test_paged_adam_offload_mgpu_no_ddp_broadcast_crashm0_artifacts/Known limitations
layout, per-iter scheduler walk, PEFT-LoRA container hooks). At bs=1 the
memory headroom isn't load-bearing, so the cost isn't amortized. Measured
~31.75% slower than fused-no-ProTrain at bs=1 / seq=512 / single-GPU.
The acceptance criterion is scoped to multi-GPU bs ≥ 4 or to
configurations where bare DDP OOMs (the 13B / 30B single-3090 headlines).
Init-time chunk residency phenomenon, not a fragmentation one; documented
as an "init window" in
DESIGN.md.offload-mode forward path is under-counted. Tracked as a follow-up cost-model
accuracy item.
Summary by CodeRabbit