Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
4c67858
fix(nvfp4): auto-enable attention compile custom op under torch_compile
thad0ctor Jun 4, 2026
ee7e66d
perf(nvfp4): fuse abs into the global-amax reduction (drop AbsFunctor…
thad0ctor Jun 4, 2026
998d498
perf(lora): batch shared-input adapter GEMMs to cut launch overhead
thad0ctor Jun 4, 2026
a62d419
feat(nvfp4): chunked bf16 lm_head cross-entropy (no logits materializ…
thad0ctor Jun 4, 2026
e3fed9c
docs(nvfp4): record agent-fix bench (+4.8%/-9GiB stack) and chunked b…
thad0ctor Jun 4, 2026
34af026
perf(nvfp4): reuse rms in fused RMSNorm forward, drop double rsqrt [F2]
thad0ctor Jun 4, 2026
915c875
docs(nvfp4): correct rtn_grad_packs desc (dPt dO pack also stays SR) …
thad0ctor Jun 4, 2026
6fe10d5
perf(nvfp4): drop dead HP q/k/v save on saved-packs backward [B1]
thad0ctor Jun 4, 2026
333b750
perf(nvfp4): bf16 dq scratch under dkdv_scratch_bf16 (bit-exact, mem)…
thad0ctor Jun 4, 2026
612fe26
docs(nvfp4): dq scratch follows dkdv_scratch_bf16; record B1/B2 fixed…
thad0ctor Jun 4, 2026
2926b53
perf(nvfp4): read Q/K transposed in fused_rope_quant_qk (drop .contig…
thad0ctor Jun 4, 2026
63c18d6
test(nvfp4): rope_quant strided-parity unit test + e2e block bench (#…
thad0ctor Jun 4, 2026
178520c
feat(nvfp4): opt-in FP4 q/k/o_proj + shared qkv activation pack (defa…
thad0ctor Jun 4, 2026
c2fde37
feat(nvfp4): separable q/k vs o_proj FP4 sub-gates + ablation
thad0ctor Jun 4, 2026
94f62c6
refactor(nvfp4): nest attention flags under model-agnostic nvfp4_trai…
thad0ctor Jun 4, 2026
8d3123a
style(nvfp4): pre-commit clean PR files; drop throwaway bench scripts
thad0ctor Jun 4, 2026
0a5b3a6
chore: repo-wide pre-commit cleanup (ruff-format, mypy/ruff suppressi…
thad0ctor Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 62 additions & 15 deletions docs/nvfp4_training.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ checkpointing path measured so far is:

- `adapter: lora`
- `nvfp4_training.base_mode: compute`
- `qwen3_5_native_attention: true`
- `qwen3_5_native_attention_backward: true`
- `qwen3_5_native_attention_backward_rtn_grad_packs: true`
- `qwen3_5_native_attention_save_backward_packs: true`
- `qwen3_5_native_attention_dkdv_scratch_bf16: true`
- `qwen3_5_fla_causal_conv_compile_boundary: true`
- `attention.enabled: true`
- `attention.backward.enabled: true`
- `attention.backward.rtn_grad_packs: true`
- `attention.backward.save_packs: true`
- `attention.backward.dkdv_scratch_bf16: true`
- `fla_causal_conv_compile_boundary: true`
- `fuse_rmsnorm: false`
- `max_grad_norm: 1.0`

Expand Down Expand Up @@ -167,6 +167,25 @@ sequence length 2048, sample packing, no gradient checkpointing:
| Locked profile, 500-step long-tail (RTX PRO 6000 #1, PCIe-first) | 4 | 1.2302 | ~6,659 tok/s | 62.03 GiB |
| Locked profile + `hadamard: false` (RTX PRO 6000 #1) | 4 | 1.2155 | ~6,739 tok/s | 62.03 GiB |
| Locked profile + once-per-forward mask classify (RTX PRO 6000 #1) | 4 | 1.1863 | ~6,906 tok/s | 62.03 GiB |
| Agent-fix re-bench: locked baseline (RTX PRO 6000 #1, 3-run median) | 4 | 1.1897 | ~6,886 tok/s | 69.63 GiB |
| + attention compile custom op auto-enabled | 4 | 1.1678 | ~7,015 tok/s | 60.64 GiB |
| + abs→`vector_norm` amax fuse + batched LoRA A-GEMMs (full stack) | 4 | 1.1310 | ~7,244 tok/s | 60.64 GiB |

A 2026-06 sweep profiled the whole step and found full attention is only ~7% of
GPU time (so a faster FP4 attention backward is not the lever), while ~26% is
unfused NVFP4 quant/elementwise and ~11% is the bf16 lm_head. The dominant cause
was a **silent eager fallback**: the attention forward's P@V `tl.dot_scaled` raises
an `InductorError` inside Inductor autotune, and with `suppress_errors` on (the
training default) the whole attention-containing subgraph ran eager, blocking
fusion of the surrounding elementwise. Auto-enabling the differentiable attention
custom op when `torch_compile` is on (see `attention.backward.compile_custom_op`
below) makes Inductor compile *around* the opaque op, restoring fusion. Stacked with
the bit-exact amax fuse (`amax(abs(t))` → `vector_norm(t, inf)`, dropping the
full-tensor abs pass) and batched shared-input LoRA A-GEMMs (`lora_batch_kernel`),
the interleaved A/B above measured **1.131 vs 1.190 s/step (~4.8% faster) and
−9 GiB** at identical loss. All three changes are bit-exact or opt-in; the first
also drops active memory because the custom-op backward does not retain forward
packs (`attention.backward.save_packs` has no effect under it).

The marginal s/step varies ~5% **between the two RTX PRO 6000 boards**: the
PCIe-first board measured 1.2208–1.2302 across 60- and 500-step runs, while the
Expand All @@ -189,7 +208,11 @@ Negative checks from the same sweep:
attention packs: 2.0500 s/step (~5,994 tok/s).
- Cut Cross Entropy and Liger fused-linear CE are not valid with this NVFP4
training setup today: b5 warmups loss-collapsed to zero after the first step
and logged non-finite grad norms.
and logged non-finite grad norms. The opt-in `bf16_lm_head_cross_entropy` path
(below) avoids that collapse — it never filters gradient mass and keeps the
logsumexp and `grad_hidden` accumulation in fp32 — but it is a memory win, not a
throughput win (1.2178 s/step, 56.48 GiB at b4; the tile GEMMs are bf16 with the
same FLOPs as the materialized lm_head, so it trades ~2.5% speed for ~13 GiB).
- `fp8_lm_head_cross_entropy` can reduce memory and run b5/b6, but it did not
beat the b4 saved-pack path on max throughput in clean validation. It is also
not fully deterministic yet: one b6 run loss-collapsed with `grad_norm: nan`,
Expand Down Expand Up @@ -424,6 +447,27 @@ path remained ahead on tokens/sec. Treat batch-6 FP8 CE as experimental: one
validation run loss-collapsed after the first step with non-finite grad norms,
while a later short repeat stayed finite.

### Chunked bf16 lm_head cross-entropy (opt-in)

`bf16_lm_head_cross_entropy: true` is a **memory** option for a frozen, bias-free
plain `nn.Linear` `lm_head`. It computes the loss and `dL/dhidden` by streaming the
vocabulary in tiles (`_VOCAB_BLOCK = 4096`) with online softmax, so the
`[tokens, vocab]` logits tensor and its gradient are never materialized. Unlike the
fused CCE/Liger kernels — which loss-collapse in this NVFP4 setup — it does **no**
gradient filtering, keeps the logsumexp and `grad_hidden` accumulation in fp32, and
downcasts to bf16 once at the end, matching `F.cross_entropy` to ~1e-7 (loss) /
~3.5e-3 (grad, pure bf16-GEMM noise) at the Qwen3.5 vocab scale.

It is a memory/batch-size unlock, **not** a throughput win: the tile GEMMs are bf16
with the same FLOPs as the materialized lm_head (which is already frozen, so there is
no weight gradient to save), so it trades speed for memory. On Qwen3.5-9B LoRA
(seq 2048, b4, RTX PRO 6000) it measured **1.2178 s/step at 56.48 GiB active** versus
the locked saved-pack path's ~1.19 s/step at 69.63 GiB — ~2.5% slower for ~13 GiB
headroom. Use it when memory-bound (to fit a larger batch); otherwise leave it off.
It engages only for a frozen bias-free `nn.Linear` lm_head in training (it falls back
to materialized CE otherwise) and is mutually exclusive with `quantize_lm_head`,
`fused_fp4_cross_entropy`, and `fp8_lm_head_cross_entropy`.

::: {.callout-note}
At the time of writing, `skip_first_n_blocks` / `skip_last_n_blocks` may be applied
by the integration layer rather than inside `convert_to_nvfp4_training` directly
Expand Down Expand Up @@ -452,14 +496,17 @@ The `nvfp4_training:` block (schema: `src/axolotl/utils/schemas/nvfp4.py`,
| `skip_first_n_blocks` | `int` | `0` | Keep the first N transformer blocks in high precision (see the ~15% high-precision policy below). |
| `skip_last_n_blocks` | `int` | `0` | Keep the last N transformer blocks in high precision (the tail blocks matter most). |
| `save_nvfp4` | `bool` | `false` | Opt-in. Store eligible weights NVFP4-packed (qdata + scales) in a `torch.save` sidecar for ~3.5× smaller weight files. See [FP4-packed save](#fp4-packed-save-save_nvfp4) below. **Lossy for FFT resume** (no bf16 master kept); bit-exact for frozen weights. Off by default (bf16 save, unchanged). |
| `qwen3_5_native_attention` | `bool` | `false` | Qwen3.5 only. Patch full softmax-attention layers to use the native NVFP4 attention path on dense causal/full batches. |
| `qwen3_5_native_attention_backward` | `bool` | `false` | Qwen3.5 only. Requires `qwen3_5_native_attention`. Use the native NVFP4 autograd attention path while training. |
| `qwen3_5_native_attention_backward_rtn_grad_packs` | `bool` | `false` | Qwen3.5 native attention training only. Use deterministic round-to-nearest for measured-safe gradient packs while leaving the dK routing-gradient dS pack governed by `stochastic_rounding`. |
| `qwen3_5_native_attention_save_backward_packs` | `bool` | `false` | Qwen3.5 native attention training only. Save deterministic forward Q/K/V FP4 packs plus transposed backward layouts and reuse them in backward. Trades extra activation memory for higher throughput. |
| `qwen3_5_native_attention_dkdv_scratch_bf16` | `bool` | `false` | Qwen3.5 native attention training only. Store per-query-head dK/dV scratch in bf16 before GQA reduction instead of fp32. Measured faster on Qwen3.5-9B b4; opt-in because it changes an intermediate gradient cast. |
| `qwen3_5_native_attention_compile_custom_op` | `bool` | `false` | Qwen3.5 native attention inference only. Opaque custom-op escape hatch for strict compile coverage around Triton `tl.dot_scaled`; rejected when native attention backward is enabled. |
| `qwen3_5_fla_causal_conv_compile_boundary` | `bool` | `false` | Qwen3.5 sample-packing only. Runs FLA varlen `causal_conv1d` behind a compile boundary so variable packed `cu_seqlens` lengths do not repeatedly trigger Dynamo/Inductor recompiles. |
| `qwen3_5_fuse_vproj` / `qwen3_5_native_mlp` / `qwen3_5_native_linear_attn` / `fp8_lm_head` | `bool` | `false` | Qwen3.5/eval-scoring paths. Current implementations are eval/no-grad only and do not accelerate grad-enabled training. Use `fp8_lm_head_cross_entropy` separately for the opt-in training loss memory path. |
| `attention.enabled` | `bool` | `false` | Patch full softmax-attention layers to the native NVFP4 attention path on dense causal/full batches (model-agnostic; applied where the architecture supports it). Replaces the deprecated flat `qwen3_5_native_attention*` flags. |
| `attention.fuse_vproj` | `bool \| null` | `null` | Run v_proj as a native NVFP4 GEMM with a key-axis pack epilogue on inference/cache-free prefill. `null`: on for inference, off for training. |
| `attention.fp4_projections` | `bool` | `false` | Run q/k/o_proj as native NVFP4 GEMMs (q/k share one activation pack) on inference prefill. **Parity-affecting** (not bit-exact); speed-neutral on hybrid models, a per-layer win on dense models. Plain-`nn.Linear` only. |
| `attention.backward.enabled` | `bool` | `false` | Requires `attention.enabled`. Use the native NVFP4 autograd attention path while training. |
| `attention.backward.rtn_grad_packs` | `bool` | `false` | Deterministic round-to-nearest for the measured-safe gradient packs, leaving the dK and dPt packs governed by `stochastic_rounding`. |
| `attention.backward.save_packs` | `bool` | `false` | Save the forward Q/K/V FP4 packs (+ transposed backward layouts) and reuse them in backward — trades activation memory for higher throughput. |
| `attention.backward.dkdv_scratch_bf16` | `bool` | `false` | Store dQ and per-query-head dK/dV backward scratch in bf16 (accumulate fp32 in-register, downcast once at the store → bit-identical to fp32-then-`.to(bf16)`; a pure memory save on the largest backward scratch planes). |
| `attention.backward.compile_custom_op` | `bool \| null` | `null` (auto) | Wrap the attention path in an opaque differentiable custom op so Inductor compiles *around* the Triton `tl.dot_scaled` (which otherwise raises an `InductorError` and silently drops the block to eager). `null` auto-enables it when `torch_compile` is on and `attention.enabled`; `true`/`false` force it. Under it `attention.backward.save_packs` has no effect. |
| `bf16_lm_head_cross_entropy` | `bool` | `false` | Opt-in **memory** path. Requires a frozen bias-free plain `nn.Linear` lm_head. Chunked online-softmax CE over bf16 vocab tiles — no `[tokens, vocab]` logits materialization, fp32 logsumexp/`grad_hidden`, no gradient filtering (avoids the CCE/Liger collapse). Trades ~2.5% throughput for ~13 GiB. Mutually exclusive with `quantize_lm_head` / `fused_fp4_cross_entropy` / `fp8_lm_head_cross_entropy`. |
| `fla_causal_conv_compile_boundary` | `bool` | `false` | Sample-packing only. Runs FLA varlen `causal_conv1d` behind a compile boundary so variable packed `cu_seqlens` lengths do not repeatedly trigger Dynamo/Inductor recompiles. |
| `linear_attn` / `mlp` / `fp8_lm_head` | `bool` | `false` | Eval/no-grad-only native-NVFP4 module patches (linear-attention projections, dense SwiGLU MLP). Do not accelerate grad-enabled training. Deprecated flat aliases: `qwen3_5_native_linear_attn`, `qwen3_5_native_mlp`. |

## FP4-packed save (`save_nvfp4`) {#fp4-packed-save-save_nvfp4}

Expand Down
138 changes: 138 additions & 0 deletions examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Fastest NVFP4 Qwen3.5-9B LoRA path + chunked bf16 lm_head cross-entropy.
#
# Identical to qwen35-9b-lora-fastest.yaml but with
# nvfp4_training.bf16_lm_head_cross_entropy: true. The lm_head stays excluded
# from FP4 (bf16) and frozen under LoRA, so the chunked bf16 CE tiles the
# projection over the vocab instead of materializing the full [tokens, vocab]
# logits (and its gradient GEMM). This is a memory / backward-traffic win, not an
# FP4 tensor-core throughput win, and it is the convergence-safe alternative to
# cut_cross_entropy / Liger fused-linear CE, which loss-collapsed in this setup.
# Loss & dL/dhidden are bit-close to F.cross_entropy over full logits.
#

# Target shape: text-only SFT, sequence_len 2048, sample packing, no gradient
# checkpointing. Measured on RTX PRO 6000 Blackwell 96GB with
# scripts/bench_nvfp4.sh (20 -> 60 marginal train_runtime):
# bf16 b6 stable baseline: 2.1550 s/step, ~5701 tok/s, 89.61 GiB active
# (learning_rate 2e-5; finite grad norms)
# all-on NVFP4 b4 baseline: 1.2475 s/step, ~6567 tok/s, 69.26 GiB active
# this saved-pack path: 1.1525-1.2075 s/step, ~6784-7108 tok/s,
# 69.63 GiB active
# plus FLA boundary: 1.1708 s/step, ~6997 tok/s, 69.63 GiB active
# plus BF16 dK/dV scratch: 1.1415-1.1485 s/step, ~7133-7177 tok/s,
# 62.03 GiB active
# beta lock-down repeat: 1.2017 s/step, ~6817 tok/s,
# 62.03 GiB active
# clean GPU3 repeat: 1.0635 s/step, ~7701 tok/s,
# 69.63 GiB active
# b5 no CE repeat: 1.5195 s/step, ~6740 tok/s,
# 74.43 GiB active
# b6 no CE repeat: 1.8473 s/step, ~6652 tok/s,
# 86.83 GiB active
# FP8 lm_head CE b6: 1.7700-1.7807 s/step, ~6902-6943 tok/s,
# 80.32 GiB active (slower; one validation run
# loss-collapsed with non-finite grad norms)
# FP8 lm_head CE b5 repeat: 1.5312 s/step, ~6688 tok/s, 69.60 GiB active
# (memory win, not a throughput win)
#
# Do not use fp16 as the convergence baseline here: with NVFP4 omitted and
# max_grad_norm: 1.0 explicit, fp16 produced NaN LoRA gradients at the first AMP
# unscale. bf16 is the natural full-precision Blackwell baseline.
#
# The speed knob here is attention.backward.save_packs: it saves
# forward FP4 attention packs and reuses them in backward, trading ~0.4 GiB of
# activation memory for less backward pack-prep work. The FLA boundary prevents
# variable packed cu_seqlens from burning compile time in causal_conv1d. BF16
# dK/dV scratch reduces the attention-backward scratch traffic before GQA reduce.
# The no-grad/eval-only Qwen3.5 native MLP, native linear-attn, fused v_proj,
# standalone fp8_lm_head, FP8 lm_head CE, CCE, and Liger fused-linear CE
# switches are intentionally omitted. Same-prepared-data ablations measured:
# 1.2017 s/step without legacy switches vs 1.2335 s/step with them; FP8 CE b5
# 1.5312 s/step; CCE/Liger CE loss-collapse with non-finite grad norms in this
# NVFP4 training setup.
#
# For benchmark comparisons, copy this config to /tmp, set base_model to the
# local model path, and keep dataset_prepared_path fixed across compared runs.
base_model: Qwen/Qwen3.5-9B
model_config_type: qwen3_5
strict: false

chat_template: qwen3_5
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
val_set_size: 0.0
dataset_prepared_path: /tmp/axolotl_nvfp4_qwen35_fastest_prepared
output_dir: /tmp/axolotl_nvfp4_qwen35_fastest_out

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

load_in_8bit: false
load_in_4bit: false
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- linear_attn.in_proj_qkv
- linear_attn.in_proj_z
- linear_attn.out_proj

gradient_accumulation_steps: 1
# This measured b4 path targets 96GB Blackwell. Reduce micro_batch_size on
# smaller cards or enable gradient_checkpointing if you need memory headroom.
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2.0e-4
max_grad_norm: 1.0

bf16: true
fp16: false
tf32: true
torch_compile: true

gradient_checkpointing: false
attn_implementation: flash_attention_2

nvfp4_training:
enabled: true
base_mode: compute
stochastic_rounding: true
hadamard: true
exclude_modules: [lm_head, embed_tokens]
skip_first_n_blocks: 0
skip_last_n_blocks: 0
fuse_rmsnorm: false

# Chunked bf16 lm_head CE: skip materializing the [tokens, vocab] logits.
bf16_lm_head_cross_entropy: true

# Native full-attention training path. save_packs is the measured throughput
# win; rtn_grad_packs keeps the safe gradient-side packs deterministic.
attention:
enabled: true
backward:
enabled: true
rtn_grad_packs: true
save_packs: true
dkdv_scratch_bf16: true
fla_causal_conv_compile_boundary: true

warmup_steps: 10
logging_steps: 1
save_strategy: "no"
saves_per_epoch:
evals_per_epoch:
weight_decay: 0.0
special_tokens:
20 changes: 11 additions & 9 deletions examples/nvfp4/qwen35-9b-lora-fastest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# max_grad_norm: 1.0 explicit, fp16 produced NaN LoRA gradients at the first AMP
# unscale. bf16 is the natural full-precision Blackwell baseline.
#
# The speed knob here is qwen3_5_native_attention_save_backward_packs: it saves
# The speed knob here is attention.backward.save_packs: it saves
# forward FP4 attention packs and reuses them in backward, trading ~0.4 GiB of
# activation memory for less backward pack-prep work. The FLA boundary prevents
# variable packed cu_seqlens from burning compile time in causal_conv1d. BF16
Expand Down Expand Up @@ -105,14 +105,16 @@ nvfp4_training:
skip_last_n_blocks: 0
fuse_rmsnorm: false

# Qwen3.5 full-attention training path. The saved-pack flag is the measured
# throughput win; RTN grad packs keep the safe gradient-side packs deterministic.
qwen3_5_native_attention: true
qwen3_5_native_attention_backward: true
qwen3_5_native_attention_backward_rtn_grad_packs: true
qwen3_5_native_attention_save_backward_packs: true
qwen3_5_native_attention_dkdv_scratch_bf16: true
qwen3_5_fla_causal_conv_compile_boundary: true
# Native full-attention training path. save_packs is the measured throughput
# win; rtn_grad_packs keeps the safe gradient-side packs deterministic.
attention:
enabled: true
backward:
enabled: true
rtn_grad_packs: true
save_packs: true
dkdv_scratch_bf16: true
fla_causal_conv_compile_boundary: true

warmup_steps: 10
logging_steps: 1
Expand Down
Loading
Loading