diff --git a/examples/protrain/3090-8b-lora.yml b/examples/protrain/3090-8b-lora.yml index a521379f1b..20a40f5464 100644 --- a/examples/protrain/3090-8b-lora.yml +++ b/examples/protrain/3090-8b-lora.yml @@ -85,7 +85,8 @@ tf32: false # validator will refuse the config. gradient_checkpointing: false -flash_attention: false +# M0 spike validated FA composes cleanly with ProTrain on this config. +flash_attention: true xformers_attention: false # IMPORTANT: Axolotl auto-enables fused Triton LoRA kernels (q/k/v/o/MLP) diff --git a/src/axolotl/integrations/protrain/DESIGN.md b/src/axolotl/integrations/protrain/DESIGN.md index d40dea9ea4..abe8caccb9 100644 --- a/src/axolotl/integrations/protrain/DESIGN.md +++ b/src/axolotl/integrations/protrain/DESIGN.md @@ -46,7 +46,7 @@ src/axolotl/integrations/protrain/ ├── cost/ │ ├── __init__.py │ ├── runtime.py # Eqs. 2–7, per-chunk max(compute, comm) roofline -│ ├── memory.py # Eqs. 8–11, op-walk peak + α=1.10 fragmentation +│ ├── memory.py # Eqs. 8–11, op-walk peak + per-dtype fragmentation alpha (see Design Decision 1) │ └── bandwidth.py # contention model when n_swap>0 competes with prefetch ├── search/ │ ├── __init__.py @@ -108,14 +108,14 @@ Every entry: Inputs · Outputs · Paper ref · Milestone. - `dispatcher.py` — `wrap_block(block: nn.Module, mode: BlockMode) -> nn.Module`. §3.1.2. - `checkpoint.py` — thin wrapper over `torch.utils.checkpoint.checkpoint` (use_reentrant=False). §3.1.2. - `swap.py` — `SwappedBlock`: wraps the block's forward in a `torch.autograd.graph.saved_tensors_hooks` context so **every autograd-saved tensor** (not just the block output) is D2H-copied to a pinned-host slot on `_swap_stream` in forward and H2D-copied back on `_swap_stream` in backward, with cross-stream event handshake against the default compute stream. Pool + stream are injected post-construction via `attach_runtime`; wrapper lifetime spans one fwd+bwd pair, and memory accounting must charge the sum of saved-tensor bytes (activations, RNG state, intermediate tensors), not just the block output. §3.1.2. -- `swap_pool.py` — `ActivationSwapPool`: pinned-host slot pool sized to `n_swap × prefetch_depth × max_act_bytes`. Backed by one `PinnedHostMemory` allocation; slot acquire/release tracked Python-side. §3.1.2. +- `swap_pool.py` — `ActivationSwapPool`: pinned-host slot pool sized to `n_swap x prefetch_depth x max_act_bytes`. Backed by one `PinnedHostMemory` allocation; slot acquire/release tracked Python-side. §3.1.2. - `offload.py` — Option B path: runs a non-persistent chunk's owning block under `BlockMode.OFFLOAD` (no recompute), re-gathering the chunk for backward and offloading after fwd. See `BLOCK_MODE_OFFLOAD_DESIGN.md` §3 / §6 for the storage-ptr book-keeping and runtime hook contract. - `layout_rules.py` — `assign_modes(n_swap, n_checkpoint, n_offload, N_block) -> BlockStrategyMap`. Swap-early / unopt-late / interleave; `n_offload` honors the unopt-late rule (`BLOCK_MODE_OFFLOAD_DESIGN.md` §5.1). §3.1.2. ### cost/ (M4) - `runtime.py` — `estimate_runtime(cfg, trace, layout) -> float`. Implements **Eqs. 2–7**: `T_iter = T_fwd + max(T_bwd + T_gpu_optim, T_cpu_optim)`, per-chunk `max(compute, comm)` roofline. §3.3, App A.1. -- `memory.py` — `estimate_peak(cfg, trace, layout, block_map) -> int`. Implements **Eqs. 8–10** (op-walk) and **Eq. 11** (α = 1.10 fragmentation). Bumps are added at the first op of each `BlockMode.CKPT` block (recompute) and additionally at the first op of each `BlockMode.OFFLOAD` block (Option B backward gather), so both block types contribute to the per-block backward memory bump. §3.3, App A.2. +- `memory.py` — `estimate_peak(cfg, trace, layout, block_map) -> int`. Implements **Eqs. 8–10** (op-walk) and **Eq. 11** (per-dtype fragmentation alpha — `ALPHA_FRAGMENTATION = 1.10` for fp16 / bf16 / 8-bit; `ALPHA_FRAGMENTATION_4BIT = 0.75` for bnb 4-bit via `alpha_fragmentation_for_dtype(hw.dominant_param_bytes_per_element)`; see Design Decision 1 for the audit-data-driven calibration). Bumps are added at the first op of each `BlockMode.CKPT` block (recompute) and additionally at the first op of each `BlockMode.OFFLOAD` block (Option B backward gather), so both block types contribute to the per-block backward memory bump. §3.3, App A.2. - `bandwidth.py` — `effective_bw(cfg, hw) -> float`. Derates prefetch BW when `n_swap > 0`. §3.3. ### search/ (M4) @@ -275,11 +275,45 @@ Mirrors `plan.md`: ## Design Decisions (previously open questions, now resolved) -1. **α fragmentation factor = 1.10** — matches paper's "up to 10% overestimate" (§3.3). M1 records ground truth; M4 can recalibrate if observed 3090 fragmentation diverges. +1. **alpha fragmentation factor — per-dtype lookup + Mode-C CKPT-chain accounting** (Coverage audit Block G, Phase 2). + + *Per-dtype alpha (landed in commit `2fcc1fcf`).* The paper's alpha=1.10 default matches "up to 10% overestimate" (§3.3) when measured against fp16 / bf16 / 8-bit configurations. Block G's empirical re-derivation across the M5 / M0-spike / Block-A matrices showed alpha=1.10 is mildly conservative for fp16 (alpha_measured ≈ 0.96) and 8-bit (alpha_measured ≈ 0.93), but over-predicts bnb-4-bit Mode-A peak by ~37 % (alpha_measured ≈ 0.70 across four 8B-Llama rows). The cost model now dispatches through `alpha_fragmentation_for_dtype(bpe)` (`cost/memory.py`): fp16 / bf16 / 8-bit (bpe ≥ 1.0) → alpha=1.10 (`ALPHA_FRAGMENTATION`); bnb 4-bit (bpe = 0.5 via `Params4bit` packing) → alpha=0.75 (`ALPHA_FRAGMENTATION_4BIT`, slightly conservative vs the 0.70 empirical floor). The dominant bpe is detected in `protrain_model_wrapper` by walking `model.named_parameters()` and picking the bpe class with the largest aggregate logical-element count (bnb `Params4bit` instances are mapped to bpe=0.5 explicitly, since their storage `element_size()` is 1 but each byte packs two 4-bit values). Tests: `tests/protrain/test_alpha_per_dtype.py`. + + *Mode-C steady-peak CKPT-chain accounting (this work).* Block G also observed a seq-dependent under-prediction in bnb-4-bit Mode-C (offload-pool chunk-offload + checkpoint-everywhere) configurations: + + | Config (30B Llama, 4-bit Mode-C, n_persist=0, n_buffer=12, n_checkpoint=60) | pred GiB | meas steady | alpha_steady = meas / pred | + |---|---:|---:|---:| + | seq=512 (`ext_30b_safe.log`) | 2.49 | 2.91 | 1.169 | + | seq=1024 (`ext_30b_seq1024.log`) | 2.50 | 3.50 | 1.400 | + | seq=2048 (`ext_30b_seq2048.log`) | 2.54 | 4.68 | 1.843 | + + The alpha_steady drift with seq is the diagnostic: the predictor's activation contribution was effectively flat across seq for all-CKPT block_maps. Root cause in `cost/memory.py::estimate_peak`: `retained_none_bytes` only accumulates NONE/OFFLOAD blocks, and the per-CKPT-first-op `ckpt_extra` bump is taken as a per-op max — so an all-CKPT cfg paid for ONE block's recompute window but nothing for the per-block-input residual that survives across the backward window. (Production uses `use_reentrant=False` per `block/checkpoint.py`; the non-reentrant variant still retains a linear-in-N_block activation footprint across the backward window because each CKPT block's saved-tensors-hooks recompute frame holds the block input — `block_input.requires_grad` and the autograd graph keep it pinned until the upstream backward completes.) With 60 CKPT blocks on Llama-30B that chain term is `60 x bs x seq x hidden x dtype_bytes` — the missing seq-dependent term the audit data exposes. + + *Fix.* `estimate_peak` now adds a `ckpt_chain_bytes = sum(activation_sizes[bid] for bid in CKPT blocks)` term that: + + - Is added to every op-walk candidate as a constant (chain is live for the entire backward, not at any single op). + - Is added to the `raw_peak == 0` static fallback (the explicit-override `synth_trace_from_overrides` skip-trace path the Mode-C audit runs all take — `op_order=()` so the per-op walk doesn't execute). + - Is disjoint by construction from `retained_none_bytes` (NONE/OFFLOAD vs CKPT in the per-block loop above). + + To avoid double-counting, the per-CKPT-first-op recompute bump is now sized at the BLOCK-INTERNAL delta only — `ckpt_extra = max(0, saved_bytes_proxy[bid] - activation_sizes[bid])` — since `activation_sizes[bid]` (the block-output / next-block-input residual proxy) is already accounted for by `ckpt_chain_bytes`. The recompute window only materializes block-internal saved tensors (Q/K/V projections, attention scores, FFN intermediates) on top of the persisted chain. In synth / toy traces where `_saved_tensor_bytes_per_block` falls back to `activation_sizes` (no `steady_fwd_block_peak_bytes` data), the internal delta is 0 and `ckpt_chain_bytes` carries the full per-block contribution. The matching enc-dec cross-attention gate (`cross_attn_persist_bytes`) skips its surcharge when the encoder-last block is in CKPT — already covered by the chain term. + + *Post-fix accuracy on the audit data points* (`estimate_peak` directly, NOT through the model wrapper's `_calibrate_peak_with_actual_chunk_bytes` post-calibration which adds a further ~0.6–0.9 GiB of actual_persistent_local correction): + + | seq | estimate_peak GiB | measured | alpha_steady | + |----:|-----------------:|--------:|---------:| + | 512 | 2.04 | 2.91 | 1.43 | + | 1024 | 2.80 | 3.50 | 1.25 | + | 2048 | 4.34 | 4.68 | 1.08 | + + alpha_steady is significantly tighter at high seq (1.84 → 1.08) and slightly looser at low seq (1.17 → 1.43, partly the per-dtype alpha shift from 1.10 to 0.75 since the audit). The chain term gives the per-seq scaling the predictor lacked; absolute accuracy at low seq is bottlenecked by the wrapper-side calibration, which is out of scope for the cost-model fix. + + Tests: `tests/protrain/test_modec_steady_peak_accuracy.py` (pins the per-seq scaling + ±35% tolerance against the three audit data points). Existing tests adjusted: none — the `cost/memory.py` op-walk's recompute-bump refinement is backwards-compatible in every fallback regime (`_saved_tensor_bytes_per_block == activation_sizes`); the cap path and all cap-based tests are unchanged. + + *Out of scope.* The iter-1 transient observed at bnb-4-bit Mode-C (~6.9x pred during the model-load → `materialize_offload` window) is an init-time chunk-residency phenomenon, not a fragmentation or activation-accounting one, and is documented separately as an "init window" not covered by alpha. Tracked as the remaining open audit item. 2. **Pinned-memory allocator:** `ctypes` → `cudaHostAlloc` directly. ~50 LOC, zero new deps, matches App B.2 precisely (avoids `CUDAHostAllocator` pow-2 rounding). DeepSpeed's `PinnedMemoryAllocator` rejected: may inherit same wart, adds import-graph weight. -3. **CPU FusedAdam source:** `deepspeed.ops.adam.DeepSpeedCPUAdam`. Paper builds directly on ZeRO-Offload's CPU Adam. Pure-Python reimpl is >10× slower and would collapse the T_bwd / T_cpu_optim overlap window the cost model assumes. DeepSpeed is already in Axolotl's env. +3. **CPU FusedAdam source:** `deepspeed.ops.adam.DeepSpeedCPUAdam`. Paper builds directly on ZeRO-Offload's CPU Adam. Pure-Python reimpl is >10x slower and would collapse the T_bwd / T_cpu_optim overlap window the cost model assumes. DeepSpeed is already in Axolotl's env. 4. **S_chunk grid:** `{32, 64, 128, 256} MB`. 7B Llama blocks are ~200 MB fp16 → chunks want to be block-scale. 16 MB is too fine-grained; per-chunk sync overhead dominates. M2 agent extends the grid if optimum lands at an endpoint. -5. **SWAP path:** paper-real D2H/H2D wrapper on `_swap_stream`, backed by `ActivationSwapPool` (pinned host slots sized `n_swap × prefetch_depth × max_act_bytes`). Searcher's CPU-feasibility gate refuses `n_swap > 0` candidates whose pool would not fit `cpu_capacity_bytes`. On RTX 3090 / 3090 Ti (12 GB/s PCIe ceiling, no NVLink) the searcher rarely selects `n_swap > 0` — paper §3.1.2 — so the path is tested-but-unused infrastructure on this hardware class. Validated end-to-end via the wrapper-injection path with `n_swap_override`. +5. **SWAP path:** paper-real D2H/H2D wrapper on `_swap_stream`, backed by `ActivationSwapPool` (pinned host slots sized `n_swap x prefetch_depth x max_act_bytes`). Searcher's CPU-feasibility gate refuses `n_swap > 0` candidates whose pool would not fit `cpu_capacity_bytes`. On RTX 3090 / 3090 Ti (12 GB/s PCIe ceiling, no NVLink) the searcher rarely selects `n_swap > 0` — paper §3.1.2 — so the path is tested-but-unused infrastructure on this hardware class. Validated end-to-end via the wrapper-injection path with `n_swap_override`. ### Memory Allocation Strategy (App B.2 — WIRED) @@ -297,7 +331,7 @@ App B.2 of the paper has **two distinct components**, each addressing a differen - **Heap routing vs. kernel scheduling.** App B.2 governs *which heap an allocation comes from*, not which stream a kernel runs on. The wire-up keeps the dedicated `_prefetch_stream` and `_swap_stream` for PCIe-vs-compute overlap (those streams are about *kernel launch ordering*) but routes the *allocations* underneath them through the default-stream heap via `SingleStreamAllocator`. Cross-stream tensor consumption stays correct because every wrapped allocation that hands a buffer to a non-default stream calls `tensor.record_stream(non_default_stream)` immediately after exiting the allocator context, defering allocator reuse until the consuming stream has retired the work. - **Wired call sites.** - - `chunk/buffer_pool.py::BufferPool.__init__` — pre-allocates every pool slot (n_buffer × S_chunk bytes) on the default-stream heap. **Highest-leverage single change** — pool slots are the dominant sustained GPU allocation in ProTrain. No `record_stream` needed: pool slots' lifetimes are owned by the pool and only return to the allocator at teardown. + - `chunk/buffer_pool.py::BufferPool.__init__` — pre-allocates every pool slot (n_buffer x S_chunk bytes) on the default-stream heap. **Highest-leverage single change** — pool slots are the dominant sustained GPU allocation in ProTrain. No `record_stream` needed: pool slots' lifetimes are owned by the pool and only return to the allocator at teardown. - `chunk/manager.py::_ensure_persistent_buffer` — long-lived persistent-chunk GPU buffers. No `record_stream` (long-lived). - `chunk/manager.py::_empty_placeholder` — cached zero-element `param.data` sentinel. No `record_stream` (process-lived, not a kernel consumer). - `chunk/manager.py::_gather_sharded` — per-region `my_shard_gpu` and `gather_scratch` scratch tensors. **Critical wrap** — this method is called from `Scheduler._gather_on_prefetch_stream` inside `with torch.cuda.stream(self._prefetch_stream):`. Without the wrap, scratch tensors would land on the prefetch-stream heap and fragment the allocator. `record_stream(current_stream)` discipline applied: the scratch buffers are tied to whichever stream is actually consuming them (the prefetch stream in steady-state, the default stream in synchronous fallback). @@ -311,15 +345,56 @@ App B.2 of the paper has **two distinct components**, each addressing a differen - **Paper's design.** PyTorch's `torch.empty(pin_memory=True)` routes through `CUDAHostAllocator`, which rounds the requested byte count up to the next power of two. For a 24 MB chunk that's a 32 MB allocation; for the trailing chunk of a 7B-param model the round-up can waste tens of MB across the offload set. ProTrain implements its own pinned allocator (`chunk/pinned_alloc.py::PinnedHostMemory`) that calls `cudaHostAlloc` directly via `ctypes` with the exact byte count, avoiding the rounding waste entirely. -- **PinnedHostMemory contract.** `PinnedHostMemory(n_buffer, S_chunk)` allocates `n_buffer × S_chunk` bytes pinned-host. `buffer(i)` returns a zero-copy `torch.Tensor` view over slot `i`; `release_buffer(i)` decrements the borrow refcount. `close()` raises if any borrow is still outstanding (use-after-free guard). The `__del__` path leaks rather than free under outstanding borrows, on the basis that a destructor-time leak is preferable to a dangling-pointer free. If `libcudart` cannot be loaded via `ctypes`, the allocator falls back to `torch.empty(size, pin_memory=True)` and exposes `is_precise_size = False` so tests can detect the regression. +- **PinnedHostMemory contract.** `PinnedHostMemory(n_buffer, S_chunk)` allocates `n_buffer x S_chunk` bytes pinned-host. `buffer(i)` returns a zero-copy `torch.Tensor` view over slot `i`; `release_buffer(i)` decrements the borrow refcount. `close()` raises if any borrow is still outstanding (use-after-free guard). The `__del__` path leaks rather than free under outstanding borrows, on the basis that a destructor-time leak is preferable to a dangling-pointer free. If `libcudart` cannot be loaded via `ctypes`, the allocator falls back to `torch.empty(size, pin_memory=True)` and exposes `is_precise_size = False` so tests can detect the regression. - **Wired call sites (pinned host).** - - `chunk/buffer_pool.py::BufferPool.__init__` — backing pinned-host region for the GPU buffer pool's H2D staging slots (`n_buffer × S_chunk`). One `PinnedHostMemory` per pool. + - `chunk/buffer_pool.py::BufferPool.__init__` — backing pinned-host region for the GPU buffer pool's H2D staging slots (`n_buffer x S_chunk`). One `PinnedHostMemory` per pool. - `chunk/manager.py::materialize_offload` — TWO unified `PinnedHostMemory` regions per manager: one for every non-persistent chunk's param shadow (replicated) or per-rank shard bytes (sharded), one for trainable-param grad shadows. Sized to the precise sum of per-chunk aligned bytes plus a 16-byte inter-chunk alignment pad. Per-chunk views into the pools are `narrow()` slices; the BUG 2 intra-chunk dtype-region alignment is preserved per-chunk under the unified layout. Closed via `_close_cpu_pools` from `restore_to_gpu` (deterministic teardown) or `__del__` (GC safety net). See `tests/protrain/test_chunk_manager_offload.py::test_materialize_offload_uses_precise_pinned_pool` for the precise-sizing assertion. - - `block/swap_pool.py::ActivationSwapPool` — backing pinned-host region for activation swap slots (`n_swap × prefetch_depth × max_act_bytes`). One `PinnedHostMemory` per pool. + - `block/swap_pool.py::ActivationSwapPool` — backing pinned-host region for activation swap slots (`n_swap x prefetch_depth x max_act_bytes`). One `PinnedHostMemory` per pool. - **Allocation sites still on `torch.empty(pin_memory=True)` (unintentional).** *None* in the wired ProTrain runtime as of this commit. If a follow-up adds a new pinned-host allocation site it should default to `PinnedHostMemory` for paper fidelity. #### Measurement status -Peak-memory delta from the wire-up has not been measured on RTX 3090 reference hardware in this commit (the `α = 1.10` fragmentation factor — item 1 above — was already absorbing the un-wired fragmentation cost in the cost model). To-be-measured in a follow-up: re-run the M1 profiler ground-truth before and after the wire-up; if peak drops by more than ~5% on a 1.5B-param target shape, recalibrate `α` downward. The single-stream wire-up's correctness — the `record_stream` discipline at every cross-stream site — has been validated by the new `tests/protrain/test_single_stream_allocator.py` test (heap-affinity assertion via free-then-reallocate fragmentation probe + nested-stream context-manager composition test). The pinned-host wire-up's correctness — total pool bytes equals the sum of per-chunk aligned bytes — is asserted by `tests/protrain/test_chunk_manager_offload.py::test_materialize_offload_uses_precise_pinned_pool`. +Peak-memory delta from the wire-up has not been measured on RTX 3090 reference hardware in this commit (the `alpha = 1.10` fragmentation factor — item 1 above — was already absorbing the un-wired fragmentation cost in the cost model). To-be-measured in a follow-up: re-run the M1 profiler ground-truth before and after the wire-up; if peak drops by more than ~5% on a 1.5B-param target shape, recalibrate `alpha` downward. The single-stream wire-up's correctness — the `record_stream` discipline at every cross-stream site — has been validated by the new `tests/protrain/test_single_stream_allocator.py` test (heap-affinity assertion via free-then-reallocate fragmentation probe + nested-stream context-manager composition test). The pinned-host wire-up's correctness — total pool bytes equals the sum of per-chunk aligned bytes — is asserted by `tests/protrain/test_chunk_manager_offload.py::test_materialize_offload_uses_precise_pinned_pool`. + +## Known Limitations + +### Checkpoint mode handling (Phase 2 M6C) + +ProTrain checkpoints encode the mode they were produced under (Mode A all-persistent vs. Mode C sharded-with-offload), so the resume path must reconcile the on-disk layout with the resumed-runtime layout. Two cases: + +- **Same-mode resume** (Mode A → Mode A, Mode C → Mode C) is the simple path — the chunk layout and optimizer-state shapes are identical so HF Trainer's `_load_from_checkpoint` copies straight in. +- **Cross-mode resume** (Mode A → Mode C, Mode C → Mode A) is bridged by **M6C-fix-1** (`a71f26e9`): the resume hook in `plugin._install_resume_hook` calls `restore_to_gpu()` on every offloaded chunk BEFORE HF copies the loaded weights into full-shape `param.data` slots, then re-runs `materialize_offload` afterward and rebuilds the per-chunk optimizer adapter. Without this hook HF would write into the zeroed non-persistent slots and ProTrain's first `gather` would overwrite the loaded state with the (still-zero) CPU shadow. The hook is installed by monkey-patching `trainer._load_from_checkpoint` with a wrapper that runs `restore_to_gpu()` *before* delegating to the original HF method and runs `materialize_offload()` + optimizer rebuild *after* it returns, all inside the same patched call. ProTrain therefore interleaves its `gather` between weight load and the first forward in plugin code rather than forking HF. + +Real-multigpu cross-mode resume coverage (4x3090, sharded Mode C, Llama-3-8B + LoRA): both `test_real_multigpu_cross_mode_resume_a_to_c` and `test_real_multigpu_cross_mode_resume_c_to_a` PASS as of the full M6C-fix-1..8 chain. See § "Standard PEFT-LoRA in Mode C" below for the chain's other layers (which closed PEFT-LoRA Mode-C correctness on top of the resume-hook fix). + +### Standard PEFT-LoRA in Mode C (Phase 2 M6C) + +Plain `peft` LoRA on top of an unquantized base is **supported in single-GPU offload mode** as of `M6C-fix-2` + `M6C-fix-3` (per-PEFT-LoRA-container gather hooks installed at both profiler-trace and runtime-scheduler surfaces). The chain works as follows: + +- `profiler/on_demand.py::_find_peft_lora_containers` discovers any module with direct trainable LoRA factors (`lora_A` / `lora_B` / `lora_magnitude_vector` / `lora_embedding_*`). Pre-forward and pre-backward gather hooks are installed at the *container* granularity (parallel to M1's fused-kernel-container strategy), so the LoRA factor sub-chunks are GPU-resident before PEFT's `LoraLayer.forward` casts them to bf16. +- `runtime/hooks.py` + `runtime/scheduler.py::ensure_chunks_resident` install the same container-granularity hooks on the live training scheduler. Without this, the runtime's block-level gather (which assumes per-block chunk granularity) leaves the LoRA sub-chunks released until after the PEFT cast op records its autograd shape, producing the canonical `ToCopyBackward0 returned an invalid gradient at index 0 - got [N, R] but expected shape compatible with [0]` failure. + +**Multi-GPU sharded mode (`protrain_zero3_shard: true, world_size > 1`) is supported for plain LoRA** as of the M6C-fix-1 through M6C-fix-8 chain (8 commits). Each fix closed a layer of the failure stack: + +- **fix-1** (`a71f26e9`) — cross-mode resume hook for HF Trainer `_load_from_checkpoint`. +- **fix-2** (`4856090e`) — per-PEFT-LoRA-container gather hooks in `profiler/on_demand.py`. +- **fix-3** (`32663f30`) — runtime-side per-LoRA-container gather hooks in `runtime/hooks.py`. +- **fix-4** (`b5ffa3d9`) — synchronous gather in `Scheduler.ensure_chunks_resident`. +- **fix-5** (`b787acb5`) — late-NCCL-re-search skip on explicit-override paths + autocast diagnostic. +- **fix-6** (`0f44bfb6`) — pre/post forward+backward quartet hooks per LoRA container. +- **fix-7** (`c0da4282`) — shape-preserving release-state placeholder (closes the `ToCopyBackward0 / TBackward0 ... shape compatible with [0]` autograd shape-capture error class via `scratch.expand(slot.shape)` views that preserve `param.size()` metadata across release/re-gather). +- **fix-8** (`17ffb8d1`) — DDP `init_sync=False` bypass for chunk-managed params (closes the residual `more than one element of the written-to tensor refers to a single memory location` from DDP's construction-time `_sync_module_states._broadcast_coalesced` writing into the expand-view placeholder). + +Multi-GPU verification (4x3090, sharded Mode C, Llama-3-8B + LoRA): `test_real_multigpu_cross_mode_resume_a_to_c` PASSES (Phase 1 Mode A 5 steps + Phase 2 Mode C resume steps 6..10; losses 1.093 → 0.832); `test_real_multigpu_cross_mode_resume_c_to_a` PASSES (Phase 1 Mode C 5 steps + Phase 2 Mode A resume steps 6..10). + +Architecturally, ProTrain now owns the parallelism contract for chunk-managed parameters end-to-end: per-rank deterministic partition via `materialize_offload`, sharded gather via `_gather_sharded`, `reduce_scatter` on backward via `reduce_grads_and_offload`, and the DDP construction-time broadcast bypass keeps DDP from clobbering the sharded layout with its replicated broadcast assumption. + +**Supported configurations (no workaround needed):** + +- **Single-GPU plain fp16 / bf16 LoRA in offload mode** — works directly as of M6C-fix-3; no special config beyond `protrain_force_all_persistent: false` and the override knobs. +- **Multi-GPU sharded plain fp16 / bf16 LoRA in offload mode** — works as of the full M6C-fix-1..8 chain. The runtime/profiler-side gather hooks (fix-2, fix-3, fix-4, fix-6), the shape-preserving release-state placeholder (fix-7), and the DDP init-sync bypass (fix-8) together close the chain that previously surfaced as `ToCopyBackward0 ... shape compatible with [0]` and DDP `_sync_module_states._broadcast_coalesced` shared-storage hazards. +- **Quantized base + LoRA** — pair LoRA with bnb 4-bit or 8-bit weight quantization. `bitsandbytes.nn.Linear4bit` / `Linear8bitLt` use typed `param.data` views that survive the non-persistent slot lifecycle in both single- and multi-GPU; the M3 13B headline test exercises this combination. + +Coverage: `tests/protrain/test_lora_offload_mode.py` (22 tests, single-GPU plain LoRA Mode C end-to-end, all PASS); `tests/protrain/test_cross_mode_resume.py` real-multigpu tests `_a_to_c` and `_c_to_a` PASS as of M6C-fix-8 (xfail markers removed in commit `17ffb8d1`); `tests/protrain/test_paged_adam_offload_mgpu.py` regresses the bnb 4-bit + paged_adamw_8bit + Mode C at seq=2048 multi-GPU path that M6C-fix-8 also closed. The M6C report under `docs/protrain/` traces the historical failure modes. diff --git a/src/axolotl/integrations/protrain/api/checkpoint.py b/src/axolotl/integrations/protrain/api/checkpoint.py index 297fa10052..3180617c80 100644 --- a/src/axolotl/integrations/protrain/api/checkpoint.py +++ b/src/axolotl/integrations/protrain/api/checkpoint.py @@ -2062,7 +2062,10 @@ def install_load_hook( The closed-over ``optim`` is captured at install time (in ``post_trainer_create``, BEFORE Accelerate.prepare wraps the optimizer), so it's already raw. We unwrap defensively in case - the caller hands in a wrapper. + the caller hands in a wrapper. At ``_patched()`` runtime we + re-resolve from ``trainer.optimizer`` so a cross-mode resume + rebuild that swaps the facade lands the load into the live + instance (falls back to the install-time raw on swap failure). The ``allow_online_reshard`` flag plumbs through to :func:`_load_protrain_optim_dir`. Default False keeps the Mode-C @@ -2071,8 +2074,8 @@ def install_load_hook( dir, all ranks barrier and load). See CHECKPOINT_DESIGN_PHASE2.md §4.1. """ - raw = _unwrap_protrain_optim(optim) - if raw is None: + raw_at_install = _unwrap_protrain_optim(optim) + if raw_at_install is None: # Caller passed something that isn't a ProTrain optimizer — # silently no-op rather than installing a hook that would # never fire. @@ -2081,6 +2084,11 @@ def install_load_hook( original = trainer._load_optimizer_and_scheduler def _patched(checkpoint: str | None) -> None: + # Re-resolve from ``trainer.optimizer`` so the cross-mode resume rebuild + # (which swaps trainer.optimizer = new_optim) loads into the live instance. + raw = _unwrap_protrain_optim(getattr(trainer, "optimizer", None)) + if raw is None: + raw = raw_at_install # Failure protocol: ``original(checkpoint)`` (the native HF # optimizer/scheduler load) is outside any cluster-wide status # handling, but the patched method still executes a distributed diff --git a/src/axolotl/integrations/protrain/api/model_wrapper.py b/src/axolotl/integrations/protrain/api/model_wrapper.py index c4479d7425..462ef7d25f 100644 --- a/src/axolotl/integrations/protrain/api/model_wrapper.py +++ b/src/axolotl/integrations/protrain/api/model_wrapper.py @@ -48,7 +48,10 @@ ) from axolotl.integrations.protrain.profiler.cache import ProfilerCacheKey from axolotl.integrations.protrain.profiler.hw_bench import measure_compute_rate -from axolotl.integrations.protrain.profiler.trace import _arch_hash +from axolotl.integrations.protrain.profiler.trace import ( + _arch_hash, + synth_trace_from_overrides, +) from axolotl.integrations.protrain.runtime.hooks import install_hooks from axolotl.integrations.protrain.runtime.scheduler import Scheduler from axolotl.integrations.protrain.search import search @@ -95,6 +98,73 @@ def _sku(device: "torch.device | str") -> str: return "cpu" +def _detect_dominant_param_bytes_per_element(model: nn.Module) -> float: + """Return the modal logical bytes-per-element across the model's params.""" + # Best-effort detection of bnb 4-bit param class. The import is + # behind a try/except because bitsandbytes is an optional dep — + # CPU-only test rigs and minimal installs may not have it. + _Params4bit: type | None = None + try: + import bitsandbytes.nn as _bnb_nn # type: ignore[import-untyped] + except Exception as _bnb_exc: # noqa: BLE001 — defensive; bnb is optional + LOG.debug( + "bitsandbytes.nn import failed (%s); 4-bit dtype detection " + "skipped — params classify by storage element_size().", + _bnb_exc, + ) + else: + _Params4bit = getattr(_bnb_nn, "Params4bit", None) + + # Aggregate logical-element counts keyed by bytes-per-element. + # The unit of "logical element" is one weight value as the + # autograd graph sees it — for ``Params4bit`` that's twice the + # storage numel. + by_bpe: dict[float, int] = {} + for _, param in model.named_parameters(): + try: + storage_numel = int(param.numel()) + except Exception as _exc: # noqa: BLE001 — defensive, missing/meta params + LOG.debug( + "param.numel() failed during dtype detection (%s); skipping param.", + _exc, + ) + continue + if storage_numel <= 0: + continue + if _Params4bit is not None and isinstance(param, _Params4bit): + # Each stored uint8 byte holds two 4-bit logical values. + logical_numel = storage_numel * 2 + bpe = 0.5 + else: + try: + bpe = float(int(param.element_size())) + except Exception as _exc: # noqa: BLE001 — defensive + LOG.debug( + "param.element_size() failed during dtype detection " + "(%s); skipping param.", + _exc, + ) + continue + logical_numel = storage_numel + by_bpe[bpe] = by_bpe.get(bpe, 0) + logical_numel + + if not by_bpe: + return 2.0 + + # Pick the bpe class with the largest aggregate logical-element + # count. Ties resolve in favour of the smaller bpe (i.e. the more + # aggressive quantization) so the searcher's alpha picks the + # tighter-budget regime when the model is genuinely mixed. + dominant_bpe = min( + by_bpe.keys(), + key=lambda b: ( + -by_bpe[b], + b, + ), # primary: descending count; secondary: smallest bpe + ) + return float(dominant_bpe) + + def _dummy_batch( model: nn.Module, batch_size: int, @@ -281,6 +351,56 @@ def _chunk_bytes(layout, chunk_manager) -> dict[int, int]: return out +def predict_init_transient_peak_bytes( + layout, + hw: HardwareProfile, + chunk_manager=None, +) -> int: + """Predict the GPU high-water mark during the init transient window.""" + # Local import to avoid a module-level cost.memory dependency cycle + # at import time (cost.memory pulls in profiler/types which would + # otherwise drag this api module in via Python's circular import + # resolution if it ever gets imported eagerly during cost.memory init). + from axolotl.integrations.protrain.cost.memory import ALPHA_FRAGMENTATION + + n_chunk = int(getattr(layout, "N_chunk", 0)) + s_chunk = int(getattr(layout, "S_chunk", 0)) + if n_chunk <= 0 or s_chunk <= 0: + return 0 + + if chunk_manager is not None: + try: + cb = _chunk_bytes(layout, chunk_manager) + except Exception as exc: # noqa: BLE001 — defensive, broken stub + LOG.debug( + "predict_init_transient_peak_bytes: _chunk_bytes failed " + "(%s); falling back to N_chunk * S_chunk upper bound.", + exc, + ) + sum_chunk_bytes = n_chunk * s_chunk + else: + sum_chunk_bytes = sum(int(v) for v in cb.values()) + # Defensive: if the chunk_manager's model has no overlap with + # the layout's param ids (e.g. tests pass a stub with empty + # named_parameters) the sum collapses to 0. Fall back to the + # layout upper bound so the caller still gets a non-zero + # prediction. Real models always populate the sum. + if sum_chunk_bytes <= 0: + sum_chunk_bytes = n_chunk * s_chunk + else: + sum_chunk_bytes = n_chunk * s_chunk + + # The hw argument is reserved for a future per-dtype iter-1 alpha + # refinement once more empirical data is available. Today alpha=1.10 + # holds across the audit's fp16 / 8-bit / 4-bit Mode-C data points + # (the 4-bit Mode-A configs have no separable transient because + # the persistent set IS the full chunk set). Touch hw to silence + # the unused-arg lint and make the future-extension intent clear. + _ = hw.dominant_param_bytes_per_element + + return int(sum_chunk_bytes * ALPHA_FRAGMENTATION) + + def _calibrate_peak_with_actual_chunk_bytes( original_peak: int, layout, @@ -322,11 +442,11 @@ def _calibrate_peak_with_actual_chunk_bytes( ---------------------------- The reverse-out below uses the SAME ``persistent_factor`` / ``buffer_factor`` as :func:`model_state_present_bytes`, NOT the - legacy 1.0×-flat assumption. The previous implementation reversed + legacy 1.0x-flat assumption. The previous implementation reversed out only ``(n_persist + n_buffer) * S`` (params-only), which left the per-chunk full-state multiplier hiding inside ``f_bm`` and then re-added only the param bytes — under full FT (where - ``persistent_factor`` can be 4-7×) that systematically under-stated + ``persistent_factor`` can be 4-7x) that systematically under-stated calibrated peak by roughly ``(persistent_factor - 1) * actual_persistent``. Mismatch was harmless under LoRA-with-frozen- base (``persistent_factor ≈ 1``); now corrected for both regimes. @@ -383,7 +503,7 @@ def _reconstruct_f_bm(bmap) -> tuple[int, int]: for bid_, mode_ in bmap.items(): if mode_ is BlockMode.NONE or mode_ is BlockMode.OFFLOAD: live_none_bytes += int( - saved_bytes_proxy.get(bid_, act_sizes_full.get(bid_, 0)) + saved_bytes_proxy.get(bid_, act_sizes_full.get(bid_, 0)) or 0 ) n_ckpt_ = sum(1 for m in bmap.values() if m is BlockMode.CKPT) max_ckpt_act_ = 0 @@ -478,7 +598,7 @@ def _structural_calibrated( # chunks are persistent (n_persist_eff ≈ N_chunk), the cost # model's post-cap raw_peak collapses to roughly # ``profile_time_model_state + small_activation_residual``. - # The reverse-out ``original_peak / α - n_persist_eff * S`` + # The reverse-out ``original_peak / alpha - n_persist_eff * S`` # then yields ``f_bm = 0`` because the chunk-padding waste in # the cost model's model-state term consumes the activation # headroom — even though the runtime DOES allocate activations @@ -805,7 +925,7 @@ def _structural_calibrated( phase2_peak, ) LOG.info( - "ProTrain peak cfg-delta (legacy α-strip): " + "ProTrain peak cfg-delta (legacy alpha-strip): " "phase2_peak=%.2f GB phase2_anal=%.2f GB " "prod_anal=%.2f GB delta_raw=%.2f GB " "floor=%.2f GB calibrated=%.2f GB", @@ -1159,7 +1279,7 @@ def _construct_runtime( # partitioning + the ChunkManager construction agree on which # chunks are persistent. # - # The runtime resident set is ``{0..n_persist-1} ∪ + # The runtime resident set is ``{0..n_persist-1} | # layout.mandatory_persistent``. ``layout.mandatory_persistent`` is # populated once by :func:`build_layout` and records every chunk # containing at least one non-block param (e.g. ``model.norm.weight``, @@ -1184,7 +1304,7 @@ def _construct_runtime( LOG.info( "ProTrain: %d chunks %s pinned by layout.mandatory_persistent " "(non-block params the block-granularity scheduler cannot " - "gather on its own); residency = prefix[0..%d) ∪ mandatory", + "gather on its own); residency = prefix[0..%d) | mandatory", len(layout.mandatory_persistent), sorted(layout.mandatory_persistent), n_persist, @@ -1254,6 +1374,16 @@ def _construct_runtime( zero3_shard, ) + # Shape-preserving release-state placeholders close a multi-GPU + # sharded PEFT race where autograd recorded ``torch.Size([0])`` on + # the placeholder before the per-container gather hook rebound it, + # yielding ``ToCopyBackward0`` shape mismatches at backward. The + # zero-stride view over a per-dtype scratch keeps ``param.size()`` + # reporting the real logical shape regardless of gather ordering. + # Engaged only on the multi-GPU sharded zero3_shard path so existing + # single-GPU / replicated tests asserting ``param.data.numel() == 0`` + # post-offload continue to hold. + _shape_preserving = bool(_zero3) chunk_manager = ChunkManager( model=model, layout=layout, @@ -1265,6 +1395,7 @@ def _construct_runtime( world_size=_ws, rank=_rank, zero3_shard=_zero3, + shape_preserving_placeholders=_shape_preserving, ) # The non-block-chunk pinning that earlier versions performed here @@ -1321,13 +1452,20 @@ def _construct_runtime( block_map=result.block_map, hw=hardware_profile, ) - if calibrated_peak != result.predicted_peak_bytes: - LOG.info( - "ProTrain: peak prediction calibrated %.2f -> %.2f GB " - "using actual per-chunk byte footprint", - result.predicted_peak_bytes / (1 << 30), - calibrated_peak / (1 << 30), - ) + # Predict the GPU high-water mark during the brief window between + # full-model GPU construction and ``materialize_offload`` so the + # searcher / telemetry can flag init-window OOM ahead of iter 1. + init_transient_peak = predict_init_transient_peak_bytes( + layout, hardware_profile, chunk_manager + ) + if calibrated_peak != result.predicted_peak_bytes or init_transient_peak > 0: + if calibrated_peak != result.predicted_peak_bytes: + LOG.info( + "ProTrain: peak prediction calibrated %.2f -> %.2f GB " + "using actual per-chunk byte footprint", + result.predicted_peak_bytes / (1 << 30), + calibrated_peak / (1 << 30), + ) # ``cfg.n_persist`` continues to mean "prefix length the search # chose". Earlier versions of this site collapsed it into # ``len(chunk_manager._persistent_ids)`` — the augmented set @@ -1355,7 +1493,23 @@ def _construct_runtime( block_map=result.block_map, predicted_peak_bytes=calibrated_peak, predicted_iter_s=result.predicted_iter_s, + predicted_init_transient_peak_bytes=init_transient_peak, ) + # Log the iter-1 transient alongside the steady peak so operators + # see both numbers in the standard ProTrain bootstrap output. The + # ratio surfaces the Mode-C ~6x under-prediction at search time + # rather than at iter-1 OOM. + LOG.info( + "ProTrain: predicted peaks: steady=%.2f GiB iter1_transient=%.2f GiB " + "(ratio=%.2fx; > 2x suggests Mode-C offload regime)", + result.predicted_peak_bytes / (1 << 30), + init_transient_peak / (1 << 30), + ( + init_transient_peak / max(result.predicted_peak_bytes, 1) + if init_transient_peak > 0 + else 0.0 + ), + ) # ---- 4.5: materialize the init-time chunk offload (M4.5 Gap 1) ----- # Physically move every non-persistent chunk's param data to pinned @@ -1383,6 +1537,259 @@ def _construct_runtime( ) _sys2.stderr.flush() + # ---- 4.5b: DDP-ignore the chunk-managed params --------------------- + # On the multi-GPU sharded path we engaged + # ``shape_preserving_placeholders=True`` above. The released-state + # ``param.data`` is now a ``scratch.expand(slot.shape)`` zero-stride + # view: shape-preserving (autograd-safe) but NOT write-safe (multiple + # logical positions share one physical element). + # + # Downstream, ``transformers.Trainer._prepare_for_training`` calls + # ``self.accelerator.prepare(model, optimizer)`` which wraps the + # model in :class:`torch.nn.parallel.DistributedDataParallel`. + # DDP's ``__init__`` runs ``_sync_module_states`` which iterates + # ``module.named_parameters()`` and broadcasts each rank-0 tensor + # into every rank's storage via ``dist._broadcast_coalesced``. The + # broadcast is an IN-PLACE WRITE; on the expanded placeholder it + # trips PyTorch's shared-storage hazard: + # + # RuntimeError: unsupported operation: more than one element + # of the written-to tensor refers to a single memory location. + # Please clone() the tensor before performing the operation. + # + # Failure is universal across all 4 ranks at DDP construction time, + # BEFORE the trainer's training loop starts. + # + # Architecturally the fix is a no-op on correctness: ProTrain owns + # the parallelism contract for chunk-managed params. Init-time + # sharding is performed by ``materialize_offload`` (each rank + # populates its own shard from the same rank-0-loaded weights via + # the Trainer's pre-wrap path); gather-time reconstruction uses + # ``all_gather_into_tensor``; grad-time drain uses + # ``reduce_scatter``. DDP's per-param broadcast at construction + # time would CORRUPT the per-rank shards (each rank's CPU shard + # holds different bytes, so broadcasting rank-0's bytes to every + # rank would overwrite rank-N's shard with rank-0's shard). DDP's + # backward-pass allreduce on these params would also conflict with + # the chunk manager's reduce_scatter drain. + # + # The supported opt-out hook is + # ``module._ddp_params_and_buffers_to_ignore`` — DDP's + # ``__init__`` reads it at construction time + # (torch/nn/parallel/distributed.py ~line 718) and excludes those + # named params from BOTH the init broadcast AND the backward + # allreduce. Persistent chunks are intentionally NOT included: + # their params stay GPU-resident through the released window, + # never pass through the expand placeholder, and DO need the + # standard DDP broadcast/allreduce for correctness (they are + # replicated across ranks, not sharded). + # + # Default OFF (single-GPU / multi-GPU replicated): no-op. The + # ``_shape_preserving`` gate guarantees we only set the ignore + # attribute on the path that needs it. + if _shape_preserving: + # Empirically, registering + # ``model._ddp_params_and_buffers_to_ignore`` is INSUFFICIENT + # on the production multi-GPU sharded path even when 100 % of + # chunk-managed names match ``model.named_parameters()`` + # (verified at INFO time via "live match: N/N"). The + # ``_sync_module_states`` broadcast STILL trips the shared- + # storage hazard, suggesting either a name-resolution + # discrepancy inside DDP's C++ filter, an accelerate-side + # transformation that re-introduces the placeholders, or a + # buffer the filter does not reach. Rather than continue + # fighting the filter at the symptom layer, we bypass the + # init-time broadcast entirely. + # + # Architectural justification: ProTrain owns the parallelism + # contract for chunk-managed params (init shard via + # ``materialize_offload``, gather via + # ``all_gather_into_tensor``, grad reduce via + # ``reduce_scatter``). DDP's init-time broadcast is REDUNDANT + # for replicated params (every rank already loaded the same + # checkpoint) and INCORRECT for sharded params (each rank + # holds a different shard, broadcasting one rank's bytes to + # all ranks would corrupt the other ranks' shards). The + # init-broadcast contract is "make all ranks agree on the + # initial state"; on the sharded ProTrain path that contract + # is satisfied by every rank loading from the SAME local + # ``modelA_ckpt`` checkpoint and going through the same + # materialize_offload partition rule — the broadcast adds + # nothing. + # + # Mechanism: monkey-patch + # ``torch.nn.parallel.DistributedDataParallel.__init__`` to + # auto-inject ``init_sync=False`` whenever the wrapped module + # carries our marker attribute + # ``_protrain_ddp_skip_init_sync``. This skips + # ``_verify_param_shape_across_processes`` (which would + # gather() shape metadata even for ignored params and could + # itself trip on the placeholder) AND the + # ``_sync_module_states`` broadcast. Backward-pass allreduce + # remains gated by ``parameters_to_ignore`` (still filled + # from ``_ddp_params_and_buffers_to_ignore`` — see DDP + # __init__ line ~718) so chunk-managed params are also + # skipped at backward, matching ProTrain's reduce_scatter + # contract. + # + # The monkey-patch is idempotent: we attach a sentinel + # attribute on the DDP class so repeat + # ``protrain_model_wrapper`` calls (test reruns, fixtures) + # don't stack patches. The patch is GATED on the marker — + # any DDP construction WITHOUT our marker (other models in + # the same process, future use cases) is untouched. + try: + import torch.nn.parallel as _tnp + + _ddp_cls = _tnp.DistributedDataParallel + if not getattr(_ddp_cls, "_protrain_init_sync_patched", False): + _orig_init = _ddp_cls.__init__ + + def _patched_init(self, module, *args, **kwargs): + # Detect our marker on the wrapped module (or any + # ancestor reached via ``module.module`` for + # nested-DDP edge cases). When present, override + # ``init_sync`` to False so the init-time + # broadcast skips the chunk-manager-managed + # placeholders. + _walk = module + _seen: set[int] = set() + while _walk is not None and id(_walk) not in _seen: + _seen.add(id(_walk)) + if getattr(_walk, "_protrain_ddp_skip_init_sync", False): + kwargs["init_sync"] = False + LOG.info( + "ProTrain (M6C-fix-8): " + "DistributedDataParallel.__init__ " + "patched-injection of init_sync=False " + "for chunk-managed model — " + "_sync_module_states broadcast and " + "_verify_param_shape_across_processes " + "are bypassed (every rank already " + "agreed on init state via " + "materialize_offload's deterministic " + "partition).", + ) + break + _walk = getattr(_walk, "module", None) + return _orig_init(self, module, *args, **kwargs) + + _ddp_cls.__init__ = _patched_init + _ddp_cls._protrain_init_sync_patched = True + + # Mark the model so the patch detects it. Persistent + # across the model lifetime — the marker is harmless if + # DDP is never wrapped around it (no patch fires). + model._protrain_ddp_skip_init_sync = True # type: ignore[attr-defined] + except Exception as _patch_exc: # noqa: BLE001 — defensive + LOG.warning( + "ProTrain (M6C-fix-8): failed to install " + "DistributedDataParallel init_sync bypass patch: %s. " + "Multi-GPU sharded path may still trip the shared-" + "storage hazard at DDP construction time.", + _patch_exc, + ) + + ignore = chunk_manager.chunk_managed_param_names() + # Cross-check: every registered name must resolve through + # ``model.named_parameters()`` — if it doesn't, DDP's + # ``_sync_module_states`` filter ``if name not in ignore`` will + # not match (DDP iterates the full recursive name; we register + # whatever ``slot.param_id`` carried). Mismatch is the silent- + # failure mode that would let the broadcast still target the + # expand placeholder. Surface a count that aligns the two + # vocabularies so any future drift is caught at INFO time. + live_names = {n for n, _ in model.named_parameters()} + unmatched = ignore - live_names + if unmatched: + LOG.warning( + "ProTrain (M6C-fix-8): %d/%d chunk-managed names do NOT " + "match model.named_parameters() — DDP broadcast filter " + "will MISS them. Sample mismatches: %s", + len(unmatched), + len(ignore), + sorted(unmatched)[:5], + ) + existing = getattr(model, "_ddp_params_and_buffers_to_ignore", None) + if existing is None: + model._ddp_params_and_buffers_to_ignore = list(ignore) # type: ignore[attr-defined] + else: + # Preserve any names a caller (or earlier integration) already + # registered; merge ours on top so neither side is lost. + merged = set(existing) | ignore + model._ddp_params_and_buffers_to_ignore = list(merged) # type: ignore[attr-defined] + LOG.info( + "ProTrain (M6C-fix-8): registered %d chunk-managed param " + "names in model._ddp_params_and_buffers_to_ignore (live " + "match: %d/%d) so DDP's _sync_module_states broadcast " + "skips the shape-preserving expand placeholders (write " + "would trip the shared-storage hazard on the expanded " + "view).", + len(ignore), + len(ignore - unmatched), + len(ignore), + ) + else: + # D1 (rebuild lifecycle): non-shape-preserving rebuild path — + # if the model still carries DDP-skip state from a prior + # shape-preserving wrap (Mode C bootstrap → Mode A/B rebuild + # without an explicit close in between), strip it so the + # downstream DDP wrap performs the normal init_sync broadcast + # and backward allreduce. Leaving the marker / ignore list in + # place would silently desynchronize weights or gradients on + # the rebuilt runtime because: + # + # - ``_protrain_ddp_skip_init_sync`` ⇒ the monkey- + # patch on ``DDP.__init__`` skips ``init_sync`` entirely on + # the rebuilt model, even though replicated Mode A NEEDS + # the init-time broadcast (every rank loaded the same + # weights but DDP's contract is to make that authoritative). + # - ``_ddp_params_and_buffers_to_ignore`` carries the chunk- + # managed name set from the prior Mode-C wrap; if the + # rebuilt Mode-A runtime keeps the same param names, DDP's + # backward allreduce would still skip them and per-rank + # gradients would diverge. + # + # The pre-protrain snapshot (``_protrain_ddp_original_ignore``) + # was taken by ChunkManager.materialize_offload's D2 lifecycle + # logic on the FIRST wrap; restoring from it here is the + # symmetric teardown that + # ``ChunkManager._restore_protrain_ddp_ignore_snapshot`` runs + # on ``close()``, applied inline so the rebuild path doesn't + # require the caller to close the prior chunk manager first. + if getattr(model, "_protrain_ddp_skip_init_sync", False): + try: + delattr(model, "_protrain_ddp_skip_init_sync") + except AttributeError: + pass + if hasattr(model, "_protrain_ddp_original_ignore"): + try: + _original = model._protrain_ddp_original_ignore # type: ignore[attr-defined] + if _original is None: + if hasattr(model, "_ddp_params_and_buffers_to_ignore"): + try: + delattr(model, "_ddp_params_and_buffers_to_ignore") + except AttributeError: + pass + else: + model._ddp_params_and_buffers_to_ignore = list(_original) # type: ignore[attr-defined] + try: + delattr(model, "_protrain_ddp_original_ignore") + except AttributeError: + pass + LOG.info( + "ProTrain (D1): rebuild path detected — stripped stale " + "DDP skip state from model so the rebuilt " + "runtime (non-shape-preserving) receives normal " + "init_sync + backward allreduce semantics." + ) + except Exception as _exc: # noqa: BLE001 — defensive + LOG.warning( + "ProTrain (D1): failed to strip stale DDP skip state on " + "rebuild: %s", + _exc, + ) + # ---- 4.6: build the CPU FusedAdam adapter (post-offload) ------------ # BUG 3 FIX: now that ``materialize_offload`` has allocated the pinned # CPU shards and installed per-param grad hooks, build the CPU Adam @@ -1495,7 +1902,7 @@ def _construct_runtime( # * Linear-layer weight tensors (``F.linear`` saves ``weight`` # for the input-grad recompute), which for transformer FFNs # can dwarf the block-output size (Llama-7B's gate/up_proj - # weight = hidden_size × intermediate_size ≈ 86 MB at bf16, + # weight = hidden_size x intermediate_size ≈ 86 MB at bf16, # vs. block output of 2 MB at bs=1 seq=256). # * Attention probabilities upcast to fp32, intermediate FFN # activations, etc. @@ -1603,7 +2010,7 @@ def _construct_runtime( if getattr(block, "_protrain_wrapped_mode", None) is _BM_swap.SWAP: block.attach_runtime(swap_pool, scheduler.swap_stream) LOG.info( - "ProTrain: SWAP pool wired — %d slots × %d bytes = %.2f MB " + "ProTrain: SWAP pool wired — %d slots x %d bytes = %.2f MB " "pinned (slot sized from max(act=%.2f MB, intra_op=%.2f MB, " "param=%.2f MB))", swap_pool.n_slot, @@ -1622,6 +2029,90 @@ def _construct_runtime( scheduler=scheduler, ) + # ---- 6.5: post-wrap re-registration of ``_ddp_params_and_buffers_to_ignore`` + # + # The earlier ignore-set registration used pre-block-wrap param names + # from ``chunk_manager.chunk_managed_param_names()``. Block wrappers + # (``block/checkpoint.py``, ``block/swap.py``, ``block/offload.py``) + # rebind the wrapped module as ``self.block = block``, so PyTorch's + # ``named_parameters()`` now injects a ``.block.`` infix + # (``layers.0.attn.q_proj.weight`` ⇒ + # ``layers.0.block.attn.q_proj.weight``). DDP's backward allreduce + # consults ``_ddp_params_and_buffers_to_ignore`` using post-wrap + # names, so a stale ignore set would let DDP all-reduce + # chunk-managed grads in conflict with ProTrain's per-chunk + # ``reduce_scatter`` drain. + # + # The chunk_manager's slot.param_id strings can't be rebuilt + # safely (other call sites still rely on them being stable), so + # rebuild the model attribute from the WRAPPED model by + # parameter-OBJECT identity: every chunk-managed + # ``nn.Parameter`` lives in ``chunk_manager._params_by_id``, + # so we walk the live ``model.named_parameters()`` and pick + # names whose param OBJECT matches one we own. + if _shape_preserving: + try: + # F-#1 fix: restrict the ignore-set membership to params + # backed by NON-PERSISTENT chunks. Persistent chunks + # explicitly need normal DDP broadcast / backward allreduce + # — see ``ChunkManager.chunk_managed_param_names``'s + # docstring (Returns section lines 2008-2011): "Persistent + # chunks are excluded — their params stay GPU-resident, + # do not pass through the released-state placeholder, and + # DO need the standard DDP broadcast for correctness." The + # initial R4-#1 patch built ``chunk_managed_param_ids`` from + # ALL ``_params_by_id.values()`` which silently swept the + # persistent params into the ignore set, breaking + # gradient sync on the chunks DDP IS supposed to handle. + chunk_managed_param_ids: set[int] = set() + for _cid in chunk_manager._non_persistent_ids: + _slots = chunk_manager._cpu_slots.get(_cid) + if not _slots: + continue + for _cpu_slot in _slots: + # ``_cpu_slot`` is renamed from a more natural + # ``slot`` to avoid shadowing the ``slot`` int + # binding the block-wrap site uses earlier in + # this function (``for slot, child in + # enumerate(parent)``). mypy carries the int type + # forward across the function scope and would + # otherwise flag this iteration as + # ``Incompatible types in assignment``. + _p = chunk_manager._params_by_id.get(_cpu_slot.param_id) + if _p is not None: + chunk_managed_param_ids.add(id(_p)) + post_wrap_ignore: set[str] = { + live_name + for live_name, live_param in model.named_parameters() + if id(live_param) in chunk_managed_param_ids + } + # Combine with the pre-protrain snapshot (the D2 lifecycle + # invariant — see ``ChunkManager.materialize_offload``) + # so any caller-registered ignore name survives. + _original = getattr(model, "_protrain_ddp_original_ignore", None) + if _original is None: + model._ddp_params_and_buffers_to_ignore = list(post_wrap_ignore) # type: ignore[attr-defined] + else: + model._ddp_params_and_buffers_to_ignore = list( # type: ignore[attr-defined] + set(_original) | post_wrap_ignore + ) + LOG.info( + "ProTrain (M6C-fix-8 / R4 post-wrap): re-registered " + "%d chunk-managed param names in " + "model._ddp_params_and_buffers_to_ignore using " + "post-block-wrap named_parameters() (DDP's backward " + "allreduce filter sees the .block.-infixed names).", + len(post_wrap_ignore), + ) + except Exception as _exc: # noqa: BLE001 — defensive + LOG.warning( + "ProTrain (M6C-fix-8 / R4 post-wrap): failed to " + "re-register _ddp_params_and_buffers_to_ignore after " + "block-wrap: %s. DDP's backward allreduce may attempt " + "to reduce chunk-managed param gradients.", + _exc, + ) + # ``capacity_bytes`` is unused inside the helper — kept in the # signature for symmetry with the wrapper's call site so a future # extension that derates by capacity (e.g. peak vs. budget headroom) @@ -1856,8 +2347,73 @@ def protrain_model_wrapper( sku=_sku(device), world=hardware_profile.gpu_count, ) + # Trace-pass override-skip gate. When the user has supplied all four + # explicit-override knobs (n_persist / n_buffer / n_swap / n_checkpoint) + # the searcher AND the cost model are bypassed downstream by the + # ``all_overrides_set`` branch. The trace pass itself becomes wasted + # work — and on big-model offload configurations (e.g. 30B + 4-bit, + # or 8B + 4-bit at seq=2048) the un-offloaded trace OOMs the device + # *before* chunk offload could engage. We therefore short-circuit + # the trace pass on this exact path: build a synthetic ProfilerTrace + # via ``synth_trace_from_overrides`` (op_order=(), analytical + # activation_sizes per discovered block, model_state_bytes from + # _count_model_state_bytes, measured pcie if CUDA is available) and + # bypass ``run_trace`` entirely. This mirrors the existing + # ``force_all_persistent`` short-circuit in trace.py:609-625 (which + # only suppresses on-demand engagement WITHIN the trace) by going one + # step further and skipping the trace itself when there is nothing + # the trace would inform. + # + # The synthetic trace is NOT saved to the on-disk cache — its + # activation_sizes are placeholders (analytical, not measured) and + # caching them would risk a future non-override run picking them up + # as if they were real. The cache key falls back to a normal + # cache-miss + run_trace on subsequent override-cleared runs. + _override_skip_trace = ( + n_persist_override is not None + and n_buffer_override is not None + and n_swap_override is not None + and n_checkpoint_override is not None + ) trace = load_cached_trace(cache_key, cache_dir=cache_dir) - if trace is None: + if trace is None and _override_skip_trace: + import sys as _sys + + LOG.info( + "ProTrain: explicit knob override path with cache miss — " + "synthesizing ProfilerTrace from defaults and SKIPPING the " + "trace pass (n_persist=%s n_buffer=%s n_swap=%s n_checkpoint=%s " + "n_offload=%s). This avoids the trace-pass OOM on big-model " + "offload configurations where the un-offloaded forward+backward " + "exceeds device memory before chunk offload can engage.", + n_persist_override, + n_buffer_override, + n_swap_override, + n_checkpoint_override, + n_offload_override, + ) + _sys.stderr.write( + "[protrain] override path: skipping trace pass, " + "synthesizing ProfilerTrace from defaults\n" + ) + _sys.stderr.flush() + trace = synth_trace_from_overrides( + model, + batch_size=batch_size, + seq_len=seq_len, + device=device, + world_size=int(hardware_profile.gpu_count), + ) + _sys.stderr.write( + f"[protrain] synth trace done: {len(trace.activation_sizes)} blocks " + f"(no op_order, no measured activations)\n" + ) + _sys.stderr.flush() + # Deliberately do NOT save to cache: the synthetic activation + # sizes are analytical placeholders, not measured values. A + # future non-override run on the same arch+bs+seq+sku+world key + # must not pick these up as real measurements. + elif trace is None: import sys as _sys LOG.info( @@ -1899,6 +2455,7 @@ def protrain_model_wrapper( device=str(device), include_backward=True, on_demand=True, + force_all_persistent=bool(force_all_persistent), world_size=int(hardware_profile.gpu_count), ) batch = _dummy_batch(model, batch_size, seq_len, device) @@ -2070,7 +2627,7 @@ def protrain_model_wrapper( ) # PCIe rates: overwrite the caller's hardcoded prior (usually 13e9 = # Gen3) with the profiler's measured H2D/D2H. A 3090 on PCIe Gen4 x16 - # sits around 50-56 GB/s — 4× the conservative default — and the + # sits around 50-56 GB/s — 4x the conservative default — and the # cost model's per-chunk comm is S_chunk / eff_h2d, so this flow- # through directly corrects the 7B over-prediction. if ( @@ -2080,6 +2637,18 @@ def protrain_model_wrapper( _hw_updates["pcie_h2d_bps"] = trace.pcie_h2d_bps if hardware_profile.pcie_d2h_bps <= 13e9 + 1e6 and trace.pcie_d2h_bps > 13e9 + 1e6: _hw_updates["pcie_d2h_bps"] = trace.pcie_d2h_bps + # Detect dominant param dtype to drive the per-dtype alpha + # fragmentation lookup. Default 2.0 (fp16/bf16) → alpha=1.10; + # bnb-4-bit weights drop bpe to 0.5 → alpha=0.75. Only stamp the + # profile when the detection differs from the caller-provided + # value AND the caller passed the default — so tests that + # explicitly hand-craft a profile with a specific bpe keep it. + _detected_bpe = _detect_dominant_param_bytes_per_element(model) + if ( + abs(hardware_profile.dominant_param_bytes_per_element - 2.0) < 1e-9 + and abs(_detected_bpe - 2.0) > 1e-9 + ): + _hw_updates["dominant_param_bytes_per_element"] = _detected_bpe if _hw_updates: hardware_profile = _replace(hardware_profile, **_hw_updates) @@ -2192,7 +2761,7 @@ def protrain_model_wrapper( # Replicate the searcher's two runtime-safety invariants. Without # these, the override path can ship configs that the searcher # would never select — e.g. an n_buffer too small for the - # scheduler's lookahead prefetch (current-block ∪ next-block + # scheduler's lookahead prefetch (current-block | next-block # non-persistent chunks must fit simultaneously) or a block_map # where a NONE block owns offloaded chunks (no activation-save # mechanism — autograd's saved tensors hold direct GPU storage @@ -2201,7 +2770,7 @@ def protrain_model_wrapper( # recomputes; OFFLOAD re-gathers via saved-tensors-hook; SWAP # persists each saved tensor to a pinned-CPU pool slot decoupled # from param.data — see ``block_map_runtime_admissible`` and - # the §6.6 SWAP × non-persistent lift in + # the §6.6 SWAP x non-persistent lift in # ``BLOCK_MODE_OFFLOAD_DESIGN.md``). min_buffer = min_n_buffer_for(layout, n_persist) if n_buffer < min_buffer: @@ -2507,7 +3076,7 @@ def protrain_model_wrapper( # are consumed by: # # * ``cost.runtime.estimate_runtime`` to derive - # α = phase2_iter_s / phase2_analytical_iter_s and scale + # alpha = phase2_iter_s / phase2_analytical_iter_s and scale # analytical-path predictions when the production cfg # bypasses the chunked-wall override (e.g. ``n_swap > 0``). # * ``_calibrate_peak_with_actual_chunk_bytes`` to apply @@ -2541,9 +3110,9 @@ def protrain_model_wrapper( ) ) # Per-component analytical decomposition at boot cfg - # (TRACE_VERSION 21). The per-component α calibration in + # (TRACE_VERSION 21). The per-component alpha calibration in # ``_compose_t_iter_with_alpha_calibration`` derives three - # independent scales — αfwd / αbwd / αopt — from the + # independent scales — alphafwd / alphabwd / alphaopt — from the # measured-vs-analytical ratios at the boot cfg. The # measured side is ``(fwd_s, bwd_s, step_s)`` from # ``measure_chunked_steady`` above; the analytical side is @@ -2569,8 +3138,8 @@ def protrain_model_wrapper( # measured step wall ≈ t_gpu_optim + (CPU-Adam tail). For # calibration we use the simpler additive # ``t_gpu_optim + t_cpu_optim`` as the analytical-step - # denominator — the αopt ratio absorbs the bwd-overlap - # difference uniformly so it's consistent with how αopt + # denominator — the alphaopt ratio absorbs the bwd-overlap + # difference uniformly so it's consistent with how alphaopt # is applied in :func:`_compose_t_iter_with_alpha_calibration`. phase2_analytical_fwd_s_val = float(t_fwd_boot) phase2_analytical_bwd_s_val = float(t_bwd_boot) @@ -2583,23 +3152,23 @@ def protrain_model_wrapper( phase2_iter_s_val = float(fwd_s + bwd_s + step_s) # Per-component-prediction anchor (TRACE_VERSION 22) for - # the residual-α multiplier. Compute what the per-component + # the residual-alpha multiplier. Compute what the per-component # formula in :func:`_compose_t_iter_with_alpha_calibration` - # WOULD predict at the boot cfg under the same αfwd / - # αbwd / αopt values that the cost model derives from the + # WOULD predict at the boot cfg under the same alphafwd / + # alphabwd / alphaopt values that the cost model derives from the # measured-vs-analytical ratios above. Crucially, this - # anchor uses the analytical-path composition (αfwd and - # αbwd both applied) — NOT the chunked-wall-override path + # anchor uses the analytical-path composition (alphafwd and + # alphabwd both applied) — NOT the chunked-wall-override path # the boot cfg's ``n_swap == 0`` would normally trigger — - # because the residual α generalises across cfgs that DO + # because the residual alpha generalises across cfgs that DO # take the analytical path (any prod cfg with ``n_swap > # 0``). At boot the override and analytical paths agree - # within αfwd/αbwd ≈ 1 anyway since the αs are calibrated + # within alphafwd/alphabwd ≈ 1 anyway since the alphas are calibrated # *against* the boot measurement; the residual captures # whatever whole-iter overhead bias remains after that # per-component correction. # - # Clamp αs to match the runtime composer's clamp so the + # Clamp alphas to match the runtime composer's clamp so the # anchor stays consistent with what the production path # actually applies (otherwise an out-of-clamp boot ratio # would skew the residual). @@ -2609,9 +3178,7 @@ def protrain_model_wrapper( ) def _clamp_for_anchor(x: float) -> float: - return max( - _PHASE2_ALPHA_CLAMP_MIN, min(_PHASE2_ALPHA_CLAMP_MAX, x) - ) + return max(_PHASE2_ALPHA_CLAMP_MIN, min(_PHASE2_ALPHA_CLAMP_MAX, x)) if ( phase2_analytical_fwd_s_val > 0.0 @@ -2639,7 +3206,7 @@ def _clamp_for_anchor(x: float) -> float: ) else: # Per-component baselines unavailable — leave the - # anchor zero so the residual α collapses to no-op. + # anchor zero so the residual alpha collapses to no-op. phase2_per_comp_pred_iter_s_val = 0.0 from dataclasses import replace as _replace @@ -2664,7 +3231,7 @@ def _clamp_for_anchor(x: float) -> float: phase2_analytical_fwd_s=phase2_analytical_fwd_s_val, phase2_analytical_bwd_s=phase2_analytical_bwd_s_val, phase2_analytical_step_s=phase2_analytical_step_s_val, - # Residual-α anchor (TRACE_VERSION 22). + # Residual-alpha anchor (TRACE_VERSION 22). phase2_per_comp_pred_iter_s=phase2_per_comp_pred_iter_s_val, ) try: @@ -2753,7 +3320,7 @@ def _clamp_for_anchor(x: float) -> float: # search's raw new pick (new_result.cfg) — NOT the # calibrated boot_result.cfg. The two used to diverge # because ``_construct_runtime`` widened ``cfg.n_persist`` - # to ``len(_persistent_ids)`` (the prefix ∪ non-block-chunk + # to ``len(_persistent_ids)`` (the prefix | non-block-chunk # pin set) post-calibration; that collapse has since been # removed (the augmented set is now plumbed through # ``layout.mandatory_persistent`` so the prefix is preserved @@ -2780,7 +3347,20 @@ def _clamp_for_anchor(x: float) -> float: block_map=new_result.block_map, hw=hardware_profile, ) - if calibrated_peak != new_result.predicted_peak_bytes: + # The init transient window has already passed by the + # time post-measurement calibration runs, so we REUSE + # the bootstrap-time prediction rather than recomputing + # from the post-offload chunk_manager — its + # ``_chunk_bytes()`` walk now sees zero-size placeholders + # (replicated path) or ``scratch.expand(slot.shape)`` + # views (sharded path) rather than full-residence + # tensors that drove the init-time peak. + init_transient_peak = boot_result.predicted_init_transient_peak_bytes + if ( + calibrated_peak != new_result.predicted_peak_bytes + or init_transient_peak + != new_result.predicted_init_transient_peak_bytes + ): # Preserve the search's prefix — see the matching # comment in ``_construct_runtime`` for why # ``len(_persistent_ids)`` (the augmented set) is @@ -2801,6 +3381,7 @@ def _clamp_for_anchor(x: float) -> float: block_map=new_result.block_map, predicted_peak_bytes=calibrated_peak, predicted_iter_s=new_result.predicted_iter_s, + predicted_init_transient_peak_bytes=init_transient_peak, ) LOG.info( "Phase-2: post-measurement search picked the same cfg " @@ -2876,7 +3457,8 @@ def _clamp_for_anchor(x: float) -> float: LOG.info( "ProTrain config: n_persist=%d n_buffer=%d n_swap=%d n_checkpoint=%d " - "S_chunk=%d N_chunk=%d peak=%.2f GiB iter=%.3f s capacity=%.2f GiB", + "S_chunk=%d N_chunk=%d peak=%.2f GiB iter1_transient=%.2f GiB " + "iter=%.3f s capacity=%.2f GiB", result.cfg.n_persist, result.cfg.n_buffer, result.cfg.n_swap, @@ -2884,6 +3466,7 @@ def _clamp_for_anchor(x: float) -> float: layout.S_chunk, layout.N_chunk, result.predicted_peak_bytes / (1 << 30), + result.predicted_init_transient_peak_bytes / (1 << 30), result.predicted_iter_s, capacity_bytes / (1 << 30), ) @@ -2915,6 +3498,11 @@ def _clamp_for_anchor(x: float) -> float: # Carry the user-supplied cache_dir so post_trainer_create's NCCL # re-measure path can persist the spliced trace under the same root. wrapped._cache_dir = cache_dir # type: ignore[attr-defined] + # Carry the override-skip flag so the plugin's late NCCL re-search + # also short-circuits when the user pinned every layout knob via + # explicit overrides — the runtime is already wired for the + # bootstrap plan and cannot be rebuilt mid-flight. + wrapped._override_skip_trace = bool(_override_skip_trace) # type: ignore[attr-defined] return wrapped @@ -3057,4 +3645,8 @@ def _find_block_parent_map( return out -__all__ = ["auto_wrap", "protrain_model_wrapper"] +__all__ = [ + "auto_wrap", + "predict_init_transient_peak_bytes", + "protrain_model_wrapper", +] diff --git a/src/axolotl/integrations/protrain/api/optim_wrapper.py b/src/axolotl/integrations/protrain/api/optim_wrapper.py index 009a1f8f11..5519401a61 100644 --- a/src/axolotl/integrations/protrain/api/optim_wrapper.py +++ b/src/axolotl/integrations/protrain/api/optim_wrapper.py @@ -35,6 +35,7 @@ from axolotl.integrations.protrain.chunk import ( CpuFusedAdamAdapter, + GpuAdamW8bitAdapter, GpuFusedAdamAdapter, ) from axolotl.integrations.protrain.types import ChunkId, WrappedModel @@ -59,7 +60,7 @@ class _ProTrainOptimizer(torch.optim.Optimizer): def __init__( self, - gpu_optim: GpuFusedAdamAdapter | None, + gpu_optim: GpuFusedAdamAdapter | GpuAdamW8bitAdapter | None, cpu_optim: CpuFusedAdamAdapter | None, params: list["nn.Parameter"], defaults: dict[str, Any], @@ -602,6 +603,27 @@ def _split_optim_param_groups( inner.param_groups = new_groups +#: Axolotl / HF Trainer optimizer-name strings that route the persistent +#: chunk set through ``GpuAdamW8bitAdapter`` instead of +#: ``GpuFusedAdamAdapter``. ``adamw_8bit`` and ``adamw_bnb_8bit`` are +#: aliases in HF's ``OptimizerNames`` (training_args.py:128-129) that both +#: dispatch to ``bnb.optim.AdamW`` with ``optim_bits=8``; we accept both +#: spellings so users carrying configs from either origin work without +#: edits. ``paged_adamw_8bit`` selects the paged variant (UVM-backed +#: state) for the same set. +_BNB_8BIT_OPTIMIZERS: frozenset[str] = frozenset( + {"adamw_8bit", "adamw_bnb_8bit", "paged_adamw_8bit"} +) +_BNB_8BIT_PAGED_OPTIMIZERS: frozenset[str] = frozenset({"paged_adamw_8bit"}) + + +def _normalize_optimizer_name(name: str | None) -> str | None: + """Lower-case + strip whitespace, unwrapping ``OptimizerNames`` enums via ``.value``.""" + if name is None: + return None + return str(getattr(name, "value", name)).strip().lower() + + def protrain_optimizer_wrapper( wrapped: WrappedModel, *, @@ -609,6 +631,7 @@ def protrain_optimizer_wrapper( betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, + optimizer_name: str | None = None, ) -> torch.optim.Optimizer: """Rebuild the GPU/CPU FusedAdam adapters at user-specified hyperparams. @@ -695,16 +718,40 @@ def protrain_optimizer_wrapper( else: cpu_params_per_chunk[ChunkId(cid)] = chunk_params - gpu_optim: GpuFusedAdamAdapter | None = None + # bnb 8-bit Adam kernels are CUDA-only, so only the persistent + # (GPU-resident) chunk set can use the 8-bit adapter; non-persistent + # CPU shards keep the 32-bit DeepSpeedCPUAdam path. + normalized_optim_name = _normalize_optimizer_name(optimizer_name) + use_bnb_8bit = normalized_optim_name in _BNB_8BIT_OPTIMIZERS + use_paged_8bit = normalized_optim_name in _BNB_8BIT_PAGED_OPTIMIZERS + + gpu_optim: GpuFusedAdamAdapter | GpuAdamW8bitAdapter | None = None cpu_optim: CpuFusedAdamAdapter | None = None if persistent_params: - gpu_optim = GpuFusedAdamAdapter( - params=persistent_params, - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - ) + if use_bnb_8bit: + LOG.info( + "protrain_optimizer_wrapper: routing %d persistent params " + "through bnb %s (optimizer_name=%s)", + len(persistent_params), + "PagedAdamW8bit" if use_paged_8bit else "AdamW8bit", + optimizer_name, + ) + gpu_optim = GpuAdamW8bitAdapter( + params=persistent_params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + paged=use_paged_8bit, + ) + else: + gpu_optim = GpuFusedAdamAdapter( + params=persistent_params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) # M7: for sharded non-persistent chunks the CPU Adam updates each # :class:`_DtypeRegion`'s flat shard_param (one per rank slice per @@ -722,6 +769,37 @@ def protrain_optimizer_wrapper( else: cpu_params_per_chunk_for_optim[cid] = chunk_params + if use_bnb_8bit and any( + params for params in cpu_params_per_chunk_for_optim.values() + ): + # bnb 8-bit Adam requires CUDA tensors; non-persistent chunks + # live on CPU. We keep the + # 32-bit CpuFusedAdamAdapter on those chunks so training stays + # correct (and the user still gets the persistent-chunk 8-bit + # win from above). Surface this once, loudly, so users + # configuring `adamw_8bit` aren't surprised by the partial + # adoption. + n_cpu_chunks = sum( + 1 for params in cpu_params_per_chunk_for_optim.values() if params + ) + LOG.warning( + "protrain_optimizer_wrapper: optimizer_name=%s requested 8-bit " + "AdamW, but %d non-persistent chunk(s) live on CPU and bnb's " + "8-bit Adam kernels are CUDA-only. Those chunks will keep " + "using 32-bit DeepSpeedCPUAdam (still correct, but the " + "optimizer-state memory win applies only to the persistent " + "set). To get end-to-end 8-bit, configure ProTrain to force " + "all chunks persistent (Mode A): set " + "``protrain_auto_mode: false`` AND " + "``protrain_force_all_persistent: true`` together — " + "``protrain_force_all_persistent`` is ignored while " + "``protrain_auto_mode`` is on (the auto-mode selector picks " + "the mode itself based on capacity), so disabling auto-mode " + "first is required for the Mode-A override to take effect.", + optimizer_name, + n_cpu_chunks, + ) + if any(params for params in cpu_params_per_chunk_for_optim.values()): try: cpu_optim = CpuFusedAdamAdapter( @@ -827,9 +905,40 @@ def protrain_optimizer_wrapper( # Swap the freshly-built adapters into the chunk manager so the # scheduler's post_block_backward -> reduce_grads_and_offload -> - # cpu_optim.step_async chain uses them. + # cpu_optim.step_async chain uses them. The chunk manager's + # ``gpu_optim`` slot is typed ``GpuFusedAdamAdapter | None`` (the + # legacy adapter); the ``GpuAdamW8bitAdapter`` is duck-compat + # at the call sites that consume the slot (``.step()``, + # ``.zero_grad()``, ``.state_dict()`` — see + # :class:`GpuAdamW8bitAdapter`). We assign through a typing cast + # rather than widening the chunk manager's type signature, which + # would touch a read-only file from this milestone's perspective. + # + # D3 lifecycle (shutdown-before-swap): ``CpuFusedAdamAdapter`` owns + # a live ``ThreadPoolExecutor`` and per-chunk DeepSpeedCPUAdam + # C-state; overwriting ``chunk_manager.cpu_optim`` without first + # tearing the old adapter down leaks executor threads + DeepSpeed + # state on every re-wrap (e.g. the resume hook's "Step 1" tears + # the adapter down at the plugin layer, but a direct second + # ``protrain_optimizer_wrapper`` invocation — e.g. user reruns the + # wrapper after changing optim hyperparams without going through + # the HF Trainer resume path — would otherwise GC-time the + # cleanup). Mirrors the same teardown the resume hook performs + # before ``restore_to_gpu``. + _old_cpu_optim = getattr(chunk_manager, "cpu_optim", None) + if _old_cpu_optim is not None and _old_cpu_optim is not cpu_optim: + # F-#3 (Major): let ``shutdown()`` failures abort the swap + # rather than warning-and-continuing. The whole point of + # calling ``shutdown()`` here is the D3 deterministic-cleanup + # invariant — masking a real teardown failure (e.g., + # ``ThreadPoolExecutor`` hung, DeepSpeed C-state corrupted) + # puts the failed adapter back on the GC path AND silently + # accepts a broken state-machine on the rebuild side. If the + # shutdown raises, the rebuild is in an inconsistent state + # and the call should fail rather than silently degrading. + _old_cpu_optim.shutdown() chunk_manager.cpu_optim = cpu_optim - chunk_manager.gpu_optim = gpu_optim + chunk_manager.gpu_optim = cast("GpuFusedAdamAdapter | None", gpu_optim) # Build the flat param list for the Optimizer base class. all_params: list["nn.Parameter"] = list(persistent_params) diff --git a/src/axolotl/integrations/protrain/args.py b/src/axolotl/integrations/protrain/args.py index 391db46d38..3def484d5e 100644 --- a/src/axolotl/integrations/protrain/args.py +++ b/src/axolotl/integrations/protrain/args.py @@ -64,6 +64,33 @@ ) +# Strict allow-list of Axolotl/HF optimizer names that ProTrain's chunk +# manager + per-chunk adapters can drive correctly. The set is the union +# of names dispatched by ``api/optim_wrapper.protrain_optimizer_wrapper``: +# +# * ``adamw_torch`` / ``adamw_torch_fused`` — default route through +# ``GpuFusedAdamAdapter`` (Apex FusedAdam, falls back to +# ``torch.optim.AdamW``) for persistent chunks and +# ``CpuFusedAdamAdapter`` (DeepSpeedCPUAdam) for non-persistent chunks. +# * ``adamw_8bit`` / ``adamw_bnb_8bit`` / ``paged_adamw_8bit`` — +# route persistent chunks through ``GpuAdamW8bitAdapter`` +# (``bnb.optim.AdamW8bit`` / ``bnb.optim.PagedAdamW8bit``). +# +# All other optimizer names (Lion, Adafactor, GaLore, Sophia, Muon, +# torchao, plain SGD, etc.) have state shapes that do not match the +# AdamW-shaped adapters and are silently broken — the validator below +# rejects them at config-load time. +_SUPPORTED_OPTIMIZERS: frozenset[str] = frozenset( + { + "adamw_torch", + "adamw_torch_fused", + "adamw_8bit", + "adamw_bnb_8bit", + "paged_adamw_8bit", + } +) + + def _has_protrain_plugin(plugins) -> bool: """Return True iff the iterable contains an explicit ProTrain plugin id. @@ -121,8 +148,12 @@ class ProTrainArgs(BaseModel): "trainer. Requires " "``plugins: [axolotl.integrations.protrain.ProTrainPlugin]``. " "Mutually exclusive with DeepSpeed, FSDP, gradient_checkpointing, " - "TP/CP/SP > 1, and load_in_8bit/load_in_4bit (see " - "`_reject_incompatible_features`)." + "and TP/CP/SP > 1 (see `_reject_incompatible_features`). " + "Composes with bitsandbytes ``load_in_8bit`` / ``load_in_4bit`` " + "(M2/M3 validated; ``Params4bit`` / ``Int8Params`` survive the " + "chunk gather/offload path because ``quant_state`` lives as a " + "Python attribute on the param and ``chunk/manager.py`` rebinds " + "``param.data`` without touching python attrs)." ) }, ) @@ -269,10 +300,7 @@ class ProTrainArgs(BaseModel): }, ) - # ------------------------------------------------------------------ - # Optimizer-state checkpoint/resume (CHECKPOINT_DESIGN.md Phase 1, - # CHECKPOINT_DESIGN_PHASE2.md Modes B + C) - # ------------------------------------------------------------------ + # Optimizer-state checkpoint/resume. protrain_save_optimizer_state: bool | None = Field( default=False, @@ -426,10 +454,17 @@ def _reject_incompatible_features(cls, data): ``sequence_parallel_degree`` > 1 — scope-excluded per plan.md (M6 single-3090 focus); the chunk layout does not shard correctly across TP/CP ranks in this milestone. - * ``load_in_8bit`` / ``load_in_4bit`` — bnb weight quantization - wraps ``nn.Linear.weight`` in a non-owning proxy. The chunk - manager reads unquantized storage for gather / offload and - cannot reason about the 8-bit / 4-bit packed buffers. + + Note: ``load_in_8bit`` / ``load_in_4bit`` are NOT in this mutex + list. M0 spike + M2/M3 audit validation established that bnb + weight quantization composes with ProTrain in both Mode A + (all-persistent) AND offload mode — ``Params4bit.data`` and + ``Int8Params.data`` are uint8/int8 storage tensors, so the + chunk manager's ``numel * element_size`` byte math handles them + correctly, and ``quant_state`` lives as a Python attribute on + the param instance and survives ``param.data`` rebinding (see + ``chunk/manager.py``). Pinned by + ``tests/protrain/test_bnb_offload.py``. Each rejection surfaces at config-load time rather than as a silent mis-training run. @@ -500,19 +535,44 @@ def _reject_incompatible_features(cls, data): "(scope-excluded per plan.md — single-3090 target). Set " "sequence_parallel_degree=1 or remove the ProTrain plugin." ) - if data.get("load_in_8bit"): - raise ValueError( - "ProTrain is incompatible with load_in_8bit=true (bitsandbytes " - "8-bit quantization wraps nn.Linear.weight in a non-owning proxy; " - "the chunk manager operates on unquantized storage for gather / " - "offload). Set load_in_8bit=false or remove the ProTrain plugin." - ) - if data.get("load_in_4bit"): + # M0 spike + M3 audit validation: bnb 8-bit / 4-bit weights compose with + # ProTrain in BOTH Mode A (all-persistent) AND offload mode (Mode C / single-GPU + # n_persist_override None: if n_persist < 0 or n_persist > layout.N_chunk: raise ValueError( @@ -541,6 +542,10 @@ def __init__( # tensor per param (cheap but not free). self._empty_by_dtype: dict["torch.dtype", "torch.Tensor"] = {} + # Opt-in: bind released param.data to a zero-stride view of the real shape so autograd captures the logical shape instead of [0]. + self._shape_preserving_placeholders: bool = bool(shape_preserving_placeholders) + self._shape_scratch_by_dtype: dict["torch.dtype", "torch.Tensor"] = {} + # Per-chunk grad-drain counter: decremented by _offload_grad for # every trainable param in the chunk; when it hits zero we kick # off the async CPU Adam step (Gap 2). @@ -646,15 +651,7 @@ def mark_persistent(self, first_n: int) -> None: for i in range(self.layout.N_chunk) if cast(ChunkId, i) not in new_persistent_ids } - # CodeRabbit R2-04 fix: once chunks have been materialized into - # CPU placeholder slots or persistent GPU buffers, the residency - # split is baked into the runtime state — a previously offloaded - # chunk newly tagged persistent would early-return in ``gather`` - # while its params still point at empty GPU placeholders, and a - # previously persistent chunk newly tagged non-persistent would - # have no ``_cpu_slots`` to drain grads into. Reject the change - # so the failure surfaces immediately rather than as silent - # weight corruption many steps later. + # After materialization the residency split is baked in; flipping it would silently corrupt weights since gather/offload paths skip already-resident chunks. if (self._cpu_slots or self._persistent_buffers) and ( new_persistent_ids != self._persistent_ids or new_non_persistent_ids != self._non_persistent_ids @@ -861,9 +858,7 @@ def _align_up(n: int, a: int) -> int: if param is None: continue dtype_here = param.data.dtype - # CodeRabbit R07 fix: split regions on requires_grad - # in addition to dtype so each region is uniformly - # trainable or uniformly frozen. + # Region must be uniformly trainable or uniformly frozen so grad allocation matches. trainable_here = bool(param.requires_grad) param_end = off + nbytes if cur_dtype is None: @@ -1130,9 +1125,11 @@ def _align_up(n: int, a: int) -> int: cpu_param = cpu_view.view(dtype).view(shape) cpu_param.copy_(orig_data) - # Release GPU storage by rebinding .data to an empty - # placeholder of the same dtype. - param.data = self._empty_placeholder(dtype) + # Release GPU storage; opt-in shape-preserving placeholder keeps param.size() correct for autograd while released. + if self._shape_preserving_placeholders: + param.data = self._shape_preserving_placeholder(shape, dtype) + else: + param.data = self._empty_placeholder(dtype) # Pinned CPU grad shadow for trainable params (replicated # only). In sharded mode the per-region shard buffer @@ -1231,12 +1228,7 @@ def _align_up(n: int, a: int) -> int: ) region_param_off += r_shard_bytes - # CodeRabbit R07 fix: only allocate the pinned grad - # shard for trainable regions. Frozen-only regions - # never receive a reduce/copy in - # :meth:`reduce_grads_and_offload`; binding a - # zero-grad view as ``shard_param.grad`` would - # let Adam's weight-decay rewrite frozen bytes. + # Frozen regions get no grad shard; otherwise Adam's weight decay would rewrite frozen bytes. cpu_region_grad: "torch.Tensor | None" = None if r_is_trainable: assert chunk_grad_view is not None @@ -1318,6 +1310,39 @@ def _align_up(n: int, a: int) -> int: precise_grad, freed / 1e9, ) + + # Rebuild model._ddp_params_and_buffers_to_ignore from the pre-protrain snapshot + current chunk-managed names so a re-materialize after resume cannot accumulate stale names. + if self._shape_preserving_placeholders and self.model is not None: + try: + protrain_set = self.chunk_managed_param_names() + if not hasattr(self.model, "_protrain_ddp_original_ignore"): + _pre_existing = getattr( + self.model, "_ddp_params_and_buffers_to_ignore", None + ) + # Distinguish unset (None) from empty-list so teardown can restore exactly. + self.model._protrain_ddp_original_ignore = ( # type: ignore[attr-defined] + None if _pre_existing is None else list(_pre_existing) + ) + _original = self.model._protrain_ddp_original_ignore # type: ignore[attr-defined] + if _original is None: + self.model._ddp_params_and_buffers_to_ignore = list(protrain_set) # type: ignore[attr-defined] + else: + self.model._ddp_params_and_buffers_to_ignore = list( # type: ignore[attr-defined] + set(_original) | protrain_set + ) + LOG.info( + "ChunkManager.materialize_offload: rebuilt " + "model._ddp_params_and_buffers_to_ignore from snapshot " + "+ %d chunk-managed names (pre-protrain original: %s)", + len(protrain_set), + "" if _original is None else f"{len(_original)} names", + ) + except Exception as _exc: # noqa: BLE001 — defensive + LOG.warning( + "ChunkManager.materialize_offload: failed to register " + "_ddp_params_and_buffers_to_ignore on model: %s", + _exc, + ) return freed def _close_cpu_pools(self) -> None: @@ -1634,6 +1659,8 @@ def _alloc_empty(shape, dtype): # placeholders are unreferenced from torch's perspective. Drop # the dict so the next gather builds fresh ones if needed. self._empty_by_dtype.clear() + # Symmetric teardown with _empty_by_dtype; the rebind above already dropped any aliases. + self._shape_scratch_by_dtype.clear() # Release + close the unified pinned pools. # @@ -1728,6 +1755,40 @@ def _empty_placeholder(self, dtype: "torch.dtype") -> "torch.Tensor": self._empty_by_dtype[dtype] = t return t + def _shape_preserving_placeholder( + self, + shape: "torch.Size | tuple[int, ...]", + dtype: "torch.dtype", + ) -> "torch.Tensor": + """Return a zero-stride view of ``shape``/``dtype`` so released params keep their real shape for autograd.""" + import torch + + from axolotl.integrations.protrain.runtime.streams import ( + SingleStreamAllocator, + ) + + # Materialize-or-fetch the per-dtype 1-element scratch. + scratch = self._shape_scratch_by_dtype.get(dtype) + if scratch is None: + if self.device.type == "cuda" and torch.cuda.is_available(): + with SingleStreamAllocator(): + scratch = torch.empty(1, device=self.device, dtype=dtype) + else: + scratch = torch.empty(1, device=self.device, dtype=dtype) + self._shape_scratch_by_dtype[dtype] = scratch + + if shape == torch.Size([]): + return scratch.view(()) + return scratch.expand(tuple(shape)) + + def chunk_managed_param_names(self) -> set[str]: + """Return param names backed by released (non-persistent) chunks so DDP can be told to ignore them.""" + names: set[str] = set() + for cid in self._non_persistent_ids: + for slot in self._cpu_slots.get(cid, []): + names.add(str(slot.param_id)) + return names + def _make_grad_offload_hook(self, chunk_id: ChunkId, slot: _CpuParamSlot): """Build a post-accumulate grad hook for one trainable non-persistent param. @@ -1807,18 +1868,7 @@ def _hook(param: "nn.Parameter") -> None: remaining = cm._grad_remaining.get(captured_cid, 0) - 1 cm._grad_remaining[captured_cid] = remaining if remaining == 0: - # All of the chunk's trainable params are drained. The - # CPU FusedAdam adapter is responsible for actually - # updating the offloaded weights — without it, the CPU - # master shards never advance and every offloaded chunk - # silently retains its iter-0 weights forever. - # - # CodeRabbit R2-05 fix: fail fast the FIRST time an - # offloaded chunk reaches its CPU-step path with no - # ``cpu_optim`` attached. Prior code skipped the - # ``step_async`` and just reset ``_grad_remaining`` so - # the next backward could fire again — which masked the - # missing optimizer behind silently stale weights. + # Fail fast on missing cpu_optim; skipping it would silently retain iter-0 weights on every offloaded chunk. if cm.cpu_optim is None: raise RuntimeError( "ChunkManager: missing CPU optimizer for offloaded " @@ -1919,7 +1969,12 @@ def _repoint() -> None: # trainable slots round-trip through this callback. if param.data.device.type != "cpu": continue - param.data = cm._empty_placeholder(slot.dtype) + if cm._shape_preserving_placeholders: + param.data = cm._shape_preserving_placeholder( + slot.shape, slot.dtype + ) + else: + param.data = cm._empty_placeholder(slot.dtype) # Also clear grad: we've consumed it in the CPU step, # and leaving param.grad pointing at the CPU grad shard # means iter N+1's autograd would accumulate new GPU @@ -2407,7 +2462,10 @@ def offload(self, chunk_id: ChunkId) -> None: # post-step repoint will null it back to a GPU placeholder. if param.data.device.type == "cpu": continue - param.data = self._empty_placeholder(slot.dtype) + if self._shape_preserving_placeholders: + param.data = self._shape_preserving_placeholder(slot.shape, slot.dtype) + else: + param.data = self._empty_placeholder(slot.dtype) self.buffer_pool.release(chunk_id) # Symmetric with the ``_active_chunks.add`` in ``gather()``: # the gather-side lease has been released, so the next gather @@ -2447,25 +2505,7 @@ def reduce_grads_and_offload(self, chunk_id: ChunkId) -> None: # when it detects DDP composition) tells us to leave the # grads alone. # - # In the non-DDP distributed path (e.g. a bare ZeRO-3 run - # or Mode-A-no-DDP / Mode-C-no-DDP) the flag is False and - # we own the cross-rank reduction. To minimize NCCL launch - # latency on small persistent chunks (Item 5 profiling - # showed ~19 ops × 17MB unbucketed on a Llama-3B 4-GPU run, - # ~30 ms / 1300 ms iter), we COALESCE every same-dtype grad - # in the chunk into a single flat buffer and issue one - # ``all_reduce`` per dtype group. PyTorch's - # ``_flatten_dense_tensors`` / ``_unflatten_dense_tensors`` - # is the same primitive DDP uses internally; it handles - # the contiguous-buffer staging and the per-tensor view - # restoration without any copy back when the grads were - # already contiguous (the common case). - # - # Mixed-dtype chunks (e.g. fp16 attention weights next to - # fp32 layernorm scales in a Llama block) issue ONE - # all_reduce per dtype run, not one per param. Homogeneous - # chunks issue exactly one collective — the structurally - # cleanest case. + # When ProTrain owns the cross-rank reduction (no outer DDP), coalesce same-dtype grads into one all_reduce per dtype to cut NCCL launch latency. if ( torch.distributed.is_available() and torch.distributed.is_initialized() @@ -2490,29 +2530,7 @@ def reduce_grads_and_offload(self, chunk_id: ChunkId) -> None: self.offload(chunk_id) def _coalesced_all_reduce_persistent_grads(self, chunk_id: ChunkId) -> None: - """Bucket persistent-chunk grads by dtype and issue one all_reduce per bucket. - - Replaces the per-param ``dist.all_reduce`` loop that dominated - launch latency on the Mode-C / Mode-A-no-DDP path (Item 5 - profiling: 19 ops × 17MB unbucketed → ~30 ms/iter). Equivalent - to PyTorch DDP's internal bucketed allreduce (which uses the - same ``_flatten_dense_tensors`` primitive). - - Algorithm: - - 1. Group every live ``param.grad`` in ``chunk_id`` by dtype. - 2. For each dtype group: flatten into one contiguous buffer, - ``all_reduce(op=AVG)`` it once, then unflatten back to - per-param views and copy each view into the original - ``param.grad``. The copy_back handles the case where - ``_flatten_dense_tensors`` materialized a fresh buffer (it - always does — the input grads' storage is independent). - - Mixed-dtype chunks (Llama: fp16 weights + fp32 RMSNorm scales) - issue one collective per dtype run, exactly like the sharded - path's per-region collectives. Empty chunks issue zero - collectives. - """ + """Bucket persistent-chunk grads by dtype and issue one all_reduce per bucket.""" import torch.distributed as dist from torch._utils import ( _flatten_dense_tensors, @@ -2632,18 +2650,7 @@ def _reduce_scatter_and_offload_shard( d2h_event = None any_trainable_region = False for region in shard_state.regions: - # CodeRabbit R07 fix: skip frozen-only regions outright. - # Their ``shard_param`` was constructed with - # ``requires_grad=False`` and ``cpu_shard_grad_bytes=None``; - # there is nothing to reduce or D2H here. Running the - # collective + binding a zero-grad view as - # ``shard_param.grad`` would re-introduce the original - # bug — Adam's weight-decay path would mutate frozen - # bytes against a silently-zero grad. The trainability - # flag is authoritative because region segmentation in - # :meth:`materialize_offload` splits on ``requires_grad``, - # so any param contributing bytes to a frozen region is - # guaranteed itself frozen and will never produce a grad. + # Frozen regions have no grad shard; reducing here would let weight-decay mutate frozen bytes. if not region.is_trainable: continue any_trainable_region = True @@ -2738,16 +2745,7 @@ def _reduce_scatter_and_offload_shard( else: region.shard_param.grad.copy_(my_shard_grad_gpu) # type: ignore[union-attr] - # CodeRabbit R2-05 fix: if we just reduce_scatter'd / D2H'd grads - # for at least one trainable region but no CPU optimizer is - # attached, the offloaded master weights would silently never - # advance. Raise BEFORE resetting ``_grad_remaining`` so the - # next backward fires the same condition again rather than - # silently masking the bad state. Distinct from the R07 - # frozen-region guard above (which is about ``is_trainable`` - # per region — purely a routing concern within this loop): - # this check fires when at least one trainable region exists - # and the chunk-level ``cpu_optim`` hook is missing entirely. + # Raise before resetting ``_grad_remaining`` so a missing cpu_optim re-fires next backward instead of silently retaining stale weights. if any_trainable_region and self.cpu_optim is None: raise RuntimeError( "ChunkManager: missing CPU optimizer for offloaded " @@ -2797,27 +2795,40 @@ def uninstall(self) -> None: LOG.debug("ChunkManager.uninstall: hook remove failed: %s", exc) self._grad_hook_handles.clear() + def _restore_protrain_ddp_ignore_snapshot(self) -> None: + """Restore ``model._ddp_params_and_buffers_to_ignore`` to its pre-protrain snapshot so teardown leaves no residue.""" + model = self.model + if model is None: + return + if not hasattr(model, "_protrain_ddp_original_ignore"): + return + try: + _original = model._protrain_ddp_original_ignore # type: ignore[attr-defined] + if _original is None: + if hasattr(model, "_ddp_params_and_buffers_to_ignore"): + try: + delattr(model, "_ddp_params_and_buffers_to_ignore") + except AttributeError: + pass + else: + model._ddp_params_and_buffers_to_ignore = list(_original) # type: ignore[attr-defined] + try: + delattr(model, "_protrain_ddp_original_ignore") + except AttributeError: + pass + LOG.info( + "ChunkManager: restored model._ddp_params_and_buffers_to_ignore " + "to pre-protrain snapshot (%s)", + "absent" if _original is None else f"{len(_original)} names", + ) + except Exception as exc: # noqa: BLE001 — best-effort + LOG.debug( + "ChunkManager._restore_protrain_ddp_ignore_snapshot failed: %s", + exc, + ) + def close(self) -> None: - """Tear down every manager-owned resource. Idempotent. - - Cascade order matters: - - 1. Drain + shut down the CPU optimizer worker pool so no - background thread can touch ``_cpu_slots`` / ``_cpu_grad_pool`` - bytes after we drop them. - 2. ``uninstall()`` — drop the per-param grad hooks so a - late-firing autograd path cannot reach into the freed pools. - 3. Clear ``_cpu_slots`` / ``_chunk_shards`` / ``_persistent_buffers`` - and the various per-chunk bookkeeping dicts BEFORE freeing - the pinned pools — every per-slot ``cpu_data`` / ``cpu_grad`` - view borrows from the unified pool, and live borrows would - block ``PinnedHostMemory.close``. - 4. ``_close_cpu_pools()`` — release the borrow on slot 0 and - free both pinned regions. - 5. Close the GPU buffer pool (drops its slot tensors and the - paired pinned-host region). - 6. Drop adapter references. - """ + """Tear down every manager-owned resource. Idempotent.""" if self._closed: return self._closed = True @@ -2844,6 +2855,7 @@ def close(self) -> None: self._grad_initial.clear() self._chunk_bytes_by_id.clear() self._empty_by_dtype.clear() + self._shape_scratch_by_dtype.clear() try: self._close_cpu_pools() @@ -2860,6 +2872,20 @@ def close(self) -> None: self.cpu_optim = None self.gpu_optim = None + # D2 lifecycle teardown: restore the model's pre-protrain + # ``_ddp_params_and_buffers_to_ignore`` snapshot so a future + # non-protrain DDP wrap of the same model is not constrained + # by our ignore set. No-op if we never snapshotted (single-GPU + # / replicated paths where ``shape_preserving_placeholders`` is + # False). + try: + self._restore_protrain_ddp_ignore_snapshot() + except Exception as exc: # noqa: BLE001 — best-effort + LOG.debug( + "ChunkManager.close: snapshot restore failed: %s", + exc, + ) + def __del__(self) -> None: # noqa: D401 try: self.uninstall() @@ -3017,19 +3043,7 @@ def shard_bytes_for(self, chunk_id: ChunkId) -> int: return 0 if s is None else s.shard_bytes def per_rank_cpu_bytes(self) -> int: - """Total pinned CPU bytes this rank holds across every sharded chunk. - - Sums BOTH the per-region shard buffer (``cpu_shard_bytes``) and - the per-region grad buffer (``cpu_shard_grad_bytes``) when - present. ``cpu_shard_bytes`` is allocated for every sharded - region; ``cpu_shard_grad_bytes`` is allocated only for trainable - regions (frozen-only regions skip it as part of the CodeRabbit - R07 fix — no Adam step, no need for the pinned grad shard). - Convenience accessor for the 4-GPU sharding test which asserts - per-rank CPU footprint roughly equals - ``total_non_persistent_bytes / world_size`` and for benchmark - scripts reporting Mode-C host RAM. - """ + """Total pinned CPU bytes this rank holds across every sharded chunk (shard buffers plus per-trainable-region grad buffers).""" total = 0 for shard_state in self._chunk_shards.values(): for region in shard_state.regions: diff --git a/src/axolotl/integrations/protrain/chunk/optim.py b/src/axolotl/integrations/protrain/chunk/optim.py index ba9f135c19..9b073c6d22 100644 --- a/src/axolotl/integrations/protrain/chunk/optim.py +++ b/src/axolotl/integrations/protrain/chunk/optim.py @@ -504,11 +504,133 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: @property def underlying(self) -> Any: - """The wrapped optimizer instance (useful for LR schedulers). + """Return the wrapped optimizer (None when adapter has no persistent params).""" + return self._optim - ``None`` when the adapter wraps an empty persistent param set. - """ + +# bnb 8-bit Adam kernels are CUDA-only, so this adapter is restricted to persistent (GPU-resident) chunks; non-persistent chunks must use the CPU FusedAdam adapter. + + +class GpuAdamW8bitAdapter: + """Synchronous bitsandbytes 8-bit AdamW for persistent (GPU-resident) chunks.""" + + def __init__( + self, + params: Iterable["nn.Parameter"], + lr: float, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + paged: bool = False, + ) -> None: + """Build the underlying ``bnb.optim.AdamW8bit`` (or paged variant) over ``params``.""" + param_list = [p for p in params if p is not None] + + self.lr = float(lr) + self.betas = (float(betas[0]), float(betas[1])) + self.eps = float(eps) + self.weight_decay = float(weight_decay) + self.paged = bool(paged) + + if len(param_list) == 0: + self._optim = None + return + + # Defer the bitsandbytes import: ``optim_wrapper`` only constructs + # this adapter when the user explicitly opts into an 8-bit + # optimizer name, so we must not pay the bnb import cost (it + # JIT-loads CUDA libraries) on every protrain bring-up. + try: + from bitsandbytes.optim import ( # type: ignore[import-not-found] + AdamW8bit, + PagedAdamW8bit, + ) + except (ImportError, RuntimeError) as err: + # ``bitsandbytes`` JIT-loads CUDA libraries on import; if the + # extension cannot be linked against the active CUDA toolkit + # the failure surfaces as ``RuntimeError`` rather than the + # canonical ``ImportError``. Catch both so callers see the + # adapter-level message instead of an opaque loader trace. + # Mirrors :class:`GpuFusedAdamAdapter`'s apex-import guard + # earlier in this module. + raise ImportError( + "GpuAdamW8bitAdapter requires `bitsandbytes` (>=0.41) for " + "the 8-bit AdamW kernels. Install via " + "`pip install bitsandbytes`." + ) from err + + # Sanity check: bnb 8-bit Adam will crash inside the CUDA kernel + # if any param tensor lives on CPU (the per-param state tensors + # are allocated on the same device as the param). Catch this at + # construction time so callers see a comprehensible error + # instead of a downstream "All input tensors need to be on the + # same GPU" RuntimeError from inside ``optimizer_update_8bit``. + for p in param_list: + if not p.is_cuda: + raise RuntimeError( + "GpuAdamW8bitAdapter received a parameter on device " + f"{p.device}; bitsandbytes' 8-bit AdamW kernels run " + "on CUDA only. Non-persistent (CPU-resident) chunks " + "must continue to use CpuFusedAdamAdapter " + "(DeepSpeedCPUAdam) - only persistent (GPU) chunks " + "may use the 8-bit adapter." + ) + + cls = PagedAdamW8bit if self.paged else AdamW8bit + self._optim = cls( + param_list, + lr=self.lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + ) + + # ---- step interface ------------------------------------------------- + + def step(self) -> None: + """Synchronous bnb 8-bit AdamW step over persistent-chunk params.""" + optim = self._optim + if optim is None: + return + optim.step() + + def zero_grad(self, set_to_none: bool = True) -> None: + """Zero gradients on every persistent-chunk parameter.""" + optim = self._optim + if optim is None: + return + optim.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> dict[str, Any]: + """Return the wrapped 8-bit optimizer's state dict (empty when no-op).""" + optim = self._optim + if optim is None: + return {"state": {}, "param_groups": []} + return optim.state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load state into the wrapped optimizer (no-op when adapter is empty).""" + optim = self._optim + if optim is None: + if state_dict.get("state") or state_dict.get("param_groups"): + raise ValueError( + "Cannot load non-empty optimizer state into an empty " + "GpuAdamW8bitAdapter: this layout has no persistent-chunk " + "params but the checkpoint contains optimizer state " + "(likely a Mode-A/Mode-C config mismatch on resume)." + ) + return + optim.load_state_dict(state_dict) + + @property + def underlying(self) -> Any: + """Return the wrapped optimizer (None when adapter has no persistent params).""" return self._optim -__all__ = ["CpuFusedAdamAdapter", "GpuFusedAdamAdapter"] +__all__ = [ + "CpuFusedAdamAdapter", + "GpuAdamW8bitAdapter", + "GpuFusedAdamAdapter", +] diff --git a/src/axolotl/integrations/protrain/cost/memory.py b/src/axolotl/integrations/protrain/cost/memory.py index 23ccf62c4b..2ab2ae6ff2 100644 --- a/src/axolotl/integrations/protrain/cost/memory.py +++ b/src/axolotl/integrations/protrain/cost/memory.py @@ -9,7 +9,11 @@ Design contract (see DESIGN.md §Design Decisions): - ``ALPHA_FRAGMENTATION = 1.10`` matches the paper's "up to 10% - overestimate on best-selected configurations" claim. + overestimate on best-selected configurations" claim. Per-dtype + refinement lives in :func:`alpha_fragmentation_for_dtype`: fp16 / + bf16 / 8-bit keep alpha=1.10; bnb 4-bit drops to + ``ALPHA_FRAGMENTATION_4BIT = 0.75`` because alpha=1.10 empirically + over-predicts bnb-4-bit Mode-A peak by ~37%. - SWAP blocks do not contribute to the op-walk peak: the paper argues swap-in "only fires when memory is available", so activation swapping is assumed to trade runtime for zero steady-state peak. @@ -157,8 +161,31 @@ def _saved_tensor_bytes_per_block(trace: ProfilerTrace) -> dict[BlockId, int]: #: the BUG-1-4 fixes in ``chunk/manager.py``) the op-walk matches #: measured peaks tightly enough to restore the paper value — see #: DESIGN.md §Design Decisions point 1. +#: +#: Treat as the fp16/bf16/8-bit default; per-dtype overrides live in +#: :func:`alpha_fragmentation_for_dtype`. The constant is retained +#: (rather than fully replaced) so callers that legitimately want the +#: fp16 ceiling — e.g. the model_wrapper's peak-calibration clamp, +#: which is computing a "what would the cost model have said under +#: pure fp16" baseline — can keep depending on the literal 1.10 +#: value, while estimate_peak now dispatches through the per-dtype +#: lookup. ALPHA_FRAGMENTATION: float = 1.10 +#: alpha floor for bnb-4-bit weights; empirical alpha_measured ~= 0.70 (Mode-A 8B Llama sweeps), 0.75 keeps a small cushion. +ALPHA_FRAGMENTATION_4BIT: float = 0.75 + + +def alpha_fragmentation_for_dtype(bytes_per_element: float) -> float: + """Return ALPHA_FRAGMENTATION_4BIT for sub-byte dtypes, else ALPHA_FRAGMENTATION. + + Args: + bytes_per_element: logical bytes per element (0.5 for bnb 4-bit, 1.0 for int8, 2.0 for fp16/bf16). + """ + if bytes_per_element < 1.0: + return ALPHA_FRAGMENTATION_4BIT + return ALPHA_FRAGMENTATION + def _group_ops_by_block(trace: ProfilerTrace) -> dict[BlockId, list[int]]: """Return ``{block_id -> [op_positions]}`` for forward ops only. @@ -278,12 +305,16 @@ def cross_attn_persist_bytes( (OFFLOAD retains forward activations on GPU symmetrically to NONE — see the ``retained_none_bytes`` / ``cumulative_none`` construction below), so we return ``0`` to avoid double-counting. - - When that block is in CKPT or SWAP mode its activations are not - in ``live_none``; CKPT discards the BLOCK INTERNALS but the - OUTPUT hidden tensor passed to the decoder cannot be discarded - (the cross-attention layers reference it). Same for SWAP — the - saved-state output isn't part of the swap-band's offload set. - We therefore return the full ``activation_sizes`` upper bound. + - When that block is in CKPT mode ``ckpt_chain_bytes`` already + covers the block-input residual that the checkpoint framework + retains across the backward window; return 0 to avoid + double-counting. + - When that block is in SWAP mode its block-output IS evicted to + pinned CPU (the swap pool offloads saved tensors including the + block boundary); the cross-attention reference forces it back to + GPU for the entire decoder window, so the bytes are NOT already + counted elsewhere. Return the full ``activation_sizes`` upper + bound for SWAP. Returns 0 when the trace looks single-tree (no decoder ops), when no encoder block_ids resolve, or when we lack activation bytes for @@ -302,6 +333,28 @@ def cross_attn_persist_bytes( # OFFLOAD-only bump is the per-block backward chunk gather, # tracked separately via ``offload_bump_op`` in estimate_peak). return 0 + if last_enc_mode is BlockMode.CKPT: + # CKPT chain bytes already cover this block's residual; avoid double-count. + return 0 + return int(trace.activation_sizes.get(last_enc_bid, 0)) + + +def cross_attn_handoff_bytes( + trace: ProfilerTrace, + block_map: BlockStrategyMap, + tree_index_map: dict[BlockId, int], +) -> int: + """Return encoder-decoder handoff bytes regardless of encoder-last mode (cap-path use).""" + if not _has_multiple_trees(tree_index_map): + return 0 + encoder_bids = sorted(bid for bid, idx in tree_index_map.items() if idx == 0) + if not encoder_bids: + return 0 + last_enc_bid = encoder_bids[-1] + last_enc_mode = block_map.get(last_enc_bid, BlockMode.NONE) + # NONE/OFFLOAD already retain the full block bytes on GPU so the cap need not preserve them again. + if last_enc_mode is BlockMode.NONE or last_enc_mode is BlockMode.OFFLOAD: + return 0 return int(trace.activation_sizes.get(last_enc_bid, 0)) @@ -469,7 +522,8 @@ def hot_iter_peak_cap( # traces are unaffected because ``cross_attn_persist_bytes`` # returns 0 outside the multi-tree path. tree_index_map = block_tree_index_map(trace) - cross_attn_bytes_for_cap = cross_attn_persist_bytes( + # Cap path must preserve handoff bytes even when encoder-last is CKPT (op-walk's zero is double-count avoidance, not absence). + cross_attn_bytes_for_cap = cross_attn_handoff_bytes( trace, block_map, tree_index_map ) encoder_last_bid: BlockId | None = None @@ -852,7 +906,7 @@ def estimate_peak( trace: ProfilerTrace, layout: ChunkLayout, block_map: BlockStrategyMap, - hw: HardwareProfile, # noqa: ARG001 - accepted for API symmetry with runtime + hw: HardwareProfile, ) -> int: """Estimate steady-state peak GPU memory in bytes. @@ -998,17 +1052,11 @@ def estimate_peak( forward_ops_by_block = _group_ops_by_block(trace) tree_index_map = block_tree_index_map(trace) cross_attn_bytes = cross_attn_persist_bytes(trace, block_map, tree_index_map) + # Block-internal saved tensors only; the block-input residual lives in ``ckpt_chain_bytes``. + saved_bytes_proxy_for_op_walk = _saved_tensor_bytes_per_block(trace) - # Resolve "first op index" for each CKPT block; used to schedule the - # checkpoint recomputation bump. If the block has no ops (degenerate - # test input) the bump lands at op index -1 and is ignored below. ckpt_bump_op: dict[int, int] = {} - # Resolve "last op index" for each OFFLOAD block; used to schedule the - # backward-window chunk-gather bump (§4.1). The last forward op is the - # closest forward index to the block's first backward op — backward - # walks blocks in reverse forward order, so the OFFLOAD-block gather - # peak materializes at that op-walk position when the forward - # activations are still resident. + # OFFLOAD bump fires at the last forward op (closest to the block's backward window). offload_bump_op: dict[int, int] = {} for block_id, op_idxs in forward_ops_by_block.items(): if not op_idxs: @@ -1019,23 +1067,20 @@ def estimate_peak( elif mode is BlockMode.OFFLOAD: offload_bump_op[op_idxs[-1]] = int(block_id) - # Retained-activation contribution from NONE + OFFLOAD blocks — - # constant across the op-walk (these activations are live from their - # first op through the end of forward). OFFLOAD retains activations - # symmetrically to NONE; the additional chunk-gather bump fires only - # at the per-block backward window via ``offload_bump_op``. retained_none_bytes = 0 + # CKPT blocks retain the block-input boundary tensor across the full backward window; sum once per CKPT block, separate from the per-op recompute bump in ``ckpt_extra``. + ckpt_chain_bytes = 0 for block_id_raw, act_sz in trace.activation_sizes.items(): - # ``activation_sizes`` is typed ``dict[BlockId, int]`` but - # pickled maps may use int keys; normalize. bid = BlockId(int(block_id_raw)) mode = block_map.get(bid, BlockMode.NONE) if mode is BlockMode.NONE or mode is BlockMode.OFFLOAD: retained_none_bytes += act_sz - # CKPT: only live during its recomputation window -> handled - # by the per-op bump below. + elif mode is BlockMode.CKPT: + ckpt_chain_bytes += act_sz # SWAP: live only during the block's forward compute; assumed - # to overlap free GPU memory (§3.3). + # to overlap free GPU memory (§3.3). The CKPT-chain term + # does NOT apply because SWAP evicts the block-output + # tensor to the pinned-CPU swap pool (see swap_pool.py). # --- Op walk ------------------------------------------------------- raw_peak = 0 @@ -1086,22 +1131,19 @@ def _none_live_at(op_idx: int) -> int: for i, op in enumerate(trace.op_order): if not op.is_forward: - # Backward-only ops are out of scope for the forward - # op-walk. Eq. 8-10 explicitly walk forward ops. continue intra = trace.intra_op_delta.get(op.op_id, 0) inter = trace.inter_op_delta.get(op.op_id, 0) live_none = _none_live_at(i) - # CKPT bump: when we hit the first op of a CKPT block, the - # recomputation materializes that block's activations *in - # addition to* any retained activations. This models the peak - # during the backward-driven recomp window that lines up with - # this op's forward-equivalent workload. + # CKPT recompute bump = internal saved-tensor delta; block-input residual already in ``ckpt_chain_bytes``. ckpt_extra = 0 if i in ckpt_bump_op: - ckpt_extra = trace.activation_sizes.get(BlockId(ckpt_bump_op[i]), 0) + bid = BlockId(ckpt_bump_op[i]) + block_act = trace.activation_sizes.get(bid, 0) + block_saved = int(saved_bytes_proxy_for_op_walk.get(bid, block_act)) + ckpt_extra = max(0, block_saved - block_act) # OFFLOAD backward-gather bump (Option B §4.1): the chunk is # re-gathered into the buffer pool for this block's backward @@ -1123,6 +1165,7 @@ def _none_live_at(op_idx: int) -> int: candidate = ( model_state_present + live_none + + ckpt_chain_bytes + ckpt_extra + offload_extra + op_cross_attn @@ -1132,10 +1175,9 @@ def _none_live_at(op_idx: int) -> int: if candidate > raw_peak: raw_peak = candidate - # If the trace has no forward ops (degenerate test input) fall back - # to a static estimate. This keeps the function total. + # Degenerate trace (no forward ops): static estimate. ckpt_chain_bytes and retained_none_bytes are disjoint by construction so summing both does not double-count. if raw_peak == 0: - raw_peak = model_state_present + retained_none_bytes + raw_peak = model_state_present + retained_none_bytes + ckpt_chain_bytes # Ground-truth forward cap from the profiler's hook-less steady pass. # @@ -1206,7 +1248,8 @@ def _none_live_at(op_idx: int) -> int: measured_cap = hot_iter_peak_cap(trace, block_map, cfg, layout) raw_peak = apply_hot_iter_cap(raw_peak, model_state_present, measured_cap, layout) - scaled = int(ALPHA_FRAGMENTATION * raw_peak) + alpha = alpha_fragmentation_for_dtype(hw.dominant_param_bytes_per_element) + scaled = int(alpha * raw_peak) LOG.debug( "estimate_peak: n_persist=%d n_buffer=%d n_swap=%d n_ckpt=%d n_offload=%d " "raw=%dB alpha=%.2f -> %dB", @@ -1216,7 +1259,7 @@ def _none_live_at(op_idx: int) -> int: cfg.n_checkpoint, cfg.n_offload, raw_peak, - ALPHA_FRAGMENTATION, + alpha, scaled, ) return scaled @@ -1224,8 +1267,11 @@ def _none_live_at(op_idx: int) -> int: __all__ = [ "ALPHA_FRAGMENTATION", + "ALPHA_FRAGMENTATION_4BIT", + "alpha_fragmentation_for_dtype", "_saved_tensor_bytes_per_block", "block_tree_index_map", + "cross_attn_handoff_bytes", "cross_attn_persist_bytes", "estimate_cpu_footprint", "estimate_peak", diff --git a/src/axolotl/integrations/protrain/cost/runtime.py b/src/axolotl/integrations/protrain/cost/runtime.py index 8bcb92a146..2681f7db91 100644 --- a/src/axolotl/integrations/protrain/cost/runtime.py +++ b/src/axolotl/integrations/protrain/cost/runtime.py @@ -753,6 +753,36 @@ def _structure_match( Boot's ``n_swap`` is always 0 by phase-2 spec (:func:`profiler.phase2.bootstrap_config`), so we compare prod's ``cfg.n_swap`` to 0 directly without needing a ``phase2_n_swap`` field. + + DEFERRED: a TRACE_VERSION 23 refactor attempted to make this gate + obsolete by decomposing each analytical component into a roofline- + compute fraction (cfg-invariant) and a synthetic non-compute / per- + block-dispatch predictor (``N_block × tau`` derived from + ``hooked_fwd_wall_s - steady_fwd_wall_s``). The per-component α + would calibrate against the non-compute fraction only, making it + cfg-invariant by construction and dropping the gate. That direction + foundered on two issues empirically: + + 1. The analytical full pred is often dominated by the compute + fraction at boot (compute > comm per chunk on small chunks), + leaving the non-compute residual ``measured - analytical`` near + zero or negative. Solving for α produces values pinned to the + clamp floor, after which the residual α machinery has to absorb + the bulk of the bias — degenerating into the v22 gate's + behaviour with extra plumbing. + 2. The chunked-wall override path at prod cfg returns measurement- + anchored predictions; adding a synthetic non-compute term on top + double-counts the dispatch overhead the chunked wall already + contains, while subtracting the boot's nc_pred via a delta + over-corrects when n_checkpoint changes (the override path + already rebuilds ``t_bwd_recompute`` for prod's cfg). + + The gate stays in place pending a deeper rework that captures the + per-block dispatch overhead at prod-cfg-aware granularity (e.g. a + per-block runtime hook microbench rather than a constant tau, or a + decomposition that distinguishes "Python interpreter overhead per + iter" from "per-chunk PCIe roofline overhead"). See ticket B + deferred report for details. """ boot_n_persist = int(getattr(trace, "phase2_n_persist", -1)) boot_n_checkpoint = int(getattr(trace, "phase2_n_checkpoint", -1)) @@ -816,10 +846,7 @@ def _clamp_residual_alpha(alpha: float) -> float: we still clamp to keep the prediction bounded but warn once so the regression is visible (the brief's "anti-hack guard"). """ - if ( - alpha < _PHASE2_RESIDUAL_NOISE_FLOOR - or alpha > _PHASE2_RESIDUAL_NOISE_CEILING - ): + if alpha < _PHASE2_RESIDUAL_NOISE_FLOOR or alpha > _PHASE2_RESIDUAL_NOISE_CEILING: global _WARNED_PHASE2_RESIDUAL_NOISY if not _WARNED_PHASE2_RESIDUAL_NOISY: LOG.warning( @@ -835,9 +862,7 @@ def _clamp_residual_alpha(alpha: float) -> float: _PHASE2_RESIDUAL_CLAMP_MAX, ) _WARNED_PHASE2_RESIDUAL_NOISY = True - return max( - _PHASE2_RESIDUAL_CLAMP_MIN, min(_PHASE2_RESIDUAL_CLAMP_MAX, alpha) - ) + return max(_PHASE2_RESIDUAL_CLAMP_MIN, min(_PHASE2_RESIDUAL_CLAMP_MAX, alpha)) def _compose_t_iter_with_alpha_calibration( diff --git a/src/axolotl/integrations/protrain/plugin.py b/src/axolotl/integrations/protrain/plugin.py index 59f88b9ca7..b065a0f02b 100644 --- a/src/axolotl/integrations/protrain/plugin.py +++ b/src/axolotl/integrations/protrain/plugin.py @@ -296,6 +296,19 @@ def _remeasure_nccl_and_research(wrapped) -> tuple[bool, bool]: if trace.nccl_gather_s and trace.nccl_reduce_s and trace.world == world_size: return (False, False) + # Skip late NCCL re-search when all explicit overrides pin the plan, to avoid + # re-running search() and raising on a cost-optimal cfg that differs from the + # synthesized bootstrap cfg. + if bool(getattr(wrapped, "_override_skip_trace", False)): + LOG.info( + "ProTrain: late NCCL re-search skipped — explicit override knobs " + "are fully set so the bootstrap cfg is pinned. world_size=%d, " + "bootstrap cfg=%s.", + world_size, + wrapped.search_result.cfg, + ) + return (False, False) + from axolotl.integrations.protrain.profiler import measure_nccl from axolotl.integrations.protrain.profiler.cache import ( ProfilerCacheKey, @@ -427,6 +440,197 @@ def _remeasure_nccl_and_research(wrapped) -> tuple[bool, bool]: return (True, cfg_changed) +def _install_resume_hook(trainer, cfg, wrapped) -> None: + """Wrap ``trainer._load_from_checkpoint`` so cross-mode resume gathers offloaded chunks before reload.""" + if getattr(trainer, "_protrain_resume_hook_installed", False): + LOG.debug( + "ProTrain: resume hook already installed on this trainer; " + "skipping duplicate patch (idempotent path)." + ) + return + + original_load = getattr(trainer, "_load_from_checkpoint", None) + if original_load is None: + # Test harness without an HF Trainer instance — nothing to patch. + LOG.debug( + "ProTrain: trainer has no _load_from_checkpoint attribute; " + "skipping resume-hook install." + ) + return + + # Snapshot the optimizer-rebuild hyperparams now so the wrapped + # closure doesn't have to re-read them off ``trainer.args`` later + # (Accelerate.prepare may have wrapped the optimizer by then and + # the hyperparam read becomes ambiguous about which inner optim's + # values to mirror). Captured as discrete locals (not a kwargs dict) + # so mypy sees the precise types at the rebuild call site — + # ``protrain_optimizer_wrapper``'s signature is positional-named + # with mixed value types (float, tuple[float, float], str | None) + # and a heterogeneous ``dict[str, object]`` ``**unpack`` flunks + # type-narrowing. + args = trainer.args + rebuild_lr = float(args.learning_rate) + rebuild_betas = (float(args.adam_beta1), float(args.adam_beta2)) + rebuild_eps = float(args.adam_epsilon) + rebuild_weight_decay = float(args.weight_decay) + rebuild_optimizer_name = _resolve_optimizer_name(args, cfg) + + def _patched(resume_from_checkpoint, model=None) -> None: + # Resolve the chunk manager LAZILY: by the time the patched + # method fires the wrapper is fully constructed (post_model_load + # ran), but at install time (post_trainer_create) the + # chunk_manager attribute IS already present — read it through + # ``wrapped`` so a future reorder can't strand the closure. + chunk_manager = getattr(wrapped, "chunk_manager", None) + if chunk_manager is None: + LOG.debug( + "ProTrain resume hook: wrapped.chunk_manager is None; " + "delegating to the original _load_from_checkpoint." + ) + return original_load(resume_from_checkpoint, model) + + # Detection: does the chunk manager actually have offloaded + # chunks live right now? Both ``_cpu_slots`` and + # ``_chunk_shards`` are populated by ``materialize_offload``; + # neither is populated under Mode A / all-persistent. Check + # both so the gate covers replicated AND sharded offload. + has_offload = bool( + getattr(chunk_manager, "_cpu_slots", None) + or getattr(chunk_manager, "_chunk_shards", None) + ) + if not has_offload: + LOG.debug( + "ProTrain resume hook: chunk manager has no offloaded " + "state (Mode A / all-persistent); delegating to the " + "original _load_from_checkpoint." + ) + return original_load(resume_from_checkpoint, model) + + LOG.info( + "ProTrain resume hook: gathering %d non-persistent chunk(s) " + "to GPU for cross-mode load_adapter (PEFT load_state_dict " + "needs full-shape destination tensors).", + len(getattr(chunk_manager, "_cpu_slots", {}) or {}) + + len(getattr(chunk_manager, "_chunk_shards", {}) or {}), + ) + + # Tear down the CPU adapter before restore_to_gpu invalidates the shard views it holds. + cpu_optim = getattr(chunk_manager, "cpu_optim", None) + if cpu_optim is not None: + try: + cpu_optim.shutdown() + except Exception: # noqa: BLE001 — fail closed + LOG.exception( + "ProTrain resume hook: cpu_optim.shutdown failed; " + "aborting before restore_to_gpu invalidates shard views." + ) + raise + chunk_manager.cpu_optim = None + # Drop the GPU adapter ref too — we'll rebuild it after the + # load. Persistent params keep their data across restore_to_gpu + # (only standalone-GPU rebind happens), but the GPU adapter's + # ``param_groups`` dict references the same Parameter instances + # so the rebuild closes the loop cleanly. + chunk_manager.gpu_optim = None + + # Step 2: restore_to_gpu rebinds every param.data to standalone + # GPU storage at full shape. After this, model.load_adapter's + # PEFT load_state_dict sees real shapes and the size-mismatch + # error class is gone. + try: + chunk_manager.restore_to_gpu() + except Exception: + LOG.exception( + "ProTrain resume hook: chunk_manager.restore_to_gpu " + "failed; the cross-mode resume cannot proceed. Re-" + "raising — the alternative (running load against the " + "zeroed param.data slots) would crash inside HF's load " + "with the same shape-mismatch error this hook exists " + "to prevent." + ) + raise + + # Step 3: run the original load. HF's _load_from_checkpoint + # signature varies across transformers versions; we forward + # ``model`` only when it was provided (to match the both-sides + # signature in transformers/trainer.py:3280). + if model is None: + original_load(resume_from_checkpoint) + else: + original_load(resume_from_checkpoint, model) + + # Step 4: re-build the offload state. ``materialize_offload`` + # reads ``param.data`` (now the freshly-loaded weights from + # the checkpoint) and copies into newly-allocated pinned + # pools, then resets ``param.data`` to the empty placeholder + # — restoring the same offload contract the wrapper installed + # at post_model_load time. Idempotency: not relevant here + # because ``restore_to_gpu`` cleared ``_cpu_slots`` / + # ``_cpu_param_pool``, so the materialize check passes. + try: + chunk_manager.materialize_offload() + except Exception: + LOG.exception( + "ProTrain resume hook: chunk_manager.materialize_offload " + "failed after the resume load; runtime is now in an " + "inconsistent state (params on standalone GPU storage " + "but no offload pinned pool). Re-raising." + ) + raise + + # Step 5: rebuild the optimizer adapters. The cpu_optim refs + # into the OLD pinned region were dropped in step 1; the GPU + # adapter held no chunk-manager-internal refs. A fresh wrap + # via ``protrain_optimizer_wrapper`` constructs adapters + # against the NEW pinned pool's ``shard_param`` views and + # against the (unchanged-identity) persistent ``Parameter`` + # objects. + from axolotl.integrations.protrain.api import protrain_optimizer_wrapper + + try: + new_optim = protrain_optimizer_wrapper( + wrapped, + lr=rebuild_lr, + betas=rebuild_betas, + eps=rebuild_eps, + weight_decay=rebuild_weight_decay, + optimizer_name=rebuild_optimizer_name, + ) + except Exception: + LOG.exception( + "ProTrain resume hook: protrain_optimizer_wrapper rebuild " + "failed after materialize_offload; runtime can't continue " + "without an optimizer. Re-raising." + ) + raise + + # ``trainer.optimizer`` was the pre-resume ``_ProTrainOptimizer`` + # facade. Replace it in-place. Accelerate.prepare hasn't run yet + # (it runs in _inner_training_loop, downstream of train()'s + # _load_from_checkpoint call site at transformers/trainer.py + # ~1413), so the swap is safe — there is no upstream wrapper + # we'd be invalidating. + trainer.optimizer = new_optim + LOG.info( + "ProTrain resume hook: optimizer adapter rebuilt and " + "installed on trainer.optimizer; cross-mode resume complete." + ) + + trainer._load_from_checkpoint = _patched # type: ignore[method-assign] + trainer._protrain_resume_hook_installed = True # type: ignore[attr-defined] + LOG.debug( + "ProTrain: cross-mode resume hook installed on trainer._load_from_checkpoint" + ) + + +def _resolve_optimizer_name(args, cfg) -> str | None: + """Return the optimizer name, preferring HF ``args.optim`` over ``cfg.optimizer``.""" + optimizer_name = getattr(args, "optim", None) or getattr(cfg, "optimizer", None) + if optimizer_name is not None and not isinstance(optimizer_name, str): + optimizer_name = getattr(optimizer_name, "value", str(optimizer_name)) + return optimizer_name + + def _is_plugin_active(cfg) -> bool: """Return True iff both the plugin is registered and auto_memory is on. @@ -792,13 +996,18 @@ def create_optimizer(self, cfg, trainer: "Trainer") -> "Optimizer | None": betas = (float(args.adam_beta1), float(args.adam_beta2)) eps = float(args.adam_epsilon) weight_decay = float(args.weight_decay) + # Forward the optimizer name so the wrapper can route 8-bit-bnb to GpuAdamW8bitAdapter. + optimizer_name = getattr(args, "optim", None) or getattr(cfg, "optimizer", None) + if optimizer_name is not None and not isinstance(optimizer_name, str): + optimizer_name = getattr(optimizer_name, "value", str(optimizer_name)) LOG.info( - "ProTrain.create_optimizer: lr=%.3e betas=%s eps=%.1e wd=%.3e", + "ProTrain.create_optimizer: lr=%.3e betas=%s eps=%.1e wd=%.3e optimizer=%s", lr, betas, eps, weight_decay, + optimizer_name, ) return protrain_optimizer_wrapper( @@ -807,6 +1016,7 @@ def create_optimizer(self, cfg, trainer: "Trainer") -> "Optimizer | None": betas=betas, eps=eps, weight_decay=weight_decay, + optimizer_name=optimizer_name, ) def post_trainer_create(self, cfg, trainer: "Trainer") -> None: @@ -854,12 +1064,16 @@ def post_trainer_create(self, cfg, trainer: "Trainer") -> None: from axolotl.integrations.protrain.api import protrain_optimizer_wrapper args = trainer.args + optimizer_name = getattr(args, "optim", None) or getattr(cfg, "optimizer", None) + if optimizer_name is not None and not isinstance(optimizer_name, str): + optimizer_name = getattr(optimizer_name, "value", str(optimizer_name)) optim = protrain_optimizer_wrapper( wrapped, lr=float(args.learning_rate), betas=(float(args.adam_beta1), float(args.adam_beta2)), eps=float(args.adam_epsilon), weight_decay=float(args.weight_decay), + optimizer_name=optimizer_name, ) # ``_ProTrainOptimizer.state_dict`` / ``load_state_dict`` already @@ -878,6 +1092,11 @@ def post_trainer_create(self, cfg, trainer: "Trainer") -> None: float(args.weight_decay), ) + # Patch _load_from_checkpoint so PEFT/HF load sees full-shape param.data + # (offloaded LoRA factors have size (0,) and would size-mismatch otherwise); + # cycle: restore_to_gpu -> original load -> materialize_offload -> rebuild optimizer. + _install_resume_hook(trainer, cfg, wrapped) + # ---- Optimizer-state checkpoint/resume (CHECKPOINT_DESIGN.md) ---- # Opt-in via protrain_save_optimizer_state. The save side is a # TrainerCallback (on_save fires after HF writes its standard diff --git a/src/axolotl/integrations/protrain/profiler/hw_bench.py b/src/axolotl/integrations/protrain/profiler/hw_bench.py index 230995d533..15e35b25fd 100644 --- a/src/axolotl/integrations/protrain/profiler/hw_bench.py +++ b/src/axolotl/integrations/protrain/profiler/hw_bench.py @@ -612,6 +612,16 @@ def measure_nccl( gather_table: dict[int, float] = {} reduce_table: dict[int, float] = {} + # surface communicator-config asymmetry as a debuggable barrier hang instead of a SIGSEGV inside the first collective + try: + dist.barrier(device_ids=[device_idx]) + except Exception as exc: # pragma: no cover - defensive + raise RuntimeError( + "measure_nccl: pre-collective dist.barrier() failed — your ranks " + "likely have asymmetric NCCL communicator config. Set " + "TORCH_DISTRIBUTED_DEBUG=DETAIL and re-run to inspect." + ) from exc + for payload_bytes in payload_sizes_bytes: # all_gather_into_tensor: each rank contributes one shard of size # payload/world_size, output is the full payload on every rank. diff --git a/src/axolotl/integrations/protrain/profiler/on_demand.py b/src/axolotl/integrations/protrain/profiler/on_demand.py index 86e6c1b581..29db6cd41e 100644 --- a/src/axolotl/integrations/protrain/profiler/on_demand.py +++ b/src/axolotl/integrations/protrain/profiler/on_demand.py @@ -32,6 +32,7 @@ from __future__ import annotations +import types from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Iterable @@ -45,6 +46,99 @@ LOG = get_logger(__name__) +def _fused_kernel_func_names() -> frozenset[str]: + """Names of fused LoRA apply_* functions whose direct-attribute weight reads bypass per-Linear gather hooks; listed by name (not import) so a missing kernel module stays non-fatal.""" + return frozenset( + { + "apply_lora_mlp_swiglu", + "apply_lora_mlp_geglu", + "apply_lora_qkv", + "apply_lora_qk", + "apply_lora_o", + "apply_lora_embedding", + } + ) + + +def _is_fused_method(attr: Any) -> bool: + """True iff ``attr`` is an instance-bound method whose underlying function is one of the fused-kernel apply_* entries.""" + if not isinstance(attr, types.MethodType): + return False + fn = getattr(attr, "__func__", None) + name = getattr(fn, "__name__", None) + return name in _fused_kernel_func_names() + + +def _find_fused_kernel_containers(model: "nn.Module") -> "list[nn.Module]": + """Return modules with at least one fused-kernel method binding; deterministic ``model.modules()`` order so tests can rely on stable enumeration.""" + out: list["nn.Module"] = [] + for sub in model.modules(): + for attr_name in ("forward", "apply_qkv", "apply_o"): + attr = getattr(sub, attr_name, None) + if _is_fused_method(attr): + out.append(sub) + break + return out + + +# PEFT trainable-factor parameter name fragments. These are the canonical +# attribute names PEFT uses for trainable LoRA factors on a wrapped layer. +# We match by substring against ``named_parameters(recurse=False)`` so the +# detector covers both bare ``lora_A`` and the ParameterDict-wrapped +# ``lora_A.default`` form (PEFT serialises the active adapter under the +# adapter-name key, defaulting to "default"). ``lora_magnitude_vector`` +# covers DoRA's per-output-channel magnitude scalar. +_PEFT_LORA_NAME_TAGS: frozenset[str] = frozenset( + { + "lora_A", + "lora_B", + "lora_embedding_A", + "lora_embedding_B", + "lora_magnitude_vector", + } +) + + +def _has_peft_lora_factor( + module: "nn.Module", *, recurse_children: bool = True +) -> bool: + """True iff ``module`` *directly* owns a trainable LoRA factor (parameter attribute or one-level-child by tag name); grandparents are excluded because PEFT's direct-attribute reads happen on the LoraLayer itself.""" + # Direct-Parameter scope: catches the bare ``nn.Parameter`` form. + for name, p in module.named_parameters(recurse=False): + if not p.requires_grad: + continue + if any(tag in name for tag in _PEFT_LORA_NAME_TAGS): + return True + if not recurse_children: + return False + # Direct-child-module scope: PEFT's ParameterDict / wrapped-Linear + # form. The child's *attribute name on this module* carries the + # PEFT tag (``lora_A`` etc.). Verify the child actually contains + # at least one trainable parameter so we don't tag a frozen-only + # subtree as a container (the M6C bug only matters for params + # that produce gradients). + for child_name, child in module.named_children(): + if not any(tag in child_name for tag in _PEFT_LORA_NAME_TAGS): + continue + for _pname, p in child.named_parameters(recurse=True): + if p.requires_grad: + return True + return False + + +def _find_peft_lora_containers(model: "nn.Module") -> "list[nn.Module]": + """Return modules that directly own trainable LoRA factors; excludes fused-kernel containers (their hooks already cover the same subtree). Deterministic ``model.modules()`` order.""" + fused = set(id(m) for m in _find_fused_kernel_containers(model)) + out: list["nn.Module"] = [] + for sub in model.modules(): + if id(sub) in fused: + continue + if not _has_peft_lora_factor(sub, recurse_children=True): + continue + out.append(sub) + return out + + @dataclass class _ParamSpill: """Bookkeeping for one parameter that's been spilled to CPU. @@ -149,6 +243,11 @@ def __init__( self._sthook_ctx: Any = None self._entered = False self._n_pin_failures = 0 + # Populated by ``__enter__`` after fused-kernel detection. Tests + # may inspect this to verify per-container hook installation. + self._fused_containers: list["nn.Module"] = [] + # PEFT-LoRA containers needing subtree gather/release so param.data stays live across backward + self._peft_lora_containers: list["nn.Module"] = [] # ---- context-manager protocol -------------------------------------- @@ -268,6 +367,90 @@ def __enter__(self) -> "OnDemandTensorMgr": sub.register_full_backward_hook(self._post_release_bwd) ) + # container-level gather/release for fused-kernel modules whose patched forward bypasses the per-Linear hooks; prepend=True so the gather precedes the trace driver's snapshot pre-hook + self._fused_containers = _find_fused_kernel_containers(self.model) + if self._fused_containers: + LOG.debug( + "OnDemandTensorMgr: %d fused-kernel container(s) " + "detected; installing per-container gather hooks", + len(self._fused_containers), + ) + for container in self._fused_containers: + self._handles.append( + container.register_forward_pre_hook( + self._pre_gather_subtree, prepend=True + ) + ) + self._handles.append( + container.register_forward_hook(self._post_release_subtree) + ) + # Backward hooks: the fused autograd Function (LoRA_MLP / + # LoRA_QKV / LoRA_O) stores raw weight Tensor refs as a + # plain Python attribute on ``ctx`` (e.g. ``ctx.weights``, + # not ``ctx.save_for_backward``), so the saved-tensors + # pack/unpack path does NOT spill them. By backward time + # the forward post-release has reset every base + # ``param.data`` to a length-0 placeholder, and the + # autograd backward's matmul against ``ctx.weights[i]`` + # raises the same ``size mismatch ... vec (0)`` the M0 + # spike captured — but firing in ``LoRA_MLP.backward`` + # instead of forward (the fix's forward-only first cut + # got the trace forward past the failure but tripped on + # the backward equivalent during the trace's + # ``loss.backward()`` call). Re-gathering the container's + # subtree before its backward enters, then releasing + # after, makes the fused autograd Function's backward + # see real weights again. Symmetric with the forward pair. + self._handles.append( + container.register_full_backward_pre_hook( + self._pre_gather_subtree_bwd, prepend=True + ) + ) + self._handles.append( + container.register_full_backward_hook( + self._post_release_subtree_bwd + ) + ) + + # PEFT-LoRA containers: subtree gather keeps both LoRA factors and the wrapped base weight live across forward+backward so autograd shape-derivation sees real sizes + self._peft_lora_containers = _find_peft_lora_containers(self.model) + if self._peft_lora_containers: + LOG.debug( + "OnDemandTensorMgr: %d PEFT-LoRA container(s) " + "detected; installing per-container gather hooks", + len(self._peft_lora_containers), + ) + for container in self._peft_lora_containers: + self._handles.append( + container.register_forward_pre_hook( + self._pre_gather_subtree, prepend=True + ) + ) + self._handles.append( + container.register_forward_hook(self._post_release_subtree) + ) + # Symmetric backward hooks: the PEFT LoRA forward path's + # autograd graph is built against the gathered tensors; + # at backward time the same shape-derivation step that + # bites at forward (``ToCopyBackward0`` reading + # ``param.size()``) bites again. Without this pair, the + # per-Linear post-release would clear ``base_layer.weight`` + # before the LoRA backward runs and grad accumulation + # against the saved-shape activation would see a length-0 + # placeholder weight. Mirror the fused-kernel container's + # backward hooks so the LoRA backward window sees real + # weights too. + self._handles.append( + container.register_full_backward_pre_hook( + self._pre_gather_subtree_bwd, prepend=True + ) + ) + self._handles.append( + container.register_full_backward_hook( + self._post_release_subtree_bwd + ) + ) + # Saved-for-backward tensors spill to CPU. Without this, autograd # would keep the gathered GPU param alive via the saved-for- # backward slot of the linear's grad_fn, defeating post_release. @@ -392,6 +575,8 @@ def _restore_after_partial_setup(self) -> None: ) self._spills.clear() self._active_param_users.clear() + self._fused_containers = [] + self._peft_lora_containers = [] def __exit__(self, exc_type, exc, tb) -> None: """Remove hooks and restore parameters from their pinned-CPU spill copies.""" @@ -504,6 +689,8 @@ def __exit__(self, exc_type, exc, tb) -> None: ) self._spills.clear() self._active_param_users.clear() + self._fused_containers = [] + self._peft_lora_containers = [] # ---- spill / restore helpers --------------------------------------- @@ -752,6 +939,30 @@ def _post_release(self, module: "nn.Module", inputs: Any, output: Any) -> None: except Exception as exc: # noqa: BLE001 - defensive LOG.debug("OnDemandTensorMgr post-release no-op (%s)", exc) + def _pre_gather_subtree(self, module: "nn.Module", inputs: Any) -> None: + """Run ``_pre_gather`` over every submodule so the fused/PEFT container's whole subtree is GPU-resident before the patched forward reads weights by direct attribute access.""" + for sub in module.modules(): + self._pre_gather(sub, inputs) + + def _post_release_subtree( + self, module: "nn.Module", inputs: Any, output: Any + ) -> None: + """Mirror of ``_pre_gather_subtree`` but walks submodules in reverse so the active-user refcounts unwind LIFO (matches the tied-param ownership pattern).""" + for sub in reversed(list(module.modules())): + self._post_release(sub, inputs, output) + + def _pre_gather_subtree_bwd(self, module: "nn.Module", grad_output: Any) -> None: + """Backward-pre subtree gather; needed because fused autograd Functions stash raw weight refs on ``ctx`` (bypassing ``save_for_backward``), so the forward post-release left them as empty placeholders.""" + for sub in module.modules(): + self._pre_gather(sub, grad_output) + + def _post_release_subtree_bwd( + self, module: "nn.Module", grad_input: Any, grad_output: Any + ) -> None: + """Backward-post subtree release; defers to ``_post_release_bwd`` per submodule so the ``inputs_have_grad`` premature-fire guard still applies (otherwise embeddings would clear their weight mid-AccumulateGrad).""" + for sub in reversed(list(module.modules())): + self._post_release_bwd(sub, grad_input, grad_output) + def _pre_gather_bwd(self, module: "nn.Module", grad_output: Any) -> None: """Backward-pre hook: gather direct params before this module's bwd. @@ -916,4 +1127,10 @@ def live_tensor_ids(self) -> Iterable[int]: return tuple(self._spills.keys()) -__all__ = ["OnDemandTensorMgr"] +__all__ = [ + "OnDemandTensorMgr", + "_find_fused_kernel_containers", + "_find_peft_lora_containers", + "_has_peft_lora_factor", + "_is_fused_method", +] diff --git a/src/axolotl/integrations/protrain/profiler/trace.py b/src/axolotl/integrations/protrain/profiler/trace.py index 790b0fae6f..1a7d851d41 100644 --- a/src/axolotl/integrations/protrain/profiler/trace.py +++ b/src/axolotl/integrations/protrain/profiler/trace.py @@ -607,7 +607,14 @@ def _output_bytes(output: Any) -> int: # exactly what doesn't fit. The cost model falls back to defaults # (identity scale, default bwd_fwd ratio) for traces marked on-demand. engage_on_demand = False - if cfg.on_demand and cuda_available: + if cfg.force_all_persistent: + # force_all_persistent overrides the on-demand auto-engagement gate so the trace honors Mode A even on borderline configs + LOG.info( + "Profiler force_all_persistent=True; skipping on-demand " + "engagement gate. Trace pass will run the trainable " + "forward+backward fully on GPU." + ) + elif cfg.on_demand and cuda_available: try: gpu_total = int(torch.cuda.get_device_properties(device).total_memory) # State-aware footprint: params (all of them) + grads + fp32 @@ -1306,4 +1313,136 @@ def _extract_loss(output: Any) -> "torch.Tensor": ) -__all__ = ["run_trace"] +def _infer_hidden_size(model: "nn.Module") -> int: + """Best-effort hidden-size inference; falls back to 2048 so synthetic SWAP slot sizing stays finite.""" + cfg = getattr(model, "config", None) + if cfg is not None: + for attr in ("hidden_size", "d_model", "n_embd"): + v = getattr(cfg, attr, None) + if isinstance(v, int) and v > 0: + return v + return 2048 + + +def _infer_intermediate_size(model: "nn.Module", hidden_size: int) -> int: + """Best-effort FFN intermediate size; sized larger than hidden so synthetic SWAP slot sizing doesn't under-shoot the largest saved activation.""" + cfg = getattr(model, "config", None) + if cfg is not None: + for attr in ("intermediate_size", "ffn_hidden_size", "d_ff", "n_inner"): + v = getattr(cfg, attr, None) + if isinstance(v, int) and v > 0: + return v + return 4 * int(hidden_size) + + +def synth_trace_from_overrides( + model: "nn.Module", + *, + batch_size: int, + seq_len: int, + device: "torch.device | str", + world_size: int, + measure_pcie_bps: bool = True, + param_grad_bytes_per_param: int = DEFAULT_PARAM_GRAD_BYTES_PER_PARAM, + optim_state_bytes_per_param: int = DEFAULT_OPTIM_STATE_BYTES_PER_PARAM, +) -> ProfilerTrace: + """Synthesize a minimally-populated ProfilerTrace so the explicit-override skip path can bypass the OOM-prone real trace pass.""" + import torch + + # Lazy import to avoid pulling block layout deps at module import. + from axolotl.integrations.protrain.block.layout_rules import ( + block_id_path_map, + discover_blocks, + flatten_block_trees, + ) + + dev = torch.device(device) if not isinstance(device, torch.device) else device + + # Discover blocks so ``activation_sizes`` keys span the actual block + # ids the runtime will use. Falls back to a single synthetic block + # entry if discovery fails (degenerate / non-transformer models). + try: + trees = discover_blocks(model) + blocks = flatten_block_trees(trees) + block_count = max(1, len(blocks)) + path_map = block_id_path_map(model, trees) + block_tree_index: dict[BlockId, int] = {} + flat_idx = 0 + for tree in sorted(trees, key=lambda t: t.forward_order): + for _ in tree.blocks: + block_tree_index[BlockId(flat_idx)] = int(tree.forward_order) + flat_idx += 1 + # path_map currently unused beyond confirming discovery worked; + # keep around as a sanity check. + del path_map + except Exception as exc: # pragma: no cover - defensive + LOG.debug( + "synth_trace_from_overrides: discover_blocks failed (%s); " + "falling back to single-block placeholder", + exc, + ) + block_count = 1 + block_tree_index = {BlockId(0): 0} + + hidden_size = _infer_hidden_size(model) + intermediate_size = _infer_intermediate_size(model, hidden_size) + # Per-block activation upper bound. We size off the FFN intermediate + # (``bs * seq * intermediate * 2 B``) because that's typically the + # largest single saved tensor PyTorch's autograd retains for backward + # — block-output residual (``bs * seq * hidden * 2 B``) under-shoots + # by the FFN expansion factor (~3.5x on Llama). Sizing too small + # here triggers the SWAP runtime's "exceeds pool slot" warning path + # which silently degrades to "keep on GPU"; the analytical value is + # still consulted ONLY by sizing-path code, never by the cost + # model (which is bypassed entirely on the override path). + per_block_act_bytes = int(batch_size) * int(seq_len) * int(intermediate_size) * 2 + activation_sizes: dict[BlockId, int] = { + BlockId(i): per_block_act_bytes for i in range(block_count) + } + + model_state_bytes = _count_model_state_bytes( + model, + param_grad_bytes_per_param=param_grad_bytes_per_param, + optim_state_bytes_per_param=optim_state_bytes_per_param, + ) + + # Conservative Gen3 fallback (matches the model_wrapper's + # default-prior threshold at line ~2078). + pcie_h2d_bps = 13e9 + pcie_d2h_bps = 13e9 + if measure_pcie_bps and dev.type == "cuda" and torch.cuda.is_available(): + try: + dev_idx = ( + dev.index if dev.index is not None else torch.cuda.current_device() + ) + pcie_h2d_bps, pcie_d2h_bps = measure_pcie(int(dev_idx)) + except Exception as exc: # pragma: no cover - defensive + LOG.warning( + "synth_trace_from_overrides: measure_pcie failed (%s); " + "falling back to 13 GB/s Gen3 prior", + exc, + ) + + return ProfilerTrace( + op_order=(), + intra_op_delta={}, + inter_op_delta={}, + activation_sizes=activation_sizes, + model_state_bytes=int(model_state_bytes), + pcie_h2d_bps=float(pcie_h2d_bps), + pcie_d2h_bps=float(pcie_d2h_bps), + nccl_gather_s={}, + nccl_reduce_s={}, + arch_hash=_arch_hash(model), + bs=int(batch_size), + seq=int(seq_len), + sku=_sku(dev), + world=int(world_size), + op_latencies={}, + cpu_adam_bytes_per_sec=0.0, + gpu_adam_bytes_per_sec=0.0, + block_tree_index=block_tree_index, + ) + + +__all__ = ["run_trace", "synth_trace_from_overrides"] diff --git a/src/axolotl/integrations/protrain/runtime/hooks.py b/src/axolotl/integrations/protrain/runtime/hooks.py index 3661bcb2f3..fb7ee6055d 100644 --- a/src/axolotl/integrations/protrain/runtime/hooks.py +++ b/src/axolotl/integrations/protrain/runtime/hooks.py @@ -1,23 +1,4 @@ -"""Block-granularity forward/backward hooks for the ProTrain runtime. - -``install_hooks`` attaches four hooks per transformer block: - -* forward-pre hook -> :meth:`Scheduler.pre_block_forward` -* forward-post hook -> :meth:`Scheduler.post_block_forward` -* backward-pre hook -> :meth:`Scheduler.pre_block_backward` -* backward-post hook -> :meth:`Scheduler.post_block_backward` - -The hooks operate at **block** granularity only — op-level hooks are -the profiler's job (M1). This module's contract is to wire the already- -wrapped blocks (see :mod:`axolotl.integrations.protrain.block.dispatcher`) -into the scheduler's prefetch / release / reduce-offload machine. - -Ordering note: ``protrain_model_wrapper`` wraps every block *before* -installing these hooks, so the hooks attach to the post-wrap modules -(``CheckpointedBlock`` / ``SwappedBlock`` / identity). The wrapper -idempotency guarantee means a re-search at epoch boundaries can -uninstall + re-wrap + re-install without any hook-level bookkeeping. -""" +"""Block-granularity forward/backward hooks plus per-PEFT-LoRA-container quartet hooks that re-bind chunk data across every autograd window where ``param.data`` could otherwise be observed as the empty placeholder.""" from __future__ import annotations @@ -30,9 +11,13 @@ flatten_block_trees, ) from axolotl.integrations.protrain.block.offload import OffloadedBlock +from axolotl.integrations.protrain.profiler.on_demand import ( + _find_peft_lora_containers, +) from axolotl.integrations.protrain.types import ( BlockId, BlockStrategyMap, + ChunkId, ) from axolotl.utils.logging import get_logger @@ -98,6 +83,79 @@ def _hook(module: nn.Module, grad_input, grad_output): # noqa: ARG001 return _hook +def _container_chunk_ids( + container: nn.Module, + chunk_manager: "ChunkManager", +) -> tuple[ChunkId, ...]: + """Return the sorted+deduped chunk-id set covering ``container``'s subtree; lookups go via ``id(param)`` because post-wrap names differ from chunk-manager construction-time names.""" + # Reverse index: id(Parameter) -> ParamId (dotted name string). + cm_id_to_name = {id(p): name for name, p in chunk_manager._params_by_id.items()} # noqa: SLF001 + chunk_ids: set[ChunkId] = set() + for param in container.parameters(recurse=True): + cm_name = cm_id_to_name.get(id(param)) + if cm_name is None: + # Param post-dates chunk-manager construction (e.g. an + # adapter PEFT installed AFTER protrain_model_wrapper — + # not the supported flow but cheap to skip defensively). + continue + cid = chunk_manager.layout.param_to_chunk.get(cm_name) + if cid is None: + continue + chunk_ids.add(cid) + # Sort for determinism — gather order doesn't matter (the chunk + # manager's gather is per-chunk independent), but a stable order + # keeps test-time enumeration reproducible. + return tuple(sorted(chunk_ids)) + + +def _make_lora_container_pre_forward_hook( + scheduler: "Scheduler", chunk_ids: tuple[ChunkId, ...] +): + """Build a forward-pre hook that gathers ``chunk_ids`` via idempotent ``ensure_chunks_resident``; chunk_ids is precomputed once per container to avoid walking parameters every forward.""" + + def _hook(module: nn.Module, inputs): # noqa: ARG001 + scheduler.ensure_chunks_resident(chunk_ids) + return None + + return _hook + + +def _make_lora_container_pre_backward_hook( + scheduler: "Scheduler", chunk_ids: tuple[ChunkId, ...] +): + """Backward-pre mirror of the forward variant; the cold-path re-gather prevents the autograd ``shape compatible with [0]`` error when a chunk was evicted before the LoRA backward kernel runs.""" + + def _hook(module: nn.Module, grad_output): # noqa: ARG001 + scheduler.ensure_chunks_resident(chunk_ids) + return None + + return _hook + + +def _make_lora_container_post_forward_hook( + scheduler: "Scheduler", chunk_ids: tuple[ChunkId, ...] +): + """Forward-post defensive re-bind; guarantees ``param.data`` is gathered before the block-level post-forward fires its release, even if an intermediate scheduler reentrancy nulled it mid-forward.""" + + def _hook(module: nn.Module, inputs, output): # noqa: ARG001 + scheduler.ensure_chunks_resident(chunk_ids) + return None + + return _hook + + +def _make_lora_container_post_backward_hook( + scheduler: "Scheduler", chunk_ids: tuple[ChunkId, ...] +): + """Backward-post defensive re-bind; covers the gap between the outer container's pre-backward and the inner Linear's ``TBackward0`` apply where the block-level scheduler may have released the chunk.""" + + def _hook(module: nn.Module, grad_input, grad_output): # noqa: ARG001 + scheduler.ensure_chunks_resident(chunk_ids) + return None + + return _hook + + def install_hooks( model: nn.Module, chunk_manager: "ChunkManager", @@ -195,10 +253,51 @@ def install_hooks( if isinstance(block, OffloadedBlock): block.attach_runtime(chunk_manager, scheduler) + # per-PEFT-LoRA-container hooks gather LoRA-factor chunks before autograd shape-derivation runs, closing the cold-path ``shape compatible with [0]`` failure that block-level hooks miss + peft_lora_containers = _find_peft_lora_containers(model) + if peft_lora_containers: + # INFO so the load-bearing per-container hook install surfaces in production logs + LOG.info( + "install_hooks: %d PEFT-LoRA container(s) detected; " + "installing per-container fwd/bwd pre+post-gather hook quartet", + len(peft_lora_containers), + ) + for container in peft_lora_containers: + cids = _container_chunk_ids(container, chunk_manager) + if not cids: + # container's params post-date chunk-manager construction; nothing to gather + continue + # prepend=True so the gather precedes any trace-driver snapshot pre-hook that would otherwise read pre-gather state + handles.append( + container.register_forward_pre_hook( + _make_lora_container_pre_forward_hook(scheduler, cids), + prepend=True, + ) + ) + # post-forward re-assert: closes the mid-forward param.data null window before block-level offload(cid) release + handles.append( + container.register_forward_hook( + _make_lora_container_post_forward_hook(scheduler, cids) + ) + ) + handles.append( + container.register_full_backward_pre_hook( + _make_lora_container_pre_backward_hook(scheduler, cids) + ) + ) + # post-backward re-assert: pins the chunk across the gap between outer container's post-forward and inner Linear's TBackward0 apply + handles.append( + container.register_full_backward_hook( + _make_lora_container_post_backward_hook(scheduler, cids) + ) + ) + LOG.debug( - "install_hooks: attached %d handles across %d transformer blocks", + "install_hooks: attached %d handles across %d transformer blocks " + "(plus %d PEFT-LoRA container pre+post fwd/bwd hook quartet(s))", len(handles), len(blocks), + len(peft_lora_containers), ) return handles diff --git a/src/axolotl/integrations/protrain/runtime/scheduler.py b/src/axolotl/integrations/protrain/runtime/scheduler.py index b811c15e78..c94dfffc4b 100644 --- a/src/axolotl/integrations/protrain/runtime/scheduler.py +++ b/src/axolotl/integrations/protrain/runtime/scheduler.py @@ -301,6 +301,28 @@ def ensure_block_resident(self, block_id: BlockId) -> None: self._gather_on_prefetch_stream(chunk_ids) self._sync_prefetch_with_compute() + def ensure_chunks_resident(self, chunk_ids: Iterable[ChunkId]) -> None: + """Synchronously gather an arbitrary chunk set on the compute stream so autograd shape-derivation sees real ``param.size()`` even on cold paths.""" + # Materialize once so we can both check emptiness and iterate + # twice (gather + the fast-path persistent-skip in the manager). + cids = tuple(chunk_ids) + if not cids: + return + # Wait on swap + prefetch streams so pool buffers and in-flight gathers complete before the compute-stream rebind. + try: + import torch as _torch + except ImportError: # pragma: no cover — defensive, CPU-only lanes + _torch = None # type: ignore[assignment] + if _torch is not None and _torch.cuda.is_available(): + compute = _torch.cuda.current_stream() + if self._swap_stream is not None: + compute.wait_stream(self._swap_stream) + if self._prefetch_stream is not None: + compute.wait_stream(self._prefetch_stream) + # gather on the compute stream so the sharded all_gather completes before autograd records source-shape against the rebound param.data + for cid in cids: + self.chunk_manager.gather(cid) + # ---- forward ------------------------------------------------------- def pre_block_forward(self, block_id: BlockId) -> None: diff --git a/src/axolotl/integrations/protrain/search/exhaustive.py b/src/axolotl/integrations/protrain/search/exhaustive.py index d55ababe94..649a95b8dd 100644 --- a/src/axolotl/integrations/protrain/search/exhaustive.py +++ b/src/axolotl/integrations/protrain/search/exhaustive.py @@ -123,7 +123,7 @@ def block_map_runtime_admissible( ) -> bool: """Return True iff the block strategy is safe for current chunk offload. - Four-mode admissibility (post-Option B with the SWAP × non-persistent + Four-mode admissibility (post-Option B with the SWAP x non-persistent lift; see ``BLOCK_MODE_OFFLOAD_DESIGN.md`` §3.5 and §6.6): * ``CKPT`` — always admissible. The recompute path re-binds storage by @@ -148,16 +148,16 @@ def block_map_runtime_admissible( its bytes). Backward grad-accumulation reads ``param.data``, which ``Scheduler.pre_block_backward`` already re-gathers symmetrically with the CKPT/OFFLOAD paths, so no additional plumbing is needed - to make SWAP × non-persistent byte-exact. + to make SWAP x non-persistent byte-exact. * ``NONE`` — admissible iff every chunk owned by the block is in the persistent set. NONE installs no hooks, so PyTorch's autograd saved-tensors reference the original GPU storage directly; once that storage is reused by another chunk's gather H2D, the saved tensor's bytes are corrupt and backward produces silently wrong - gradients. There is no in-tree fix for NONE × non-persistent — + gradients. There is no in-tree fix for NONE x non-persistent — use CKPT, OFFLOAD, or SWAP for blocks with non-persistent chunks. - Pre-2026-05 history: SWAP × non-persistent was conservatively + Pre-2026-05 history: SWAP x non-persistent was conservatively rejected on the assumption that "saved tensors are not a safe persistence mechanism once ``param.data`` is rebound to the empty sentinel". The conjecture conflated NONE (which IS unsafe) with @@ -489,14 +489,16 @@ def search( # ``F(block_map)`` is the raw-peak contribution excluding the # ``(n_persist + n_buffer) * S_chunk`` term, pre-alpha. from axolotl.integrations.protrain.cost.memory import ( - ALPHA_FRAGMENTATION, + ALPHA_FRAGMENTATION, # noqa: F401 — re-exported for downstream consumers + alpha_fragmentation_for_dtype, apply_hot_iter_cap, block_tree_index_map, hot_iter_peak_cap, model_state_present_bytes, ) - alpha = ALPHA_FRAGMENTATION + # Must mirror estimate_peak's per-dtype alpha so the search's GPU-gate and the wrapper's post-search calibration agree. + alpha = alpha_fragmentation_for_dtype(hw.dominant_param_bytes_per_element) s_chunk = layout.S_chunk # Hoist trace-only maps out of the (n_swap, n_ckpt) hot loop — diff --git a/src/axolotl/integrations/protrain/types.py b/src/axolotl/integrations/protrain/types.py index e0571afd00..653f36caac 100644 --- a/src/axolotl/integrations/protrain/types.py +++ b/src/axolotl/integrations/protrain/types.py @@ -80,6 +80,15 @@ class ProfilerConfig: device: str # e.g. "cuda:2" include_backward: bool = True on_demand: bool = True # OnDemandTensorMgr for models > single-GPU + # When True, suppress the trace-pass on-demand engagement gate even if + # model_state exceeds the device-memory threshold. Plumbed from the + # caller's ``force_all_persistent`` flag so a user who has explicitly + # opted into Mode A doesn't get on-demand offloading silently re- + # engaged during the trace pass (which can hang or destabilize the + # host on borderline configurations). + # The trace pass still runs the trainable forward+backward; the + # caller is responsible for ensuring the model fits. + force_all_persistent: bool = False # Distributed world size. ``None`` (default) means "auto-detect" — the # tracer probes ``torch.distributed.get_world_size()`` if a process # group is initialized and falls back to 1 otherwise. Pass an explicit @@ -256,13 +265,13 @@ class ProfilerTrace: # Fraction of model parameters with ``requires_grad=True`` at trace time # (range [0.0, 1.0]). LoRA / adapter training has very low trainable - # fractions (~0.1% on 7B-LoRA-r8) — backward compute is then ~1× forward - # rather than the canonical 2× full-finetune ratio, because autograd + # fractions (~0.1% on 7B-LoRA-r8) — backward compute is then ~1x forward + # rather than the canonical 2x full-finetune ratio, because autograd # skips frozen subgraphs. The cost model's ``_bwd_compute_time_from_trace`` # consults this fraction to pick a tighter fallback ratio when the # measured ``steady_bwd_wall_s`` is unavailable (7B-class profiler runs # OOM the backward without chunk offload engaged). 0.0 means unmeasured - # (pre-v8) — falls back to the canonical 2× ratio. New in TRACE_VERSION=8. + # (pre-v8) — falls back to the canonical 2x ratio. New in TRACE_VERSION=8. trainable_param_fraction: float = 0.0 # ----- Phase-2 chunked-runtime measurements (TRACE_VERSION 10) ----- @@ -308,7 +317,7 @@ class ProfilerTrace: # These fields default to 0.0 / 0; the cost model treats 0.0 in # ``steady_bwd_chunked_wall_s`` as "no phase-2 measurement available" # and falls back to the v8 path (``steady_bwd_wall_s`` ratio → - # trainable-fraction heuristic → 2× canonical). + # trainable-fraction heuristic → 2x canonical). steady_bwd_chunked_wall_s: float = 0.0 steady_step_overlap_s: float = 0.0 steady_phase2_peak_bytes: int = 0 @@ -366,9 +375,9 @@ class ProfilerTrace: # captured pre-splice so the chunked-wall override does not # short-circuit the analytical path. # The cost model derives a multiplicative scale - # ``α = phase2_iter_s / phase2_analytical_iter_s`` and applies it to + # ``alpha = phase2_iter_s / phase2_analytical_iter_s`` and applies it to # any analytical-path prediction. When the analytical path is not - # taken (e.g. ``cfg.n_swap == 0`` and chunked walls populated) α is + # taken (e.g. ``cfg.n_swap == 0`` and chunked walls populated) alpha is # not consulted — the chunked-wall override is already absolute. # # ``phase2_analytical_peak_bytes`` plays the analogous role for peak @@ -382,7 +391,7 @@ class ProfilerTrace: # # All three fields default to 0 / 0.0 — that is the "no phase-2 # baseline available" sentinel that collapses both calibrations to - # their pre-refactor behaviour (no α scaling on the runtime side; + # their pre-refactor behaviour (no alpha scaling on the runtime side; # only the same-cfg measurement window on the peak side). phase2_iter_s: float = 0.0 phase2_analytical_iter_s: float = 0.0 @@ -390,7 +399,7 @@ class ProfilerTrace: # ----- Phase-2 PER-COMPONENT analytical-baseline calibration (TRACE_VERSION 21) ----- # - # The single-scalar α (``phase2_iter_s / phase2_analytical_iter_s``) + # The single-scalar alpha (``phase2_iter_s / phase2_analytical_iter_s``) # collapses three independent calibration scales — fwd, bwd, optim — # into one ratio anchored at the bootstrap cfg. That works only when # the production cfg has the same fwd/bwd/optim bias profile as boot; @@ -402,11 +411,11 @@ class ProfilerTrace: # forced an asymmetric structure-match gate that suppressed any # deflation outside boot's exact shape. # - # The per-component fix decomposes α into three independent scales: + # The per-component fix decomposes alpha into three independent scales: # - # αfwd = phase2_fwd_s / phase2_analytical_fwd_s - # αbwd = phase2_bwd_s / phase2_analytical_bwd_s - # αopt = phase2_step_s / phase2_analytical_step_s (= analytical + # alphafwd = phase2_fwd_s / phase2_analytical_fwd_s + # alphabwd = phase2_bwd_s / phase2_analytical_bwd_s + # alphaopt = phase2_step_s / phase2_analytical_step_s (= analytical # t_gpu_optim # + t_cpu_optim # at boot) @@ -414,9 +423,9 @@ class ProfilerTrace: # Each scale calibrates against the matching analytical component, so # cfg-shape changes that move the fwd/bwd/optim balance no longer # destabilise the prediction — the scales carry component-by-component - # rather than as a lumped ratio. This makes α<1 deflation safe (each + # rather than as a lumped ratio. This makes alpha<1 deflation safe (each # scale corrects only the component it was measured against), so the - # structure-match gate from the single-α era is dropped. + # structure-match gate from the single-alpha era is dropped. # # ``phase2_fwd_s`` / ``phase2_bwd_s`` / ``phase2_step_s`` are the # measured medians from ``measure_chunked_steady`` at the bootstrap @@ -427,7 +436,7 @@ class ProfilerTrace: # # All six default to 0.0 — the "no per-component baseline available" # sentinel. When any component baseline is zero, the cost model falls - # back to the single-α path (``phase2_iter_s / phase2_analytical_iter_s``) + # back to the single-alpha path (``phase2_iter_s / phase2_analytical_iter_s``) # if those legacy fields are populated, or to no calibration otherwise. # Cached traces from TRACE_VERSION <= 20 are invalidated by the # version bump on cache.py; in-memory traces constructed without these @@ -441,7 +450,7 @@ class ProfilerTrace: # ----- Phase-2 RESIDUAL whole-iter overhead anchor (TRACE_VERSION 22) ----- # - # Per-component α (TRACE_VERSION 21) corrects fwd/bwd/optim bias + # Per-component alpha (TRACE_VERSION 21) corrects fwd/bwd/optim bias # *within each component* — its strength is generalising the # measurement to a production cfg with a different fwd/bwd/optim # balance (different ``n_persist`` / ``n_swap`` / ``n_checkpoint``). @@ -449,28 +458,28 @@ class ProfilerTrace: # whole-iter overheads (Python hook dispatch, kernel launch latency, # NCCL handshake, allocator churn between fwd and bwd, etc.) that # scale roughly linearly with ``N_block`` rather than with any - # individual component. The previous single-α calibration absorbed + # individual component. The previous single-alpha calibration absorbed # those overheads accidentally because it scaled the whole iter; the # per-component decomposition by construction does not. # - # ``phase2_per_comp_pred_iter_s`` records what the per-component-α - # composition (using the SAME αfwd / αbwd / αopt values derived at + # ``phase2_per_comp_pred_iter_s`` records what the per-component-alpha + # composition (using the SAME alphafwd / alphabwd / alphaopt values derived at # boot) WOULD predict at the boot cfg. The cost model then derives # - # α_residual = phase2_iter_s / phase2_per_comp_pred_iter_s + # alpha_residual = phase2_iter_s / phase2_per_comp_pred_iter_s # # at boot and multiplies it onto every per-component prediction at - # production cfgs. By construction α_residual collapses to 1.0 when + # production cfgs. By construction alpha_residual collapses to 1.0 when # the per-component formula already explains the boot iter — i.e. # whole-iter overhead is fully captured by the components — so the # residual is a no-op on workloads where it should be. When the # analytical model systematically under-counts whole-iter overhead - # (the 7B-LoRA regression: ~50% bias on 32-block PEFT), α_residual + # (the 7B-LoRA regression: ~50% bias on 32-block PEFT), alpha_residual # > 1.0 inflates the prediction back toward the measurement. # # Bounds [0.8, 2.0] (wider on the inflate side than per-component's - # [0.5, 2.0]) reflect that residual α captures genuine missing - # overhead, not measurement noise — the natural regime is α ≥ 1. + # [0.5, 2.0]) reflect that residual alpha captures genuine missing + # overhead, not measurement noise — the natural regime is alpha ≥ 1. # # Default 0.0 means "no residual baseline available"; the cost # model collapses to per-component-only behaviour (the post- @@ -525,7 +534,7 @@ class ChunkLayout: of the source paper); ``mandatory_persistent`` is the local integration's correctness extension. Cost model + search keep ``cfg.n_persist`` strictly meaning "prefix length the search chose"; - the runtime resident set is ``{0..n_persist-1} ∪ mandatory_persistent``. + the runtime resident set is ``{0..n_persist-1} | mandatory_persistent``. The default is an empty frozenset so legacy ``ChunkLayout(...)`` constructions stay drop-in compatible. @@ -542,13 +551,7 @@ class ChunkLayout: mandatory_persistent: frozenset[ChunkId] = field(default_factory=frozenset) def effective_persistent_ids(self, n_persist: int) -> frozenset[ChunkId]: - """Return ``{0..n_persist-1} ∪ mandatory_persistent`` as a frozenset. - - Single source of truth for "which chunks are GPU-resident under - ``n_persist``" so the searcher, cost model, and runtime construction - cannot disagree. Clamps ``n_persist`` defensively into - ``[0, N_chunk]``. - """ + """Return ``{0..n_persist-1} | mandatory_persistent`` as a frozenset.""" n = max(0, min(int(n_persist), int(self.N_chunk))) prefix = {ChunkId(i) for i in range(n)} return frozenset(prefix | set(self.mandatory_persistent)) @@ -595,6 +598,7 @@ class SearchResult: block_map: BlockStrategyMap predicted_peak_bytes: int predicted_iter_s: float + predicted_init_transient_peak_bytes: int = 0 # --------------------------------------------------------------------------- @@ -639,6 +643,8 @@ class HardwareProfile: # scale. Populated by ``profiler.hw_bench.measure_compute_rate`` from # the model_wrapper just before the searcher runs. gpu_compute_tflops: float = 0.0 + # Drives per-dtype alpha lookup; bnb-4-bit ``Params4bit`` is mapped to 0.5 (packed) not the uint8 storage size. + dominant_param_bytes_per_element: float = 2.0 # --------------------------------------------------------------------------- diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index d5f2d9f780..5bcb6d5dcf 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -25,24 +25,48 @@ def check_cuda_p2p_ib_support(): def check_cuda_p2p_support() -> bool: + """Return True iff every local-GPU pair supports P2P; rank-symmetric and fail-closed on introspection failure.""" + # fail-closed: unintrospectable pairs must be treated as unsafe so all ranks agree on NCCL_P2P_DISABLE try: world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) except ValueError: + LOG.warning( + "check_cuda_p2p_support: invalid WORLD_SIZE=%r; disabling P2P " + "(fail-closed posture).", + os.environ.get("WORLD_SIZE"), + ) + return False + + if world_size <= 1: return True - if world_size > 1: - node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8")) - local_other_rank = (local_rank // node_world_size) * node_world_size - local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0 - try: - can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank) - except AssertionError as exc: - # some sort of logic error in indexing processes, assume p2p is fine for now - LOG.warning(exc) - return True - return can_p2p + try: + n = torch.cuda.device_count() + except Exception as exc: # pragma: no cover - defensive # noqa: BLE001 + LOG.warning( + "check_cuda_p2p_support: device_count failed (%s); disabling P2P " + "(fail-closed posture).", + exc, + ) + return False + if n <= 1: + return True + for i in range(n): + for j in range(i + 1, n): + try: + if not torch.cuda.can_device_access_peer(i, j): + return False + except Exception as exc: # noqa: BLE001 — broad catch keeps fail-closed even if C++ binding raises a non-AssertionError + LOG.warning( + "check_cuda_p2p_support: can_device_access_peer(%s, %s) " + "raised %s (%s); disabling P2P (fail-closed posture).", + i, + j, + type(exc).__name__, + exc, + ) + return False return True diff --git a/tests/protrain/peft_edge_cases/__init__.py b/tests/protrain/peft_edge_cases/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/protrain/peft_edge_cases/test_dora.py b/tests/protrain/peft_edge_cases/test_dora.py new file mode 100644 index 0000000000..0155efe8df --- /dev/null +++ b/tests/protrain/peft_edge_cases/test_dora.py @@ -0,0 +1,157 @@ +"""DoRA + ProTrain smoke: magnitude vectors must traverse the per-region split alongside LoRA factors.""" + +from __future__ import annotations + +import math + +import pytest + +pytestmark = pytest.mark.gpu + + +def _build_tiny_llama_with_dora(): + """Tiny Llama-arch LM with DoRA LoRA; prefers cached SmolLM2-135M, falls back to fresh-init.""" + pytest.importorskip("torch") + pytest.importorskip("transformers") + pytest.importorskip("peft") + + import torch + from peft import LoraConfig, get_peft_model + from transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaForCausalLM, + ) + + # Narrow to offline-load failure families so genuine API breakage still surfaces. + try: + cfg = AutoConfig.from_pretrained( + "HuggingFaceTB/SmolLM2-135M", local_files_only=True + ) + cfg.use_cache = False + model = AutoModelForCausalLM.from_pretrained( + "HuggingFaceTB/SmolLM2-135M", + local_files_only=True, + torch_dtype=torch.bfloat16, + ) + except (OSError, ValueError, EnvironmentError): + cfg = LlamaConfig( + hidden_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=512, + vocab_size=1024, + max_position_embeddings=128, + rms_norm_eps=1e-5, + use_cache=False, + ) + model = LlamaForCausalLM(cfg).to(dtype=torch.bfloat16) + + # --- DoRA-enabled LoRA config ---------------------------------------- + # Target the standard Llama attention + MLP linears. Use small r/alpha + # to keep the smoke fast; DoRA's distinguishing feature is the + # magnitude vector, not its rank. + lora_cfg = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + use_dora=True, + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(model, lora_cfg) + return peft_model, cfg + + +def test_protrain_dora_smoke() -> None: + """ProTrain + DoRA: 5 iters, finite losses, strictly decreasing.""" + pytest.importorskip("torch") + + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain DoRA smoke requires CUDA.") + + peft_model, cfg = _build_tiny_llama_with_dora() + + device = torch.device("cuda:0") + peft_model = peft_model.to(device) + + # --- Sanity: DoRA magnitude vectors must exist and be trainable ------ + # If this assertion fails, ``use_dora=True`` silently degraded to + # plain LoRA and the test wouldn't actually stress the new tensors. + magnitude_params = [ + (n, p) for n, p in peft_model.named_parameters() if "lora_magnitude_vector" in n + ] + assert magnitude_params, ( + "DoRA magnitude vectors not found; LoraConfig(use_dora=True) may " + "have silently degraded — this test would be testing plain LoRA" + ) + for n, p in magnitude_params: + assert p.requires_grad, f"DoRA magnitude vector {n} not trainable" + + # ProTrain wrap: Mode-A (single GPU, all chunks GPU-resident). + from axolotl.integrations.protrain.api import ( + protrain_model_wrapper, + protrain_optimizer_wrapper, + ) + from axolotl.integrations.protrain.types import HardwareProfile + + hw = HardwareProfile( + gpu_sku=torch.cuda.get_device_name(0), + gpu_memory_bytes=torch.cuda.get_device_properties(0).total_memory, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + + bs, seq = 1, 64 + # try/finally ensures hook handles, pinned-host borrows, and CPU adapter threads release on assertion failure. + wrapped = protrain_model_wrapper( + peft_model, + model_config=cfg, + hardware_profile=hw, + batch_size=bs, + seq_len=seq, + capacity_bytes=20 * (1 << 30), + force_all_persistent=True, + ) + try: + optim = protrain_optimizer_wrapper(wrapped, lr=1e-3) + + vocab = int(getattr(cfg, "vocab_size", 1024)) + torch.manual_seed(0) + input_ids = torch.randint(0, vocab, (bs, seq), device=device, dtype=torch.long) + labels = input_ids.clone() + + losses: list[float] = [] + n_iters = 5 + for i in range(n_iters): + out = wrapped.module(input_ids=input_ids, labels=labels) + loss = out.loss + loss_value = float(loss.detach()) + assert math.isfinite(loss_value), ( + f"iter {i}: non-finite loss {loss_value}; losses so far={losses}" + ) + loss.backward() + optim.step() + optim.zero_grad() + losses.append(loss_value) + + print(f"\nProTrain + DoRA smoke (tiny Llama): losses={losses}") + + # final < first on a fixed batch confirms DoRA magnitude vectors and LoRA factors actually receive gradient updates. + assert all(math.isfinite(v) for v in losses), f"non-finite loss in {losses}" + assert losses[-1] < losses[0], ( + f"DoRA + ProTrain loss did not decrease over {n_iters} iters: " + f"{losses} — magnitude vectors or LoRA factors may not be " + f"receiving gradient updates through the chunk-region split" + ) + finally: + close = getattr(wrapped, "close", None) + if callable(close): + close() diff --git a/tests/protrain/peft_edge_cases/test_multi_adapter.py b/tests/protrain/peft_edge_cases/test_multi_adapter.py new file mode 100644 index 0000000000..5aaa8044b3 --- /dev/null +++ b/tests/protrain/peft_edge_cases/test_multi_adapter.py @@ -0,0 +1,167 @@ +"""Multi-LoRA + ProTrain smoke: set_adapter transitions must not corrupt the chunk-region split.""" + +from __future__ import annotations + +import math + +import pytest + +pytestmark = pytest.mark.gpu + + +def _build_tiny_llama_with_two_adapters(): + pytest.importorskip("torch") + pytest.importorskip("transformers") + pytest.importorskip("peft") + + import torch + from peft import LoraConfig, get_peft_model + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=256, + vocab_size=512, + max_position_embeddings=64, + rms_norm_eps=1e-5, + use_cache=False, + ) + model = LlamaForCausalLM(cfg).to(dtype=torch.bfloat16) + + lora_alpha = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + lora_beta = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + + peft_model = get_peft_model(model, lora_alpha, adapter_name="alpha") + peft_model.add_adapter("beta", lora_beta) + return peft_model, cfg + + +def _wrap_protrain(peft_model, cfg, *, bs: int, seq: int, capacity_bytes: int): + import torch + + from axolotl.integrations.protrain.api import ( + protrain_model_wrapper, + protrain_optimizer_wrapper, + ) + from axolotl.integrations.protrain.types import HardwareProfile + + hw = HardwareProfile( + gpu_sku=torch.cuda.get_device_name(0), + gpu_memory_bytes=torch.cuda.get_device_properties(0).total_memory, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + wrapped = protrain_model_wrapper( + peft_model, + model_config=cfg, + hardware_profile=hw, + batch_size=bs, + seq_len=seq, + capacity_bytes=capacity_bytes, + force_all_persistent=True, + ) + optim = protrain_optimizer_wrapper(wrapped, lr=1e-3) + return wrapped, optim + + +def _train_loop(wrapped, optim, *, n_iters, input_ids, labels) -> list[float]: + + losses: list[float] = [] + for i in range(n_iters): + out = wrapped.module(input_ids=input_ids, labels=labels) + loss = out.loss + loss_value = float(loss.detach()) + assert math.isfinite(loss_value), f"iter {i}: non-finite loss {loss_value}" + loss.backward() + optim.step() + optim.zero_grad() + losses.append(loss_value) + return losses + + +def test_protrain_multi_lora_adapter_switch() -> None: + """ProTrain + multi-LoRA adapter switch: alpha 3 iters, beta 3 iters, no crash.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain multi-adapter smoke requires CUDA.") + + peft_model, cfg = _build_tiny_llama_with_two_adapters() + device = torch.device("cuda:0") + peft_model = peft_model.to(device) + + # Sanity: both adapters are present. + adapter_names = set(getattr(peft_model.peft_config, "keys", lambda: [])()) + assert {"alpha", "beta"}.issubset(adapter_names), ( + f"expected both adapters loaded, got {adapter_names}" + ) + + bs, seq = 1, 32 + vocab = int(cfg.vocab_size) + torch.manual_seed(0) + input_ids = torch.randint(0, vocab, (bs, seq), device=device, dtype=torch.long) + labels = input_ids.clone() + + # Explicit close before re-wrap so DDP-ignore restoration and CPU-adapter teardown are deterministic, not GC-timing dependent. + wrapped_b = None + try: + peft_model.set_adapter("alpha") + wrapped_a, optim_a = _wrap_protrain( + peft_model, cfg, bs=bs, seq=seq, capacity_bytes=4 * (1 << 30) + ) + try: + losses_alpha = _train_loop( + wrapped_a, optim_a, n_iters=3, input_ids=input_ids, labels=labels + ) + assert losses_alpha[-1] < losses_alpha[0], ( + f"alpha adapter did not train: {losses_alpha}" + ) + finally: + close_a = getattr(wrapped_a, "close", None) + if callable(close_a): + close_a() + + # Switch to beta. Re-wrap (chunk layout depends on requires_grad which + # changed) and train another 3 iters. The point of the test is that + # the set_adapter transition + re-wrap path doesn't crash and beta + # also makes progress. + peft_model.set_adapter("beta") + wrapped_b, optim_b = _wrap_protrain( + peft_model, cfg, bs=bs, seq=seq, capacity_bytes=4 * (1 << 30) + ) + losses_beta = _train_loop( + wrapped_b, optim_b, n_iters=3, input_ids=input_ids, labels=labels + ) + assert losses_beta[-1] < losses_beta[0], ( + f"beta adapter did not train after switch: {losses_beta}" + ) + + print( + f"\nProTrain + multi-adapter: losses_alpha={losses_alpha} " + f"losses_beta={losses_beta}" + ) + finally: + if wrapped_b is not None: + close_b = getattr(wrapped_b, "close", None) + if callable(close_b): + close_b() diff --git a/tests/protrain/peft_edge_cases/test_vision_lm_hybrid.py b/tests/protrain/peft_edge_cases/test_vision_lm_hybrid.py new file mode 100644 index 0000000000..7e7abe515c --- /dev/null +++ b/tests/protrain/peft_edge_cases/test_vision_lm_hybrid.py @@ -0,0 +1,134 @@ +"""Mixed trainable/frozen + LoRA + ProTrain smoke: chunk-region split must absorb a non-uniform requires_grad map.""" + +from __future__ import annotations + +import math + +import pytest + +pytestmark = pytest.mark.gpu + + +def _build_tiny_llama_mixed_trainable(): + pytest.importorskip("torch") + pytest.importorskip("transformers") + pytest.importorskip("peft") + + import torch + from peft import LoraConfig, get_peft_model + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=256, + vocab_size=512, + max_position_embeddings=64, + rms_norm_eps=1e-5, + use_cache=False, + ) + base_lm = LlamaForCausalLM(cfg).to(dtype=torch.bfloat16) + lora_cfg = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(base_lm, lora_cfg) + + # Trainable embedding alongside LoRA factors yields the 3-way frozen/LoRA/dense requires_grad split. + embed = peft_model.get_input_embeddings() + for p in embed.parameters(): + p.requires_grad = True + + return peft_model, cfg + + +def test_protrain_mixed_trainable_frozen_smoke() -> None: + """ProTrain + LoRA + trainable embed_tokens (mixed-grad chunk regions): 5 iters.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain mixed trainable/frozen smoke requires CUDA.") + + # Seed before model build so LoRA init is reproducible; re-seed at randint to make the synthetic batch deterministic. + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + peft_model, cfg = _build_tiny_llama_mixed_trainable() + device = torch.device("cuda:0") + peft_model = peft_model.to(device) + + # Sanity: trainable surface is what we expect (LoRA + embedding). + trainable = {n for n, p in peft_model.named_parameters() if p.requires_grad} + has_lora = any("lora" in n.lower() for n in trainable) + has_embed = any("embed_tokens" in n for n in trainable) + assert has_lora, f"expected trainable LoRA params, got {sorted(trainable)[:5]}" + assert has_embed, ( + f"expected embed_tokens.weight to be trainable, got {sorted(trainable)[:5]}" + ) + # And we still have frozen base attention/MLP — otherwise the test + # degrades to "everything trainable" and the mixed-grad split isn't + # exercised. + frozen = [n for n, p in peft_model.named_parameters() if not p.requires_grad] + assert any("self_attn" in n or "mlp" in n for n in frozen), ( + f"expected frozen base attn/mlp, got first 5 frozen={frozen[:5]}" + ) + + from axolotl.integrations.protrain.api import ( + protrain_model_wrapper, + protrain_optimizer_wrapper, + ) + from axolotl.integrations.protrain.types import HardwareProfile + + hw = HardwareProfile( + gpu_sku=torch.cuda.get_device_name(0), + gpu_memory_bytes=torch.cuda.get_device_properties(0).total_memory, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + + bs, seq = 1, 32 + wrapped = protrain_model_wrapper( + peft_model, + model_config=cfg, + hardware_profile=hw, + batch_size=bs, + seq_len=seq, + capacity_bytes=4 * (1 << 30), + force_all_persistent=True, + ) + optim = protrain_optimizer_wrapper(wrapped, lr=1e-3) + + torch.manual_seed(0) + input_ids = torch.randint( + 0, cfg.vocab_size, (bs, seq), device=device, dtype=torch.long + ) + labels = input_ids.clone() + + losses: list[float] = [] + for i in range(5): + out = wrapped.module(input_ids=input_ids, labels=labels) + loss = out.loss + loss_value = float(loss.detach()) + assert math.isfinite(loss_value), f"iter {i}: non-finite loss {loss_value}" + loss.backward() + optim.step() + optim.zero_grad() + losses.append(loss_value) + + print(f"\nProTrain + mixed trainable/frozen: losses={losses}") + + assert all(math.isfinite(v) for v in losses), f"non-finite loss in {losses}" + assert losses[-1] < losses[0], ( + f"mixed trainable/frozen loss did not decrease: {losses} — chunk-" + f"region split for mixed-grad components may be silently dropping " + f"gradient updates" + ) diff --git a/tests/protrain/test_adamw8bit_adapter.py b/tests/protrain/test_adamw8bit_adapter.py new file mode 100644 index 0000000000..0b2474acb4 --- /dev/null +++ b/tests/protrain/test_adamw8bit_adapter.py @@ -0,0 +1,394 @@ +"""Unit tests for ``GpuAdamW8bitAdapter`` construction, state round-trip, and the wrapper dispatch path.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any +from unittest import mock + +import pytest + +from axolotl.integrations.protrain.chunk.optim import ( + GpuAdamW8bitAdapter, + GpuFusedAdamAdapter, +) + +if TYPE_CHECKING: + import torch +else: + torch = pytest.importorskip("torch") + + +pytestmark = pytest.mark.gpu + + +def _gpu_device() -> "torch.device": + """Pick a CUDA device that respects ``CUDA_VISIBLE_DEVICES`` and skip cleanly when CUDA is absent.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available; test_adamw8bit_adapter requires GPU.") + return torch.device("cuda:0") + + +# --------------------------------------------------------------------------- +# Adapter unit tests +# --------------------------------------------------------------------------- + + +def test_adapter_state_shapes_after_step() -> None: + """After one step, per-param state must carry the bnb 8-bit moments.""" + bnb = pytest.importorskip("bitsandbytes") + device = _gpu_device() + # min_8bit_size defaults to 4096 — we need enough elements per param + # for bnb to actually 8-bit-quantize the state (smaller params fall + # back to fp32 state internally and ``state1.dtype`` would be float). + p = torch.nn.Parameter(torch.randn(128, 128, dtype=torch.float32, device=device)) + adapter = GpuAdamW8bitAdapter( + params=[p], + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.01, + ) + p.grad = torch.randn_like(p) + adapter.step() + + state = adapter.underlying.state[p] + assert state["state1"].dtype == torch.uint8 + assert state["state2"].dtype == torch.uint8 + assert state["state1"].shape == p.shape + assert state["state2"].shape == p.shape + # Codebooks (256-entry quantization maps) and absmax block scales. + assert state["qmap1"].shape == (256,) + assert state["qmap2"].shape == (256,) + assert state["absmax1"].numel() > 0 + assert state["absmax2"].numel() > 0 + # ``bnb`` is imported by the adapter; keep the reference alive for + # the assertions to be non-trivial under some lazy-import paths. + assert bnb is not None + + +def test_state_dict_round_trip_preserves_8bit_state() -> None: + """state_dict -> new adapter -> load_state_dict preserves uint8 moments.""" + pytest.importorskip("bitsandbytes") + device = _gpu_device() + torch.manual_seed(123) + p1 = torch.nn.Parameter(torch.randn(256, 256, dtype=torch.float32, device=device)) + adapter1 = GpuAdamW8bitAdapter(params=[p1], lr=1e-3) + p1.grad = torch.randn_like(p1) + adapter1.step() + + state1_before = adapter1.underlying.state[p1]["state1"].clone() + state2_before = adapter1.underlying.state[p1]["state2"].clone() + qmap1_before = adapter1.underlying.state[p1]["qmap1"].clone() + absmax1_before = adapter1.underlying.state[p1]["absmax1"].clone() + sd = adapter1.state_dict() + + # Fresh adapter, identical params, load the saved state. + p2 = torch.nn.Parameter(p1.detach().clone()) + adapter2 = GpuAdamW8bitAdapter(params=[p2], lr=1e-3) + adapter2.load_state_dict(sd) + + state1_after = adapter2.underlying.state[p2]["state1"] + state2_after = adapter2.underlying.state[p2]["state2"] + qmap1_after = adapter2.underlying.state[p2]["qmap1"] + absmax1_after = adapter2.underlying.state[p2]["absmax1"] + assert torch.equal(state1_before, state1_after) + assert torch.equal(state2_before, state2_after) + assert torch.equal(qmap1_before, qmap1_after) + assert torch.equal(absmax1_before, absmax1_after) + + +def test_cpu_param_raises_clear_error() -> None: + """Constructing the adapter with CPU params must surface the bail condition.""" + pytest.importorskip("bitsandbytes") + p = torch.nn.Parameter(torch.randn(128, 128, dtype=torch.float32, device="cpu")) + with pytest.raises(RuntimeError) as exc_info: + GpuAdamW8bitAdapter(params=[p], lr=1e-3) + msg = str(exc_info.value) + assert "CUDA" in msg + assert "non-persistent" in msg + assert "M2.5" in msg or "CpuFusedAdamAdapter" in msg + + +def test_empty_param_set_is_no_op() -> None: + """Mode-C with no persistent chunks: empty adapter must short-circuit cleanly.""" + pytest.importorskip("bitsandbytes") + adapter = GpuAdamW8bitAdapter(params=[], lr=1e-3) + # No underlying optimizer. + assert adapter.underlying is None + # step / zero_grad are silent no-ops; state_dict returns the + # canonical empty shape. + adapter.step() + adapter.zero_grad() + sd = adapter.state_dict() + assert sd == {"state": {}, "param_groups": []} + # load_state_dict accepts the matching empty shell silently. + adapter.load_state_dict({"state": {}, "param_groups": []}) + # ...but rejects a non-empty payload (Mode-A/Mode-C config mismatch). + with pytest.raises(ValueError): + adapter.load_state_dict({"state": {0: {"step": 1}}, "param_groups": []}) + + +def test_paged_variant_constructs_paged_class() -> None: + """``paged=True`` must instantiate ``bnb.optim.PagedAdamW8bit``.""" + bnb = pytest.importorskip("bitsandbytes") + device = _gpu_device() + p = torch.nn.Parameter(torch.randn(128, 128, dtype=torch.float32, device=device)) + adapter = GpuAdamW8bitAdapter(params=[p], lr=1e-3, paged=True) + assert isinstance(adapter.underlying, bnb.optim.PagedAdamW8bit) + + +def test_step_actually_updates_params() -> None: + """One step should mutate ``param.data`` (sanity-check that the kernel ran).""" + pytest.importorskip("bitsandbytes") + device = _gpu_device() + torch.manual_seed(7) + p = torch.nn.Parameter(torch.randn(128, 128, dtype=torch.float32, device=device)) + p_before = p.detach().clone() + adapter = GpuAdamW8bitAdapter(params=[p], lr=1e-2) + p.grad = torch.ones_like(p) + adapter.step() + # AdamW with positive grads + positive LR moves params toward zero on the + # first step; the deltas are non-zero everywhere. + assert not torch.equal(p.detach(), p_before) + + +# --------------------------------------------------------------------------- +# Dispatch test — protrain_optimizer_wrapper routing +# --------------------------------------------------------------------------- + + +class _FakeChunkLayout: + """Minimal stand-in for ``ChunkLayout`` exposing only the ``chunks`` field the wrapper iterates.""" + + def __init__(self, chunks: list[list[int]]) -> None: + self.chunks = chunks + + +class _FakeChunkManager: + """Minimal stand-in for ``ChunkManager`` for the dispatch test.""" + + def __init__( + self, + params_by_id: dict[int, torch.nn.Parameter], + persistent_ids: set[int], + chunks: list[list[int]], + ) -> None: + self.layout = _FakeChunkLayout(chunks) + self._params_by_id = params_by_id + self._persistent_ids = persistent_ids + self._non_persistent_ids = { + cid for cid, _ in enumerate(chunks) if cid not in persistent_ids + } + self._chunk_shards: dict[int, Any] = {} + self._cpu_slots: dict[int, list[Any]] = {} + # cpu_optim / gpu_optim are written by the wrapper at the end. + self.cpu_optim = None + self.gpu_optim = None + self.zero3_shard = False + + +def _build_dispatch_fixture( + n_persistent_params: int = 1, + n_cpu_params: int = 0, +) -> tuple[Any, list[torch.nn.Parameter]]: + """Build a tiny WrappedModel + persistent-only chunk layout on CUDA.""" + device = _gpu_device() + persistent = [ + torch.nn.Parameter(torch.randn(128, 128, dtype=torch.float32, device=device)) + for _ in range(n_persistent_params) + ] + cpu_params = [ + torch.nn.Parameter(torch.randn(128, 128, dtype=torch.float32, device="cpu")) + for _ in range(n_cpu_params) + ] + all_params = persistent + cpu_params + params_by_id = {i: p for i, p in enumerate(all_params)} + chunks = [[i] for i in range(len(all_params))] + persistent_ids = set(range(n_persistent_params)) + + cm = _FakeChunkManager( + params_by_id=params_by_id, + persistent_ids=persistent_ids, + chunks=chunks, + ) + # ``module`` is consulted by ``_collect_no_decay_param_ids``; an empty + # nn.Module has no params, so the no-decay set is empty (acceptable + # for this dispatch test). + module = torch.nn.Module() + wrapped = SimpleNamespace( + module=module, + chunk_manager=cm, + ) + return wrapped, persistent + + +@pytest.mark.parametrize( + "optim_name", + ["adamw_8bit", "adamw_bnb_8bit", "paged_adamw_8bit"], +) +def test_dispatch_routes_8bit_names_to_bnb_adapter(optim_name: str) -> None: + """All three Axolotl/HF 8-bit names route persistent set through the bnb adapter.""" + pytest.importorskip("bitsandbytes") + pytest.importorskip("deepspeed") # CpuFusedAdam path import — ok if missing? skip + from axolotl.integrations.protrain.api.optim_wrapper import ( + protrain_optimizer_wrapper, + ) + + wrapped, _persistent = _build_dispatch_fixture( + n_persistent_params=1, + n_cpu_params=0, + ) + optim = protrain_optimizer_wrapper( + wrapped, + lr=1e-3, + optimizer_name=optim_name, + ) + # Inner adapter must be the 8-bit variant. + assert isinstance(optim._gpu_optim, GpuAdamW8bitAdapter) + if optim_name == "paged_adamw_8bit": + assert optim._gpu_optim.paged is True + else: + assert optim._gpu_optim.paged is False + # No CPU chunks in this fixture, so cpu_optim is None. + assert optim._cpu_optim is None + + +def test_dispatch_default_optimizer_uses_fused_adam() -> None: + """``optimizer_name=None`` (and unrelated names) keeps the GpuFusedAdamAdapter path.""" + pytest.importorskip("bitsandbytes") + from axolotl.integrations.protrain.api.optim_wrapper import ( + protrain_optimizer_wrapper, + ) + + wrapped, _persistent = _build_dispatch_fixture( + n_persistent_params=1, + n_cpu_params=0, + ) + # Default / non-8bit name: persistent set must use the legacy path. + optim = protrain_optimizer_wrapper( + wrapped, + lr=1e-3, + optimizer_name="adamw_torch", + ) + assert isinstance(optim._gpu_optim, GpuFusedAdamAdapter) + + +def test_dispatch_warns_when_8bit_requested_with_cpu_chunks() -> None: + """Bail-condition warning fires when 8-bit + non-persistent chunks coexist.""" + pytest.importorskip("bitsandbytes") + pytest.importorskip("deepspeed") + from axolotl.integrations.protrain.api.optim_wrapper import ( + protrain_optimizer_wrapper, + ) + + wrapped, _persistent = _build_dispatch_fixture( + n_persistent_params=1, + n_cpu_params=1, + ) + # CpuFusedAdamAdapter requires DeepSpeed's compiled CPU Adam kernel — + # under DS_SKIP_CUDA_CHECK this is JIT-built on demand. Stub it so + # this test does not depend on the local DS build state. + captured_warnings: list[str] = [] + + def _capture_warning(msg, *args, **kwargs): + # ``LOG.warning`` from the wrapper uses %-style formatting. + try: + captured_warnings.append(msg % args if args else msg) + except (TypeError, ValueError): + captured_warnings.append(str(msg)) + + with mock.patch( + "axolotl.integrations.protrain.chunk.optim.CpuFusedAdamAdapter", + autospec=True, + ) as fake_cpu_cls: + fake_cpu_cls.return_value = mock.MagicMock(_optims={}) + with mock.patch( + "axolotl.integrations.protrain.api.optim_wrapper.CpuFusedAdamAdapter", + fake_cpu_cls, + ): + with mock.patch( + "axolotl.integrations.protrain.api.optim_wrapper.LOG.warning", + side_effect=_capture_warning, + ): + _optim = protrain_optimizer_wrapper( + wrapped, + lr=1e-3, + optimizer_name="adamw_8bit", + ) + # The bail-condition warning must surface. + assert any( + "8-bit Adam kernels are CUDA-only" in msg for msg in captured_warnings + ), captured_warnings + + +# End-to-end smoke: full ProTrain pipeline with adamw_8bit on tiny GPT-2. + + +def _tiny_gpt2(device): + """Smallest HF causal-LM the profiler's batch factory drives end-to-end.""" + pytest.importorskip("transformers") + from transformers import GPT2Config, GPT2LMHeadModel + + torch.manual_seed(0) + cfg = GPT2Config( + n_layer=2, + n_head=2, + n_embd=64, + vocab_size=128, + n_positions=128, + ) + return GPT2LMHeadModel(cfg).to(device) + + +@pytest.mark.slow +def test_end_to_end_5_steps_descending_loss() -> None: + """5 forward+backward+step iterations on tiny GPT-2 with adamw_8bit yield descending loss.""" + pytest.importorskip("torch") + pytest.importorskip("transformers") + pytest.importorskip("bitsandbytes") + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA") + + from axolotl.integrations.protrain import auto_wrap + from axolotl.integrations.protrain.api.optim_wrapper import ( + protrain_optimizer_wrapper, + ) + + device = torch.device("cuda") + model = _tiny_gpt2(device) + + wrapped = auto_wrap(model, batch_size=2, seq_len=8) + try: + optim = protrain_optimizer_wrapper( + wrapped, + lr=1e-2, # high enough to see loss move in 5 steps + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + optimizer_name="adamw_8bit", + ) + # Persistent set on tiny model routes to the 8-bit adapter; no CPU chunks in Mode A. + assert isinstance(optim._gpu_optim, GpuAdamW8bitAdapter), ( + f"expected GpuAdamW8bitAdapter, got {type(optim._gpu_optim).__name__}" + ) + + # Overfit a single fixed batch so per-iter noise cannot mask the descent. + torch.manual_seed(42) + fixed_input = torch.randint(0, 128, (2, 8), device=device) + losses: list[float] = [] + for _ in range(5): + out = wrapped.module(input_ids=fixed_input, labels=fixed_input) + loss = out.loss + losses.append(float(loss.detach())) + loss.backward() + optim.step() + optim.zero_grad() + + assert len(losses) == 5 + assert all(loss > 0 for loss in losses), f"non-positive loss: {losses}" + assert losses[-1] < losses[0], f"loss did not descend: {losses}" + finally: + # Release CUDA/chunk resources so a failure cannot leak into later GPU tests. + wrapped.close() diff --git a/tests/protrain/test_alpha_per_dtype.py b/tests/protrain/test_alpha_per_dtype.py new file mode 100644 index 0000000000..f4432278cc --- /dev/null +++ b/tests/protrain/test_alpha_per_dtype.py @@ -0,0 +1,223 @@ +"""Pin the per-dtype alpha fragmentation factor lookup so the 4-bit branch can't silently regress.""" + +from __future__ import annotations + +import pytest + +from axolotl.integrations.protrain.cost.memory import ( + ALPHA_FRAGMENTATION, + ALPHA_FRAGMENTATION_4BIT, + alpha_fragmentation_for_dtype, +) + + +def test_constants_have_expected_values(): + """Lock the two named constants so unrelated edits cannot drift the calibration.""" + assert ALPHA_FRAGMENTATION == pytest.approx(1.10) + assert ALPHA_FRAGMENTATION_4BIT == pytest.approx(0.75) + + +@pytest.mark.parametrize( + ("bpe", "expected_alpha", "description"), + [ + # fp32 — alpha=1.10 (the >=1.0 branch). + (4.0, ALPHA_FRAGMENTATION, "fp32 weights → alpha=1.10"), + # fp16 / bf16 — alpha=1.10 (paper default; Block G alpha_measured ≈ 0.96). + (2.0, ALPHA_FRAGMENTATION, "fp16/bf16 weights → alpha=1.10"), + # bnb 8-bit — alpha=1.10 (Block G alpha_measured ≈ 0.93; mildly conservative). + (1.0, ALPHA_FRAGMENTATION, "bnb 8-bit weights → alpha=1.10"), + # bnb 4-bit (Params4bit) — alpha=0.75 (Block G alpha_measured ≈ 0.70). + (0.5, ALPHA_FRAGMENTATION_4BIT, "bnb 4-bit weights → alpha=0.75"), + ], +) +def test_alpha_lookup_by_dtype(bpe: float, expected_alpha: float, description: str): + assert alpha_fragmentation_for_dtype(bpe) == pytest.approx(expected_alpha), ( + description + ) + + +def test_alpha_lookup_threshold_is_one_byte(): + """The fp16/8-bit-vs-4-bit cutoff is exactly 1.0 B/element.""" + # Strictly below the cutoff — 4-bit branch. + assert alpha_fragmentation_for_dtype(0.99) == pytest.approx( + ALPHA_FRAGMENTATION_4BIT + ) + # Exactly at the cutoff — fp16 branch (8-bit is conservative-ish, keep alpha=1.10). + assert alpha_fragmentation_for_dtype(1.0) == pytest.approx(ALPHA_FRAGMENTATION) + # Strictly above the cutoff — fp16 branch. + assert alpha_fragmentation_for_dtype(1.01) == pytest.approx(ALPHA_FRAGMENTATION) + + +def test_alpha_lookup_extreme_bpe_does_not_crash(): + """Boundary / out-of-range inputs land in one of the two known branches.""" + # Tiny positive value — still routes to 4-bit branch. + assert alpha_fragmentation_for_dtype(0.001) == pytest.approx( + ALPHA_FRAGMENTATION_4BIT + ) + # Zero — by the documented rule (< 1.0) routes to 4-bit branch. + assert alpha_fragmentation_for_dtype(0.0) == pytest.approx(ALPHA_FRAGMENTATION_4BIT) + # Negative — by the documented rule (< 1.0) routes to 4-bit branch. + # Real callers should never pass negative; this just locks behaviour + # so a future ``max(0, bpe)`` guard is opt-in. + assert alpha_fragmentation_for_dtype(-1.0) == pytest.approx( + ALPHA_FRAGMENTATION_4BIT + ) + # Very large value — fp16 branch. + assert alpha_fragmentation_for_dtype(1024.0) == pytest.approx(ALPHA_FRAGMENTATION) + + +def test_dominant_param_dtype_detector_default_for_fp16_model(): + """The detector returns 2.0 (fp16) for a typical bf16 model so non-quantized callers stay at alpha=1.10.""" + import torch + from torch import nn + + from axolotl.integrations.protrain.api.model_wrapper import ( + _detect_dominant_param_bytes_per_element, + ) + + class _Toy(nn.Module): + def __init__(self) -> None: + super().__init__() + # Two layers' worth of bf16 weights — dominant by aggregate count. + self.w1 = nn.Parameter(torch.zeros(128, 64, dtype=torch.bfloat16)) + self.w2 = nn.Parameter(torch.zeros(64, 32, dtype=torch.bfloat16)) + # A small fp32 buffer (layer-norm-scale-shaped) that should NOT + # flip the dominant classification despite element_size=4. + self.ln = nn.Parameter(torch.zeros(32, dtype=torch.float32)) + + bpe = _detect_dominant_param_bytes_per_element(_Toy()) + assert bpe == pytest.approx(2.0), ( + f"bf16 model with a small fp32 LN param should classify as bpe=2.0, got {bpe}" + ) + + +def test_dominant_param_dtype_detector_returns_default_on_empty_model(): + """The detector falls back to 2.0 (fp16/bf16) on a paramless model so the cost model picks alpha=1.10.""" + from torch import nn + + from axolotl.integrations.protrain.api.model_wrapper import ( + _detect_dominant_param_bytes_per_element, + ) + + class _Empty(nn.Module): + pass + + assert _detect_dominant_param_bytes_per_element(_Empty()) == pytest.approx(2.0) + + +def test_dominant_param_dtype_detector_classifies_int8_dominant_model(): + """An int8-dominant model with bf16 LoRA factors still classifies as bpe=1.0 and lands on alpha=1.10.""" + import torch + from torch import nn + + from axolotl.integrations.protrain.api.model_wrapper import ( + _detect_dominant_param_bytes_per_element, + ) + + class _Int8Heavy(nn.Module): + def __init__(self) -> None: + super().__init__() + # Large int8-storage weight (analog for bnb int8 base) — the + # numel here is the logical-element count too (int8 is 1:1). + self.base_w = nn.Parameter( + torch.zeros(4096, 4096, dtype=torch.uint8), requires_grad=False + ) + # Small bf16 LoRA factors on top. + self.lora_a = nn.Parameter(torch.zeros(16, 4096, dtype=torch.bfloat16)) + self.lora_b = nn.Parameter(torch.zeros(4096, 16, dtype=torch.bfloat16)) + + bpe = _detect_dominant_param_bytes_per_element(_Int8Heavy()) + assert bpe == pytest.approx(1.0), ( + f"int8-dominant model should classify as bpe=1.0, got {bpe}" + ) + # And the lookup routes it to the conservative alpha=1.10. + assert alpha_fragmentation_for_dtype(bpe) == pytest.approx(ALPHA_FRAGMENTATION) + + +def test_estimate_peak_uses_per_dtype_alpha(): + """End-to-end pin: bpe=0.5 makes ``estimate_peak`` scale by 0.75 (4-bit alpha) while bpe=2.0 stays at 1.10.""" + from axolotl.integrations.protrain.cost.memory import estimate_peak + from axolotl.integrations.protrain.types import ( + BlockId, + BlockMode, + BlockStrategyMap, + ChunkLayout, + CostConfig, + HardwareProfile, + ProfilerTrace, + ) + + # Minimal viable trace + layout — one block, one tiny op. No + # measured per-block peaks, no measured deltas, so the op-walk + # raw peak is dominated by ``model_state_present`` (which is 0 + # because ``model_state_bytes`` is 0) plus the persistent / + # buffer pool terms. + # We arrange S_chunk * (n_persist + n_buffer) = 1 GiB so the raw + # peak is large and easy to multiply against alpha. + s_chunk = 1 << 28 # 256 MiB + n_chunk = 4 + layout = ChunkLayout( + S_chunk=s_chunk, + N_chunk=n_chunk, + chunks=tuple(tuple() for _ in range(n_chunk)), # type: ignore[arg-type] + param_to_chunk={}, + block_to_chunks={BlockId(0): ()}, + ) + trace = ProfilerTrace( + op_order=(), + intra_op_delta={}, + inter_op_delta={}, + activation_sizes={BlockId(0): 0}, + model_state_bytes=0, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + nccl_gather_s={}, + nccl_reduce_s={}, + arch_hash="test", + bs=1, + seq=16, + sku="test", + world=1, + ) + cfg = CostConfig(n_persist=2, n_buffer=2, n_swap=0, n_checkpoint=0) + block_map: BlockStrategyMap = {BlockId(0): BlockMode.NONE} + + # Default HW profile — bpe=2.0 lands on alpha=1.10. + hw_fp16 = HardwareProfile( + gpu_sku="test", + gpu_memory_bytes=24 * (1 << 30), + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + # 4-bit HW profile — bpe=0.5 lands on alpha=0.75. + hw_4bit = HardwareProfile( + gpu_sku="test", + gpu_memory_bytes=24 * (1 << 30), + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + dominant_param_bytes_per_element=0.5, + ) + + peak_fp16 = estimate_peak(cfg, trace, layout, block_map, hw_fp16) + peak_4bit = estimate_peak(cfg, trace, layout, block_map, hw_4bit) + + # The alpha=0.75 branch must return strictly less peak than the + # alpha=1.10 branch on the same raw inputs — concrete value depends + # on the op-walk's exact accounting, so assert the relative + # contract. + assert peak_4bit < peak_fp16, ( + f"per-dtype alpha should yield smaller peak for 4-bit " + f"(alpha=0.75): got peak_4bit={peak_4bit}, peak_fp16={peak_fp16}" + ) + # Ratio is 0.75 / 1.10 modulo int() rounding (cost model + # casts the alpha-scaled value to int). Use 1% slack. + expected_ratio = ALPHA_FRAGMENTATION_4BIT / ALPHA_FRAGMENTATION + observed_ratio = peak_4bit / max(peak_fp16, 1) + assert observed_ratio == pytest.approx(expected_ratio, rel=0.01), ( + f"peak_4bit / peak_fp16 = {observed_ratio:.4f} should match " + f"alpha_4bit / alpha_fp16 = {expected_ratio:.4f}" + ) diff --git a/tests/protrain/test_bnb_offload.py b/tests/protrain/test_bnb_offload.py new file mode 100644 index 0000000000..83d2f5f5a7 --- /dev/null +++ b/tests/protrain/test_bnb_offload.py @@ -0,0 +1,426 @@ +"""bnb 4-bit / 8-bit composition with the ProTrain offload path: gather/offload must not perturb ``quant_state``.""" + +from __future__ import annotations + +from typing import cast + +import pytest + +from axolotl.integrations.protrain.types import ( + BlockId, + ParamId, +) + +# --------------------------------------------------------------------------- +# Helpers — mirror the patterns used in test_chunk_manager_offload.py +# --------------------------------------------------------------------------- + + +def _bnb_or_skip(): + """Import bitsandbytes or skip — CPU-only CI lanes may lack the optional package.""" + try: + import bitsandbytes as bnb # noqa: F401 + + return bnb + except ImportError as exc: # pragma: no cover — env probe + pytest.skip(f"bitsandbytes unavailable: {exc}") + + +def _tiny_bnb_model(hidden: int = 64, n_layers: int = 2): + """A tiny Llama-shaped model whose blocks use ``bnb.nn.Linear4bit`` so the offload path hits real ``Params4bit`` storage.""" + bnb = _bnb_or_skip() + + import torch + from torch import nn + + class TinyBlock(nn.Module): + """One transformer-shaped block: a Linear4bit acting as ``self_attn``.""" + + def __init__(self) -> None: + super().__init__() + self.self_attn = bnb.nn.Linear4bit( + hidden, + hidden, + bias=False, + compute_dtype=torch.bfloat16, + quant_type="nf4", + quant_storage=torch.uint8, + ) + + def forward(self, x: "torch.Tensor") -> "torch.Tensor": + return self.self_attn(x) + + class InnerLlama(nn.Module): + """Inner ``model.layers`` container; matches the Llama path layout.""" + + def __init__(self) -> None: + super().__init__() + self.embed_tokens = nn.Linear(hidden, hidden, bias=False).to( + dtype=torch.bfloat16 + ) + self.layers = nn.ModuleList([TinyBlock() for _ in range(n_layers)]) + + def forward(self, x: "torch.Tensor") -> "torch.Tensor": + x = self.embed_tokens(x) + for layer in self.layers: + x = layer(x) + return x + + class TinyBnbLlama(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = InnerLlama() + self.lm_head = nn.Linear(hidden, hidden, bias=False).to( + dtype=torch.bfloat16 + ) + + def forward(self, x: "torch.Tensor") -> "torch.Tensor": + return self.lm_head(self.model(x)) + + torch.manual_seed(0) + return TinyBnbLlama() + + +def _build_layout_for(model, S_chunk: int): + """Build a ChunkLayout where each ``model.layers.{i}`` block is its own chunk.""" + from axolotl.integrations.protrain.chunk.layout import build_layout + + block_spans: dict[BlockId, list[ParamId]] = {} + for name, _ in model.named_parameters(): + if name.startswith("model.layers."): + idx = int(name.split(".")[2]) + block_spans.setdefault(cast(BlockId, idx), []).append(cast(ParamId, name)) + + exec_order = [cast(ParamId, n) for n, _ in model.named_parameters()] + return build_layout(model, exec_order, S_chunk, block_spans) + + +def _build_chunk_manager( + model, n_persist: int, S_chunk: int, n_buffer: int | None = None +): + import torch + + from axolotl.integrations.protrain.chunk.buffer_pool import BufferPool + from axolotl.integrations.protrain.chunk.manager import ChunkManager + from axolotl.integrations.protrain.chunk.pinned_alloc import PinnedHostMemory + + layout = _build_layout_for(model, S_chunk) + if n_buffer is None: + n_buffer = max(2, min(4, layout.N_chunk - n_persist)) + host = PinnedHostMemory(n_buffer=n_buffer, S_chunk=layout.S_chunk) + pool = BufferPool( + n_buffer=n_buffer, + S_chunk=layout.S_chunk, + pinned_host=host, + device=torch.device("cuda"), + ) + mgr = ChunkManager( + model=model, + layout=layout, + n_persist=n_persist, + buffer_pool=pool, + cpu_optim=None, + gpu_optim=None, + device=torch.device("cuda"), + ) + return mgr, layout, pool, host + + +# --------------------------------------------------------------------------- +# Test 1: bnb 4-bit module discovery +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +def test_bnb_4bit_module_discovery_in_trace() -> None: + """``discover_blocks`` finds blocks containing ``bnb.nn.Linear4bit`` (no special-casing of standard linears).""" + bnb = _bnb_or_skip() + + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime (Linear4bit needs cuda)") + + from axolotl.integrations.protrain.block.layout_rules import discover_blocks + + model = _tiny_bnb_model(hidden=64, n_layers=4).to("cuda") + + trees = discover_blocks(model) + assert trees, "discover_blocks returned no block trees for bnb model" + + # Walk the discovered trees and confirm 4 ``model.layers.*`` blocks + # were enumerated. ``BlockTree.blocks`` is the authoritative list of + # block instances (the ``model.layers.{i}`` modules) and + # ``parent_path`` records where in the dotted tree they live. + block_count = sum(len(tree.blocks) for tree in trees) + assert block_count == 4, ( + f"discover_blocks expected 4 bnb blocks, got {block_count} " + f"({[t.parent_path for t in trees]})" + ) + parent_paths = {tree.parent_path for tree in trees} + assert "model.layers" in parent_paths, ( + f"discover_blocks did not anchor to model.layers (got {parent_paths})" + ) + + # Confirm the discovered block instances are the bnb-bearing + # ``TinyBlock``s (i.e. discovery did not silently swap them out for + # something else) and their inner ``self_attn`` is a real Linear4bit. + for tree in trees: + for block in tree.blocks: + assert isinstance(block.self_attn, bnb.nn.Linear4bit), ( + f"discovered block.self_attn is not Linear4bit: " + f"{type(block.self_attn).__name__}" + ) + assert isinstance(block.self_attn.weight, bnb.nn.Params4bit), ( + f"discovered block weight is not Params4bit: " + f"{type(block.self_attn.weight).__name__}" + ) + + +# --------------------------------------------------------------------------- +# Test 2: quant_state survives offload-restore round trip +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +def test_quant_state_survives_offload_round_trip() -> None: + """A ``Params4bit``'s ``quant_state`` survives a chunk-manager offload/gather round trip (QLoRA + Mode C invariant).""" + # Skip-if-missing probe; we don't need the bnb handle here because + # the model's bnb modules are accessed via their PyTorch instances. + _bnb_or_skip() + + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + # n_persist=1 with this S_chunk leaves >= 2 non-persistent block-only chunks to exercise. + hidden = 64 + n_layers = 4 + model = _tiny_bnb_model(hidden=hidden, n_layers=n_layers).to("cuda") + + # Trigger the lazy quantization by running one forward — bnb only + # populates ``quant_state`` once Params4bit.cuda() OR a Linear4bit + # forward call has happened. ``.to("cuda")`` above takes care of + # the move; this forward populates the per-weight state2 etc. + x0 = torch.randn(2, hidden, dtype=torch.bfloat16, device="cuda") + y_pre = model(x0).detach().clone() + + # Snapshot every Linear4bit's pre-offload quant_state identity and + # absmax bytes so we can compare against the post-restore state. + pre_state = {} + for i in range(n_layers): + layer = model.model.layers[i].self_attn + qs = layer.weight.quant_state + assert qs is not None, ( + f"model.layers.{i}.self_attn.weight.quant_state is None pre-offload" + ) + pre_state[i] = { + "qs_id": id(qs), + "absmax_bytes": qs.absmax.detach().clone(), + "absmax_device": qs.absmax.device, + "shape": qs.shape, + "quant_type": qs.quant_type, + } + + # Build the chunk manager. We want each block's Linear4bit weight + # to land in its own chunk AND we want embed_tokens/lm_head (the + # non-block params) to land in chunks separate from any block, so + # the non-block chunks become mandatory_persistent and the + # block-only chunks can offload. embed_tokens is bf16 64*64 = 8192 + # bytes; a single Linear4bit weight is 64*64/2 = 2048 packed bytes; + # an S_chunk of 4096 gives embed_tokens its own (oversize) chunk + # and each block weight its own chunk. + S_chunk = 4096 + mgr, layout, pool, host = _build_chunk_manager(model, n_persist=1, S_chunk=S_chunk) + try: + # Sanity: layout produced enough non-persistent chunks to exercise. + nonp_count = sum( + 1 + for cid in range(layout.N_chunk) + if cid >= 1 and cid not in layout.mandatory_persistent + ) + assert nonp_count >= 2, ( + f"test setup wants >= 2 non-persistent chunks, got {nonp_count} " + f"(N_chunk={layout.N_chunk}, " + f"mandatory={sorted(layout.mandatory_persistent)})" + ) + + # Offload — non-persistent chunks' param.data goes to pinned CPU. + freed = mgr.materialize_offload() + assert freed > 0, "materialize_offload freed 0 bytes (expected > 0)" + + # quant_state must still be attached after offload; otherwise gather + forward would crash in bnb.MatMul4Bit. + for i in range(n_layers): + layer = model.model.layers[i].self_attn + qs = layer.weight.quant_state + assert qs is not None, ( + f"layers.{i}.self_attn.weight.quant_state vanished after offload" + ) + assert id(qs) == pre_state[i]["qs_id"], ( + f"layers.{i}.self_attn.weight.quant_state was replaced (id mismatch)" + ) + # absmax is owned by the QuantState object, not the chunk-managed storage. + assert qs.absmax.device == pre_state[i]["absmax_device"], ( + f"layers.{i}.self_attn.weight.quant_state.absmax migrated devices: " + f"was {pre_state[i]['absmax_device']}, now {qs.absmax.device}" + ) + assert torch.equal(qs.absmax, pre_state[i]["absmax_bytes"]), ( + f"layers.{i}.self_attn.weight.quant_state.absmax bytes changed" + ) + + # Gather every non-persistent chunk back; Linear4bit forward must still produce identical output. + for cid in sorted(mgr._non_persistent_ids): + mgr.gather(cid) + + # Confirm post-gather quant_state attribute is still intact and + # param.data is GPU-resident at the right shape. + for i in range(n_layers): + layer = model.model.layers[i].self_attn + assert layer.weight.data.device.type == "cuda" + assert layer.weight.data.numel() > 0 + qs = layer.weight.quant_state + assert id(qs) == pre_state[i]["qs_id"], ( + f"layers.{i}.self_attn quant_state replaced during gather" + ) + + # End-to-end correctness: forward should match pre-offload bit-for-bit + # because we never modified any weight bytes — only moved them. + y_post = model(x0) + assert torch.allclose(y_pre, y_post, rtol=0, atol=0), ( + "Linear4bit forward produced different output after offload-restore " + "round trip — quant_state metadata is out of sync with stored bytes" + ) + finally: + # Always free pinned-host buffers and chunk-manager state so a failure cannot bleed into later GPU tests. + mgr.uninstall() + host.close() + del pool + + +# --------------------------------------------------------------------------- +# Test 3: 5-step training smoke through ProTrain offload + bnb 4-bit +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +def test_offload_mode_4bit_e2e_5_steps() -> None: + """Five-step Linear4bit + ProTrain offload training smoke; loss must descend across the window.""" + # Skip-if-missing probe; the bnb instances live inside the model + # factory and are accessed via PyTorch's module tree, not directly. + _bnb_or_skip() + + import torch + from torch import nn + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_bnb_model(hidden=hidden, n_layers=n_layers).to("cuda") + + # Freeze all base weights inside the block sequence — those are + # the params that will be chunk-managed and offloaded. + for layer in model.model.layers: + for p in layer.parameters(): + p.requires_grad_(False) + # embed_tokens / lm_head are outside the block sequence and will + # land in mandatory_persistent chunks; freeze them too so the only + # trainable params are the LoRA adapters added below — the test + # is about offload + bnb correctness, not full base-weight training. + for p in model.model.embed_tokens.parameters(): + p.requires_grad_(False) + for p in model.lm_head.parameters(): + p.requires_grad_(False) + + # Tiny LoRA adapter set, kept OUTSIDE the chunked block sequence — + # they live as ``model.lora_adapters.{i}`` so the layout's + # block_spans (built from ``model.layers.*``) does not claim them. + # Non-block params land in mandatory_persistent chunks (always + # GPU-resident, never offloaded), so the trainable LoRA grads do + # not engage the per-param offload-time grad hook (which would + # require a CPU optimizer attached to the chunk manager). + class LoRAAdapter(nn.Module): + def __init__(self, in_f: int, out_f: int, r: int = 2) -> None: + super().__init__() + self.lora_a = nn.Linear(in_f, r, bias=False).to( + dtype=torch.bfloat16, device="cuda" + ) + self.lora_b = nn.Linear(r, out_f, bias=False).to( + dtype=torch.bfloat16, device="cuda" + ) + + def forward(self, x: "torch.Tensor") -> "torch.Tensor": + return self.lora_b(self.lora_a(x)) + + model.lora_adapters = nn.ModuleList( + [LoRAAdapter(hidden, hidden) for _ in range(n_layers)] + ) + + # Patch each block's forward to add the corresponding LoRA delta + # AFTER the base bnb forward — same algebraic shape as a real QLoRA + # adapter, but with the adapter layer kept outside the block tree. + for i, block in enumerate(model.model.layers): + adapter = model.lora_adapters[i] + base_forward = block.forward + + def _patched(x, _base=base_forward, _adapter=adapter): + return _base(x) + _adapter(x) + + block.forward = _patched + + # Prime quant_state via one forward. + x = torch.randn(2, hidden, dtype=torch.bfloat16, device="cuda") + _ = model(x) + + # n_persist=1, S_chunk sized so each block weight gets its own chunk and embed/lm_head become mandatory_persistent. + S_chunk = 4096 + mgr, layout, pool, host = _build_chunk_manager( + model, n_persist=1, S_chunk=S_chunk, n_buffer=n_layers + ) + try: + freed = mgr.materialize_offload() + assert freed > 0, ( + f"materialize_offload freed 0 bytes — no non-persistent chunks " + f"(N_chunk={layout.N_chunk}, " + f"mandatory={sorted(layout.mandatory_persistent)})" + ) + + # Optimizer over LoRA-adapter params only; we only need to prove gather + dequant + backprop + offload composes. + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "no trainable params — LoRA wrap didn't take" + optim = torch.optim.AdamW(trainable, lr=1e-3) + + # All-resident approximation: gather every non-persistent chunk before forward, offload after step. + nonp = sorted(mgr._non_persistent_ids) + + losses: list[float] = [] + target = torch.zeros(2, hidden, dtype=torch.bfloat16, device="cuda") + + for _step in range(5): + for cid in nonp: + mgr.gather(cid) + out = model(x) + loss = (out - target).pow(2).mean() + loss.backward() + optim.step() + optim.zero_grad() + for cid in nonp: + mgr.offload(cid) + losses.append(float(loss.detach())) + + # 5% headroom over noise: a regression in the gather path (e.g. quant_state desync) would fail this. + assert len(losses) == 5 + assert losses[-1] < losses[0] * 0.95, ( + f"loss did not descend across 5 steps: {losses}" + ) + finally: + # Always free pinned-host buffers and chunk-manager state so a failure cannot bleed into later GPU tests. + mgr.uninstall() + host.close() + del pool diff --git a/tests/protrain/test_chunk_optim_shutdown.py b/tests/protrain/test_chunk_optim_shutdown.py index aad56e96f7..53255699aa 100644 --- a/tests/protrain/test_chunk_optim_shutdown.py +++ b/tests/protrain/test_chunk_optim_shutdown.py @@ -160,7 +160,7 @@ def test_shutdown_skips_missing_ds_opt_adam(): def test_shutdown_logs_destroy_failure_but_continues(caplog): """A per-chunk destroy failure is logged and does not block other chunks.""" - import logging + from axolotl.integrations.protrain.chunk import optim as optim_module adapter, fakes = _make_adapter_with_mock_ds(n_chunks=3) @@ -175,7 +175,9 @@ def destroy_adam(self, _opt_id): # noqa: ANN001 exploding = _ExplodingDs() adapter._optims[ChunkId(1)].ds_opt_adam = exploding # type: ignore[attr-defined] - with caplog.at_level(logging.WARNING, logger="axolotl"): + with mock.patch.object( + optim_module.LOG, "warning", wraps=optim_module.LOG.warning + ) as mock_warn: adapter.shutdown() # Healthy chunks still got their destroy call. @@ -183,10 +185,27 @@ def destroy_adam(self, _opt_id): # noqa: ANN001 assert len(fakes[2].destroy_calls) == 1 # The failing chunk attempted destroy exactly once. assert exploding.calls == 1 - # And the failure surfaced via a warning. - assert any( - "destroy_adam failed" in record.getMessage() for record in caplog.records - ), "Expected a warning log for the failed destroy_adam call" + # And the failure surfaced via a warning. Inspect the mock's + # call args directly — match on the format-string prefix that + # uniquely identifies the destroy_adam-failure log site. + matching_calls = [ + call + for call in mock_warn.call_args_list + if call.args + and isinstance(call.args[0], str) + and "destroy_adam failed" in call.args[0] + ] + assert matching_calls, ( + f"Expected a LOG.warning call matching 'destroy_adam failed' but got " + f"{[call.args for call in mock_warn.call_args_list]}" + ) + # The warning's format args should include the failing chunk id (1) and + # the underlying exception. Sanity-check both so a future copy-edit of + # the warning text doesn't silently mask the diagnostic content. + matching_call = matching_calls[0] + assert ChunkId(1) in matching_call.args, ( + f"warning's chunk-id format arg should be ChunkId(1); got {matching_call.args}" + ) def test_shutdown_destroys_state_even_when_wait_all_raises(): diff --git a/tests/protrain/test_cost_search.py b/tests/protrain/test_cost_search.py index 8de2e41bba..c81edb426d 100644 --- a/tests/protrain/test_cost_search.py +++ b/tests/protrain/test_cost_search.py @@ -3492,10 +3492,7 @@ def test_alpha_residual_compensates_for_unmodeled_overhead(): # so the per-component composition's boot prediction equals the # analytical lumped iter (no per-component-bias correction). boot_per_comp_pred = ( - boot_t_fwd - + boot_t_bwd - + boot_t_gpu - + max(0.0, boot_t_cpu - boot_t_bwd) + boot_t_fwd + boot_t_bwd + boot_t_gpu + max(0.0, boot_t_cpu - boot_t_bwd) ) # Stage measured phase-2 iter at 2.0 × per-component prediction # — the missing whole-iter overhead the residual α must absorb. @@ -3590,10 +3587,7 @@ def test_alpha_residual_no_op_when_per_component_explains_boot(): boot_step = max(boot_t_gpu + boot_t_cpu, 1e-12) boot_per_comp_pred = ( - boot_t_fwd - + boot_t_bwd - + boot_t_gpu - + max(0.0, boot_t_cpu - boot_t_bwd) + boot_t_fwd + boot_t_bwd + boot_t_gpu + max(0.0, boot_t_cpu - boot_t_bwd) ) # Measured iter == per-component prediction → residual α = 1.0. measured_iter = boot_per_comp_pred diff --git a/tests/protrain/test_cross_mode_resume.py b/tests/protrain/test_cross_mode_resume.py new file mode 100644 index 0000000000..573d0516e5 --- /dev/null +++ b/tests/protrain/test_cross_mode_resume.py @@ -0,0 +1,610 @@ +"""Cross-mode (Mode A persistent vs Mode C sharded+offload) checkpoint resume smoke tests.""" + +from __future__ import annotations + +import math +import os +import socket +import subprocess +import sys +import textwrap +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.gpu + + +def _build_tiny_llama_lora(): + pytest.importorskip("torch") + pytest.importorskip("transformers") + pytest.importorskip("peft") + + import torch + from peft import LoraConfig, get_peft_model + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=256, + vocab_size=512, + max_position_embeddings=64, + rms_norm_eps=1e-5, + use_cache=False, + ) + model = LlamaForCausalLM(cfg).to(dtype=torch.bfloat16) + lora_cfg = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + return get_peft_model(model, lora_cfg), cfg + + +def _wrap( + model, cfg, *, force_all_persistent: bool, zero3_shard: bool, bs: int, seq: int +): + import torch + + from axolotl.integrations.protrain.api import ( + protrain_model_wrapper, + protrain_optimizer_wrapper, + ) + from axolotl.integrations.protrain.types import HardwareProfile + + hw = HardwareProfile( + gpu_sku=torch.cuda.get_device_name(0), + gpu_memory_bytes=torch.cuda.get_device_properties(0).total_memory, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + wrapped = protrain_model_wrapper( + model, + model_config=cfg, + hardware_profile=hw, + batch_size=bs, + seq_len=seq, + capacity_bytes=4 * (1 << 30), + force_all_persistent=force_all_persistent, + zero3_shard=zero3_shard, + ) + optim = protrain_optimizer_wrapper(wrapped, lr=1e-3) + return wrapped, optim + + +def _train(wrapped, optim, *, n_iters, input_ids, labels) -> list[float]: + losses: list[float] = [] + for i in range(n_iters): + out = wrapped.module(input_ids=input_ids, labels=labels) + loss = out.loss + loss_value = float(loss.detach()) + assert math.isfinite(loss_value), f"iter {i}: non-finite loss {loss_value}" + loss.backward() + optim.step() + optim.zero_grad() + losses.append(loss_value) + return losses + + +def _resume(wrapped, optim, model_state, optim_state): + """Best-effort cross-mode load: never crash, allow optimizer-state cold-start when layouts differ.""" + underlying = getattr(wrapped, "module", wrapped) + try: + # Allow strict=False because LoRA-PEFT state dicts contain only + # trainable params; PEFT's load_state_dict accepts strict-False. + load = getattr(underlying, "load_state_dict", None) + if load is not None: + load(model_state, strict=False) + except Exception as exc: + pytest.fail(f"cross-mode model state_dict load crashed: {exc}") + + if optim_state is not None and hasattr(optim, "load_state_dict"): + try: + optim.load_state_dict(optim_state) + except Exception as exc: # noqa: BLE001 + # Documented limitation: cross-mode optimizer-state remap may + # not be implemented. We don't fail the test on this — we + # log it and let training cold-start the optimizer. + print( + f"\n[cross-mode-resume] optimizer state load failed (cold-start): {exc}" + ) + + +def _make_inputs(cfg, *, bs: int, seq: int): + import torch + + device = torch.device("cuda:0") + torch.manual_seed(0) + input_ids = torch.randint( + 0, cfg.vocab_size, (bs, seq), device=device, dtype=torch.long + ) + labels = input_ids.clone() + return input_ids, labels + + +def test_cross_mode_resume_a_to_c() -> None: + """Mode A trains+saves, Mode C re-wraps and resumes; assert finite loss with explicit close().""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain cross-mode resume smoke requires CUDA.") + + model, cfg = _build_tiny_llama_lora() + device = torch.device("cuda:0") + model = model.to(device) + + bs, seq = 1, 32 + input_ids, labels = _make_inputs(cfg, bs=bs, seq=seq) + + wrapped_c = None + try: + # Mode A: train + capture state. + wrapped_a, optim_a = _wrap( + model, cfg, force_all_persistent=True, zero3_shard=False, bs=bs, seq=seq + ) + try: + losses_a = _train( + wrapped_a, optim_a, n_iters=3, input_ids=input_ids, labels=labels + ) + underlying_a = getattr(wrapped_a, "module", wrapped_a) + model_state = { + k: v.detach().clone() for k, v in underlying_a.state_dict().items() + } + optim_state = ( + optim_a.state_dict() if hasattr(optim_a, "state_dict") else None + ) + finally: + # Explicit teardown BEFORE re-wrapping so the D2 snapshot is + # restored and the new chunk manager starts from a clean + # ``_ddp_params_and_buffers_to_ignore`` baseline. GC-only + # teardown would leave the prior wrap's hooks / pinned pool + # alive until the next allocator cycle. + close_a = getattr(wrapped_a, "close", None) + if callable(close_a): + close_a() + + # Mode C: re-wrap fresh from same model object, load state, train more. + wrapped_c, optim_c = _wrap( + model, cfg, force_all_persistent=False, zero3_shard=True, bs=bs, seq=seq + ) + _resume(wrapped_c, optim_c, model_state, optim_state) + losses_c = _train( + wrapped_c, optim_c, n_iters=3, input_ids=input_ids, labels=labels + ) + + print(f"\nA→C resume: losses_a={losses_a} losses_c={losses_c}") + + # Acceptance: no crash above; losses are finite; Mode C losses are + # not catastrophically larger than the last Mode A loss (allow 5x as + # a generous bound — the optimizer may have cold-started). + assert all(math.isfinite(v) for v in losses_c), ( + f"non-finite Mode C loss: {losses_c}" + ) + assert losses_c[0] < 5.0 * losses_a[-1] + 1.0, ( + f"Mode C loss diverged after A→C resume: a-end={losses_a[-1]} " + f"c-start={losses_c[0]} (>5x is treated as catastrophic divergence)" + ) + finally: + if wrapped_c is not None: + close_c = getattr(wrapped_c, "close", None) + if callable(close_c): + close_c() + + +def test_cross_mode_resume_c_to_a() -> None: + """Mode C trains+saves, Mode A re-wraps and resumes; symmetric to A-to-C.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain cross-mode resume smoke requires CUDA.") + + model, cfg = _build_tiny_llama_lora() + device = torch.device("cuda:0") + model = model.to(device) + + bs, seq = 1, 32 + input_ids, labels = _make_inputs(cfg, bs=bs, seq=seq) + + wrapped_a = None + try: + wrapped_c, optim_c = _wrap( + model, cfg, force_all_persistent=False, zero3_shard=True, bs=bs, seq=seq + ) + try: + losses_c = _train( + wrapped_c, optim_c, n_iters=3, input_ids=input_ids, labels=labels + ) + underlying_c = getattr(wrapped_c, "module", wrapped_c) + model_state = { + k: v.detach().clone() for k, v in underlying_c.state_dict().items() + } + optim_state = ( + optim_c.state_dict() if hasattr(optim_c, "state_dict") else None + ) + finally: + close_c = getattr(wrapped_c, "close", None) + if callable(close_c): + close_c() + + wrapped_a, optim_a = _wrap( + model, cfg, force_all_persistent=True, zero3_shard=False, bs=bs, seq=seq + ) + _resume(wrapped_a, optim_a, model_state, optim_state) + losses_a = _train( + wrapped_a, optim_a, n_iters=3, input_ids=input_ids, labels=labels + ) + + print(f"\nC→A resume: losses_c={losses_c} losses_a={losses_a}") + + assert all(math.isfinite(v) for v in losses_a), ( + f"non-finite Mode A loss: {losses_a}" + ) + assert losses_a[0] < 5.0 * losses_c[-1] + 1.0, ( + f"Mode A loss diverged after C→A resume: c-end={losses_c[-1]} " + f"a-start={losses_a[0]} (>5x is treated as catastrophic divergence)" + ) + finally: + if wrapped_a is not None: + close_a = getattr(wrapped_a, "close", None) + if callable(close_a): + close_a() + + +# Multi-GPU subprocess tests: single-process tests above auto-coerce to Mode A under +# world_size<=1, so these accelerate-launch a real LoRA workload to exercise real sharding. + + +def _pick_free_port() -> int: + """Bind to port 0 so the OS hands back a free port (avoids MASTER_PORT collisions).""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] + + +def _nvidia_smi_gpu_indices() -> list[int]: + """Return GPU indices from nvidia-smi (subprocess sidesteps CUDA_VISIBLE_DEVICES masking).""" + try: + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader,nounits"], + stderr=subprocess.DEVNULL, + timeout=10, + ).decode("utf-8", errors="replace") + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + return [] + indices: list[int] = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + try: + indices.append(int(line)) + except ValueError: + continue + return indices + + +def _nvidia_smi_gpu_count() -> int: + """Return the GPU count from nvidia-smi.""" + return len(_nvidia_smi_gpu_indices()) + + +# Indices ``_launch_axolotl`` pins via ``CUDA_VISIBLE_DEVICES``. The +# corresponding precheck must verify these specific indices actually +# exist on the host — a count-based >=4 check passes on any 4-GPU box +# but launch fails late if e.g. GPU 7 isn't present. Kept in sync with +# the env in ``_launch_axolotl``. +_REQUIRED_GPU_INDICES = (1, 4, 5, 7) + + +_MODE_A_YAML = textwrap.dedent( + """\ + base_model: NousResearch/Meta-Llama-3-8B-Instruct + model_type: LlamaForCausalLM + load_in_8bit: false + load_in_4bit: false + strict: false + datasets: + - path: tatsu-lab/alpaca + type: alpaca + val_set_size: 0.0 + output_dir: {output_dir} + {resume_line} + sequence_len: 256 + sample_packing: false + pad_to_sequence_len: false + adapter: lora + lora_r: 16 + lora_alpha: 32 + lora_dropout: 0.05 + lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj + plugins: + - axolotl.integrations.protrain.ProTrainPlugin + protrain_auto_memory: true + protrain_auto_mode: false + protrain_force_all_persistent: true + protrain_zero3_shard: false + gradient_accumulation_steps: 1 + micro_batch_size: 1 + max_steps: {max_steps} + optimizer: adamw_torch + lr_scheduler: cosine + learning_rate: 0.0002 + bf16: true + fp16: false + tf32: false + gradient_checkpointing: false + flash_attention: false + xformers_attention: false + lora_mlp_kernel: false + lora_qkv_kernel: false + lora_o_kernel: false + logging_steps: 1 + save_steps: {save_steps} + save_first_step: false + save_total_limit: 2 + warmup_steps: 2 + weight_decay: 0.0 + """ +) + + +_MODE_C_YAML = textwrap.dedent( + """\ + base_model: NousResearch/Meta-Llama-3-8B-Instruct + model_type: LlamaForCausalLM + load_in_8bit: false + load_in_4bit: false + strict: false + datasets: + - path: tatsu-lab/alpaca + type: alpaca + val_set_size: 0.0 + output_dir: {output_dir} + {resume_line} + sequence_len: 256 + sample_packing: false + pad_to_sequence_len: false + adapter: lora + lora_r: 16 + lora_alpha: 32 + lora_dropout: 0.05 + lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj + plugins: + - axolotl.integrations.protrain.ProTrainPlugin + protrain_auto_memory: true + protrain_auto_mode: false + protrain_force_all_persistent: false + protrain_zero3_shard: true + protrain_n_persist_override: 0 + protrain_n_buffer_override: 8 + protrain_n_swap_override: 0 + protrain_n_checkpoint_override: 0 + protrain_n_offload_override: 32 + gradient_accumulation_steps: 1 + micro_batch_size: 1 + max_steps: {max_steps} + optimizer: adamw_torch + lr_scheduler: cosine + learning_rate: 0.0002 + bf16: true + fp16: false + tf32: false + gradient_checkpointing: false + flash_attention: false + xformers_attention: false + lora_mlp_kernel: false + lora_qkv_kernel: false + lora_o_kernel: false + logging_steps: 1 + save_steps: {save_steps} + save_first_step: false + save_total_limit: 2 + warmup_steps: 2 + weight_decay: 0.0 + """ +) + + +def _launch_axolotl(yaml_path: Path, log_path: Path, repo_root: Path) -> int: + """Spawn ``accelerate launch`` of ``axolotl.cli.train``; pins GPUs 1/4/5/7 (stable P2P set).""" + env = os.environ.copy() + env["DS_SKIP_CUDA_CHECK"] = "1" + env["PYTHONUNBUFFERED"] = "1" + env["PYTHONPATH"] = str(repo_root / "src") + env["CUDA_VISIBLE_DEVICES"] = "1,4,5,7" + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + # Pick a free port; prevents EADDRINUSE if other torch.distributed + # processes are already bound (e.g. concurrent tests on the same + # rig). Accelerate forwards MASTER_PORT into the child group. + env.setdefault("MASTER_PORT", str(_pick_free_port())) + + cmd = [ + sys.executable, + "-m", + "accelerate.commands.launch", + "--num_processes", + "4", + "--mixed_precision", + "bf16", + "-m", + "axolotl.cli.train", + str(yaml_path), + ] + with log_path.open("w") as f: + proc = subprocess.run( + cmd, + env=env, + stdout=f, + stderr=subprocess.STDOUT, + check=False, + timeout=720, # per-launch budget; multi-GPU bring-up takes ~1 min + ) + return proc.returncode + + +def _require_real_multigpu() -> None: + """Skip helper for the multi-GPU subprocess tests.""" + visible = _nvidia_smi_gpu_indices() + missing = [i for i in _REQUIRED_GPU_INDICES if i not in visible] + if missing: + pytest.skip( + f"real multi-GPU cross-mode resume requires GPU indices " + f"{list(_REQUIRED_GPU_INDICES)} (hard-coded in " + f"``_launch_axolotl``); nvidia-smi reports {visible}, " + f"missing {missing}" + ) + # accelerate must be importable in the *child* invocation; check it + # in the parent first so we get a clean skip rather than a child- + # subprocess crash. + try: + import accelerate # noqa: F401 + except ImportError: + pytest.skip("accelerate not installed; required for multi-GPU launch") + + +def _repo_root() -> Path: + """Resolve the worktree root (parent of src/axolotl).""" + here = Path(__file__).resolve() + # tests/protrain/test_cross_mode_resume.py -> tests/protrain -> tests -> repo + return here.parents[2] + + +@pytest.mark.slow +@pytest.mark.gpu +def test_real_multigpu_cross_mode_resume_a_to_c(tmp_path: Path) -> None: + """4x3090 cross-mode A->C: subprocess trains Mode A 5 steps, resumes Mode C for 5 more.""" + _require_real_multigpu() + + repo_root = _repo_root() + workdir = tmp_path + modeA_ckpt_dir = workdir / "modeA_ckpt" + modeC_resumed_dir = workdir / "modeC_resumed" + + yaml_a = workdir / "modeA_save.yml" + yaml_a.write_text( + _MODE_A_YAML.format( + output_dir=str(modeA_ckpt_dir), + resume_line="", + max_steps=5, + save_steps=5, + ) + ) + log_a = workdir / "modeA_save.log" + rc_a = _launch_axolotl(yaml_a, log_a, repo_root) + assert rc_a == 0, ( + f"Mode A train+save subprocess exited {rc_a}; tail:\n" + f"{log_a.read_text()[-3000:]}" + ) + assert (modeA_ckpt_dir / "checkpoint-5").is_dir(), ( + f"Mode A did not produce checkpoint-5/ under {modeA_ckpt_dir}; " + f"contents: {list(modeA_ckpt_dir.iterdir()) if modeA_ckpt_dir.exists() else 'NONE'}" + ) + + yaml_c = workdir / "modeC_resume.yml" + yaml_c.write_text( + _MODE_C_YAML.format( + output_dir=str(modeC_resumed_dir), + resume_line=f"resume_from_checkpoint: {modeA_ckpt_dir / 'checkpoint-5'}", + max_steps=10, + save_steps=10, + ) + ) + log_c = workdir / "modeC_resume.log" + rc_c = _launch_axolotl(yaml_c, log_c, repo_root) + log_c_text = log_c.read_text() + assert rc_c == 0, ( + f"Mode C resume subprocess exited {rc_c}; tail:\n{log_c_text[-3000:]}" + ) + assert "Traceback" not in log_c_text, ( + f"Mode C resume produced a Traceback; tail:\n{log_c_text[-3000:]}" + ) + # Sanity: the per-step loss line format Axolotl emits contains + # ``'loss':``. Five resumed steps should leave at least 5 such lines + # (one per training_step log). Anything less means the loop didn't + # enter the resumed range. + assert log_c_text.count("'loss':") >= 5, ( + f"Mode C resume did not produce >= 5 step-loss lines; tail:\n" + f"{log_c_text[-3000:]}" + ) + + +@pytest.mark.slow +@pytest.mark.gpu +def test_real_multigpu_cross_mode_resume_c_to_a(tmp_path: Path) -> None: + """4x3090 cross-mode C->A: subprocess trains Mode C 5 steps, resumes Mode A for 5 more.""" + _require_real_multigpu() + + repo_root = _repo_root() + workdir = tmp_path + modeC_ckpt_dir = workdir / "modeC_ckpt" + modeA_resumed_dir = workdir / "modeA_resumed" + + yaml_c = workdir / "modeC_save.yml" + yaml_c.write_text( + _MODE_C_YAML.format( + output_dir=str(modeC_ckpt_dir), + resume_line="", + max_steps=5, + save_steps=5, + ) + ) + log_c = workdir / "modeC_save.log" + rc_c = _launch_axolotl(yaml_c, log_c, repo_root) + assert rc_c == 0, ( + f"Mode C train+save subprocess exited {rc_c}; tail:\n" + f"{log_c.read_text()[-3000:]}" + ) + assert (modeC_ckpt_dir / "checkpoint-5").is_dir(), ( + f"Mode C did not produce checkpoint-5/ under {modeC_ckpt_dir}" + ) + + yaml_a = workdir / "modeA_resume.yml" + yaml_a.write_text( + _MODE_A_YAML.format( + output_dir=str(modeA_resumed_dir), + resume_line=f"resume_from_checkpoint: {modeC_ckpt_dir / 'checkpoint-5'}", + max_steps=10, + save_steps=10, + ) + ) + log_a = workdir / "modeA_resume.log" + rc_a = _launch_axolotl(yaml_a, log_a, repo_root) + log_a_text = log_a.read_text() + assert rc_a == 0, ( + f"Mode A resume subprocess exited {rc_a}; tail:\n{log_a_text[-3000:]}" + ) + assert "Traceback" not in log_a_text, ( + f"Mode A resume produced a Traceback; tail:\n{log_a_text[-3000:]}" + ) + assert log_a_text.count("'loss':") >= 5, ( + f"Mode A resume did not produce >= 5 step-loss lines; tail:\n" + f"{log_a_text[-3000:]}" + ) diff --git a/tests/protrain/test_fused_lora_kernels.py b/tests/protrain/test_fused_lora_kernels.py new file mode 100644 index 0000000000..d4d4a96ccc --- /dev/null +++ b/tests/protrain/test_fused_lora_kernels.py @@ -0,0 +1,427 @@ +"""Fused LoRA kernels bypass per-Linear gather hooks; container-level hooks must gather all sub-params before the patched forward.""" + +from __future__ import annotations + +import types + +import pytest +import torch +from torch import nn + +from axolotl.integrations.protrain.profiler.on_demand import ( + OnDemandTensorMgr, + _find_fused_kernel_containers, + _is_fused_method, +) + + +# Synthetic stand-ins for axolotl.kernels.lora.apply_lora_* — same names +# so the on-demand manager's name-based detector matches them, but with +# trivial implementations that read child Linear weight refs directly +# (the same access pattern the real fused kernels use). +def apply_lora_mlp_swiglu(self, x): # noqa: D401 — stand-in + """Stand-in MLP fused kernel: direct child-Linear weight reads bypass per-Linear gather hooks.""" + gate_w = self.gate_proj.weight # [hidden, dim] + up_w = self.up_proj.weight # [hidden, dim] + down_w = self.down_proj.weight # [dim, hidden] + # Reproduces the size-mismatch crash when gate_w.data is the empty post-spill placeholder; container pre-hook must gather it first. + h = torch.nn.functional.silu(x @ gate_w.t()) * (x @ up_w.t()) + return h @ down_w.t() + + +def apply_lora_qkv(self, x): # noqa: D401 — stand-in + """Stand-in QKV fused kernel: reads q/k/v weights directly.""" + return ( + x @ self.q_proj.weight.t(), + x @ self.k_proj.weight.t(), + x @ self.v_proj.weight.t(), + ) + + +def apply_lora_o(self, x): # noqa: D401 — stand-in + """Stand-in O fused kernel: reads o_proj weight directly.""" + return x @ self.o_proj.weight.t() + + +def apply_lora_embedding(self, x): # noqa: D401 — stand-in + """Stand-in embed fused kernel: reads embed weight directly.""" + return self.weight[x] + + +class TinyMLP(nn.Module): + def __init__(self, dim: int = 8, hidden: int = 16): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden, bias=False) + self.up_proj = nn.Linear(dim, hidden, bias=False) + self.down_proj = nn.Linear(hidden, dim, bias=False) + + def forward(self, x): + # Match the fused stand-in's swiglu math so the equivalence check + # in ``test_container_pregather_runs_before_fused_forward`` is + # against an identical computation rather than a structural shim. + return self.down_proj( + torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x) + ) + + +class TinyAttn(nn.Module): + def __init__(self, dim: int = 8): + super().__init__() + self.q_proj = nn.Linear(dim, dim, bias=False) + self.k_proj = nn.Linear(dim, dim, bias=False) + self.v_proj = nn.Linear(dim, dim, bias=False) + self.o_proj = nn.Linear(dim, dim, bias=False) + + def apply_qkv(self, x): + return self.q_proj(x), self.k_proj(x), self.v_proj(x) + + def apply_o(self, x): + return self.o_proj(x) + + def forward(self, x): + q, k, v = self.apply_qkv(x) + attn = (q @ k.transpose(-1, -2)).softmax(dim=-1) @ v + return self.apply_o(attn) + + +class TinyBlock(nn.Module): + def __init__(self, dim: int = 8, hidden: int = 16): + super().__init__() + self.self_attn = TinyAttn(dim) + self.mlp = TinyMLP(dim, hidden) + + def forward(self, x): + return self.mlp(x + self.self_attn(x)) + + +class TinyModel(nn.Module): + def __init__(self, n_blocks: int = 2, dim: int = 8, hidden: int = 16): + super().__init__() + self.layers = nn.ModuleList([TinyBlock(dim, hidden) for _ in range(n_blocks)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def _patch_mlp_swiglu(model: TinyModel) -> list[nn.Module]: + """Install fused MLP kernel on every block's ``mlp`` (mirrors apply_lora_kernel_patches).""" + patched: list[nn.Module] = [] + for block in model.layers: + block.mlp.forward = types.MethodType(apply_lora_mlp_swiglu, block.mlp) + patched.append(block.mlp) + return patched + + +def _patch_attn_qkv_o(model: TinyModel) -> list[nn.Module]: + """Install fused QKV + O kernels on every block's ``self_attn``.""" + patched: list[nn.Module] = [] + for block in model.layers: + block.self_attn.apply_qkv = types.MethodType(apply_lora_qkv, block.self_attn) + block.self_attn.apply_o = types.MethodType(apply_lora_o, block.self_attn) + patched.append(block.self_attn) + return patched + + +# --------------------------------------------------------------------------- +# Detector helpers — pure logic, no torch hooks, no GPU. +# --------------------------------------------------------------------------- + + +def test_is_fused_method_recognises_swiglu(): + """A MethodType bound to apply_lora_mlp_swiglu is detected.""" + mlp = TinyMLP() + assert not _is_fused_method(mlp.forward) + mlp.forward = types.MethodType(apply_lora_mlp_swiglu, mlp) + assert _is_fused_method(mlp.forward) + + +def test_is_fused_method_recognises_all_fused_names(): + """All apply_lora_* method bindings are detected.""" + fns = [ + apply_lora_mlp_swiglu, + apply_lora_qkv, + apply_lora_o, + apply_lora_embedding, + ] + holder = nn.Linear(2, 2) + for fn in fns: + bound = types.MethodType(fn, holder) + assert _is_fused_method(bound), ( + f"Detector missed fused kernel binding for {fn.__name__}" + ) + + +def test_is_fused_method_rejects_unrelated_method(): + """Unrelated ``MethodType`` bindings (e.g. plain Linear forward) are NOT flagged.""" + + def some_other_method(self, x): + return x + + holder = nn.Linear(2, 2) + bound = types.MethodType(some_other_method, holder) + assert not _is_fused_method(bound) + + +def test_find_containers_empty_when_unpatched(): + """No containers when the model has no fused-kernel monkey-patch.""" + model = TinyModel() + assert _find_fused_kernel_containers(model) == [] + + +def test_find_containers_picks_up_mlp_only(): + """Container set lists every patched ``mlp`` (one per block).""" + model = TinyModel(n_blocks=3) + patched = _patch_mlp_swiglu(model) + found = _find_fused_kernel_containers(model) + assert found == patched, ( + f"expected exactly the patched mlps, got {found!r} vs {patched!r}" + ) + + +def test_find_containers_picks_up_qkv_and_o(): + """``self_attn`` is a single container even when both apply_qkv and apply_o are fused.""" + model = TinyModel(n_blocks=2) + patched = _patch_attn_qkv_o(model) + found = _find_fused_kernel_containers(model) + assert found == patched, ( + f"expected exactly the patched self_attns, got {found!r} vs {patched!r}" + ) + + +def test_find_containers_picks_up_mixed_set(): + """Mix of mlp + self_attn fused kernels yields all containers in module order.""" + model = TinyModel(n_blocks=2) + mlps = _patch_mlp_swiglu(model) + attns = _patch_attn_qkv_o(model) + found = _find_fused_kernel_containers(model) + # Containers appear in ``model.modules()`` order. Each block emits + # self_attn then mlp under TinyBlock's ``__init__`` order. + expected_ordered = [] + for sa, mp in zip(attns, mlps, strict=True): + expected_ordered.extend([sa, mp]) + assert found == expected_ordered, ( + f"expected interleaved [attn, mlp] x n_blocks, got {found!r}" + ) + + +# --------------------------------------------------------------------------- +# Live-hook behavior (CPU-only — gather/release semantics are device-agnostic). +# --------------------------------------------------------------------------- + + +def test_container_pregather_runs_before_fused_forward(): + """Container pre-gather restores gate_proj.weight.data before fused MLP forward, avoiding vec(0) matmul crash.""" + torch.manual_seed(0) + model = TinyModel(n_blocks=1, dim=8, hidden=16) + _patch_mlp_swiglu(model) + + x = torch.randn(2, 8) + # Reference output: run BEFORE entering the manager so weights are + # still resident at their original locations. + expected = model(x) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + # Sanity: every direct param has been spilled (cpu_storage populated). + assert len(mgr._spills) == sum(1 for _ in model.parameters()) + # Sanity: the fused container set is non-empty. + assert len(mgr._fused_containers) == 1 + # The patched forward must succeed and match the un-spilled output. + # CPU-original path: ``_pre_gather`` re-points ``param.data`` at + # ``cpu_storage`` (no device move on a CPU model), so numeric + # equivalence is byte-exact. + got = model(x) + assert torch.allclose(got, expected, atol=0, rtol=0) + + +def test_container_pregather_fires_for_qkv_and_o(): + """Both apply_qkv and apply_o entrypoints see real weights inside the patched attn forward.""" + torch.manual_seed(1) + model = TinyModel(n_blocks=1, dim=8, hidden=16) + _patch_attn_qkv_o(model) + + x = torch.randn(2, 8) + expected = model(x) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + assert len(mgr._fused_containers) == 1 + got = model(x) + assert torch.allclose(got, expected, atol=0, rtol=0) + + +def test_pre_post_hook_count_includes_per_container_pair(): + """Container hooks add exactly one pre + one post handle per fused container.""" + model = TinyModel(n_blocks=2, dim=8, hidden=16) + _patch_mlp_swiglu(model) + _patch_attn_qkv_o(model) + + n_modules = sum(1 for _ in model.modules()) + n_containers = len(_find_fused_kernel_containers(model)) + assert n_containers == 4 # 2 self_attn + 2 mlp + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + # Per-module loop registers 4 handles each (forward pre/post + + # backward pre/post). Container loop adds another 4 handles per + # container (forward pre/post + backward pre/post — backward is + # required because the fused autograd Function keeps base-weight + # refs on ctx outside the saved-tensors spill path). + expected = 4 * n_modules + 4 * n_containers + assert len(mgr._handles) == expected, ( + f"hook count mismatch: got {len(mgr._handles)}, expected {expected}" + ) + + +def test_post_release_clears_data_after_container_forward(): + """After the container forward returns, every gathered sub-param is back to empty placeholder.""" + torch.manual_seed(2) + model = TinyModel(n_blocks=1, dim=8, hidden=16) + _patch_mlp_swiglu(model) + _patch_attn_qkv_o(model) + + x = torch.randn(2, 8) + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + _ = model(x) + # Outside any module forward (we're back in the with-block but + # past the model call), the post-release hooks have all fired + # and every spilled param's .data is the empty placeholder. + for name, p in model.named_parameters(): + assert p.data.numel() == 0, ( + f"param {name} not released after forward: numel={p.data.numel()}" + ) + + +def test_unpatched_model_has_no_container_overhead(): + """When no fused kernels are installed, the container code path is a no-op.""" + model = TinyModel(n_blocks=2) + n_modules = sum(1 for _ in model.modules()) + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + assert mgr._fused_containers == [] + assert len(mgr._handles) == 4 * n_modules + + +def test_disabled_manager_skips_container_detection(): + """Disabled fast path is a true no-op even with a fully-patched model.""" + model = TinyModel(n_blocks=1) + _patch_mlp_swiglu(model) + mgr = OnDemandTensorMgr(device="cpu", disabled=True, model=model) + with mgr: + # Fast path: no spills, no container hooks. + assert mgr._fused_containers == [] + assert mgr._handles == [] + + +def test_container_backward_under_fake_fused_autograd_function(): + """Backward subtree hook must re-gather weights when fused ctx keeps them outside save_for_backward.""" + + class FakeFusedMatmul(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight): + # Save x via the standard path (covered by pack/unpack); keep + # weight as a plain Python attribute (the LoRA_MLP pattern). + ctx.save_for_backward(x) + ctx.weight = weight # outside save_for_backward — needs gather + return x @ weight.t() + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + weight = ctx.weight + # This matmul is what blows up with vec(0) when weight.data + # was cleared by the forward post-release. Same shape match + # as ``LoRA_MLP.backward``'s ``matmul_lora`` step. + grad_x = grad_output @ weight + grad_w = grad_output.t() @ x + return grad_x, grad_w + + class FakeFusedMLP(nn.Module): + def __init__(self, dim: int = 8): + super().__init__() + self.proj = nn.Linear(dim, dim, bias=False) + + def fused_forward(self, x): + return FakeFusedMatmul.apply(x, self.proj.weight) + + class FakeBlock(nn.Module): + def __init__(self, dim: int = 8): + super().__init__() + self.mlp = FakeFusedMLP(dim) + + def forward(self, x): + return self.mlp(x) + + class FakeModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([FakeBlock(8)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + torch.manual_seed(7) + model = FakeModel() + # Patch the fused MLP forward so our detector picks the container up. + for layer in model.layers: + layer.mlp.forward = types.MethodType(apply_lora_mlp_swiglu, layer.mlp) + # Also override with the FakeFusedMatmul wiring so the autograd Function + # actually runs (overrides the swiglu stand-in for THIS test only). + for layer in model.layers: + layer.mlp.forward = types.MethodType(fused_forward, layer.mlp) + + x = torch.randn(2, 8, requires_grad=True) + # Reference: forward + backward without the manager. + y_ref = model(x) + loss_ref = y_ref.sum() + loss_ref.backward() + grad_ref = {name: p.grad.detach().clone() for name, p in model.named_parameters()} + model.zero_grad(set_to_none=True) + x.grad = None + + # Re-detect: replace the fwd binding with the swiglu name (so detector + # picks up the container) but keep fused_forward as the actual call — + # detection is name-based, so we need a fused-name MethodType in place. + # Trick: re-bind the swiglu name to fused_forward via __name__ alias. + fused_forward.__name__ = "apply_lora_mlp_swiglu" # match the detector + for layer in model.layers: + layer.mlp.forward = types.MethodType(fused_forward, layer.mlp) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + assert len(mgr._fused_containers) == 1 + y = model(x) + loss = y.sum() + # Backward subtree hook re-gathers weights; absent it, autograd's bwd matmul against the post-release placeholder raises vec(0) size mismatch. + loss.backward() + + # Param grads must match the un-spilled reference (within fp32 tol). + for name, p in model.named_parameters(): + assert p.grad is not None, f"missing grad on {name}" + assert torch.allclose(p.grad, grad_ref[name], atol=1e-6), ( + f"grad on {name} differs under backward subtree hook path: " + f"max_diff={(p.grad - grad_ref[name]).abs().max().item():.3e}" + ) + + +@pytest.mark.parametrize("n_blocks", [1, 3]) +def test_container_hooks_handle_repeated_forward(n_blocks): + """Repeated forward calls under the manager all see real weights.""" + torch.manual_seed(3) + model = TinyModel(n_blocks=n_blocks, dim=8, hidden=16) + _patch_mlp_swiglu(model) + _patch_attn_qkv_o(model) + + x = torch.randn(2, 8) + expected = model(x) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + for _ in range(3): + got = model(x) + assert torch.allclose(got, expected, atol=0, rtol=0) diff --git a/tests/protrain/test_init_transient_peak.py b/tests/protrain/test_init_transient_peak.py new file mode 100644 index 0000000000..551214824f --- /dev/null +++ b/tests/protrain/test_init_transient_peak.py @@ -0,0 +1,201 @@ +"""Pin predict_init_transient_peak_bytes: iter-1 alloc spike is ~6.9x the steady predictor and must surface for the feasibility gate.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +from axolotl.integrations.protrain.api.model_wrapper import ( + predict_init_transient_peak_bytes, +) +from axolotl.integrations.protrain.cost.memory import ALPHA_FRAGMENTATION +from axolotl.integrations.protrain.types import ( + BlockId, + ChunkId, + ChunkLayout, + HardwareProfile, + ParamId, +) + +# Empirical iter-1 peak (seq-insensitive) for 30B 4-bit Mode-C: dominated by chunked-pool model load, not activations. +AUDIT_ITER1_PEAK_GIB = 17.20 + +# Sum_chunk_bytes ground truth derived from param_pool + persistent-share at the 17.20 GiB measured peak. +AUDIT_30B_4BIT_SUM_CHUNK_GIB = 15.27 + + +def _make_layout_with_chunk_bytes( + *, sum_chunk_bytes: int, n_chunk: int, s_chunk: int +) -> ChunkLayout: + """ChunkLayout whose chunk-byte sum equals sum_chunk_bytes; the stub controls per-param accounting exactly.""" + chunks = tuple((ParamId(f"p.{i}"),) for i in range(n_chunk)) + return ChunkLayout( + S_chunk=s_chunk, + N_chunk=n_chunk, + chunks=chunks, + param_to_chunk={ParamId(f"p.{i}"): ChunkId(i) for i in range(n_chunk)}, + block_to_chunks={BlockId(0): tuple(ChunkId(i) for i in range(n_chunk))}, + ) + + +def _stub_chunk_manager(layout: ChunkLayout, per_chunk_bytes: int) -> SimpleNamespace: + """Stub matching _chunk_bytes's chunk_manager.model.named_parameters(); meta-device tensors so 15 GiB worth of chunks costs zero RAM.""" + params: list[tuple[str, nn.Parameter]] = [] + for pids in layout.chunks: + for pid in pids: + # fp32 = 4 bytes/element; round up so numel * 4 >= per_chunk_bytes. + numel = max(1, (per_chunk_bytes + 3) // 4) + param = nn.Parameter( + torch.empty(numel, dtype=torch.float32, device="meta"), + requires_grad=False, + ) + params.append((str(pid), param)) + + model = SimpleNamespace(named_parameters=lambda: iter(params)) + return SimpleNamespace(model=model) + + +def _hw_profile(*, bpe: float, gpu_memory_gib: int = 24) -> HardwareProfile: + return HardwareProfile( + gpu_sku="test", + gpu_memory_bytes=gpu_memory_gib * (1 << 30), + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + dominant_param_bytes_per_element=bpe, + ) + + +def test_audit_30b_4bit_modec_within_10pct_of_measured_peak(): + """Prediction must land within 10% of the measured 17.20 GiB iter-1 peak for 30B 4-bit Mode-C.""" + n_chunk = 302 + s_chunk = 67108864 # 64 MiB — matches ext_30b_safe bootstrap log + total_target_bytes = int(AUDIT_30B_4BIT_SUM_CHUNK_GIB * (1 << 30)) + per_chunk_bytes = total_target_bytes // n_chunk + actual_sum_bytes = per_chunk_bytes * n_chunk + + layout = _make_layout_with_chunk_bytes( + sum_chunk_bytes=actual_sum_bytes, n_chunk=n_chunk, s_chunk=s_chunk + ) + chunk_manager = _stub_chunk_manager(layout, per_chunk_bytes) + # bpe=0.5 = bnb-4-bit Params4bit (the audit's actual dtype). + hw = _hw_profile(bpe=0.5) + + predicted_bytes = predict_init_transient_peak_bytes(layout, hw, chunk_manager) + predicted_gib = predicted_bytes / (1 << 30) + measured_gib = AUDIT_ITER1_PEAK_GIB + + residual = abs(predicted_gib - measured_gib) / measured_gib + assert residual <= 0.10, ( + f"iter-1 transient prediction must land within 10% of the " + f"audit-measured peak; got prediction={predicted_gib:.2f} GiB, " + f"measured={measured_gib:.2f} GiB, residual={residual * 100:.1f}%" + ) + + # And on the specific empirical anchor: 15.27 GiB x 1.10 = 16.80 GiB, + # which should match within tens of MiB (per-chunk byte-rounding + + # the actual int * float multiply at the prediction site). + expected_anchor_gib = AUDIT_30B_4BIT_SUM_CHUNK_GIB * ALPHA_FRAGMENTATION + assert predicted_gib == pytest.approx(expected_anchor_gib, rel=0.005), ( + f"prediction should anchor at sum_chunk_bytes x 1.10 = " + f"{expected_anchor_gib:.2f} GiB; got {predicted_gib:.2f} GiB" + ) + + +def test_fp16_30b_dense_predicts_full_residence_at_alpha_1_10(): + """fp16 30B dense layout: iter-1 alpha is dtype-agnostic, bpe=2.0 and bpe=0.5 yield identical predictions.""" + # 60 GiB raw model — Llama-30B at fp16 is ~60 GiB params. + n_chunk = 240 + s_chunk = 1 << 28 # 256 MiB + total_target_bytes = 60 * (1 << 30) + per_chunk_bytes = total_target_bytes // n_chunk + actual_sum_bytes = per_chunk_bytes * n_chunk + + layout = _make_layout_with_chunk_bytes( + sum_chunk_bytes=actual_sum_bytes, n_chunk=n_chunk, s_chunk=s_chunk + ) + cm = _stub_chunk_manager(layout, per_chunk_bytes) + + pred_fp16 = predict_init_transient_peak_bytes(layout, _hw_profile(bpe=2.0), cm) + pred_4bit = predict_init_transient_peak_bytes(layout, _hw_profile(bpe=0.5), cm) + + # iter-1 alpha is dtype-agnostic; the per-dtype reduction only applies in steady state. + assert pred_fp16 == pred_4bit, ( + f"iter-1 transient alpha must be dtype-agnostic; fp16 pred " + f"{pred_fp16} != 4-bit pred {pred_4bit}" + ) + + # 60 GiB x 1.10 = 66 GiB exceeds 24 GiB capacity; surfacing this lets the searcher reject all-persistent layouts. + expected_gib = 60.0 * ALPHA_FRAGMENTATION + assert pred_fp16 / (1 << 30) == pytest.approx(expected_gib, rel=0.005) + + +def test_falls_back_to_layout_upper_bound_without_chunk_manager(): + """No chunk_manager: prediction falls back to N_chunk * S_chunk * alpha, the path used pre-runtime by the feasibility gate.""" + n_chunk = 100 + s_chunk = 1 << 26 # 64 MiB + layout = _make_layout_with_chunk_bytes( + sum_chunk_bytes=0, # unused: no chunk_manager + n_chunk=n_chunk, + s_chunk=s_chunk, + ) + + pred = predict_init_transient_peak_bytes(layout, _hw_profile(bpe=0.5)) + expected = int(n_chunk * s_chunk * ALPHA_FRAGMENTATION) + assert pred == expected, ( + f"fallback path: expected {expected} bytes (N_chunk * S_chunk * alpha), got {pred}" + ) + + +def test_returns_zero_for_empty_layout(): + """Degenerate N_chunk == 0 collapses to 0, the documented "not computed" sentinel.""" + layout = ChunkLayout( + S_chunk=0, + N_chunk=0, + chunks=(), + param_to_chunk={}, + block_to_chunks={}, + ) + assert predict_init_transient_peak_bytes(layout, _hw_profile(bpe=0.5)) == 0 + + +def test_search_result_default_sentinel_is_zero(): + """Legacy SearchResult constructions without predicted_init_transient_peak_bytes must default to the 0 sentinel.""" + from axolotl.integrations.protrain.types import ( + BlockMode, + BlockStrategyMap, + CostConfig, + SearchResult, + ) + + block_map: BlockStrategyMap = {BlockId(0): BlockMode.NONE} + sr = SearchResult( + cfg=CostConfig(n_persist=0, n_buffer=1, n_swap=0, n_checkpoint=0), + block_map=block_map, + predicted_peak_bytes=1 << 30, + predicted_iter_s=0.5, + ) + assert sr.predicted_init_transient_peak_bytes == 0 + + +def test_chunk_manager_with_empty_named_parameters_falls_back(): + """Stub chunk_manager with no param overlap must fall back to the N_chunk * S_chunk upper bound, not emit 0.""" + n_chunk = 50 + s_chunk = 1 << 26 + layout = _make_layout_with_chunk_bytes( + sum_chunk_bytes=0, n_chunk=n_chunk, s_chunk=s_chunk + ) + # Empty named_parameters() → _chunk_bytes returns all-zero dict. + cm = SimpleNamespace( + model=SimpleNamespace(named_parameters=lambda: iter([])), + ) + pred = predict_init_transient_peak_bytes(layout, _hw_profile(bpe=0.5), cm) + expected_upper_bound = int(n_chunk * s_chunk * ALPHA_FRAGMENTATION) + assert pred == expected_upper_bound, ( + f"empty chunk_manager should fall back to upper bound " + f"{expected_upper_bound}, got {pred}" + ) diff --git a/tests/protrain/test_late_nccl_search_skip.py b/tests/protrain/test_late_nccl_search_skip.py new file mode 100644 index 0000000000..5ff841323c --- /dev/null +++ b/tests/protrain/test_late_nccl_search_skip.py @@ -0,0 +1,302 @@ +"""Late NCCL re-search must short-circuit when all four override knobs pin the bootstrap plan, avoiding cfg_changed RuntimeError.""" + +from __future__ import annotations + +from typing import cast +from unittest.mock import patch + +import pytest + +from axolotl.integrations.protrain.profiler.cache import ProfilerCacheKey +from axolotl.integrations.protrain.types import ( + BlockId, + BlockMode, + BlockStrategyMap, + ChunkLayout, + CostConfig, + HardwareProfile, + OpId, + OpRecord, + ProfilerTrace, + SearchResult, + WrappedModel, +) + +# --------------------------------------------------------------------------- +# Fixture builders (mirror tests/protrain/test_plugin_nccl_remeasure.py so the +# two test modules describe the helper from compatible angles). +# --------------------------------------------------------------------------- + + +def _make_trace(*, world: int = 1) -> ProfilerTrace: + """Minimal ProfilerTrace stub with empty NCCL tables matching the override-skip synthesized trace.""" + op = OpRecord( + op_id=cast(OpId, 0), + module_path="layer0", + qualified_name="aten::linear", + shape_signature=((1, 4),), + block_id=cast(BlockId, 0), + is_forward=True, + ) + return ProfilerTrace( + op_order=(op,), + intra_op_delta={cast(OpId, 0): 0}, + inter_op_delta={cast(OpId, 0): 0}, + activation_sizes={cast(BlockId, 0): 1024}, + model_state_bytes=1024, + pcie_h2d_bps=10e9, + pcie_d2h_bps=10e9, + nccl_gather_s={}, + nccl_reduce_s={}, + arch_hash="deadbeef", + bs=1, + seq=128, + sku="MockGPU", + world=world, + ) + + +def _make_layout() -> ChunkLayout: + return ChunkLayout( + S_chunk=1 << 20, + N_chunk=2, + chunks=((),), + param_to_chunk={}, + block_to_chunks={}, + ) + + +def _make_hw() -> HardwareProfile: + return HardwareProfile( + gpu_sku="MockGPU", + gpu_memory_bytes=24 * (1 << 30), + gpu_count=1, + pcie_h2d_bps=10e9, + pcie_d2h_bps=10e9, + has_nvlink=False, + ) + + +def _make_search_result() -> SearchResult: + return SearchResult( + cfg=CostConfig(n_persist=1, n_buffer=1, n_swap=0, n_checkpoint=0), + block_map=cast( + BlockStrategyMap, + {cast(BlockId, 0): BlockMode.CKPT}, + ), + predicted_peak_bytes=1 << 30, + predicted_iter_s=0.1, + ) + + +def _make_wrapped(*, with_override_flag: bool | None = False) -> WrappedModel: + """Build a WrappedModel-like object with the private attrs the helper needs (flag True/False/missing).""" + import torch.nn as nn + + trace = _make_trace(world=1) + layout = _make_layout() + hw = _make_hw() + cache_key = ProfilerCacheKey( + arch_hash="deadbeef", bs=1, seq=128, sku="MockGPU", world=1 + ) + wrapped = WrappedModel( + module=nn.Identity(), + search_result=_make_search_result(), + chunk_manager=None, + scheduler=None, + _hook_handles=[], + ) + wrapped._trace = trace # type: ignore[attr-defined] + wrapped._layout = layout # type: ignore[attr-defined] + wrapped._capacity_bytes = 22 * (1 << 30) # type: ignore[attr-defined] + wrapped._hardware_profile = hw # type: ignore[attr-defined] + wrapped._cache_key = cache_key # type: ignore[attr-defined] + if with_override_flag is not None: + wrapped._override_skip_trace = with_override_flag # type: ignore[attr-defined] + return wrapped + + +def _patch_dist(*, initialized: bool, world_size: int = 4): + """Patch ``torch.distributed`` to look like a live multi-rank PG.""" + import torch.distributed as dist + + return [ + patch.object(dist, "is_available", return_value=True), + patch.object(dist, "is_initialized", return_value=initialized), + patch.object(dist, "get_world_size", return_value=world_size), + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_late_search_skipped_when_overrides_set(): + """With _override_skip_trace=True the helper short-circuits before measure_nccl or search.search runs.""" + pytest.importorskip("torch") + + from axolotl.integrations.protrain import plugin as plugin_mod + + wrapped = _make_wrapped(with_override_flag=True) + orig_search_result = wrapped.search_result + orig_trace = wrapped._trace # type: ignore[attr-defined] + + measure_calls: list[int] = [] + search_calls: list[ProfilerTrace] = [] + + def fake_measure(world_size: int): + measure_calls.append(world_size) + return {1 << 20: 0.001}, {1 << 20: 0.001} + + def fake_search(trace, layout, capacity_bytes, hw, cpu_capacity_bytes=None): + search_calls.append(trace) + return _make_search_result() + + patches = _patch_dist(initialized=True, world_size=4) + [ + patch( + "axolotl.integrations.protrain.profiler.measure_nccl", + side_effect=fake_measure, + ), + patch( + "axolotl.integrations.protrain.search.search", + side_effect=fake_search, + ), + ] + for p in patches: + p.start() + try: + updated, changed = plugin_mod._remeasure_nccl_and_research(wrapped) + finally: + for p in patches: + p.stop() + + # Helper returned the no-op signal. + assert (updated, changed) == (False, False) + + # Crucially: neither measurement nor search ran. + assert measure_calls == [], ( + f"measure_nccl was called {measure_calls} times on the override-skip " + "path; the gate should short-circuit before the measurement." + ) + assert search_calls == [], ( + f"search.search was called {len(search_calls)} times on the override-" + "skip path; the gate should short-circuit before the re-run." + ) + + # Trace and search_result untouched (still the bootstrap synthesis). + assert wrapped.search_result is orig_search_result + assert wrapped._trace is orig_trace # type: ignore[attr-defined] + assert wrapped._trace.nccl_gather_s == {} # type: ignore[attr-defined] + assert wrapped._trace.nccl_reduce_s == {} # type: ignore[attr-defined] + # post_nccl_search_result must NOT have been stashed (no late search ran). + assert not hasattr(wrapped, "post_nccl_search_result") + assert not hasattr(wrapped, "post_nccl_trace") + + +def test_late_search_runs_when_overrides_not_set(tmp_path, monkeypatch): + """Control: _override_skip_trace=False makes measure_nccl and search.search both fire.""" + pytest.importorskip("torch") + + monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path)) + + from axolotl.integrations.protrain import plugin as plugin_mod + + wrapped = _make_wrapped(with_override_flag=False) + + fake_gather = {1 << 20: 0.0023} + fake_reduce = {1 << 20: 0.0019} + + measure_calls: list[int] = [] + search_calls: list[ProfilerTrace] = [] + + def fake_measure(world_size: int): + measure_calls.append(world_size) + return fake_gather, fake_reduce + + def fake_search(trace, layout, capacity_bytes, hw, cpu_capacity_bytes=None): + search_calls.append(trace) + # Return the SAME cfg so cfg_changed=False (no fail-fast raise). + return _make_search_result() + + patches = _patch_dist(initialized=True, world_size=4) + [ + patch( + "axolotl.integrations.protrain.profiler.measure_nccl", + side_effect=fake_measure, + ), + patch( + "axolotl.integrations.protrain.search.search", + side_effect=fake_search, + ), + ] + for p in patches: + p.start() + try: + updated, changed = plugin_mod._remeasure_nccl_and_research(wrapped) + finally: + for p in patches: + p.stop() + + # Both fired exactly once. + assert measure_calls == [4], ( + f"measure_nccl call list {measure_calls} mismatched expected [4] on " + "the non-override searcher path" + ) + assert len(search_calls) == 1, ( + f"search.search ran {len(search_calls)} times; expected 1 on the " + "non-override searcher path" + ) + + # Trace got the new tables; search_result swapped (same cfg, refreshed). + assert (updated, changed) == (True, False) + assert wrapped._trace.nccl_gather_s == fake_gather # type: ignore[attr-defined] + assert wrapped._trace.nccl_reduce_s == fake_reduce # type: ignore[attr-defined] + + +def test_late_search_skipped_when_attr_missing_does_not_skip(tmp_path, monkeypatch): + """Missing _override_skip_trace must not short-circuit; gate is positive opt-in.""" + pytest.importorskip("torch") + + monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path)) + + from axolotl.integrations.protrain import plugin as plugin_mod + + wrapped = _make_wrapped(with_override_flag=None) + assert not hasattr(wrapped, "_override_skip_trace"), ( + "test setup invariant: this case must NOT have the attribute" + ) + + measure_calls: list[int] = [] + search_calls: list[ProfilerTrace] = [] + + def fake_measure(world_size: int): + measure_calls.append(world_size) + return {1 << 20: 0.001}, {1 << 20: 0.001} + + def fake_search(trace, layout, capacity_bytes, hw, cpu_capacity_bytes=None): + search_calls.append(trace) + return _make_search_result() + + patches = _patch_dist(initialized=True, world_size=4) + [ + patch( + "axolotl.integrations.protrain.profiler.measure_nccl", + side_effect=fake_measure, + ), + patch( + "axolotl.integrations.protrain.search.search", + side_effect=fake_search, + ), + ] + for p in patches: + p.start() + try: + updated, changed = plugin_mod._remeasure_nccl_and_research(wrapped) + finally: + for p in patches: + p.stop() + + # Without the flag, the helper ran the full path (single multi-rank + # measurement, single search). + assert measure_calls == [4] + assert len(search_calls) == 1 + assert (updated, changed) == (True, False) diff --git a/tests/protrain/test_lora_offload_mode.py b/tests/protrain/test_lora_offload_mode.py new file mode 100644 index 0000000000..9614f9f24d --- /dev/null +++ b/tests/protrain/test_lora_offload_mode.py @@ -0,0 +1,975 @@ +"""Pins PEFT-LoRA container fwd/bwd hooks: detector + on-demand manager + tiny end-to-end.""" + +from __future__ import annotations + +import contextlib +import math + +import pytest +import torch +from torch import nn + +from axolotl.integrations.protrain.profiler.on_demand import ( + OnDemandTensorMgr, + _find_fused_kernel_containers, + _find_peft_lora_containers, + _has_peft_lora_factor, +) + +# --------------------------------------------------------------------------- +# Tiny synthetic LoRA layer (no PEFT install — we just put parameters in the +# canonical PEFT shape so the detector's substring rule fires). +# --------------------------------------------------------------------------- + + +class FakeLoraLayer(nn.Module): + """Synthetic PEFT LoraLayer: frozen base + trainable lora_A/lora_B ParameterDicts.""" + + def __init__(self, in_features: int, out_features: int, r: int = 4) -> None: + super().__init__() + self.base_layer = nn.Linear(in_features, out_features, bias=False) + for p in self.base_layer.parameters(): + p.requires_grad_(False) + # Match PEFT's ParameterDict layout: ``self.lora_A["default"]`` + # is the trainable ``[r, in_features]`` matrix; ``self.lora_B + # ["default"]`` is ``[out_features, r]``. The substring + # ``"lora_A"`` / ``"lora_B"`` shows up in the child's + # named_parameters and the detector picks them up. + self.lora_A = nn.ParameterDict( + {"default": nn.Parameter(torch.randn(r, in_features))} + ) + self.lora_B = nn.ParameterDict( + {"default": nn.Parameter(torch.zeros(out_features, r))} + ) + + def forward(self, x): + base_out = self.base_layer(x) + # Direct attribute reads on lora_A/lora_B skip the per-Linear gather hook, + # so without a container hook backward sees [0]-shape and ToCopyBackward0 rejects. + lora_a = self.lora_A["default"] + lora_b = self.lora_B["default"] + return base_out + (x @ lora_a.t()) @ lora_b.t() + + +class TinyPeftBlock(nn.Module): + """Block holding a base norm + a fake-PEFT-LoRA linear.""" + + def __init__(self, dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + for p in self.norm.parameters(): + p.requires_grad_(False) + self.proj = FakeLoraLayer(dim, dim, r=4) + + def forward(self, x): + return self.proj(self.norm(x)) + + +class TinyPeftModel(nn.Module): + def __init__(self, n_blocks: int = 2, dim: int = 8) -> None: + super().__init__() + self.layers = nn.ModuleList([TinyPeftBlock(dim) for _ in range(n_blocks)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +# --------------------------------------------------------------------------- +# Detector unit tests (CPU, no GPU, no torch hooks). +# --------------------------------------------------------------------------- + + +def test_has_peft_lora_factor_detects_parameter_dict(): + """A module owning a child ParameterDict named ``lora_A`` is detected.""" + layer = FakeLoraLayer(4, 4, r=2) + assert _has_peft_lora_factor(layer) + + +def test_has_peft_lora_factor_rejects_plain_linear(): + """A vanilla nn.Linear without LoRA factors is NOT detected.""" + plain = nn.Linear(4, 4) + assert not _has_peft_lora_factor(plain) + + +def test_has_peft_lora_factor_rejects_frozen_lora(): + """Detector only targets trainable PEFT factors; frozen ones don't need a container hook.""" + layer = FakeLoraLayer(4, 4, r=2) + for p in layer.lora_A.parameters(): + p.requires_grad_(False) + for p in layer.lora_B.parameters(): + p.requires_grad_(False) + assert not _has_peft_lora_factor(layer) + + +def test_find_peft_lora_containers_picks_up_each_proj(): + """One container per FakeLoraLayer instance, in module order.""" + model = TinyPeftModel(n_blocks=3, dim=8) + found = _find_peft_lora_containers(model) + expected = [block.proj for block in model.layers] + assert found == expected, f"expected one container per LoRA proj, got {found!r}" + + +def test_find_peft_lora_containers_empty_when_no_lora(): + """No PEFT factors anywhere -> empty container list.""" + model = nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 4)) + assert _find_peft_lora_containers(model) == [] + + +def test_find_peft_lora_containers_outermost_only(): + """When a parent qualifies, descendants are skipped to prevent duplicate gather-hook ref-counts.""" + # The TinyPeftBlock above already owns the LoraLayer as a direct + # child; its ``recurse_children`` walk picks up ``lora_A`` / + # ``lora_B`` on the FakeLoraLayer. The outermost detection rule + # should pin ``block.proj`` (the FakeLoraLayer itself) — NOT the + # enclosing block — because we walk modules() outside-in and the + # block's own named_parameters(recurse=False) is empty (it owns no + # trainable params directly; the only trainable params live on the + # FakeLoraLayer child's ParameterDicts). + model = TinyPeftModel(n_blocks=2, dim=8) + found = _find_peft_lora_containers(model) + expected = [block.proj for block in model.layers] + # Must be exactly the projs (not ALSO the enclosing blocks that + # would qualify under recurse_children walk). + assert found == expected + + +def test_find_peft_lora_containers_skips_fused_overlap(): + """Fused detector wins on overlap; duplicate PEFT hook would stack gather ref-counts.""" + import types + + from tests.protrain.test_fused_lora_kernels import ( + TinyModel, + _patch_attn_qkv_o, + apply_lora_mlp_swiglu, + ) + + model = TinyModel(n_blocks=1, dim=8, hidden=16) + # Fuse the MLP forward AND attach a LoRA factor onto its gate_proj + # so the same module qualifies under both detectors. + block = model.layers[0] + block.mlp.forward = types.MethodType(apply_lora_mlp_swiglu, block.mlp) + # Plant a trainable LoRA-shaped ParameterDict on the same fused MLP. + block.mlp.lora_A = nn.ParameterDict({"default": nn.Parameter(torch.randn(2, 8))}) + block.mlp.lora_B = nn.ParameterDict({"default": nn.Parameter(torch.zeros(16, 2))}) + + fused = _find_fused_kernel_containers(model) + peft = _find_peft_lora_containers(model) + assert block.mlp in fused + assert block.mlp not in peft, ( + "PEFT detector must defer to the fused detector when both match" + ) + # Independent helper: ensure attn (no fused, no LoRA) shows up nowhere. + assert _patch_attn_qkv_o is not None # smoke import only + + +# Live-hook behavior — CPU-only, exercises gather/release semantics for PEFT-LoRA containers. + + +def test_lora_container_hooks_install_on_enter(): + """Entering the manager registers container hooks for every PEFT proj.""" + model = TinyPeftModel(n_blocks=2, dim=8) + n_modules = sum(1 for _ in model.modules()) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + # Detection populated the per-container list. + assert len(mgr._peft_lora_containers) == 2 + assert mgr._peft_lora_containers == [block.proj for block in model.layers] + # No fused containers in this model (no fused-kernel patches). + assert mgr._fused_containers == [] + # Per-module hook count: 4 per module (fwd pre/post + bwd pre/post) + # plus the per-container quartet for each PEFT container. + n_peft_containers = len(mgr._peft_lora_containers) + expected = 4 * n_modules + 4 * n_peft_containers + assert len(mgr._handles) == expected + + +def test_lora_container_pregather_runs_before_forward(): + """Forward through PEFT-LoRA layers under the manager matches un-spilled output.""" + torch.manual_seed(0) + model = TinyPeftModel(n_blocks=1, dim=8) + x = torch.randn(2, 8) + expected = model(x) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + # Spill is in place: every parameter has been moved to cpu_storage + # and replaced with an empty placeholder. + assert len(mgr._spills) == sum(1 for _ in model.parameters()) + got = model(x) + # CPU-original spill: re-gathered tensor IS the original tensor, + # so byte-exact equivalence holds. + assert torch.allclose(got, expected, atol=0, rtol=0) + + +def test_lora_container_backward_succeeds_under_spill(): + """Pins PEFT-LoRA backward under spill: ToCopyBackward0 invalid-gradient-[0] without container hook.""" + torch.manual_seed(1) + model = TinyPeftModel(n_blocks=2, dim=8) + + x = torch.randn(2, 8, requires_grad=False) + target = torch.zeros(2, 8) + + # Reference path: forward + backward without the manager — captures + # the un-spilled grads to compare against. Run manually so we hold + # onto the grad tensors before zeroing. + out_ref = model(x) + loss_ref = (out_ref - target).pow(2).mean() + loss_ref.backward() + grad_ref = { + name: p.grad.detach().clone() + for name, p in model.named_parameters() + if p.grad is not None + } + model.zero_grad(set_to_none=True) + + # Hooked path: same forward + backward inside the manager. + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + assert len(mgr._peft_lora_containers) == 2 + out = model(x) + loss = (out - target).pow(2).mean() + # Without the container backward hook this raises invalid-gradient-[0]. + loss.backward() + + # Every trainable param produced a finite grad (presence is the + # fundamental assertion; numerical equivalence is a strict bonus). + for name, p in model.named_parameters(): + if not p.requires_grad: + continue + assert p.grad is not None, f"missing grad on {name} after hooked backward" + assert torch.isfinite(p.grad).all(), f"non-finite grad on {name}" + # CPU-original spill is byte-equivalent so grad numerics should + # match the reference within fp32 round-off. + assert torch.allclose(p.grad, grad_ref[name], atol=1e-6), ( + f"grad on {name} differs under hook path: " + f"max_diff={(p.grad - grad_ref[name]).abs().max().item():.3e}" + ) + + +def test_lora_container_post_release_clears_data_after_forward(): + """After model(x) completes, every spilled param is back to placeholder.""" + torch.manual_seed(2) + model = TinyPeftModel(n_blocks=1, dim=8) + x = torch.randn(2, 8) + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + _ = model(x) + # Outside any module forward, every spilled param's .data is + # back to the empty placeholder. + for name, p in model.named_parameters(): + assert p.data.numel() == 0, ( + f"param {name} not released after forward: numel={p.data.numel()}" + ) + + +def test_lora_container_hooks_dormant_when_no_lora(): + """Models without PEFT factors install no PEFT-LoRA container hooks.""" + model = nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 4)) + n_modules = sum(1 for _ in model.modules()) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + assert mgr._peft_lora_containers == [] + # Per-module quartet only — no container quartet. + assert len(mgr._handles) == 4 * n_modules + + +# E2E smoke: 5 fwd+bwd+step iterations on a tiny LoRA model under the on-demand spill manager. + + +def test_e2e_5_steps_lora_under_on_demand(): + """Pins 5 fwd+bwd iterations of a tiny PEFT-LoRA model under the on-demand spill manager.""" + torch.manual_seed(3) + model = TinyPeftModel(n_blocks=2, dim=16) + + x = torch.randn(4, 16) + target = torch.zeros(4, 16) + + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "no trainable params — LoRA wrap didn't take" + + losses: list[float] = [] + grad_max_per_iter: list[float] = [] + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + for _step in range(5): + model.zero_grad(set_to_none=True) + out = model(x) + loss = (out - target).pow(2).mean() + # Without the container backward hook, this raises iter-0: + # "ToCopyBackward0 returned an invalid gradient at index 0 + # — got [...] but expected shape compatible with [0]". + loss.backward() + losses.append(float(loss.detach())) + # Capture the largest grad magnitude across trainable + # params — proves gradients actually flowed (a silently + # failed bwd would leave grads at None or all-zero). + max_g = 0.0 + for p in trainable: + if p.grad is not None: + max_g = max(max_g, float(p.grad.abs().max())) + grad_max_per_iter.append(max_g) + + assert len(losses) == 5 + assert all(math.isfinite(v) for v in losses), f"non-finite loss: {losses}" + # Every iteration produced finite, non-zero grads. + assert all(g > 0.0 and math.isfinite(g) for g in grad_max_per_iter), ( + f"grads vanished or non-finite under hook path: {grad_max_per_iter}" + ) + + +def test_e2e_with_disabled_manager_baseline(): + """Sanity: disabled manager is a no-op and full fwd+bwd+optim.step works.""" + torch.manual_seed(3) + model = TinyPeftModel(n_blocks=2, dim=16) + + x = torch.randn(4, 16) + target = torch.zeros(4, 16) + + trainable = [p for p in model.parameters() if p.requires_grad] + optim = torch.optim.AdamW(trainable, lr=1e-2) + + losses: list[float] = [] + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=True, model=model) + with mgr: + for _step in range(5): + optim.zero_grad(set_to_none=True) + out = model(x) + loss = (out - target).pow(2).mean() + loss.backward() + losses.append(float(loss.detach())) + optim.step() + + assert len(losses) == 5 + assert losses[-1] < losses[0] * 0.95, losses + + +def test_lora_container_fwd_hook_count_includes_per_container_pair(): + """Per-container hook count: exactly 4 handles per detected container.""" + model = TinyPeftModel(n_blocks=3, dim=8) + n_modules = sum(1 for _ in model.modules()) + n_containers = len(_find_peft_lora_containers(model)) + assert n_containers == 3 + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + # Per-module loop: 4 handles each (forward pre/post + backward + # pre/post). Container loop: another 4 handles per container + # (forward pre/post + backward pre/post). + expected = 4 * n_modules + 4 * n_containers + assert len(mgr._handles) == expected, ( + f"hook count mismatch: got {len(mgr._handles)}, expected {expected}" + ) + + +@pytest.mark.parametrize("n_blocks", [1, 4]) +def test_lora_repeated_forward_under_manager(n_blocks): + """Repeated forward calls under the manager all see real LoRA weights.""" + torch.manual_seed(5) + model = TinyPeftModel(n_blocks=n_blocks, dim=8) + x = torch.randn(2, 8) + expected = model(x) + + mgr = OnDemandTensorMgr(device=torch.device("cpu"), disabled=False, model=model) + with mgr: + for _ in range(3): + got = model(x) + assert torch.allclose(got, expected, atol=0, rtol=0) + + +# Runtime-side coverage: per-LoRA-container hook installation + chunk-id closure capture +# so a future runtime gather-chain reorder cannot re-introduce the placeholder-shape bwd gap. + + +class _AttnLikeBlock(nn.Module): + """TinyPeftBlock variant exposing self_attn so discover_blocks' attention heuristic fires.""" + + def __init__(self, dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + for p in self.norm.parameters(): + p.requires_grad_(False) + # Wrap the FakeLoraLayer under ``self_attn`` so the + # discover_blocks attention heuristic identifies the + # enclosing ModuleList as a block list. + self.self_attn = FakeLoraLayer(dim, dim, r=4) + + def forward(self, x): + return self.self_attn(self.norm(x)) + + +class _TinyAttnPeftModel(nn.Module): + """Discover-blocks-friendly PEFT-LoRA fixture: ModuleList of _AttnLikeBlock with self_attn FakeLoraLayer.""" + + def __init__(self, n_blocks: int = 2, dim: int = 8) -> None: + super().__init__() + self.layers = nn.ModuleList([_AttnLikeBlock(dim) for _ in range(n_blocks)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def _build_runtime_chunk_layout(model: nn.Module, S_chunk: int): + """Build a ChunkLayout treating each layers.{i} as a block (no CUDA / no protrain_model_wrapper).""" + from typing import cast as _cast + + from axolotl.integrations.protrain.chunk.layout import build_layout + from axolotl.integrations.protrain.types import ( + BlockId as _BlockId, + ParamId as _ParamId, + ) + + # Block spans: each ``layers.{i}`` maps to its trainable + frozen + # parameter dotted-name list. The detector in + # _find_peft_lora_containers walks ``model.modules()`` and tags + # each ``FakeLoraLayer`` instance regardless of where in the tree + # it lives, so the spans need only steer build_layout's + # block-contiguity packing (every LoRA factor lands in a chunk + # owned by its enclosing block). + block_spans: dict = {} + for name, _ in model.named_parameters(): + if name.startswith("layers."): + idx = int(name.split(".")[1]) + block_spans.setdefault(_cast(_BlockId, idx), []).append( + _cast(_ParamId, name) + ) + exec_order = [_cast(_ParamId, n) for n, _ in model.named_parameters()] + return build_layout(model, exec_order, S_chunk, block_spans) + + +class _RecordingScheduler: + """Stub Scheduler capturing ensure_chunks_resident calls (keeps install_hooks tests CPU-portable).""" + + def __init__(self) -> None: + # Each entry: (call_kind, tuple_of_chunk_ids). call_kind + # encodes whether the call originated from a block-level or + # container-level hook, so tests can assert ordering and + # aggregation independently. + self.calls: list[tuple[str, tuple]] = [] + + def pre_block_forward(self, block_id) -> None: + self.calls.append(("pre_block_forward", (int(block_id),))) + + def post_block_forward(self, block_id) -> None: + self.calls.append(("post_block_forward", (int(block_id),))) + + def pre_block_backward(self, block_id) -> None: + self.calls.append(("pre_block_backward", (int(block_id),))) + + def post_block_backward(self, block_id) -> None: + self.calls.append(("post_block_backward", (int(block_id),))) + + def ensure_block_resident(self, block_id) -> None: + self.calls.append(("ensure_block_resident", (int(block_id),))) + + def ensure_chunks_resident(self, chunk_ids) -> None: + # Tag each call with the originating LoRA-container hook edge so per-edge tests + # can distinguish pre/post forward/backward firings via the factory qualname. + import sys + + edge_tag = "ensure_chunks_resident" + try: + caller_frame = sys._getframe(1) + qualname = caller_frame.f_code.co_qualname + except (AttributeError, ValueError): # pragma: no cover + qualname = "" + for needle, edge in ( + ("_make_lora_container_pre_forward_hook", "pre_forward"), + ("_make_lora_container_post_forward_hook", "post_forward"), + ("_make_lora_container_pre_backward_hook", "pre_backward"), + ("_make_lora_container_post_backward_hook", "post_backward"), + ): + if needle in qualname: + edge_tag = f"ensure_chunks_resident:{edge}" + break + self.calls.append((edge_tag, tuple(int(c) for c in chunk_ids))) + + +class _RecordingChunkManagerStub: + """Minimal ChunkManager stand-in exposing only layout + _params_by_id (what install_hooks reads).""" + + def __init__(self, model: nn.Module, layout) -> None: + from typing import cast as _cast + + from axolotl.integrations.protrain.types import ParamId as _ParamId + + self.layout = layout + self._params_by_id = { + _cast(_ParamId, name): p for name, p in model.named_parameters() + } + + +def test_install_hooks_attaches_lora_container_pre_hooks_cpu(): + """install_hooks adds 4-hook quartets per block AND per PEFT-LoRA container (fwd+bwd pre+post).""" + from axolotl.integrations.protrain.runtime.hooks import install_hooks + from axolotl.integrations.protrain.types import ( + BlockId as _BlockId, + BlockMode as _BlockMode, + ) + + torch.manual_seed(7) + n_blocks = 3 + model = _TinyAttnPeftModel(n_blocks=n_blocks, dim=8) + + layout = _build_runtime_chunk_layout(model, S_chunk=4096) + cm = _RecordingChunkManagerStub(model, layout) + sched = _RecordingScheduler() + + block_map = {_BlockId(i): _BlockMode.NONE for i in range(n_blocks)} + + handles = install_hooks( + model=model, + chunk_manager=cm, # type: ignore[arg-type] + block_map=block_map, + scheduler=sched, # type: ignore[arg-type] + ) + try: + # Per-block: 4 hooks (fwd pre/post + bwd pre/post). Per LoRA container: also 4 hooks. + n_containers = len(_find_peft_lora_containers(model)) + assert n_containers == n_blocks # one FakeLoraLayer per block + expected = 4 * n_blocks + 4 * n_containers + assert len(handles) == expected, ( + f"hook count mismatch: got {len(handles)} expected {expected} " + f"(blocks={n_blocks}, containers={n_containers})" + ) + finally: + # Best-effort removal per-handle so one failure does not skip the rest. + for h in handles: + with contextlib.suppress(Exception): + h.remove() + + +def test_install_hooks_lora_container_chunk_ids_cover_lora_factors(): + """Each LoRA container's chunk-id closure covers every trainable LoRA factor under it.""" + from axolotl.integrations.protrain.runtime.hooks import _container_chunk_ids + + torch.manual_seed(8) + n_blocks = 2 + model = _TinyAttnPeftModel(n_blocks=n_blocks, dim=8) + + layout = _build_runtime_chunk_layout(model, S_chunk=4096) + cm = _RecordingChunkManagerStub(model, layout) + + containers = _find_peft_lora_containers(model) + assert len(containers) == n_blocks + + for container in containers: + cids = _container_chunk_ids(container, cm) # type: ignore[arg-type] + assert cids, f"container {container} produced empty chunk-id set" + # Verify each trainable LoRA factor reachable from the container + # lands in one of the captured chunk ids — this is the + # correctness invariant the runtime hook depends on. + cm_id_to_name = {id(p): name for name, p in cm._params_by_id.items()} + for p in container.parameters(recurse=True): + if not p.requires_grad: + continue + cm_name = cm_id_to_name.get(id(p)) + if cm_name is None: + continue + cid = layout.param_to_chunk.get(cm_name) + assert cid in cids, ( + f"trainable param {cm_name} (chunk {cid}) not in container's " + f"captured chunk-id set {cids}" + ) + + +def test_install_hooks_lora_container_pre_forward_fires_ensure_chunks_resident(): + """forward-pre hook fires ensure_chunks_resident with non-empty chunk-id tuples per container.""" + from axolotl.integrations.protrain.runtime.hooks import install_hooks + from axolotl.integrations.protrain.types import ( + BlockId as _BlockId, + BlockMode as _BlockMode, + ) + + torch.manual_seed(9) + n_blocks = 2 + model = _TinyAttnPeftModel(n_blocks=n_blocks, dim=8) + + layout = _build_runtime_chunk_layout(model, S_chunk=4096) + cm = _RecordingChunkManagerStub(model, layout) + sched = _RecordingScheduler() + block_map = {_BlockId(i): _BlockMode.NONE for i in range(n_blocks)} + + handles = install_hooks( + model=model, + chunk_manager=cm, # type: ignore[arg-type] + block_map=block_map, + scheduler=sched, # type: ignore[arg-type] + ) + try: + x = torch.randn(2, 8) + _ = model(x) + + # Filter on edge-tagged label so deletion of pre-forward (while post-forward stays) fails. + pre_fwd_calls = [ + c for c in sched.calls if c[0] == "ensure_chunks_resident:pre_forward" + ] + assert len(pre_fwd_calls) >= n_blocks, ( + f"expected at least {n_blocks} ensure_chunks_resident:pre_forward " + f"calls (one per container), got {len(pre_fwd_calls)} " + f"(all calls: {sched.calls})" + ) + for _kind, cids in pre_fwd_calls: + assert cids, "ensure_chunks_resident:pre_forward invoked with empty tuple" + finally: + # Best-effort removal per-handle so one failure does not skip the rest. + for h in handles: + with contextlib.suppress(Exception): + h.remove() + + +def test_install_hooks_lora_container_post_forward_fires_ensure_chunks_resident(): + """post-forward hook fires ensure_chunks_resident on each LoRA container (defense-in-depth re-bind).""" + from axolotl.integrations.protrain.runtime.hooks import install_hooks + from axolotl.integrations.protrain.types import ( + BlockId as _BlockId, + BlockMode as _BlockMode, + ) + + torch.manual_seed(11) + n_blocks = 2 + model = _TinyAttnPeftModel(n_blocks=n_blocks, dim=8) + + layout = _build_runtime_chunk_layout(model, S_chunk=4096) + cm = _RecordingChunkManagerStub(model, layout) + sched = _RecordingScheduler() + block_map = {_BlockId(i): _BlockMode.NONE for i in range(n_blocks)} + + handles = install_hooks( + model=model, + chunk_manager=cm, # type: ignore[arg-type] + block_map=block_map, + scheduler=sched, # type: ignore[arg-type] + ) + try: + x = torch.randn(2, 8) + _ = model(x) + + # Assert BOTH edges fired independently so dropping either is caught. + n_containers = n_blocks # one FakeLoraLayer per block + pre_fwd_calls = [ + c for c in sched.calls if c[0] == "ensure_chunks_resident:pre_forward" + ] + post_fwd_calls = [ + c for c in sched.calls if c[0] == "ensure_chunks_resident:post_forward" + ] + assert len(pre_fwd_calls) >= n_containers, ( + f"expected at least {n_containers} ensure_chunks_resident:pre_forward " + f"calls (one per container per forward pass), got " + f"{len(pre_fwd_calls)} (all calls: {sched.calls})" + ) + assert len(post_fwd_calls) >= n_containers, ( + f"expected at least {n_containers} ensure_chunks_resident:post_forward " + f"calls (one per container per forward pass), got " + f"{len(post_fwd_calls)} (all calls: {sched.calls})" + ) + finally: + # Best-effort removal per-handle so one failure does not skip the rest. + for h in handles: + with contextlib.suppress(Exception): + h.remove() + + +def test_install_hooks_lora_container_post_backward_fires_ensure_chunks_resident(): + """post-backward hook fires ensure_chunks_resident; pins all 4 hook-quartet edges over fwd+bwd.""" + from axolotl.integrations.protrain.runtime.hooks import install_hooks + from axolotl.integrations.protrain.types import ( + BlockId as _BlockId, + BlockMode as _BlockMode, + ) + + torch.manual_seed(12) + n_blocks = 2 + model = _TinyAttnPeftModel(n_blocks=n_blocks, dim=8) + + layout = _build_runtime_chunk_layout(model, S_chunk=4096) + cm = _RecordingChunkManagerStub(model, layout) + sched = _RecordingScheduler() + block_map = {_BlockId(i): _BlockMode.NONE for i in range(n_blocks)} + + handles = install_hooks( + model=model, + chunk_manager=cm, # type: ignore[arg-type] + block_map=block_map, + scheduler=sched, # type: ignore[arg-type] + ) + try: + x = torch.randn(2, 8, requires_grad=False) + target = torch.zeros(2, 8) + out = model(x) + loss = (out - target).pow(2).mean() + loss.backward() + + n_containers = n_blocks + # Assert all four quartet edges fired so dropping any single edge is caught. + per_edge_calls = { + edge: [c for c in sched.calls if c[0] == f"ensure_chunks_resident:{edge}"] + for edge in ( + "pre_forward", + "post_forward", + "pre_backward", + "post_backward", + ) + } + for edge, calls in per_edge_calls.items(): + assert len(calls) >= n_containers, ( + f"expected at least {n_containers} " + f"ensure_chunks_resident:{edge} calls (one per container " + f"per fwd/bwd window), got {len(calls)}. " + f"per-edge counts: " + f"{ {e: len(c) for e, c in per_edge_calls.items()} } " + f"(all calls: {sched.calls})" + ) + finally: + # Best-effort removal per-handle so one failure does not skip the rest. + for h in handles: + with contextlib.suppress(Exception): + h.remove() + + +def test_install_hooks_no_lora_no_container_hooks(): + """Non-LoRA model gets only block-quartet hooks; container walk does not raise.""" + from axolotl.integrations.protrain.runtime.hooks import install_hooks + from axolotl.integrations.protrain.types import ( + BlockId as _BlockId, + BlockMode as _BlockMode, + ) + + class _PlainAttnBlock(nn.Module): + def __init__(self, dim): + super().__init__() + # Expose ``self_attn`` so discover_blocks' attention + # heuristic identifies the enclosing ModuleList as a + # block list (mirrors _AttnLikeBlock). + self.self_attn = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + return self.self_attn(x) + + class _PlainModel(nn.Module): + def __init__(self, n: int, dim: int) -> None: + super().__init__() + self.layers = nn.ModuleList([_PlainAttnBlock(dim) for _ in range(n)]) + + def forward(self, x): + for b in self.layers: + x = b(x) + return x + + n_blocks = 2 + model = _PlainModel(n_blocks, dim=4) + layout = _build_runtime_chunk_layout(model, S_chunk=4096) + cm = _RecordingChunkManagerStub(model, layout) + sched = _RecordingScheduler() + block_map = {_BlockId(i): _BlockMode.NONE for i in range(n_blocks)} + + handles = install_hooks( + model=model, + chunk_manager=cm, # type: ignore[arg-type] + block_map=block_map, + scheduler=sched, # type: ignore[arg-type] + ) + try: + # 4 per block, 0 per container. + assert len(handles) == 4 * n_blocks + finally: + # Best-effort removal per-handle so one failure does not skip the rest. + for h in handles: + with contextlib.suppress(Exception): + h.remove() + + +# --------------------------------------------------------------------------- +# Real-runtime end-to-end (GPU-gated): exercise the full +# ChunkManager + Scheduler stack against a tiny PEFT-LoRA model and +# confirm the LoRA forward + backward succeed under offload mode. +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +def test_runtime_lora_e2e_under_offload_mode_smoke(): + """Pins PEFT-LoRA fwd+bwd through real ChunkManager+Scheduler under non-persistent chunks.""" + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + # Probe DeepSpeedCPUAdam availability so we can run the fwd+bwd validation + # even on degraded CPU-Adam environments (tolerating the offload-step skip). + cpu_adam_available = False + try: + import deepspeed # noqa: F401 + from deepspeed.ops.adam import DeepSpeedCPUAdam + + # Probe the JIT-loaded extension by attempting one construction; + # CUDA/torch toolchain mismatch surfaces here. + _probe = torch.nn.Parameter(torch.zeros(2, dtype=torch.float32)) + try: + DeepSpeedCPUAdam([_probe], lr=1e-3) + cpu_adam_available = True + except Exception: # noqa: BLE001 + cpu_adam_available = False + except ImportError: + cpu_adam_available = False + + pytest.importorskip("peft") + pytest.importorskip("transformers") + + from peft import LoraConfig, get_peft_model + from transformers import LlamaConfig, LlamaForCausalLM + + from axolotl.integrations.protrain.api import ( + protrain_model_wrapper, + protrain_optimizer_wrapper, + ) + from axolotl.integrations.protrain.types import HardwareProfile + + # Sized so build_layout produces enough chunks that LoRA factors + # land in non-persistent chunks (mandatory_persistent only covers + # embed / final-norm). + cfg = LlamaConfig( + hidden_size=512, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=8, + intermediate_size=1024, + vocab_size=1024, + max_position_embeddings=64, + rms_norm_eps=1e-5, + use_cache=False, + ) + torch.manual_seed(13) + base_model = LlamaForCausalLM(cfg).to(dtype=torch.bfloat16, device="cuda") + lora_cfg = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + model = get_peft_model(base_model, lora_cfg).to(device="cuda") + + # Force a small S_chunk so multiple chunks emerge and LoRA + # factors land in non-persistent chunks. + import axolotl.integrations.protrain.api.model_wrapper as mw + + orig_pick = mw.pick_S_chunk + mw.pick_S_chunk = lambda *a, **k: 1 << 20 # 1 MiB + try: + hw = HardwareProfile( + gpu_sku=torch.cuda.get_device_name(0), + gpu_memory_bytes=torch.cuda.get_device_properties(0).total_memory, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + # Env-failure substrings degrade this smoke to skip; any other + # ValueError/RuntimeError surfaces as a real wrapper regression. + _wrapper_env_failure_substrings = ( + "DeepSpeedCPUAdam", # CPU Adam JIT-load failed + "CUDA version", # DeepSpeed CUDA/torch toolchain mismatch + "bitsandbytes", # bnb load issues + "No module named", # ModuleNotFoundError surface + # Searcher / capacity gates that legitimately mean + # "config not feasible on this rig", not "wrapper + # regression": + "no feasible config", + "cpu_capacity", + "capacity_bytes", + ) + + def _is_wrapper_env_failure(exc: BaseException) -> bool: + msg = str(exc) + return any(sub in msg for sub in _wrapper_env_failure_substrings) + + try: + wrapped = protrain_model_wrapper( + model, + model_config=cfg, + hardware_profile=hw, + batch_size=1, + seq_len=32, + capacity_bytes=2 * (1 << 30), + force_all_persistent=False, + zero3_shard=False, + n_persist_override=0, + n_buffer_override=16, + n_swap_override=0, + n_checkpoint_override=0, + n_offload_override=cfg.num_hidden_layers, + ) + except (ValueError, RuntimeError) as exc: + if not _is_wrapper_env_failure(exc): + # Real wrapper regression — let it surface. + raise + pytest.skip(f"protrain_model_wrapper offload setup unavailable: {exc}") + + # Env-failure substrings degrade to skip optimizer round-trip; deferred: + # narrow further once exact DeepSpeedCPUAdam/torchao/apex error strings are captured. + _env_failure_substrings = ( + "DeepSpeedCPUAdam", # DeepSpeed CPU Adam JIT-load failure + "CUDA version", # DeepSpeed CUDA/torch toolchain mismatch + "bitsandbytes", # bnb load issues + "No module named", # ModuleNotFoundError surface + "missing CPU optimizer for offloaded chunk", + # The fix-3 validation signal — backward unwound past the + # LoRA bf16-cast node BEFORE the per-chunk grad hook + # raised; the message confirms the fix worked. + ) + + def _is_env_failure(exc: BaseException) -> bool: + msg = str(exc) + return any(sub in msg for sub in _env_failure_substrings) + + optim = None + if cpu_adam_available: + try: + optim = protrain_optimizer_wrapper(wrapped, lr=1e-3) + except RuntimeError as exc: + # Only suppress documented env-failure signatures; real + # protrain_optimizer_wrapper regressions must surface. + if not _is_env_failure(exc): + raise + optim = None + + input_ids = torch.randint( + 0, cfg.vocab_size, (1, 32), device="cuda", dtype=torch.long + ) + labels = input_ids.clone() + # iter-0 backward must NOT raise ToCopyBackward0 invalid-gradient-[0]: + # that signals the LoRA gather-before-cast invariant was broken. + out = wrapped.module(input_ids=input_ids, labels=labels) + loss = out.loss + loss_v = float(loss.detach()) + assert math.isfinite(loss_v), f"non-finite loss: {loss_v}" + # Tolerate "missing CPU optimizer for offloaded chunk" since backward + # already unwound past the LoRA cast node before the offload-step hook fires. + try: + loss.backward() + except RuntimeError as exc: + msg = str(exc) + if "ToCopyBackward" in msg: + pytest.fail( + f"regression: ToCopyBackward0 fired in backward — " + f"runtime LoRA gather hook did not cover the autograd " + f"shape-derivation step.\n{exc}" + ) + if "missing CPU optimizer for offloaded chunk" in msg: + pass + else: + raise + # Only suppress documented env-failure substrings; real optim.step regressions surface. + if optim is not None: + try: + optim.step() + optim.zero_grad() + except (RuntimeError, ImportError) as exc: + if not _is_env_failure(exc): + raise + finally: + mw.pick_S_chunk = orig_pick diff --git a/tests/protrain/test_modec_steady_peak_accuracy.py b/tests/protrain/test_modec_steady_peak_accuracy.py new file mode 100644 index 0000000000..899bca4136 --- /dev/null +++ b/tests/protrain/test_modec_steady_peak_accuracy.py @@ -0,0 +1,220 @@ +"""bnb-4-bit Mode-C steady-peak: predictor must charge the full ckpt-chain residual sum across all CKPT blocks.""" + +from __future__ import annotations + +import pytest + +from axolotl.integrations.protrain.block.layout_rules import assign_modes +from axolotl.integrations.protrain.cost.memory import estimate_peak +from axolotl.integrations.protrain.types import ( + BlockId, + ChunkId, + ChunkLayout, + CostConfig, + HardwareProfile, + ParamId, + ProfilerTrace, +) + +GiB = 1 << 30 + + +# Llama-30B (huggyllama/llama-30b) architecture from +# ``m0_artifacts/ext_30b_seq{512,1024,2048}.yml``: +# num_hidden_layers = 60 +# hidden_size = 6656 +# intermediate_size = 17920 +# num_attention_heads = 52 +# vocab_size = 32000 +LLAMA_30B_N_BLOCK = 60 +LLAMA_30B_INTERMEDIATE = 17920 + +# Audit Mode-C cfg knobs (identical across the three seq runs; see +# ``m0_artifacts/ext_30b_seq2048.yml``): +# protrain_n_persist_override: 0 +# protrain_n_buffer_override: 12 +# protrain_n_swap_override: 0 +# protrain_n_checkpoint_override: 60 +N_PERSIST = 0 +N_BUFFER = 12 +N_SWAP = 0 +N_CHECKPOINT = 60 + +# Layout knobs observed in every log: ``layout built: S_chunk=67108864 +# N_chunk=302``. ``layout.mandatory_persistent`` was [0, 300, 301] per +# the wrapper's residency = prefix[0..0) | mandatory line — 3 chunks +# pinned by layout regardless of n_persist. +S_CHUNK = 67108864 # 64 MiB +N_CHUNK = 302 +MANDATORY_PERSISTENT_IDS = (0, 300, 301) + +# Measured steady-state peaks (GiB) from empirical 30B 4-bit Mode-C runs at three seq lengths. +MEASURED_STEADY_GIB = { + 512: 2.91, + 1024: 3.50, + 2048: 4.68, +} + +# 30B QLoRA model-state aggregate seen in the audit runs. Approximate: +# frozen base @ 4-bit ≈ 15 GiB; tiny LoRA adapters ≈ 100 MiB x 16 bytes +# (param+grad+fp32 master+m+v) ≈ 1.6 GiB. The trace's +# ``_count_model_state_bytes`` records these as a single aggregate; the +# cost model's ``model_state_present_bytes`` clamps +# ``persistent_factor = max(1.0, model_state_bytes / fp16_total)`` so +# the exact value matters only when it exceeds ``N_chunk * S_chunk`` +# (18.875 GiB here). 16 GiB lands BELOW that threshold ⇒ +# ``persistent_factor`` clamps to 1.0 — matching the audit logs' +# implicit assumption (the wrapper's ``peak prediction calibrated +# 0.00 -> 2.54 GB`` line ONLY makes sense at ``persistent_factor=1.0``). +MODEL_STATE_BYTES_30B_QLORA = 16 * GiB + + +def _build_layout() -> ChunkLayout: + """Reconstruct the audit's chunk layout (N_chunk=302 x 64 MiB) with the three layout-mandatory chunks pinned. + """ + chunks = tuple((ParamId(f"p.{cid}"),) for cid in range(N_CHUNK)) + param_to_chunk = {ParamId(f"p.{cid}"): ChunkId(cid) for cid in range(N_CHUNK)} + # Single dummy block_to_chunks entry (the audit n_offload=0 cfg + # never reads this map — estimate_peak only walks + # trace.activation_sizes and trace.op_order). + block_to_chunks: dict[BlockId, tuple[ChunkId, ...]] = { + BlockId(b): (ChunkId(b % N_CHUNK),) for b in range(LLAMA_30B_N_BLOCK) + } + return ChunkLayout( + S_chunk=S_CHUNK, + N_chunk=N_CHUNK, + chunks=chunks, + param_to_chunk=param_to_chunk, + block_to_chunks=block_to_chunks, + mandatory_persistent=frozenset( + ChunkId(cid) for cid in MANDATORY_PERSISTENT_IDS + ), + ) + + +def _build_synth_trace(seq_len: int) -> ProfilerTrace: + """Reconstruct synth_trace_from_overrides output (empty op_order, FFN-intermediate activation proxy).""" + bs = 1 # audit cfg: micro_batch_size: 1 + per_block_act_bytes = int(bs) * int(seq_len) * int(LLAMA_30B_INTERMEDIATE) * 2 + activation_sizes = { + BlockId(b): per_block_act_bytes for b in range(LLAMA_30B_N_BLOCK) + } + return ProfilerTrace( + op_order=(), + intra_op_delta={}, + inter_op_delta={}, + activation_sizes=activation_sizes, + model_state_bytes=int(MODEL_STATE_BYTES_30B_QLORA), + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + nccl_gather_s={}, + nccl_reduce_s={}, + arch_hash="huggyllama/llama-30b-qlora-modec", + bs=bs, + seq=int(seq_len), + sku="NVIDIA RTX PRO 6000 Blackwell (audit)", + world=1, + ) + + +def _build_hw_4bit() -> HardwareProfile: + """HW profile with dominant_param_bytes_per_element=0.5 to route estimate_peak through the 4-bit alpha branch.""" + return HardwareProfile( + gpu_sku="NVIDIA RTX PRO 6000 Blackwell (audit)", + gpu_memory_bytes=24 * GiB, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + zero3_shard=False, + cpu_adam_bytes_per_sec=2e9, + gpu_adam_bytes_per_sec=4e11, + dominant_param_bytes_per_element=0.5, + ) + + +# Band absorbs wrapper-side calibration offset, intermediate-vs-hidden proxy slack, and per-dtype alpha shift. +TOLERANCE_FRAC = 0.35 + + +@pytest.mark.parametrize("seq_len", [512, 1024, 2048]) +def test_modec_steady_peak_within_tolerance(seq_len: int) -> None: + """estimate_peak lands within +/-35% of the measured steady peak across seq=512/1024/2048.""" + layout = _build_layout() + trace = _build_synth_trace(seq_len) + hw = _build_hw_4bit() + cfg = CostConfig( + n_persist=N_PERSIST, + n_buffer=N_BUFFER, + n_swap=N_SWAP, + n_checkpoint=N_CHECKPOINT, + n_offload=0, + ) + block_map = assign_modes(N_SWAP, N_CHECKPOINT, LLAMA_30B_N_BLOCK) + + predicted_bytes = estimate_peak(cfg, trace, layout, block_map, hw) + predicted_gib = predicted_bytes / GiB + measured_gib = MEASURED_STEADY_GIB[seq_len] + relative_error = abs(predicted_gib - measured_gib) / measured_gib + + assert relative_error <= TOLERANCE_FRAC, ( + f"30B 4-bit Mode-C seq={seq_len}: predicted_peak={predicted_gib:.3f} GiB " + f"vs measured_steady={measured_gib:.3f} GiB; relative_error={relative_error:.3f} " + f"(tolerance +/-{TOLERANCE_FRAC:.2f}). " + f"Check the ckpt_chain_bytes accumulator in cost/memory.py::estimate_peak " + f"and the raw_peak == 0 fallback." + ) + + +def test_modec_steady_peak_scales_with_seq() -> None: + """Predicted peak must grow with sequence length on Mode-C; flat-output regression is the failure mode.""" + layout = _build_layout() + hw = _build_hw_4bit() + cfg = CostConfig( + n_persist=N_PERSIST, + n_buffer=N_BUFFER, + n_swap=N_SWAP, + n_checkpoint=N_CHECKPOINT, + n_offload=0, + ) + block_map = assign_modes(N_SWAP, N_CHECKPOINT, LLAMA_30B_N_BLOCK) + + predictions: list[tuple[int, int]] = [] + for seq_len in (512, 1024, 2048): + trace = _build_synth_trace(seq_len) + peak_bytes = estimate_peak(cfg, trace, layout, block_map, hw) + predictions.append((seq_len, peak_bytes)) + + # Strict monotonicity in seq_len. Each doubling of seq_len doubles + # the per-block activation contribution (synth proxy is linear in + # seq); the CKPT-chain sum across 60 blocks therefore doubles too, + # and the prediction must grow. + for (seq_a, peak_a), (seq_b, peak_b) in zip( + predictions, predictions[1:], strict=False + ): + assert peak_b > peak_a, ( + f"predicted peak must grow with sequence length: " + f"seq={seq_a} -> {peak_a / GiB:.3f} GiB but " + f"seq={seq_b} -> {peak_b / GiB:.3f} GiB (expected strict increase). " + f"This breaks the per-seq scaling guarantee." + ) + + # Sanity: the seq=2048 prediction must grow by at least + # ``2 * N_block * (1024 * intermediate * 2 bytes) * alpha_4bit`` + # relative to seq=1024 — the chain contribution scales linearly + # with seq, so doubling seq adds at least that much to raw_peak. + expected_min_delta = int( + 0.75 # ALPHA_FRAGMENTATION_4BIT + * LLAMA_30B_N_BLOCK + * 1024 + * LLAMA_30B_INTERMEDIATE + * 2 + * 0.5 # half-credit slack for cap / rounding interactions + ) + actual_delta = predictions[2][1] - predictions[1][1] + assert actual_delta >= expected_min_delta, ( + f"seq=1024 -> 2048 should add >= " + f"{expected_min_delta / GiB:.2f} GiB via the CKPT-chain term; " + f"got delta={actual_delta / GiB:.2f} GiB. Suggests the " + f"``ckpt_chain_bytes`` accumulator is dropping CKPT blocks." + ) diff --git a/tests/protrain/test_paged_adam_offload_mgpu.py b/tests/protrain/test_paged_adam_offload_mgpu.py new file mode 100644 index 0000000000..0dacca4de8 --- /dev/null +++ b/tests/protrain/test_paged_adam_offload_mgpu.py @@ -0,0 +1,226 @@ +"""Multi-GPU regression: QLoRA + paged_adamw_8bit + Mode C at seq=2048 crashed DDP broadcast on shape-preserving placeholders.""" + +from __future__ import annotations + +import os +import socket +import subprocess +import sys +import textwrap +from pathlib import Path + +import pytest + + +def _pick_free_port() -> int: + """Bind to port 0 so the OS hands back a free port and MASTER_PORT collisions are impossible.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] + + +def _nvidia_smi_gpu_indices() -> list[int]: + """List GPU indices via nvidia-smi to bypass the pytest host's CUDA_VISIBLE_DEVICES masking.""" + try: + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader,nounits"], + stderr=subprocess.DEVNULL, + timeout=10, + ).decode("utf-8", errors="replace") + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + return [] + indices: list[int] = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + try: + indices.append(int(line)) + except ValueError: + continue + return indices + + +# Precheck must verify these specific indices since count-based gating would still let launches fail late. +_REQUIRED_GPU_INDICES = (1, 4, 5, 7) + + +def _repo_root() -> Path: + """Resolve the worktree root (parent of ``src/axolotl``).""" + here = Path(__file__).resolve() + # tests/protrain/test_paged_adam_offload_mgpu.py -> tests/protrain -> tests -> repo + return here.parents[2] + + +# Every key in this YAML is part of the regression contract; do not edit without re-validating the failure repro. +_REPRODUCER_YAML = textwrap.dedent( + """\ + base_model: NousResearch/Meta-Llama-3-8B-Instruct + model_type: LlamaForCausalLM + + load_in_8bit: false + load_in_4bit: true + strict: false + + datasets: + - path: tatsu-lab/alpaca + type: alpaca + val_set_size: 0.0 + output_dir: {output_dir} + + sequence_len: 2048 + sample_packing: false + pad_to_sequence_len: true + + adapter: qlora + lora_r: 16 + lora_alpha: 32 + lora_dropout: 0.05 + lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - up_proj + - down_proj + - gate_proj + + plugins: + - axolotl.integrations.protrain.ProTrainPlugin + + protrain_auto_memory: true + protrain_auto_mode: false + protrain_force_all_persistent: false + protrain_zero3_shard: true + protrain_n_persist_override: 0 + protrain_n_buffer_override: 12 + protrain_n_swap_override: 0 + protrain_n_checkpoint_override: 32 + + gradient_accumulation_steps: 1 + micro_batch_size: 1 + max_steps: 5 + optimizer: paged_adamw_8bit + lr_scheduler: cosine + learning_rate: 0.0002 + + bf16: true + fp16: false + tf32: false + + gradient_checkpointing: false + + flash_attention: false + xformers_attention: false + + lora_mlp_kernel: false + lora_qkv_kernel: false + lora_o_kernel: false + + logging_steps: 1 + save_steps: 100 + save_first_step: false + save_total_limit: 1 + + warmup_steps: 1 + weight_decay: 0.0 + + peft_autocast_adapter_dtype: false + """ +) + + +def _launch_axolotl(yaml_path: Path, log_path: Path, repo_root: Path) -> int: + """Run accelerate launch of axolotl.cli.train; pins GPUs 1,4,5,7 with a 720s timeout for cold-cache hook install.""" + env = os.environ.copy() + env["DS_SKIP_CUDA_CHECK"] = "1" + env["PYTHONUNBUFFERED"] = "1" + env["PYTHONPATH"] = str(repo_root / "src") + env["CUDA_VISIBLE_DEVICES"] = "1,4,5,7" + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + env.setdefault("MASTER_PORT", str(_pick_free_port())) + + cmd = [ + sys.executable, + "-m", + "accelerate.commands.launch", + "--num_processes", + "4", + "--mixed_precision", + "bf16", + "-m", + "axolotl.cli.train", + str(yaml_path), + ] + with log_path.open("w") as f: + proc = subprocess.run( + cmd, + env=env, + stdout=f, + stderr=subprocess.STDOUT, + check=False, + timeout=720, + ) + return proc.returncode + + +def _require_real_multigpu() -> None: + """Skip helper for the multi-GPU subprocess test.""" + visible = _nvidia_smi_gpu_indices() + missing = [i for i in _REQUIRED_GPU_INDICES if i not in visible] + if missing: + pytest.skip( + f"4-bit + paged_adamw_8bit + Mode C multi-GPU regression requires " + f"GPU indices {list(_REQUIRED_GPU_INDICES)} (hard-coded in " + f"``_launch_axolotl``); nvidia-smi reports {visible}, " + f"missing {missing}" + ) + try: + import accelerate # noqa: F401 + except ImportError: + pytest.skip("accelerate not installed; required for multi-GPU launch") + + +@pytest.mark.slow +@pytest.mark.gpu +def test_paged_adam_offload_mgpu_no_ddp_broadcast_crash(tmp_path: Path) -> None: + """4x3090 QLoRA + paged_adamw_8bit + Mode C at seq=2048 trains 5 steps without the DDP broadcast crash on expand placeholders.""" + _require_real_multigpu() + + repo_root = _repo_root() + workdir = tmp_path + output_dir = workdir / "protrain_paged_qlora_mgpu_out" + + yaml_path = workdir / "ext_b1_qlora_paged_seq2048_mgpu.yml" + yaml_path.write_text(_REPRODUCER_YAML.format(output_dir=str(output_dir))) + + log_path = workdir / "ext_b1_qlora_paged_seq2048_mgpu.log" + rc = _launch_axolotl(yaml_path, log_path, repo_root) + log_text = log_path.read_text() + log_tail = log_text[-3000:] + + assert rc == 0, ( + f"paged_adamw_8bit + Mode C multi-GPU subprocess exited {rc} " + f"(expected 0); tail:\n{log_tail}" + ) + assert "Traceback" not in log_text, ( + f"unexpected Traceback in the captured log; tail:\n{log_tail}" + ) + # DDP init_sync bypass must engage when the chunk-managed marker is present, else broadcast over expand placeholders crashes. + assert "patched-injection of init_sync=False" in log_text, ( + f"DDP init_sync bypass did NOT fire on this YAML's path. tail:\n{log_tail}" + ) + # Chunk-managed param-name registration is the secondary defence; keep pinning it so it cannot silently empty out. + assert "registered" in log_text and "chunk-managed param names" in log_text, ( + f"chunk-managed param-name registration log line missing. tail:\n{log_tail}" + ) + # Sanity: 5 steps of training means at least 5 per-step loss lines. + assert log_text.count("'loss':") >= 5, ( + f"expected >= 5 per-step loss log lines for max_steps=5, got " + f"{log_text.count(chr(0x27) + 'loss' + chr(0x27) + ':')}; " + f"tail:\n{log_tail}" + ) diff --git a/tests/protrain/test_param_data_shape_preservation.py b/tests/protrain/test_param_data_shape_preservation.py new file mode 100644 index 0000000000..ab88a35a0f --- /dev/null +++ b/tests/protrain/test_param_data_shape_preservation.py @@ -0,0 +1,584 @@ +"""Pin the shape-preserving placeholder invariant: released params keep their logical shape so autograd records the real size.""" + +from __future__ import annotations + +from typing import cast + +import pytest + +from axolotl.integrations.protrain.types import ( + BlockId, + ParamId, +) + + +def _tiny_model(hidden: int = 64, n_layers: int = 4): + """A tiny 4-layer transformer-shaped model so each ``h.{i}`` Linear becomes its own block / chunk.""" + import torch + from torch import nn + + class TinyTransformer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.embed = nn.Linear(hidden, hidden, bias=False) + self.h = nn.ModuleList( + [nn.Linear(hidden, hidden, bias=False) for _ in range(n_layers)] + ) + self.head = nn.Linear(hidden, hidden, bias=False) + + def forward(self, x: "torch.Tensor") -> "torch.Tensor": + x = self.embed(x) + for layer in self.h: + x = layer(x) + return self.head(x) + + torch.manual_seed(0) + return TinyTransformer() + + +def _build_layout_for(model, S_chunk: int): + from axolotl.integrations.protrain.chunk.layout import build_layout + + block_spans: dict[BlockId, list[ParamId]] = {} + for name, _ in model.named_parameters(): + if name.startswith("h."): + idx = int(name.split(".")[1]) + block_spans.setdefault(cast(BlockId, idx), []).append(cast(ParamId, name)) + + exec_order = [cast(ParamId, n) for n, _ in model.named_parameters()] + return build_layout(model, exec_order, S_chunk, block_spans) + + +def _build_chunk_manager( + model, + n_persist: int, + S_chunk: int, + *, + shape_preserving_placeholders: bool, + n_buffer: int | None = None, +): + """Assemble a :class:`ChunkManager` with the shape-preserving-placeholders flag toggled.""" + import torch + + from axolotl.integrations.protrain.chunk.buffer_pool import BufferPool + from axolotl.integrations.protrain.chunk.manager import ChunkManager + from axolotl.integrations.protrain.chunk.pinned_alloc import PinnedHostMemory + + layout = _build_layout_for(model, S_chunk) + if n_buffer is None: + n_buffer = max(2, min(4, layout.N_chunk - n_persist)) + host = PinnedHostMemory(n_buffer=n_buffer, S_chunk=layout.S_chunk) + pool = BufferPool( + n_buffer=n_buffer, + S_chunk=layout.S_chunk, + pinned_host=host, + device=torch.device("cuda"), + ) + mgr = ChunkManager( + model=model, + layout=layout, + n_persist=n_persist, + buffer_pool=pool, + cpu_optim=None, + gpu_optim=None, + device=torch.device("cuda"), + shape_preserving_placeholders=shape_preserving_placeholders, + ) + return mgr, layout, pool, host + + +def _teardown_chunk_manager(mgr, host, pool) -> None: + """Best-effort teardown so an assertion failure cannot leak hooks, pinned-host borrows, or buffer-pool state into later tests.""" + try: + mgr.uninstall() + except Exception: # noqa: BLE001 — best-effort teardown + pass + try: + host.close() + except Exception: # noqa: BLE001 — best-effort teardown + pass + # ``del pool`` drops the local reference so the GC can release + # the pool's GPU buffer slots immediately rather than at + # function-return. + del pool + + +@pytest.mark.gpu +def test_release_state_preserves_shape() -> None: + """With the flag on, every non-persistent param keeps its real shape after ``materialize_offload`` (not ``Size([0])``).""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_model(hidden=hidden, n_layers=n_layers).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + + # Record the canonical shape of every named param BEFORE + # materialize_offload — we'll compare against this snapshot below. + original_shapes: dict[str, torch.Size] = { + name: p.shape for name, p in model.named_parameters() + } + original_dtypes: dict[str, torch.dtype] = { + name: p.dtype for name, p in model.named_parameters() + } + + mgr, layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + mgr.materialize_offload() + + # Every non-persistent chunk's params should retain their original + # shape — the legacy code would have rebound to torch.Size([0]). + non_persist = sorted(mgr._non_persistent_ids) + assert non_persist, "need at least one non-persistent chunk" + for cid in non_persist: + for pid in layout.chunks[int(cid)]: + param = dict(model.named_parameters())[str(pid)] + expected_shape = original_shapes[str(pid)] + assert param.shape == expected_shape, ( + f"shape-preserving release violated: param={pid} " + f"expected shape={expected_shape}, got {param.shape}" + ) + assert param.size() == expected_shape, ( + f"param.size() drift: param={pid} expected {expected_shape}, " + f"got {param.size()}" + ) + # dim() must reflect the original ndim too (LoRA factors + # are 2-D; embedding is 2-D; layernorm scales are 1-D — the + # bug surface includes shape AND dim consistency). + assert param.dim() == len(expected_shape), ( + f"param.dim() drift: param={pid} expected {len(expected_shape)}, " + f"got {param.dim()}" + ) + assert param.dtype == original_dtypes[str(pid)], ( + f"dtype drift: param={pid} expected {original_dtypes[str(pid)]}, " + f"got {param.dtype}" + ) + assert param.device.type == "cuda", ( + f"released param expected on cuda, got {param.device}" + ) + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_release_state_default_off_is_unchanged() -> None: + """Default ``shape_preserving_placeholders=False`` keeps the legacy ``numel()==0`` placeholder semantics intact.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_model(hidden=hidden, n_layers=n_layers).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + + mgr, layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=False, + ) + try: + mgr.materialize_offload() + + # Legacy invariant: every non-persistent chunk's params have a + # torch.Size([0]) placeholder after release. + non_persist = sorted(mgr._non_persistent_ids) + for cid in non_persist: + for pid in layout.chunks[int(cid)]: + param = dict(model.named_parameters())[str(pid)] + assert param.data.numel() == 0, ( + f"legacy invariant broken: param={pid} expected numel==0, " + f"got numel={param.data.numel()} shape={param.shape}" + ) + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_gather_offload_round_trip_shape() -> None: + """After gather→offload, released shape is preserved — confirms ``offload()`` honours the flag, not just ``materialize_offload``.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_model(hidden=hidden, n_layers=n_layers).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + + original_shapes: dict[str, torch.Size] = { + name: p.shape for name, p in model.named_parameters() + } + + mgr, layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + mgr.materialize_offload() + + non_persist = sorted(mgr._non_persistent_ids) + assert non_persist, "need at least one non-persistent chunk" + cid = non_persist[0] + + # gather → params should be at real shape with real storage + mgr.gather(cid) + for pid in layout.chunks[int(cid)]: + param = dict(model.named_parameters())[str(pid)] + assert param.shape == original_shapes[str(pid)] + assert param.data.numel() > 0, "gathered param should have real storage" + + # offload → released; under the flag, shape must still match. + mgr.offload(cid) + for pid in layout.chunks[int(cid)]: + param = dict(model.named_parameters())[str(pid)] + assert param.shape == original_shapes[str(pid)], ( + f"post-offload shape drift on flag=True: param={pid} " + f"expected {original_shapes[str(pid)]}, got {param.shape}" + ) + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_storage_footprint_is_bounded() -> None: + """The shape-preserving placeholder costs ~zero extra bytes: one 1-element scratch per dtype, shared via expand.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_model(hidden=hidden, n_layers=n_layers).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + + mgr, layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + mgr.materialize_offload() + + # Walk the released params; bucket their storage pointers by dtype. + seen_storage_ptrs: dict[torch.dtype, set[int]] = {} + for cid in sorted(mgr._non_persistent_ids): + for pid in layout.chunks[int(cid)]: + param = dict(model.named_parameters())[str(pid)] + ptr = param.data.untyped_storage().data_ptr() + seen_storage_ptrs.setdefault(param.dtype, set()).add(ptr) + + # For each dtype represented in the released set, every param's + # released-state storage_ptr should equal the per-dtype scratch's + # storage_ptr. + for dtype, ptrs in seen_storage_ptrs.items(): + scratch = mgr._shape_scratch_by_dtype.get(dtype) + assert scratch is not None, ( + f"no scratch cached for dtype={dtype} but released params exist" + ) + # Scratch is 1 element wide; expand views share that storage. + assert scratch.numel() == 1, ( + f"scratch for dtype={dtype} should be 1-element, got " + f"numel={scratch.numel()}" + ) + scratch_ptr = scratch.untyped_storage().data_ptr() + assert ptrs == {scratch_ptr}, ( + f"dtype={dtype}: released params should all share scratch's " + f"storage_ptr={scratch_ptr}, got {ptrs}" + ) + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_autograd_shape_capture_on_released_param() -> None: + """Direct reproducer of the autograd race: a forward over the placeholder must record the real shape, not ``[0]``.""" + pytest.importorskip("torch") + import torch + from torch import nn + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + # Build a Parameter with a non-trivial 2D shape (mirrors a LoRA + # factor [out_features, r]). + real_shape = (256, 16) + dtype = torch.bfloat16 + param = nn.Parameter( + torch.empty(0, dtype=dtype, device="cuda") + ) # initial "released" state + + # ---- Legacy [0] placeholder path: param.size() == [0] ---------- + assert param.shape == torch.Size([0]) + # Calling F.linear in this state fails BEFORE the autograd record + # can complete — the kernel's shape check trips. + x = torch.randn(4, real_shape[1], dtype=dtype, device="cuda") + with pytest.raises(RuntimeError): + _ = nn.functional.linear(x, param) + + # ---- Shape-preserving placeholder path: param.size() == real_shape --- + # We construct a manager just to use the helper method + # ``_shape_preserving_placeholder`` directly; full materialize is + # not needed for this micro-test. + hidden = 64 + model = _tiny_model(hidden=hidden, n_layers=2).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + mgr, _layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + placeholder = mgr._shape_preserving_placeholder(real_shape, dtype) + assert placeholder.shape == torch.Size(real_shape) + assert placeholder.dtype == dtype + assert placeholder.device.type == "cuda" + # Storage cost: one element (the scratch). + assert placeholder.untyped_storage().nbytes() == placeholder.element_size() + + param.data = placeholder + assert param.shape == torch.Size(real_shape) + assert param.size() == torch.Size(real_shape) + assert param.dim() == 2 + + # Forward must run while the placeholder is still bound so autograd records its shape (not the real-data rebind). + x = torch.randn( + 4, real_shape[1], dtype=dtype, device="cuda", requires_grad=True + ) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + y_placeholder = nn.functional.linear(x, param) + # The matmul-output shape must reflect the placeholder's reported + # weight shape; if the placeholder shrank back to ``[0]`` the + # output would be ``(batch, 0)`` and the shape assertion below + # would catch it BEFORE backward fires. + assert y_placeholder.shape == torch.Size([4, real_shape[0]]), ( + f"forward through placeholder produced wrong-shape output: " + f"expected (4, {real_shape[0]}), got {tuple(y_placeholder.shape)} — " + f"placeholder.size() likely regressed." + ) + + # Rebind to real storage before backward; a placeholder-shape regression would surface as a ToCopyBackward0 error. + real_data = torch.randn(*real_shape, dtype=dtype, device="cuda") + param.data = real_data + + loss = y_placeholder.sum() + loss.backward() + assert param.grad is not None + assert param.grad.shape == torch.Size(real_shape), ( + f"autograd recorded the WRONG shape: expected {real_shape}, " + f"got {tuple(param.grad.shape)} — the shape-preserving " + f"placeholder invariant has regressed." + ) + + # Also exercise the post-gather steady-state forward+backward + # path so a regression that only fires on the placeholder side + # is distinguishable from one that fires on the real-data side. + param.grad = None + x_real = torch.randn( + 4, real_shape[1], dtype=dtype, device="cuda", requires_grad=True + ) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + y_real = nn.functional.linear(x_real, param) + y_real.sum().backward() + assert param.grad is not None + assert param.grad.shape == torch.Size(real_shape) + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_release_state_placeholder_is_write_unsafe() -> None: + """The expand placeholder is NOT write-safe: any in-place write trips PyTorch's shared-storage hazard (DDP broadcast root cause).""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + model = _tiny_model(hidden=hidden, n_layers=2).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + mgr, _layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + placeholder = mgr._shape_preserving_placeholder( + torch.Size([hidden, hidden]), torch.float32 + ) + # Shape preserved by the placeholder. + assert placeholder.shape == torch.Size([hidden, hidden]) + # Storage points at the per-dtype scratch (1 element). + assert placeholder.untyped_storage().nbytes() == placeholder.element_size() + + # In-place write fails with the shared-storage hazard. Any of + # ``copy_``, ``add_``, ``zero_``, ``mul_`` triggers it. + real_payload = torch.zeros(hidden, hidden, dtype=torch.float32, device="cuda") + with pytest.raises(RuntimeError, match="more than one element"): + placeholder.copy_(real_payload) + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_chunk_managed_param_names_excludes_persistent() -> None: + """``chunk_managed_param_names()`` returns exactly the non-persistent param names that DDP must skip on broadcast.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_model(hidden=hidden, n_layers=n_layers).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + + mgr, layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + mgr.materialize_offload() + + ignored = mgr.chunk_managed_param_names() + + # Build the expected set: every param in a non-persistent chunk. + expected: set[str] = set() + for cid in mgr._non_persistent_ids: + for pid in layout.chunks[int(cid)]: + expected.add(str(pid)) + assert ignored == expected, ( + f"chunk_managed_param_names mismatch: " + f"expected={sorted(expected)} got={sorted(ignored)}" + ) + + # Persistent chunk params are explicitly NOT in the set. + persistent_names: set[str] = set() + for cid in mgr._persistent_ids: + for pid in layout.chunks[int(cid)]: + persistent_names.add(str(pid)) + assert ignored.isdisjoint(persistent_names), ( + f"persistent params leaked into ignore set: " + f"intersection={ignored & persistent_names}" + ) + + # Sanity: every returned name resolves through named_parameters(). + by_name = dict(model.named_parameters()) + for name in ignored: + assert name in by_name, f"unknown param name in ignore set: {name}" + + finally: + _teardown_chunk_manager(mgr, host, pool) + + +@pytest.mark.gpu +def test_release_state_is_write_safe_through_gather_round_trip() -> None: + """Gather must rebind ``param.data`` to fresh storage before any write so the write-unsafe placeholder is never written to.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + torch.cuda.empty_cache() + + hidden = 64 + n_layers = 4 + model = _tiny_model(hidden=hidden, n_layers=n_layers).to("cuda") + S_chunk = hidden * hidden * 4 + 4096 + + mgr, layout, pool, host = _build_chunk_manager( + model, + n_persist=1, + S_chunk=S_chunk, + shape_preserving_placeholders=True, + ) + try: + mgr.materialize_offload() + + non_persist = sorted(mgr._non_persistent_ids) + assert non_persist, "need at least one non-persistent chunk" + cid = non_persist[0] + + # Pre-gather: param.data IS the expand placeholder (write-unsafe). + target_pid = str(layout.chunks[int(cid)][0]) + target_param = dict(model.named_parameters())[target_pid] + pre_gather_storage_ptr = target_param.data.untyped_storage().data_ptr() + + # gather → param.data must rebind to a fresh typed view of the pool + # buffer before any write reaches the placeholder. + mgr.gather(cid) + target_param = dict(model.named_parameters())[target_pid] + post_gather_storage_ptr = target_param.data.untyped_storage().data_ptr() + assert post_gather_storage_ptr != pre_gather_storage_ptr, ( + "gather did not rebind param.data — still pointing at the " + "expand placeholder; in-place write would trip the hazard" + ) + + # Confirm the gathered param IS write-safe: an in-place fill must + # succeed (proving the rebind landed on real storage). + target_param.data.fill_(0.5) + assert torch.allclose( + target_param.data, + torch.full_like(target_param.data, 0.5), + ), "in-place fill on gathered param did not take effect" + + # Round-trip: offload returns to placeholder; another gather must + # again rebind to fresh storage. This pins the cycle. + mgr.offload(cid) + target_param = dict(model.named_parameters())[target_pid] + placeholder_storage_ptr = target_param.data.untyped_storage().data_ptr() + # Re-gather and confirm the rebind happens before any write. + mgr.gather(cid) + target_param = dict(model.named_parameters())[target_pid] + re_gather_storage_ptr = target_param.data.untyped_storage().data_ptr() + assert re_gather_storage_ptr != placeholder_storage_ptr, ( + "re-gather did not rebind param.data after offload returned " + "it to the expand placeholder" + ) + + finally: + _teardown_chunk_manager(mgr, host, pool) diff --git a/tests/protrain/test_plugin_args_validators.py b/tests/protrain/test_plugin_args_validators.py index 121932ebae..7e578356dc 100644 --- a/tests/protrain/test_plugin_args_validators.py +++ b/tests/protrain/test_plugin_args_validators.py @@ -123,22 +123,20 @@ def test_mutex_rejects_sequence_parallel() -> None: assert "sequence_parallel_degree" in str(exc.value) -def test_mutex_rejects_load_in_8bit() -> None: +def test_mutex_allows_load_in_8bit() -> None: + """M0 spike validated bnb 8-bit composes with ProTrain Mode A; validator must allow it.""" cfg = _minimal_active_cfg(load_in_8bit=True) - with pytest.raises(ValidationError) as exc: - ProTrainArgs.model_validate(cfg) - assert "load_in_8bit" in str(exc.value) + ProTrainArgs.model_validate(cfg) -def test_mutex_rejects_load_in_4bit() -> None: +def test_mutex_allows_load_in_4bit() -> None: + """M0 spike validated bnb 4-bit (QLoRA) composes with ProTrain Mode A; validator must allow it.""" cfg = _minimal_active_cfg(load_in_4bit=True) - with pytest.raises(ValidationError) as exc: - ProTrainArgs.model_validate(cfg) - assert "load_in_4bit" in str(exc.value) + ProTrainArgs.model_validate(cfg) def test_mutex_allows_load_in_xbit_false() -> None: - """Both bnb flags explicitly False is the supported path.""" + """Both bnb flags explicitly False is still the supported path.""" cfg = _minimal_active_cfg(load_in_8bit=False, load_in_4bit=False) ProTrainArgs.model_validate(cfg) @@ -168,3 +166,98 @@ def test_force_all_persistent_default_is_false() -> None: """ args = ProTrainArgs() assert args.protrain_force_all_persistent is False + + +# --------------------------------------------------------------------- +# Optimizer allow-list (M6B) — ProTrain's chunk-manager adapters only +# drive AdamW-shaped state. Unsupported optimizers must be rejected at +# config-load time rather than corrupting state inside the step path. +# --------------------------------------------------------------------- + + +def test_optimizer_validator_accepts_adamw_torch() -> None: + cfg = _minimal_active_cfg(optimizer="adamw_torch") + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_accepts_adamw_torch_fused() -> None: + cfg = _minimal_active_cfg(optimizer="adamw_torch_fused") + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_accepts_adamw_8bit() -> None: + cfg = _minimal_active_cfg(optimizer="adamw_8bit") + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_accepts_adamw_bnb_8bit() -> None: + cfg = _minimal_active_cfg(optimizer="adamw_bnb_8bit") + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_accepts_paged_adamw_8bit() -> None: + cfg = _minimal_active_cfg(optimizer="paged_adamw_8bit") + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_accepts_missing_optimizer() -> None: + """No ``optimizer`` key — Axolotl picks a supported default elsewhere.""" + cfg = _minimal_active_cfg() + assert "optimizer" not in cfg + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_accepts_none_optimizer() -> None: + """Explicit ``optimizer: null`` must not raise (default-fill happens later).""" + cfg = _minimal_active_cfg(optimizer=None) + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_rejects_lion() -> None: + cfg = _minimal_active_cfg(optimizer="lion_pytorch") + with pytest.raises(ValidationError) as exc: + ProTrainArgs.model_validate(cfg) + msg = str(exc.value) + assert "lion_pytorch" in msg + assert "ProTrain" in msg + + +def test_optimizer_validator_rejects_adafactor() -> None: + cfg = _minimal_active_cfg(optimizer="adafactor") + with pytest.raises(ValidationError) as exc: + ProTrainArgs.model_validate(cfg) + assert "adafactor" in str(exc.value) + + +def test_optimizer_validator_rejects_sgd() -> None: + cfg = _minimal_active_cfg(optimizer="sgd") + with pytest.raises(ValidationError) as exc: + ProTrainArgs.model_validate(cfg) + assert "sgd" in str(exc.value) + + +def test_optimizer_validator_message_cites_chunk_optim_path() -> None: + """Error message must point users at the adapter source file.""" + cfg = _minimal_active_cfg(optimizer="muon") + with pytest.raises(ValidationError) as exc: + ProTrainArgs.model_validate(cfg) + msg = str(exc.value) + assert "src/axolotl/integrations/protrain/chunk/optim.py" in msg + # Message should also enumerate the supported set + give a fix. + assert "adamw_torch" in msg + assert "remove the ProTrain plugin" in msg + + +def test_optimizer_validator_is_case_insensitive_accept() -> None: + """Mixed-case supported names must still be accepted.""" + cfg = _minimal_active_cfg(optimizer="AdamW_Torch") + ProTrainArgs.model_validate(cfg) + + +def test_optimizer_validator_skips_when_protrain_inactive() -> None: + """An unsupported optimizer is fine if ProTrain isn't enabled.""" + cfg = { + "protrain_auto_memory": False, + "optimizer": "lion_pytorch", + } + ProTrainArgs.model_validate(cfg) diff --git a/tests/protrain/test_profiler.py b/tests/protrain/test_profiler.py index f99932edb5..1ef5145ba3 100644 --- a/tests/protrain/test_profiler.py +++ b/tests/protrain/test_profiler.py @@ -553,6 +553,72 @@ def forward(self, input_ids=None, **kwargs): ) +@pytest.mark.gpu +def test_force_all_persistent_suppresses_on_demand_in_run_trace( + gpu_device, monkeypatch, caplog +): + """force_all_persistent=True must skip the on-demand trace gate even at 0% device-memory threshold.""" + import logging + + import torch + from torch import nn + + if not torch.cuda.is_available(): + pytest.skip("CUDA unavailable") + + device = torch.device(f"cuda:{gpu_device}") + + class TinyBlock(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(32, 64) + self.fc2 = nn.Linear(64, 32) + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + + class TinyModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([TinyBlock(), TinyBlock()]) + + def forward(self, input_ids=None, **kwargs): + x = input_ids.to(torch.float32) + for layer in self.layers: + x = layer(x) + return type("Out", (), {"loss": x.sum()})() + + model = TinyModel().to(device) + batch = {"input_ids": torch.randn(2, 32, device=device)} + + # Force on-demand to engage (without the fix) by dropping the threshold. + from axolotl.integrations.protrain.profiler import trace as trace_mod + + monkeypatch.setattr(trace_mod, "ON_DEMAND_STATE_BYTES_FRACTION", 0.0) + + cfg = ProfilerConfig( + batch_size=2, + seq_len=32, + device=str(device), + include_backward=False, + on_demand=True, + force_all_persistent=True, + ) + + with caplog.at_level(logging.INFO, logger=trace_mod.LOG.name): + trace = run_trace(model, batch, cfg) + + assert len(trace.op_order) > 0 + log_text = "\n".join(rec.getMessage() for rec in caplog.records) + assert "force_all_persistent=True; skipping on-demand" in log_text, ( + f"force_all_persistent did not suppress on-demand engagement; " + f"trace log was:\n{log_text}" + ) + assert "Profiler engaging on-demand mode" not in log_text, ( + f"on-demand was engaged despite force_all_persistent=True; log: {log_text}" + ) + + @pytest.mark.gpu def test_on_demand_engaged_cost_model_finite(gpu_device, monkeypatch): """Cost model must produce a finite, positive iter-time on an on-demand trace. diff --git a/tests/protrain/test_quantization.py b/tests/protrain/test_quantization.py new file mode 100644 index 0000000000..ad4bc81e78 --- /dev/null +++ b/tests/protrain/test_quantization.py @@ -0,0 +1,104 @@ +"""ProTrain + bitsandbytes quantization composability: validator drop and packed-byte param sizing.""" + +from __future__ import annotations + +from typing import cast + +import torch +from torch import nn + +from axolotl.integrations.protrain.args import ProTrainArgs +from axolotl.integrations.protrain.chunk.layout import _param_bytes +from axolotl.integrations.protrain.types import ParamId + + +def _minimal_active_cfg(**overrides) -> dict: + cfg: dict = { + "protrain_auto_memory": True, + "plugins": ["axolotl.integrations.protrain.ProTrainPlugin"], + "base_model": "HuggingFaceTB/SmolLM2-135M", + } + cfg.update(overrides) + return cfg + + +# --------------------------------------------------------------------- +# Validator drop — load_in_8bit / load_in_4bit must be accepted when +# ProTrain is active. Mirrors the positive-control test in +# ``test_plugin_args_validators.py`` but kept here so the quant +# milestone owns its own regression surface. +# --------------------------------------------------------------------- + + +def test_load_in_8bit_passes_with_protrain_active() -> None: + cfg = _minimal_active_cfg(load_in_8bit=True) + # Must NOT raise. + ProTrainArgs.model_validate(cfg) + + +def test_load_in_4bit_passes_with_protrain_active() -> None: + cfg = _minimal_active_cfg(load_in_4bit=True) + # Must NOT raise. + ProTrainArgs.model_validate(cfg) + + +def test_load_in_4bit_passes_with_qlora_adapter() -> None: + """QLoRA = ``load_in_4bit: true`` + ``adapter: qlora``; the canonical config.""" + cfg = _minimal_active_cfg(load_in_4bit=True, adapter="qlora") + ProTrainArgs.model_validate(cfg) + + +# --------------------------------------------------------------------- +# Chunk layout — _param_bytes must size packed-byte storage correctly. +# Synthetic models stand in for bnb's Int8Params / Params4bit because: +# * Int8Params post-.cuda() with has_fp16_weights=False is a +# torch.int8 tensor of shape (out, in), element_size=1. +# * Params4bit storage is a torch.uint8 tensor of shape +# (ceil(in*out/2), 1), element_size=1. +# In both cases byte size = numel * 1 = packed bytes — the exact +# accounting the chunk packer relies on. Reproduce that shape with +# stock dtypes so the test runs without bnb installed. +# --------------------------------------------------------------------- + + +def test_param_bytes_int8_matches_packed_bytes() -> None: + """Int8Params storage: numel == out*in, element_size == 1.""" + out, in_ = 32, 64 + model = nn.Module() + # Bypass nn.Parameter's float-only constraint by registering a buffer-shaped + # int8 storage as if it were a frozen weight (matches Int8Params stride). + model.weight = nn.Parameter( + torch.zeros(out, in_, dtype=torch.int8), requires_grad=False + ) + sizes = _param_bytes(model) + assert sizes[cast(ParamId, "weight")] == out * in_ # 1 byte per element + + +def test_param_bytes_uint8_matches_packed_bytes() -> None: + """Params4bit storage: 2 weights packed per uint8 byte → numel == ceil(out*in/2).""" + out, in_ = 32, 64 + packed = (out * in_ + 1) // 2 # 2-per-byte packing + model = nn.Module() + model.weight = nn.Parameter( + torch.zeros(packed, 1, dtype=torch.uint8), requires_grad=False + ) + sizes = _param_bytes(model) + assert ( + sizes[cast(ParamId, "weight")] == packed + ) # 1 byte per element, packed storage + + +def test_param_bytes_mixed_dtypes() -> None: + """A frozen-int8 base + fp16 LoRA + fp32 norm scale — the realistic LoRA-on-8bit shape.""" + model = nn.Module() + model.base_weight = nn.Parameter( + torch.zeros(32, 64, dtype=torch.int8), requires_grad=False + ) + model.lora_a = nn.Parameter(torch.zeros(16, 64, dtype=torch.float16)) + model.lora_b = nn.Parameter(torch.zeros(32, 16, dtype=torch.float16)) + model.norm = nn.Parameter(torch.zeros(64, dtype=torch.float32)) + sizes = _param_bytes(model) + assert sizes[cast(ParamId, "base_weight")] == 32 * 64 * 1 # int8 packed + assert sizes[cast(ParamId, "lora_a")] == 16 * 64 * 2 # fp16 + assert sizes[cast(ParamId, "lora_b")] == 32 * 16 * 2 + assert sizes[cast(ParamId, "norm")] == 64 * 4 # fp32 diff --git a/tests/protrain/test_resume_robustness.py b/tests/protrain/test_resume_robustness.py new file mode 100644 index 0000000000..e934c03c67 --- /dev/null +++ b/tests/protrain/test_resume_robustness.py @@ -0,0 +1,536 @@ +"""In-process rebuild lifecycle invariants: DDP ignore rebuilds from snapshot, CPU adapter shuts down before swap, stale skip-state clears on non-shape-preserving rewrap.""" + +from __future__ import annotations + +import math + +import pytest + + +def _build_tiny_lora_model(): + """Minimal LoRA-on-Llama setup small enough for the chunk manager + searcher to fit on any test rig.""" + pytest.importorskip("peft") + pytest.importorskip("transformers") + + import torch + from peft import LoraConfig, get_peft_model + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + hidden_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=512, + vocab_size=1024, + max_position_embeddings=128, + rms_norm_eps=1e-5, + use_cache=False, + ) + torch.manual_seed(0) + base = LlamaForCausalLM(cfg).to(dtype=torch.bfloat16) + lora_cfg = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.0, + bias="none", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + model = get_peft_model(base, lora_cfg) + return model, cfg + + +def _wrap_protrain( + model, + cfg, + *, + force_all_persistent: bool, + zero3_shard: bool, + n_persist_override: int | None = None, + n_buffer_override: int | None = None, + n_swap_override: int | None = None, + n_checkpoint_override: int | None = None, + n_offload_override: int | None = None, + small_chunk: bool = False, +): + """Wrap a model in ProTrain; small_chunk + overrides let tests force the CPU-adapter / non-persistent paths the searcher would otherwise skip.""" + import torch + + from axolotl.integrations.protrain.api import ( + protrain_model_wrapper, + protrain_optimizer_wrapper, + ) + from axolotl.integrations.protrain.types import HardwareProfile + + hw = HardwareProfile( + gpu_sku=torch.cuda.get_device_name(0), + gpu_memory_bytes=torch.cuda.get_device_properties(0).total_memory, + gpu_count=1, + pcie_h2d_bps=13e9, + pcie_d2h_bps=13e9, + has_nvlink=False, + ) + + # When small_chunk=True, monkey-patch pick_S_chunk so the layout + # builder produces multiple chunks. Without this, the tiny test + # model's params all fit in a single chunk and force_all_persistent + # vs override-driven non-persistent become indistinguishable. The + # 1 MiB value matches the working pattern in + # ``test_lora_offload_mode``; finer S_chunk values produce a + # larger N_chunk than n_buffer_override can satisfy + # (``min_n_buffer_for`` validates 2 * max(non_persistent_per_block)). + import axolotl.integrations.protrain.api.model_wrapper as mw + + orig_pick_S_chunk = mw.pick_S_chunk + if small_chunk: + mw.pick_S_chunk = lambda *a, **k: 1 << 20 # 1 MiB + try: + wrapped = protrain_model_wrapper( + model, + model_config=cfg, + hardware_profile=hw, + batch_size=1, + seq_len=32, + capacity_bytes=4 * (1 << 30), + force_all_persistent=force_all_persistent, + zero3_shard=zero3_shard, + n_persist_override=n_persist_override, + n_buffer_override=n_buffer_override, + n_swap_override=n_swap_override, + n_checkpoint_override=n_checkpoint_override, + n_offload_override=n_offload_override, + ) + finally: + # Restore the global so a subsequent test's wrap uses the + # searcher-picked S_chunk (one global monkey-patch leak would + # silently distort downstream resource accounting). + mw.pick_S_chunk = orig_pick_S_chunk + optim = protrain_optimizer_wrapper(wrapped, lr=1e-3) + return wrapped, optim + + +def _train_one_step(wrapped, optim, *, input_ids, labels) -> float: + out = wrapped.module(input_ids=input_ids, labels=labels) + loss = out.loss + loss_value = float(loss.detach()) + loss.backward() + optim.step() + optim.zero_grad() + return loss_value + + +def _make_batch(cfg): + import torch + + torch.manual_seed(1) + return ( + torch.randint(0, cfg.vocab_size, (1, 32), device="cuda", dtype=torch.long), + torch.randint(0, cfg.vocab_size, (1, 32), device="cuda", dtype=torch.long), + ) + + +@pytest.mark.gpu +def test_ddp_ignore_set_does_not_grow_on_repeat_materialize() -> None: + """A second materialize_offload must not grow the DDP ignore set; rebuild from the original snapshot, do not union.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain D2 invariant requires CUDA.") + + model, cfg = _build_tiny_lora_model() + model = model.to("cuda") + wrapped, _optim = _wrap_protrain( + model, cfg, force_all_persistent=False, zero3_shard=True + ) + try: + underlying = getattr(wrapped, "module", wrapped) + chunk_manager = getattr(wrapped, "chunk_manager", None) + if chunk_manager is None or not getattr( + chunk_manager, "_shape_preserving_placeholders", False + ): + # Single-process Mode C silently downgrades to Mode A + # (zero3_shard coerced to False when world_size <= 1), so + # the shape-preserving placeholder path isn't engaged. + # Skip in that case — multi-GPU coverage lives in + # ``test_real_multigpu_cross_mode_resume_*``. + pytest.skip( + "single-process Mode C downgrade path: " + "shape-preserving placeholders not engaged." + ) + + first_ignore = list( + getattr(underlying, "_ddp_params_and_buffers_to_ignore", []) + ) + first_snapshot = getattr(underlying, "_protrain_ddp_original_ignore", "") + first_size = len(first_ignore) + + # Simulate the resume hook's second materialize_offload call. + assert chunk_manager is not None + chunk_manager.restore_to_gpu() + chunk_manager.materialize_offload() + + second_ignore = list( + getattr(underlying, "_ddp_params_and_buffers_to_ignore", []) + ) + second_snapshot = getattr( + underlying, "_protrain_ddp_original_ignore", "" + ) + second_size = len(second_ignore) + + # The snapshot must survive intact (we never re-snapshot). + assert first_snapshot == second_snapshot, ( + f"_protrain_ddp_original_ignore snapshot drifted between " + f"materialize_offload calls: {first_snapshot!r} -> " + f"{second_snapshot!r}. The D2 invariant requires the " + f"pre-protrain snapshot to be captured once and reused." + ) + # The ignore set size must be stable across repeat + # materialize_offload calls — not double / triple / etc. + # the protrain set. + assert second_size == first_size, ( + f"_ddp_params_and_buffers_to_ignore grew from {first_size} to " + f"{second_size} names across a repeat materialize_offload " + f"call — D2 regression: the pre-fix union logic is leaking " + f"stale names across resume cycles." + ) + # And the set membership must be identical (not just same + # cardinality with different names). + assert set(first_ignore) == set(second_ignore), ( + f"_ddp_params_and_buffers_to_ignore CONTENT diverged across " + f"a repeat materialize_offload call. First-only names: " + f"{set(first_ignore) - set(second_ignore)}. Second-only " + f"names: {set(second_ignore) - set(first_ignore)}." + ) + finally: + close = getattr(wrapped, "close", None) + if callable(close): + close() + + +@pytest.mark.gpu +def test_ddp_ignore_snapshot_survives_restore_and_rematerialize() -> None: + """Pre-existing _ddp_params_and_buffers_to_ignore is preserved across materialize_offload and restored on close().""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain D2 invariant requires CUDA.") + + model, cfg = _build_tiny_lora_model() + model = model.to("cuda") + fake_pre_existing = ["caller_registered_ignore_name"] + model._ddp_params_and_buffers_to_ignore = list(fake_pre_existing) # type: ignore[attr-defined] + + wrapped, _optim = _wrap_protrain( + model, cfg, force_all_persistent=False, zero3_shard=True + ) + try: + underlying = getattr(wrapped, "module", wrapped) + chunk_manager = getattr(wrapped, "chunk_manager", None) + if chunk_manager is None or not getattr( + chunk_manager, "_shape_preserving_placeholders", False + ): + pytest.skip( + "single-process Mode C downgrade path: " + "shape-preserving placeholders not engaged." + ) + + # Snapshot must equal the pre-existing value. + snap = getattr(underlying, "_protrain_ddp_original_ignore", None) + assert snap == fake_pre_existing, ( + f"snapshot did not capture pre-existing user value: " + f"expected {fake_pre_existing!r}, got {snap!r}" + ) + # The fake pre-existing name must still be present in the + # post-wrap ignore set (merged with the protrain set). + post_wrap = set(getattr(underlying, "_ddp_params_and_buffers_to_ignore", [])) + assert "caller_registered_ignore_name" in post_wrap + + # Second materialize_offload — same invariants must hold. + assert chunk_manager is not None + chunk_manager.restore_to_gpu() + chunk_manager.materialize_offload() + post_resume = set(getattr(underlying, "_ddp_params_and_buffers_to_ignore", [])) + assert "caller_registered_ignore_name" in post_resume + finally: + close = getattr(wrapped, "close", None) + if callable(close): + close() + + # After close, the snapshot must be restored. + restored = list(getattr(model, "_ddp_params_and_buffers_to_ignore", [])) + assert restored == fake_pre_existing, ( + f"close() did not restore the pre-existing ignore set: " + f"expected {fake_pre_existing!r}, got {restored!r}" + ) + # And the snapshot sentinel should be cleared. + assert not hasattr(model, "_protrain_ddp_original_ignore"), ( + "_protrain_ddp_original_ignore should be cleared after close()" + ) + + +@pytest.mark.gpu +def test_cpu_optim_replaced_calls_shutdown_on_previous() -> None: + """Re-wrapping the optimizer must call shutdown() on the previous cpu_optim before installing the new one.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain D3 invariant requires CUDA.") + + # Probe DeepSpeedCPUAdam availability up front — the CPU adapter + # path needs it to construct, and the test cannot validate D3 + # if the build env can't even build a CPU adapter. + try: + import deepspeed # noqa: F401 + from deepspeed.ops.adam import DeepSpeedCPUAdam + + _probe = torch.nn.Parameter(torch.zeros(2, dtype=torch.float32)) + try: + DeepSpeedCPUAdam([_probe], lr=1e-3) + except Exception as exc: # noqa: BLE001 + pytest.skip( + f"DeepSpeedCPUAdam JIT load failed ({exc}); D3 invariant " + f"requires a working CPU adapter build." + ) + except ImportError: + pytest.skip("deepspeed not installed; D3 invariant requires CPU adapter.") + + from axolotl.integrations.protrain.api import protrain_optimizer_wrapper + + model, cfg = _build_tiny_lora_model() + model = model.to("cuda") + # Force non-persistent chunks so a CpuFusedAdamAdapter actually + # gets constructed. small_chunk=True ensures N_chunk > 1 even on + # this tiny model so the n_persist=0 override produces chunks + # that ARE offloaded. + wrapped, _optim = _wrap_protrain( + model, + cfg, + force_all_persistent=False, + zero3_shard=False, + n_persist_override=0, + n_buffer_override=16, + n_swap_override=0, + n_checkpoint_override=0, + # OFFLOAD mode re-gathers saved tensors on backward via the per-block hook, avoiding the NONE-mode chunk-slot-reuse hazard. + n_offload_override=cfg.num_hidden_layers, + small_chunk=True, + ) + try: + chunk_manager = wrapped.chunk_manager + previous_cpu_optim = getattr(chunk_manager, "cpu_optim", None) + assert previous_cpu_optim is not None, ( + "test setup did not produce a CPU adapter — the D3 invariant " + "needs at least one non-persistent chunk to be exercised. " + "Check that force_all_persistent=False + n_persist_override=0 " + "+ small_chunk=True actually produced non-persistent chunks " + "for this model size." + ) + + # Patch shutdown to record invocation. + shutdown_calls: list[bool] = [] + orig_shutdown = previous_cpu_optim.shutdown + + def _tracked_shutdown(*args, **kwargs): + shutdown_calls.append(True) + return orig_shutdown(*args, **kwargs) + + previous_cpu_optim.shutdown = _tracked_shutdown # type: ignore[method-assign] + + # Re-run the optimizer wrapper — this is the path D3 fixed. + _new_optim = protrain_optimizer_wrapper(wrapped, lr=2e-3) + + # The new cpu_optim must be a different object AND the old + # one's shutdown must have been called. + new_cpu_optim = getattr(chunk_manager, "cpu_optim", None) + assert new_cpu_optim is not previous_cpu_optim, ( + "protrain_optimizer_wrapper did not swap chunk_manager.cpu_optim " + "— the test cannot detect D3 regression." + ) + assert shutdown_calls, ( + "D3 regression: protrain_optimizer_wrapper replaced " + "chunk_manager.cpu_optim without calling shutdown() on the " + "previous adapter. The old adapter's ThreadPoolExecutor + " + "DeepSpeed C-state would leak on every re-wrap." + ) + finally: + close = getattr(wrapped, "close", None) + if callable(close): + close() + + +@pytest.mark.gpu +def test_rewrap_non_shape_preserving_clears_ddp_skip_state() -> None: + """Non-shape-preserving rewrap must clear stale _protrain_ddp_skip_init_sync and ignore-list state from a prior shape-preserving wrap.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain D1 invariant requires CUDA.") + + model, cfg = _build_tiny_lora_model() + model = model.to("cuda") + + # Simulate a prior Mode C wrap's residue on the model. + model._protrain_ddp_skip_init_sync = True # type: ignore[attr-defined] + model._protrain_ddp_original_ignore = None # type: ignore[attr-defined] + model._ddp_params_and_buffers_to_ignore = [ # type: ignore[attr-defined] + "fake.stale.name" + ] + + wrapped, _optim = _wrap_protrain( + model, cfg, force_all_persistent=True, zero3_shard=False + ) + try: + # The D1 else branch must have stripped the markers. + assert not getattr(model, "_protrain_ddp_skip_init_sync", False), ( + "D1 regression: _protrain_ddp_skip_init_sync persisted across " + "a non-shape-preserving rebuild. DDP would silently skip " + "init_sync on the rebuilt Mode A runtime." + ) + assert not hasattr(model, "_protrain_ddp_original_ignore"), ( + "D1 regression: _protrain_ddp_original_ignore not cleared on " + "non-shape-preserving rebuild." + ) + # And the stale ignore-list entry should be gone (because the + # snapshot was None → attribute should be deleted). + assert not hasattr(model, "_ddp_params_and_buffers_to_ignore"), ( + "D1 regression: stale _ddp_params_and_buffers_to_ignore " + "(set to a fake value before the rebuild) was not deleted " + "during the non-shape-preserving rebuild teardown." + ) + finally: + close = getattr(wrapped, "close", None) + if callable(close): + close() + + +@pytest.mark.gpu +def test_resume_hook_inprocess_cycle_continues_training() -> None: + """In-process resume hook cycle (restore_to_gpu, reload state_dict, re-materialize) must produce finite losses without catastrophic divergence.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("ProTrain resume hook in-process cycle requires CUDA.") + + # Probe DeepSpeedCPUAdam availability — the offload-mode wrap path + # needs it to construct, and the resume cycle below rebuilds the + # CPU adapter. Without it, the test would skip mid-cycle which is + # noisier than skipping up front. + try: + import deepspeed # noqa: F401 + from deepspeed.ops.adam import DeepSpeedCPUAdam + + _probe = torch.nn.Parameter(torch.zeros(2, dtype=torch.float32)) + try: + DeepSpeedCPUAdam([_probe], lr=1e-3) + except Exception as exc: # noqa: BLE001 + pytest.skip( + f"DeepSpeedCPUAdam JIT load failed ({exc}); resume cycle " + f"requires a working CPU adapter build." + ) + except ImportError: + pytest.skip("deepspeed not installed; resume cycle requires CPU adapter.") + + from axolotl.integrations.protrain.api import protrain_optimizer_wrapper + + model, cfg = _build_tiny_lora_model() + model = model.to("cuda") + input_ids, labels = _make_batch(cfg) + + # Force chunks off-GPU so materialize_offload actually moves bytes + # (the D2 hot path the test claims to exercise). small_chunk=True + # ensures N_chunk > 1 on the tiny model. + wrapped, optim = _wrap_protrain( + model, + cfg, + force_all_persistent=False, + zero3_shard=False, + n_persist_override=0, + n_buffer_override=16, + n_swap_override=0, + n_checkpoint_override=0, + # OFFLOAD mode re-gathers saved tensors on backward via the per-block hook, avoiding the NONE-mode chunk-slot-reuse hazard. + n_offload_override=cfg.num_hidden_layers, + small_chunk=True, + ) + try: + # Train 3 steps under the initial wrap. + losses_pre = [ + _train_one_step(wrapped, optim, input_ids=input_ids, labels=labels) + for _ in range(3) + ] + for i, lv in enumerate(losses_pre): + assert math.isfinite(lv), f"phase 1 step {i}: non-finite loss {lv}" + + # Simulate the resume hook's in-process cycle. + underlying = getattr(wrapped, "module", wrapped) + chunk_manager = wrapped.chunk_manager + assert chunk_manager is not None + + # Step 1: tear down the CPU optim BEFORE restore_to_gpu (per + # the resume hook's preamble at plugin.py:557-572). This is + # the SAME teardown the production resume hook performs; + # ``restore_to_gpu`` is about to invalidate the CPU shards + # the adapter holds references to. + if getattr(chunk_manager, "cpu_optim", None) is not None: + chunk_manager.cpu_optim.shutdown() + + # Step 2: restore_to_gpu — rebinds param.data back to standalone + # GPU storage so the state_dict capture below sees the real + # parameter shapes (not the ``[0]`` placeholder that's bound + # while chunks are offloaded). The production HF Trainer save + # path has the same property: checkpoints are taken AFTER + # ProTrain's resume hook restores chunks to GPU, not while + # offloaded — otherwise the saved state_dict would have + # ``Size([0])`` entries that would fail to load on resume. + chunk_manager.restore_to_gpu() + + # Step 3: capture the saved state and load it back. In + # production this is the HF Trainer's + # ``trainer.save_state_dict`` → user copies the checkpoint → + # ``_load_from_checkpoint`` cycle; here we do the round-trip + # in-process to keep the smoke unit-scoped. + saved_state = { + k: v.detach().clone() for k, v in underlying.state_dict().items() + } + underlying.load_state_dict(saved_state, strict=False) + + # Second materialize_offload on the same manager actually moves bytes thanks to the non-persistent overrides. + chunk_manager.materialize_offload() + + # Rebuild the optimizer adapter; cpu_optim is None here so this exercises the "no prior adapter" branch. + optim_resumed = protrain_optimizer_wrapper(wrapped, lr=1e-3) + + # Train 3 more steps after the simulated resume. + losses_post = [ + _train_one_step(wrapped, optim_resumed, input_ids=input_ids, labels=labels) + for _ in range(3) + ] + for i, lv in enumerate(losses_post): + assert math.isfinite(lv), ( + f"phase 3 (post-resume) step {i}: non-finite loss {lv}" + ) + + # Continuity: the first post-resume loss should not be wildly + # larger than the last pre-resume loss. Allow 5x as a generous + # bound that catches catastrophic divergence (NaN-precursor, + # state corruption) but tolerates the cold-started optimizer + # state. + assert losses_post[0] < 5.0 * losses_pre[-1] + 1.0, ( + f"resume produced catastrophic divergence: " + f"pre-end={losses_pre[-1]:.4f}, post-start={losses_post[0]:.4f} " + f"(>5x is treated as a state-corruption signal)" + ) + print( + f"\nresume-robustness in-process cycle: " + f"losses_pre={losses_pre} losses_post={losses_post}" + ) + finally: + close = getattr(wrapped, "close", None) + if callable(close): + close() diff --git a/tests/protrain/test_sharded_lora_offload.py b/tests/protrain/test_sharded_lora_offload.py new file mode 100644 index 0000000000..c30e335f95 --- /dev/null +++ b/tests/protrain/test_sharded_lora_offload.py @@ -0,0 +1,359 @@ +"""Multi-rank sharded LoRA gather must restore full param.data shape on the compute stream, avoiding the ToCopyBackward [0] shape mismatch.""" + +from __future__ import annotations + +import os +import sys + +import pytest + +pytestmark = pytest.mark.gpu + + +# --------------------------------------------------------------------------- +# mp.spawn worker bodies (must be top-level so the spawn fork can pickle them) +# --------------------------------------------------------------------------- + + +def _build_tiny_lora_model_cpu(): + """Tiny CPU LoRA-wrapped Linear stack; bf16 base + fp32 lora factors reproduces the mixed-dtype region split.""" + import torch + from torch import nn + + torch.manual_seed(13) + + class _LoraWrappedLinear(nn.Module): + """Mimics PEFT's LoRA-wrapped Linear so chunk-manager offload sees lora_A/lora_B as separate slots in the same chunk.""" + + def __init__(self, in_dim: int, out_dim: int, r: int) -> None: + super().__init__() + self.base_layer = nn.Linear(in_dim, out_dim, bias=False).to(torch.bfloat16) + self.lora_A = nn.ModuleDict({"default": nn.Linear(in_dim, r, bias=False)}) + self.lora_B = nn.ModuleDict({"default": nn.Linear(r, out_dim, bias=False)}) + # Mirror PEFT's autocast_adapter_dtype default: upcast LoRA + # factor weights to fp32 even when the base is bf16. This + # produces the mixed-dtype regions in materialize_offload. + self.lora_A["default"].weight.data = self.lora_A["default"].weight.data.to( + torch.float32 + ) + self.lora_B["default"].weight.data = self.lora_B["default"].weight.data.to( + torch.float32 + ) + + def forward(self, x): # noqa: D401 — small forward + base = self.base_layer(x) + lora_out = self.lora_B["default"]( + self.lora_A["default"](x.to(torch.float32)) + ) + return base + lora_out.to(base.dtype) + + block = _LoraWrappedLinear(in_dim=8, out_dim=8, r=2) + model = nn.Module() + model.h = nn.ModuleList([block]) # type: ignore[attr-defined] + return model + + +def _worker_sharded_lora_gather_rebinds( + rank: int, world_size: int, tmpdir: str +) -> None: + """2-rank gloo: after sharded gather, every LoRA factor param.data must have its full shape back, not the [0] placeholder.""" + import contextlib + import os as _os + + import torch + import torch.distributed as dist + + from axolotl.integrations.protrain.chunk.buffer_pool import BufferPool + from axolotl.integrations.protrain.chunk.layout import build_layout + from axolotl.integrations.protrain.chunk.manager import ChunkManager + from axolotl.integrations.protrain.chunk.pinned_alloc import PinnedHostMemory + from axolotl.integrations.protrain.types import BlockId, ChunkId, ParamId + + _os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + _os.environ.setdefault("MASTER_PORT", "29605") + dist.init_process_group( + backend="gloo", + init_method=f"file://{tmpdir}/rendezvous-sharded-lora", + rank=rank, + world_size=world_size, + ) + + try: + model = _build_tiny_lora_model_cpu() + + # Layout: one block, all params in one chunk (large S_chunk). + block_spans: dict = {} + for name, _p in model.named_parameters(): + block_spans.setdefault(BlockId(0), []).append(ParamId(name)) # type: ignore[index] + exec_order = [ParamId(n) for n, _ in model.named_parameters()] + S_chunk = 1 << 14 # 16 KB — fits the tiny model + layout = build_layout(model, exec_order, S_chunk, block_spans) + + # Snapshot pre-offload shapes so the rebind invariant can be asserted post-gather. + pre_shapes = {str(name): tuple(p.shape) for name, p in model.named_parameters()} + pre_data = { + str(name): p.detach().clone().cpu() for name, p in model.named_parameters() + } + + host = PinnedHostMemory(n_buffer=1, S_chunk=layout.S_chunk) + pool = BufferPool( + n_buffer=1, + S_chunk=layout.S_chunk, + pinned_host=host, + device=torch.device("cpu"), + ) + mgr = ChunkManager( + model=model, + layout=layout, + n_persist=0, + buffer_pool=pool, + cpu_optim=None, + gpu_optim=None, + device=torch.device("cpu"), + world_size=world_size, + rank=rank, + zero3_shard=True, + ) + + try: + mgr.materialize_offload() + except RuntimeError as exc: + if "gloo" in str(exc).lower(): + _os.makedirs(tmpdir, exist_ok=True) + with open(_os.path.join(tmpdir, f"rank{rank}.skip"), "w") as f: + f.write(f"gloo-unsupported: {exc}\n") + return + raise + + # Post-offload invariant: every offloaded LoRA param.data is + # the [0] empty placeholder. This is what the autograd source- + # shape derivation would record if the cast op recorded against + # this state — the bug the rebind is designed to prevent. + for name, p in model.named_parameters(): + if name in {"h.0.base_layer.weight"}: + continue # base weight may or may not be offloaded + assert tuple(p.shape) == (0,), ( + f"rank {rank}: post-materialize_offload, '{name}' should " + f"be the [0] empty placeholder, got shape {tuple(p.shape)}" + ) + + # Sharded gather collective: after this, every LoRA factor's param.data must reflect its real shape so autograd records the correct source-shape. + try: + mgr.gather(ChunkId(0)) + except RuntimeError as exc: + if "not implemented" in str(exc).lower() or "nccl" in str(exc).lower(): + with open(_os.path.join(tmpdir, f"rank{rank}.skip"), "w") as f: + f.write(f"gloo-collective-unsupported: {exc}\n") + return + raise + + # Every LoRA-factor param.data must hold its real shape after the sharded gather; pins the multi-GPU failure mode at unit scope. + for name, p in model.named_parameters(): + assert tuple(p.shape) == pre_shapes[str(name)], ( + f"rank {rank}: post-gather, '{name}' shape " + f"{tuple(p.shape)} != pre-offload {pre_shapes[str(name)]}; " + "the sharded gather did not restore the real shape, so " + "any autograd source-shape derivation against this state " + "would record [0] and backward would fail with " + "'ToCopyBackward0 ... shape compatible with [0]'." + ) + + # Gathered bytes must match the pre-offload snapshot; ensures the routing did not perturb the byte layout. + for name, p in model.named_parameters(): + snap = pre_data[str(name)] + assert torch.allclose(p.data.cpu().float(), snap.float(), atol=0.0), ( + f"rank {rank}: post-gather '{name}' bytes diverge from " + "pre-offload snapshot." + ) + + mgr.uninstall() + host.close() + + finally: + with contextlib.suppress(Exception): + dist.barrier() + dist.destroy_process_group() + + +def _worker_sharded_lora_ensure_chunks_resident( + rank: int, world_size: int, tmpdir: str +) -> None: + """2-rank gloo: Scheduler.ensure_chunks_resident must restore LoRA-factor shape on the compute stream (no prefetch-stream hop).""" + import contextlib + import os as _os + + import torch + import torch.distributed as dist + + from axolotl.integrations.protrain.chunk.buffer_pool import BufferPool + from axolotl.integrations.protrain.chunk.layout import build_layout + from axolotl.integrations.protrain.chunk.manager import ChunkManager + from axolotl.integrations.protrain.chunk.pinned_alloc import PinnedHostMemory + from axolotl.integrations.protrain.runtime.scheduler import Scheduler + from axolotl.integrations.protrain.types import ( + BlockId, + BlockMode, + ChunkId, + ParamId, + ) + + _os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + _os.environ.setdefault("MASTER_PORT", "29607") + dist.init_process_group( + backend="gloo", + init_method=f"file://{tmpdir}/rendezvous-sharded-lora-ecr", + rank=rank, + world_size=world_size, + ) + + try: + model = _build_tiny_lora_model_cpu() + + block_spans: dict = {} + for name, _p in model.named_parameters(): + block_spans.setdefault(BlockId(0), []).append(ParamId(name)) # type: ignore[index] + exec_order = [ParamId(n) for n, _ in model.named_parameters()] + S_chunk = 1 << 14 + layout = build_layout(model, exec_order, S_chunk, block_spans) + + pre_shapes = {str(name): tuple(p.shape) for name, p in model.named_parameters()} + + host = PinnedHostMemory(n_buffer=1, S_chunk=layout.S_chunk) + pool = BufferPool( + n_buffer=1, + S_chunk=layout.S_chunk, + pinned_host=host, + device=torch.device("cpu"), + ) + mgr = ChunkManager( + model=model, + layout=layout, + n_persist=0, + buffer_pool=pool, + cpu_optim=None, + gpu_optim=None, + device=torch.device("cpu"), + world_size=world_size, + rank=rank, + zero3_shard=True, + ) + + try: + mgr.materialize_offload() + except RuntimeError as exc: + if "gloo" in str(exc).lower(): + _os.makedirs(tmpdir, exist_ok=True) + with open(_os.path.join(tmpdir, f"rank{rank}.skip"), "w") as f: + f.write(f"gloo-unsupported: {exc}\n") + return + raise + + # Build a Scheduler. The block_map is needed by Scheduler's + # constructor; for this test we only exercise + # ``ensure_chunks_resident`` which doesn't actually consult + # block-mode keys, so OFFLOAD-everywhere is fine. + block_map = {BlockId(0): BlockMode.OFFLOAD} + # effective_h2d_bps / effective_d2h_bps are telemetry-only here; ensure_chunks_resident does not consult them. + scheduler = Scheduler( + chunk_manager=mgr, + block_map=block_map, + layout=layout, + effective_h2d_bps=1.0e10, + effective_d2h_bps=1.0e10, + ) + + # ensure_chunks_resident routes synchronously through the chunk manager so the rebind is inline. + try: + scheduler.ensure_chunks_resident([ChunkId(0)]) + except RuntimeError as exc: + if "not implemented" in str(exc).lower() or "nccl" in str(exc).lower(): + with open(_os.path.join(tmpdir, f"rank{rank}.skip"), "w") as f: + f.write(f"gloo-collective-unsupported: {exc}\n") + return + raise + + # The container-hook contract: after ensure_chunks_resident + # returns, every LoRA factor param has its real shape and the + # autograd source-shape derivation step (the + # ``ToCopyBackward0`` source-shape recorder in the multi-GPU + # failure mode) reads the correct shape. + for name, p in model.named_parameters(): + assert tuple(p.shape) == pre_shapes[str(name)], ( + f"rank {rank}: after ensure_chunks_resident, '{name}' " + f"shape {tuple(p.shape)} != pre-offload " + f"{pre_shapes[str(name)]}. The Scheduler did not synchronously " + "rebind the LoRA factor's param.data — autograd would " + "record [0] as the source shape and backward fails." + ) + + # Second call must hit the _active_chunks fast path without behavior change (idempotency contract). + scheduler.ensure_chunks_resident([ChunkId(0)]) + for name, p in model.named_parameters(): + assert tuple(p.shape) == pre_shapes[str(name)], ( + f"rank {rank}: idempotent ensure_chunks_resident " + f"second call broke param '{name}' shape: " + f"{tuple(p.shape)} != {pre_shapes[str(name)]}" + ) + + mgr.uninstall() + host.close() + + finally: + with contextlib.suppress(Exception): + dist.barrier() + dist.destroy_process_group() + + +# --------------------------------------------------------------------------- +# Skip-detection helper (mirrors test_chunk_manager_distributed.py pattern) +# --------------------------------------------------------------------------- + + +def _check_skip_files(tmpdir: str, world_size: int) -> None: + """If any worker dropped a ``rank{N}.skip`` file, surface as pytest.skip.""" + for r in range(world_size): + skip_path = os.path.join(tmpdir, f"rank{r}.skip") + if os.path.exists(skip_path): + with open(skip_path) as f: + pytest.skip(f"sharded-lora gloo worker skipped: {f.read().strip()}") + + +# --------------------------------------------------------------------------- +# Test bodies (parent-process spawners) +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +def test_sharded_lora_gather_rebinds_param_data_2rank(tmp_path) -> None: + """Sharded gather across 2 ranks must restore every LoRA factor's full shape, not the [0] placeholder.""" + import torch.multiprocessing as mp + + if sys.platform != "linux": + pytest.skip("mp.spawn / gloo round-trip is linux-only in CI") + + world_size = 2 + mp.spawn( + _worker_sharded_lora_gather_rebinds, + args=(world_size, str(tmp_path)), + nprocs=world_size, + join=True, + ) + _check_skip_files(str(tmp_path), world_size) + + +@pytest.mark.slow +def test_sharded_lora_ensure_chunks_resident_2rank(tmp_path) -> None: + """Same sharded gather invariant driven via Scheduler.ensure_chunks_resident; routing must be synchronous on the compute stream.""" + import torch.multiprocessing as mp + + if sys.platform != "linux": + pytest.skip("mp.spawn / gloo round-trip is linux-only in CI") + + world_size = 2 + mp.spawn( + _worker_sharded_lora_ensure_chunks_resident, + args=(world_size, str(tmp_path)), + nprocs=world_size, + join=True, + ) + _check_skip_files(str(tmp_path), world_size) diff --git a/tests/protrain/test_trace_skip_on_override.py b/tests/protrain/test_trace_skip_on_override.py new file mode 100644 index 0000000000..37a2c059ea --- /dev/null +++ b/tests/protrain/test_trace_skip_on_override.py @@ -0,0 +1,329 @@ +"""Trace pass must be skipped when all four override knobs are set; un-offloaded trace would OOM big offload configs.""" + +from __future__ import annotations + +import importlib.util + +import pytest + +_SEARCH_AVAILABLE = ( + importlib.util.find_spec("axolotl.integrations.protrain.search") is not None +) +_SEARCH_SKIP_REASON = ( + "blocked on M4a search landing " + "(axolotl.integrations.protrain.search not importable)" +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _hw_profile_3090(): + """Return a HardwareProfile describing an RTX 3090.""" + from axolotl.integrations.protrain.types import HardwareProfile + + return HardwareProfile( + gpu_sku="NVIDIA GeForce RTX 3090", + gpu_memory_bytes=24 * (1 << 30), + gpu_count=1, + pcie_h2d_bps=16.0 * (1 << 30), + pcie_d2h_bps=16.0 * (1 << 30), + has_nvlink=False, + ) + + +def _tiny_gpt2(device): + """Tiny GPT-2 LM head on device; 4 layers leaves room for distinct n_swap / n_checkpoint values.""" + pytest.importorskip("transformers") + import torch + from transformers import GPT2Config, GPT2LMHeadModel + + torch.manual_seed(0) + cfg = GPT2Config( + n_layer=4, + n_head=2, + n_embd=64, + vocab_size=128, + n_positions=128, + ) + return GPT2LMHeadModel(cfg).to(device) + + +# --------------------------------------------------------------------------- +# Test 1 — pure unit: synth_trace_from_overrides field shapes +# --------------------------------------------------------------------------- + + +def test_synth_trace_from_overrides_shape() -> None: + """Synthetic ProfilerTrace must have field shapes downstream consumers depend on.""" + pytest.importorskip("torch") + pytest.importorskip("transformers") + import torch + + from axolotl.integrations.protrain.profiler.trace import ( + synth_trace_from_overrides, + ) + from axolotl.integrations.protrain.types import ProfilerTrace + + model = _tiny_gpt2(torch.device("cpu")) + trace = synth_trace_from_overrides( + model, + batch_size=2, + seq_len=64, + device="cpu", + world_size=1, + measure_pcie_bps=False, # CPU-only test path + ) + + assert isinstance(trace, ProfilerTrace) + + # Op-order is empty — _param_exec_order falls back to named_parameters + # declaration order, which is correct for uniform transformer stacks. + assert trace.op_order == () + assert trace.intra_op_delta == {} + assert trace.inter_op_delta == {} + assert trace.op_latencies == {} + assert trace.nccl_gather_s == {} + assert trace.nccl_reduce_s == {} + + # GPT-2 with n_layer=4 should produce 4 entries in activation_sizes. + # The discovery path may also pick up nested sub-blocks; we just + # require >= 1 (the bounds check at model_wrapper.py:2096 needs + # n_block >= 1) and that every value is a positive int. + assert len(trace.activation_sizes) >= 1 + for bid, size in trace.activation_sizes.items(): + assert isinstance(size, int) and size > 0, ( + f"activation_sizes[{bid}] = {size}; expected positive int" + ) + + # model_state_bytes is a real measurement: GPT-2 with n_layer=4 + # n_embd=64 vocab=128 has roughly 80k params, so ~80k * 16 B (default + # param+grad+optim per fp16+adam) ≈ 1.3 MB. Bounds-check liberally: + assert trace.model_state_bytes > 0 + assert trace.model_state_bytes < 100 * (1 << 20) # < 100 MB sanity + + # PCIe defaults when measure_pcie_bps=False: 13 GB/s Gen3 prior. + assert trace.pcie_h2d_bps == pytest.approx(13e9) + assert trace.pcie_d2h_bps == pytest.approx(13e9) + + # Cache key fields populated. + assert trace.bs == 2 + assert trace.seq == 64 + assert trace.world == 1 + assert isinstance(trace.arch_hash, str) and len(trace.arch_hash) == 64 + + # Chunked-runtime fields default to "no measurement" sentinels so the cost model collapses to its earlier path. + assert trace.cpu_adam_bytes_per_sec == 0.0 + assert trace.gpu_adam_bytes_per_sec == 0.0 + assert trace.steady_bwd_chunked_wall_s == 0.0 + + +# --------------------------------------------------------------------------- +# Test 2 — end-to-end: run_trace is NOT called when all four overrides set +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +@pytest.mark.skipif(not _SEARCH_AVAILABLE, reason=_SEARCH_SKIP_REASON) +def test_run_trace_skipped_on_override_full_path( + gpu_device, monkeypatch, tmp_path +) -> None: # noqa: ARG001 — gpu_device fixture activates CUDA masking + """run_trace must not be called when all four overrides are set; fresh cache_dir forces the skip path, not cache-hit.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + from axolotl.integrations.protrain.api import ( + model_wrapper as model_wrapper_mod, + protrain_model_wrapper, + ) + from axolotl.integrations.protrain.types import WrappedModel + + def _exploding_run_trace(*args, **kwargs): # noqa: ARG001 + raise AssertionError( + "run_trace was called on the override-skip path; this is the bug " + "the trace-pass override-skip gate is supposed to prevent." + ) + + monkeypatch.setattr(model_wrapper_mod, "run_trace", _exploding_run_trace) + + device = torch.device("cuda") + model = _tiny_gpt2(device) + hw = _hw_profile_3090() + + # Compute N_chunk/N_block dynamically so layout heuristic shifts don't trip min_n_buffer_for before the skip gate engages. + from axolotl.integrations.protrain.block.layout_rules import ( + discover_blocks, + flatten_block_trees, + ) + from axolotl.integrations.protrain.chunk.layout import build_layout + + discovered = discover_blocks(model) + flat_blocks = flatten_block_trees(discovered) + n_block_estimate = len(flat_blocks) + # Mirror the wrapper's layout build so n_persist_override == N_chunk holds when the override path runs. + block_spans: dict = {} + for name, param in model.named_parameters(): + # Find which block (if any) this param belongs to via the + # discovered block list. + for block_idx, block_module in enumerate(flat_blocks): + if any(p is param for p in block_module.parameters()): + from axolotl.integrations.protrain.types import ( + BlockId, + ParamId, + ) + + block_spans.setdefault(BlockId(block_idx), []).append(ParamId(name)) + break + from typing import cast as _cast + + from axolotl.integrations.protrain.types import ParamId as _ParamId + + exec_order = [_cast(_ParamId, n) for n, _ in model.named_parameters()] + # 4 MiB S_chunk matches the wrapper's default for tiny models; + # the exact value isn't load-bearing as long as the same value is + # used inside ``protrain_model_wrapper`` (which it will be, since + # the override path also takes the wrapper's default S_chunk). + layout = build_layout(model, exec_order, 4 << 20, block_spans) + n_chunk_estimate = layout.N_chunk + + wrapped = protrain_model_wrapper( + model, + model_config=None, + hardware_profile=hw, + batch_size=2, + seq_len=64, + capacity_bytes=1 << 30, + cache_dir=str(tmp_path), # force cache miss + n_persist_override=n_chunk_estimate, + n_buffer_override=0, + n_swap_override=0, + n_checkpoint_override=n_block_estimate, + n_offload_override=0, + auto_mode=False, + ) + try: + assert isinstance(wrapped, WrappedModel) + # The override path's SearchResult round-trips into the wrapper. + assert wrapped.search_result is not None + assert wrapped.search_result.cfg.n_swap == 0 + # n_checkpoint is bounded by N_block which is what activation_sizes + # maps; the synthetic trace populates one entry per discovered + # block. The wrapper accepted the override so the bounds check + # passed — sanity check that we land at n_block from the synth. + assert wrapped.search_result.cfg.n_checkpoint <= n_block_estimate + + finally: + wrapped.close() + + +@pytest.mark.gpu +@pytest.mark.skipif(not _SEARCH_AVAILABLE, reason=_SEARCH_SKIP_REASON) +def test_run_trace_invoked_without_override(gpu_device, monkeypatch, tmp_path) -> None: # noqa: ARG001 — gpu_device fixture activates CUDA masking + """Control: without overrides, run_trace must fire exactly once on a fresh cache_dir.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + from axolotl.integrations.protrain.api import ( + model_wrapper as model_wrapper_mod, + protrain_model_wrapper, + ) + from axolotl.integrations.protrain.types import WrappedModel + + call_count = {"n": 0} + real_run_trace = model_wrapper_mod.run_trace + + def _counting_run_trace(*args, **kwargs): + call_count["n"] += 1 + return real_run_trace(*args, **kwargs) + + monkeypatch.setattr(model_wrapper_mod, "run_trace", _counting_run_trace) + + device = torch.device("cuda") + model = _tiny_gpt2(device) + hw = _hw_profile_3090() + + wrapped = protrain_model_wrapper( + model, + model_config=None, + hardware_profile=hw, + batch_size=2, + seq_len=64, + capacity_bytes=1 << 30, + cache_dir=str(tmp_path), # force cache miss + # No overrides → searcher path → run_trace must fire. + auto_mode=False, + ) + try: + assert isinstance(wrapped, WrappedModel) + assert call_count["n"] == 1, ( + f"run_trace was called {call_count['n']} times; expected exactly 1 " + "on the searcher path with a fresh cache_dir" + ) + + finally: + wrapped.close() + + +# --------------------------------------------------------------------------- +# Test 3 — partial overrides do NOT skip the trace pass +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu +@pytest.mark.skipif(not _SEARCH_AVAILABLE, reason=_SEARCH_SKIP_REASON) +def test_partial_overrides_do_not_skip_trace(gpu_device, monkeypatch, tmp_path) -> None: # noqa: ARG001 — gpu_device fixture activates CUDA masking + """Partial overrides (e.g. only n_persist) must not trigger the skip; the gate requires all four knobs.""" + pytest.importorskip("torch") + import torch + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA runtime") + + from axolotl.integrations.protrain.api import ( + model_wrapper as model_wrapper_mod, + protrain_model_wrapper, + ) + from axolotl.integrations.protrain.types import WrappedModel + + call_count = {"n": 0} + real_run_trace = model_wrapper_mod.run_trace + + def _counting_run_trace(*args, **kwargs): + call_count["n"] += 1 + return real_run_trace(*args, **kwargs) + + monkeypatch.setattr(model_wrapper_mod, "run_trace", _counting_run_trace) + + device = torch.device("cuda") + model = _tiny_gpt2(device) + hw = _hw_profile_3090() + + wrapped = protrain_model_wrapper( + model, + model_config=None, + hardware_profile=hw, + batch_size=2, + seq_len=64, + capacity_bytes=1 << 30, + cache_dir=str(tmp_path), + n_persist_override=1, # only ONE override set + # The other three knobs are None ⇒ partial override ⇒ NO skip. + auto_mode=False, + ) + try: + assert isinstance(wrapped, WrappedModel) + assert call_count["n"] == 1, ( + f"run_trace was called {call_count['n']} times; expected exactly 1 " + "with partial overrides (only n_persist set)" + ) + + finally: + wrapped.close()