Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
fee826b
M0: add ProTrain plugin design doc
thad0ctor Apr 23, 2026
9d1a654
M1a: freeze ProTrain shared types
thad0ctor Apr 23, 2026
431042b
M1: memory-aware profiler
thad0ctor Apr 23, 2026
28d833d
M2: hierarchical chunk manager
thad0ctor Apr 23, 2026
7e3ff76
M3: interleaved block manager
thad0ctor Apr 23, 2026
aa7cf8c
M2 test: fix chunk-manager test contracts and pinned-alloc ctypes path
thad0ctor Apr 23, 2026
81a93b4
M4a: cost models + exhaustive searcher
thad0ctor Apr 23, 2026
5c1b19b
M4b: runtime scheduler + api wrappers
thad0ctor Apr 23, 2026
7e03e05
M4 integration: xfail with BufferPool-exhaustion at forward-block bou…
thad0ctor Apr 23, 2026
cc62164
M4 integration hardening: fix 4 bugs, document 2 runtime gaps
thad0ctor Apr 23, 2026
afa21c7
M5: Axolotl plugin glue + example + e2e test
thad0ctor Apr 23, 2026
10b0248
M4.5: implement init-time chunk offload + per-param grad offload
thad0ctor Apr 23, 2026
875577c
M6: multi-GPU 4x 3090 throughput validation
thad0ctor Apr 23, 2026
8f1f5ba
tests: harden 7B capacity-safety, add SWAP/monotonicity/multi-GPU-der…
thad0ctor Apr 24, 2026
5af9c2e
chunk: fix CPU-Adam race, view-dtype alignment, adapter order, data r…
thad0ctor Apr 24, 2026
45cff47
plugin: wire create_optimizer dispatch, broaden mutex, fix defaults +…
thad0ctor Apr 24, 2026
c481142
profiler: record per-op latencies; cost model uses measured compute; …
thad0ctor Apr 24, 2026
c59ec09
M7: true ZeRO-3 chunk sharding (all_gather / reduce_scatter)
thad0ctor Apr 24, 2026
54d3fe6
M7 followup: cost-model sharding awareness + mixed-dtype shard support
thad0ctor Apr 24, 2026
dbf47bb
bench: multi-GPU throughput comparison (DDP / replicated / ZeRO-3)
thad0ctor Apr 24, 2026
a6ea055
plugin: auto-select multi-GPU mode (A/B/C) based on workload fit + CP…
thad0ctor Apr 24, 2026
7d0892a
profiler: add CPU+GPU Adam microbenchmarks; loosen 7B runtime tolerance
thad0ctor Apr 24, 2026
a1e67a5
profiler: measure hook-less steady-state wall time; cost model scales…
thad0ctor Apr 24, 2026
95243f7
M7 cost-model close-out: PCIe plumb-through + steady-state cap + asym…
thad0ctor Apr 24, 2026
803ac6c
profiler: record steady_fwd_peak_bytes; memory cost model caps at mea…
thad0ctor Apr 24, 2026
814f27e
profiler: record per-block steady peaks; memory cost model uses them …
thad0ctor Apr 24, 2026
f2bd2fa
cost+search: extract hot_iter_peak_cap helper; plumb into searcher's …
thad0ctor Apr 24, 2026
a2234f3
profiler: multi-iter hot-loop steady measurement; cost model uses mea…
thad0ctor Apr 24, 2026
39e966f
test: loosen 7B runtime tolerance 0.25 -> 0.35 for 3090-vs-3090Ti SKU…
thad0ctor Apr 24, 2026
1f69fdc
review followups: bump TRACE_VERSION 6 -> 7; correct 7B test docstring
thad0ctor Apr 26, 2026
3e09937
docs: align comments with paper Eqs. and document trace v7 / swap-buf…
thad0ctor Apr 26, 2026
0c08dcb
profiler: implement on-demand replay (param offload + saved-tensor sp…
thad0ctor Apr 26, 2026
41bd25d
profiler: implement multi-rank NCCL benchmarks (TRACE_VERSION 7 -> 8)
thad0ctor Apr 26, 2026
6b60b87
cost-model: per-SKU compute-rate calibration + LoRA-aware bwd/fwd fal…
thad0ctor Apr 26, 2026
e7393c2
M5/M6: re-add decreasing-loss check + env-var gate skipped E2E tests
thad0ctor Apr 26, 2026
fa50735
docs: ratify two workstream-shape drifts from plan.md
thad0ctor Apr 26, 2026
e9ef434
on-demand: clean up partially-spilled params on __enter__ failure
thad0ctor Apr 26, 2026
3d27896
on-demand: support backward by routing unpack copy through self.device
thad0ctor Apr 26, 2026
a24fb7e
profiler: state-aware on-demand threshold (params + grads + optim state)
thad0ctor Apr 26, 2026
830fe37
profiler: fold requires_grad into arch_hash (TRACE_VERSION 8 -> 9)
thad0ctor Apr 26, 2026
1513b17
hw_bench: fix measure_nccl shard math comment (rounded down, not up)
thad0ctor Apr 26, 2026
7460d00
hw_bench: empty_cache between NCCL payload sizes
thad0ctor Apr 26, 2026
7e565b0
test(plugin_e2e): align comment with strict-< loss-reduction bar
thad0ctor Apr 26, 2026
7caa731
test(multi_gpu_benchmark): guard recorded-threshold tests on canonica…
thad0ctor Apr 26, 2026
bc03a93
test(profiler): smoke-test cost-model under engaged on-demand
thad0ctor Apr 26, 2026
1c6f8e9
docs: update DESIGN.md length self-reference (250 -> 260 lines)
thad0ctor Apr 26, 2026
fd27f0b
docs: document on-demand inflation of intra/inter_op_delta in DESIGN.md
thad0ctor Apr 26, 2026
f4434c2
profiler: dedupe _arch_hash, import canonical version in model_wrapper
thad0ctor Apr 27, 2026
eb82df3
on-demand: prepend pre-gather so intra_op_delta excludes gather bytes
thad0ctor Apr 27, 2026
63182a4
docs: document NCCL measurement gap in default plugin path
thad0ctor Apr 27, 2026
f5e9f7a
chunk-layout: derive exec order from trace.op_order (paper §3.1.1)
thad0ctor Apr 27, 2026
10c5658
plugin: late-bind NCCL measurement in post_trainer_create
thad0ctor Apr 27, 2026
29600aa
chunk: add ChunkManager.restore_to_gpu (materialize_offload inverse)
thad0ctor Apr 27, 2026
c60b4ce
cost-model: D1b translation for phase-2 chunked backward (TRACE_VERSI…
thad0ctor Apr 27, 2026
be94640
wrapper: refactor runtime construction into _construct_runtime helper
thad0ctor Apr 27, 2026
5e6ed13
phase-2: chunked-runtime backward measurement + bootstrap-rebuild plu…
thad0ctor Apr 27, 2026
a3c95fd
test(7b-integration): tighten runtime tolerance 0.35 -> 0.25 for v10/…
thad0ctor Apr 27, 2026
ec65f68
optim-partition: route by _persistent_ids set, not n_persist prefix
thad0ctor Apr 27, 2026
e79eb06
chunk: implement sharded restore_to_gpu via per-region all_gather
thad0ctor Apr 27, 2026
d390ce3
search: add CPU-RAM hard feasibility filter (cpu_capacity_bytes)
thad0ctor Apr 27, 2026
0c9acc4
phase-2: chunked-runtime forward measurement (TRACE_VERSION 11)
thad0ctor Apr 27, 2026
8ea2c82
Bypass chunk comm for phase2 backward runtime
thad0ctor Apr 27, 2026
71793b0
Merge branch 'protrain-sharded-restore' into protrain-paper-fidelity
thad0ctor Apr 27, 2026
6841205
Merge branch 'protrain-cpu-feasibility-filter' into protrain-paper-fi…
thad0ctor Apr 27, 2026
2d88e72
Merge branch 'protrain-phase2-backward-bypass' into protrain-paper-fi…
thad0ctor Apr 27, 2026
e8d14db
docs(protrain): align phase-2 calibration comments
thad0ctor Apr 27, 2026
99afc31
phase-2: calibrate checkpointed offload runtime
thad0ctor Apr 27, 2026
7588ec2
docs(protrain): add optimizer checkpoint/resume design note
thad0ctor Apr 28, 2026
5ce0c15
feat(protrain): Phase 1 optimizer checkpoint/resume (single-rank, non…
thad0ctor Apr 28, 2026
a809491
docs(protrain): Phase 2 checkpoint design — multi-rank + ZeRO-3 sharded
thad0ctor Apr 28, 2026
b959dfb
fix(protrain): Phase 1 review fixes — three bugs caught by review
thad0ctor Apr 28, 2026
865e5b7
test(protrain): subprocess functional-equivalence + post-load CPU pin…
thad0ctor Apr 28, 2026
00e9832
feat(protrain): Phase 2 Mode-B optimizer checkpoint (multi-rank repli…
thad0ctor Apr 28, 2026
aefb819
test(protrain): Phase 2 Mode-B unit + multi-rank gloo tests
thad0ctor Apr 28, 2026
7f6a9b9
feat(protrain): Phase 2 Mode-C optimizer checkpoint (ZeRO-3 sharded)
thad0ctor Apr 28, 2026
164cc3e
test(protrain): Phase 2 Mode-C unit + multi-rank gloo tests
thad0ctor Apr 28, 2026
b70ba03
test(protrain): fix flaky tiny_llama loss check + 4gpu MASTER_PORT co…
thad0ctor Apr 29, 2026
bb02f96
fix(protrain): Mode-C verify gate, bf16 hash, broadcast-aware size gate
thad0ctor Apr 30, 2026
a722635
Merge protrain-test-fixes into the consolidated ProTrain branch
thad0ctor Apr 30, 2026
1c94394
docs(protrain): align schema + design notes with Phase 2 implementation
thad0ctor Apr 30, 2026
3bb9259
feat(protrain): batch_factory abstraction for non-causal-LM calibration
thad0ctor Apr 30, 2026
34a30e3
feat(protrain): preflight NCCL measurement via early dist init
thad0ctor May 1, 2026
cf4055a
fix(protrain): Mode-C lockstep failure protocol + stray-file rejection
thad0ctor May 1, 2026
96c6a7d
perf(protrain): coalesce persistent-chunk grad reduce + clarify gathe…
thad0ctor May 1, 2026
a80a848
fix(protrain): translate phase-2 chunked backward across n_buffer in …
thad0ctor May 1, 2026
348e060
test(protrain): add Mistral Mode-C + SmolLM2 full-FT validation cells…
thad0ctor May 1, 2026
7319f56
test(protrain): add post-v1 validation matrix cells — seq-cls, enc-de…
thad0ctor May 1, 2026
59740c3
feat(protrain): paper-real activation SWAP path (option 2A, minimum v…
thad0ctor May 1, 2026
d384ce5
feat(protrain): T5/encoder-decoder support via discover_blocks BlockTree
thad0ctor May 1, 2026
37c05d5
feat(protrain): paper-real activation SWAP via saved_tensors_hooks (M5+)
thad0ctor May 1, 2026
f5d9aa6
feat(protrain): offline Mode-C cross-world-size reshard tool + test
thad0ctor May 1, 2026
f5d0aa6
perf(protrain): document SWAP backward unpack/free autograd-engine floor
thad0ctor May 1, 2026
007b7be
feat(protrain): per-tree cost-model walk for encoder-decoder peak acc…
thad0ctor May 1, 2026
5747c81
feat(protrain): opt-in Mode-C online cross-world-size reshard on load
thad0ctor May 1, 2026
71cd8de
refactor(protrain): /simplify pass on round-2 commits
thad0ctor May 1, 2026
2ef5f26
test(protrain): M5 CLI end-to-end smoke (axolotl train via subprocess)
thad0ctor May 1, 2026
1da51c6
test(protrain): M6 Mode-C external baseline vs DeepSpeed ZeRO-3
thad0ctor May 1, 2026
78e4259
fix(protrain): re-bind shard_param.grad on set_to_none=True after zer…
thad0ctor May 1, 2026
94bd0c9
fix(protrain): make CpuFusedAdam-unavailable warning honest about cor…
thad0ctor May 1, 2026
5d29598
refactor(protrain): public-promote cost-model helpers used by searcher
thad0ctor May 1, 2026
5df91d5
refactor(protrain): extract _perform_online_reshard helper from load …
thad0ctor May 1, 2026
817e494
refactor(protrain): persist BlockId->tree_index in ProfilerTrace
thad0ctor May 1, 2026
491b5e2
fix(protrain): address CodeRabbit PR #10 — 35 findings + multi-rank c…
thad0ctor May 3, 2026
e900a69
fix(protrain): address CodeRabbit PR #10 May-3 round (18 findings)
thad0ctor May 3, 2026
646d3ea
fix(protrain): CodeRabbit PR #10 round-2 + CI cleanup (6 findings + l…
thad0ctor May 3, 2026
a6b4c20
fix(protrain): CodeRabbit PR #10 round-3 (12 findings + test contract…
thad0ctor May 3, 2026
4454317
fix(protrain): CodeRabbit PR #10 round-4 (6 inline + 3 duplicates + C…
thad0ctor May 4, 2026
b0df26f
fix(protrain): CodeRabbit PR #10 round-5 (7 findings + 2 CI test fixes)
thad0ctor May 4, 2026
0c6997a
fix(protrain): CodeRabbit PR #10 round-6 (4 findings + caplog propaga…
thad0ctor May 4, 2026
edc20fa
fix(protrain): CodeRabbit PR #10 round-7 (3 findings)
thad0ctor May 4, 2026
6c5836e
fix(protrain): CodeRabbit PR #10 round-7b nitpick (post_trainer_creat…
thad0ctor May 4, 2026
430b4a0
fix(protrain): CodeRabbit PR #10 round-7c (LOCAL_RANK guard at pre-wr…
thad0ctor May 4, 2026
4934673
fix(protrain): CodeRabbit PR #12 round-1 (24 findings + 2 test fixes)
thad0ctor May 4, 2026
16809dc
fix(protrain): CodeRabbit PR #12 round-2 (10 findings + option-B desi…
thad0ctor May 4, 2026
8264f77
feat(protrain): Option B M1 + M2 — BlockMode.OFFLOAD types + runtime …
thad0ctor May 4, 2026
a1ab8af
feat(protrain): CodeRabbit round-3 (10 findings) + Option B M3 (sched…
thad0ctor May 5, 2026
ea20710
feat(protrain): Option B M4 — cost model + searcher (n_offload axis)
thad0ctor May 5, 2026
94fbca1
fix(protrain): CodeRabbit PR #12 round-4 (6 findings on a1ab8aff)
thad0ctor May 5, 2026
c7c155f
feat(protrain): CodeRabbit round-5 + Option B M5 — OFFLOAD ships end-…
thad0ctor May 5, 2026
d44f9c9
fix(protrain): CodeRabbit PR #13 round-1 (19 findings)
thad0ctor May 5, 2026
a927fa7
fix(protrain): CodeRabbit PR #13 round-2 (10 findings)
thad0ctor May 5, 2026
ac9e00f
fix(protrain): CodeRabbit PR #13 round-3 (2 findings)
thad0ctor May 5, 2026
b83a89f
fix(protrain): drop unused import to satisfy pre-commit ruff
thad0ctor May 5, 2026
0ccbc5d
ci: disable uv persistent cache in sdist job to fix Py3.12 install
thad0ctor May 5, 2026
48b9311
fix(protrain): CodeRabbit PR #14 round-1 (24 findings)
thad0ctor May 5, 2026
5383cdb
fix(protrain): CodeRabbit PR #14 round-2 (6 findings + lint fix)
thad0ctor May 5, 2026
018445d
fix(protrain): CodeRabbit PR #15 round-1 (14 findings)
thad0ctor May 5, 2026
c8f752f
fix(protrain): CodeRabbit PR #15 round-2 (3 findings)
thad0ctor May 5, 2026
c99b23a
fix(protrain): CodeRabbit PR #15 round-3 (2 findings)
thad0ctor May 5, 2026
09e8c9e
fix(protrain): CodeRabbit PR #16 round-1 (12 findings)
thad0ctor May 5, 2026
4b1a1e0
fix(protrain): CodeRabbit PR #17 round-1 (18 findings)
thad0ctor May 5, 2026
108ef58
fix(protrain): CodeRabbit PR #17 round-2 (9 findings)
thad0ctor May 5, 2026
f6f63d5
fix(protrain): CodeRabbit PR #17 round-3 (3 findings)
thad0ctor May 5, 2026
019df69
fix(protrain): CodeRabbit PR #17 round-4 (3 findings)
thad0ctor May 5, 2026
498e1af
fix(protrain): CodeRabbit PR #17 round-5 (3 findings + 1 followup)
thad0ctor May 5, 2026
c584e29
fix(protrain): CodeRabbit PR #17 round-6 (1 finding)
thad0ctor May 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ jobs:

- name: Install uv
uses: astral-sh/setup-uv@v7
with:
# Disable the action's persistent cache. With caching enabled
# the sdist install fails on Python 3.12 with
# "Failed to deserialize cache entry: invalid ID" — the cache
# entry written by one uv version is unreadable by the next,
# producing a deterministic failure across CI runs (same
# hash ID every time). The Python 3.14 leg is unaffected.
# Disabling cache for this single job costs ~10s of pip
# install time but unblocks Py3.12 sdist install.
enable-cache: false

- name: Install PyTorch
run: |
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ lora-out/*
qlora-out/*
mlruns/*

# Benchmark output (machine-specific, regenerate via scripts/benchmark_*.py)
scripts/*_results.json
scripts/**/*_results.json

/.quarto/
prepared-datasets/
submit.sh
Expand Down
115 changes: 115 additions & 0 deletions examples/protrain/3090-7b-lora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# ProTrain 7B/8B LoRA on a single RTX 3090 (24 GB)
#
# Opts into the ProTrain plugin via `plugins:`. The plugin's post_model_load
# hook wraps the model with the hierarchical chunk manager + interleaved
# block manager. The plugin's post_trainer_create hook then installs
# `protrain_optimizer_wrapper` on trainer.optimizer — this is the real
# wiring path because Axolotl's OptimizerMixin.create_optimizer does NOT
# dispatch to PluginManager.create_optimizer (see plugin.py for why).
#
# Mode selection is automatic. Leave ``protrain_auto_mode`` on (default);
# the plugin runs the searcher and then picks Mode A (GPU-resident / DDP-
# friendly), Mode B (replicated CPU-offload), or Mode C (ZeRO-3 sharded
# CPU-offload) based on the model's fit and per-rank CPU RAM. For 7B/8B
# LoRA on a single 24 GB 3090 the selector picks Mode A — the frozen
# base fits in fp16 alongside LoRA optimizer state + activations, and
# DDP scales at ~3.6x on PCIe Gen3 4x 3090 while ZeRO-3 sharding on
# the same rig lands at ~0.7x (see DESIGN.md §Multi-GPU).
#
# Set ``protrain_auto_mode: false`` below only if you need explicit
# control (reproducing a specific benchmark configuration, or a
# heterogeneous-CPU setup where the node-RAM/world-size heuristic is
# wrong). In that case ``protrain_force_all_persistent`` and
# ``protrain_zero3_shard`` become the explicit overrides.

# NousResearch/Meta-Llama-3-8B-Instruct is the 8B-class Llama mirror on HF
# Hub that is *not* gated (public-license, no HF-terms accept step). It was
# chosen over mistralai/Mistral-7B-v0.3 (gated: 401 for new users) and
# meta-llama/Llama-3.1-8B (gated: requires accepted license) for frictionless
# downloads in CI and first-run contributors. HuggingFaceH4/zephyr-7b-beta is
# an equivalent ungated fallback if the Llama arch is undesirable.
base_model: NousResearch/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
val_set_size: 0.0
output_dir: ./outputs/protrain-3090-7b-lora

sequence_len: 256 # small to keep activation memory low
sample_packing: false
pad_to_sequence_len: false

adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
- gate_proj

plugins:
- axolotl.integrations.protrain.ProTrainPlugin

# -- ProTrain knobs (see axolotl.integrations.protrain.args.ProTrainArgs) --
protrain_auto_memory: true
# Leave auto-mode on (default); the plugin picks the right mode.
# protrain_auto_mode: true # default — the selector handles it
# protrain_force_all_persistent: true # explicit override (only honoured when protrain_auto_mode=false)

gradient_accumulation_steps: 1
micro_batch_size: 1
max_steps: 20
optimizer: adamw_torch # adamw_torch baseline; ProTrainPlugin.post_trainer_create replaces this with protrain_optimizer_wrapper
lr_scheduler: cosine
learning_rate: 0.0002

bf16: true
fp16: false
tf32: false

# IMPORTANT: the ProTrain block manager installs its own CKPT hooks when
# the searcher assigns a block to CKPT mode (typical for tight-capacity
# offload configs). Enabling Axolotl / HuggingFace gradient checkpointing
# here would double-checkpoint the forward pass — and the ProTrainArgs
# validator will refuse the config.
gradient_checkpointing: false

flash_attention: false
xformers_attention: false

# IMPORTANT: Axolotl auto-enables fused Triton LoRA kernels (q/k/v/o/MLP)
# when these flags are unset. Those kernels read raw weight tensors
# directly via torch.matmul; ProTrain's profiler engages "on-demand"
# mode for 7B+ models on a 24 GB card (model state > 60% of device
# memory) and offloads params to CPU between modules using forward
# hooks. The Axolotl LoRA kernels bypass nn.Linear's standard forward
# hook machinery, so the offload-then-restore pattern does not see
# them and they read empty/CPU tensors -> RuntimeError("size mismatch
# ... vec (0)") inside matmul_lora. Disable them here to keep the
# stock PEFT LoRA forward path (which IS hookable) so the profiler's
# on-demand pass works. The performance cost is ~5-10% on this
# 7B-class workload — acceptable for the M5 acceptance run, and the
# steady-state runtime under the chunk manager itself is dominated by
# H2D/D2H traffic rather than LoRA matmul throughput.
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

logging_steps: 1
save_steps: 20
save_first_step: false
save_total_limit: 1

warmup_steps: 2
weight_decay: 0.0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ docstring-code-format = false
addopts = "-m 'not slow'"
markers = [
"slow: marks tests as slow",
"gpu: marks tests that require a CUDA GPU",
]

# UV specific configuration
Expand Down
Loading
Loading