NVFP4: attention perf wins + flag consolidation#42
Closed
thad0ctor wants to merge 17 commits into
Closed
Conversation
The bare native-NVFP4 flash kernel uses tl.dot_scaled, which Inductor cannot compile: under torch_compile the autograd.Function path raises an Inductor CompilationError at the P@V dot_scaled during warm-cache precompile and, with the default error suppression, silently falls the whole attention region back to eager — blocking fusion of the surrounding elementwise quant/dequant. The differentiable opaque custom op already compiles around it with bit-identical forward and dq/dk/dv grads, so make qwen3_5_native_attention_compile_custom_op a tri-state (None default) that auto-enables whenever torch_compile is on. Explicit True/False still force the choice.
… pass) The two-level NVFP4 quant prologue computed the per-tensor scale as torch.amax(torch.abs(t)), which materializes |t| in a full-size elementwise pass (AbsFunctor) before the reduce. Replace with an inf-norm reduction that folds the abs into the reduce kernel, eliminating that pass. Bit-identical (verified across shapes/dtypes incl. the all-zero edge), applied at every global-amax site: the fused MSLK two-level quant, the non-Hadamard recipe amax, the torchao RTN path, and the load-time weight/embedding quant.
Concatenate the q/k/v and gate/up LoRA A matrices (which read the same input) along the rank dimension so the per-projection X@A products become a single GEMM, and fuse the matching dA backward into one X^T@grad_B. Opt-in via lora_batch_kernel; only active on the plain LoRA fast path (no DoRA/dropout/lora_bias). Default behavior unchanged. Bit-exact parity with the per-module path (outputs + all grads, fwd/bwd, inplace on and off).
…ation) Add nvfp4_training.bf16_lm_head_cross_entropy (default off): a CCE-style online-softmax CE that tiles the frozen bf16 lm_head over the vocab so the full [tokens, vocab] logit tensor and its gradient GEMM are never materialized. The per-tile matmul is plain bf16 (bit-for-bit the materialized hidden @ W.t()); logsumexp/softmax and dL/dhidden accumulate in fp32. No gradient filtering, so the returned gradient is the exact tiled CE gradient — convergence-safe under NVFP4 stochastic-rounding grads where cut_cross_entropy / Liger collapsed. Returns dL/dhidden only (frozen lm_head). Wired in patch_manager, registered in the central CE mutual-exclusivity check and guarded against quantize_lm_head / fused_fp4 / fp8 CE. Loss & grad validated bit-close to F.cross_entropy and finite at Qwen3.5 vocab scale.
…f16 CE memory mode
…-shape memory probe
…uous() copy) — 1.45-1.52x producer, bit-identical [prefill #2]
… ~13% attn-block, bit-identical)
…ult-off, parity-affecting) Isolated q/k/o GEMM 1.13x vs bf16 (shared-pack); but Qwen3.5 is hybrid (8/32 full-attn layers) so model-level prefill is 1.000x (buried) and it costs parity (logit cos 0.9969, argmax 79%). Default-off; the reusable lever for dense models where every layer is full-attention. NOT bit-exact.
Error concentrates in o_proj (direct residual readout): FP4 q/k-only=0.99863, o-only=0.99748, all=0.99677 logit cos. q/k errors are softmax-DAMPENED, not amplified. q/k-only recovers most parity; two-level o_proj weight quant + QAT are further levers. Sub-gates _nvfp4_fp4_qk / _nvfp4_fp4_o (default on w/ main flag).
…ning.attention
Consolidate the attention flag group into a nested schema (attention.{enabled,
fuse_vproj, fp4_projections, backward.{enabled, rtn_grad_packs, save_packs,
dkdv_scratch_bf16, compile_custom_op}}) + top-level linear_attn / mlp /
fla_causal_conv_compile_boundary. Old flat qwen3_5_* names still parse via a
deprecation before-validator. requires-checks moved into NVFP4AttentionConfig;
patch_manager + config.py rewired to nested; q/k/o fp4_projections wired through.
Examples, tests (nested + legacy-migration coverage), and docs updated.
Validated: 30 config tests pass, kernel suite 36 pass, examples parse nested.
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
📖 Documentation Preview: Deployed on Netlify from commit 0a5b3a6 |
ruff + ruff-format across the PR's changed files; per-file mypy disable-error-code for the loosely-typed nvfp4 kernel files; nosec on the sidecar torch.load; zip strict= in the lora test. Remove the one-off bench/proof scripts (prove_*/ablate_*/bench_*/check_*) that referenced /tmp.
…ons) Apply ruff-format across the pre-existing nvfp4 files, file-level mypy disable-error-code for the loosely-typed kernels/scripts, E741 noqa for the flash-kernel O accumulator, and B007/B023/B905 fixes in the e2e tests. Brings pre-commit --all-files fully green (the beta base was not pre-commit-clean).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NVFP4: attention perf wins + flag consolidation
Native-NVFP4 training perf wins for the Qwen3.5 LoRA path on Blackwell (sm_120), plus a model-agnostic flag consolidation. All validated on RTX PRO 6000 (96 GB). Pre-commit is green (repo-wide cleanup folded in).
Benchmarks — vanilla bf16 vs NVFP4 (Qwen3.5, LoRA, seq 2048, sample packing, torch_compile)
scripts/bench_nvfp4.sh(marginaltrain_runtimedelta, exclusive-GPU). NVFP4 patches all 8 Qwen3.5 full-attention layers and converges cleanly (loss ~1.0 at lr 2e-4).Step-speed progression on the locked 9B fastest path (b4 unless noted, RTX PRO 6000)
What's quantized — Qwen3.5 hybrid coverage
Qwen3.5 is hybrid: of its 32 blocks, 8 are full-attention and 24 are GatedDeltaNet linear-attention (the logs show "patched 8 Qwen3.5 full-attention layers"). The FP4 flash-attention kernel only touches those 8; the rest is covered as follows:
base_mode: compute)The model-wide speedup is carried mainly by FP4 base GEMMs on every layer; the FP4 attention kernel is a smaller slice on top (8/32), which is why attention-only levers (e.g.
attention.fp4_projections) look diluted on this hybrid model but pay off on dense (all-attention) models. The only non-FP4 compute is the linear-attn recurrence in the 24 GatedDeltaNet layers — a native FP4 linear-attn path exists but is inference/no-grad only (a proven loser for training) and is intentionally off here.Example config + training basics
examples/nvfp4/qwen35-9b-lora-fastest.yaml(the locked fastest path):Training:
Notes: requires an NVFP4-capable Blackwell GPU + MSLK.
base_mode: computequantizes the frozen base to FP4; LoRA adapters train in bf16. A chunked-bf16-CE variant lives inqwen35-9b-lora-bf16-ce.yaml.Kept wins
torch_compile; fuseabsinto the global-amax reduction; batched shared-input LoRA A-GEMMs; chunked bf16 lm_head CE (no logits materialization).save_for_backwardon the saved-packs backward — bit-identical grads, lower peak memory.dqscratch underdkdv_scratch_bf16— bit-exact, frees the largest backward scratch plane.rms(drop the duplicatersqrt) — bit-exact.fused_rope_quant_qkreads the transposed Q/K view via 4D strides (drops the per-layer.contiguous()copy) — 1.45–1.52× producer, ~13% prefill attention-block, bit-identical (+ unit test).q/k/o_projwith shared qkv activation pack — default OFF, parity-affecting; the model-agnostic shared-pack lever for dense models.nvfp4_training.attention.{enabled, fuse_vproj, fp4_projections, backward.{...}}+ top-levellinear_attn/mlp/fla_causal_conv_compile_boundary. Old flatqwen3_5_native_attention*flags keep working via a deprecation migrator (parse + warn).Not included (proven losers, separate branches)
flash_attn(the earlier 1.34×/4.29× was an fp32-SDPA-baseline artifact).atomic_addcontention scales with Skv); two-kernel backward is optimal.Tests / lint
pre-commit run --all-filesgreen (ruff, ruff-format, mypy, bandit, yaml/eof/whitespace) — includes a repo-wide cleanup commit, since the base branch was not previously pre-commit-clean.🤖 Generated with Claude Code