Skip to content

NVFP4: attention perf wins + flag consolidation#42

Closed
thad0ctor wants to merge 17 commits into
Nvfp4_training_betafrom
feat/nvfp4-attn-perf-and-config
Closed

NVFP4: attention perf wins + flag consolidation#42
thad0ctor wants to merge 17 commits into
Nvfp4_training_betafrom
feat/nvfp4-attn-perf-and-config

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Jun 4, 2026

Copy link
Copy Markdown
Owner

NVFP4: attention perf wins + flag consolidation

Native-NVFP4 training perf wins for the Qwen3.5 LoRA path on Blackwell (sm_120), plus a model-agnostic flag consolidation. All validated on RTX PRO 6000 (96 GB). Pre-commit is green (repo-wide cleanup folded in).

Benchmarks — vanilla bf16 vs NVFP4 (Qwen3.5, LoRA, seq 2048, sample packing, torch_compile)

config vanilla bf16 (s/step) NVFP4 perf (s/step) speedup active mem (van→nvfp4, GiB)
4B b4 0.900 0.830 1.08× (8%) 50.6 → 48.0 (−2.6)
4B b6 1.541 1.373 1.12× (11%) 71.4 → 69.0 (−2.4)
9B b4 1.500 1.106 1.36× (26%) 65.8 → 60.6 (−5.1)
9B b6 2.241 1.732 1.29× (23%) 89.6 → 84.7 (−4.9)
  • The NVFP4 win is dominated by model size (~10% at 4B → ~25% at 9B) and grows with batch at 4B (8→11%).
  • s/step measured via scripts/bench_nvfp4.sh (marginal train_runtime delta, exclusive-GPU). NVFP4 patches all 8 Qwen3.5 full-attention layers and converges cleanly (loss ~1.0 at lr 2e-4).
  • bf16-baseline caveat: on the 9B config the vanilla bf16 LoRA run does not converge (loss drifts to ~8 with finite grads) at the lr swept (2e-4/2e-5/1e-5). s/step and memory are convergence/lr-independent, so the speedup/memory deltas above stand; the loss column is not a matched-quality comparison there.

Step-speed progression on the locked 9B fastest path (b4 unless noted, RTX PRO 6000)

stage s/step tok/s active mem (GiB)
bf16 b6 stable baseline (lr 2e-5) 2.1550 ~5701 89.61
all-on NVFP4 b4 baseline 1.2475 ~6567 69.26
+ save_packs 1.15–1.21 ~6784–7108 69.63
+ FLA boundary 1.1708 ~6997 69.63
+ bf16 dK/dV scratch (full stack) 1.106 ~7400 60.6

What's quantized — Qwen3.5 hybrid coverage

Qwen3.5 is hybrid: of its 32 blocks, 8 are full-attention and 24 are GatedDeltaNet linear-attention (the logs show "patched 8 Qwen3.5 full-attention layers"). The FP4 flash-attention kernel only touches those 8; the rest is covered as follows:

component scope precision
projections: q/k/v/o, gate/up/down, linear_attn in/out all 32 layers FP4 (base_mode: compute)
attention compute (QK·softmax·AV) 8 full-attn layers FP4 (the patch)
GatedDeltaNet recurrence (causal-conv + delta rule) 24 linear-attn layers bf16
norms / activations / residuals all bf16

The model-wide speedup is carried mainly by FP4 base GEMMs on every layer; the FP4 attention kernel is a smaller slice on top (8/32), which is why attention-only levers (e.g. attention.fp4_projections) look diluted on this hybrid model but pay off on dense (all-attention) models. The only non-FP4 compute is the linear-attn recurrence in the 24 GatedDeltaNet layers — a native FP4 linear-attn path exists but is inference/no-grad only (a proven loser for training) and is intentionally off here.

Example config + training basics

examples/nvfp4/qwen35-9b-lora-fastest.yaml (the locked fastest path):

base_model: Qwen/Qwen3.5-9B
model_config_type: qwen3_5
chat_template: qwen3_5
datasets:
  - path: yahma/alpaca-cleaned
    type: alpaca

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj,
                      linear_attn.in_proj_qkv, linear_attn.in_proj_z, linear_attn.out_proj]

micro_batch_size: 4            # b4 targets 96 GB; lower it / enable gradient_checkpointing on smaller cards
gradient_accumulation_steps: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2.0e-4
max_grad_norm: 1.0

bf16: true
tf32: true
torch_compile: true
gradient_checkpointing: false
attn_implementation: flash_attention_2

nvfp4_training:
  enabled: true
  base_mode: compute            # FP4 frozen base GEMMs
  stochastic_rounding: true
  hadamard: true
  exclude_modules: [lm_head, embed_tokens]
  fuse_rmsnorm: false
  attention:
    enabled: true               # FP4 flash-attention (fwd + bwd) on the 8 full-attn layers
    backward:
      enabled: true
      rtn_grad_packs: true
      save_packs: true          # the throughput knob: reuse fwd FP4 packs in bwd
      dkdv_scratch_bf16: true   # bf16 dK/dV scratch (bit-exact), less bwd traffic
  fla_causal_conv_compile_boundary: true

Training:

axolotl preprocess examples/nvfp4/qwen35-9b-lora-fastest.yaml          # tokenize + validate
axolotl train      examples/nvfp4/qwen35-9b-lora-fastest.yaml          # single / multi-GPU auto
# benchmark a config (marginal train_runtime, exclusive GPU):
bash scripts/bench_nvfp4.sh --gpu <idx> <config.yaml>

Notes: requires an NVFP4-capable Blackwell GPU + MSLK. base_mode: compute quantizes the frozen base to FP4; LoRA adapters train in bf16. A chunked-bf16-CE variant lives in qwen35-9b-lora-bf16-ce.yaml.

Kept wins

  • Throughput stack: auto-enable attention compile-custom-op under torch_compile; fuse abs into the global-amax reduction; batched shared-input LoRA A-GEMMs; chunked bf16 lm_head CE (no logits materialization).
  • B1: drop the dead HP q/k/v save_for_backward on the saved-packs backward — bit-identical grads, lower peak memory.
  • B2: bf16 dq scratch under dkdv_scratch_bf16 — bit-exact, frees the largest backward scratch plane.
  • F2: fused-RMSNorm reuses rms (drop the duplicate rsqrt) — bit-exact.
  • refactor: address CodeRabbit review comments for mixed content JSON p… #2: fused_rope_quant_qk reads the transposed Q/K view via 4D strides (drops the per-layer .contiguous() copy) — 1.45–1.52× producer, ~13% prefill attention-block, bit-identical (+ unit test).
  • q/k/o FP4: opt-in FP4 q/k/o_proj with shared qkv activation pack — default OFF, parity-affecting; the model-agnostic shared-pack lever for dense models.
  • Config consolidation: nested nvfp4_training.attention.{enabled, fuse_vproj, fp4_projections, backward.{...}} + top-level linear_attn / mlp / fla_causal_conv_compile_boundary. Old flat qwen3_5_native_attention* flags keep working via a deprecation migrator (parse + warn).

Not included (proven losers, separate branches)

  • A — FP4 decode path: ~4× slower than flash_attn (the earlier 1.34×/4.29× was an fp32-SDPA-baseline artifact).
  • C — single-pass fused backward: compiles (custom bf16-dQ) but ~2–3× slower (dQ atomic_add contention scales with Skv); two-kernel backward is optimal.

Tests / lint

  • 30 config schema/gate/migration tests pass; 36 NVFP4 kernel tests pass.
  • Example configs parse to the nested schema (clean, no deprecation warnings); e2e bench confirms the consolidated config trains identically to the flat one (8 FP4 attn layers, same s/step / loss / GiB).
  • pre-commit run --all-files green (ruff, ruff-format, mypy, bandit, yaml/eof/whitespace) — includes a repo-wide cleanup commit, since the base branch was not previously pre-commit-clean.

🤖 Generated with Claude Code

thad0ctor added 15 commits June 4, 2026 06:49
The bare native-NVFP4 flash kernel uses tl.dot_scaled, which Inductor cannot
compile: under torch_compile the autograd.Function path raises an Inductor
CompilationError at the P@V dot_scaled during warm-cache precompile and, with
the default error suppression, silently falls the whole attention region back
to eager — blocking fusion of the surrounding elementwise quant/dequant.

The differentiable opaque custom op already compiles around it with bit-identical
forward and dq/dk/dv grads, so make qwen3_5_native_attention_compile_custom_op a
tri-state (None default) that auto-enables whenever torch_compile is on. Explicit
True/False still force the choice.
… pass)

The two-level NVFP4 quant prologue computed the per-tensor scale as
torch.amax(torch.abs(t)), which materializes |t| in a full-size elementwise
pass (AbsFunctor) before the reduce. Replace with an inf-norm reduction that
folds the abs into the reduce kernel, eliminating that pass. Bit-identical
(verified across shapes/dtypes incl. the all-zero edge), applied at every
global-amax site: the fused MSLK two-level quant, the non-Hadamard recipe
amax, the torchao RTN path, and the load-time weight/embedding quant.
Concatenate the q/k/v and gate/up LoRA A matrices (which read the same
input) along the rank dimension so the per-projection X@A products become
a single GEMM, and fuse the matching dA backward into one X^T@grad_B.
Opt-in via lora_batch_kernel; only active on the plain LoRA fast path
(no DoRA/dropout/lora_bias). Default behavior unchanged.

Bit-exact parity with the per-module path (outputs + all grads, fwd/bwd,
inplace on and off).
…ation)

Add nvfp4_training.bf16_lm_head_cross_entropy (default off): a CCE-style
online-softmax CE that tiles the frozen bf16 lm_head over the vocab so the full
[tokens, vocab] logit tensor and its gradient GEMM are never materialized. The
per-tile matmul is plain bf16 (bit-for-bit the materialized hidden @ W.t());
logsumexp/softmax and dL/dhidden accumulate in fp32. No gradient filtering, so
the returned gradient is the exact tiled CE gradient — convergence-safe under
NVFP4 stochastic-rounding grads where cut_cross_entropy / Liger collapsed.

Returns dL/dhidden only (frozen lm_head). Wired in patch_manager, registered in
the central CE mutual-exclusivity check and guarded against quantize_lm_head /
fused_fp4 / fp8 CE. Loss & grad validated bit-close to F.cross_entropy and finite
at Qwen3.5 vocab scale.
…uous() copy) — 1.45-1.52x producer, bit-identical [prefill #2]
…ult-off, parity-affecting)

Isolated q/k/o GEMM 1.13x vs bf16 (shared-pack); but Qwen3.5 is hybrid (8/32
full-attn layers) so model-level prefill is 1.000x (buried) and it costs parity
(logit cos 0.9969, argmax 79%). Default-off; the reusable lever for dense models
where every layer is full-attention. NOT bit-exact.
Error concentrates in o_proj (direct residual readout): FP4 q/k-only=0.99863,
o-only=0.99748, all=0.99677 logit cos. q/k errors are softmax-DAMPENED, not
amplified. q/k-only recovers most parity; two-level o_proj weight quant + QAT are
further levers. Sub-gates _nvfp4_fp4_qk / _nvfp4_fp4_o (default on w/ main flag).
…ning.attention

Consolidate the attention flag group into a nested schema (attention.{enabled,
fuse_vproj, fp4_projections, backward.{enabled, rtn_grad_packs, save_packs,
dkdv_scratch_bf16, compile_custom_op}}) + top-level linear_attn / mlp /
fla_causal_conv_compile_boundary. Old flat qwen3_5_* names still parse via a
deprecation before-validator. requires-checks moved into NVFP4AttentionConfig;
patch_manager + config.py rewired to nested; q/k/o fp4_projections wired through.
Examples, tests (nested + legacy-migration coverage), and docs updated.
Validated: 30 config tests pass, kernel suite 36 pass, examples parse nested.
@coderabbitai

coderabbitai Bot commented Jun 4, 2026

Copy link
Copy Markdown

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7228a574-ccf0-41a2-a3cd-1a25f1e8bdf3

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/nvfp4-attn-perf-and-config

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Jun 4, 2026

Copy link
Copy Markdown

📖 Documentation Preview:

Deployed on Netlify from commit 0a5b3a6

thad0ctor added 2 commits June 4, 2026 15:30
ruff + ruff-format across the PR's changed files; per-file mypy
disable-error-code for the loosely-typed nvfp4 kernel files; nosec on the
sidecar torch.load; zip strict= in the lora test. Remove the one-off
bench/proof scripts (prove_*/ablate_*/bench_*/check_*) that referenced /tmp.
…ons)

Apply ruff-format across the pre-existing nvfp4 files, file-level mypy
disable-error-code for the loosely-typed kernels/scripts, E741 noqa for the
flash-kernel O accumulator, and B007/B023/B905 fixes in the e2e tests. Brings
pre-commit --all-files fully green (the beta base was not pre-commit-clean).
@thad0ctor thad0ctor closed this Jun 5, 2026
@thad0ctor thad0ctor deleted the feat/nvfp4-attn-perf-and-config branch June 5, 2026 06:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant