forked from axolotl-ai-cloud/axolotl
-
Notifications
You must be signed in to change notification settings - Fork 0
feat: ProTrain integration with BlockMode.OFFLOAD (Option B complete) #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
129 commits
Select commit
Hold shift + click to select a range
fee826b
M0: add ProTrain plugin design doc
thad0ctor 9d1a654
M1a: freeze ProTrain shared types
thad0ctor 431042b
M1: memory-aware profiler
thad0ctor 28d833d
M2: hierarchical chunk manager
thad0ctor 7e3ff76
M3: interleaved block manager
thad0ctor aa7cf8c
M2 test: fix chunk-manager test contracts and pinned-alloc ctypes path
thad0ctor 81a93b4
M4a: cost models + exhaustive searcher
thad0ctor 5c1b19b
M4b: runtime scheduler + api wrappers
thad0ctor 7e03e05
M4 integration: xfail with BufferPool-exhaustion at forward-block bou…
thad0ctor cc62164
M4 integration hardening: fix 4 bugs, document 2 runtime gaps
thad0ctor afa21c7
M5: Axolotl plugin glue + example + e2e test
thad0ctor 10b0248
M4.5: implement init-time chunk offload + per-param grad offload
thad0ctor 875577c
M6: multi-GPU 4x 3090 throughput validation
thad0ctor 8f1f5ba
tests: harden 7B capacity-safety, add SWAP/monotonicity/multi-GPU-der…
thad0ctor 5af9c2e
chunk: fix CPU-Adam race, view-dtype alignment, adapter order, data r…
thad0ctor 45cff47
plugin: wire create_optimizer dispatch, broaden mutex, fix defaults +…
thad0ctor c481142
profiler: record per-op latencies; cost model uses measured compute; …
thad0ctor c59ec09
M7: true ZeRO-3 chunk sharding (all_gather / reduce_scatter)
thad0ctor 54d3fe6
M7 followup: cost-model sharding awareness + mixed-dtype shard support
thad0ctor dbf47bb
bench: multi-GPU throughput comparison (DDP / replicated / ZeRO-3)
thad0ctor a6ea055
plugin: auto-select multi-GPU mode (A/B/C) based on workload fit + CP…
thad0ctor 7d0892a
profiler: add CPU+GPU Adam microbenchmarks; loosen 7B runtime tolerance
thad0ctor a1e67a5
profiler: measure hook-less steady-state wall time; cost model scales…
thad0ctor 95243f7
M7 cost-model close-out: PCIe plumb-through + steady-state cap + asym…
thad0ctor 803ac6c
profiler: record steady_fwd_peak_bytes; memory cost model caps at mea…
thad0ctor 814f27e
profiler: record per-block steady peaks; memory cost model uses them …
thad0ctor f2bd2fa
cost+search: extract hot_iter_peak_cap helper; plumb into searcher's …
thad0ctor a2234f3
profiler: multi-iter hot-loop steady measurement; cost model uses mea…
thad0ctor 39e966f
test: loosen 7B runtime tolerance 0.25 -> 0.35 for 3090-vs-3090Ti SKU…
thad0ctor 1f69fdc
review followups: bump TRACE_VERSION 6 -> 7; correct 7B test docstring
thad0ctor 3e09937
docs: align comments with paper Eqs. and document trace v7 / swap-buf…
thad0ctor 0c08dcb
profiler: implement on-demand replay (param offload + saved-tensor sp…
thad0ctor 41bd25d
profiler: implement multi-rank NCCL benchmarks (TRACE_VERSION 7 -> 8)
thad0ctor 6b60b87
cost-model: per-SKU compute-rate calibration + LoRA-aware bwd/fwd fal…
thad0ctor e7393c2
M5/M6: re-add decreasing-loss check + env-var gate skipped E2E tests
thad0ctor fa50735
docs: ratify two workstream-shape drifts from plan.md
thad0ctor e9ef434
on-demand: clean up partially-spilled params on __enter__ failure
thad0ctor 3d27896
on-demand: support backward by routing unpack copy through self.device
thad0ctor a24fb7e
profiler: state-aware on-demand threshold (params + grads + optim state)
thad0ctor 830fe37
profiler: fold requires_grad into arch_hash (TRACE_VERSION 8 -> 9)
thad0ctor 1513b17
hw_bench: fix measure_nccl shard math comment (rounded down, not up)
thad0ctor 7460d00
hw_bench: empty_cache between NCCL payload sizes
thad0ctor 7e565b0
test(plugin_e2e): align comment with strict-< loss-reduction bar
thad0ctor 7caa731
test(multi_gpu_benchmark): guard recorded-threshold tests on canonica…
thad0ctor bc03a93
test(profiler): smoke-test cost-model under engaged on-demand
thad0ctor 1c6f8e9
docs: update DESIGN.md length self-reference (250 -> 260 lines)
thad0ctor fd27f0b
docs: document on-demand inflation of intra/inter_op_delta in DESIGN.md
thad0ctor f4434c2
profiler: dedupe _arch_hash, import canonical version in model_wrapper
thad0ctor eb82df3
on-demand: prepend pre-gather so intra_op_delta excludes gather bytes
thad0ctor 63182a4
docs: document NCCL measurement gap in default plugin path
thad0ctor f5e9f7a
chunk-layout: derive exec order from trace.op_order (paper §3.1.1)
thad0ctor 10c5658
plugin: late-bind NCCL measurement in post_trainer_create
thad0ctor 29600aa
chunk: add ChunkManager.restore_to_gpu (materialize_offload inverse)
thad0ctor c60b4ce
cost-model: D1b translation for phase-2 chunked backward (TRACE_VERSI…
thad0ctor be94640
wrapper: refactor runtime construction into _construct_runtime helper
thad0ctor 5e6ed13
phase-2: chunked-runtime backward measurement + bootstrap-rebuild plu…
thad0ctor a3c95fd
test(7b-integration): tighten runtime tolerance 0.35 -> 0.25 for v10/…
thad0ctor ec65f68
optim-partition: route by _persistent_ids set, not n_persist prefix
thad0ctor e79eb06
chunk: implement sharded restore_to_gpu via per-region all_gather
thad0ctor d390ce3
search: add CPU-RAM hard feasibility filter (cpu_capacity_bytes)
thad0ctor 0c9acc4
phase-2: chunked-runtime forward measurement (TRACE_VERSION 11)
thad0ctor 8ea2c82
Bypass chunk comm for phase2 backward runtime
thad0ctor 71793b0
Merge branch 'protrain-sharded-restore' into protrain-paper-fidelity
thad0ctor 6841205
Merge branch 'protrain-cpu-feasibility-filter' into protrain-paper-fi…
thad0ctor 2d88e72
Merge branch 'protrain-phase2-backward-bypass' into protrain-paper-fi…
thad0ctor e8d14db
docs(protrain): align phase-2 calibration comments
thad0ctor 99afc31
phase-2: calibrate checkpointed offload runtime
thad0ctor 7588ec2
docs(protrain): add optimizer checkpoint/resume design note
thad0ctor 5ce0c15
feat(protrain): Phase 1 optimizer checkpoint/resume (single-rank, non…
thad0ctor a809491
docs(protrain): Phase 2 checkpoint design — multi-rank + ZeRO-3 sharded
thad0ctor b959dfb
fix(protrain): Phase 1 review fixes — three bugs caught by review
thad0ctor 865e5b7
test(protrain): subprocess functional-equivalence + post-load CPU pin…
thad0ctor 00e9832
feat(protrain): Phase 2 Mode-B optimizer checkpoint (multi-rank repli…
thad0ctor aefb819
test(protrain): Phase 2 Mode-B unit + multi-rank gloo tests
thad0ctor 7f6a9b9
feat(protrain): Phase 2 Mode-C optimizer checkpoint (ZeRO-3 sharded)
thad0ctor 164cc3e
test(protrain): Phase 2 Mode-C unit + multi-rank gloo tests
thad0ctor b70ba03
test(protrain): fix flaky tiny_llama loss check + 4gpu MASTER_PORT co…
thad0ctor bb02f96
fix(protrain): Mode-C verify gate, bf16 hash, broadcast-aware size gate
thad0ctor a722635
Merge protrain-test-fixes into the consolidated ProTrain branch
thad0ctor 1c94394
docs(protrain): align schema + design notes with Phase 2 implementation
thad0ctor 3bb9259
feat(protrain): batch_factory abstraction for non-causal-LM calibration
thad0ctor 34a30e3
feat(protrain): preflight NCCL measurement via early dist init
thad0ctor cf4055a
fix(protrain): Mode-C lockstep failure protocol + stray-file rejection
thad0ctor 96c6a7d
perf(protrain): coalesce persistent-chunk grad reduce + clarify gathe…
thad0ctor a80a848
fix(protrain): translate phase-2 chunked backward across n_buffer in …
thad0ctor 348e060
test(protrain): add Mistral Mode-C + SmolLM2 full-FT validation cells…
thad0ctor 7319f56
test(protrain): add post-v1 validation matrix cells — seq-cls, enc-de…
thad0ctor 59740c3
feat(protrain): paper-real activation SWAP path (option 2A, minimum v…
thad0ctor d384ce5
feat(protrain): T5/encoder-decoder support via discover_blocks BlockTree
thad0ctor 37c05d5
feat(protrain): paper-real activation SWAP via saved_tensors_hooks (M5+)
thad0ctor f5d9aa6
feat(protrain): offline Mode-C cross-world-size reshard tool + test
thad0ctor f5d0aa6
perf(protrain): document SWAP backward unpack/free autograd-engine floor
thad0ctor 007b7be
feat(protrain): per-tree cost-model walk for encoder-decoder peak acc…
thad0ctor 5747c81
feat(protrain): opt-in Mode-C online cross-world-size reshard on load
thad0ctor 71cd8de
refactor(protrain): /simplify pass on round-2 commits
thad0ctor 2ef5f26
test(protrain): M5 CLI end-to-end smoke (axolotl train via subprocess)
thad0ctor 1da51c6
test(protrain): M6 Mode-C external baseline vs DeepSpeed ZeRO-3
thad0ctor 78e4259
fix(protrain): re-bind shard_param.grad on set_to_none=True after zer…
thad0ctor 94bd0c9
fix(protrain): make CpuFusedAdam-unavailable warning honest about cor…
thad0ctor 5d29598
refactor(protrain): public-promote cost-model helpers used by searcher
thad0ctor 5df91d5
refactor(protrain): extract _perform_online_reshard helper from load …
thad0ctor 817e494
refactor(protrain): persist BlockId->tree_index in ProfilerTrace
thad0ctor 491b5e2
fix(protrain): address CodeRabbit PR #10 — 35 findings + multi-rank c…
thad0ctor e900a69
fix(protrain): address CodeRabbit PR #10 May-3 round (18 findings)
thad0ctor 646d3ea
fix(protrain): CodeRabbit PR #10 round-2 + CI cleanup (6 findings + l…
thad0ctor a6b4c20
fix(protrain): CodeRabbit PR #10 round-3 (12 findings + test contract…
thad0ctor 4454317
fix(protrain): CodeRabbit PR #10 round-4 (6 inline + 3 duplicates + C…
thad0ctor b0df26f
fix(protrain): CodeRabbit PR #10 round-5 (7 findings + 2 CI test fixes)
thad0ctor 0c6997a
fix(protrain): CodeRabbit PR #10 round-6 (4 findings + caplog propaga…
thad0ctor edc20fa
fix(protrain): CodeRabbit PR #10 round-7 (3 findings)
thad0ctor 6c5836e
fix(protrain): CodeRabbit PR #10 round-7b nitpick (post_trainer_creat…
thad0ctor 430b4a0
fix(protrain): CodeRabbit PR #10 round-7c (LOCAL_RANK guard at pre-wr…
thad0ctor 4934673
fix(protrain): CodeRabbit PR #12 round-1 (24 findings + 2 test fixes)
thad0ctor 16809dc
fix(protrain): CodeRabbit PR #12 round-2 (10 findings + option-B desi…
thad0ctor 8264f77
feat(protrain): Option B M1 + M2 — BlockMode.OFFLOAD types + runtime …
thad0ctor a1ab8af
feat(protrain): CodeRabbit round-3 (10 findings) + Option B M3 (sched…
thad0ctor ea20710
feat(protrain): Option B M4 — cost model + searcher (n_offload axis)
thad0ctor 94fbca1
fix(protrain): CodeRabbit PR #12 round-4 (6 findings on a1ab8aff)
thad0ctor c7c155f
feat(protrain): CodeRabbit round-5 + Option B M5 — OFFLOAD ships end-…
thad0ctor d44f9c9
fix(protrain): CodeRabbit PR #13 round-1 (19 findings)
thad0ctor a927fa7
fix(protrain): CodeRabbit PR #13 round-2 (10 findings)
thad0ctor ac9e00f
fix(protrain): CodeRabbit PR #13 round-3 (2 findings)
thad0ctor b83a89f
fix(protrain): drop unused import to satisfy pre-commit ruff
thad0ctor 0ccbc5d
ci: disable uv persistent cache in sdist job to fix Py3.12 install
thad0ctor 48b9311
fix(protrain): CodeRabbit PR #14 round-1 (24 findings)
thad0ctor 5383cdb
fix(protrain): CodeRabbit PR #14 round-2 (6 findings + lint fix)
thad0ctor 018445d
fix(protrain): CodeRabbit PR #15 round-1 (14 findings)
thad0ctor c8f752f
fix(protrain): CodeRabbit PR #15 round-2 (3 findings)
thad0ctor c99b23a
fix(protrain): CodeRabbit PR #15 round-3 (2 findings)
thad0ctor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # ProTrain 7B/8B LoRA on a single RTX 3090 (24 GB) | ||
| # | ||
| # Opts into the ProTrain plugin via `plugins:`. The plugin's post_model_load | ||
| # hook wraps the model with the hierarchical chunk manager + interleaved | ||
| # block manager. The plugin's post_trainer_create hook then installs | ||
| # `protrain_optimizer_wrapper` on trainer.optimizer — this is the real | ||
| # wiring path because Axolotl's OptimizerMixin.create_optimizer does NOT | ||
| # dispatch to PluginManager.create_optimizer (see plugin.py for why). | ||
| # | ||
| # Mode selection is automatic. Leave ``protrain_auto_mode`` on (default); | ||
| # the plugin runs the searcher and then picks Mode A (GPU-resident / DDP- | ||
| # friendly), Mode B (replicated CPU-offload), or Mode C (ZeRO-3 sharded | ||
| # CPU-offload) based on the model's fit and per-rank CPU RAM. For 7B/8B | ||
| # LoRA on a single 24 GB 3090 the selector picks Mode A — the frozen | ||
| # base fits in fp16 alongside LoRA optimizer state + activations, and | ||
| # DDP scales at ~3.6x on PCIe Gen3 4x 3090 while ZeRO-3 sharding on | ||
| # the same rig lands at ~0.7x (see DESIGN.md §Multi-GPU). | ||
| # | ||
| # Set ``protrain_auto_mode: false`` below only if you need explicit | ||
| # control (reproducing a specific benchmark configuration, or a | ||
| # heterogeneous-CPU setup where the node-RAM/world-size heuristic is | ||
| # wrong). In that case ``protrain_force_all_persistent`` and | ||
| # ``protrain_zero3_shard`` become the explicit overrides. | ||
|
|
||
| # NousResearch/Meta-Llama-3-8B-Instruct is the 8B-class Llama mirror on HF | ||
| # Hub that is *not* gated (public-license, no HF-terms accept step). It was | ||
| # chosen over mistralai/Mistral-7B-v0.3 (gated: 401 for new users) and | ||
| # meta-llama/Llama-3.1-8B (gated: requires accepted license) for frictionless | ||
| # downloads in CI and first-run contributors. HuggingFaceH4/zephyr-7b-beta is | ||
| # an equivalent ungated fallback if the Llama arch is undesirable. | ||
| base_model: NousResearch/Meta-Llama-3-8B-Instruct | ||
| model_type: LlamaForCausalLM | ||
|
|
||
| load_in_8bit: false | ||
| load_in_4bit: false | ||
| strict: false | ||
|
|
||
| datasets: | ||
| - path: tatsu-lab/alpaca | ||
| type: alpaca | ||
| val_set_size: 0.0 | ||
| output_dir: ./outputs/protrain-3090-7b-lora | ||
|
|
||
| sequence_len: 256 # small to keep activation memory low | ||
| sample_packing: false | ||
| pad_to_sequence_len: false | ||
|
|
||
| adapter: lora | ||
| lora_r: 16 | ||
| lora_alpha: 32 | ||
| lora_dropout: 0.05 | ||
| lora_target_modules: | ||
| - q_proj | ||
| - k_proj | ||
| - v_proj | ||
| - o_proj | ||
| - up_proj | ||
| - down_proj | ||
| - gate_proj | ||
|
|
||
| plugins: | ||
| - axolotl.integrations.protrain.ProTrainPlugin | ||
|
|
||
| # -- ProTrain knobs (see axolotl.integrations.protrain.args.ProTrainArgs) -- | ||
| protrain_auto_memory: true | ||
| # Leave auto-mode on (default); the plugin picks the right mode. | ||
| # protrain_auto_mode: true # default — the selector handles it | ||
| # protrain_force_all_persistent: true # explicit override (only honoured when protrain_auto_mode=false) | ||
|
|
||
| gradient_accumulation_steps: 1 | ||
| micro_batch_size: 1 | ||
| max_steps: 20 | ||
| optimizer: adamw_torch # adamw_torch baseline; ProTrainPlugin.post_trainer_create replaces this with protrain_optimizer_wrapper | ||
| lr_scheduler: cosine | ||
| learning_rate: 0.0002 | ||
|
|
||
| bf16: true | ||
| fp16: false | ||
| tf32: false | ||
|
|
||
| # IMPORTANT: the ProTrain block manager installs its own CKPT hooks when | ||
| # the searcher assigns a block to CKPT mode (typical for tight-capacity | ||
| # offload configs). Enabling Axolotl / HuggingFace gradient checkpointing | ||
| # here would double-checkpoint the forward pass — and the ProTrainArgs | ||
| # validator will refuse the config. | ||
| gradient_checkpointing: false | ||
|
|
||
| flash_attention: false | ||
| xformers_attention: false | ||
|
|
||
| # IMPORTANT: Axolotl auto-enables fused Triton LoRA kernels (q/k/v/o/MLP) | ||
| # when these flags are unset. Those kernels read raw weight tensors | ||
| # directly via torch.matmul; ProTrain's profiler engages "on-demand" | ||
| # mode for 7B+ models on a 24 GB card (model state > 60% of device | ||
| # memory) and offloads params to CPU between modules using forward | ||
| # hooks. The Axolotl LoRA kernels bypass nn.Linear's standard forward | ||
| # hook machinery, so the offload-then-restore pattern does not see | ||
| # them and they read empty/CPU tensors -> RuntimeError("size mismatch | ||
| # ... vec (0)") inside matmul_lora. Disable them here to keep the | ||
| # stock PEFT LoRA forward path (which IS hookable) so the profiler's | ||
| # on-demand pass works. The performance cost is ~5-10% on this | ||
| # 7B-class workload — acceptable for the M5 acceptance run, and the | ||
| # steady-state runtime under the chunk manager itself is dominated by | ||
| # H2D/D2H traffic rather than LoRA matmul throughput. | ||
| lora_mlp_kernel: false | ||
| lora_qkv_kernel: false | ||
| lora_o_kernel: false | ||
|
|
||
| logging_steps: 1 | ||
| save_steps: 20 | ||
| save_first_step: false | ||
| save_total_limit: 1 | ||
|
|
||
| warmup_steps: 2 | ||
| weight_decay: 0.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.