Skip to content
Closed
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
12f1a12
chore(protrain): fix pre-existing mypy on model_wrapper:386 + formatting
thad0ctor May 9, 2026
6c3fcb1
feat(protrain): enable FlashAttention in canonical LoRA example (M4)
thad0ctor May 9, 2026
45a934f
feat(protrain): allow load_in_8bit / load_in_4bit (M2+M3 Mode A)
thad0ctor May 9, 2026
1fe8ddb
feat(protrain): integrate fused LoRA kernels via container hooks (M1)
thad0ctor May 9, 2026
a6dfc37
feat(protrain): add bnb 8-bit AdamW optimizer adapter (M2.5)
thad0ctor May 9, 2026
823b4db
feat(protrain): reject unsupported optimizers at config load (M6B)
thad0ctor May 9, 2026
c857675
test(protrain): PEFT edge-case smoke tests (M6A)
thad0ctor May 10, 2026
3ce55a8
test(protrain): cross-mode (A↔C) resume smoke tests (M6C)
thad0ctor May 10, 2026
6cb5c84
fix(protrain): force_all_persistent suppresses trace-pass on-demand e…
thad0ctor May 10, 2026
a868978
test(protrain): pin bnb 8-bit/4-bit + ProTrain offload-mode (M3 audit…
thad0ctor May 10, 2026
91e0912
fix(p2p): rank-symmetric check_cuda_p2p_support + measure_nccl barrier
thad0ctor May 10, 2026
a00df59
test(protrain): real multi-GPU cross-mode resume xfail tests (M6C)
thad0ctor May 10, 2026
016dac8
docs(protrain): document mode-pinned checkpoints + Mode C plain LoRA gap
thad0ctor May 10, 2026
4856090
feat(protrain): per-container PEFT-LoRA gather in on-demand profiler …
thad0ctor May 10, 2026
a71f26e
feat(protrain): cross-mode resume hook for HF Trainer load_checkpoint…
thad0ctor May 10, 2026
4eb6da6
feat(protrain): skip profiler trace pass when explicit override knobs…
thad0ctor May 10, 2026
32663f3
feat(protrain): runtime-side per-LoRA-container gather hooks (M6C-fix-3)
thad0ctor May 10, 2026
008b62e
docs(protrain): update Mode C PEFT-LoRA section per M6C-fix-3 close
thad0ctor May 10, 2026
b5ffa3d
refactor(protrain): synchronous gather in ensure_chunks_resident (M6C…
thad0ctor May 10, 2026
b787acb
feat(protrain): late-NCCL-re-search skip on overrides + autocast diag…
thad0ctor May 11, 2026
0f44bfb
feat(protrain): per-LoRA-container post-fwd/bwd hooks (M6C-fix-6 hard…
thad0ctor May 11, 2026
55d9237
docs(protrain): formalize M6C-fix end-of-chain limitation in DESIGN.md
thad0ctor May 11, 2026
c0da428
feat(protrain): shape-preserving release-state placeholder (M6C-fix-7…
thad0ctor May 11, 2026
17ffb8d
feat(protrain): close M6C chain — DDP init-sync bypass for chunk-mana…
thad0ctor May 11, 2026
6febed8
docs(protrain): close M6C limitation section — multi-GPU plain LoRA M…
thad0ctor May 11, 2026
2fcc1fc
feat(protrain): per-dtype α fragmentation factor (Coverage audit Bloc…
thad0ctor May 12, 2026
f74c559
test(protrain): regress paged_adamw_8bit + Mode C multi-GPU @ seq=2048
thad0ctor May 12, 2026
d1ef2dd
chore(protrain): address CodeRabbit PR #21 quick-win nits
thad0ctor May 12, 2026
3aff348
chore(protrain): apply CodeRabbit re-review quick-win nits (round 2)
thad0ctor May 12, 2026
09cf657
feat(protrain): in-process rebuild lifecycle (D1/D2/D3) + P2P fail-cl…
thad0ctor May 12, 2026
d7624fb
test(protrain): address remaining CodeRabbit test-quality deferrals (…
thad0ctor May 12, 2026
6961490
feat(protrain): scheduler SWAP-stream safety barrier (R3-#1) + resume…
thad0ctor May 12, 2026
e6d8a1a
test(protrain): CodeRabbit R3 test-quality fixes (R3-#2, #3, #4, #5, #8)
thad0ctor May 12, 2026
b61f04e
feat(protrain): predict iter-1 init-transient peak (audit Block G)
thad0ctor May 12, 2026
aa0c6ba
fix(protrain): Mode-C steady-peak CKPT-chain accounting (audit Block G)
thad0ctor May 12, 2026
c996ce9
fix(protrain): close CodeRabbit R4 review (1 Critical + 2 Major + 1 M…
thad0ctor May 12, 2026
f09be09
chore(protrain): apply ruff-format reformats to cost/runtime + test_c…
thad0ctor May 12, 2026
55377e5
chore(protrain): normalize confusable unicode in commentary/docstring…
thad0ctor May 12, 2026
69eb152
fix(protrain): CodeRabbit full-review Majors — 4 real correctness gap…
thad0ctor May 12, 2026
40bb8ad
chore(protrain): CodeRabbit full-review Minors — docs consistency + t…
thad0ctor May 12, 2026
67372c3
fix(test): test_chunk_optim_shutdown caplog → mock.patch on LOG (CI f…
thad0ctor May 13, 2026
db094b5
chore(protrain): trim non-WHY comments and address CodeRabbit findings
thad0ctor May 21, 2026
cc72ca4
docs(protrain): document deferred non-compute α decomposition (ticket B)
thad0ctor May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/protrain/3090-8b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ tf32: false
# validator will refuse the config.
gradient_checkpointing: false

flash_attention: false
# M0 spike validated FA composes cleanly with ProTrain on this config.
flash_attention: true
xformers_attention: false

# IMPORTANT: Axolotl auto-enables fused Triton LoRA kernels (q/k/v/o/MLP)
Expand Down
77 changes: 76 additions & 1 deletion src/axolotl/integrations/protrain/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,41 @@ Mirrors `plan.md`:

## Design Decisions (previously open questions, now resolved)

1. **α fragmentation factor = 1.10** — matches paper's "up to 10% overestimate" (§3.3). M1 records ground truth; M4 can recalibrate if observed 3090 fragmentation diverges.
1. **α fragmentation factor — per-dtype lookup + Mode-C CKPT-chain accounting** (Coverage audit Block G, Phase 2).

*Per-dtype α (landed in commit `2fcc1fcf`).* The paper's α=1.10 default matches "up to 10% overestimate" (§3.3) when measured against fp16 / bf16 / 8-bit configurations. Block G's empirical re-derivation across the M5 / M0-spike / Block-A matrices showed α=1.10 is mildly conservative for fp16 (α_measured ≈ 0.96) and 8-bit (α_measured ≈ 0.93), but over-predicts bnb-4-bit Mode-A peak by ~37 % (α_measured ≈ 0.70 across four 8B-Llama rows). The cost model now dispatches through `alpha_fragmentation_for_dtype(bpe)` (`cost/memory.py`): fp16 / bf16 / 8-bit (bpe ≥ 1.0) → α=1.10 (`ALPHA_FRAGMENTATION`); bnb 4-bit (bpe = 0.5 via `Params4bit` packing) → α=0.75 (`ALPHA_FRAGMENTATION_4BIT`, slightly conservative vs the 0.70 empirical floor). The dominant bpe is detected in `protrain_model_wrapper` by walking `model.named_parameters()` and picking the bpe class with the largest aggregate logical-element count (bnb `Params4bit` instances are mapped to bpe=0.5 explicitly, since their storage `element_size()` is 1 but each byte packs two 4-bit values). Tests: `tests/protrain/test_alpha_per_dtype.py`.

*Mode-C steady-peak CKPT-chain accounting (this work).* Block G also observed a seq-dependent under-prediction in bnb-4-bit Mode-C (offload-pool chunk-offload + checkpoint-everywhere) configurations:

| Config (30B Llama, 4-bit Mode-C, n_persist=0, n_buffer=12, n_checkpoint=60) | pred GiB | meas steady | α_steady = meas / pred |
|---|---:|---:|---:|
| seq=512 (`ext_30b_safe.log`) | 2.49 | 2.91 | 1.169 |
| seq=1024 (`ext_30b_seq1024.log`) | 2.50 | 3.50 | 1.400 |
| seq=2048 (`ext_30b_seq2048.log`) | 2.54 | 4.68 | 1.843 |

The α_steady drift with seq is the diagnostic: the predictor's activation contribution was effectively flat across seq for all-CKPT block_maps. Root cause in `cost/memory.py::estimate_peak`: `retained_none_bytes` only accumulates NONE/OFFLOAD blocks, and the per-CKPT-first-op `ckpt_extra` bump is taken as a per-op max — so an all-CKPT cfg paid for ONE block's recompute window but nothing for the *chain* of block-input residuals that the activation-checkpointing framework (`torch.utils.checkpoint` with `use_reentrant=True`, the production wrap) retains across the WHOLE backward window. With 60 CKPT blocks on Llama-30B that chain is `60 × bs × seq × hidden × dtype_bytes` — the missing seq-dependent term.

*Fix.* `estimate_peak` now adds a `ckpt_chain_bytes = sum(activation_sizes[bid] for bid in CKPT blocks)` term that:

- Is added to every op-walk candidate as a constant (chain is live for the entire backward, not at any single op).
- Is added to the `raw_peak == 0` static fallback (the explicit-override `synth_trace_from_overrides` skip-trace path the Mode-C audit runs all take — `op_order=()` so the per-op walk doesn't execute).
- Is disjoint by construction from `retained_none_bytes` (NONE/OFFLOAD vs CKPT in the per-block loop above).

To avoid double-counting, the per-CKPT-first-op recompute bump is now sized at the BLOCK-INTERNAL delta only — `ckpt_extra = max(0, saved_bytes_proxy[bid] - activation_sizes[bid])` — since `activation_sizes[bid]` (the block-output / next-block-input residual proxy) is already accounted for by `ckpt_chain_bytes`. The recompute window only materializes block-internal saved tensors (Q/K/V projections, attention scores, FFN intermediates) on top of the persisted chain. In synth / toy traces where `_saved_tensor_bytes_per_block` falls back to `activation_sizes` (no `steady_fwd_block_peak_bytes` data), the internal delta is 0 and `ckpt_chain_bytes` carries the full per-block contribution. The matching enc-dec cross-attention gate (`cross_attn_persist_bytes`) skips its surcharge when the encoder-last block is in CKPT — already covered by the chain term.

*Post-fix accuracy on the audit data points* (`estimate_peak` directly, NOT through the model wrapper's `_calibrate_peak_with_actual_chunk_bytes` post-calibration which adds a further ~0.6–0.9 GiB of actual_persistent_local correction):

| seq | estimate_peak GiB | measured | α_steady |
|----:|-----------------:|--------:|---------:|
| 512 | 2.04 | 2.91 | 1.43 |
| 1024 | 2.80 | 3.50 | 1.25 |
| 2048 | 4.34 | 4.68 | 1.08 |

α_steady is significantly tighter at high seq (1.84 → 1.08) and slightly looser at low seq (1.17 → 1.43, partly the per-dtype α shift from 1.10 to 0.75 since the audit). The chain term gives the per-seq scaling the predictor lacked; absolute accuracy at low seq is bottlenecked by the wrapper-side calibration, which is out of scope for the cost-model fix.

Tests: `tests/protrain/test_modec_steady_peak_accuracy.py` (pins the per-seq scaling + ±35% tolerance against the three audit data points). Existing tests adjusted: none — the `cost/memory.py` op-walk's recompute-bump refinement is backwards-compatible in every fallback regime (`_saved_tensor_bytes_per_block == activation_sizes`); the cap path and all cap-based tests are unchanged.

*Out of scope.* The iter-1 transient observed at bnb-4-bit Mode-C (~6.9× pred during the model-load → `materialize_offload` window) is an init-time chunk-residency phenomenon, not a fragmentation or activation-accounting one, and is documented separately as an "init window" not covered by α. Tracked as the remaining open audit item.
2. **Pinned-memory allocator:** `ctypes` → `cudaHostAlloc` directly. ~50 LOC, zero new deps, matches App B.2 precisely (avoids `CUDAHostAllocator` pow-2 rounding). DeepSpeed's `PinnedMemoryAllocator` rejected: may inherit same wart, adds import-graph weight.
3. **CPU FusedAdam source:** `deepspeed.ops.adam.DeepSpeedCPUAdam`. Paper builds directly on ZeRO-Offload's CPU Adam. Pure-Python reimpl is >10× slower and would collapse the T_bwd / T_cpu_optim overlap window the cost model assumes. DeepSpeed is already in Axolotl's env.
4. **S_chunk grid:** `{32, 64, 128, 256} MB`. 7B Llama blocks are ~200 MB fp16 → chunks want to be block-scale. 16 MB is too fine-grained; per-chunk sync overhead dominates. M2 agent extends the grid if optimum lands at an endpoint.
Expand Down Expand Up @@ -323,3 +357,44 @@ App B.2 of the paper has **two distinct components**, each addressing a differen
#### Measurement status

Peak-memory delta from the wire-up has not been measured on RTX 3090 reference hardware in this commit (the `α = 1.10` fragmentation factor — item 1 above — was already absorbing the un-wired fragmentation cost in the cost model). To-be-measured in a follow-up: re-run the M1 profiler ground-truth before and after the wire-up; if peak drops by more than ~5% on a 1.5B-param target shape, recalibrate `α` downward. The single-stream wire-up's correctness — the `record_stream` discipline at every cross-stream site — has been validated by the new `tests/protrain/test_single_stream_allocator.py` test (heap-affinity assertion via free-then-reallocate fragmentation probe + nested-stream context-manager composition test). The pinned-host wire-up's correctness — total pool bytes equals the sum of per-chunk aligned bytes — is asserted by `tests/protrain/test_chunk_manager_offload.py::test_materialize_offload_uses_precise_pinned_pool`.

## Known Limitations

### Checkpoint mode handling (Phase 2 M6C)

ProTrain checkpoints encode the mode they were produced under (Mode A all-persistent vs. Mode C sharded-with-offload), so the resume path must reconcile the on-disk layout with the resumed-runtime layout. Two cases:

- **Same-mode resume** (Mode A → Mode A, Mode C → Mode C) is the simple path — the chunk layout and optimizer-state shapes are identical so HF Trainer's `_load_from_checkpoint` copies straight in.
- **Cross-mode resume** (Mode A → Mode C, Mode C → Mode A) is bridged by **M6C-fix-1** (`a71f26e9`): the resume hook in `plugin._install_resume_hook` calls `restore_to_gpu()` on every offloaded chunk BEFORE HF copies the loaded weights into full-shape `param.data` slots, then re-runs `materialize_offload` afterward and rebuilds the per-chunk optimizer adapter. Without this hook HF would write into the zeroed non-persistent slots and ProTrain's first `gather` would overwrite the loaded state with the (still-zero) CPU shadow. The hook is registered as an HF Trainer callback that fires after `_load_from_checkpoint` finishes; ProTrain interleaves its `gather` between weight load and the first forward in plugin code rather than forking HF.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Clarify when the resume hook actually runs.

This paragraph says ProTrain restores full-shape tensors before HF copies weights, then says the hook fires after _load_from_checkpoint finishes. Those are different insertion points, so the recovery sequence is ambiguous as written.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/protrain/DESIGN.md` around lines 368 - 369, The
paragraph is ambiguous about ordering: update the text to clearly state that
plugin._install_resume_hook registers a Trainer callback that runs immediately
after HF's _load_from_checkpoint completes but before HF performs its final
parameter copy into full-shape param.data slots; the hook calls restore_to_gpu()
on offloaded chunks, then ProTrain lets HF finish copying, then calls
materialize_offload and rebuilds per-chunk optimizer adapters so that ProTrain's
first gather sees the restored weights rather than zeroed CPU shadows—reference
plugin._install_resume_hook, restore_to_gpu(), materialize_offload, gather and
_load_from_checkpoint in that clarified sequence.

Real-multigpu cross-mode resume coverage (4×3090, sharded Mode C, Llama-3-8B + LoRA): both `test_real_multigpu_cross_mode_resume_a_to_c` and `test_real_multigpu_cross_mode_resume_c_to_a` PASS as of the full M6C-fix-1..8 chain. See § "Standard PEFT-LoRA in Mode C" below for the chain's other layers (which closed PEFT-LoRA Mode-C correctness on top of the resume-hook fix).

### Standard PEFT-LoRA in Mode C (Phase 2 M6C)

Plain `peft` LoRA on top of an unquantized base is **supported in single-GPU offload mode** as of `M6C-fix-2` + `M6C-fix-3` (per-PEFT-LoRA-container gather hooks installed at both profiler-trace and runtime-scheduler surfaces). The chain works as follows:

- `profiler/on_demand.py::_find_peft_lora_containers` discovers any module with direct trainable LoRA factors (`lora_A` / `lora_B` / `lora_magnitude_vector` / `lora_embedding_*`). Pre-forward and pre-backward gather hooks are installed at the *container* granularity (parallel to M1's fused-kernel-container strategy), so the LoRA factor sub-chunks are GPU-resident before PEFT's `LoraLayer.forward` casts them to bf16.
- `runtime/hooks.py` + `runtime/scheduler.py::ensure_chunks_resident` install the same container-granularity hooks on the live training scheduler. Without this, the runtime's block-level gather (which assumes per-block chunk granularity) leaves the LoRA sub-chunks released until after the PEFT cast op records its autograd shape, producing the canonical `ToCopyBackward0 returned an invalid gradient at index 0 - got [N, R] but expected shape compatible with [0]` failure.

**Multi-GPU sharded mode (`protrain_zero3_shard: true, world_size > 1`) is supported for plain LoRA** as of the M6C-fix-1 through M6C-fix-8 chain (8 commits). Each fix closed a layer of the failure stack:

- **fix-1** (`a71f26e9`) — cross-mode resume hook for HF Trainer `_load_from_checkpoint`.
- **fix-2** (`4856090e`) — per-PEFT-LoRA-container gather hooks in `profiler/on_demand.py`.
- **fix-3** (`32663f30`) — runtime-side per-LoRA-container gather hooks in `runtime/hooks.py`.
- **fix-4** (`b5ffa3d9`) — synchronous gather in `Scheduler.ensure_chunks_resident`.
- **fix-5** (`b787acb5`) — late-NCCL-re-search skip on explicit-override paths + autocast diagnostic.
- **fix-6** (`0f44bfb6`) — pre/post forward+backward quartet hooks per LoRA container.
- **fix-7** (`c0da4282`) — shape-preserving release-state placeholder (closes the `ToCopyBackward0 / TBackward0 ... shape compatible with [0]` autograd shape-capture error class via `scratch.expand(slot.shape)` views that preserve `param.size()` metadata across release/re-gather).
- **fix-8** (`17ffb8d1`) — DDP `init_sync=False` bypass for chunk-managed params (closes the residual `more than one element of the written-to tensor refers to a single memory location` from DDP's construction-time `_sync_module_states._broadcast_coalesced` writing into the expand-view placeholder).

Multi-GPU verification (4×3090, sharded Mode C, Llama-3-8B + LoRA): `test_real_multigpu_cross_mode_resume_a_to_c` PASSES (Phase 1 Mode A 5 steps + Phase 2 Mode C resume steps 6..10; losses 1.093 → 0.832); `test_real_multigpu_cross_mode_resume_c_to_a` PASSES (Phase 1 Mode C 5 steps + Phase 2 Mode A resume steps 6..10).

Architecturally, ProTrain now owns the parallelism contract for chunk-managed parameters end-to-end: per-rank deterministic partition via `materialize_offload`, sharded gather via `_gather_sharded`, `reduce_scatter` on backward via `reduce_grads_and_offload`, and the DDP construction-time broadcast bypass keeps DDP from clobbering the sharded layout with its replicated broadcast assumption.

**Supported configurations (no workaround needed):**

- **Single-GPU plain fp16 / bf16 LoRA in offload mode** — works directly as of M6C-fix-3; no special config beyond `protrain_force_all_persistent: false` and the override knobs.
- **Multi-GPU sharded plain fp16 / bf16 LoRA in offload mode** — works as of the full M6C-fix-1..8 chain. The runtime/profiler-side gather hooks (fix-2, fix-3, fix-4, fix-6), the shape-preserving release-state placeholder (fix-7), and the DDP init-sync bypass (fix-8) together close the chain that previously surfaced as `ToCopyBackward0 ... shape compatible with [0]` and DDP `_sync_module_states._broadcast_coalesced` shared-storage hazards.
- **Quantized base + LoRA** — pair LoRA with bnb 4-bit or 8-bit weight quantization. `bitsandbytes.nn.Linear4bit` / `Linear8bitLt` use typed `param.data` views that survive the non-persistent slot lifecycle in both single- and multi-GPU; the M3 13B headline test exercises this combination.

Coverage: `tests/protrain/test_lora_offload_mode.py` (22 tests, single-GPU plain LoRA Mode C end-to-end, all PASS); `tests/protrain/test_cross_mode_resume.py` real-multigpu tests `_a_to_c` and `_c_to_a` PASS as of M6C-fix-8 (xfail markers removed in commit `17ffb8d1`); `tests/protrain/test_paged_adam_offload_mgpu.py` regresses the bnb 4-bit + paged_adamw_8bit + Mode C at seq=2048 multi-GPU path that M6C-fix-8 also closed. The M6C report under `docs/protrain/` traces the historical failure modes.
Loading
Loading