Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
12f1a12
chore(protrain): fix pre-existing mypy on model_wrapper:386 + formatting
thad0ctor May 9, 2026
6c3fcb1
feat(protrain): enable FlashAttention in canonical LoRA example (M4)
thad0ctor May 9, 2026
45a934f
feat(protrain): allow load_in_8bit / load_in_4bit (M2+M3 Mode A)
thad0ctor May 9, 2026
1fe8ddb
feat(protrain): integrate fused LoRA kernels via container hooks (M1)
thad0ctor May 9, 2026
a6dfc37
feat(protrain): add bnb 8-bit AdamW optimizer adapter (M2.5)
thad0ctor May 9, 2026
823b4db
feat(protrain): reject unsupported optimizers at config load (M6B)
thad0ctor May 9, 2026
c857675
test(protrain): PEFT edge-case smoke tests (M6A)
thad0ctor May 10, 2026
3ce55a8
test(protrain): cross-mode (A↔C) resume smoke tests (M6C)
thad0ctor May 10, 2026
6cb5c84
fix(protrain): force_all_persistent suppresses trace-pass on-demand e…
thad0ctor May 10, 2026
a868978
test(protrain): pin bnb 8-bit/4-bit + ProTrain offload-mode (M3 audit…
thad0ctor May 10, 2026
91e0912
fix(p2p): rank-symmetric check_cuda_p2p_support + measure_nccl barrier
thad0ctor May 10, 2026
a00df59
test(protrain): real multi-GPU cross-mode resume xfail tests (M6C)
thad0ctor May 10, 2026
016dac8
docs(protrain): document mode-pinned checkpoints + Mode C plain LoRA gap
thad0ctor May 10, 2026
4856090
feat(protrain): per-container PEFT-LoRA gather in on-demand profiler …
thad0ctor May 10, 2026
a71f26e
feat(protrain): cross-mode resume hook for HF Trainer load_checkpoint…
thad0ctor May 10, 2026
4eb6da6
feat(protrain): skip profiler trace pass when explicit override knobs…
thad0ctor May 10, 2026
32663f3
feat(protrain): runtime-side per-LoRA-container gather hooks (M6C-fix-3)
thad0ctor May 10, 2026
008b62e
docs(protrain): update Mode C PEFT-LoRA section per M6C-fix-3 close
thad0ctor May 10, 2026
b5ffa3d
refactor(protrain): synchronous gather in ensure_chunks_resident (M6C…
thad0ctor May 10, 2026
b787acb
feat(protrain): late-NCCL-re-search skip on overrides + autocast diag…
thad0ctor May 11, 2026
0f44bfb
feat(protrain): per-LoRA-container post-fwd/bwd hooks (M6C-fix-6 hard…
thad0ctor May 11, 2026
55d9237
docs(protrain): formalize M6C-fix end-of-chain limitation in DESIGN.md
thad0ctor May 11, 2026
c0da428
feat(protrain): shape-preserving release-state placeholder (M6C-fix-7…
thad0ctor May 11, 2026
17ffb8d
feat(protrain): close M6C chain — DDP init-sync bypass for chunk-mana…
thad0ctor May 11, 2026
6febed8
docs(protrain): close M6C limitation section — multi-GPU plain LoRA M…
thad0ctor May 11, 2026
2fcc1fc
feat(protrain): per-dtype α fragmentation factor (Coverage audit Bloc…
thad0ctor May 12, 2026
f74c559
test(protrain): regress paged_adamw_8bit + Mode C multi-GPU @ seq=2048
thad0ctor May 12, 2026
d1ef2dd
chore(protrain): address CodeRabbit PR #21 quick-win nits
thad0ctor May 12, 2026
3aff348
chore(protrain): apply CodeRabbit re-review quick-win nits (round 2)
thad0ctor May 12, 2026
09cf657
feat(protrain): in-process rebuild lifecycle (D1/D2/D3) + P2P fail-cl…
thad0ctor May 12, 2026
d7624fb
test(protrain): address remaining CodeRabbit test-quality deferrals (…
thad0ctor May 12, 2026
6961490
feat(protrain): scheduler SWAP-stream safety barrier (R3-#1) + resume…
thad0ctor May 12, 2026
e6d8a1a
test(protrain): CodeRabbit R3 test-quality fixes (R3-#2, #3, #4, #5, #8)
thad0ctor May 12, 2026
b61f04e
feat(protrain): predict iter-1 init-transient peak (audit Block G)
thad0ctor May 12, 2026
aa0c6ba
fix(protrain): Mode-C steady-peak CKPT-chain accounting (audit Block G)
thad0ctor May 12, 2026
c996ce9
fix(protrain): close CodeRabbit R4 review (1 Critical + 2 Major + 1 M…
thad0ctor May 12, 2026
f09be09
chore(protrain): apply ruff-format reformats to cost/runtime + test_c…
thad0ctor May 12, 2026
55377e5
chore(protrain): normalize confusable unicode in commentary/docstring…
thad0ctor May 12, 2026
69eb152
fix(protrain): CodeRabbit full-review Majors — 4 real correctness gap…
thad0ctor May 12, 2026
40bb8ad
chore(protrain): CodeRabbit full-review Minors — docs consistency + t…
thad0ctor May 12, 2026
67372c3
fix(test): test_chunk_optim_shutdown caplog → mock.patch on LOG (CI f…
thad0ctor May 13, 2026
db094b5
chore(protrain): trim non-WHY comments and address CodeRabbit findings
thad0ctor May 21, 2026
cc72ca4
docs(protrain): document deferred non-compute α decomposition (ticket B)
thad0ctor May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/protrain/3090-8b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ tf32: false
# validator will refuse the config.
gradient_checkpointing: false

flash_attention: false
# M0 spike validated FA composes cleanly with ProTrain on this config.
flash_attention: true
xformers_attention: false

# IMPORTANT: Axolotl auto-enables fused Triton LoRA kernels (q/k/v/o/MLP)
Expand Down
97 changes: 86 additions & 11 deletions src/axolotl/integrations/protrain/DESIGN.md

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions src/axolotl/integrations/protrain/api/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,10 @@ def install_load_hook(
The closed-over ``optim`` is captured at install time (in
``post_trainer_create``, BEFORE Accelerate.prepare wraps the
optimizer), so it's already raw. We unwrap defensively in case
the caller hands in a wrapper.
the caller hands in a wrapper. At ``_patched()`` runtime we
re-resolve from ``trainer.optimizer`` so a cross-mode resume
rebuild that swaps the facade lands the load into the live
instance (falls back to the install-time raw on swap failure).

The ``allow_online_reshard`` flag plumbs through to
:func:`_load_protrain_optim_dir`. Default False keeps the Mode-C
Expand All @@ -2071,8 +2074,8 @@ def install_load_hook(
dir, all ranks barrier and load). See CHECKPOINT_DESIGN_PHASE2.md
§4.1.
"""
raw = _unwrap_protrain_optim(optim)
if raw is None:
raw_at_install = _unwrap_protrain_optim(optim)
if raw_at_install is None:
# Caller passed something that isn't a ProTrain optimizer —
# silently no-op rather than installing a hook that would
# never fire.
Expand All @@ -2081,6 +2084,11 @@ def install_load_hook(
original = trainer._load_optimizer_and_scheduler

def _patched(checkpoint: str | None) -> None:
# Re-resolve from ``trainer.optimizer`` so the cross-mode resume rebuild
# (which swaps trainer.optimizer = new_optim) loads into the live instance.
raw = _unwrap_protrain_optim(getattr(trainer, "optimizer", None))
if raw is None:
raw = raw_at_install
# Failure protocol: ``original(checkpoint)`` (the native HF
# optimizer/scheduler load) is outside any cluster-wide status
# handling, but the patched method still executes a distributed
Expand Down
678 changes: 635 additions & 43 deletions src/axolotl/integrations/protrain/api/model_wrapper.py

Large diffs are not rendered by default.

131 changes: 120 additions & 11 deletions src/axolotl/integrations/protrain/api/optim_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from axolotl.integrations.protrain.chunk import (
CpuFusedAdamAdapter,
GpuAdamW8bitAdapter,
GpuFusedAdamAdapter,
)
from axolotl.integrations.protrain.types import ChunkId, WrappedModel
Expand All @@ -59,7 +60,7 @@ class _ProTrainOptimizer(torch.optim.Optimizer):

def __init__(
self,
gpu_optim: GpuFusedAdamAdapter | None,
gpu_optim: GpuFusedAdamAdapter | GpuAdamW8bitAdapter | None,
cpu_optim: CpuFusedAdamAdapter | None,
params: list["nn.Parameter"],
defaults: dict[str, Any],
Expand Down Expand Up @@ -602,13 +603,35 @@ def _split_optim_param_groups(
inner.param_groups = new_groups


#: Axolotl / HF Trainer optimizer-name strings that route the persistent
#: chunk set through ``GpuAdamW8bitAdapter`` instead of
#: ``GpuFusedAdamAdapter``. ``adamw_8bit`` and ``adamw_bnb_8bit`` are
#: aliases in HF's ``OptimizerNames`` (training_args.py:128-129) that both
#: dispatch to ``bnb.optim.AdamW`` with ``optim_bits=8``; we accept both
#: spellings so users carrying configs from either origin work without
#: edits. ``paged_adamw_8bit`` selects the paged variant (UVM-backed
#: state) for the same set.
_BNB_8BIT_OPTIMIZERS: frozenset[str] = frozenset(
{"adamw_8bit", "adamw_bnb_8bit", "paged_adamw_8bit"}
)
_BNB_8BIT_PAGED_OPTIMIZERS: frozenset[str] = frozenset({"paged_adamw_8bit"})


def _normalize_optimizer_name(name: str | None) -> str | None:
"""Lower-case + strip whitespace, unwrapping ``OptimizerNames`` enums via ``.value``."""
if name is None:
return None
return str(getattr(name, "value", name)).strip().lower()


def protrain_optimizer_wrapper(
wrapped: WrappedModel,
*,
lr: float,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
optimizer_name: str | None = None,
) -> torch.optim.Optimizer:
"""Rebuild the GPU/CPU FusedAdam adapters at user-specified hyperparams.

Expand Down Expand Up @@ -695,16 +718,40 @@ def protrain_optimizer_wrapper(
else:
cpu_params_per_chunk[ChunkId(cid)] = chunk_params

gpu_optim: GpuFusedAdamAdapter | None = None
# bnb 8-bit Adam kernels are CUDA-only, so only the persistent
# (GPU-resident) chunk set can use the 8-bit adapter; non-persistent
# CPU shards keep the 32-bit DeepSpeedCPUAdam path.
normalized_optim_name = _normalize_optimizer_name(optimizer_name)
use_bnb_8bit = normalized_optim_name in _BNB_8BIT_OPTIMIZERS
use_paged_8bit = normalized_optim_name in _BNB_8BIT_PAGED_OPTIMIZERS

gpu_optim: GpuFusedAdamAdapter | GpuAdamW8bitAdapter | None = None
cpu_optim: CpuFusedAdamAdapter | None = None
if persistent_params:
gpu_optim = GpuFusedAdamAdapter(
params=persistent_params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
if use_bnb_8bit:
LOG.info(
"protrain_optimizer_wrapper: routing %d persistent params "
"through bnb %s (optimizer_name=%s)",
len(persistent_params),
"PagedAdamW8bit" if use_paged_8bit else "AdamW8bit",
optimizer_name,
)
gpu_optim = GpuAdamW8bitAdapter(
params=persistent_params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
paged=use_paged_8bit,
)
else:
gpu_optim = GpuFusedAdamAdapter(
params=persistent_params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)

# M7: for sharded non-persistent chunks the CPU Adam updates each
# :class:`_DtypeRegion`'s flat shard_param (one per rank slice per
Expand All @@ -722,6 +769,37 @@ def protrain_optimizer_wrapper(
else:
cpu_params_per_chunk_for_optim[cid] = chunk_params

if use_bnb_8bit and any(
params for params in cpu_params_per_chunk_for_optim.values()
):
# bnb 8-bit Adam requires CUDA tensors; non-persistent chunks
# live on CPU. We keep the
# 32-bit CpuFusedAdamAdapter on those chunks so training stays
# correct (and the user still gets the persistent-chunk 8-bit
# win from above). Surface this once, loudly, so users
# configuring `adamw_8bit` aren't surprised by the partial
# adoption.
n_cpu_chunks = sum(
1 for params in cpu_params_per_chunk_for_optim.values() if params
)
LOG.warning(
"protrain_optimizer_wrapper: optimizer_name=%s requested 8-bit "
"AdamW, but %d non-persistent chunk(s) live on CPU and bnb's "
"8-bit Adam kernels are CUDA-only. Those chunks will keep "
"using 32-bit DeepSpeedCPUAdam (still correct, but the "
"optimizer-state memory win applies only to the persistent "
"set). To get end-to-end 8-bit, configure ProTrain to force "
"all chunks persistent (Mode A): set "
"``protrain_auto_mode: false`` AND "
"``protrain_force_all_persistent: true`` together — "
"``protrain_force_all_persistent`` is ignored while "
"``protrain_auto_mode`` is on (the auto-mode selector picks "
"the mode itself based on capacity), so disabling auto-mode "
"first is required for the Mode-A override to take effect.",
optimizer_name,
n_cpu_chunks,
)

if any(params for params in cpu_params_per_chunk_for_optim.values()):
try:
cpu_optim = CpuFusedAdamAdapter(
Expand Down Expand Up @@ -827,9 +905,40 @@ def protrain_optimizer_wrapper(

# Swap the freshly-built adapters into the chunk manager so the
# scheduler's post_block_backward -> reduce_grads_and_offload ->
# cpu_optim.step_async chain uses them.
# cpu_optim.step_async chain uses them. The chunk manager's
# ``gpu_optim`` slot is typed ``GpuFusedAdamAdapter | None`` (the
# legacy adapter); the ``GpuAdamW8bitAdapter`` is duck-compat
# at the call sites that consume the slot (``.step()``,
# ``.zero_grad()``, ``.state_dict()`` — see
# :class:`GpuAdamW8bitAdapter`). We assign through a typing cast
# rather than widening the chunk manager's type signature, which
# would touch a read-only file from this milestone's perspective.
#
# D3 lifecycle (shutdown-before-swap): ``CpuFusedAdamAdapter`` owns
# a live ``ThreadPoolExecutor`` and per-chunk DeepSpeedCPUAdam
# C-state; overwriting ``chunk_manager.cpu_optim`` without first
# tearing the old adapter down leaks executor threads + DeepSpeed
# state on every re-wrap (e.g. the resume hook's "Step 1" tears
# the adapter down at the plugin layer, but a direct second
# ``protrain_optimizer_wrapper`` invocation — e.g. user reruns the
# wrapper after changing optim hyperparams without going through
# the HF Trainer resume path — would otherwise GC-time the
# cleanup). Mirrors the same teardown the resume hook performs
# before ``restore_to_gpu``.
_old_cpu_optim = getattr(chunk_manager, "cpu_optim", None)
if _old_cpu_optim is not None and _old_cpu_optim is not cpu_optim:
# F-#3 (Major): let ``shutdown()`` failures abort the swap
# rather than warning-and-continuing. The whole point of
# calling ``shutdown()`` here is the D3 deterministic-cleanup
# invariant — masking a real teardown failure (e.g.,
# ``ThreadPoolExecutor`` hung, DeepSpeed C-state corrupted)
# puts the failed adapter back on the GC path AND silently
# accepts a broken state-machine on the rebuild side. If the
# shutdown raises, the rebuild is in an inconsistent state
# and the call should fail rather than silently degrading.
_old_cpu_optim.shutdown()
chunk_manager.cpu_optim = cpu_optim
chunk_manager.gpu_optim = gpu_optim
chunk_manager.gpu_optim = cast("GpuFusedAdamAdapter | None", gpu_optim)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# Build the flat param list for the Optimizer base class.
all_params: list["nn.Parameter"] = list(persistent_params)
Expand Down
104 changes: 82 additions & 22 deletions src/axolotl/integrations/protrain/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@
)


# Strict allow-list of Axolotl/HF optimizer names that ProTrain's chunk
# manager + per-chunk adapters can drive correctly. The set is the union
# of names dispatched by ``api/optim_wrapper.protrain_optimizer_wrapper``:
#
# * ``adamw_torch`` / ``adamw_torch_fused`` — default route through
# ``GpuFusedAdamAdapter`` (Apex FusedAdam, falls back to
# ``torch.optim.AdamW``) for persistent chunks and
# ``CpuFusedAdamAdapter`` (DeepSpeedCPUAdam) for non-persistent chunks.
# * ``adamw_8bit`` / ``adamw_bnb_8bit`` / ``paged_adamw_8bit`` —
# route persistent chunks through ``GpuAdamW8bitAdapter``
# (``bnb.optim.AdamW8bit`` / ``bnb.optim.PagedAdamW8bit``).
#
# All other optimizer names (Lion, Adafactor, GaLore, Sophia, Muon,
# torchao, plain SGD, etc.) have state shapes that do not match the
# AdamW-shaped adapters and are silently broken — the validator below
# rejects them at config-load time.
_SUPPORTED_OPTIMIZERS: frozenset[str] = frozenset(
{
"adamw_torch",
"adamw_torch_fused",
"adamw_8bit",
"adamw_bnb_8bit",
"paged_adamw_8bit",
}
)


def _has_protrain_plugin(plugins) -> bool:
"""Return True iff the iterable contains an explicit ProTrain plugin id.

Expand Down Expand Up @@ -121,8 +148,12 @@ class ProTrainArgs(BaseModel):
"trainer. Requires "
"``plugins: [axolotl.integrations.protrain.ProTrainPlugin]``. "
"Mutually exclusive with DeepSpeed, FSDP, gradient_checkpointing, "
"TP/CP/SP > 1, and load_in_8bit/load_in_4bit (see "
"`_reject_incompatible_features`)."
"and TP/CP/SP > 1 (see `_reject_incompatible_features`). "
"Composes with bitsandbytes ``load_in_8bit`` / ``load_in_4bit`` "
"(M2/M3 validated; ``Params4bit`` / ``Int8Params`` survive the "
"chunk gather/offload path because ``quant_state`` lives as a "
"Python attribute on the param and ``chunk/manager.py`` rebinds "
"``param.data`` without touching python attrs)."
)
},
)
Expand Down Expand Up @@ -269,10 +300,7 @@ class ProTrainArgs(BaseModel):
},
)

# ------------------------------------------------------------------
# Optimizer-state checkpoint/resume (CHECKPOINT_DESIGN.md Phase 1,
# CHECKPOINT_DESIGN_PHASE2.md Modes B + C)
# ------------------------------------------------------------------
# Optimizer-state checkpoint/resume.

protrain_save_optimizer_state: bool | None = Field(
default=False,
Expand Down Expand Up @@ -426,10 +454,17 @@ def _reject_incompatible_features(cls, data):
``sequence_parallel_degree`` > 1 — scope-excluded per plan.md
(M6 single-3090 focus); the chunk layout does not shard
correctly across TP/CP ranks in this milestone.
* ``load_in_8bit`` / ``load_in_4bit`` — bnb weight quantization
wraps ``nn.Linear.weight`` in a non-owning proxy. The chunk
manager reads unquantized storage for gather / offload and
cannot reason about the 8-bit / 4-bit packed buffers.

Note: ``load_in_8bit`` / ``load_in_4bit`` are NOT in this mutex
list. M0 spike + M2/M3 audit validation established that bnb
weight quantization composes with ProTrain in both Mode A
(all-persistent) AND offload mode — ``Params4bit.data`` and
``Int8Params.data`` are uint8/int8 storage tensors, so the
chunk manager's ``numel * element_size`` byte math handles them
correctly, and ``quant_state`` lives as a Python attribute on
the param instance and survives ``param.data`` rebinding (see
``chunk/manager.py``). Pinned by
``tests/protrain/test_bnb_offload.py``.

Each rejection surfaces at config-load time rather than as a
silent mis-training run.
Expand Down Expand Up @@ -500,19 +535,44 @@ def _reject_incompatible_features(cls, data):
"(scope-excluded per plan.md — single-3090 target). Set "
"sequence_parallel_degree=1 or remove the ProTrain plugin."
)
if data.get("load_in_8bit"):
raise ValueError(
"ProTrain is incompatible with load_in_8bit=true (bitsandbytes "
"8-bit quantization wraps nn.Linear.weight in a non-owning proxy; "
"the chunk manager operates on unquantized storage for gather / "
"offload). Set load_in_8bit=false or remove the ProTrain plugin."
)
if data.get("load_in_4bit"):
# M0 spike + M3 audit validation: bnb 8-bit / 4-bit weights compose with
# ProTrain in BOTH Mode A (all-persistent) AND offload mode (Mode C / single-GPU
# n_persist_override<N_chunk). Int8Params.data and Params4bit.data are int8/uint8
# tensors so chunk numel*element_size byte math handles them correctly; quant_state
# lives as a Python attribute on the param instance and survives the chunk gather/
# offload path because chunk/manager.py rebinds param.data without touching python
# attrs. Pinned by tests/protrain/test_bnb_offload.py.
return data
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@model_validator(mode="before")
@classmethod
def _reject_unsupported_optimizer(cls, data):
"""Reject ``cfg.optimizer`` values that ProTrain's adapters cannot drive."""
if not isinstance(data, dict):
return data
if not data.get("protrain_auto_memory"):
return data
plugins = data.get("plugins") or []
if not _has_protrain_plugin(plugins):
return data
optimizer = data.get("optimizer")
if optimizer is None:
return data
# Tolerate enum values supplied programmatically (e.g.
# ``OptimizerNames.ADAMW_TORCH``) as well as the YAML string.
optimizer_str = getattr(optimizer, "value", optimizer)
normalized = str(optimizer_str).strip().lower()
if normalized not in _SUPPORTED_OPTIMIZERS:
supported = ", ".join(sorted(_SUPPORTED_OPTIMIZERS))
raise ValueError(
"ProTrain is incompatible with load_in_4bit=true (bitsandbytes "
"4-bit quantization wraps nn.Linear.weight in a non-owning proxy; "
"the chunk manager operates on unquantized storage for gather / "
"offload). Set load_in_4bit=false or remove the ProTrain plugin."
f"ProTrain currently supports AdamW family optimizers only "
f"(got `{optimizer_str}`). Lion, Adafactor, GaLore, Sophia, "
f"Muon, and torchao optimizers require optimizer-specific "
f"chunk-manager adapters that have not been implemented. See "
f"src/axolotl/integrations/protrain/chunk/optim.py for the "
f"supported adapter list. Supported optimizers: "
f"{supported}. Set `optimizer: adamw_torch` (or another "
f"supported value above) or remove the ProTrain plugin."
)
return data

Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/integrations/protrain/chunk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from axolotl.integrations.protrain.chunk.manager import ChunkManager
from axolotl.integrations.protrain.chunk.optim import (
CpuFusedAdamAdapter,
GpuAdamW8bitAdapter,
GpuFusedAdamAdapter,
)
from axolotl.integrations.protrain.chunk.pinned_alloc import PinnedHostMemory
Expand All @@ -23,6 +24,7 @@
"BufferPool",
"ChunkManager",
"CpuFusedAdamAdapter",
"GpuAdamW8bitAdapter",
"GpuFusedAdamAdapter",
"PinnedHostMemory",
"build_layout",
Expand Down
Loading
Loading