From 4c67858e072de320edae14a6cf3313f09dcdd5a2 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:14:17 -0700 Subject: [PATCH 01/17] fix(nvfp4): auto-enable attention compile custom op under torch_compile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bare native-NVFP4 flash kernel uses tl.dot_scaled, which Inductor cannot compile: under torch_compile the autograd.Function path raises an Inductor CompilationError at the P@V dot_scaled during warm-cache precompile and, with the default error suppression, silently falls the whole attention region back to eager — blocking fusion of the surrounding elementwise quant/dequant. The differentiable opaque custom op already compiles around it with bit-identical forward and dq/dk/dv grads, so make qwen3_5_native_attention_compile_custom_op a tri-state (None default) that auto-enables whenever torch_compile is on. Explicit True/False still force the choice. --- src/axolotl/loaders/patch_manager.py | 4 +-- src/axolotl/utils/schemas/config.py | 9 +++++ src/axolotl/utils/schemas/nvfp4.py | 8 +++-- tests/e2e/test_nvfp4_integration.py | 52 ++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 5 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index d19854267c..58f2b56814 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -539,8 +539,8 @@ def _apply_qwen3_5_native_nvfp4_patches(self, model: PreTrainedModel): "qwen3_5_native_attention_dkdv_scratch_bf16", False, ), - compile_custom_op=getattr( - nvfp4, "qwen3_5_native_attention_compile_custom_op", False + compile_custom_op=bool( + getattr(nvfp4, "qwen3_5_native_attention_compile_custom_op", False) ), stochastic_rounding=nvfp4.stochastic_rounding, ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0fef647e40..f287904891 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1709,6 +1709,15 @@ def check_nvfp4_training(self): "nvfp4_training.qwen3_5_native_attention_compile_custom_op " "requires qwen3_5_native_attention: true." ) + # Tri-state auto-resolve: under torch_compile the bare tl.dot_scaled flash + # kernel raises an Inductor CompilationError and (with the default error + # suppression) silently falls the attention region back to eager, blocking + # fusion of the surrounding elementwise. The opaque custom op compiles + # around it with bit-identical grads, so default it on when compile is live. + if self.nvfp4_training.qwen3_5_native_attention_compile_custom_op is None: + self.nvfp4_training.qwen3_5_native_attention_compile_custom_op = bool( + self.nvfp4_training.qwen3_5_native_attention and self.torch_compile + ) if self.nvfp4_training.fp8_lm_head_cross_entropy and ( self.nvfp4_training.quantize_lm_head or self.nvfp4_training.fused_fp4_cross_entropy diff --git a/src/axolotl/utils/schemas/nvfp4.py b/src/axolotl/utils/schemas/nvfp4.py index 97f8143380..3a7a990598 100644 --- a/src/axolotl/utils/schemas/nvfp4.py +++ b/src/axolotl/utils/schemas/nvfp4.py @@ -274,8 +274,8 @@ class NVFP4TrainingConfig(BaseModel): "default." }, ) - qwen3_5_native_attention_compile_custom_op: bool = Field( - default=False, + qwen3_5_native_attention_compile_custom_op: bool | None = Field( + default=None, json_schema_extra={ "description": "Route the native NVFP4 flash-attention call through an " "opaque torch custom op as a torch.compile compatibility escape hatch " @@ -284,7 +284,9 @@ class NVFP4TrainingConfig(BaseModel): "qwen3_5_native_attention_backward it wraps a DIFFERENTIABLE custom op " "(forward + registered native-NVFP4 backward) so Inductor compiles " "around the whole attention instead of falling the backward subgraph " - "back to eager. Not a proven speed knob; OFF by default." + "back to eager. Tri-state: None auto-enables it whenever torch_compile " + "is on (the bare tl.dot_scaled path raises an Inductor CompilationError " + "there and silently falls the region back to eager); True/False force it." }, ) qwen3_5_fla_causal_conv_compile_boundary: bool = Field( diff --git a/tests/e2e/test_nvfp4_integration.py b/tests/e2e/test_nvfp4_integration.py index 92af5f1bda..fea7cd77eb 100644 --- a/tests/e2e/test_nvfp4_integration.py +++ b/tests/e2e/test_nvfp4_integration.py @@ -328,6 +328,58 @@ def test_qwen3_5_compile_custom_op_requires_native_attention(monkeypatch): ) +def test_qwen3_5_compile_custom_op_autoenabled_under_torch_compile(monkeypatch): + # Tri-state default (None): under torch_compile the bare tl.dot_scaled flash + # kernel raises an Inductor CompilationError and silently falls back to eager, + # so the opaque custom op is auto-enabled. + _supported(monkeypatch, True) + cfg = AxolotlConfigWCapabilities( + **BASE, + **CAPS, + model_config_type="qwen3_5", + torch_compile=True, + nvfp4_training={ + "enabled": True, + "qwen3_5_native_attention": True, + "qwen3_5_native_attention_backward": True, + }, + ) + assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is True + + +def test_qwen3_5_compile_custom_op_default_off_without_torch_compile(monkeypatch): + _supported(monkeypatch, True) + cfg = AxolotlConfigWCapabilities( + **BASE, + **CAPS, + model_config_type="qwen3_5", + nvfp4_training={ + "enabled": True, + "qwen3_5_native_attention": True, + "qwen3_5_native_attention_backward": True, + }, + ) + assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is False + + +def test_qwen3_5_compile_custom_op_explicit_optout_under_torch_compile(monkeypatch): + # An explicit False must survive auto-resolution (opt-out wins). + _supported(monkeypatch, True) + cfg = AxolotlConfigWCapabilities( + **BASE, + **CAPS, + model_config_type="qwen3_5", + torch_compile=True, + nvfp4_training={ + "enabled": True, + "qwen3_5_native_attention": True, + "qwen3_5_native_attention_backward": True, + "qwen3_5_native_attention_compile_custom_op": False, + }, + ) + assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is False + + def _tiny_lora_model(): """A 2-layer toy model wrapped with a PEFT LoRA adapter (CPU-friendly).""" import torch From ee7e66d9e58ee5d4318b8432dceca8d8f9982bfc Mon Sep 17 00:00:00 2001 From: Robert Gilbreth Date: Thu, 4 Jun 2026 00:07:55 -0700 Subject: [PATCH 02/17] perf(nvfp4): fuse abs into the global-amax reduction (drop AbsFunctor pass) The two-level NVFP4 quant prologue computed the per-tensor scale as torch.amax(torch.abs(t)), which materializes |t| in a full-size elementwise pass (AbsFunctor) before the reduce. Replace with an inf-norm reduction that folds the abs into the reduce kernel, eliminating that pass. Bit-identical (verified across shapes/dtypes incl. the all-zero edge), applied at every global-amax site: the fused MSLK two-level quant, the non-Hadamard recipe amax, the torchao RTN path, and the load-time weight/embedding quant. --- src/axolotl/utils/nvfp4_training.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/nvfp4_training.py b/src/axolotl/utils/nvfp4_training.py index 481cfa8265..9d43e5610e 100644 --- a/src/axolotl/utils/nvfp4_training.py +++ b/src/axolotl/utils/nvfp4_training.py @@ -209,6 +209,16 @@ def _one(x): return NVFP4Tensor.__tensor_unflatten__(inner, ctx, None, None) +def _abs_amax(t: torch.Tensor) -> torch.Tensor: + """Global max(|t|) as fp32 in ONE fused reduction (no materialized ``|t|``). + + Bit-identical to ``torch.amax(torch.abs(t))`` but the inf-norm reduction folds + the abs into the reduce kernel, dropping the separate ``AbsFunctor`` elementwise + pass that dominated the NVFP4 quant prologue. + """ + return torch.linalg.vector_norm(t, ord=float("inf")).to(torch.float32) + + def _quantize(t: torch.Tensor, policy: QuantPolicy): """Quantize a high-precision tensor to an NVFP4Tensor (along its last dim). @@ -241,7 +251,7 @@ def _quantize(t: torch.Tensor, policy: QuantPolicy): ) if policy.hadamard: t = _apply_rht(t).contiguous() - per_tensor_scale = per_tensor_amax_to_scale(torch.max(torch.abs(t))) + per_tensor_scale = per_tensor_amax_to_scale(_abs_amax(t)) if policy.stochastic: t = _sr_dither(t, per_tensor_scale).contiguous() # RHT/SR rewrite the whole tensor up front (no per-block-row independence), so @@ -574,7 +584,7 @@ def from_linear( ) w = linear.weight.detach() - pts = per_tensor_amax_to_scale(torch.max(torch.abs(w))) + pts = per_tensor_amax_to_scale(_abs_amax(w)) w_q = _to_nvfp4_chunked( w.contiguous(), pts, @@ -599,7 +609,7 @@ def _embedding_to_nvfp4(weight: torch.Tensor): ) w = weight.detach().contiguous() - pts = per_tensor_amax_to_scale(torch.max(torch.abs(w))) + pts = per_tensor_amax_to_scale(_abs_amax(w)) return _to_nvfp4_chunked( w, pts, QuantizeTensorToNVFP4Kwargs(block_size=_BLOCK_SIZE) ) @@ -1250,7 +1260,7 @@ def _recipe_m_per_block(m: int) -> int: def _recipe_rht_amax(t: torch.Tensor, hadamard: bool) -> torch.Tensor: if not hadamard: - return torch.amax(torch.abs(t)).to(torch.float32) + return _abs_amax(t) m, n = t.shape m_per_block = _recipe_m_per_block(m) grid = (triton.cdiv(n, 64), triton.cdiv(m, m_per_block)) @@ -1342,7 +1352,7 @@ def _mslk_quantize_op( from mslk.quantize.triton.fp4_quantize import triton_quantize_nvfp4 t = t.contiguous() - amax = torch.amax(torch.abs(t)).to(torch.float32) + amax = _abs_amax(t) global_scale = _NVFP4_GLOBAL_AMAX / torch.clamp(amax, min=1e-12) q, s = triton_quantize_nvfp4(t, global_scale) return ( @@ -2509,7 +2519,7 @@ def collect_nvfp4_packed_state(model: nn.Module) -> tuple[dict, set[str]]: if isinstance(module, NVFP4Linear): # FFT: pack the bf16 master to FP4 (lossy for resume). w = module.weight.detach() - pts = per_tensor_amax_to_scale(torch.max(torch.abs(w))) + pts = per_tensor_amax_to_scale(_abs_amax(w)) w_q = _to_nvfp4_chunked( w.contiguous(), pts, QuantizeTensorToNVFP4Kwargs(block_size=_BLOCK_SIZE) ) From 998d498cea2d5f4509a59c781981431a51deb45f Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:09:15 -0700 Subject: [PATCH 03/17] perf(lora): batch shared-input adapter GEMMs to cut launch overhead Concatenate the q/k/v and gate/up LoRA A matrices (which read the same input) along the rank dimension so the per-projection X@A products become a single GEMM, and fuse the matching dA backward into one X^T@grad_B. Opt-in via lora_batch_kernel; only active on the plain LoRA fast path (no DoRA/dropout/lora_bias). Default behavior unchanged. Bit-exact parity with the per-module path (outputs + all grads, fwd/bwd, inplace on and off). --- src/axolotl/kernels/lora.py | 222 +++++++++++++++++++++--- src/axolotl/monkeypatch/lora_kernels.py | 2 + src/axolotl/utils/schemas/config.py | 6 + tests/e2e/kernels/test_lora.py | 87 ++++++++++ 4 files changed, 291 insertions(+), 26 deletions(-) diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 999a10b5ad..7e0275bb8a 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -318,6 +318,71 @@ def matmul_lora( return out.view(batch, seq_len, -1) if reshape else out +def _batched_lora_forward( + X: torch.Tensor, + bases: list[torch.Tensor], + As: list[torch.Tensor | None], + Bs: list[torch.Tensor | None], + scales: list[float], +) -> list[torch.Tensor]: + """Add LoRA contributions to pre-computed base outputs with a single shared A GEMM. + + All projections must read the same input X. The per-projection A matrices + ([rank, in]) are concatenated along rank so `X @ A_cat^T` is one GEMM producing + [T, sum(rank)]; the result is split and each chunk fed to its own `@ B^T`. + Returns the bases mutated in place (base_i += s_i * X @ A_i^T @ B_i^T). + + Only valid for the plain LoRA path (no DoRA / dropout / lora_bias). Projections + with A is None are passed through unchanged. + """ + dtype = X.dtype + idx = [i for i, A in enumerate(As) if A is not None] + if not idx: + return bases + + Xf = X.reshape(-1, X.shape[-1]) + A_cat = torch.cat([As[i] for i in idx], dim=0).to(dtype) # [sum_r, in] + XA = Xf @ A_cat.t() # [T, sum_r] — single fused GEMM + + ranks = [As[i].shape[0] for i in idx] + offset = 0 + for j, i in enumerate(idx): + r = ranks[j] + chunk = XA[:, offset : offset + r] + offset += r + contrib = scales[i] * (chunk @ Bs[i].t().to(dtype)) # [T, out_i] + bases[i].view(-1, bases[i].shape[-1]).add_(contrib) + return bases + + +def _batched_lora_dA( + X_lora_t: torch.Tensor, + grad_Bs: list[torch.Tensor | None], + scales: list[float], + template_As: list[torch.Tensor | None], +) -> list[torch.Tensor | None]: + """Fused dA for projections sharing input: stack grad_B columns, one X_lora^T @ G GEMM. + + grad_B_i is [T, rank_i]; concatenating along rank gives [T, sum_r], so + `X_lora^T @ G_cat` is one GEMM yielding stacked dA [in, sum_r], split back per + projection. Returns dA per projection in [in, rank] layout (caller transposes). + """ + idx = [i for i, g in enumerate(grad_Bs) if g is not None] + out: list[torch.Tensor | None] = [None] * len(grad_Bs) + if not idx: + return out + + G_cat = torch.cat([grad_Bs[i] for i in idx], dim=1) # [T, sum_r] + dA_cat = X_lora_t @ G_cat # [in, sum_r] — single fused GEMM + offset = 0 + for i in idx: + r = grad_Bs[i].shape[1] + dA_cat[:, offset : offset + r].mul_(scales[i]) + out[i] = dA_cat[:, offset : offset + r] + offset += r + return out + + class LoRA_MLP(torch.autograd.Function): """Optimized LoRA MLP implementation. @@ -363,11 +428,20 @@ def forward( activation_fn: Callable, activation_fn_backward: Callable, inplace: bool | None = True, + batched: bool = False, ) -> torch.Tensor: has_dropout = X_drop is not None has_dora = gate_magnitude is not None dtype = X.dtype X_lora = X_drop if has_dropout else X + # Gate/up share input X; batch their A-GEMM on the plain path only. + can_batch_gu = ( + batched + and not has_dora + and not has_dropout + and gate_lora_bias is None + and up_lora_bias is None + ) if has_dora: # Gate with DoRA @@ -400,6 +474,16 @@ def forward( gate_combined = gate_base + gate_lora up_combined = up_base + up_lora + elif can_batch_gu: + gate = matmul_lora(X, gate_weight, gate_bias, gate_quant, None, None, None) + up = matmul_lora(X, up_weight, up_bias, up_quant, None, None, None) + _batched_lora_forward( + X, + [gate, up], + [gate_A, up_A], + [gate_B, up_B], + [gate_scale, up_scale], + ) else: gate = matmul_lora( X, @@ -513,6 +597,7 @@ def forward( ctx.inplace = inplace ctx.has_dropout = has_dropout ctx.has_dora = has_dora + ctx.can_batch_gu = can_batch_gu return output @@ -658,19 +743,39 @@ def backward( if up_A_t is not None and up_B_t is not None: grad_B_up = grad_up @ up_B_t.t() # [T, rank] — reuse for dX - d_up_A = torch.empty_like(up_A_t) - d_up_B = torch.empty_like(up_B_t) - d_up_A.addmm_(X_lora.t(), grad_B_up, alpha=up_scale, beta=0) - d_up_B.addmm_(up_A_t.t() @ X_lora.t(), grad_up, alpha=up_scale, beta=0) - if gate_A_t is not None and gate_B_t is not None: grad_B_gate = grad_gate @ gate_B_t.t() # [T, rank] — reuse for dX - d_gate_A = torch.empty_like(gate_A_t) - d_gate_B = torch.empty_like(gate_B_t) - d_gate_A.addmm_(X_lora.t(), grad_B_gate, alpha=gate_scale, beta=0) - d_gate_B.addmm_( - gate_A_t.t() @ X_lora.t(), grad_gate, alpha=gate_scale, beta=0 + + if getattr(ctx, "can_batch_gu", False): + X_lora_t_gu = X_lora.t() + d_gate_A, d_up_A = _batched_lora_dA( + X_lora_t_gu, + [grad_B_gate, grad_B_up], + [gate_scale, up_scale], + [gate_A_t, up_A_t], ) + if grad_B_up is not None: + d_up_B = torch.empty_like(up_B_t) + d_up_B.addmm_(up_A_t.t() @ X_lora_t_gu, grad_up, alpha=up_scale, beta=0) + if grad_B_gate is not None: + d_gate_B = torch.empty_like(gate_B_t) + d_gate_B.addmm_( + gate_A_t.t() @ X_lora_t_gu, grad_gate, alpha=gate_scale, beta=0 + ) + else: + if grad_B_up is not None: + d_up_A = torch.empty_like(up_A_t) + d_up_B = torch.empty_like(up_B_t) + d_up_A.addmm_(X_lora.t(), grad_B_up, alpha=up_scale, beta=0) + d_up_B.addmm_(up_A_t.t() @ X_lora.t(), grad_up, alpha=up_scale, beta=0) + + if grad_B_gate is not None: + d_gate_A = torch.empty_like(gate_A_t) + d_gate_B = torch.empty_like(gate_B_t) + d_gate_A.addmm_(X_lora.t(), grad_B_gate, alpha=gate_scale, beta=0) + d_gate_B.addmm_( + gate_A_t.t() @ X_lora.t(), grad_gate, alpha=gate_scale, beta=0 + ) # Compute input gradients dX = None @@ -751,7 +856,8 @@ def backward( None, d_down_lora_bias, d_down_mag, - # Activation fns and flags + # activation_fn, activation_fn_backward, inplace, batched + None, None, None, None, @@ -810,6 +916,7 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch. swiglu_forward, swiglu_backward, inplace, + getattr(self, "_lora_batch_kernel", False), ) return out @@ -866,6 +973,7 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T geglu_forward, geglu_backward, inplace, + getattr(self, "_lora_batch_kernel", False), ) return out @@ -914,9 +1022,19 @@ def forward( v_magnitude: torch.Tensor | None, # Flags inplace: bool = True, + batched: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: has_dropout = X_drop is not None has_dora = q_magnitude is not None + # Batched LoRA only valid on the plain path (no DoRA/dropout/lora_bias). + can_batch = ( + batched + and not has_dora + and not has_dropout + and q_lora_bias is None + and k_lora_bias is None + and v_lora_bias is None + ) if has_dora: dtype = X.dtype @@ -979,6 +1097,33 @@ def forward( k_lora_bias, v_lora_bias, ) + elif can_batch: + # Base outputs only (X @ W), then one fused A-GEMM for all LoRA paths. + Q = matmul_lora(X, q_weight, q_bias, q_quant, None, None, None) + K = matmul_lora(X, k_weight, k_bias, k_quant, None, None, None) + V = matmul_lora(X, v_weight, v_bias, v_quant, None, None, None) + _batched_lora_forward( + X, + [Q, K, V], + [q_A, k_A, v_A], + [q_B, k_B, v_B], + [q_scale, k_scale, v_scale], + ) + + dtype = X.dtype + ctx.save_for_backward( + X, + X, + q_A.to(dtype) if q_A is not None else q_A, + q_B.to(dtype) if q_B is not None else q_B, + k_A.to(dtype) if k_A is not None else k_A, + k_B.to(dtype) if k_B is not None else k_B, + v_A.to(dtype) if v_A is not None else v_A, + v_B.to(dtype) if v_B is not None else v_B, + None, + None, + None, + ) else: # Standard LoRA (with optional dropout and bias) Q = matmul_lora( @@ -1038,6 +1183,7 @@ def forward( ctx.inplace = inplace ctx.has_dropout = has_dropout ctx.has_dora = has_dora + ctx.can_batch = can_batch return Q, K, V @@ -1158,24 +1304,46 @@ def backward( if A_q is not None and B_q is not None: grad_B_q = q_grad @ B_q # [T, rank] — reused for dA and dX - d_A_q = torch.empty_like(A_q.t()) - d_B_q = torch.empty_like(B_q.t()) - d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0) - d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0) - if A_k is not None and B_k is not None: grad_B_k = k_grad @ B_k - d_A_k = torch.empty_like(A_k.t()) - d_B_k = torch.empty_like(B_k.t()) - d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0) - d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0) - if A_v is not None and B_v is not None: grad_B_v = v_grad @ B_v - d_A_v = torch.empty_like(A_v.t()) - d_B_v = torch.empty_like(B_v.t()) - d_A_v.addmm_(X_lora_t, grad_B_v, alpha=v_scale, beta=0) - d_B_v.addmm_(A_v @ X_lora_t, v_grad, alpha=v_scale, beta=0) + + if getattr(ctx, "can_batch", False): + # One fused X_lora^T @ grad_B_cat for all three dA. + d_A_q, d_A_k, d_A_v = _batched_lora_dA( + X_lora_t, + [grad_B_q, grad_B_k, grad_B_v], + [q_scale, k_scale, v_scale], + [A_q, A_k, A_v], + ) + if grad_B_q is not None: + d_B_q = torch.empty_like(B_q.t()) + d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0) + if grad_B_k is not None: + d_B_k = torch.empty_like(B_k.t()) + d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0) + if grad_B_v is not None: + d_B_v = torch.empty_like(B_v.t()) + d_B_v.addmm_(A_v @ X_lora_t, v_grad, alpha=v_scale, beta=0) + else: + if grad_B_q is not None: + d_A_q = torch.empty_like(A_q.t()) + d_B_q = torch.empty_like(B_q.t()) + d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0) + d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0) + + if grad_B_k is not None: + d_A_k = torch.empty_like(A_k.t()) + d_B_k = torch.empty_like(B_k.t()) + d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0) + d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0) + + if grad_B_v is not None: + d_A_v = torch.empty_like(A_v.t()) + d_B_v = torch.empty_like(B_v.t()) + d_A_v.addmm_(X_lora_t, grad_B_v, alpha=v_scale, beta=0) + d_B_v.addmm_(A_v @ X_lora_t, v_grad, alpha=v_scale, beta=0) # Base path input gradient (can use inplace on X since X_lora refs are done) from axolotl.utils.nvfp4_training import is_nvfp4_base, nvfp4_base_dgrad @@ -1275,7 +1443,8 @@ def backward( None, d_v_lora_bias, d_v_mag, - # inplace + # inplace, batched + None, None, ) @@ -1356,6 +1525,7 @@ def apply_lora_qkv( Vmag, # Flags inplace, + getattr(self, "_lora_batch_kernel", False), ) return Q, K, V diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 8156b72c7d..61bfad7c82 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -477,6 +477,7 @@ def apply_lora_kernel_patches( ) if can_patch_qkv: + self_attn._lora_batch_kernel = bool(cfg.lora_batch_kernel) if has_v_proj: self_attn.apply_qkv = types.MethodType( apply_lora_qkv, self_attn @@ -517,6 +518,7 @@ def apply_lora_kernel_patches( ) if can_patch_mlp: + mlp._lora_batch_kernel = bool(cfg.lora_batch_kernel) apply_fn = APPLY_FN_MAPPING[activation] layer.mlp.forward = types.MethodType(apply_fn, mlp) else: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f287904891..20ced08939 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -938,6 +938,12 @@ class AxolotlInputConfig( "description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html" }, ) + lora_batch_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Batch the per-projection LoRA adapter GEMMs that share an input (q/k/v and gate/up) into a single concatenated matmul to cut tiny-kernel launch overhead. Opt-in; only affects the lora_qkv_kernel/lora_mlp_kernel fast path (no DoRA/dropout/lora_bias)." + }, + ) chunked_cross_entropy: bool | None = Field( default=None, diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index 10850bdc85..7e7f66bd28 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -580,3 +580,90 @@ def test_inplace_operations(sample_tensors, apply_function): out2 = apply_function(mlp, X.clone(), inplace=False) assert torch.allclose(out1, out2, rtol=1e-3) + + +def _lora_pair(out_f, in_f, rank, device="cuda"): + A = (torch.randn(rank, in_f, device=device) * 0.02).requires_grad_(True) + B = (torch.randn(out_f, rank, device=device) * 0.02).requires_grad_(True) + return A, B + + +@pytest.mark.parametrize("inplace", [True, False]) +def test_batched_qkv_matches_per_module(inplace): + """Batched LoRA QKV must match the per-module path bit-for-bit (same math).""" + torch.manual_seed(0) + bsz, seq, in_f, rank = 2, 7, 256, 16 + dt = torch.bfloat16 + qo, ko, vo = 256, 128, 128 + + def run(batched): + torch.manual_seed(1) + X = torch.randn(bsz, seq, in_f, device="cuda", dtype=dt, requires_grad=True) + weights = [ + torch.randn(o, in_f, device="cuda", dtype=dt) * 0.02 for o in (qo, ko, vo) + ] + qA, qB = _lora_pair(qo, in_f, rank) + kA, kB = _lora_pair(ko, in_f, rank) + vA, vB = _lora_pair(vo, in_f, rank) + q, k, v = LoRA_QKV.apply( + X, None, + weights[0], None, None, qA, qB, 2.0, None, None, + weights[1], None, None, kA, kB, 2.0, None, None, + weights[2], None, None, vA, vB, 2.0, None, None, + inplace, batched, + ) + loss = (q.float() ** 2).sum() + (k.float() ** 2).sum() + (v.float() ** 2).sum() + loss.backward() + return ( + torch.cat([q.reshape(-1), k.reshape(-1), v.reshape(-1)]).detach(), + X.grad.detach().clone(), + [g.grad.detach().clone() for g in (qA, qB, kA, kB, vA, vB)], + ) + + o1, x1, g1 = run(False) + o2, x2, g2 = run(True) + assert torch.isfinite(o2).all() and torch.isfinite(x2).all() + assert torch.equal(o1, o2) + assert torch.equal(x1, x2) + for a, b in zip(g1, g2): + assert torch.equal(a, b) + + +@pytest.mark.parametrize("inplace", [True, False]) +def test_batched_mlp_matches_per_module(inplace): + """Batched LoRA MLP (gate/up fused) must match per-module path bit-for-bit.""" + torch.manual_seed(0) + bsz, seq, in_f, inter, rank = 2, 7, 256, 512, 16 + dt = torch.bfloat16 + + def run(batched): + torch.manual_seed(2) + X = torch.randn(bsz, seq, in_f, device="cuda", dtype=dt, requires_grad=True) + gW = torch.randn(inter, in_f, device="cuda", dtype=dt) * 0.02 + uW = torch.randn(inter, in_f, device="cuda", dtype=dt) * 0.02 + dW = torch.randn(in_f, inter, device="cuda", dtype=dt) * 0.02 + gA, gB = _lora_pair(inter, in_f, rank) + uA, uB = _lora_pair(inter, in_f, rank) + dA, dB = _lora_pair(in_f, inter, rank) + out = LoRA_MLP.apply( + X, None, + gW, None, None, gA, gB, 2.0, None, None, + uW, None, None, uA, uB, 2.0, None, None, + dW, None, None, dA, dB, 2.0, None, None, + swiglu_forward, swiglu_backward, + inplace, batched, + ) + (out.float() ** 2).sum().backward() + return ( + out.detach().reshape(-1), + X.grad.detach().clone(), + [g.grad.detach().clone() for g in (gA, gB, uA, uB, dA, dB)], + ) + + o1, x1, g1 = run(False) + o2, x2, g2 = run(True) + assert torch.isfinite(o2).all() and torch.isfinite(x2).all() + assert torch.equal(o1, o2) + assert torch.equal(x1, x2) + for a, b in zip(g1, g2): + assert torch.equal(a, b) From a62d419293acb610a1b12eb3e467df35aef77c76 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:10:51 -0700 Subject: [PATCH 04/17] feat(nvfp4): chunked bf16 lm_head cross-entropy (no logits materialization) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add nvfp4_training.bf16_lm_head_cross_entropy (default off): a CCE-style online-softmax CE that tiles the frozen bf16 lm_head over the vocab so the full [tokens, vocab] logit tensor and its gradient GEMM are never materialized. The per-tile matmul is plain bf16 (bit-for-bit the materialized hidden @ W.t()); logsumexp/softmax and dL/dhidden accumulate in fp32. No gradient filtering, so the returned gradient is the exact tiled CE gradient — convergence-safe under NVFP4 stochastic-rounding grads where cut_cross_entropy / Liger collapsed. Returns dL/dhidden only (frozen lm_head). Wired in patch_manager, registered in the central CE mutual-exclusivity check and guarded against quantize_lm_head / fused_fp4 / fp8 CE. Loss & grad validated bit-close to F.cross_entropy and finite at Qwen3.5 vocab scale. --- examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml | 136 +++++++++++ src/axolotl/kernels/bf16_fused_ce.py | 272 +++++++++++++++++++++ src/axolotl/loaders/patch_manager.py | 7 + src/axolotl/utils/schemas/config.py | 10 + src/axolotl/utils/schemas/nvfp4.py | 19 ++ src/axolotl/utils/schemas/validation.py | 4 + tests/e2e/test_nvfp4_integration.py | 33 +++ tests/kernels/test_bf16_fused_ce.py | 92 +++++++ 8 files changed, 573 insertions(+) create mode 100644 examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml create mode 100644 src/axolotl/kernels/bf16_fused_ce.py create mode 100644 tests/kernels/test_bf16_fused_ce.py diff --git a/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml b/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml new file mode 100644 index 0000000000..9494d43c4d --- /dev/null +++ b/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml @@ -0,0 +1,136 @@ +# Fastest NVFP4 Qwen3.5-9B LoRA path + chunked bf16 lm_head cross-entropy. +# +# Identical to qwen35-9b-lora-fastest.yaml but with +# nvfp4_training.bf16_lm_head_cross_entropy: true. The lm_head stays excluded +# from FP4 (bf16) and frozen under LoRA, so the chunked bf16 CE tiles the +# projection over the vocab instead of materializing the full [tokens, vocab] +# logits (and its gradient GEMM). This is a memory / backward-traffic win, not an +# FP4 tensor-core throughput win, and it is the convergence-safe alternative to +# cut_cross_entropy / Liger fused-linear CE, which loss-collapsed in this setup. +# Loss & dL/dhidden are bit-close to F.cross_entropy over full logits. +# + +# Target shape: text-only SFT, sequence_len 2048, sample packing, no gradient +# checkpointing. Measured on RTX PRO 6000 Blackwell 96GB with +# scripts/bench_nvfp4.sh (20 -> 60 marginal train_runtime): +# bf16 b6 stable baseline: 2.1550 s/step, ~5701 tok/s, 89.61 GiB active +# (learning_rate 2e-5; finite grad norms) +# all-on NVFP4 b4 baseline: 1.2475 s/step, ~6567 tok/s, 69.26 GiB active +# this saved-pack path: 1.1525-1.2075 s/step, ~6784-7108 tok/s, +# 69.63 GiB active +# plus FLA boundary: 1.1708 s/step, ~6997 tok/s, 69.63 GiB active +# plus BF16 dK/dV scratch: 1.1415-1.1485 s/step, ~7133-7177 tok/s, +# 62.03 GiB active +# beta lock-down repeat: 1.2017 s/step, ~6817 tok/s, +# 62.03 GiB active +# clean GPU3 repeat: 1.0635 s/step, ~7701 tok/s, +# 69.63 GiB active +# b5 no CE repeat: 1.5195 s/step, ~6740 tok/s, +# 74.43 GiB active +# b6 no CE repeat: 1.8473 s/step, ~6652 tok/s, +# 86.83 GiB active +# FP8 lm_head CE b6: 1.7700-1.7807 s/step, ~6902-6943 tok/s, +# 80.32 GiB active (slower; one validation run +# loss-collapsed with non-finite grad norms) +# FP8 lm_head CE b5 repeat: 1.5312 s/step, ~6688 tok/s, 69.60 GiB active +# (memory win, not a throughput win) +# +# Do not use fp16 as the convergence baseline here: with NVFP4 omitted and +# max_grad_norm: 1.0 explicit, fp16 produced NaN LoRA gradients at the first AMP +# unscale. bf16 is the natural full-precision Blackwell baseline. +# +# The speed knob here is qwen3_5_native_attention_save_backward_packs: it saves +# forward FP4 attention packs and reuses them in backward, trading ~0.4 GiB of +# activation memory for less backward pack-prep work. The FLA boundary prevents +# variable packed cu_seqlens from burning compile time in causal_conv1d. BF16 +# dK/dV scratch reduces the attention-backward scratch traffic before GQA reduce. +# The no-grad/eval-only Qwen3.5 native MLP, native linear-attn, fused v_proj, +# standalone fp8_lm_head, FP8 lm_head CE, CCE, and Liger fused-linear CE +# switches are intentionally omitted. Same-prepared-data ablations measured: +# 1.2017 s/step without legacy switches vs 1.2335 s/step with them; FP8 CE b5 +# 1.5312 s/step; CCE/Liger CE loss-collapse with non-finite grad norms in this +# NVFP4 training setup. +# +# For benchmark comparisons, copy this config to /tmp, set base_model to the +# local model path, and keep dataset_prepared_path fixed across compared runs. +base_model: Qwen/Qwen3.5-9B +model_config_type: qwen3_5 +strict: false + +chat_template: qwen3_5 +datasets: + - path: yahma/alpaca-cleaned + type: alpaca +val_set_size: 0.0 +dataset_prepared_path: /tmp/axolotl_nvfp4_qwen35_fastest_prepared +output_dir: /tmp/axolotl_nvfp4_qwen35_fastest_out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +load_in_8bit: false +load_in_4bit: false +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - gate_proj + - up_proj + - down_proj + - linear_attn.in_proj_qkv + - linear_attn.in_proj_z + - linear_attn.out_proj + +gradient_accumulation_steps: 1 +# This measured b4 path targets 96GB Blackwell. Reduce micro_batch_size on +# smaller cards or enable gradient_checkpointing if you need memory headroom. +micro_batch_size: 4 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2.0e-4 +max_grad_norm: 1.0 + +bf16: true +fp16: false +tf32: true +torch_compile: true + +gradient_checkpointing: false +attn_implementation: flash_attention_2 + +nvfp4_training: + enabled: true + base_mode: compute + stochastic_rounding: true + hadamard: true + exclude_modules: [lm_head, embed_tokens] + skip_first_n_blocks: 0 + skip_last_n_blocks: 0 + fuse_rmsnorm: false + + # Chunked bf16 lm_head CE: skip materializing the [tokens, vocab] logits. + bf16_lm_head_cross_entropy: true + + # Qwen3.5 full-attention training path. The saved-pack flag is the measured + # throughput win; RTN grad packs keep the safe gradient-side packs deterministic. + qwen3_5_native_attention: true + qwen3_5_native_attention_backward: true + qwen3_5_native_attention_backward_rtn_grad_packs: true + qwen3_5_native_attention_save_backward_packs: true + qwen3_5_native_attention_dkdv_scratch_bf16: true + qwen3_5_fla_causal_conv_compile_boundary: true + +warmup_steps: 10 +logging_steps: 1 +save_strategy: "no" +saves_per_epoch: +evals_per_epoch: +weight_decay: 0.0 +special_tokens: diff --git a/src/axolotl/kernels/bf16_fused_ce.py b/src/axolotl/kernels/bf16_fused_ce.py new file mode 100644 index 0000000000..5bd377e942 --- /dev/null +++ b/src/axolotl/kernels/bf16_fused_ce.py @@ -0,0 +1,272 @@ +"""Chunked bf16 lm_head + cross-entropy without materializing full logits. + +For the NVFP4 fastest path the lm_head is excluded from FP4 and stays a frozen +bf16 ``nn.Linear``. The default HF forward still materializes the full +``[batch*seq, vocab]`` bf16 logit tensor (and its fp32 upcast) before CE, plus a +matching logits-gradient GEMM in backward. This module fuses the projection with +the loss the way Cut Cross-Entropy does — tiling over the vocab, computing one +``[M, V_BLOCK]`` logit tile at a time, accumulating the logsumexp/label-logit in +fp32 — but the tile GEMM runs in plain bf16, bit-for-bit the same arithmetic as +the materialized path's ``hidden @ W.t()``, with no extra quantization. + +Numerical-safety choices (the prior CCE/Liger fused-linear-CE collapsed here with +non-finite grad norms under NVFP4 stochastic-rounding grads + max_grad_norm AMP +unscale): + * No gradient filtering / low-probability vocab skipping. The returned + ``dL/dhidden`` is the exact tiled CE gradient, not an approximation. + * logsumexp and softmax recomputed max-shifted in fp32. + * ``grad_hidden`` accumulated in fp32 across tiles, cast to bf16 once at the + end (avoids the per-tile bf16 accumulation drift of the FP4 CE path). + +This is a MEMORY win (no ``[M, V]`` logits) and a backward-traffic win (no full +logits-grad GEMM materialized), not an FP4 tensor-core throughput win — the tile +GEMM is bf16. Frozen, bias-free lm_head only (returns dL/dhidden, no weight grad). +""" + +from __future__ import annotations + +import functools + +import torch +from torch import nn + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +# Vocab tile width. The transient fp32 logit tile is [M, _VOCAB_BLOCK]; 4096 keeps +# it small (16 MiB at M=4096) while the bf16 tile GEMM stays efficient. Tunable. +_VOCAB_BLOCK = 4096 + +_PATCHED_FORWARDS: set[type] = set() + + +class _BF16FusedCrossEntropy(torch.autograd.Function): + """Tiled bf16 lm_head -> fp32 logsumexp/gather -> CE, no ``[M, V]`` logits. + + forward accumulates, per vocab tile, the running fp32 logsumexp (max-shifted) + and the gathered label logit. backward recomputes the softmax tile-by-tile + from the saved logsumexp and accumulates ``dL/dhidden = (softmax - onehot) @ W`` + in fp32 — lm_head is frozen, so no weight grad. + + ``grad_scale`` is the per-token weight already folded into the reduction + (1/num_items for grad-accum, else 1/valid_count), so backward stays a pure + function of the saved tensors. + """ + + @staticmethod + def forward(ctx, hidden, weight, labels, ignore_index, logit_scale, grad_scale): + # hidden: [M, H] (2D, contiguous), weight: [V, H], labels: [M] + M = hidden.shape[0] + V = weight.shape[0] + device = hidden.device + + valid = labels != ignore_index + safe_labels = torch.where(valid, labels, labels.new_zeros(())) + + running_max = torch.full((M,), float("-inf"), device=device, dtype=torch.float32) + running_sum = torch.zeros(M, device=device, dtype=torch.float32) + label_logit = torch.zeros(M, device=device, dtype=torch.float32) + + for lo in range(0, V, _VOCAB_BLOCK): + hi = min(lo + _VOCAB_BLOCK, V) + # bf16 tile GEMM (identical to the materialized hidden @ W.t()), fp32 + # only for the reduction. + logits = (hidden @ weight[lo:hi].t()).float() # [M, Vb] + if logit_scale != 1.0: + logits = logits * logit_scale + + tile_max = logits.max(dim=1).values + new_max = torch.maximum(running_max, tile_max) + running_sum = running_sum * torch.exp(running_max - new_max) + torch.exp( + logits - new_max.unsqueeze(1) + ).sum(dim=1) + running_max = new_max + + in_tile = (safe_labels >= lo) & (safe_labels < hi) + cols = (safe_labels - lo).clamp(0, hi - lo - 1) + gathered = logits.gather(1, cols.unsqueeze(1)).squeeze(1) + label_logit = torch.where(in_tile, gathered, label_logit) + + lse = running_max + torch.log(running_sum) + loss = ((lse - label_logit) * valid.float()).sum() * grad_scale + + ctx.save_for_backward(hidden, weight, lse, safe_labels, valid) + ctx.logit_scale = logit_scale + ctx.grad_scale = grad_scale + return loss + + @staticmethod + def backward(ctx, grad_loss): + hidden, weight, lse, safe_labels, valid = ctx.saved_tensors + V = weight.shape[0] + M, H = hidden.shape + rows = torch.arange(M, device=hidden.device) + + # d(loss)/d(logit_v) = grad_loss * grad_scale * mask * (softmax_v - onehot_v) * logit_scale + coef = ( + grad_loss.float() * ctx.grad_scale * valid.float() * ctx.logit_scale + ).unsqueeze(1) # [M, 1] + + grad_hidden = torch.zeros(M, H, device=hidden.device, dtype=torch.float32) + for lo in range(0, V, _VOCAB_BLOCK): + hi = min(lo + _VOCAB_BLOCK, V) + logits = (hidden @ weight[lo:hi].t()).float() + if ctx.logit_scale != 1.0: + logits = logits * ctx.logit_scale + sm = torch.exp(logits - lse.unsqueeze(1)) # softmax tile [M, Vb] + + in_tile = (safe_labels >= lo) & (safe_labels < hi) + cols = (safe_labels - lo).clamp(0, hi - lo - 1) + sm[rows, cols] -= in_tile.float() # subtract onehot in place + + grad_hidden += (sm * coef) @ weight[lo:hi].float() + + return grad_hidden.to(hidden.dtype), None, None, None, None, None + + +def bf16_lm_head_cross_entropy( + hidden: torch.Tensor, + lm_head: nn.Linear, + labels: torch.Tensor, + *, + ignore_index: int = -100, + num_items_in_batch=None, + shift: bool = True, + logit_scale: float = 1.0, +) -> torch.Tensor | None: + """Chunked bf16 lm_head + CE, or None if the head isn't a plain frozen Linear. + + Mirrors ``ForCausalLMLoss``: shifts labels by one (predict next token), + flattens, and reduces by sum/num_items (grad-accum) or mean over the unmasked + tokens. Returns None for a non-plain / trainable / biased lm_head so the + caller falls back to the materialized CE path. + """ + if type(lm_head) is not nn.Linear: + return None + if lm_head.bias is not None or lm_head.weight.requires_grad: + return None + if hidden.device.type != "cuda": + return None + + if shift: + labels = nn.functional.pad(labels, (0, 1), value=ignore_index)[..., 1:] + hidden2d = hidden.reshape(-1, hidden.shape[-1]).contiguous() + labels1d = labels.reshape(-1).to(hidden.device) + + valid = labels1d != ignore_index + if num_items_in_batch is not None: + if torch.is_tensor(num_items_in_batch): + grad_scale = num_items_in_batch.to( + device=hidden.device, dtype=torch.float32 + ).reciprocal() + else: + grad_scale = 1.0 / float(num_items_in_batch) + else: + grad_scale = 1.0 / valid.sum().clamp(min=1).float() + + return _BF16FusedCrossEntropy.apply( + hidden2d, lm_head.weight, labels1d, ignore_index, logit_scale, grad_scale + ) + + +def _make_fused_forward(orig_forward): + from transformers.modeling_outputs import CausalLMOutputWithPast + + # Preserve the original signature: the Trainer inspects forward via + # _remove_unused_columns; a bare *args/**kwargs wrapper would hide + # input_ids/labels and drop every dataset column. + @functools.wraps(orig_forward) + def forward(self, *args, **kwargs): + labels = kwargs.get("labels") + if ( + labels is None + or not getattr(self, "_axolotl_bf16_lm_head_ce_enabled", False) + or not self.training + or kwargs.get("logits_to_keep") + or kwargs.get("return_dict") is False + ): + return orig_forward(self, *args, **kwargs) + + lm_head = self.get_output_embeddings() + labels = kwargs.pop("labels") + num_items_in_batch = kwargs.pop("num_items_in_batch", None) + base = getattr(self, "model", None) + if base is None: + kwargs["labels"] = labels + if num_items_in_batch is not None: + kwargs["num_items_in_batch"] = num_items_in_batch + return orig_forward(self, *args, **kwargs) + + outputs = base(*args, **kwargs) + loss = bf16_lm_head_cross_entropy( + outputs.last_hidden_state, + lm_head, + labels, + num_items_in_batch=num_items_in_batch, + shift=True, + ) + if loss is None: # head became non-plain mid-run -> safe fallback + kwargs["labels"] = labels + if num_items_in_batch is not None: + kwargs["num_items_in_batch"] = num_items_in_batch + return orig_forward(self, *args, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=None, + past_key_values=getattr(outputs, "past_key_values", None), + hidden_states=getattr(outputs, "hidden_states", None), + attentions=getattr(outputs, "attentions", None), + ) + + return forward + + +def patch_model_bf16_lm_head_cross_entropy(model: nn.Module) -> bool: + """Patch ``model``'s ForCausalLM forward to use the chunked bf16 CE. + + Returns True if a patch was installed (frozen bias-free nn.Linear lm_head), + False otherwise. Idempotent per ForCausalLM class. The PEFT wrapper delegates + its forward to the base model, so patching the underlying ForCausalLM class is + enough whether or not LoRA is in use. + """ + causal = model + if hasattr(model, "get_base_model"): + try: + causal = model.get_base_model() + except Exception: + causal = model + + try: + lm_head = causal.get_output_embeddings() + except (AttributeError, NotImplementedError): + LOG.warning("bf16_lm_head_cross_entropy: model has no output embeddings") + return False + + if type(lm_head) is not nn.Linear: + LOG.warning( + "bf16_lm_head_cross_entropy: output embedding is %s, not a plain " + "nn.Linear (NVFP4-quantized or LoRA-wrapped lm_head is not supported " + "here; keeping the materialized CE path).", + type(lm_head).__name__, + ) + return False + if lm_head.bias is not None or lm_head.weight.requires_grad: + LOG.warning( + "bf16_lm_head_cross_entropy: requires a frozen bias-free lm_head; " + "keeping the materialized CE path." + ) + return False + + causal._axolotl_bf16_lm_head_ce_enabled = True + cls = causal.__class__ + if cls in _PATCHED_FORWARDS: + return True + cls.forward = _make_fused_forward(cls.forward) + _PATCHED_FORWARDS.add(cls) + LOG.info( + "bf16_lm_head_cross_entropy: patched %s.forward (logits not materialized)", + cls.__name__, + ) + return True diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 58f2b56814..792a80cdb9 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -502,6 +502,13 @@ def _apply_nvfp4_training(self, model: PreTrainedModel): granularity=getattr(nvfp4, "fp8_lm_head_granularity", "rowwise"), ) + if getattr(nvfp4, "bf16_lm_head_cross_entropy", False): + from axolotl.kernels.bf16_fused_ce import ( + patch_model_bf16_lm_head_cross_entropy, + ) + + patch_model_bf16_lm_head_cross_entropy(model) + def _apply_qwen3_5_native_nvfp4_patches(self, model: PreTrainedModel): nvfp4 = self.cfg.nvfp4_training if not (nvfp4 and nvfp4.enabled): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 20ced08939..04bcbaa9e1 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1734,6 +1734,16 @@ def check_nvfp4_training(self): "fused_fp4_cross_entropy, or use fused_fp4_cross_entropy for an " "NVFP4-quantized lm_head." ) + if self.nvfp4_training.bf16_lm_head_cross_entropy and ( + self.nvfp4_training.quantize_lm_head + or self.nvfp4_training.fused_fp4_cross_entropy + or self.nvfp4_training.fp8_lm_head_cross_entropy + ): + raise ValueError( + "nvfp4_training.bf16_lm_head_cross_entropy requires the lm_head to " + "remain a frozen plain nn.Linear. Disable quantize_lm_head/" + "fused_fp4_cross_entropy/fp8_lm_head_cross_entropy." + ) qwen3_5_native_flags = ( self.nvfp4_training.qwen3_5_native_attention, self.nvfp4_training.qwen3_5_native_attention_backward, diff --git a/src/axolotl/utils/schemas/nvfp4.py b/src/axolotl/utils/schemas/nvfp4.py index 3a7a990598..578d779d03 100644 --- a/src/axolotl/utils/schemas/nvfp4.py +++ b/src/axolotl/utils/schemas/nvfp4.py @@ -104,6 +104,25 @@ class NVFP4TrainingConfig(BaseModel): "OFF by default." }, ) + bf16_lm_head_cross_entropy: bool = Field( + default=False, + json_schema_extra={ + "description": "Patch a plain frozen bias-free nn.Linear lm_head " + "training forward to compute cross-entropy by tiling over the vocab in " + "bf16, avoiding full [batch*seq, vocab] logits materialization (~1 GiB " + "at vocab 152k / seq 8k) and the matching logits-gradient GEMM. The " + "per-tile lm_head matmul runs in plain bf16 (bit-for-bit the same " + "arithmetic as the materialized hidden @ W.t()); logsumexp/softmax and " + "the dL/dhidden accumulation are kept in fp32. This is the exact tiled " + "CE gradient (no low-probability vocab filtering), so it is " + "convergence-safe under NVFP4 stochastic-rounding grads where the fused " + "cut_cross_entropy / Liger paths collapsed. Returns dL/dhidden only (no " + "lm_head weight grad). Incompatible with quantize_lm_head, " + "fused_fp4_cross_entropy, and the FP8 cross-entropy patch. This is a " + "MEMORY/backward-traffic win, not an FP4 tensor-core throughput win. " + "OFF by default." + }, + ) fp8_lm_head: bool = Field( default=False, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index ec916ce23e..319a3902f2 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1026,11 +1026,15 @@ def check_cross_entropy_conflicts(cls, data): if isinstance(nvfp4, dict): nvfp4_enabled = nvfp4.get("enabled") fp8_lm_head_ce = nvfp4.get("fp8_lm_head_cross_entropy") + bf16_lm_head_ce = nvfp4.get("bf16_lm_head_cross_entropy") else: nvfp4_enabled = getattr(nvfp4, "enabled", None) fp8_lm_head_ce = getattr(nvfp4, "fp8_lm_head_cross_entropy", None) + bf16_lm_head_ce = getattr(nvfp4, "bf16_lm_head_cross_entropy", None) if nvfp4_enabled and fp8_lm_head_ce: ce_options["nvfp4_training.fp8_lm_head_cross_entropy"] = True + if nvfp4_enabled and bf16_lm_head_ce: + ce_options["nvfp4_training.bf16_lm_head_cross_entropy"] = True enabled_options = [k for k, v in ce_options.items() if v] diff --git a/tests/e2e/test_nvfp4_integration.py b/tests/e2e/test_nvfp4_integration.py index fea7cd77eb..550f6e19e8 100644 --- a/tests/e2e/test_nvfp4_integration.py +++ b/tests/e2e/test_nvfp4_integration.py @@ -115,6 +115,39 @@ def test_gate_refuses_fp8_lm_head_ce_with_quantized_lm_head(monkeypatch): ) +def test_schema_accepts_bf16_lm_head_ce(monkeypatch): + _supported(monkeypatch, True) + cfg = AxolotlInputConfig( + **BASE, + nvfp4_training={"enabled": True, "bf16_lm_head_cross_entropy": True}, + ) + assert cfg.nvfp4_training.bf16_lm_head_cross_entropy is True + + +def test_schema_refuses_bf16_lm_head_ce_with_other_ce(monkeypatch): + _supported(monkeypatch, True) + with pytest.raises(ValueError, match="Only one cross entropy optimization"): + AxolotlInputConfig( + **BASE, + cut_cross_entropy=True, + nvfp4_training={"enabled": True, "bf16_lm_head_cross_entropy": True}, + ) + + +def test_gate_refuses_bf16_lm_head_ce_with_quantized_lm_head(monkeypatch): + _supported(monkeypatch, True) + with pytest.raises(ValueError, match="bf16_lm_head_cross_entropy"): + AxolotlConfigWCapabilities( + **BASE, + **CAPS, + nvfp4_training={ + "enabled": True, + "quantize_lm_head": True, + "bf16_lm_head_cross_entropy": True, + }, + ) + + def test_schema_accepts_qwen3_5_native_switches(monkeypatch): _supported(monkeypatch, True) cfg = AxolotlInputConfig( diff --git a/tests/kernels/test_bf16_fused_ce.py b/tests/kernels/test_bf16_fused_ce.py new file mode 100644 index 0000000000..0a3f16f838 --- /dev/null +++ b/tests/kernels/test_bf16_fused_ce.py @@ -0,0 +1,92 @@ +"""Correctness for the chunked bf16 lm_head + cross-entropy. + +The chunked path must produce a loss and dL/dhidden that match the same-weight +materialized ``F.cross_entropy`` reference (bit-close, not approximate), and both +must be FINITE — it is the convergence-safe alternative to the fused CCE/Liger +paths that collapsed under NVFP4 stochastic-rounding grads. +""" + +import pytest +import torch +import torch.nn.functional as F +from torch import nn + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for bf16 fused CE", allow_module_level=True) + +from axolotl.kernels.bf16_fused_ce import ( # noqa: E402 + bf16_lm_head_cross_entropy, +) + + +@pytest.mark.parametrize("num_items", [None, 137.0]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_bf16_fused_ce_matches_materialized(num_items, dtype): + torch.manual_seed(0) + M, H, V = 192, 256, 4096 + 512 # crosses a vocab-tile boundary + lm_head = nn.Linear(H, V, bias=False).cuda().to(dtype) + lm_head.weight.requires_grad_(False) + + hidden = torch.randn(M, H, device="cuda", dtype=dtype) + labels = torch.randint(0, V, (M,), device="cuda") + labels[::7] = -100 # mask some tokens + + # reference: full bf16 logits (upcast to fp32), standard CE + h_ref = hidden.clone().requires_grad_(True) + logits = (h_ref @ lm_head.weight.t()).float() + reduction = "sum" if num_items is not None else "mean" + ref = F.cross_entropy(logits, labels, ignore_index=-100, reduction=reduction) + if num_items is not None: + ref = ref / num_items + ref.backward() + + # fused (shift=False to align with the un-shifted reference) + h_fused = hidden.clone().requires_grad_(True) + fused = bf16_lm_head_cross_entropy( + h_fused, lm_head, labels, num_items_in_batch=num_items, shift=False + ) + fused.backward() + + assert torch.isfinite(fused).all() + assert torch.isfinite(h_fused.grad).all() + + loss_rel = (fused - ref).abs() / (ref.abs() + 1e-9) + grad_rel = (h_fused.grad - h_ref.grad).float().norm() / ( + h_ref.grad.float().norm() + 1e-9 + ) + # fp32 is bit-tight; bf16 carries the GEMM's intrinsic rounding noise. + loss_tol = 1e-6 if dtype == torch.float32 else 1e-4 + grad_tol = 1e-5 if dtype == torch.float32 else 5e-3 + assert loss_rel < loss_tol, (dtype, num_items, loss_rel.item()) + assert grad_rel < grad_tol, (dtype, num_items, grad_rel.item()) + + +def test_bf16_fused_ce_all_masked_is_finite(): + """A fully-masked microbatch must give a finite zero loss / zero grad.""" + torch.manual_seed(0) + M, H, V = 64, 128, 8192 + lm_head = nn.Linear(H, V, bias=False).cuda().bfloat16() + lm_head.weight.requires_grad_(False) + + hidden = torch.randn(M, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + labels = torch.full((M,), -100, device="cuda") + loss = bf16_lm_head_cross_entropy(hidden, lm_head, labels, shift=False) + loss.backward() + + assert torch.isfinite(loss).all() + assert float(loss) == 0.0 + assert torch.isfinite(hidden.grad).all() + + +def test_bf16_fused_ce_rejects_non_plain_head(): + """Trainable / biased lm_head -> None (caller falls back to materialized).""" + H, V = 128, 4096 + labels = torch.randint(0, V, (32,), device="cuda") + hidden = torch.randn(32, H, device="cuda", dtype=torch.bfloat16) + + trainable = nn.Linear(H, V, bias=False).cuda().bfloat16() # requires_grad default + assert bf16_lm_head_cross_entropy(hidden, trainable, labels) is None + + biased = nn.Linear(H, V, bias=True).cuda().bfloat16() + biased.weight.requires_grad_(False) + assert bf16_lm_head_cross_entropy(hidden, biased, labels) is None From e3fed9cf642feb945066af9330fa994a749c3932 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 06:52:02 -0700 Subject: [PATCH 05/17] docs(nvfp4): record agent-fix bench (+4.8%/-9GiB stack) and chunked bf16 CE memory mode --- docs/nvfp4_training.qmd | 49 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/docs/nvfp4_training.qmd b/docs/nvfp4_training.qmd index e503635ffc..cb0ec7b899 100644 --- a/docs/nvfp4_training.qmd +++ b/docs/nvfp4_training.qmd @@ -167,6 +167,25 @@ sequence length 2048, sample packing, no gradient checkpointing: | Locked profile, 500-step long-tail (RTX PRO 6000 #1, PCIe-first) | 4 | 1.2302 | ~6,659 tok/s | 62.03 GiB | | Locked profile + `hadamard: false` (RTX PRO 6000 #1) | 4 | 1.2155 | ~6,739 tok/s | 62.03 GiB | | Locked profile + once-per-forward mask classify (RTX PRO 6000 #1) | 4 | 1.1863 | ~6,906 tok/s | 62.03 GiB | +| Agent-fix re-bench: locked baseline (RTX PRO 6000 #1, 3-run median) | 4 | 1.1897 | ~6,886 tok/s | 69.63 GiB | +| + attention compile custom op auto-enabled | 4 | 1.1678 | ~7,015 tok/s | 60.64 GiB | +| + abs→`vector_norm` amax fuse + batched LoRA A-GEMMs (full stack) | 4 | 1.1310 | ~7,244 tok/s | 60.64 GiB | + +A 2026-06 sweep profiled the whole step and found full attention is only ~7% of +GPU time (so a faster FP4 attention backward is not the lever), while ~26% is +unfused NVFP4 quant/elementwise and ~11% is the bf16 lm_head. The dominant cause +was a **silent eager fallback**: the attention forward's P@V `tl.dot_scaled` raises +an `InductorError` inside Inductor autotune, and with `suppress_errors` on (the +training default) the whole attention-containing subgraph ran eager, blocking +fusion of the surrounding elementwise. Auto-enabling the differentiable attention +custom op when `torch_compile` is on (see `qwen3_5_native_attention_compile_custom_op` +below) makes Inductor compile *around* the opaque op, restoring fusion. Stacked with +the bit-exact amax fuse (`amax(abs(t))` → `vector_norm(t, inf)`, dropping the +full-tensor abs pass) and batched shared-input LoRA A-GEMMs (`lora_batch_kernel`), +the interleaved A/B above measured **1.131 vs 1.190 s/step (~4.8% faster) and +−9 GiB** at identical loss. All three changes are bit-exact or opt-in; the first +also drops active memory because the custom-op backward does not retain forward +packs (`qwen3_5_native_attention_save_backward_packs` has no effect under it). The marginal s/step varies ~5% **between the two RTX PRO 6000 boards**: the PCIe-first board measured 1.2208–1.2302 across 60- and 500-step runs, while the @@ -189,7 +208,11 @@ Negative checks from the same sweep: attention packs: 2.0500 s/step (~5,994 tok/s). - Cut Cross Entropy and Liger fused-linear CE are not valid with this NVFP4 training setup today: b5 warmups loss-collapsed to zero after the first step - and logged non-finite grad norms. + and logged non-finite grad norms. The opt-in `bf16_lm_head_cross_entropy` path + (below) avoids that collapse — it never filters gradient mass and keeps the + logsumexp and `grad_hidden` accumulation in fp32 — but it is a memory win, not a + throughput win (1.2178 s/step, 56.48 GiB at b4; the tile GEMMs are bf16 with the + same FLOPs as the materialized lm_head, so it trades ~2.5% speed for ~13 GiB). - `fp8_lm_head_cross_entropy` can reduce memory and run b5/b6, but it did not beat the b4 saved-pack path on max throughput in clean validation. It is also not fully deterministic yet: one b6 run loss-collapsed with `grad_norm: nan`, @@ -424,6 +447,27 @@ path remained ahead on tokens/sec. Treat batch-6 FP8 CE as experimental: one validation run loss-collapsed after the first step with non-finite grad norms, while a later short repeat stayed finite. +### Chunked bf16 lm_head cross-entropy (opt-in) + +`bf16_lm_head_cross_entropy: true` is a **memory** option for a frozen, bias-free +plain `nn.Linear` `lm_head`. It computes the loss and `dL/dhidden` by streaming the +vocabulary in tiles (`_VOCAB_BLOCK = 4096`) with online softmax, so the +`[tokens, vocab]` logits tensor and its gradient are never materialized. Unlike the +fused CCE/Liger kernels — which loss-collapse in this NVFP4 setup — it does **no** +gradient filtering, keeps the logsumexp and `grad_hidden` accumulation in fp32, and +downcasts to bf16 once at the end, matching `F.cross_entropy` to ~1e-7 (loss) / +~3.5e-3 (grad, pure bf16-GEMM noise) at the Qwen3.5 vocab scale. + +It is a memory/batch-size unlock, **not** a throughput win: the tile GEMMs are bf16 +with the same FLOPs as the materialized lm_head (which is already frozen, so there is +no weight gradient to save), so it trades speed for memory. On Qwen3.5-9B LoRA +(seq 2048, b4, RTX PRO 6000) it measured **1.2178 s/step at 56.48 GiB active** versus +the locked saved-pack path's ~1.19 s/step at 69.63 GiB — ~2.5% slower for ~13 GiB +headroom. Use it when memory-bound (to fit a larger batch); otherwise leave it off. +It engages only for a frozen bias-free `nn.Linear` lm_head in training (it falls back +to materialized CE otherwise) and is mutually exclusive with `quantize_lm_head`, +`fused_fp4_cross_entropy`, and `fp8_lm_head_cross_entropy`. + ::: {.callout-note} At the time of writing, `skip_first_n_blocks` / `skip_last_n_blocks` may be applied by the integration layer rather than inside `convert_to_nvfp4_training` directly @@ -457,7 +501,8 @@ The `nvfp4_training:` block (schema: `src/axolotl/utils/schemas/nvfp4.py`, | `qwen3_5_native_attention_backward_rtn_grad_packs` | `bool` | `false` | Qwen3.5 native attention training only. Use deterministic round-to-nearest for measured-safe gradient packs while leaving the dK routing-gradient dS pack governed by `stochastic_rounding`. | | `qwen3_5_native_attention_save_backward_packs` | `bool` | `false` | Qwen3.5 native attention training only. Save deterministic forward Q/K/V FP4 packs plus transposed backward layouts and reuse them in backward. Trades extra activation memory for higher throughput. | | `qwen3_5_native_attention_dkdv_scratch_bf16` | `bool` | `false` | Qwen3.5 native attention training only. Store per-query-head dK/dV scratch in bf16 before GQA reduction instead of fp32. Measured faster on Qwen3.5-9B b4; opt-in because it changes an intermediate gradient cast. | -| `qwen3_5_native_attention_compile_custom_op` | `bool` | `false` | Qwen3.5 native attention inference only. Opaque custom-op escape hatch for strict compile coverage around Triton `tl.dot_scaled`; rejected when native attention backward is enabled. | +| `qwen3_5_native_attention_compile_custom_op` | `bool \| null` | `null` (auto) | Qwen3.5 native attention. Wraps the attention path in an opaque differentiable custom op so Inductor compiles *around* the Triton `tl.dot_scaled` (which otherwise raises an `InductorError` and silently drops the block to eager). `null` auto-enables it when `torch_compile` is on and native attention is enabled (the measured fastest path); `true`/`false` force it. Bit-exact; under it `save_backward_packs` has no effect (backward recomputes packs across the boundary) and active memory drops. | +| `bf16_lm_head_cross_entropy` | `bool` | `false` | Opt-in **memory** path. Requires a frozen bias-free plain `nn.Linear` lm_head. Chunked online-softmax CE over bf16 vocab tiles — no `[tokens, vocab]` logits materialization, fp32 logsumexp/`grad_hidden`, no gradient filtering (avoids the CCE/Liger collapse). Trades ~2.5% throughput for ~13 GiB. Mutually exclusive with `quantize_lm_head` / `fused_fp4_cross_entropy` / `fp8_lm_head_cross_entropy`. | | `qwen3_5_fla_causal_conv_compile_boundary` | `bool` | `false` | Qwen3.5 sample-packing only. Runs FLA varlen `causal_conv1d` behind a compile boundary so variable packed `cu_seqlens` lengths do not repeatedly trigger Dynamo/Inductor recompiles. | | `qwen3_5_fuse_vproj` / `qwen3_5_native_mlp` / `qwen3_5_native_linear_attn` / `fp8_lm_head` | `bool` | `false` | Qwen3.5/eval-scoring paths. Current implementations are eval/no-grad only and do not accelerate grad-enabled training. Use `fp8_lm_head_cross_entropy` separately for the opt-in training loss memory path. | From 34af026322e2e58bff5386b8a311783a18d40a60 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:10:45 -0700 Subject: [PATCH 06/17] perf(nvfp4): reuse rms in fused RMSNorm forward, drop double rsqrt [F2] --- src/axolotl/kernels/nvfp4_rmsnorm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/axolotl/kernels/nvfp4_rmsnorm.py b/src/axolotl/kernels/nvfp4_rmsnorm.py index 66db8c11be..26364d6714 100644 --- a/src/axolotl/kernels/nvfp4_rmsnorm.py +++ b/src/axolotl/kernels/nvfp4_rmsnorm.py @@ -129,18 +129,26 @@ def fused_rmsnorm_nvfp4( x: torch.Tensor, weight: torch.Tensor, eps: float, + rms: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """RMSNorm(x) * weight, returning (normalized bf16, packed fp4, swizzled scales). Single-level (activation) NVFP4: global_scale is implicitly 1.0. The fp4/scale outputs feed ``torch._scaled_mm`` directly (TN layout, contraction %32). + + ``rms`` (the per-row ``rsqrt(mean(x^2)+eps)``) may be passed in when the caller + already computed it (e.g. the autograd Function saves it for backward), to avoid + a redundant full-row reduction over x. """ orig_dims, K = x.shape[:-1], x.shape[-1] x2 = x.reshape(-1, K) M, N = x2.shape assert N % 16 == 0, "K must be divisible by 16 for NVFP4 quantization" - rms = torch.rsqrt(x2.float().pow(2).mean(-1, keepdim=True) + eps).reshape(M) + if rms is None: + rms = torch.rsqrt(x2.float().pow(2).mean(-1, keepdim=True) + eps).reshape(M) + else: + rms = rms.reshape(M) num_scales = N // 16 n_row_blocks = triton.cdiv(M, 128) @@ -191,9 +199,9 @@ class _FusedRMSNormNVFP4Function(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, eps): - y, xq, xsc = fused_rmsnorm_nvfp4(x, weight, eps) x2 = x.reshape(-1, x.shape[-1]) r = torch.rsqrt(x2.float().pow(2).mean(-1, keepdim=True) + eps) + y, xq, xsc = fused_rmsnorm_nvfp4(x, weight, eps, rms=r.reshape(-1)) ctx.save_for_backward(x2, weight, r) ctx.eps = eps # Stash the fused quant keyed by the output the consuming linear sees. From 915c87535e427a6099c1d5f6d7e8562d9b16af47 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:10:45 -0700 Subject: [PATCH 07/17] docs(nvfp4): correct rtn_grad_packs desc (dPt dO pack also stays SR) [AB2] --- src/axolotl/utils/schemas/nvfp4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/nvfp4.py b/src/axolotl/utils/schemas/nvfp4.py index 578d779d03..6922b6f3b0 100644 --- a/src/axolotl/utils/schemas/nvfp4.py +++ b/src/axolotl/utils/schemas/nvfp4.py @@ -269,7 +269,8 @@ class NVFP4TrainingConfig(BaseModel): "description": "Qwen3.5 native attention training only. Use " "deterministic round-to-nearest for the measured-safe gradient packs " "(softmax P and transposed dO for dV, and dS for dQ) while leaving the " - "dK routing-gradient dS pack governed by stochastic_rounding. This " + "dK routing-gradient dS pack AND the dPt dO pack governed by " + "stochastic_rounding. This " "collapsed mode was faster in backward microbenchmarks; convergence " "validation is still required for production training. OFF by default." }, From 6fe10d5197215aa47a6b513d610945937a06bb28 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:15:03 -0700 Subject: [PATCH 08/17] perf(nvfp4): drop dead HP q/k/v save on saved-packs backward [B1] --- scripts/check_b1_parity.py | 38 +++++++++++++++++++++++++ src/axolotl/kernels/attn_nvfp4_flash.py | 23 +++++++++++++-- 2 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 scripts/check_b1_parity.py diff --git a/scripts/check_b1_parity.py b/scripts/check_b1_parity.py new file mode 100644 index 0000000000..1ad81abe6b --- /dev/null +++ b/scripts/check_b1_parity.py @@ -0,0 +1,38 @@ +"""B1 parity probe: dump dq/dk/dv + peak mem for the saved-packs FP4 backward. + +Run on BOTH the pre-B1 commit and the B1 commit (same args, SR OFF so the FP4 +backward is deterministic), then diff the two dumps — B1 only removes dead HP q/k/v +saves, so grads must be BIT-IDENTICAL and peak memory should drop. + + PYTHONPATH=/src python scripts/check_b1_parity.py /tmp/b1_.pt +""" +import sys +import torch + +from axolotl.kernels.attn_nvfp4_flash import nvfp4_flash_attn_func + +torch.manual_seed(0) +dev = "cuda" +Z, H, Hk, Sq, Skv, D = 1, 8, 2, 1024, 1024, 256 # GQA, D=256 (Qwen3.5 full-attn) +q = torch.randn(Z, H, Sq, D, device=dev, dtype=torch.bfloat16, requires_grad=True) +k = torch.randn(Z, Hk, Skv, D, device=dev, dtype=torch.bfloat16, requires_grad=True) +v = torch.randn(Z, Hk, Skv, D, device=dev, dtype=torch.bfloat16, requires_grad=True) + +torch.cuda.reset_peak_memory_stats() +out = nvfp4_flash_attn_func( + q, k, v, 1.0 / (D**0.5), + causal=True, num_key_value_groups=H // Hk, + stochastic_rounding=False, # deterministic for cross-commit bit-compare + save_backward_packs=True, + dkdv_scratch_bf16=True, +) +g = torch.randn_like(out) +out.backward(g) +peak = torch.cuda.max_memory_allocated() / 2**30 + +path = sys.argv[1] +torch.save( + {"dq": q.grad.cpu(), "dk": k.grad.cpu(), "dv": v.grad.cpu(), "peak_GiB": peak}, + path, +) +print(f"saved {path} peak_GiB={peak:.4f}") diff --git a/src/axolotl/kernels/attn_nvfp4_flash.py b/src/axolotl/kernels/attn_nvfp4_flash.py index 78a978453a..8ad06767ef 100644 --- a/src/axolotl/kernels/attn_nvfp4_flash.py +++ b/src/axolotl/kernels/attn_nvfp4_flash.py @@ -1530,10 +1530,21 @@ def forward( bias = None if key_pad_bias is not None: bias = key_pad_bias.to(torch.float32).contiguous() + # On the saved-packs backward, HP q/k/v are never dereferenced (have_lse + # skips prep; full pack reuse skips kprep and sets packprep STORE_Q/QT + # False), so don't pay three full bf16 [.,S,D] copies — save placeholders. + if save_backward_packs: + empty = torch.empty(0, device=query.device, dtype=query.dtype) + q_save = k_save = v_save = empty + else: + q_save = query.reshape(z * h, s_q, d).contiguous() + k_save = key.reshape(z * hk, s_kv, d).contiguous() + v_save = value.reshape(z * hk, s_kv, d).contiguous() + ctx.save_backward_packs = save_backward_packs ctx.save_for_backward( - query.reshape(z * h, s_q, d).contiguous(), - key.reshape(z * hk, s_kv, d).contiguous(), - value.reshape(z * hk, s_kv, d).contiguous(), + q_save, + k_save, + v_save, out.reshape(z * h, s_q, d), bias if bias is not None else torch.empty(0, device=query.device), lse, @@ -1572,6 +1583,12 @@ def backward(ctx, grad_out): z, h, hk, s_q, s_kv, d = ctx.dims block_m, block_n, num_warps, num_stages = ctx.tiles bias = bias if ctx.has_bias else None + if getattr(ctx, "save_backward_packs", False): + # q/k/v were not saved (placeholders); the packs are the real operands. + # _run_bwd never reads HP q/k/v here, but needs a tensor with q's + # [z*h,s_q,d] shape/device/stride for scratch allocs and the unused + # bias-dummy pointer — o satisfies that. + q = k = v = o do = grad_out.reshape(z * h, s_q, d).contiguous() dq, dk, dv = _run_bwd( q, k, v, do, o, bias, From 333b75056c7cce492ed2f5b3a3e752f9c10f267b Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 00:15:45 -0700 Subject: [PATCH 09/17] perf(nvfp4): bf16 dq scratch under dkdv_scratch_bf16 (bit-exact, mem) [B2] --- src/axolotl/kernels/attn_nvfp4_flash.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/kernels/attn_nvfp4_flash.py b/src/axolotl/kernels/attn_nvfp4_flash.py index 8ad06767ef..c45995a7da 100644 --- a/src/axolotl/kernels/attn_nvfp4_flash.py +++ b/src/axolotl/kernels/attn_nvfp4_flash.py @@ -1345,8 +1345,12 @@ def _run_bwd( reuse_k_pack = knv_saved is not None and ksc_saved is not None reuse_v_pack = vnv_saved is not None and vsc_saved is not None reuse_kt_pack = ktnv_saved is not None and ktsc_saved is not None - dq = torch.empty(z * h, s_q, d, device=q.device, dtype=torch.float32) dkdv_scratch_dtype = torch.bfloat16 if dkdv_scratch_bf16 else torch.float32 + # dq accumulates in fp32 registers and only downcasts at the final store, so a + # bf16 scratch buffer is bit-identical to fp32-then-.to(bf16) here — a pure + # memory save (the largest bwd scratch plane). Must stay fp32 if the fused + # atomic-add dq path is ever enabled (atomics need fp32). + dq = torch.empty(z * h, s_q, d, device=q.device, dtype=dkdv_scratch_dtype) dk = torch.empty(z * h, s_kv, d, device=q.device, dtype=dkdv_scratch_dtype) dv = torch.empty(z * h, s_kv, d, device=q.device, dtype=dkdv_scratch_dtype) if not have_lse: From 612fe2681adaf99538fdf83ef7591709fba404cd Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 07:12:25 -0700 Subject: [PATCH 10/17] docs(nvfp4): dq scratch follows dkdv_scratch_bf16; record B1/B2 fixed-shape memory probe --- docs/nvfp4_training.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/nvfp4_training.qmd b/docs/nvfp4_training.qmd index cb0ec7b899..011bf3c433 100644 --- a/docs/nvfp4_training.qmd +++ b/docs/nvfp4_training.qmd @@ -500,7 +500,7 @@ The `nvfp4_training:` block (schema: `src/axolotl/utils/schemas/nvfp4.py`, | `qwen3_5_native_attention_backward` | `bool` | `false` | Qwen3.5 only. Requires `qwen3_5_native_attention`. Use the native NVFP4 autograd attention path while training. | | `qwen3_5_native_attention_backward_rtn_grad_packs` | `bool` | `false` | Qwen3.5 native attention training only. Use deterministic round-to-nearest for measured-safe gradient packs while leaving the dK routing-gradient dS pack governed by `stochastic_rounding`. | | `qwen3_5_native_attention_save_backward_packs` | `bool` | `false` | Qwen3.5 native attention training only. Save deterministic forward Q/K/V FP4 packs plus transposed backward layouts and reuse them in backward. Trades extra activation memory for higher throughput. | -| `qwen3_5_native_attention_dkdv_scratch_bf16` | `bool` | `false` | Qwen3.5 native attention training only. Store per-query-head dK/dV scratch in bf16 before GQA reduction instead of fp32. Measured faster on Qwen3.5-9B b4; opt-in because it changes an intermediate gradient cast. | +| `qwen3_5_native_attention_dkdv_scratch_bf16` | `bool` | `false` | Qwen3.5 native attention training only. Store the dQ **and** per-query-head dK/dV backward scratch in bf16 (dQ/dK/dV accumulate fp32 in-register and downcast once at the store, so this is bit-identical to fp32-then-`.to(bf16)` — a pure memory save on the largest backward scratch planes). Measured faster + lower-memory on Qwen3.5-9B b4. Fixed-shape probe at the 9B full-attn shape: fwd+bwd peak 406→342 MiB (−16%) on the custom-op path, grads `maxabsdiff=0`. | | `qwen3_5_native_attention_compile_custom_op` | `bool \| null` | `null` (auto) | Qwen3.5 native attention. Wraps the attention path in an opaque differentiable custom op so Inductor compiles *around* the Triton `tl.dot_scaled` (which otherwise raises an `InductorError` and silently drops the block to eager). `null` auto-enables it when `torch_compile` is on and native attention is enabled (the measured fastest path); `true`/`false` force it. Bit-exact; under it `save_backward_packs` has no effect (backward recomputes packs across the boundary) and active memory drops. | | `bf16_lm_head_cross_entropy` | `bool` | `false` | Opt-in **memory** path. Requires a frozen bias-free plain `nn.Linear` lm_head. Chunked online-softmax CE over bf16 vocab tiles — no `[tokens, vocab]` logits materialization, fp32 logsumexp/`grad_hidden`, no gradient filtering (avoids the CCE/Liger collapse). Trades ~2.5% throughput for ~13 GiB. Mutually exclusive with `quantize_lm_head` / `fused_fp4_cross_entropy` / `fp8_lm_head_cross_entropy`. | | `qwen3_5_fla_causal_conv_compile_boundary` | `bool` | `false` | Qwen3.5 sample-packing only. Runs FLA varlen `causal_conv1d` behind a compile boundary so variable packed `cu_seqlens` lengths do not repeatedly trigger Dynamo/Inductor recompiles. | From 2926b537905e28a3382f3a72798e9a08d2416e6c Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 08:23:19 -0700 Subject: [PATCH 11/17] =?UTF-8?q?perf(nvfp4):=20read=20Q/K=20transposed=20?= =?UTF-8?q?in=20fused=5Frope=5Fquant=5Fqk=20(drop=20.contiguous()=20copy)?= =?UTF-8?q?=20=E2=80=94=201.45-1.52x=20producer,=20bit-identical=20[prefil?= =?UTF-8?q?l=20#2]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/prove_p2.py | 30 ++++++++++++++++++++ scripts/prove_p3.py | 27 ++++++++++++++++++ src/axolotl/kernels/nvfp4_fused_producers.py | 17 +++++++---- 3 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 scripts/prove_p2.py create mode 100644 scripts/prove_p3.py diff --git a/scripts/prove_p2.py b/scripts/prove_p2.py new file mode 100644 index 0000000000..3098020fb1 --- /dev/null +++ b/scripts/prove_p2.py @@ -0,0 +1,30 @@ +"""#2 fused_rope_quant_qk: parity (strided==contiguous, bit-identical) + the +saved-copy latency on a realistic transposed (non-contiguous) Q input.""" +import torch +from axolotl.kernels.nvfp4_fused_producers import fused_rope_quant_qk + +dev = "cuda"; dt = torch.bfloat16 +torch.manual_seed(0) +for Z, H, S, D in [(1, 16, 2048, 256), (1, 16, 4096, 256)]: + rot = D + base = torch.randn(Z, S, H, D, device=dev, dtype=dt) # [Z,S,H,D] contiguous + x_t = base.transpose(1, 2) # [Z,H,S,D] non-contig (D contig) + cos = torch.randn(Z, S, rot, device=dev, dtype=dt) + sin = torch.randn(Z, S, rot, device=dev, dtype=dt) + + q_s, sc_s = fused_rope_quant_qk(x_t, cos, sin) # strided (new, no copy) + q_c, sc_c = fused_rope_quant_qk(x_t.contiguous(), cos, sin) # contiguous reference + ok = torch.equal(q_s, q_c) and torch.equal(sc_s.view(torch.uint8), sc_c.view(torch.uint8)) + + def t(fn, it=50): + torch.cuda.synchronize() + for _ in range(5): fn() + torch.cuda.synchronize() + a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() + for _ in range(it): fn() + b.record(); torch.cuda.synchronize() + return a.elapsed_time(b) / it + new = t(lambda: fused_rope_quant_qk(x_t, cos, sin)) + old = t(lambda: fused_rope_quant_qk(x_t.contiguous(), cos, sin)) + print(f"S={S}: parity_bit_identical={ok} new(strided) {new*1000:6.1f}us " + f"old(contig+kernel) {old*1000:6.1f}us speedup {old/new:.2f}x") diff --git a/scripts/prove_p3.py b/scripts/prove_p3.py new file mode 100644 index 0000000000..6396f7c0df --- /dev/null +++ b/scripts/prove_p3.py @@ -0,0 +1,27 @@ +"""#3 forward V-load hoist: dump forward output (for cross-worktree bit-compare) ++ forward prefill latency. Run under each worktree's PYTHONPATH.""" +import sys, math, torch +from axolotl.kernels.attn_nvfp4_flash import nvfp4_flash_attention + +dev = "cuda"; dt = torch.bfloat16 +tag = sys.argv[1] +outs = {} +for Z, H, Hk, S, D in [(1, 16, 4, 2048, 256), (1, 16, 4, 4096, 256)]: + torch.manual_seed(0) + q = torch.randn(Z, H, S, D, device=dev, dtype=dt) + k = torch.randn(Z, Hk, S, D, device=dev, dtype=dt) + v = torch.randn(Z, Hk, S, D, device=dev, dtype=dt) + sc = 1.0 / math.sqrt(D) + fn = lambda: nvfp4_flash_attention(q, k, v, sc, causal=True, num_key_value_groups=H // Hk) + out = fn() + outs[S] = out.cpu() + + torch.cuda.synchronize() + for _ in range(5): fn() + torch.cuda.synchronize() + a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() + for _ in range(50): fn() + b.record(); torch.cuda.synchronize() + print(f"S={S}: forward {a.elapsed_time(b)/50*1000:7.1f} us") +torch.save(outs, f"/tmp/p3_{tag}.pt") +print(f"saved /tmp/p3_{tag}.pt") diff --git a/src/axolotl/kernels/nvfp4_fused_producers.py b/src/axolotl/kernels/nvfp4_fused_producers.py index 43e2d2d190..77bfb1cb41 100644 --- a/src/axolotl/kernels/nvfp4_fused_producers.py +++ b/src/axolotl/kernels/nvfp4_fused_producers.py @@ -47,7 +47,7 @@ def _rope_quant_kernel( cos_ptr, sin_ptr, # [Z, S, ROT] (ROT = rotary_dim) q_ptr, s_ptr, # [Z*H, S, D//2] uint8, [Z*H, S, D//16] e4m3 Z, H, S, - s_xn, s_xr, # x: per-(z*h) stride, per-row(seq) stride; col stride = 1 + s_xz, s_xh, s_xr, # x: per-z, per-h, per-row(seq) strides; col(D) stride = 1 s_cz, s_cr, # cos/sin: per-z stride, per-row stride; col stride = 1 s_qn, s_qr, # q packed: per-(z*h) stride, per-row stride s_sn, s_sr, # scale: per-(z*h) stride, per-row stride @@ -58,12 +58,13 @@ def _rope_quant_kernel( pid_n = tl.program_id(0) # z*h pid_r = tl.program_id(1) z = pid_n // H + h = pid_n % H offs_r = pid_r * BLOCK_R + tl.arange(0, BLOCK_R) rmask = offs_r < S offs_d = tl.arange(0, D) - xbase = pid_n * s_xn + xbase = z * s_xz + h * s_xh x = tl.load( x_ptr + xbase + offs_r[:, None] * s_xr + offs_d[None, :], mask=rmask[:, None], other=0.0, @@ -130,20 +131,24 @@ def fused_rope_quant_qk( z, h, s, d = x.shape rot = cos.shape[-1] assert d % 16 == 0 and rot % 2 == 0 and rot <= d - x = x.contiguous() + # Read x in whatever layout it arrives (typically a [Z,S,H,D].transpose(1,2) + # view → non-contiguous, but D stays contiguous), so we DON'T pay a full bf16 + # copy here. The kernel only needs the head_dim (D) unit-stride; if some caller + # passes a non-D-contiguous x, fall back to a copy (correctness over speed). + if x.stride(3) != 1: + x = x.contiguous() cos = cos.contiguous() sin = sin.contiguous() - xn = x.reshape(z * h, s, d) q = x.new_empty(z * h, s, d // 2, dtype=torch.uint8) sc = x.new_empty(z * h, s, d // 16, dtype=torch.uint8) BLOCK_R = 64 grid = (z * h, triton.cdiv(s, BLOCK_R)) _rope_quant_kernel[grid]( - xn, cos, sin, q, sc, + x, cos, sin, q, sc, z, h, s, - xn.stride(0), xn.stride(1), + x.stride(0), x.stride(1), x.stride(2), cos.stride(0), cos.stride(1), q.stride(0), q.stride(1), sc.stride(0), sc.stride(1), From 63c18d6fa00af0ad9efe02bdbd8dfa9343373d07 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 08:28:29 -0700 Subject: [PATCH 12/17] test(nvfp4): rope_quant strided-parity unit test + e2e block bench (#2: ~13% attn-block, bit-identical) --- scripts/bench_block.py | 40 +++++++++++++++++ .../kernels/test_nvfp4_rope_quant_strided.py | 45 +++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 scripts/bench_block.py create mode 100644 tests/kernels/test_nvfp4_rope_quant_strided.py diff --git a/scripts/bench_block.py b/scripts/bench_block.py new file mode 100644 index 0000000000..dd7930b122 --- /dev/null +++ b/scripts/bench_block.py @@ -0,0 +1,40 @@ +"""End-to-end prefill attention-block compute (producers Q/K + V + flash kernel), +timed. Q/K passed as transposed (non-contiguous) views — the production layout — +so #2's strided path is exercised. Dumps output for cross-worktree bit-compare.""" +import sys, math, torch +from axolotl.kernels.nvfp4_fused_producers import fused_rope_quant_qk, quant_v_keyaxis +from axolotl.kernels.attn_nvfp4_flash import nvfp4_flash_attention_packed + +dev = "cuda"; dt = torch.bfloat16 +tag = sys.argv[1] +BLOCK_N = 128 +outs = {} +for Z, H, Hk, S, D in [(1, 16, 4, 2048, 256), (1, 16, 4, 4096, 256)]: + torch.manual_seed(0) + rot = D; sc = 1.0 / math.sqrt(D) + qb = torch.randn(Z, S, H, D, device=dev, dtype=dt) + kb = torch.randn(Z, S, Hk, D, device=dev, dtype=dt) + q_t = qb.transpose(1, 2) # [Z,H,S,D] non-contig (prod layout) + k_t = kb.transpose(1, 2) # [Z,Hk,S,D] + v = torch.randn(Z, Hk, S, D, device=dev, dtype=dt) + cos = torch.randn(Z, S, rot, device=dev, dtype=dt) + sin = torch.randn(Z, S, rot, device=dev, dtype=dt) + + def fn(): + qnv, qsc = fused_rope_quant_qk(q_t, cos, sin) + knv, ksc = fused_rope_quant_qk(k_t, cos, sin) + vnv, vsc, _ = quant_v_keyaxis(v, block_n=BLOCK_N) + return nvfp4_flash_attention_packed( + qnv, qsc, knv, ksc, vnv, vsc, z=Z, h=H, hk=Hk, s_q=S, s_kv=S, d=D, + scaling=sc, out_dtype=dt, causal=True, block_n=BLOCK_N, out_layout="zshd") + + outs[S] = fn().cpu() + torch.cuda.synchronize() + for _ in range(5): fn() + torch.cuda.synchronize() + a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() + for _ in range(50): fn() + b.record(); torch.cuda.synchronize() + print(f"S={S}: block(producers+flash) {a.elapsed_time(b)/50*1000:7.1f} us") +torch.save(outs, f"/tmp/block_{tag}.pt") +print(f"saved /tmp/block_{tag}.pt") diff --git a/tests/kernels/test_nvfp4_rope_quant_strided.py b/tests/kernels/test_nvfp4_rope_quant_strided.py new file mode 100644 index 0000000000..2e65dc4536 --- /dev/null +++ b/tests/kernels/test_nvfp4_rope_quant_strided.py @@ -0,0 +1,45 @@ +"""fused_rope_quant_qk must read a transposed (non-contiguous, D-unit-stride) Q/K +view and produce BIT-IDENTICAL packs to the contiguous path — the invariant behind +dropping the per-layer .contiguous() copy (prefill grab #2). The production caller +passes q_norm(...).transpose(1,2), exactly this layout.""" +import pytest +import torch + +cuda = torch.cuda.is_available() +pytestmark = pytest.mark.skipif(not cuda, reason="needs a CUDA (sm_120) device") + +if cuda: + from axolotl.kernels.nvfp4_fused_producers import fused_rope_quant_qk + + +@pytest.mark.parametrize("Z,H,S,D", [(1, 16, 300, 256), (1, 8, 256, 128), (2, 4, 128, 256)]) +def test_strided_matches_contiguous(Z, H, S, D): + torch.manual_seed(0) + rot = D + base = torch.randn(Z, S, H, D, device="cuda", dtype=torch.bfloat16) + x_t = base.transpose(1, 2) # [Z,H,S,D] non-contiguous, D unit-stride + assert x_t.stride(3) == 1 and not x_t.is_contiguous() + cos = torch.randn(Z, S, rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(Z, S, rot, device="cuda", dtype=torch.bfloat16) + + q_s, sc_s = fused_rope_quant_qk(x_t, cos, sin) # strided (no copy) + q_c, sc_c = fused_rope_quant_qk(x_t.contiguous(), cos, sin) # contiguous reference + + assert torch.equal(q_s, q_c), "packed FP4 differs between strided and contiguous" + assert torch.equal(sc_s.view(torch.uint8), sc_c.view(torch.uint8)), "scale differs" + assert q_s.shape == (Z * H, S, D // 2) and q_s.dtype == torch.uint8 + assert sc_s.shape == (Z * H, S, D // 16) + + +def test_noncontiguous_d_falls_back(): + """If head_dim is NOT unit-stride, it must fall back to a copy and still be correct.""" + torch.manual_seed(0) + Z, H, S, D = 1, 4, 64, 128 + # make D non-unit-stride by transposing S<->D, then take a view where D is dim 3 + x = torch.randn(Z, H, D, S, device="cuda", dtype=torch.bfloat16).transpose(2, 3) # [Z,H,S,D], D stride = S + assert x.stride(3) != 1 + cos = torch.randn(Z, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(Z, S, D, device="cuda", dtype=torch.bfloat16) + q_fb, sc_fb = fused_rope_quant_qk(x, cos, sin) + q_ref, sc_ref = fused_rope_quant_qk(x.contiguous(), cos, sin) + assert torch.equal(q_fb, q_ref) and torch.equal(sc_fb.view(torch.uint8), sc_ref.view(torch.uint8)) From 178520c2164b9fb5b4a2e056eb3cc6ae40401325 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 09:02:51 -0700 Subject: [PATCH 13/17] feat(nvfp4): opt-in FP4 q/k/o_proj + shared qkv activation pack (default-off, parity-affecting) Isolated q/k/o GEMM 1.13x vs bf16 (shared-pack); but Qwen3.5 is hybrid (8/32 full-attn layers) so model-level prefill is 1.000x (buried) and it costs parity (logit cos 0.9969, argmax 79%). Default-off; the reusable lever for dense models where every layer is full-attention. NOT bit-exact. --- scripts/bench_proj_iso.py | 55 ++++++++++ scripts/prove_attnproj.py | 57 ++++++++++ .../monkeypatch/attention/nvfp4_flash_attn.py | 102 ++++++++++++++++-- 3 files changed, 208 insertions(+), 6 deletions(-) create mode 100644 scripts/bench_proj_iso.py create mode 100644 scripts/prove_attnproj.py diff --git a/scripts/bench_proj_iso.py b/scripts/bench_proj_iso.py new file mode 100644 index 0000000000..de4d95aafc --- /dev/null +++ b/scripts/bench_proj_iso.py @@ -0,0 +1,55 @@ +"""Isolated: does FP4 q/k/o_proj (shared-pack) beat bf16 at the real Qwen3.5-4B +dims? This is the deciding factor for the dense-model rationale (per-GEMM, where a +dense model would hit it every layer).""" +import torch +from transformers import AutoModelForCausalLM +from axolotl.kernels.attn_nvfp4_flash import _quant_nvfp4 +from axolotl.kernels.nvfp4_linear import nvfp4_linear +from axolotl.monkeypatch.attention.nvfp4_linear_attn import ( + _gemm_from_packed_act, _get_packed_weight, +) + +dev = "cuda" +m = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16).to(dev).eval() +attn = None +for mod in m.modules(): + if all(hasattr(mod, p) for p in ("q_proj", "k_proj", "o_proj")) and type(mod.q_proj) is torch.nn.Linear: + attn = mod; break +H = attn.q_proj.in_features +qO, kO, oI, oO = attn.q_proj.out_features, attn.k_proj.out_features, attn.o_proj.in_features, attn.o_proj.out_features +print(f"dims: hidden={H} q_out={qO} k_out={kO} o_in={oI} o_out={oO}") + +M = 1024 +x = torch.randn(M, H, device=dev, dtype=torch.bfloat16) +xo = torch.randn(M, oI, device=dev, dtype=torch.bfloat16) +qw = _get_packed_weight(attn, "_q", attn.q_proj) +kw = _get_packed_weight(attn, "_k", attn.k_proj) +ow = _get_packed_weight(attn, "_o", attn.o_proj) + + +@torch.no_grad() +def bf16(): + return attn.q_proj(x), attn.k_proj(x), attn.o_proj(xo) + + +@torch.no_grad() +def fp4(): + anv, asc = _quant_nvfp4(x.unsqueeze(0)); anv, asc = anv[0], asc[0] + q = _gemm_from_packed_act(anv, asc, qw[0], qw[1], M, qO, H, torch.bfloat16) + k = _gemm_from_packed_act(anv, asc, kw[0], kw[1], M, kO, H, torch.bfloat16) + o = nvfp4_linear(xo, ow[0], ow[1], oO) + return q, k, o + + +def t(fn, it=100): + torch.cuda.synchronize() + for _ in range(10): fn() + torch.cuda.synchronize() + a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() + for _ in range(it): fn() + b.record(); torch.cuda.synchronize() + return a.elapsed_time(b) / it + + +bf, fp = t(bf16), t(fp4) +print(f"q+k+o proj (M={M}): bf16 {bf*1000:6.1f}us FP4-shared {fp*1000:6.1f}us speedup {bf/fp:.2f}x") diff --git a/scripts/prove_attnproj.py b/scripts/prove_attnproj.py new file mode 100644 index 0000000000..ef246aa729 --- /dev/null +++ b/scripts/prove_attnproj.py @@ -0,0 +1,57 @@ +"""Model-level parity + prefill latency for FP4 q/k/o_proj (shared-pack). +Loads Qwen3.5-4B, patches full-attn layers, compares prefill logits + latency +with the proj-FP4 flag ON vs OFF (FP4-attention-only) vs unpatched bf16.""" +import torch, torch.nn.functional as F +from transformers import AutoModelForCausalLM +from axolotl.monkeypatch.attention.nvfp4_flash_attn import patch_qwen3_5_nvfp4_attention + +dev = "cuda" +m = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3.5-4B", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", +).to(dev).eval() +torch.manual_seed(0) +ids = torch.randint(0, 10000, (1, 1024), device=dev) + + +def cos(a, b): + return F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + + +with torch.no_grad(): + base = m(ids).logits.float() + +n = patch_qwen3_5_nvfp4_attention(m, fuse_attn_proj=True) +ok = sum(int(getattr(mod, "_nvfp4_attn_proj_ok", False)) for mod in m.modules()) +print(f"patched full-attn layers: {n}; proj-FP4-eligible: {ok}") + + +def set_flag(v): + for mod in m.modules(): + if hasattr(mod, "_nvfp4_fuse_attn_proj"): + mod._nvfp4_fuse_attn_proj = v + + +with torch.no_grad(): + set_flag(True); on = m(ids).logits.float() + set_flag(False); off = m(ids).logits.float() + +print(f"logit cos proj-FP4 ON vs bf16-unpatched : {cos(on, base):.5f}") +print(f"logit cos FP4-attn OFF vs bf16-unpatched : {cos(off, base):.5f}") +print(f"logit cos ON vs OFF (marginal proj-FP4) : {cos(on, off):.5f}") +print(f"argmax agree ON vs bf16: {(on.argmax(-1)==base.argmax(-1)).float().mean().item():.4f}") + + +def t(v, it=10): + set_flag(v) + torch.cuda.synchronize() + with torch.no_grad(): + for _ in range(3): m(ids) + torch.cuda.synchronize() + a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() + for _ in range(it): m(ids) + b.record(); torch.cuda.synchronize() + return a.elapsed_time(b) / it + + +on_ms, off_ms = t(True), t(False) +print(f"prefill 1x1024: proj-FP4 ON {on_ms:.2f}ms OFF {off_ms:.2f}ms speedup {off_ms/on_ms:.3f}x") diff --git a/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py b/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py index e80375af59..0bff9fd3be 100644 --- a/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py +++ b/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py @@ -205,6 +205,69 @@ def _nvfp4_attention( ) # [Z, S, H, D] +# --------------------------------------------------------------------------- +# Shared-activation-pack FP4 for the full-attention block's hidden-consuming +# projections. q_proj (gated, -> 2*head_dim) and k_proj BOTH read hidden_states +# and contract the same K=hidden, so hidden is NVFP4-packed ONCE and reused for +# both FP4 GEMMs (the cross-op lever already used by the MLP gate/up and the GDN +# in_proj). o_proj (post-attention activation) gets its own FP4 GEMM. All +# parity-affecting (q/k/o go FP4) -> opt-in, default OFF, plain-nn.Linear only. +# --------------------------------------------------------------------------- +def _attn_proj_ok(module: nn.Module) -> bool: + from axolotl.monkeypatch.attention.nvfp4_linear_attn import _is_plain_linear + + if not ( + _is_plain_linear(module.q_proj) + and _is_plain_linear(module.k_proj) + and _is_plain_linear(module.o_proj) + ): + return False + return ( + module.q_proj.in_features % 16 == 0 + and module.k_proj.in_features % 16 == 0 + and module.o_proj.in_features % 16 == 0 + ) + + +def _nvfp4_qk_proj(module, hidden_states): + """FP4 q_proj + k_proj sharing ONE NVFP4 pack of hidden_states.""" + from axolotl.kernels.attn_nvfp4_flash import _quant_nvfp4 + from axolotl.monkeypatch.attention.nvfp4_linear_attn import ( + _gemm_from_packed_act, + _get_packed_weight, + ) + + dtype = hidden_states.dtype + k = hidden_states.shape[-1] + lead = hidden_states.shape[:-1] + x2d = hidden_states.reshape(-1, k) + m = x2d.shape[0] + anv, asc = _quant_nvfp4(x2d.unsqueeze(0)) + anv, asc = anv[0], asc[0] + q_wnv, q_wsc = _get_packed_weight(module, "_qproj_packed", module.q_proj) + k_wnv, k_wsc = _get_packed_weight(module, "_kproj_packed", module.k_proj) + q_outf, k_outf = module.q_proj.out_features, module.k_proj.out_features + q = _gemm_from_packed_act(anv, asc, q_wnv, q_wsc, m, q_outf, k, dtype) + kk = _gemm_from_packed_act(anv, asc, k_wnv, k_wsc, m, k_outf, k, dtype) + if module.q_proj.bias is not None: + q = q + module.q_proj.bias + if module.k_proj.bias is not None: + kk = kk + module.k_proj.bias + return q.reshape(*lead, q_outf), kk.reshape(*lead, k_outf) + + +def _nvfp4_o_proj(module, attn_output): + """FP4 o_proj (+bias).""" + from axolotl.kernels.nvfp4_linear import nvfp4_linear + from axolotl.monkeypatch.attention.nvfp4_linear_attn import _get_packed_weight + + o_wnv, o_wsc = _get_packed_weight(module, "_oproj_packed", module.o_proj) + out = nvfp4_linear(attn_output, o_wnv, o_wsc, module.o_proj.out_features) + if module.o_proj.bias is not None: + out = out + module.o_proj.bias + return out + + def make_nvfp4_forward(orig_forward): """Build a patched ``Qwen3_5Attention.forward`` that uses NVFP4 attention. @@ -243,16 +306,32 @@ def forward( if not has_cache_context: kind = _mask_is_dense_causal_or_full(attention_mask, q_len, q_len) - query_states, gate = torch.chunk( - self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), - 2, dim=-1, + # Shared-pack FP4 q/k_proj: no-grad prefill, opt-in, plain-Linear only. + use_fp4_proj = ( + not grad_enabled + and kind is not None + and getattr(self, "_nvfp4_fuse_attn_proj", False) + and getattr(self, "_nvfp4_attn_proj_ok", False) ) + if use_fp4_proj: + q_full, k_full = _nvfp4_qk_proj(self, hidden_states) + query_states, gate = torch.chunk( + q_full.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1, + ) + else: + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), + 2, dim=-1, + ) gate = gate.reshape(*input_shape, -1) query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm( - self.k_proj(hidden_states).view(hidden_shape) - ).transpose(1, 2) + if use_fp4_proj: + key_states = self.k_norm(k_full.view(hidden_shape)).transpose(1, 2) + else: + key_states = self.k_norm( + self.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) cos, sin = position_embeddings @@ -360,6 +439,8 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output * torch.sigmoid(gate) + if use_fp4_proj: + return _nvfp4_o_proj(self, attn_output), None return self.o_proj(attn_output), None return orig_forward( @@ -373,6 +454,7 @@ def forward( def patch_qwen3_5_nvfp4_attention( model: nn.Module, fuse_vproj: bool = _FUSE_VPROJ, + fuse_attn_proj: bool = False, train_backward: bool = False, backward_rtn_grad_packs: bool = False, save_backward_packs: bool = False, @@ -404,6 +486,10 @@ def patch_qwen3_5_nvfp4_attention( module._nvfp4_dkdv_scratch_bf16 = dkdv_scratch_bf16 module._nvfp4_compile_custom_op = compile_custom_op module._nvfp4_stochastic_rounding = stochastic_rounding + module._nvfp4_fuse_attn_proj = fuse_attn_proj + module._nvfp4_attn_proj_ok = ( + _attn_proj_ok(module) if fuse_attn_proj else False + ) continue orig = type(module).forward if seen_forward is None: @@ -417,6 +503,10 @@ def patch_qwen3_5_nvfp4_attention( module._nvfp4_dkdv_scratch_bf16 = dkdv_scratch_bf16 module._nvfp4_compile_custom_op = compile_custom_op module._nvfp4_stochastic_rounding = stochastic_rounding + module._nvfp4_fuse_attn_proj = fuse_attn_proj + module._nvfp4_attn_proj_ok = ( + _attn_proj_ok(module) if fuse_attn_proj else False + ) patched += 1 LOG.info( "nvfp4 attention: patched %d Qwen3.5 full-attention layers " From c2fde37a904a81b7fceb271dda1b3632d9881c0e Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 09:20:24 -0700 Subject: [PATCH 14/17] feat(nvfp4): separable q/k vs o_proj FP4 sub-gates + ablation Error concentrates in o_proj (direct residual readout): FP4 q/k-only=0.99863, o-only=0.99748, all=0.99677 logit cos. q/k errors are softmax-DAMPENED, not amplified. q/k-only recovers most parity; two-level o_proj weight quant + QAT are further levers. Sub-gates _nvfp4_fp4_qk / _nvfp4_fp4_o (default on w/ main flag). --- scripts/ablate_proj.py | 16 ++++++++++++++++ .../monkeypatch/attention/nvfp4_flash_attn.py | 11 ++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 scripts/ablate_proj.py diff --git a/scripts/ablate_proj.py b/scripts/ablate_proj.py new file mode 100644 index 0000000000..4be0a5fe0c --- /dev/null +++ b/scripts/ablate_proj.py @@ -0,0 +1,16 @@ +import torch, torch.nn.functional as F +from transformers import AutoModelForCausalLM +from axolotl.monkeypatch.attention.nvfp4_flash_attn import patch_qwen3_5_nvfp4_attention +dev="cuda" +m=AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(dev).eval() +torch.manual_seed(0); ids=torch.randint(0,10000,(1,1024),device=dev) +def cos(a,b): return F.cosine_similarity(a.flatten(),b.flatten(),dim=0).item() +with torch.no_grad(): base=m(ids).logits.float() +patch_qwen3_5_nvfp4_attention(m, fuse_attn_proj=True) +def setg(qk,o): + for mod in m.modules(): + if hasattr(mod,"_nvfp4_fuse_attn_proj"): mod._nvfp4_fp4_qk=qk; mod._nvfp4_fp4_o=o +for name,qk,o in [("o-only",False,True),("qk-only",True,False),("all",True,True)]: + setg(qk,o) + with torch.no_grad(): lg=m(ids).logits.float() + print(f"{name:8s}: logit cos {cos(lg,base):.5f} argmax-agree {(lg.argmax(-1)==base.argmax(-1)).float().mean().item():.3f}") diff --git a/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py b/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py index 0bff9fd3be..1cb714409f 100644 --- a/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py +++ b/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py @@ -313,7 +313,12 @@ def forward( and getattr(self, "_nvfp4_fuse_attn_proj", False) and getattr(self, "_nvfp4_attn_proj_ok", False) ) - if use_fp4_proj: + # q/k feed QK^T->softmax (error amplified through exp); o_proj is a direct + # readout. Separable sub-gates let q/k stay bf16 while o_proj goes FP4 (the + # near-lossless subset). Both default ON when the main flag is on. + use_fp4_qk = use_fp4_proj and getattr(self, "_nvfp4_fp4_qk", True) + use_fp4_o = use_fp4_proj and getattr(self, "_nvfp4_fp4_o", True) + if use_fp4_qk: q_full, k_full = _nvfp4_qk_proj(self, hidden_states) query_states, gate = torch.chunk( q_full.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1, @@ -326,7 +331,7 @@ def forward( gate = gate.reshape(*input_shape, -1) query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) - if use_fp4_proj: + if use_fp4_qk: key_states = self.k_norm(k_full.view(hidden_shape)).transpose(1, 2) else: key_states = self.k_norm( @@ -439,7 +444,7 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output * torch.sigmoid(gate) - if use_fp4_proj: + if use_fp4_o: return _nvfp4_o_proj(self, attn_output), None return self.o_proj(attn_output), None From 94f62c60f9b53af4533746ea7e5038e9c43843ab Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 09:58:19 -0700 Subject: [PATCH 15/17] refactor(nvfp4): nest attention flags under model-agnostic nvfp4_training.attention Consolidate the attention flag group into a nested schema (attention.{enabled, fuse_vproj, fp4_projections, backward.{enabled, rtn_grad_packs, save_packs, dkdv_scratch_bf16, compile_custom_op}}) + top-level linear_attn / mlp / fla_causal_conv_compile_boundary. Old flat qwen3_5_* names still parse via a deprecation before-validator. requires-checks moved into NVFP4AttentionConfig; patch_manager + config.py rewired to nested; q/k/o fp4_projections wired through. Examples, tests (nested + legacy-migration coverage), and docs updated. Validated: 30 config tests pass, kernel suite 36 pass, examples parse nested. --- docs/nvfp4_training.qmd | 34 +-- examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml | 20 +- examples/nvfp4/qwen35-9b-lora-fastest.yaml | 20 +- src/axolotl/loaders/patch_manager.py | 53 ++--- src/axolotl/utils/schemas/config.py | 92 ++----- src/axolotl/utils/schemas/nvfp4.py | 264 ++++++++++++++------- tests/e2e/test_nvfp4_integration.py | 135 +++++++---- 7 files changed, 339 insertions(+), 279 deletions(-) diff --git a/docs/nvfp4_training.qmd b/docs/nvfp4_training.qmd index 011bf3c433..630eeef881 100644 --- a/docs/nvfp4_training.qmd +++ b/docs/nvfp4_training.qmd @@ -123,12 +123,12 @@ checkpointing path measured so far is: - `adapter: lora` - `nvfp4_training.base_mode: compute` -- `qwen3_5_native_attention: true` -- `qwen3_5_native_attention_backward: true` -- `qwen3_5_native_attention_backward_rtn_grad_packs: true` -- `qwen3_5_native_attention_save_backward_packs: true` -- `qwen3_5_native_attention_dkdv_scratch_bf16: true` -- `qwen3_5_fla_causal_conv_compile_boundary: true` +- `attention.enabled: true` +- `attention.backward.enabled: true` +- `attention.backward.rtn_grad_packs: true` +- `attention.backward.save_packs: true` +- `attention.backward.dkdv_scratch_bf16: true` +- `fla_causal_conv_compile_boundary: true` - `fuse_rmsnorm: false` - `max_grad_norm: 1.0` @@ -178,14 +178,14 @@ was a **silent eager fallback**: the attention forward's P@V `tl.dot_scaled` rai an `InductorError` inside Inductor autotune, and with `suppress_errors` on (the training default) the whole attention-containing subgraph ran eager, blocking fusion of the surrounding elementwise. Auto-enabling the differentiable attention -custom op when `torch_compile` is on (see `qwen3_5_native_attention_compile_custom_op` +custom op when `torch_compile` is on (see `attention.backward.compile_custom_op` below) makes Inductor compile *around* the opaque op, restoring fusion. Stacked with the bit-exact amax fuse (`amax(abs(t))` → `vector_norm(t, inf)`, dropping the full-tensor abs pass) and batched shared-input LoRA A-GEMMs (`lora_batch_kernel`), the interleaved A/B above measured **1.131 vs 1.190 s/step (~4.8% faster) and −9 GiB** at identical loss. All three changes are bit-exact or opt-in; the first also drops active memory because the custom-op backward does not retain forward -packs (`qwen3_5_native_attention_save_backward_packs` has no effect under it). +packs (`attention.backward.save_packs` has no effect under it). The marginal s/step varies ~5% **between the two RTX PRO 6000 boards**: the PCIe-first board measured 1.2208–1.2302 across 60- and 500-step runs, while the @@ -496,15 +496,17 @@ The `nvfp4_training:` block (schema: `src/axolotl/utils/schemas/nvfp4.py`, | `skip_first_n_blocks` | `int` | `0` | Keep the first N transformer blocks in high precision (see the ~15% high-precision policy below). | | `skip_last_n_blocks` | `int` | `0` | Keep the last N transformer blocks in high precision (the tail blocks matter most). | | `save_nvfp4` | `bool` | `false` | Opt-in. Store eligible weights NVFP4-packed (qdata + scales) in a `torch.save` sidecar for ~3.5× smaller weight files. See [FP4-packed save](#fp4-packed-save-save_nvfp4) below. **Lossy for FFT resume** (no bf16 master kept); bit-exact for frozen weights. Off by default (bf16 save, unchanged). | -| `qwen3_5_native_attention` | `bool` | `false` | Qwen3.5 only. Patch full softmax-attention layers to use the native NVFP4 attention path on dense causal/full batches. | -| `qwen3_5_native_attention_backward` | `bool` | `false` | Qwen3.5 only. Requires `qwen3_5_native_attention`. Use the native NVFP4 autograd attention path while training. | -| `qwen3_5_native_attention_backward_rtn_grad_packs` | `bool` | `false` | Qwen3.5 native attention training only. Use deterministic round-to-nearest for measured-safe gradient packs while leaving the dK routing-gradient dS pack governed by `stochastic_rounding`. | -| `qwen3_5_native_attention_save_backward_packs` | `bool` | `false` | Qwen3.5 native attention training only. Save deterministic forward Q/K/V FP4 packs plus transposed backward layouts and reuse them in backward. Trades extra activation memory for higher throughput. | -| `qwen3_5_native_attention_dkdv_scratch_bf16` | `bool` | `false` | Qwen3.5 native attention training only. Store the dQ **and** per-query-head dK/dV backward scratch in bf16 (dQ/dK/dV accumulate fp32 in-register and downcast once at the store, so this is bit-identical to fp32-then-`.to(bf16)` — a pure memory save on the largest backward scratch planes). Measured faster + lower-memory on Qwen3.5-9B b4. Fixed-shape probe at the 9B full-attn shape: fwd+bwd peak 406→342 MiB (−16%) on the custom-op path, grads `maxabsdiff=0`. | -| `qwen3_5_native_attention_compile_custom_op` | `bool \| null` | `null` (auto) | Qwen3.5 native attention. Wraps the attention path in an opaque differentiable custom op so Inductor compiles *around* the Triton `tl.dot_scaled` (which otherwise raises an `InductorError` and silently drops the block to eager). `null` auto-enables it when `torch_compile` is on and native attention is enabled (the measured fastest path); `true`/`false` force it. Bit-exact; under it `save_backward_packs` has no effect (backward recomputes packs across the boundary) and active memory drops. | +| `attention.enabled` | `bool` | `false` | Patch full softmax-attention layers to the native NVFP4 attention path on dense causal/full batches (model-agnostic; applied where the architecture supports it). Replaces the deprecated flat `qwen3_5_native_attention*` flags. | +| `attention.fuse_vproj` | `bool \| null` | `null` | Run v_proj as a native NVFP4 GEMM with a key-axis pack epilogue on inference/cache-free prefill. `null`: on for inference, off for training. | +| `attention.fp4_projections` | `bool` | `false` | Run q/k/o_proj as native NVFP4 GEMMs (q/k share one activation pack) on inference prefill. **Parity-affecting** (not bit-exact); speed-neutral on hybrid models, a per-layer win on dense models. Plain-`nn.Linear` only. | +| `attention.backward.enabled` | `bool` | `false` | Requires `attention.enabled`. Use the native NVFP4 autograd attention path while training. | +| `attention.backward.rtn_grad_packs` | `bool` | `false` | Deterministic round-to-nearest for the measured-safe gradient packs, leaving the dK and dPt packs governed by `stochastic_rounding`. | +| `attention.backward.save_packs` | `bool` | `false` | Save the forward Q/K/V FP4 packs (+ transposed backward layouts) and reuse them in backward — trades activation memory for higher throughput. | +| `attention.backward.dkdv_scratch_bf16` | `bool` | `false` | Store dQ and per-query-head dK/dV backward scratch in bf16 (accumulate fp32 in-register, downcast once at the store → bit-identical to fp32-then-`.to(bf16)`; a pure memory save on the largest backward scratch planes). | +| `attention.backward.compile_custom_op` | `bool \| null` | `null` (auto) | Wrap the attention path in an opaque differentiable custom op so Inductor compiles *around* the Triton `tl.dot_scaled` (which otherwise raises an `InductorError` and silently drops the block to eager). `null` auto-enables it when `torch_compile` is on and `attention.enabled`; `true`/`false` force it. Under it `attention.backward.save_packs` has no effect. | | `bf16_lm_head_cross_entropy` | `bool` | `false` | Opt-in **memory** path. Requires a frozen bias-free plain `nn.Linear` lm_head. Chunked online-softmax CE over bf16 vocab tiles — no `[tokens, vocab]` logits materialization, fp32 logsumexp/`grad_hidden`, no gradient filtering (avoids the CCE/Liger collapse). Trades ~2.5% throughput for ~13 GiB. Mutually exclusive with `quantize_lm_head` / `fused_fp4_cross_entropy` / `fp8_lm_head_cross_entropy`. | -| `qwen3_5_fla_causal_conv_compile_boundary` | `bool` | `false` | Qwen3.5 sample-packing only. Runs FLA varlen `causal_conv1d` behind a compile boundary so variable packed `cu_seqlens` lengths do not repeatedly trigger Dynamo/Inductor recompiles. | -| `qwen3_5_fuse_vproj` / `qwen3_5_native_mlp` / `qwen3_5_native_linear_attn` / `fp8_lm_head` | `bool` | `false` | Qwen3.5/eval-scoring paths. Current implementations are eval/no-grad only and do not accelerate grad-enabled training. Use `fp8_lm_head_cross_entropy` separately for the opt-in training loss memory path. | +| `fla_causal_conv_compile_boundary` | `bool` | `false` | Sample-packing only. Runs FLA varlen `causal_conv1d` behind a compile boundary so variable packed `cu_seqlens` lengths do not repeatedly trigger Dynamo/Inductor recompiles. | +| `linear_attn` / `mlp` / `fp8_lm_head` | `bool` | `false` | Eval/no-grad-only native-NVFP4 module patches (linear-attention projections, dense SwiGLU MLP). Do not accelerate grad-enabled training. Deprecated flat aliases: `qwen3_5_native_linear_attn`, `qwen3_5_native_mlp`. | ## FP4-packed save (`save_nvfp4`) {#fp4-packed-save-save_nvfp4} diff --git a/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml b/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml index 9494d43c4d..3f6275aec4 100644 --- a/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml +++ b/examples/nvfp4/qwen35-9b-lora-bf16-ce.yaml @@ -39,7 +39,7 @@ # max_grad_norm: 1.0 explicit, fp16 produced NaN LoRA gradients at the first AMP # unscale. bf16 is the natural full-precision Blackwell baseline. # -# The speed knob here is qwen3_5_native_attention_save_backward_packs: it saves +# The speed knob here is attention.backward.save_packs: it saves # forward FP4 attention packs and reuses them in backward, trading ~0.4 GiB of # activation memory for less backward pack-prep work. The FLA boundary prevents # variable packed cu_seqlens from burning compile time in causal_conv1d. BF16 @@ -118,14 +118,16 @@ nvfp4_training: # Chunked bf16 lm_head CE: skip materializing the [tokens, vocab] logits. bf16_lm_head_cross_entropy: true - # Qwen3.5 full-attention training path. The saved-pack flag is the measured - # throughput win; RTN grad packs keep the safe gradient-side packs deterministic. - qwen3_5_native_attention: true - qwen3_5_native_attention_backward: true - qwen3_5_native_attention_backward_rtn_grad_packs: true - qwen3_5_native_attention_save_backward_packs: true - qwen3_5_native_attention_dkdv_scratch_bf16: true - qwen3_5_fla_causal_conv_compile_boundary: true + # Native full-attention training path. save_packs is the measured throughput + # win; rtn_grad_packs keeps the safe gradient-side packs deterministic. + attention: + enabled: true + backward: + enabled: true + rtn_grad_packs: true + save_packs: true + dkdv_scratch_bf16: true + fla_causal_conv_compile_boundary: true warmup_steps: 10 logging_steps: 1 diff --git a/examples/nvfp4/qwen35-9b-lora-fastest.yaml b/examples/nvfp4/qwen35-9b-lora-fastest.yaml index 22dfa975c2..006bb77c57 100644 --- a/examples/nvfp4/qwen35-9b-lora-fastest.yaml +++ b/examples/nvfp4/qwen35-9b-lora-fastest.yaml @@ -29,7 +29,7 @@ # max_grad_norm: 1.0 explicit, fp16 produced NaN LoRA gradients at the first AMP # unscale. bf16 is the natural full-precision Blackwell baseline. # -# The speed knob here is qwen3_5_native_attention_save_backward_packs: it saves +# The speed knob here is attention.backward.save_packs: it saves # forward FP4 attention packs and reuses them in backward, trading ~0.4 GiB of # activation memory for less backward pack-prep work. The FLA boundary prevents # variable packed cu_seqlens from burning compile time in causal_conv1d. BF16 @@ -105,14 +105,16 @@ nvfp4_training: skip_last_n_blocks: 0 fuse_rmsnorm: false - # Qwen3.5 full-attention training path. The saved-pack flag is the measured - # throughput win; RTN grad packs keep the safe gradient-side packs deterministic. - qwen3_5_native_attention: true - qwen3_5_native_attention_backward: true - qwen3_5_native_attention_backward_rtn_grad_packs: true - qwen3_5_native_attention_save_backward_packs: true - qwen3_5_native_attention_dkdv_scratch_bf16: true - qwen3_5_fla_causal_conv_compile_boundary: true + # Native full-attention training path. save_packs is the measured throughput + # win; rtn_grad_packs keeps the safe gradient-side packs deterministic. + attention: + enabled: true + backward: + enabled: true + rtn_grad_packs: true + save_packs: true + dkdv_scratch_bf16: true + fla_causal_conv_compile_boundary: true warmup_steps: 10 logging_steps: 1 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 792a80cdb9..2bee09ba5e 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -279,8 +279,7 @@ def _apply_flash_attention_patches(self): "be used by axolotl train. For Qwen3.5 native FP4 attention " "training, keep a training-capable attention backend such as " "flash_attention_2 and enable nvfp4_training with " - "qwen3_5_native_attention plus " - "qwen3_5_native_attention_backward." + "attention.enabled plus attention.backward.enabled." ) patch_sage_fp4_attn() @@ -516,7 +515,8 @@ def _apply_qwen3_5_native_nvfp4_patches(self, model: PreTrainedModel): if self.cfg.model_config_type not in ("qwen3_5", "qwen3_5_moe"): return - if getattr(nvfp4, "qwen3_5_native_attention", False): + attn = nvfp4.attention + if attn.enabled: from axolotl.monkeypatch.attention.nvfp4_flash_attn import ( patch_qwen3_5_nvfp4_attention, ) @@ -524,42 +524,25 @@ def _apply_qwen3_5_native_nvfp4_patches(self, model: PreTrainedModel): patch_qwen3_5_nvfp4_attention( model, fuse_vproj=( - self.inference - if getattr(nvfp4, "qwen3_5_fuse_vproj", None) is None - else nvfp4.qwen3_5_fuse_vproj - ), - train_backward=getattr( - nvfp4, "qwen3_5_native_attention_backward", False - ), - backward_rtn_grad_packs=getattr( - nvfp4, - "qwen3_5_native_attention_backward_rtn_grad_packs", - False, - ), - save_backward_packs=getattr( - nvfp4, - "qwen3_5_native_attention_save_backward_packs", - False, - ), - dkdv_scratch_bf16=getattr( - nvfp4, - "qwen3_5_native_attention_dkdv_scratch_bf16", - False, - ), - compile_custom_op=bool( - getattr(nvfp4, "qwen3_5_native_attention_compile_custom_op", False) + self.inference if attn.fuse_vproj is None else attn.fuse_vproj ), + fuse_attn_proj=attn.fp4_projections, + train_backward=attn.backward.enabled, + backward_rtn_grad_packs=attn.backward.rtn_grad_packs, + save_backward_packs=attn.backward.save_packs, + dkdv_scratch_bf16=attn.backward.dkdv_scratch_bf16, + compile_custom_op=bool(attn.backward.compile_custom_op), stochastic_rounding=nvfp4.stochastic_rounding, ) - if getattr(nvfp4, "qwen3_5_native_linear_attn", False): + if nvfp4.linear_attn: from axolotl.monkeypatch.attention.nvfp4_linear_attn import ( patch_qwen3_5_nvfp4_linear_attn, ) patch_qwen3_5_nvfp4_linear_attn(model) - if getattr(nvfp4, "qwen3_5_native_mlp", False): + if nvfp4.mlp: from axolotl.monkeypatch.attention.nvfp4_mlp import ( patch_qwen3_5_nvfp4_mlp, ) @@ -962,11 +945,7 @@ def _apply_model_specific_patches(self): fla_causal_conv_compile_boundary=bool( nvfp4 and nvfp4.enabled - and getattr( - nvfp4, - "qwen3_5_fla_causal_conv_compile_boundary", - False, - ) + and nvfp4.fla_causal_conv_compile_boundary ) ) @@ -980,11 +959,7 @@ def _apply_model_specific_patches(self): fla_causal_conv_compile_boundary=bool( nvfp4 and nvfp4.enabled - and getattr( - nvfp4, - "qwen3_5_fla_causal_conv_compile_boundary", - False, - ) + and nvfp4.fla_causal_conv_compile_boundary ) ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 04bcbaa9e1..fd1b1fcd23 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1570,9 +1570,9 @@ def check_sage_fp4_inference_only(self): "attn_implementation: sage_fp4 (SageAttention-3 FP4) is " "inference-only and cannot be used by axolotl train. Use " "fp4_attention_qat for fake-quant attention training, or use " - "nvfp4_training.qwen3_5_native_attention with " - "qwen3_5_native_attention_backward for the experimental " - "Qwen3.5 native FP4 training path." + "nvfp4_training.attention.enabled with " + "attention.backward.enabled for the experimental " + "native FP4 training path." ) LOG.warning( "attn_implementation: sage_fp4 (SageAttention-3 FP4) is INFERENCE only " @@ -1667,62 +1667,14 @@ def check_nvfp4_training(self): f"got adapter={self.adapter!r}." ) - if ( - self.nvfp4_training.qwen3_5_native_attention_backward - and not self.nvfp4_training.qwen3_5_native_attention - ): - raise ValueError( - "nvfp4_training.qwen3_5_native_attention_backward requires " - "qwen3_5_native_attention: true." - ) - if ( - self.nvfp4_training.qwen3_5_native_attention_backward_rtn_grad_packs - and not self.nvfp4_training.qwen3_5_native_attention_backward - ): - raise ValueError( - "nvfp4_training.qwen3_5_native_attention_backward_rtn_grad_packs " - "requires qwen3_5_native_attention_backward: true." - ) - if ( - self.nvfp4_training.qwen3_5_native_attention_save_backward_packs - and not self.nvfp4_training.qwen3_5_native_attention_backward - ): - raise ValueError( - "nvfp4_training.qwen3_5_native_attention_save_backward_packs " - "requires qwen3_5_native_attention_backward: true." - ) - if ( - self.nvfp4_training.qwen3_5_native_attention_dkdv_scratch_bf16 - and not self.nvfp4_training.qwen3_5_native_attention_backward - ): - raise ValueError( - "nvfp4_training.qwen3_5_native_attention_dkdv_scratch_bf16 " - "requires qwen3_5_native_attention_backward: true." - ) - if ( - self.nvfp4_training.qwen3_5_fuse_vproj - and not self.nvfp4_training.qwen3_5_native_attention - ): - raise ValueError( - "nvfp4_training.qwen3_5_fuse_vproj requires " - "qwen3_5_native_attention: true." - ) - if ( - self.nvfp4_training.qwen3_5_native_attention_compile_custom_op - and not self.nvfp4_training.qwen3_5_native_attention - ): - raise ValueError( - "nvfp4_training.qwen3_5_native_attention_compile_custom_op " - "requires qwen3_5_native_attention: true." - ) - # Tri-state auto-resolve: under torch_compile the bare tl.dot_scaled flash - # kernel raises an Inductor CompilationError and (with the default error - # suppression) silently falls the attention region back to eager, blocking - # fusion of the surrounding elementwise. The opaque custom op compiles - # around it with bit-identical grads, so default it on when compile is live. - if self.nvfp4_training.qwen3_5_native_attention_compile_custom_op is None: - self.nvfp4_training.qwen3_5_native_attention_compile_custom_op = bool( - self.nvfp4_training.qwen3_5_native_attention and self.torch_compile + # attention.* requires-relationships are enforced inside NVFP4AttentionConfig. + # Tri-state auto-resolve (cross-field with torch_compile): under compile the + # bare tl.dot_scaled flash kernel raises an Inductor CompilationError and + # silently falls the attention region back to eager; the opaque custom op + # compiles around it with bit-identical grads, so default it on when live. + if self.nvfp4_training.attention.backward.compile_custom_op is None: + self.nvfp4_training.attention.backward.compile_custom_op = bool( + self.nvfp4_training.attention.enabled and self.torch_compile ) if self.nvfp4_training.fp8_lm_head_cross_entropy and ( self.nvfp4_training.quantize_lm_head @@ -1744,17 +1696,19 @@ def check_nvfp4_training(self): "remain a frozen plain nn.Linear. Disable quantize_lm_head/" "fused_fp4_cross_entropy/fp8_lm_head_cross_entropy." ) + _attn = self.nvfp4_training.attention qwen3_5_native_flags = ( - self.nvfp4_training.qwen3_5_native_attention, - self.nvfp4_training.qwen3_5_native_attention_backward, - self.nvfp4_training.qwen3_5_native_attention_backward_rtn_grad_packs, - self.nvfp4_training.qwen3_5_native_attention_save_backward_packs, - self.nvfp4_training.qwen3_5_native_attention_dkdv_scratch_bf16, - self.nvfp4_training.qwen3_5_native_attention_compile_custom_op, - self.nvfp4_training.qwen3_5_fla_causal_conv_compile_boundary, - self.nvfp4_training.qwen3_5_fuse_vproj, - self.nvfp4_training.qwen3_5_native_linear_attn, - self.nvfp4_training.qwen3_5_native_mlp, + _attn.enabled, + _attn.backward.enabled, + _attn.backward.rtn_grad_packs, + _attn.backward.save_packs, + _attn.backward.dkdv_scratch_bf16, + _attn.backward.compile_custom_op, + _attn.fuse_vproj, + _attn.fp4_projections, + self.nvfp4_training.fla_causal_conv_compile_boundary, + self.nvfp4_training.linear_attn, + self.nvfp4_training.mlp, ) model_config_type = getattr(self, "model_config_type", None) if ( diff --git a/src/axolotl/utils/schemas/nvfp4.py b/src/axolotl/utils/schemas/nvfp4.py index 6922b6f3b0..2ef1a9537c 100644 --- a/src/axolotl/utils/schemas/nvfp4.py +++ b/src/axolotl/utils/schemas/nvfp4.py @@ -4,9 +4,120 @@ from the fake-quant QAT/PTQ `quantization:` block. """ +import warnings from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator + + +class NVFP4AttentionBackwardConfig(BaseModel): + """Native-NVFP4 attention training (backward) settings (expert recipe).""" + + enabled: bool = Field( + default=False, + json_schema_extra={ + "description": "Use the native NVFP4 autograd attention path while " + "training. Validated for convergence but can be slower than bf16 at " + "short sequence lengths, so it stays explicitly opt-in." + }, + ) + rtn_grad_packs: bool = Field( + default=False, + json_schema_extra={ + "description": "Deterministic round-to-nearest for the measured-safe " + "gradient packs (softmax P / transposed dO for dV, dS for dQ), leaving " + "the dK and dPt packs governed by stochastic_rounding. Faster in " + "microbenchmarks; convergence validation still required." + }, + ) + save_packs: bool = Field( + default=False, + json_schema_extra={ + "description": "Save the forward pass's deterministic Q/K/V NVFP4 packs " + "(+ transposed backward layouts) and reuse them in backward — trades " + "activation memory for less backward pack-prep work." + }, + ) + dkdv_scratch_bf16: bool = Field( + default=False, + json_schema_extra={ + "description": "Store the per-query-head dK/dV scratch buffers in bf16 " + "before GQA reduction instead of fp32 (less scratch traffic; bit-exact " + "vs the fp32-then-cast path)." + }, + ) + compile_custom_op: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Route the native NVFP4 flash attention through an opaque " + "torch custom op (torch.compile escape hatch). Tri-state: None " + "auto-enables it whenever torch_compile is on; True/False force it." + }, + ) + + +class NVFP4AttentionConfig(BaseModel): + """Native-NVFP4 full-attention path (model-agnostic; applied where the " + architecture supports it). Replaces the flat ``qwen3_5_native_attention*`` flags.""" + + enabled: bool = Field( + default=False, + json_schema_extra={ + "description": "Patch full softmax-attention layers to the native NVFP4 " + "attention path on dense causal/full batches. Falls back to the model's " + "configured attention for unsupported masks/cache states." + }, + ) + fuse_vproj: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Run v_proj as a native NVFP4 GEMM with key-axis pack " + "epilogue on inference/cache-free prefill (~1.4-1.5x V producer; v_proj " + "goes FP4). Default (null): ON for inference, OFF for training." + }, + ) + fp4_projections: bool = Field( + default=False, + json_schema_extra={ + "description": "Run q/k/o_proj as native NVFP4 GEMMs (q/k share one " + "activation pack) on inference prefill. Parity-affecting (not bit-exact); " + "speed-neutral on hybrid models, a per-layer win on dense models. OFF by " + "default; plain-nn.Linear only." + }, + ) + backward: NVFP4AttentionBackwardConfig = Field( + default_factory=NVFP4AttentionBackwardConfig, + json_schema_extra={"description": "Native-NVFP4 attention training settings."}, + ) + + @model_validator(mode="after") + def _check_requires(self): + if not self.enabled: + if self.backward.enabled: + raise ValueError( + "nvfp4_training.attention.backward.enabled requires " + "attention.enabled: true." + ) + if self.fuse_vproj: + raise ValueError( + "nvfp4_training.attention.fuse_vproj requires attention.enabled: true." + ) + if self.compile_custom_op_set(): + raise ValueError( + "nvfp4_training.attention.backward.compile_custom_op requires " + "attention.enabled: true." + ) + if not self.backward.enabled: + for sub in ("rtn_grad_packs", "save_packs", "dkdv_scratch_bf16"): + if getattr(self.backward, sub): + raise ValueError( + f"nvfp4_training.attention.backward.{sub} requires " + "attention.backward.enabled: true." + ) + return self + + def compile_custom_op_set(self) -> bool: + return bool(self.backward.compile_custom_op) class NVFP4TrainingConfig(BaseModel): @@ -245,103 +356,84 @@ class NVFP4TrainingConfig(BaseModel): "resume. OFF by default (bf16 save, unchanged)." }, ) - qwen3_5_native_attention: bool = Field( - default=False, - json_schema_extra={ - "description": "Qwen3.5 only. Patch full softmax-attention layers to use " - "the native NVFP4 attention path on dense causal/full batches. Forward " - "falls back to the model's configured attention for unsupported masks or " - "cache states. OFF by default." - }, - ) - qwen3_5_native_attention_backward: bool = Field( - default=False, - json_schema_extra={ - "description": "Qwen3.5 only; requires qwen3_5_native_attention. Use the " - "native NVFP4 autograd attention path while training. This is validated " - "for convergence but can be slower than bf16 at short sequence lengths, " - "so it stays explicitly opt-in." - }, - ) - qwen3_5_native_attention_backward_rtn_grad_packs: bool = Field( - default=False, - json_schema_extra={ - "description": "Qwen3.5 native attention training only. Use " - "deterministic round-to-nearest for the measured-safe gradient packs " - "(softmax P and transposed dO for dV, and dS for dQ) while leaving the " - "dK routing-gradient dS pack AND the dPt dO pack governed by " - "stochastic_rounding. This " - "collapsed mode was faster in backward microbenchmarks; convergence " - "validation is still required for production training. OFF by default." - }, - ) - qwen3_5_native_attention_save_backward_packs: bool = Field( - default=False, + attention: NVFP4AttentionConfig = Field( + default_factory=NVFP4AttentionConfig, json_schema_extra={ - "description": "Qwen3.5 native attention training only. Save the " - "forward pass's deterministic Q/K/V NVFP4 packs plus transposed " - "backward layouts and reuse them during backward. This spends extra " - "activation memory to skip backward pack-prep work. OFF by default." + "description": "Native-NVFP4 full-attention path settings (model-agnostic; " + "applied where the architecture supports it). Replaces the deprecated flat " + "qwen3_5_native_attention* flags." }, ) - qwen3_5_native_attention_dkdv_scratch_bf16: bool = Field( + linear_attn: bool = Field( default=False, json_schema_extra={ - "description": "Qwen3.5 native attention training only. Store the " - "per-query-head dK/dV scratch buffers in bf16 before GQA reduction " - "instead of fp32. This can reduce dK/dV scratch memory traffic, but " - "changes an intermediate accumulation cast and stays opt-in. OFF by " - "default." + "description": "Patch linear-attention (e.g. GatedDeltaNet) large " + "projection GEMMs to native NVFP4 in no-grad forward/eval. Training " + "forwards fall back to the original implementation." }, ) - qwen3_5_native_attention_compile_custom_op: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Route the native NVFP4 flash-attention call through an " - "opaque torch custom op as a torch.compile compatibility escape hatch " - "when the internal Triton tl.dot_scaled kernel cannot be captured. On " - "the no-grad path this wraps the packed forward op; with " - "qwen3_5_native_attention_backward it wraps a DIFFERENTIABLE custom op " - "(forward + registered native-NVFP4 backward) so Inductor compiles " - "around the whole attention instead of falling the backward subgraph " - "back to eager. Tri-state: None auto-enables it whenever torch_compile " - "is on (the bare tl.dot_scaled path raises an Inductor CompilationError " - "there and silently falls the region back to eager); True/False force it." - }, - ) - qwen3_5_fla_causal_conv_compile_boundary: bool = Field( + mlp: bool = Field( default=False, json_schema_extra={ - "description": "Qwen3.5 sample-packing only. Run FLA varlen " - "causal_conv1d behind a torch.compile boundary so packed cu_seqlens " - "length changes do not trigger repeated Dynamo recompiles. This may " - "trade graph coverage for steadier train steps. OFF by default." + "description": "Patch dense SwiGLU MLP GEMMs to native NVFP4 in no-grad " + "forward/eval. Training forwards fall back to the original implementation." }, ) - qwen3_5_fuse_vproj: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Qwen3.5 native attention only. Run v_proj as a native " - "NVFP4 GEMM with key-axis pack epilogue on inference/cache-free prefill " - "(~1.4-1.5x V producer; v_proj goes FP4, so attn-op cos ~0.967 but " - "real-model logit cos stays >=0.998 / top-1 preserved). Default (null): " - "ON for inference, OFF for training (the grad path ignores this flag). " - "Skipped automatically for adapter-wrapped v_proj modules." - }, - ) - qwen3_5_native_linear_attn: bool = Field( - default=False, - json_schema_extra={ - "description": "Qwen3.5 only. Patch GatedDeltaNet's large projection GEMMs " - "to native NVFP4 in no-grad forward/eval. Training forwards fall back to " - "the original implementation." - }, - ) - qwen3_5_native_mlp: bool = Field( + fla_causal_conv_compile_boundary: bool = Field( default=False, json_schema_extra={ - "description": "Qwen3.5 only. Patch dense SwiGLU MLP GEMMs to native NVFP4 " - "in no-grad forward/eval. Training forwards fall back to the original " - "implementation." + "description": "Sample-packing only. Run FLA varlen causal_conv1d behind a " + "torch.compile boundary so packed cu_seqlens length changes do not trigger " + "repeated Dynamo recompiles. Trades graph coverage for steadier steps." }, ) + + @model_validator(mode="before") + @classmethod + def _migrate_legacy_attention_flags(cls, data): + """Map deprecated flat ``qwen3_5_*`` attention flags onto the nested schema.""" + if not isinstance(data, dict): + return data + attn_map = { + "qwen3_5_native_attention": ("enabled",), + "qwen3_5_fuse_vproj": ("fuse_vproj",), + "qwen3_5_native_attention_backward": ("backward", "enabled"), + "qwen3_5_native_attention_backward_rtn_grad_packs": ("backward", "rtn_grad_packs"), + "qwen3_5_native_attention_save_backward_packs": ("backward", "save_packs"), + "qwen3_5_native_attention_dkdv_scratch_bf16": ("backward", "dkdv_scratch_bf16"), + "qwen3_5_native_attention_compile_custom_op": ("backward", "compile_custom_op"), + } + top_map = { + "qwen3_5_native_linear_attn": "linear_attn", + "qwen3_5_native_mlp": "mlp", + "qwen3_5_fla_causal_conv_compile_boundary": "fla_causal_conv_compile_boundary", + } + used: list[str] = [] + attn = dict(data.get("attention") or {}) + bwd = dict(attn.get("backward") or {}) + for old, path in attn_map.items(): + if old not in data: + continue + used.append(old) + val = data.pop(old) + if len(path) == 1: + attn.setdefault(path[0], val) # explicit nested value wins + else: + bwd.setdefault(path[1], val) + if bwd: + attn["backward"] = bwd + if attn: + data["attention"] = attn + for old, new in top_map.items(): + if old in data: + used.append(old) + data.setdefault(new, data.pop(old)) + if used: + warnings.warn( + "nvfp4_training: flat attention flags " + f"{sorted(used)} are deprecated; use the nested " + "nvfp4_training.attention.* schema instead.", + DeprecationWarning, + stacklevel=2, + ) + return data diff --git a/tests/e2e/test_nvfp4_integration.py b/tests/e2e/test_nvfp4_integration.py index 550f6e19e8..b32b204fd0 100644 --- a/tests/e2e/test_nvfp4_integration.py +++ b/tests/e2e/test_nvfp4_integration.py @@ -155,23 +155,54 @@ def test_schema_accepts_qwen3_5_native_switches(monkeypatch): model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward": True, - "qwen3_5_native_attention_backward_rtn_grad_packs": True, - "qwen3_5_native_attention_save_backward_packs": True, - "qwen3_5_native_attention_dkdv_scratch_bf16": True, - "qwen3_5_fla_causal_conv_compile_boundary": True, - "qwen3_5_fuse_vproj": True, - "qwen3_5_native_linear_attn": True, - "qwen3_5_native_mlp": True, + "attention": { + "enabled": True, + "fuse_vproj": True, + "fp4_projections": True, + "backward": { + "enabled": True, + "rtn_grad_packs": True, + "save_packs": True, + "dkdv_scratch_bf16": True, + }, + }, + "linear_attn": True, + "mlp": True, + "fla_causal_conv_compile_boundary": True, }, ) - assert cfg.nvfp4_training.qwen3_5_native_attention is True - assert cfg.nvfp4_training.qwen3_5_native_attention_backward is True - assert cfg.nvfp4_training.qwen3_5_native_attention_backward_rtn_grad_packs is True - assert cfg.nvfp4_training.qwen3_5_native_attention_save_backward_packs is True - assert cfg.nvfp4_training.qwen3_5_native_attention_dkdv_scratch_bf16 is True - assert cfg.nvfp4_training.qwen3_5_fla_causal_conv_compile_boundary is True + a = cfg.nvfp4_training.attention + assert a.enabled is True and a.fuse_vproj is True and a.fp4_projections is True + assert a.backward.enabled is True and a.backward.rtn_grad_packs is True + assert a.backward.save_packs is True and a.backward.dkdv_scratch_bf16 is True + assert cfg.nvfp4_training.linear_attn is True and cfg.nvfp4_training.mlp is True + assert cfg.nvfp4_training.fla_causal_conv_compile_boundary is True + + +def test_schema_migrates_legacy_attention_flags(monkeypatch): + _supported(monkeypatch, True) + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + cfg = AxolotlInputConfig( + **BASE, + model_config_type="qwen3_5", + nvfp4_training={ + "enabled": True, + "qwen3_5_native_attention": True, + "qwen3_5_native_attention_backward": True, + "qwen3_5_native_attention_save_backward_packs": True, + "qwen3_5_native_mlp": True, + "qwen3_5_fla_causal_conv_compile_boundary": True, + }, + ) + a = cfg.nvfp4_training.attention + assert a.enabled is True and a.backward.enabled is True + assert a.backward.save_packs is True + assert cfg.nvfp4_training.mlp is True + assert cfg.nvfp4_training.fla_causal_conv_compile_boundary is True + assert any("deprecated" in str(x.message) for x in w) def test_gate_refuses_qwen3_5_switch_on_other_model(monkeypatch): @@ -183,7 +214,7 @@ def test_gate_refuses_qwen3_5_switch_on_other_model(monkeypatch): model_config_type="qwen2", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, + "attention": {"enabled": True}, }, ) @@ -269,58 +300,55 @@ def test_disabled_nvfp4_skips_gate(monkeypatch): def test_gate_refuses_qwen3_5_backward_without_attention(monkeypatch): _supported(monkeypatch, True) - with pytest.raises(ValueError, match="qwen3_5_native_attention"): + with pytest.raises(ValueError, match=r"requires attention\.enabled"): AxolotlConfigWCapabilities( **BASE, **CAPS, nvfp4_training={ "enabled": True, - "qwen3_5_native_attention_backward": True, + "attention": {"backward": {"enabled": True}}, }, ) def test_gate_refuses_qwen3_5_rtn_without_backward(monkeypatch): _supported(monkeypatch, True) - with pytest.raises(ValueError, match="qwen3_5_native_attention_backward"): + with pytest.raises(ValueError, match=r"requires attention\.backward\.enabled"): AxolotlConfigWCapabilities( **BASE, **CAPS, model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward_rtn_grad_packs": True, + "attention": {"enabled": True, "backward": {"rtn_grad_packs": True}}, }, ) def test_gate_refuses_qwen3_5_saved_packs_without_backward(monkeypatch): _supported(monkeypatch, True) - with pytest.raises(ValueError, match="qwen3_5_native_attention_backward"): + with pytest.raises(ValueError, match=r"requires attention\.backward\.enabled"): AxolotlConfigWCapabilities( **BASE, **CAPS, model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_save_backward_packs": True, + "attention": {"enabled": True, "backward": {"save_packs": True}}, }, ) def test_gate_refuses_qwen3_5_dkdv_scratch_bf16_without_backward(monkeypatch): _supported(monkeypatch, True) - with pytest.raises(ValueError, match="qwen3_5_native_attention_backward"): + with pytest.raises(ValueError, match=r"requires attention\.backward\.enabled"): AxolotlConfigWCapabilities( **BASE, **CAPS, model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_dkdv_scratch_bf16": True, + "attention": {"enabled": True, "backward": {"dkdv_scratch_bf16": True}}, }, ) @@ -336,27 +364,29 @@ def test_qwen3_5_compile_custom_op_allowed_with_backward(monkeypatch): model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward": True, - "qwen3_5_native_attention_compile_custom_op": True, + "attention": { + "enabled": True, + "backward": {"enabled": True, "compile_custom_op": True}, + }, }, ) - assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is True - assert cfg.nvfp4_training.qwen3_5_native_attention_backward is True + assert cfg.nvfp4_training.attention.backward.compile_custom_op is True + assert cfg.nvfp4_training.attention.backward.enabled is True def test_qwen3_5_compile_custom_op_requires_native_attention(monkeypatch): - # The remaining dependency gate: compile_custom_op needs native_attention on. + # The remaining dependency gate: compile_custom_op needs attention.enabled on. _supported(monkeypatch, True) - with pytest.raises(ValueError, match="requires qwen3_5_native_attention"): + with pytest.raises( + ValueError, match=r"compile_custom_op requires attention\.enabled" + ): AxolotlConfigWCapabilities( **BASE, **CAPS, model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": False, - "qwen3_5_native_attention_compile_custom_op": True, + "attention": {"backward": {"compile_custom_op": True}}, }, ) @@ -373,11 +403,10 @@ def test_qwen3_5_compile_custom_op_autoenabled_under_torch_compile(monkeypatch): torch_compile=True, nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward": True, + "attention": {"enabled": True, "backward": {"enabled": True}}, }, ) - assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is True + assert cfg.nvfp4_training.attention.backward.compile_custom_op is True def test_qwen3_5_compile_custom_op_default_off_without_torch_compile(monkeypatch): @@ -388,11 +417,10 @@ def test_qwen3_5_compile_custom_op_default_off_without_torch_compile(monkeypatch model_config_type="qwen3_5", nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward": True, + "attention": {"enabled": True, "backward": {"enabled": True}}, }, ) - assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is False + assert cfg.nvfp4_training.attention.backward.compile_custom_op is False def test_qwen3_5_compile_custom_op_explicit_optout_under_torch_compile(monkeypatch): @@ -405,12 +433,13 @@ def test_qwen3_5_compile_custom_op_explicit_optout_under_torch_compile(monkeypat torch_compile=True, nvfp4_training={ "enabled": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward": True, - "qwen3_5_native_attention_compile_custom_op": False, + "attention": { + "enabled": True, + "backward": {"enabled": True, "compile_custom_op": False}, + }, }, ) - assert cfg.nvfp4_training.qwen3_5_native_attention_compile_custom_op is False + assert cfg.nvfp4_training.attention.backward.compile_custom_op is False def _tiny_lora_model(): @@ -466,10 +495,14 @@ def fake_patch(_model, **kwargs): "nvfp4_training": { "enabled": True, "stochastic_rounding": True, - "qwen3_5_native_attention": True, - "qwen3_5_native_attention_backward": True, - "qwen3_5_native_attention_save_backward_packs": True, - "qwen3_5_native_attention_dkdv_scratch_bf16": True, + "attention": { + "enabled": True, + "backward": { + "enabled": True, + "save_packs": True, + "dkdv_scratch_bf16": True, + }, + }, }, } ) @@ -510,7 +543,7 @@ def fake_patch(**kwargs): "sample_packing": True, "nvfp4_training": { "enabled": True, - "qwen3_5_fla_causal_conv_compile_boundary": True, + "fla_causal_conv_compile_boundary": True, }, } ) From 8d3123a9380998ff1c34661ac340fcbb4b3d5a71 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 15:30:25 -0700 Subject: [PATCH 16/17] style(nvfp4): pre-commit clean PR files; drop throwaway bench scripts ruff + ruff-format across the PR's changed files; per-file mypy disable-error-code for the loosely-typed nvfp4 kernel files; nosec on the sidecar torch.load; zip strict= in the lora test. Remove the one-off bench/proof scripts (prove_*/ablate_*/bench_*/check_*) that referenced /tmp. --- scripts/ablate_proj.py | 16 - scripts/bench_block.py | 40 - scripts/bench_proj_iso.py | 55 - scripts/check_b1_parity.py | 38 - scripts/prove_attnproj.py | 57 - scripts/prove_p2.py | 30 - scripts/prove_p3.py | 27 - src/axolotl/kernels/attn_nvfp4_flash.py | 1021 ++++++++++++----- src/axolotl/kernels/bf16_fused_ce.py | 4 +- src/axolotl/kernels/lora.py | 19 +- src/axolotl/kernels/nvfp4_fused_producers.py | 174 ++- src/axolotl/kernels/nvfp4_rmsnorm.py | 6 +- src/axolotl/loaders/patch_manager.py | 7 +- .../monkeypatch/attention/nvfp4_flash_attn.py | 101 +- src/axolotl/utils/nvfp4_training.py | 308 +++-- src/axolotl/utils/schemas/config.py | 13 +- src/axolotl/utils/schemas/nvfp4.py | 15 +- tests/e2e/kernels/test_lora.py | 73 +- tests/e2e/test_nvfp4_integration.py | 44 +- .../kernels/test_nvfp4_rope_quant_strided.py | 15 +- 20 files changed, 1306 insertions(+), 757 deletions(-) delete mode 100644 scripts/ablate_proj.py delete mode 100644 scripts/bench_block.py delete mode 100644 scripts/bench_proj_iso.py delete mode 100644 scripts/check_b1_parity.py delete mode 100644 scripts/prove_attnproj.py delete mode 100644 scripts/prove_p2.py delete mode 100644 scripts/prove_p3.py diff --git a/scripts/ablate_proj.py b/scripts/ablate_proj.py deleted file mode 100644 index 4be0a5fe0c..0000000000 --- a/scripts/ablate_proj.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch, torch.nn.functional as F -from transformers import AutoModelForCausalLM -from axolotl.monkeypatch.attention.nvfp4_flash_attn import patch_qwen3_5_nvfp4_attention -dev="cuda" -m=AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(dev).eval() -torch.manual_seed(0); ids=torch.randint(0,10000,(1,1024),device=dev) -def cos(a,b): return F.cosine_similarity(a.flatten(),b.flatten(),dim=0).item() -with torch.no_grad(): base=m(ids).logits.float() -patch_qwen3_5_nvfp4_attention(m, fuse_attn_proj=True) -def setg(qk,o): - for mod in m.modules(): - if hasattr(mod,"_nvfp4_fuse_attn_proj"): mod._nvfp4_fp4_qk=qk; mod._nvfp4_fp4_o=o -for name,qk,o in [("o-only",False,True),("qk-only",True,False),("all",True,True)]: - setg(qk,o) - with torch.no_grad(): lg=m(ids).logits.float() - print(f"{name:8s}: logit cos {cos(lg,base):.5f} argmax-agree {(lg.argmax(-1)==base.argmax(-1)).float().mean().item():.3f}") diff --git a/scripts/bench_block.py b/scripts/bench_block.py deleted file mode 100644 index dd7930b122..0000000000 --- a/scripts/bench_block.py +++ /dev/null @@ -1,40 +0,0 @@ -"""End-to-end prefill attention-block compute (producers Q/K + V + flash kernel), -timed. Q/K passed as transposed (non-contiguous) views — the production layout — -so #2's strided path is exercised. Dumps output for cross-worktree bit-compare.""" -import sys, math, torch -from axolotl.kernels.nvfp4_fused_producers import fused_rope_quant_qk, quant_v_keyaxis -from axolotl.kernels.attn_nvfp4_flash import nvfp4_flash_attention_packed - -dev = "cuda"; dt = torch.bfloat16 -tag = sys.argv[1] -BLOCK_N = 128 -outs = {} -for Z, H, Hk, S, D in [(1, 16, 4, 2048, 256), (1, 16, 4, 4096, 256)]: - torch.manual_seed(0) - rot = D; sc = 1.0 / math.sqrt(D) - qb = torch.randn(Z, S, H, D, device=dev, dtype=dt) - kb = torch.randn(Z, S, Hk, D, device=dev, dtype=dt) - q_t = qb.transpose(1, 2) # [Z,H,S,D] non-contig (prod layout) - k_t = kb.transpose(1, 2) # [Z,Hk,S,D] - v = torch.randn(Z, Hk, S, D, device=dev, dtype=dt) - cos = torch.randn(Z, S, rot, device=dev, dtype=dt) - sin = torch.randn(Z, S, rot, device=dev, dtype=dt) - - def fn(): - qnv, qsc = fused_rope_quant_qk(q_t, cos, sin) - knv, ksc = fused_rope_quant_qk(k_t, cos, sin) - vnv, vsc, _ = quant_v_keyaxis(v, block_n=BLOCK_N) - return nvfp4_flash_attention_packed( - qnv, qsc, knv, ksc, vnv, vsc, z=Z, h=H, hk=Hk, s_q=S, s_kv=S, d=D, - scaling=sc, out_dtype=dt, causal=True, block_n=BLOCK_N, out_layout="zshd") - - outs[S] = fn().cpu() - torch.cuda.synchronize() - for _ in range(5): fn() - torch.cuda.synchronize() - a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() - for _ in range(50): fn() - b.record(); torch.cuda.synchronize() - print(f"S={S}: block(producers+flash) {a.elapsed_time(b)/50*1000:7.1f} us") -torch.save(outs, f"/tmp/block_{tag}.pt") -print(f"saved /tmp/block_{tag}.pt") diff --git a/scripts/bench_proj_iso.py b/scripts/bench_proj_iso.py deleted file mode 100644 index de4d95aafc..0000000000 --- a/scripts/bench_proj_iso.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Isolated: does FP4 q/k/o_proj (shared-pack) beat bf16 at the real Qwen3.5-4B -dims? This is the deciding factor for the dense-model rationale (per-GEMM, where a -dense model would hit it every layer).""" -import torch -from transformers import AutoModelForCausalLM -from axolotl.kernels.attn_nvfp4_flash import _quant_nvfp4 -from axolotl.kernels.nvfp4_linear import nvfp4_linear -from axolotl.monkeypatch.attention.nvfp4_linear_attn import ( - _gemm_from_packed_act, _get_packed_weight, -) - -dev = "cuda" -m = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", dtype=torch.bfloat16).to(dev).eval() -attn = None -for mod in m.modules(): - if all(hasattr(mod, p) for p in ("q_proj", "k_proj", "o_proj")) and type(mod.q_proj) is torch.nn.Linear: - attn = mod; break -H = attn.q_proj.in_features -qO, kO, oI, oO = attn.q_proj.out_features, attn.k_proj.out_features, attn.o_proj.in_features, attn.o_proj.out_features -print(f"dims: hidden={H} q_out={qO} k_out={kO} o_in={oI} o_out={oO}") - -M = 1024 -x = torch.randn(M, H, device=dev, dtype=torch.bfloat16) -xo = torch.randn(M, oI, device=dev, dtype=torch.bfloat16) -qw = _get_packed_weight(attn, "_q", attn.q_proj) -kw = _get_packed_weight(attn, "_k", attn.k_proj) -ow = _get_packed_weight(attn, "_o", attn.o_proj) - - -@torch.no_grad() -def bf16(): - return attn.q_proj(x), attn.k_proj(x), attn.o_proj(xo) - - -@torch.no_grad() -def fp4(): - anv, asc = _quant_nvfp4(x.unsqueeze(0)); anv, asc = anv[0], asc[0] - q = _gemm_from_packed_act(anv, asc, qw[0], qw[1], M, qO, H, torch.bfloat16) - k = _gemm_from_packed_act(anv, asc, kw[0], kw[1], M, kO, H, torch.bfloat16) - o = nvfp4_linear(xo, ow[0], ow[1], oO) - return q, k, o - - -def t(fn, it=100): - torch.cuda.synchronize() - for _ in range(10): fn() - torch.cuda.synchronize() - a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() - for _ in range(it): fn() - b.record(); torch.cuda.synchronize() - return a.elapsed_time(b) / it - - -bf, fp = t(bf16), t(fp4) -print(f"q+k+o proj (M={M}): bf16 {bf*1000:6.1f}us FP4-shared {fp*1000:6.1f}us speedup {bf/fp:.2f}x") diff --git a/scripts/check_b1_parity.py b/scripts/check_b1_parity.py deleted file mode 100644 index 1ad81abe6b..0000000000 --- a/scripts/check_b1_parity.py +++ /dev/null @@ -1,38 +0,0 @@ -"""B1 parity probe: dump dq/dk/dv + peak mem for the saved-packs FP4 backward. - -Run on BOTH the pre-B1 commit and the B1 commit (same args, SR OFF so the FP4 -backward is deterministic), then diff the two dumps — B1 only removes dead HP q/k/v -saves, so grads must be BIT-IDENTICAL and peak memory should drop. - - PYTHONPATH=/src python scripts/check_b1_parity.py /tmp/b1_.pt -""" -import sys -import torch - -from axolotl.kernels.attn_nvfp4_flash import nvfp4_flash_attn_func - -torch.manual_seed(0) -dev = "cuda" -Z, H, Hk, Sq, Skv, D = 1, 8, 2, 1024, 1024, 256 # GQA, D=256 (Qwen3.5 full-attn) -q = torch.randn(Z, H, Sq, D, device=dev, dtype=torch.bfloat16, requires_grad=True) -k = torch.randn(Z, Hk, Skv, D, device=dev, dtype=torch.bfloat16, requires_grad=True) -v = torch.randn(Z, Hk, Skv, D, device=dev, dtype=torch.bfloat16, requires_grad=True) - -torch.cuda.reset_peak_memory_stats() -out = nvfp4_flash_attn_func( - q, k, v, 1.0 / (D**0.5), - causal=True, num_key_value_groups=H // Hk, - stochastic_rounding=False, # deterministic for cross-commit bit-compare - save_backward_packs=True, - dkdv_scratch_bf16=True, -) -g = torch.randn_like(out) -out.backward(g) -peak = torch.cuda.max_memory_allocated() / 2**30 - -path = sys.argv[1] -torch.save( - {"dq": q.grad.cpu(), "dk": k.grad.cpu(), "dv": v.grad.cpu(), "peak_GiB": peak}, - path, -) -print(f"saved {path} peak_GiB={peak:.4f}") diff --git a/scripts/prove_attnproj.py b/scripts/prove_attnproj.py deleted file mode 100644 index ef246aa729..0000000000 --- a/scripts/prove_attnproj.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Model-level parity + prefill latency for FP4 q/k/o_proj (shared-pack). -Loads Qwen3.5-4B, patches full-attn layers, compares prefill logits + latency -with the proj-FP4 flag ON vs OFF (FP4-attention-only) vs unpatched bf16.""" -import torch, torch.nn.functional as F -from transformers import AutoModelForCausalLM -from axolotl.monkeypatch.attention.nvfp4_flash_attn import patch_qwen3_5_nvfp4_attention - -dev = "cuda" -m = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen3.5-4B", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", -).to(dev).eval() -torch.manual_seed(0) -ids = torch.randint(0, 10000, (1, 1024), device=dev) - - -def cos(a, b): - return F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() - - -with torch.no_grad(): - base = m(ids).logits.float() - -n = patch_qwen3_5_nvfp4_attention(m, fuse_attn_proj=True) -ok = sum(int(getattr(mod, "_nvfp4_attn_proj_ok", False)) for mod in m.modules()) -print(f"patched full-attn layers: {n}; proj-FP4-eligible: {ok}") - - -def set_flag(v): - for mod in m.modules(): - if hasattr(mod, "_nvfp4_fuse_attn_proj"): - mod._nvfp4_fuse_attn_proj = v - - -with torch.no_grad(): - set_flag(True); on = m(ids).logits.float() - set_flag(False); off = m(ids).logits.float() - -print(f"logit cos proj-FP4 ON vs bf16-unpatched : {cos(on, base):.5f}") -print(f"logit cos FP4-attn OFF vs bf16-unpatched : {cos(off, base):.5f}") -print(f"logit cos ON vs OFF (marginal proj-FP4) : {cos(on, off):.5f}") -print(f"argmax agree ON vs bf16: {(on.argmax(-1)==base.argmax(-1)).float().mean().item():.4f}") - - -def t(v, it=10): - set_flag(v) - torch.cuda.synchronize() - with torch.no_grad(): - for _ in range(3): m(ids) - torch.cuda.synchronize() - a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() - for _ in range(it): m(ids) - b.record(); torch.cuda.synchronize() - return a.elapsed_time(b) / it - - -on_ms, off_ms = t(True), t(False) -print(f"prefill 1x1024: proj-FP4 ON {on_ms:.2f}ms OFF {off_ms:.2f}ms speedup {off_ms/on_ms:.3f}x") diff --git a/scripts/prove_p2.py b/scripts/prove_p2.py deleted file mode 100644 index 3098020fb1..0000000000 --- a/scripts/prove_p2.py +++ /dev/null @@ -1,30 +0,0 @@ -"""#2 fused_rope_quant_qk: parity (strided==contiguous, bit-identical) + the -saved-copy latency on a realistic transposed (non-contiguous) Q input.""" -import torch -from axolotl.kernels.nvfp4_fused_producers import fused_rope_quant_qk - -dev = "cuda"; dt = torch.bfloat16 -torch.manual_seed(0) -for Z, H, S, D in [(1, 16, 2048, 256), (1, 16, 4096, 256)]: - rot = D - base = torch.randn(Z, S, H, D, device=dev, dtype=dt) # [Z,S,H,D] contiguous - x_t = base.transpose(1, 2) # [Z,H,S,D] non-contig (D contig) - cos = torch.randn(Z, S, rot, device=dev, dtype=dt) - sin = torch.randn(Z, S, rot, device=dev, dtype=dt) - - q_s, sc_s = fused_rope_quant_qk(x_t, cos, sin) # strided (new, no copy) - q_c, sc_c = fused_rope_quant_qk(x_t.contiguous(), cos, sin) # contiguous reference - ok = torch.equal(q_s, q_c) and torch.equal(sc_s.view(torch.uint8), sc_c.view(torch.uint8)) - - def t(fn, it=50): - torch.cuda.synchronize() - for _ in range(5): fn() - torch.cuda.synchronize() - a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() - for _ in range(it): fn() - b.record(); torch.cuda.synchronize() - return a.elapsed_time(b) / it - new = t(lambda: fused_rope_quant_qk(x_t, cos, sin)) - old = t(lambda: fused_rope_quant_qk(x_t.contiguous(), cos, sin)) - print(f"S={S}: parity_bit_identical={ok} new(strided) {new*1000:6.1f}us " - f"old(contig+kernel) {old*1000:6.1f}us speedup {old/new:.2f}x") diff --git a/scripts/prove_p3.py b/scripts/prove_p3.py deleted file mode 100644 index 6396f7c0df..0000000000 --- a/scripts/prove_p3.py +++ /dev/null @@ -1,27 +0,0 @@ -"""#3 forward V-load hoist: dump forward output (for cross-worktree bit-compare) -+ forward prefill latency. Run under each worktree's PYTHONPATH.""" -import sys, math, torch -from axolotl.kernels.attn_nvfp4_flash import nvfp4_flash_attention - -dev = "cuda"; dt = torch.bfloat16 -tag = sys.argv[1] -outs = {} -for Z, H, Hk, S, D in [(1, 16, 4, 2048, 256), (1, 16, 4, 4096, 256)]: - torch.manual_seed(0) - q = torch.randn(Z, H, S, D, device=dev, dtype=dt) - k = torch.randn(Z, Hk, S, D, device=dev, dtype=dt) - v = torch.randn(Z, Hk, S, D, device=dev, dtype=dt) - sc = 1.0 / math.sqrt(D) - fn = lambda: nvfp4_flash_attention(q, k, v, sc, causal=True, num_key_value_groups=H // Hk) - out = fn() - outs[S] = out.cpu() - - torch.cuda.synchronize() - for _ in range(5): fn() - torch.cuda.synchronize() - a = torch.cuda.Event(True); b = torch.cuda.Event(True); a.record() - for _ in range(50): fn() - b.record(); torch.cuda.synchronize() - print(f"S={S}: forward {a.elapsed_time(b)/50*1000:7.1f} us") -torch.save(outs, f"/tmp/p3_{tag}.pt") -print(f"saved /tmp/p3_{tag}.pt") diff --git a/src/axolotl/kernels/attn_nvfp4_flash.py b/src/axolotl/kernels/attn_nvfp4_flash.py index c45995a7da..25612c0fcd 100644 --- a/src/axolotl/kernels/attn_nvfp4_flash.py +++ b/src/axolotl/kernels/attn_nvfp4_flash.py @@ -64,7 +64,6 @@ import torch import triton import triton.language as tl - from mslk.quantize.triton.fp4_quantize import convert_fp32_to_fp4_packed _E4M3_EPS = tl.constexpr(1.5258789e-05) @@ -88,8 +87,11 @@ # --------------------------------------------------------------------------- @triton.jit def _pack_nvfp4_along_k( - x, base_off, seed, - ROWS: tl.constexpr, K: tl.constexpr, + x, + base_off, + seed, + ROWS: tl.constexpr, + K: tl.constexpr, STOCHASTIC: tl.constexpr, ): NG: tl.constexpr = K // 16 @@ -106,8 +108,12 @@ def _pack_nvfp4_along_k( # not the GEMMs, dominated this kernel. ax = tl.abs(xn) step = tl.where(ax < 2.0, 1.0, tl.where(ax < 4.0, 2.0, 4.0)) - off = base_off + tl.arange(0, ROWS)[:, None, None] * K + \ - tl.arange(0, NG)[None, :, None] * 16 + tl.arange(0, 16)[None, None, :] + off = ( + base_off + + tl.arange(0, ROWS)[:, None, None] * K + + tl.arange(0, NG)[None, :, None] * 16 + + tl.arange(0, 16)[None, None, :] + ) u = tl.rand(seed, off) - 0.5 xn = xn + u * step xn = tl.clamp(xn, -_F4_MAX, _F4_MAX) @@ -123,12 +129,21 @@ def _pack_nvfp4_along_k( # --------------------------------------------------------------------------- @triton.jit def _quant_nvfp4_kernel( - x_ptr, q_ptr, s_ptr, - R, K, K_READ, - s_xb, s_qb, s_sb, # per-batch strides - s_xr, s_xk, # input row / contraction-col strides (transpose w/o copy) - s_qr, s_sr, # per-row strides (= K//2, K//16) - BLOCK_R: tl.constexpr, BLOCK_K: tl.constexpr, + x_ptr, + q_ptr, + s_ptr, + R, + K, + K_READ, + s_xb, + s_qb, + s_sb, # per-batch strides + s_xr, + s_xk, # input row / contraction-col strides (transpose w/o copy) + s_qr, + s_sr, # per-row strides (= K//2, K//16) + BLOCK_R: tl.constexpr, + BLOCK_K: tl.constexpr, ): pid_b = tl.program_id(0) pid_r = tl.program_id(1) @@ -139,7 +154,8 @@ def _quant_nvfp4_kernel( kmask = offs_k < K_READ x = tl.load( x_ptr + pid_b * s_xb + offs_r[:, None] * s_xr + offs_k[None, :] * s_xk, - mask=rmask[:, None] & kmask[None, :], other=0.0, + mask=rmask[:, None] & kmask[None, :], + other=0.0, ).to(tl.float32) NG: tl.constexpr = BLOCK_K // 16 @@ -153,12 +169,14 @@ def _quant_nvfp4_kernel( offs_qk = pid_k * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2) tl.store( q_ptr + pid_b * s_qb + offs_r[:, None] * s_qr + offs_qk[None, :], - qpk, mask=rmask[:, None] & (offs_qk[None, :] < s_qr), + qpk, + mask=rmask[:, None] & (offs_qk[None, :] < s_qr), ) offs_sk = pid_k * NG + tl.arange(0, NG) tl.store( s_ptr + pid_b * s_sb + offs_r[:, None] * s_sr + offs_sk[None, :], - sc.to(tl.uint8, bitcast=True), mask=rmask[:, None] & (offs_sk[None, :] < s_sr), + sc.to(tl.uint8, bitcast=True), + mask=rmask[:, None] & (offs_sk[None, :] < s_sr), ) @@ -181,7 +199,7 @@ def _quant_nvfp4( non-power-of-2 key axis is supported. """ if transpose: - B, K_read, R = x.shape # contraction axis is the middle dim (Skv) + B, K_read, R = x.shape # contraction axis is the middle dim (Skv) s_xr, s_xk = x.stride(2), x.stride(1) s_xb = x.stride(0) else: @@ -197,11 +215,21 @@ def _quant_nvfp4( BLOCK_K = min(triton.next_power_of_2(K), 256) grid = (B, triton.cdiv(R, BLOCK_R), triton.cdiv(K, BLOCK_K)) _quant_nvfp4_kernel[grid]( - x, q, s, R, K, K_read, - s_xb, q.stride(0), s.stride(0), - s_xr, s_xk, - K // 2, K // 16, - BLOCK_R=BLOCK_R, BLOCK_K=BLOCK_K, + x, + q, + s, + R, + K, + K_read, + s_xb, + q.stride(0), + s.stride(0), + s_xr, + s_xk, + K // 2, + K // 16, + BLOCK_R=BLOCK_R, + BLOCK_K=BLOCK_K, ) return q, s.view(torch.float8_e4m3fn) @@ -221,15 +249,26 @@ def _quant_nvfp4( @triton.jit def _quant_nvfp4_dual_kernel( x_ptr, - qa_ptr, sa_ptr, # layout A (along D): [B, S, D//2], [B, S, D//16] - qb_ptr, sb_ptr, # layout B (along S): [B, D, S_pad//2], [B, D, S_pad//16] - S, D, S_PAD, - s_xb, s_xs, s_xd, # source strides - s_qab, s_qar, # layout A: batch stride, per-row (S) stride (= D//2) - s_sab, s_sar, # layout A scale: batch stride, per-row stride (= D//16) - s_qbb, s_qbr, # layout B: batch stride, per-row (D) stride (= S_pad//2) - s_sbb, s_sbr, # layout B scale: batch stride, per-row stride (= S_pad//16) - BLOCK_S: tl.constexpr, BLOCK_D: tl.constexpr, + qa_ptr, + sa_ptr, # layout A (along D): [B, S, D//2], [B, S, D//16] + qb_ptr, + sb_ptr, # layout B (along S): [B, D, S_pad//2], [B, D, S_pad//16] + S, + D, + S_PAD, + s_xb, + s_xs, + s_xd, # source strides + s_qab, + s_qar, # layout A: batch stride, per-row (S) stride (= D//2) + s_sab, + s_sar, # layout A scale: batch stride, per-row stride (= D//16) + s_qbb, + s_qbr, # layout B: batch stride, per-row (D) stride (= S_pad//2) + s_sbb, + s_sbr, # layout B scale: batch stride, per-row stride (= S_pad//16) + BLOCK_S: tl.constexpr, + BLOCK_D: tl.constexpr, ): pid_b = tl.program_id(0) pid_s = tl.program_id(1) @@ -239,7 +278,8 @@ def _quant_nvfp4_dual_kernel( # one read of the source tile; padded S rows -> 0.0 (amax 0 -> eps scale -> zero pack) x = tl.load( x_ptr + pid_b * s_xb + offs_s[:, None] * s_xs + offs_d[None, :] * s_xd, - mask=smask[:, None], other=0.0, + mask=smask[:, None], + other=0.0, ).to(tl.float32) # layout A: group-16 along D, one pack per S-row @@ -253,12 +293,14 @@ def _quant_nvfp4_dual_kernel( offs_ad = tl.arange(0, BLOCK_D // 2) tl.store( qa_ptr + pid_b * s_qab + offs_s[:, None] * s_qar + offs_ad[None, :], - qa, mask=smask[:, None], + qa, + mask=smask[:, None], ) offs_asg = tl.arange(0, NGA) tl.store( sa_ptr + pid_b * s_sab + offs_s[:, None] * s_sar + offs_asg[None, :], - sca.to(tl.uint8, bitcast=True), mask=smask[:, None], + sca.to(tl.uint8, bitcast=True), + mask=smask[:, None], ) # layout B: group-16 along S, one pack per D-row (transpose the resident tile) @@ -273,12 +315,14 @@ def _quant_nvfp4_dual_kernel( offs_bs = pid_s * (BLOCK_S // 2) + tl.arange(0, BLOCK_S // 2) tl.store( qb_ptr + pid_b * s_qbb + offs_d[:, None] * s_qbr + offs_bs[None, :], - qb, mask=offs_bs[None, :] < (S_PAD // 2), + qb, + mask=offs_bs[None, :] < (S_PAD // 2), ) offs_bsg = pid_s * NGB + tl.arange(0, NGB) tl.store( sb_ptr + pid_b * s_sbb + offs_d[:, None] * s_sbr + offs_bsg[None, :], - scb.to(tl.uint8, bitcast=True), mask=offs_bsg[None, :] < (S_PAD // 16), + scb.to(tl.uint8, bitcast=True), + mask=offs_bsg[None, :] < (S_PAD // 16), ) @@ -309,14 +353,26 @@ def _quant_nvfp4_dual( grid = (B, triton.cdiv(s_pad, BLOCK_S)) _quant_nvfp4_dual_kernel[grid]( x, - qa, sa, qb, sb, - S, D, s_pad, - x.stride(0), x.stride(1), x.stride(2), - qa.stride(0), qa.stride(1), - sa.stride(0), sa.stride(1), - qb.stride(0), qb.stride(1), - sb.stride(0), sb.stride(1), - BLOCK_S=BLOCK_S, BLOCK_D=D, + qa, + sa, + qb, + sb, + S, + D, + s_pad, + x.stride(0), + x.stride(1), + x.stride(2), + qa.stride(0), + qa.stride(1), + sa.stride(0), + sa.stride(1), + qb.stride(0), + qb.stride(1), + sb.stride(0), + sb.stride(1), + BLOCK_S=BLOCK_S, + BLOCK_D=D, ) return qa, sa.view(torch.float8_e4m3fn), qb, sb.view(torch.float8_e4m3fn) @@ -327,35 +383,47 @@ def _quant_nvfp4_dual( # --------------------------------------------------------------------------- @triton.jit def _flash_fwd_kernel( - qnv_ptr, qsc_ptr, # [Z*H, Sq, D//2], [Z*H, Sq, D//16] - knv_ptr, ksc_ptr, # [Z*Hk, Skv, D//2], [Z*Hk, Skv, D//16] - vnv_ptr, vsc_ptr, # [Z*Hk, D, Skv//2], [Z*Hk, D, Skv//16] (V^T, quant on key) - bias_ptr, # [Z, Skv] fp32 additive key-pad bias, or 0 - out_ptr, # [Z*H, Sq, D] (default) or [Z, Sq, H, D] (OUT_ZSHD) - lse_ptr, # [Z*H, Sq] fp32 logsumexp, written iff STORE_LSE + qnv_ptr, + qsc_ptr, # [Z*H, Sq, D//2], [Z*H, Sq, D//16] + knv_ptr, + ksc_ptr, # [Z*Hk, Skv, D//2], [Z*Hk, Skv, D//16] + vnv_ptr, + vsc_ptr, # [Z*Hk, D, Skv//2], [Z*Hk, D, Skv//16] (V^T, quant on key) + bias_ptr, # [Z, Skv] fp32 additive key-pad bias, or 0 + out_ptr, # [Z*H, Sq, D] (default) or [Z, Sq, H, D] (OUT_ZSHD) + lse_ptr, # [Z*H, Sq] fp32 logsumexp, written iff STORE_LSE scaling, - Sq, Skv, + Sq, + Skv, D: tl.constexpr, - H: tl.constexpr, HK: tl.constexpr, - sq_qn, sq_sn, - sk_kn, sk_sn, - sv_kn, sv_sn, + H: tl.constexpr, + HK: tl.constexpr, + sq_qn, + sq_sn, + sk_kn, + sk_sn, + sv_kn, + sv_sn, sb_z, - so_n, # out row (Sq-axis) stride: D ([Z*H,Sq,D]) or H*D ([Z,Sq,H,D]) - so_z, so_h, # out z / head strides, used only when OUT_ZSHD + so_n, # out row (Sq-axis) stride: D ([Z*H,Sq,D]) or H*D ([Z,Sq,H,D]) + so_z, + so_h, # out z / head strides, used only when OUT_ZSHD HAS_BIAS: tl.constexpr, CAUSAL: tl.constexpr, STORE_LSE: tl.constexpr, - OUT_ZSHD: tl.constexpr, # store the [Z, Sq, H, D] layout directly (no transpose+copy) - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - DP2: tl.constexpr, DP16: tl.constexpr, # D//2, D//16 - NP2: tl.constexpr, NP16: tl.constexpr, # BLOCK_N//2, BLOCK_N//16 + OUT_ZSHD: tl.constexpr, # store the [Z, Sq, H, D] layout directly (no transpose+copy) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + DP2: tl.constexpr, + DP16: tl.constexpr, # D//2, D//16 + NP2: tl.constexpr, + NP16: tl.constexpr, # BLOCK_N//2, BLOCK_N//16 ): pid_m = tl.program_id(0) pid_zh = tl.program_id(1) z = pid_zh // H h = pid_zh % H - zhk = z * HK + (h // (H // HK)) # GQA: query head -> kv head + zhk = z * HK + (h // (H // HK)) # GQA: query head -> kv head offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_dp = tl.arange(0, DP2) @@ -367,11 +435,13 @@ def _flash_fwd_kernel( mmask = offs_m < Sq qnv = tl.load( qnv_ptr + qbase + offs_m[:, None] * sq_qn + offs_dp[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ) qsc = tl.load( qsc_ptr + qscbase + offs_m[:, None] * sq_sn + offs_dsc[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) m_i = tl.full((BLOCK_M,), _NEG_INF, dtype=tl.float32) @@ -401,11 +471,13 @@ def _flash_fwd_kernel( # load packed K tile [BLOCK_N, D//2] + scale [BLOCK_N, D//16] knv = tl.load( knv_ptr + kbase + offs_n[:, None] * sk_kn + offs_dp[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ) ksc = tl.load( ksc_ptr + kscbase + offs_n[:, None] * sk_sn + offs_dsc[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) # QK^T via native NVFP4: [BLOCK_M, BLOCK_N] s = tl.dot_scaled(qnv, qsc, "e2m1", knv.T, ksc, "e2m1") @@ -422,13 +494,13 @@ def _flash_fwd_kernel( # online softmax m_new = tl.maximum(m_i, tl.max(s, axis=1)) alpha = tl.exp(m_i - m_new) - p = tl.exp(s - m_new[:, None]) # [BLOCK_M, BLOCK_N], >=0 + p = tl.exp(s - m_new[:, None]) # [BLOCK_M, BLOCK_N], >=0 l_i = l_i * alpha + tl.sum(p, axis=1) acc = acc * alpha[:, None] # in-kernel NVFP4 pack of P along the key axis (group-16) pb = p.reshape(BLOCK_M, NP16, 16) - pamax = tl.max(pb, axis=2) # P>=0 + pamax = tl.max(pb, axis=2) # P>=0 psc = tl.clamp(pamax / _F4_MAX, _E4M3_EPS, _F8E4M3_MAX).to(tl.float8e4nv) pn = pb / psc.to(tl.float32)[:, :, None] ppairs = pn.reshape(BLOCK_M * NP2, 2).split() @@ -436,10 +508,16 @@ def _flash_fwd_kernel( # load packed V^T tile: [D, BLOCK_N//2] + scale [D, BLOCK_N//16] vnv = tl.load( - vnv_ptr + vbase + offs_d[:, None] * sv_kn + (start_n // 2 + offs_np)[None, :], + vnv_ptr + + vbase + + offs_d[:, None] * sv_kn + + (start_n // 2 + offs_np)[None, :], ) vsc = tl.load( - vsc_ptr + vscbase + offs_d[:, None] * sv_sn + (start_n // 16 + offs_nsc)[None, :], + vsc_ptr + + vscbase + + offs_d[:, None] * sv_sn + + (start_n // 16 + offs_nsc)[None, :], ).to(tl.float8e4nv, bitcast=True) # P @ V via native NVFP4: a=P [BLOCK_M, BLOCK_N], b=V^T loaded [D, BLOCK_N//2] acc = tl.dot_scaled(pq, psc, "e2m1", vnv.T, vsc, "e2m1", acc=acc) @@ -455,7 +533,8 @@ def _flash_fwd_kernel( obase = pid_zh * (Sq * so_n) tl.store( out_ptr + obase + offs_m[:, None] * so_n + offs_d[None, :], - acc.to(out_ptr.dtype.element_ty), mask=mmask[:, None], + acc.to(out_ptr.dtype.element_ty), + mask=mmask[:, None], ) if STORE_LSE: # Persist logsumexp so the backward prep can skip its full QK^T recompute. @@ -495,14 +574,30 @@ def _resolve_fwd_tiles(d, block_m, block_n, num_warps, num_stages): # --------------------------------------------------------------------------- @triton.jit def _flash_bwd_prep_kernel( - q_ptr, k_ptr, - do_ptr, o_ptr, bias_ptr, - lse_ptr, delta_ptr, - scaling, seed, Sq, Skv, - D: tl.constexpr, H: tl.constexpr, HK: tl.constexpr, - sq_n, sk_n, sdo_n, so_n, sb_z, - HAS_BIAS: tl.constexpr, CAUSAL: tl.constexpr, HAVE_LSE: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + q_ptr, + k_ptr, + do_ptr, + o_ptr, + bias_ptr, + lse_ptr, + delta_ptr, + scaling, + seed, + Sq, + Skv, + D: tl.constexpr, + H: tl.constexpr, + HK: tl.constexpr, + sq_n, + sk_n, + sdo_n, + so_n, + sb_z, + HAS_BIAS: tl.constexpr, + CAUSAL: tl.constexpr, + HAVE_LSE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, ): pid_m = tl.program_id(0) pid_zh = tl.program_id(1) @@ -521,7 +616,8 @@ def _flash_bwd_prep_kernel( else: q = tl.load( q_ptr + pid_zh * (Sq * sq_n) + offs_m[:, None] * sq_n + offs_d[None, :], - mask=mmask[:, None], other=0.0, + mask=mmask[:, None], + other=0.0, ).to(tl.float32) qnv, qsc = _pack_nvfp4_along_k(q, 0, seed, BLOCK_M, D, False) @@ -536,7 +632,8 @@ def _flash_bwd_prep_kernel( nmask = offs_n < Skv k = tl.load( k_ptr + zhk * (Skv * sk_n) + offs_n[:, None] * sk_n + offs_d[None, :], - mask=nmask[:, None], other=0.0, + mask=nmask[:, None], + other=0.0, ).to(tl.float32) knv, ksc = _pack_nvfp4_along_k(k, 0, seed, BLOCK_N, D, False) s = tl.dot_scaled(qnv, qsc, "e2m1", knv.T, ksc, "e2m1") * scaling @@ -559,11 +656,13 @@ def _flash_bwd_prep_kernel( do = tl.load( do_ptr + pid_zh * (Sq * sdo_n) + offs_m[:, None] * sdo_n + offs_d[None, :], - mask=mmask[:, None], other=0.0, + mask=mmask[:, None], + other=0.0, ).to(tl.float32) o = tl.load( o_ptr + pid_zh * (Sq * so_n) + offs_m[:, None] * so_n + offs_d[None, :], - mask=mmask[:, None], other=0.0, + mask=mmask[:, None], + other=0.0, ).to(tl.float32) delta = tl.sum(do * o, axis=1) @@ -586,14 +685,30 @@ def _flash_bwd_prep_kernel( # --------------------------------------------------------------------------- @triton.jit def _flash_bwd_packprep_kernel( - q_ptr, do_ptr, o_ptr, delta_ptr, - qnv_ptr, qsc_ptr, qtnv_ptr, qtsc_ptr, - donv_ptr, dosc_ptr, dotnv_ptr, dotsc_ptr, - seed, Sq, Sq_pad, + q_ptr, + do_ptr, + o_ptr, + delta_ptr, + qnv_ptr, + qsc_ptr, + qtnv_ptr, + qtsc_ptr, + donv_ptr, + dosc_ptr, + dotnv_ptr, + dotsc_ptr, + seed, + Sq, + Sq_pad, D: tl.constexpr, - sq_n, sdo_n, so_n, - SR_DO: tl.constexpr, SR_DOT: tl.constexpr, WRITE_DELTA: tl.constexpr, - STORE_Q: tl.constexpr, STORE_QT: tl.constexpr, + sq_n, + sdo_n, + so_n, + SR_DO: tl.constexpr, + SR_DOT: tl.constexpr, + WRITE_DELTA: tl.constexpr, + STORE_Q: tl.constexpr, + STORE_QT: tl.constexpr, BLOCK_M: tl.constexpr, ): pid_m = tl.program_id(0) @@ -614,7 +729,8 @@ def _flash_bwd_packprep_kernel( if STORE_Q or STORE_QT: q = tl.load( q_ptr + pid_zh * (Sq * sq_n) + offs_m[:, None] * sq_n + offs_d[None, :], - mask=mmask[:, None], other=0.0, + mask=mmask[:, None], + other=0.0, ).to(tl.float32) if STORE_Q: qnv, qsc = _pack_nvfp4_along_k(q, 0, seed, BLOCK_M, D, False) @@ -626,11 +742,13 @@ def _flash_bwd_packprep_kernel( if STORE_Q: tl.store( qnv_ptr + pid_zh * (Sq * DP2) + offs_m[:, None] * DP2 + offs_dp[None, :], - qnv, mask=mmask[:, None], + qnv, + mask=mmask[:, None], ) tl.store( qsc_ptr + pid_zh * (Sq * DP16) + offs_m[:, None] * DP16 + offs_dsc[None, :], - qsc.to(tl.uint8, bitcast=True), mask=mmask[:, None], + qsc.to(tl.uint8, bitcast=True), + mask=mmask[:, None], ) sq2 = Sq_pad // 2 sq16 = Sq_pad // 16 @@ -646,12 +764,14 @@ def _flash_bwd_packprep_kernel( do = tl.load( do_ptr + pid_zh * (Sq * sdo_n) + offs_m[:, None] * sdo_n + offs_d[None, :], - mask=mmask[:, None], other=0.0, + mask=mmask[:, None], + other=0.0, ).to(tl.float32) if WRITE_DELTA: o = tl.load( o_ptr + pid_zh * (Sq * so_n) + offs_m[:, None] * so_n + offs_d[None, :], - mask=mmask[:, None], other=0.0, + mask=mmask[:, None], + other=0.0, ).to(tl.float32) delta = tl.sum(do * o, axis=1) tl.store(delta_ptr + pid_zh * Sq + offs_m, delta, mask=mmask) @@ -664,11 +784,13 @@ def _flash_bwd_packprep_kernel( tl.store( donv_ptr + pid_zh * (Sq * DP2) + offs_m[:, None] * DP2 + offs_dp[None, :], - donv, mask=mmask[:, None], + donv, + mask=mmask[:, None], ) tl.store( dosc_ptr + pid_zh * (Sq * DP16) + offs_m[:, None] * DP16 + offs_dsc[None, :], - dosc.to(tl.uint8, bitcast=True), mask=mmask[:, None], + dosc.to(tl.uint8, bitcast=True), + mask=mmask[:, None], ) tl.store( dotnv_ptr + pid_zh * (D * sq2) + offs_d[:, None] * sq2 + offs_mp[None, :], @@ -688,12 +810,23 @@ def _flash_bwd_packprep_kernel( # --------------------------------------------------------------------------- @triton.jit def _flash_bwd_kprep_kernel( - k_ptr, v_ptr, - knv_ptr, ksc_ptr, vnv_ptr, vsc_ptr, ktnv_ptr, ktsc_ptr, - seed, Skv, Skv_pad, + k_ptr, + v_ptr, + knv_ptr, + ksc_ptr, + vnv_ptr, + vsc_ptr, + ktnv_ptr, + ktsc_ptr, + seed, + Skv, + Skv_pad, D: tl.constexpr, - sk_n, sv_n, - STORE_K: tl.constexpr, STORE_V: tl.constexpr, STORE_KT: tl.constexpr, + sk_n, + sv_n, + STORE_K: tl.constexpr, + STORE_V: tl.constexpr, + STORE_KT: tl.constexpr, BLOCK_N: tl.constexpr, ): pid_n = tl.program_id(0) @@ -705,7 +838,8 @@ def _flash_bwd_kprep_kernel( if STORE_K or STORE_KT: k = tl.load( k_ptr + pid_zhk * (Skv * sk_n) + offs_n[:, None] * sk_n + offs_d[None, :], - mask=nmask[:, None], other=0.0, + mask=nmask[:, None], + other=0.0, ).to(tl.float32) if STORE_K: knv, ksc = _pack_nvfp4_along_k(k, 0, seed, BLOCK_N, D, False) @@ -714,7 +848,8 @@ def _flash_bwd_kprep_kernel( if STORE_V: v = tl.load( v_ptr + pid_zhk * (Skv * sv_n) + offs_n[:, None] * sv_n + offs_d[None, :], - mask=nmask[:, None], other=0.0, + mask=nmask[:, None], + other=0.0, ).to(tl.float32) vnv, vsc = _pack_nvfp4_along_k(v, 0, seed, BLOCK_N, D, False) @@ -730,20 +865,30 @@ def _flash_bwd_kprep_kernel( if STORE_K: tl.store( knv_ptr + pid_zhk * (Skv * DP2) + offs_n[:, None] * DP2 + offs_dp[None, :], - knv, mask=nmask[:, None], + knv, + mask=nmask[:, None], ) tl.store( - ksc_ptr + pid_zhk * (Skv * DP16) + offs_n[:, None] * DP16 + offs_dsc[None, :], - ksc.to(tl.uint8, bitcast=True), mask=nmask[:, None], + ksc_ptr + + pid_zhk * (Skv * DP16) + + offs_n[:, None] * DP16 + + offs_dsc[None, :], + ksc.to(tl.uint8, bitcast=True), + mask=nmask[:, None], ) if STORE_V: tl.store( vnv_ptr + pid_zhk * (Skv * DP2) + offs_n[:, None] * DP2 + offs_dp[None, :], - vnv, mask=nmask[:, None], + vnv, + mask=nmask[:, None], ) tl.store( - vsc_ptr + pid_zhk * (Skv * DP16) + offs_n[:, None] * DP16 + offs_dsc[None, :], - vsc.to(tl.uint8, bitcast=True), mask=nmask[:, None], + vsc_ptr + + pid_zhk * (Skv * DP16) + + offs_n[:, None] * DP16 + + offs_dsc[None, :], + vsc.to(tl.uint8, bitcast=True), + mask=nmask[:, None], ) sk2 = Skv_pad // 2 sk16 = Skv_pad // 16 @@ -753,7 +898,10 @@ def _flash_bwd_kprep_kernel( kTnv, ) tl.store( - ktsc_ptr + pid_zhk * (D * sk16) + offs_d[:, None] * sk16 + offs_nsc[None, :], + ktsc_ptr + + pid_zhk * (D * sk16) + + offs_d[:, None] * sk16 + + offs_nsc[None, :], kTsc.to(tl.uint8, bitcast=True), ) @@ -778,17 +926,40 @@ def _flash_bwd_kprep_kernel( # --------------------------------------------------------------------------- @triton.jit def _flash_bwd_dkdv_kernel( - qnv_ptr, qsc_ptr, qtnv_ptr, qtsc_ptr, - donv_ptr, dosc_ptr, dotnv_ptr, dotsc_ptr, - knv_ptr, ksc_ptr, vnv_ptr, vsc_ptr, bias_ptr, - lse_ptr, delta_ptr, - dk_ptr, dv_ptr, - scaling, seed, Sq, Sq_pad, Skv, - D: tl.constexpr, H: tl.constexpr, HK: tl.constexpr, - sb_z, sdk_n, sdv_n, - HAS_BIAS: tl.constexpr, CAUSAL: tl.constexpr, - SR: tl.constexpr, SR_P_DV: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + qnv_ptr, + qsc_ptr, + qtnv_ptr, + qtsc_ptr, + donv_ptr, + dosc_ptr, + dotnv_ptr, + dotsc_ptr, + knv_ptr, + ksc_ptr, + vnv_ptr, + vsc_ptr, + bias_ptr, + lse_ptr, + delta_ptr, + dk_ptr, + dv_ptr, + scaling, + seed, + Sq, + Sq_pad, + Skv, + D: tl.constexpr, + H: tl.constexpr, + HK: tl.constexpr, + sb_z, + sdk_n, + sdv_n, + HAS_BIAS: tl.constexpr, + CAUSAL: tl.constexpr, + SR: tl.constexpr, + SR_P_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, ): pid_n = tl.program_id(0) pid_zh = tl.program_id(1) @@ -813,19 +984,23 @@ def _flash_bwd_dkdv_kernel( knv = tl.load( knv_ptr + zhk * (Skv * DP2) + offs_n[:, None] * DP2 + offs_dp[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ) ksc = tl.load( ksc_ptr + zhk * (Skv * DP16) + offs_n[:, None] * DP16 + offs_dsc[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) vnv = tl.load( vnv_ptr + zhk * (Skv * DP2) + offs_n[:, None] * DP2 + offs_dp[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ) vsc = tl.load( vsc_ptr + zhk * (Skv * DP16) + offs_n[:, None] * DP16 + offs_dsc[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) dk = tl.zeros((BLOCK_N, D), dtype=tl.float32) @@ -842,11 +1017,13 @@ def _flash_bwd_dkdv_kernel( # load precomputed FP4 packs (quantized once in the pack-prep pass) qnv = tl.load( qnv_ptr + pid_zh * (Sq * DP2) + offs_m[:, None] * DP2 + offs_dp[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ) qsc = tl.load( qsc_ptr + pid_zh * (Sq * DP16) + offs_m[:, None] * DP16 + offs_dsc[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) mp = (start_m // 2) + offs_mp0 msc = (start_m // 16) + offs_msc0 @@ -866,9 +1043,7 @@ def _flash_bwd_dkdv_kernel( pT = tl.where(sT == _NEG_INF, 0.0, pT) # dV += pT @ dO^T.T (contract M). pT [BLOCK_N, BLOCK_M] (SR), dO^T precomputed. - pT_q, pT_s = _pack_nvfp4_along_k( - pT, start_m, seed, BLOCK_N, BLOCK_M, SR_P_DV - ) + pT_q, pT_s = _pack_nvfp4_along_k(pT, start_m, seed, BLOCK_N, BLOCK_M, SR_P_DV) dotnv = tl.load( dotnv_ptr + pid_zh * (D * sq2) + offs_d[:, None] * sq2 + mp[None, :], ) @@ -880,11 +1055,16 @@ def _flash_bwd_dkdv_kernel( # dPt[n,m] = sum_d V[n,d] dO[m,d] (contract D). dO precomputed (SR). donv = tl.load( donv_ptr + pid_zh * (Sq * DP2) + offs_m[:, None] * DP2 + offs_dp[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ) dosc = tl.load( - dosc_ptr + pid_zh * (Sq * DP16) + offs_m[:, None] * DP16 + offs_dsc[None, :], - mask=mmask[:, None], other=0, + dosc_ptr + + pid_zh * (Sq * DP16) + + offs_m[:, None] * DP16 + + offs_dsc[None, :], + mask=mmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) dpT = tl.dot_scaled(vnv, vsc, "e2m1", donv.T, dosc, "e2m1") @@ -892,7 +1072,9 @@ def _flash_bwd_dkdv_kernel( dsT = tl.where(pT == 0.0, 0.0, dsT) # dK += dSt @ Q^T.T (contract M). dSt [BLOCK_N, BLOCK_M] (SR), Q^T precomputed (RTN). - dsT_q, dsT_s = _pack_nvfp4_along_k(dsT, start_m + 3 * Sq, seed, BLOCK_N, BLOCK_M, SR) + dsT_q, dsT_s = _pack_nvfp4_along_k( + dsT, start_m + 3 * Sq, seed, BLOCK_N, BLOCK_M, SR + ) qtnv = tl.load( qtnv_ptr + pid_zh * (D * sq2) + offs_d[:, None] * sq2 + mp[None, :], ) @@ -903,11 +1085,13 @@ def _flash_bwd_dkdv_kernel( tl.store( dk_ptr + pid_zh * (Skv * sdk_n) + offs_n[:, None] * sdk_n + offs_d[None, :], - dk.to(dk_ptr.dtype.element_ty), mask=nmask[:, None], + dk.to(dk_ptr.dtype.element_ty), + mask=nmask[:, None], ) tl.store( dv_ptr + pid_zh * (Skv * sdv_n) + offs_n[:, None] * sdv_n + offs_d[None, :], - dv.to(dv_ptr.dtype.element_ty), mask=nmask[:, None], + dv.to(dv_ptr.dtype.element_ty), + mask=nmask[:, None], ) @@ -918,14 +1102,35 @@ def _flash_bwd_dkdv_kernel( # --------------------------------------------------------------------------- @triton.jit def _flash_bwd_dq_kernel( - qnv_ptr, qsc_ptr, donv_ptr, dosc_ptr, bias_ptr, - knv_ptr, ksc_ptr, vnv_ptr, vsc_ptr, ktnv_ptr, ktsc_ptr, - lse_ptr, delta_ptr, dq_ptr, - scaling, seed, Sq, Skv, Skv_pad, - D: tl.constexpr, H: tl.constexpr, HK: tl.constexpr, - sb_z, sdq_n, - HAS_BIAS: tl.constexpr, CAUSAL: tl.constexpr, SR_DS_DQ: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + qnv_ptr, + qsc_ptr, + donv_ptr, + dosc_ptr, + bias_ptr, + knv_ptr, + ksc_ptr, + vnv_ptr, + vsc_ptr, + ktnv_ptr, + ktsc_ptr, + lse_ptr, + delta_ptr, + dq_ptr, + scaling, + seed, + Sq, + Skv, + Skv_pad, + D: tl.constexpr, + H: tl.constexpr, + HK: tl.constexpr, + sb_z, + sdq_n, + HAS_BIAS: tl.constexpr, + CAUSAL: tl.constexpr, + SR_DS_DQ: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, ): pid_m = tl.program_id(0) pid_zh = tl.program_id(1) @@ -950,19 +1155,23 @@ def _flash_bwd_dq_kernel( qnv = tl.load( qnv_ptr + pid_zh * (Sq * DP2) + offs_m[:, None] * DP2 + offs_dp[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ) qsc = tl.load( qsc_ptr + pid_zh * (Sq * DP16) + offs_m[:, None] * DP16 + offs_dsc[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) do_q = tl.load( donv_ptr + pid_zh * (Sq * DP2) + offs_m[:, None] * DP2 + offs_dp[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ) do_s = tl.load( dosc_ptr + pid_zh * (Sq * DP16) + offs_m[:, None] * DP16 + offs_dsc[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) lse = tl.load(lse_ptr + pid_zh * Sq + offs_m, mask=mmask, other=0.0) delta = tl.load(delta_ptr + pid_zh * Sq + offs_m, mask=mmask, other=0.0) @@ -979,11 +1188,13 @@ def _flash_bwd_dq_kernel( # precomputed K-side packs: load each layout close to the GEMM that uses it. knv = tl.load( knv_ptr + zhk * (Skv * DP2) + offs_n[:, None] * DP2 + offs_dp[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ) ksc = tl.load( ksc_ptr + zhk * (Skv * DP16) + offs_n[:, None] * DP16 + offs_dsc[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) s = tl.dot_scaled(qnv, qsc, "e2m1", knv.T, ksc, "e2m1") * scaling @@ -999,11 +1210,13 @@ def _flash_bwd_dq_kernel( vnv = tl.load( vnv_ptr + zhk * (Skv * DP2) + offs_n[:, None] * DP2 + offs_dp[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ) vsc = tl.load( vsc_ptr + zhk * (Skv * DP16) + offs_n[:, None] * DP16 + offs_dsc[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) dp = tl.dot_scaled(do_q, do_s, "e2m1", vnv.T, vsc, "e2m1") ds = p * (dp - delta[:, None]) * scaling @@ -1025,15 +1238,24 @@ def _flash_bwd_dq_kernel( tl.store( dq_ptr + pid_zh * (Sq * sdq_n) + offs_m[:, None] * sdq_n + offs_d[None, :], - dq.to(dq_ptr.dtype.element_ty), mask=mmask[:, None], + dq.to(dq_ptr.dtype.element_ty), + mask=mmask[:, None], ) @triton.jit def _gqa_reduce_cast_dkdv_kernel( - dk_ptr, dv_ptr, dk_out_ptr, dv_out_ptr, - Skv, D: tl.constexpr, H: tl.constexpr, HK: tl.constexpr, NG: tl.constexpr, - BLOCK_S: tl.constexpr, BLOCK_D: tl.constexpr, + dk_ptr, + dv_ptr, + dk_out_ptr, + dv_out_ptr, + Skv, + D: tl.constexpr, + H: tl.constexpr, + HK: tl.constexpr, + NG: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_D: tl.constexpr, ): pid_s = tl.program_id(0) pid_d = tl.program_id(1) @@ -1078,19 +1300,44 @@ def _gqa_reduce_cast_dkdv( _gqa_reduce_cast_dkdv_kernel[ (triton.cdiv(s_kv, block_s), triton.cdiv(d, block_d), z * hk) ]( - dk, dv, dk_out, dv_out, - s_kv, D=d, H=h, HK=hk, NG=ng, - BLOCK_S=block_s, BLOCK_D=block_d, - num_warps=4, num_stages=2, + dk, + dv, + dk_out, + dv_out, + s_kv, + D=d, + H=h, + HK=hk, + NG=ng, + BLOCK_S=block_s, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, ) return dk_out, dv_out def _run_flash_packed( - qnv, qsc, knv, ksc, vnv, vsc, - z, h, hk, s_q, s_kv, d, - scaling, causal, bias, out, - block_m, block_n, num_warps, num_stages, + qnv, + qsc, + knv, + ksc, + vnv, + vsc, + z, + h, + hk, + s_q, + s_kv, + d, + scaling, + causal, + bias, + out, + block_m, + block_n, + num_warps, + num_stages, out_zshd=False, ): """Launch the flash kernel on already-packed (tl.dot_scaled layout) Q/K/V. @@ -1112,17 +1359,27 @@ def _run_flash_packed( grid = (triton.cdiv(s_q, block_m), z * h) _flash_fwd_kernel[grid]( - qnv_v, qsc_v, knv_v, ksc_v, vnv_v, vsc_v, + qnv_v, + qsc_v, + knv_v, + ksc_v, + vnv_v, + vsc_v, bias if bias is not None else qnv_v, out, out, # dummy lse ptr (STORE_LSE=False) scaling, - s_q, s_kv, + s_q, + s_kv, D=d, - H=h, HK=hk, - sq_qn=qnv_v.stride(1), sq_sn=qsc_v.stride(1), - sk_kn=knv_v.stride(1), sk_sn=ksc_v.stride(1), - sv_kn=vnv_v.stride(1), sv_sn=vsc_v.stride(1), + H=h, + HK=hk, + sq_qn=qnv_v.stride(1), + sq_sn=qsc_v.stride(1), + sk_kn=knv_v.stride(1), + sk_sn=ksc_v.stride(1), + sv_kn=vnv_v.stride(1), + sv_sn=vsc_v.stride(1), sb_z=bias.stride(0) if bias is not None else 0, so_n=out.stride(1), so_z=out.stride(0) if out_zshd else 0, @@ -1131,10 +1388,14 @@ def _run_flash_packed( CAUSAL=causal, STORE_LSE=False, OUT_ZSHD=out_zshd, - BLOCK_M=block_m, BLOCK_N=block_n, - DP2=d // 2, DP16=d // 16, - NP2=block_n // 2, NP16=block_n // 16, - num_warps=num_warps, num_stages=num_stages, + BLOCK_M=block_m, + BLOCK_N=block_n, + DP2=d // 2, + DP16=d // 16, + NP2=block_n // 2, + NP16=block_n // 16, + num_warps=num_warps, + num_stages=num_stages, ) @@ -1186,10 +1447,26 @@ def nvfp4_flash_attention_packed( else: out = torch.empty(z * h, s_q, d, device=qnv.device, dtype=out_dtype) _run_flash_packed( - qnv, qsc, knv, ksc, vnv, vsc, - z, h, hk, s_q, s_kv, d, - scaling, causal, bias, out, - block_m, block_n, num_warps, num_stages, + qnv, + qsc, + knv, + ksc, + vnv, + vsc, + z, + h, + hk, + s_q, + s_kv, + d, + scaling, + causal, + bias, + out, + block_m, + block_n, + num_warps, + num_stages, out_zshd=out_zshd, ) if out_zshd: @@ -1272,7 +1549,8 @@ def nvfp4_flash_attention( out = torch.empty(z * h, s_q, d, device=query.device, dtype=out_dtype) lse = ( torch.empty(z * h, s_q, device=query.device, dtype=torch.float32) - if return_lse else out + if return_lse + else out ) qnv_v = qnv.view(torch.uint8) @@ -1286,28 +1564,43 @@ def nvfp4_flash_attention( def _run(): _flash_fwd_kernel[grid]( - qnv_v, qsc_v, knv_v, ksc_v, vnv_v, vsc_v, + qnv_v, + qsc_v, + knv_v, + ksc_v, + vnv_v, + vsc_v, bias if bias is not None else qnv_v, # dummy ptr when no bias out, lse, scaling, - s_q, s_kv, + s_q, + s_kv, D=d, - H=h, HK=hk, - sq_qn=qnv_v.stride(1), sq_sn=qsc_v.stride(1), - sk_kn=knv_v.stride(1), sk_sn=ksc_v.stride(1), - sv_kn=vnv_v.stride(1), sv_sn=vsc_v.stride(1), + H=h, + HK=hk, + sq_qn=qnv_v.stride(1), + sq_sn=qsc_v.stride(1), + sk_kn=knv_v.stride(1), + sk_sn=ksc_v.stride(1), + sv_kn=vnv_v.stride(1), + sv_sn=vsc_v.stride(1), sb_z=bias.stride(0) if bias is not None else 0, so_n=out.stride(1), - so_z=0, so_h=0, + so_z=0, + so_h=0, HAS_BIAS=bias is not None, CAUSAL=causal, STORE_LSE=return_lse, OUT_ZSHD=False, - BLOCK_M=block_m, BLOCK_N=block_n, - DP2=d // 2, DP16=d // 16, - NP2=block_n // 2, NP16=block_n // 16, - num_warps=num_warps, num_stages=num_stages, + BLOCK_M=block_m, + BLOCK_N=block_n, + DP2=d // 2, + DP16=d // 16, + NP2=block_n // 2, + NP16=block_n // 16, + num_warps=num_warps, + num_stages=num_stages, ) _run() @@ -1324,15 +1617,40 @@ def _run(): # Full forward + native-NVFP4 backward as a torch.autograd.Function. # --------------------------------------------------------------------------- def _run_bwd( - q, k, v, do, o, bias, - z, h, hk, s_q, s_kv, d, scaling, causal, sr, - block_m, block_n, num_warps, num_stages, + q, + k, + v, + do, + o, + bias, + z, + h, + hk, + s_q, + s_kv, + d, + scaling, + causal, + sr, + block_m, + block_n, + num_warps, + num_stages, lse=None, - sr_p_dv=None, sr_dot_dv=None, sr_ds_dq=None, + sr_p_dv=None, + sr_dot_dv=None, + sr_ds_dq=None, dkdv_scratch_bf16=False, - qnv_saved=None, qsc_saved=None, qtnv_saved=None, qtsc_saved=None, - knv_saved=None, ksc_saved=None, vnv_saved=None, vsc_saved=None, - ktnv_saved=None, ktsc_saved=None, + qnv_saved=None, + qsc_saved=None, + qtnv_saved=None, + qtsc_saved=None, + knv_saved=None, + ksc_saved=None, + vnv_saved=None, + vsc_saved=None, + ktnv_saved=None, + ktsc_saved=None, ): """Native-NVFP4 backward. q/do/o: [Z*H,Sq,D]; k/v: [Z*Hk,Skv,D] (hp). Returns dq [Z*H,Sq,D], dk/dv [Z*H,Skv,D] (per query head; GQA-reduced by the caller). @@ -1387,14 +1705,32 @@ def _run_bwd( if not have_lse: _flash_bwd_prep_kernel[(triton.cdiv(s_q, block_m), z * h)]( - q, k, do, o, bdummy, lse, delta, - scaling, seed, s_q, s_kv, - D=d, H=h, HK=hk, - sq_n=q.stride(1), sk_n=k.stride(1), sdo_n=do.stride(1), so_n=o.stride(1), + q, + k, + do, + o, + bdummy, + lse, + delta, + scaling, + seed, + s_q, + s_kv, + D=d, + H=h, + HK=hk, + sq_n=q.stride(1), + sk_n=k.stride(1), + sdo_n=do.stride(1), + so_n=o.stride(1), sb_z=sb_z, - HAS_BIAS=has_bias, CAUSAL=causal, HAVE_LSE=have_lse, - BLOCK_M=block_m, BLOCK_N=dq_block_n, - num_warps=dq_warps, num_stages=num_stages, + HAS_BIAS=has_bias, + CAUSAL=causal, + HAVE_LSE=have_lse, + BLOCK_M=block_m, + BLOCK_N=dq_block_n, + num_warps=dq_warps, + num_stages=num_stages, ) # Pack-prep: quantize the dK/dV pass's m-block-local operands ONCE here (q/qT @@ -1425,14 +1761,33 @@ def _run_bwd( dotnv_p = q.new_empty(z * h, d, s_q_pad // 2, dtype=torch.uint8) dotsc_p = q.new_zeros(z * h, d, s_q_pad // 16, dtype=torch.uint8) _flash_bwd_packprep_kernel[(triton.cdiv(s_q, pp_block_m), z * h)]( - q, do, o, delta, - qnv_p, qsc_p, qtnv_p, qtsc_p, - donv_p, dosc_p, dotnv_p, dotsc_p, - seed, s_q, s_q_pad, - D=d, sq_n=q.stride(1), sdo_n=do.stride(1), so_n=o.stride(1), - SR_DO=sr, SR_DOT=sr_dot_dv, WRITE_DELTA=have_lse, - STORE_Q=not reuse_q_pack, STORE_QT=not reuse_qt_pack, BLOCK_M=pp_block_m, - num_warps=8, num_stages=2, + q, + do, + o, + delta, + qnv_p, + qsc_p, + qtnv_p, + qtsc_p, + donv_p, + dosc_p, + dotnv_p, + dotsc_p, + seed, + s_q, + s_q_pad, + D=d, + sq_n=q.stride(1), + sdo_n=do.stride(1), + so_n=o.stride(1), + SR_DO=sr, + SR_DOT=sr_dot_dv, + WRITE_DELTA=have_lse, + STORE_Q=not reuse_q_pack, + STORE_QT=not reuse_qt_pack, + BLOCK_M=pp_block_m, + num_warps=8, + num_stages=2, ) if not reuse_q_pack: qsc_p = qsc_p.view(torch.float8_e4m3fn) @@ -1465,39 +1820,101 @@ def _run_bwd( ktsc_p = k.new_zeros(z * hk, d, s_kv_pad // 16, dtype=torch.uint8) if not (reuse_k_pack and reuse_v_pack and reuse_kt_pack): _flash_bwd_kprep_kernel[(triton.cdiv(s_kv, kprep_block_n), z * hk)]( - k, v, - knv_p, ksc_p, vnv_p, vsc_p, ktnv_p, ktsc_p, - seed, s_kv, s_kv_pad, - D=d, sk_n=k.stride(1), sv_n=v.stride(1), - STORE_K=not reuse_k_pack, STORE_V=not reuse_v_pack, STORE_KT=not reuse_kt_pack, + k, + v, + knv_p, + ksc_p, + vnv_p, + vsc_p, + ktnv_p, + ktsc_p, + seed, + s_kv, + s_kv_pad, + D=d, + sk_n=k.stride(1), + sv_n=v.stride(1), + STORE_K=not reuse_k_pack, + STORE_V=not reuse_v_pack, + STORE_KT=not reuse_kt_pack, BLOCK_N=kprep_block_n, - num_warps=dq_warps, num_stages=num_stages, + num_warps=dq_warps, + num_stages=num_stages, ) ksc_pv = ksc_p.view(torch.uint8) vsc_pv = vsc_p.view(torch.uint8) ktsc_pv = ktsc_p.view(torch.uint8) _flash_bwd_dkdv_kernel[(triton.cdiv(s_kv, dkdv_block_n), z * h)]( - qnv_p, qsc_p.view(torch.uint8), qtnv_p, qtsc_p.view(torch.uint8), - donv_p, dosc_p.view(torch.uint8), dotnv_p, dotsc_p.view(torch.uint8), - knv_p, ksc_pv, vnv_p, vsc_pv, bdummy, lse, delta, dk, dv, - scaling, seed, s_q, s_q_pad, s_kv, - D=d, H=h, HK=hk, - sb_z=sb_z, sdk_n=dk.stride(1), sdv_n=dv.stride(1), - HAS_BIAS=has_bias, CAUSAL=causal, SR=sr, SR_P_DV=sr_p_dv, - BLOCK_M=block_m, BLOCK_N=dkdv_block_n, - num_warps=dkdv_warps, num_stages=dkdv_stages, + qnv_p, + qsc_p.view(torch.uint8), + qtnv_p, + qtsc_p.view(torch.uint8), + donv_p, + dosc_p.view(torch.uint8), + dotnv_p, + dotsc_p.view(torch.uint8), + knv_p, + ksc_pv, + vnv_p, + vsc_pv, + bdummy, + lse, + delta, + dk, + dv, + scaling, + seed, + s_q, + s_q_pad, + s_kv, + D=d, + H=h, + HK=hk, + sb_z=sb_z, + sdk_n=dk.stride(1), + sdv_n=dv.stride(1), + HAS_BIAS=has_bias, + CAUSAL=causal, + SR=sr, + SR_P_DV=sr_p_dv, + BLOCK_M=block_m, + BLOCK_N=dkdv_block_n, + num_warps=dkdv_warps, + num_stages=dkdv_stages, ) _flash_bwd_dq_kernel[(triton.cdiv(s_q, dq_block_m), z * h)]( - qnv_p, qsc_p.view(torch.uint8), donv_p, dosc_p.view(torch.uint8), bdummy, - knv_p, ksc_pv, vnv_p, vsc_pv, ktnv_p, ktsc_pv, - lse, delta, dq, - scaling, seed, s_q, s_kv, s_kv_pad, - D=d, H=h, HK=hk, - sb_z=sb_z, sdq_n=dq.stride(1), - HAS_BIAS=has_bias, CAUSAL=causal, SR_DS_DQ=sr_ds_dq, - BLOCK_M=dq_block_m, BLOCK_N=dq_block_n, - num_warps=dq_warps, num_stages=dq_stages, + qnv_p, + qsc_p.view(torch.uint8), + donv_p, + dosc_p.view(torch.uint8), + bdummy, + knv_p, + ksc_pv, + vnv_p, + vsc_pv, + ktnv_p, + ktsc_pv, + lse, + delta, + dq, + scaling, + seed, + s_q, + s_kv, + s_kv_pad, + D=d, + H=h, + HK=hk, + sb_z=sb_z, + sdq_n=dq.stride(1), + HAS_BIAS=has_bias, + CAUSAL=causal, + SR_DS_DQ=sr_ds_dq, + BLOCK_M=dq_block_m, + BLOCK_N=dq_block_n, + num_warps=dq_warps, + num_stages=dq_stages, ) return dq, dk, dv @@ -1505,27 +1922,57 @@ def _run_bwd( class _NVFP4FlashAttn(torch.autograd.Function): @staticmethod def forward( - ctx, query, key, value, scaling, causal, num_key_value_groups, - key_pad_bias, sr, save_backward_packs, - backward_p_dv_sr, backward_dot_dv_sr, backward_ds_dq_sr, + ctx, + query, + key, + value, + scaling, + causal, + num_key_value_groups, + key_pad_bias, + sr, + save_backward_packs, + backward_p_dv_sr, + backward_dot_dv_sr, + backward_ds_dq_sr, dkdv_scratch_bf16, - block_m, block_n, num_warps, num_stages, + block_m, + block_n, + num_warps, + num_stages, ): z, h, s_q, d = query.shape _, hk, s_kv, _ = key.shape if save_backward_packs: out, lse, packs = nvfp4_flash_attention( - query, key, value, scaling, causal=causal, - num_key_value_groups=num_key_value_groups, key_pad_bias=key_pad_bias, - block_m=block_m, block_n=block_n, num_warps=num_warps, num_stages=num_stages, - return_lse=True, return_packs=True, + query, + key, + value, + scaling, + causal=causal, + num_key_value_groups=num_key_value_groups, + key_pad_bias=key_pad_bias, + block_m=block_m, + block_n=block_n, + num_warps=num_warps, + num_stages=num_stages, + return_lse=True, + return_packs=True, ) qnv, qsc, qtnv, qtsc, knv, ksc, vdnv, vdsc, ktnv, ktsc = packs else: out, lse = nvfp4_flash_attention( - query, key, value, scaling, causal=causal, - num_key_value_groups=num_key_value_groups, key_pad_bias=key_pad_bias, - block_m=block_m, block_n=block_n, num_warps=num_warps, num_stages=num_stages, + query, + key, + value, + scaling, + causal=causal, + num_key_value_groups=num_key_value_groups, + key_pad_bias=key_pad_bias, + block_m=block_m, + block_n=block_n, + num_warps=num_warps, + num_stages=num_stages, return_lse=True, ) qnv = qsc = qtnv = qtsc = knv = ksc = vdnv = vdsc = ktnv = ktsc = ( @@ -1580,9 +2027,22 @@ def forward( @staticmethod def backward(ctx, grad_out): ( - q, k, v, o, bias, lse, - qnv, qsc, qtnv, qtsc, - knv, ksc, vdnv, vdsc, ktnv, ktsc, + q, + k, + v, + o, + bias, + lse, + qnv, + qsc, + qtnv, + qtsc, + knv, + ksc, + vdnv, + vdsc, + ktnv, + ktsc, ) = ctx.saved_tensors z, h, hk, s_q, s_kv, d = ctx.dims block_m, block_n, num_warps, num_stages = ctx.tiles @@ -1595,9 +2055,26 @@ def backward(ctx, grad_out): q = k = v = o do = grad_out.reshape(z * h, s_q, d).contiguous() dq, dk, dv = _run_bwd( - q, k, v, do, o, bias, - z, h, hk, s_q, s_kv, d, ctx.scaling, ctx.causal, ctx.sr, - block_m, block_n, 4, 1, lse=lse, + q, + k, + v, + do, + o, + bias, + z, + h, + hk, + s_q, + s_kv, + d, + ctx.scaling, + ctx.causal, + ctx.sr, + block_m, + block_n, + 4, + 1, + lse=lse, sr_p_dv=ctx.backward_p_dv_sr, sr_dot_dv=ctx.backward_dot_dv_sr, sr_ds_dq=ctx.backward_ds_dq_sr, @@ -1621,10 +2098,23 @@ def backward(ctx, grad_out): dk = dk.reshape(z, hk, s_kv, d).to(grad_out.dtype) dv = dv.reshape(z, hk, s_kv, d).to(grad_out.dtype) return ( - dq, dk, dv, - None, None, None, None, None, - None, None, None, None, None, - None, None, None, None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, ) @@ -1660,10 +2150,21 @@ def nvfp4_flash_attn_func( assert h % hk == 0 and h // hk == num_key_value_groups assert d in (128, 256) return _NVFP4FlashAttn.apply( - query, key, value, scaling, causal, num_key_value_groups, - key_pad_bias, stochastic_rounding, save_backward_packs, - backward_p_dv_stochastic_rounding, backward_dot_dv_stochastic_rounding, + query, + key, + value, + scaling, + causal, + num_key_value_groups, + key_pad_bias, + stochastic_rounding, + save_backward_packs, + backward_p_dv_stochastic_rounding, + backward_dot_dv_stochastic_rounding, backward_ds_dq_stochastic_rounding, dkdv_scratch_bf16, - block_m, block_n, num_warps, num_stages, + block_m, + block_n, + num_warps, + num_stages, ) diff --git a/src/axolotl/kernels/bf16_fused_ce.py b/src/axolotl/kernels/bf16_fused_ce.py index 5bd377e942..3984595d4f 100644 --- a/src/axolotl/kernels/bf16_fused_ce.py +++ b/src/axolotl/kernels/bf16_fused_ce.py @@ -64,7 +64,9 @@ def forward(ctx, hidden, weight, labels, ignore_index, logit_scale, grad_scale): valid = labels != ignore_index safe_labels = torch.where(valid, labels, labels.new_zeros(())) - running_max = torch.full((M,), float("-inf"), device=device, dtype=torch.float32) + running_max = torch.full( + (M,), float("-inf"), device=device, dtype=torch.float32 + ) running_sum = torch.zeros(M, device=device, dtype=torch.float32) label_logit = torch.zeros(M, device=device, dtype=torch.float32) diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 7e0275bb8a..a4d2c5211a 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="union-attr" """ Module for definition of Low-Rank Adaptation (LoRA) Triton kernels. @@ -64,7 +65,11 @@ def get_lora_parameters( if is_nvfp4_base(base_layer): b = base_layer.bias - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if ( + not hasattr(proj, "disable_adapters") + or proj.disable_adapters + or proj.merged + ): return None, b, base_layer, None, None, None, None, None, None quant_state = base_layer W = None @@ -72,7 +77,11 @@ def get_lora_parameters( W = base_layer.weight b = base_layer.bias - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if ( + not hasattr(proj, "disable_adapters") + or proj.disable_adapters + or proj.merged + ): quant_state = getattr(W, "quant_state", None) if quant_state is None and W.dtype == torch.float8_e4m3fn: quant_state = getattr(base_layer, "weight_scale_inv", None) @@ -296,8 +305,10 @@ def matmul_lora( if is_nvfp4: # FP4 base GEMM against the NVFP4 module (W is None); the `out=` buffer # cannot be reused since the FP4 GEMM allocates its own output. - out = nvfp4_base_dgrad(X, W_quant) if nvfp4_dgrad else nvfp4_base_fprop( - X, W_quant + out = ( + nvfp4_base_dgrad(X, W_quant) + if nvfp4_dgrad + else nvfp4_base_fprop(X, W_quant) ) else: W = dequantize(W.t(), W_quant) diff --git a/src/axolotl/kernels/nvfp4_fused_producers.py b/src/axolotl/kernels/nvfp4_fused_producers.py index 77bfb1cb41..3d7a549d56 100644 --- a/src/axolotl/kernels/nvfp4_fused_producers.py +++ b/src/axolotl/kernels/nvfp4_fused_producers.py @@ -27,7 +27,6 @@ import torch import triton import triton.language as tl - from mslk.quantize.triton.fp4_quantize import convert_fp32_to_fp4_packed _E4M3_EPS = tl.constexpr(1.5258789e-05) @@ -43,19 +42,32 @@ # --------------------------------------------------------------------------- @triton.jit def _rope_quant_kernel( - x_ptr, # [Z*H, S, D] high-precision (normed) Q or K - cos_ptr, sin_ptr, # [Z, S, ROT] (ROT = rotary_dim) - q_ptr, s_ptr, # [Z*H, S, D//2] uint8, [Z*H, S, D//16] e4m3 - Z, H, S, - s_xz, s_xh, s_xr, # x: per-z, per-h, per-row(seq) strides; col(D) stride = 1 - s_cz, s_cr, # cos/sin: per-z stride, per-row stride; col stride = 1 - s_qn, s_qr, # q packed: per-(z*h) stride, per-row stride - s_sn, s_sr, # scale: per-(z*h) stride, per-row stride - D: tl.constexpr, ROT: tl.constexpr, HALF: tl.constexpr, + x_ptr, # [Z*H, S, D] high-precision (normed) Q or K + cos_ptr, + sin_ptr, # [Z, S, ROT] (ROT = rotary_dim) + q_ptr, + s_ptr, # [Z*H, S, D//2] uint8, [Z*H, S, D//16] e4m3 + Z, + H, + S, + s_xz, + s_xh, + s_xr, # x: per-z, per-h, per-row(seq) strides; col(D) stride = 1 + s_cz, + s_cr, # cos/sin: per-z stride, per-row stride; col stride = 1 + s_qn, + s_qr, # q packed: per-(z*h) stride, per-row stride + s_sn, + s_sr, # scale: per-(z*h) stride, per-row stride + D: tl.constexpr, + ROT: tl.constexpr, + HALF: tl.constexpr, BLOCK_R: tl.constexpr, - DP2: tl.constexpr, DP16: tl.constexpr, NG: tl.constexpr, + DP2: tl.constexpr, + DP16: tl.constexpr, + NG: tl.constexpr, ): - pid_n = tl.program_id(0) # z*h + pid_n = tl.program_id(0) # z*h pid_r = tl.program_id(1) z = pid_n // H h = pid_n % H @@ -67,7 +79,8 @@ def _rope_quant_kernel( xbase = z * s_xz + h * s_xh x = tl.load( x_ptr + xbase + offs_r[:, None] * s_xr + offs_d[None, :], - mask=rmask[:, None], other=0.0, + mask=rmask[:, None], + other=0.0, ).to(tl.float32) # partial RoPE on the first ROT dims; tail [ROT, D) passes through unrotated. @@ -77,7 +90,8 @@ def _rope_quant_kernel( partner = tl.where(is_low, offs_d + HALF, offs_d - HALF) xp = tl.load( x_ptr + xbase + offs_r[:, None] * s_xr + partner[None, :], - mask=rmask[:, None] & (partner[None, :] < ROT), other=0.0, + mask=rmask[:, None] & (partner[None, :] < ROT), + other=0.0, ).to(tl.float32) rot = tl.where(is_low, -xp, xp) @@ -86,11 +100,13 @@ def _rope_quant_kernel( cd = tl.where(is_rot, offs_d, 0) cos = tl.load( cos_ptr + cbase + offs_r[:, None] * s_cr + cd[None, :], - mask=rmask[:, None], other=0.0, + mask=rmask[:, None], + other=0.0, ).to(tl.float32) sin = tl.load( sin_ptr + cbase + offs_r[:, None] * s_cr + cd[None, :], - mask=rmask[:, None], other=0.0, + mask=rmask[:, None], + other=0.0, ).to(tl.float32) x_rot = tl.where(is_rot[None, :], x * cos + rot * sin, x) @@ -106,12 +122,14 @@ def _rope_quant_kernel( offs_qk = tl.arange(0, DP2) tl.store( q_ptr + pid_n * s_qn + offs_r[:, None] * s_qr + offs_qk[None, :], - qpk, mask=rmask[:, None], + qpk, + mask=rmask[:, None], ) offs_sk = tl.arange(0, DP16) tl.store( s_ptr + pid_n * s_sn + offs_r[:, None] * s_sr + offs_sk[None, :], - sc.to(tl.uint8, bitcast=True), mask=rmask[:, None], + sc.to(tl.uint8, bitcast=True), + mask=rmask[:, None], ) @@ -146,15 +164,30 @@ def fused_rope_quant_qk( BLOCK_R = 64 grid = (z * h, triton.cdiv(s, BLOCK_R)) _rope_quant_kernel[grid]( - x, cos, sin, q, sc, - z, h, s, - x.stride(0), x.stride(1), x.stride(2), - cos.stride(0), cos.stride(1), - q.stride(0), q.stride(1), - sc.stride(0), sc.stride(1), - D=d, ROT=rot, HALF=rot // 2, + x, + cos, + sin, + q, + sc, + z, + h, + s, + x.stride(0), + x.stride(1), + x.stride(2), + cos.stride(0), + cos.stride(1), + q.stride(0), + q.stride(1), + sc.stride(0), + sc.stride(1), + D=d, + ROT=rot, + HALF=rot // 2, BLOCK_R=BLOCK_R, - DP2=d // 2, DP16=d // 16, NG=d // 16, + DP2=d // 2, + DP16=d // 16, + NG=d // 16, ) return q, sc.view(torch.float8_e4m3fn) @@ -193,18 +226,30 @@ def quant_v_keyaxis( # --------------------------------------------------------------------------- @triton.jit def _vproj_pack_keyaxis_kernel( - xnv_ptr, xsc_ptr, # [Z, S, K//2] uint8, [Z, S, K//16] e4m3 (activation) - wnv_ptr, wsc_ptr, # [HK*D, K//2] uint8, [HK*D, K//16] e4m3 (Wv, [HK*D,K]) - vnv_ptr, vsc_ptr, # [Z*HK, D, S_pad//2] uint8, [Z*HK, D, S_pad//16] e4m3 - S, S_pad, K, - sx_z, sx_s, # x packed: per-z, per-row(seq) strides - ssc_z, ssc_s, # x scale: per-z, per-row strides - sw_n, # weight packed row stride (= K//2); scale row = K//16 - sv_d, svsc_d, # vnv/vsc per-row(D) strides (= S_pad//2, S_pad//16) - HK: tl.constexpr, D: tl.constexpr, - BLOCK_S: tl.constexpr, BLOCK_K: tl.constexpr, - KP2: tl.constexpr, KP16: tl.constexpr, - SP2: tl.constexpr, SP16: tl.constexpr, # BLOCK_S//2, BLOCK_S//16 + xnv_ptr, + xsc_ptr, # [Z, S, K//2] uint8, [Z, S, K//16] e4m3 (activation) + wnv_ptr, + wsc_ptr, # [HK*D, K//2] uint8, [HK*D, K//16] e4m3 (Wv, [HK*D,K]) + vnv_ptr, + vsc_ptr, # [Z*HK, D, S_pad//2] uint8, [Z*HK, D, S_pad//16] e4m3 + S, + S_pad, + K, + sx_z, + sx_s, # x packed: per-z, per-row(seq) strides + ssc_z, + ssc_s, # x scale: per-z, per-row strides + sw_n, # weight packed row stride (= K//2); scale row = K//16 + sv_d, + svsc_d, # vnv/vsc per-row(D) strides (= S_pad//2, S_pad//16) + HK: tl.constexpr, + D: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_K: tl.constexpr, + KP2: tl.constexpr, + KP16: tl.constexpr, + SP2: tl.constexpr, + SP16: tl.constexpr, # BLOCK_S//2, BLOCK_S//16 ): pid_z = tl.program_id(0) pid_hk = tl.program_id(1) @@ -213,7 +258,7 @@ def _vproj_pack_keyaxis_kernel( offs_s = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) smask = offs_s < S - offs_d = tl.arange(0, D) # this head's D output cols = W rows [hk*D, (hk+1)*D) + offs_d = tl.arange(0, D) # this head's D output cols = W rows [hk*D, (hk+1)*D) wrow = pid_hk * D + offs_d acc = tl.zeros((BLOCK_S, D), dtype=tl.float32) @@ -222,11 +267,13 @@ def _vproj_pack_keyaxis_kernel( offk16 = k0 // 16 + tl.arange(0, KP16) a = tl.load( xnv_ptr + pid_z * sx_z + offs_s[:, None] * sx_s + offk2[None, :], - mask=smask[:, None], other=0, + mask=smask[:, None], + other=0, ) asc = tl.load( xsc_ptr + pid_z * ssc_z + offs_s[:, None] * ssc_s + offk16[None, :], - mask=smask[:, None], other=0, + mask=smask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) w = tl.load(wnv_ptr + wrow[:, None] * sw_n + offk2[None, :]) wsc = tl.load( @@ -239,7 +286,7 @@ def _vproj_pack_keyaxis_kernel( # pack along the SEQ axis (group-16): transpose the [BLOCK_S, D] tile to [D, BLOCK_S] # so groups of 16 run down the seq axis, matching the V^T key-axis layout. - accT = tl.trans(acc) # [D, BLOCK_S] + accT = tl.trans(acc) # [D, BLOCK_S] NG: tl.constexpr = BLOCK_S // 16 xb = accT.reshape(D, NG, 16) amax = tl.max(tl.abs(xb), axis=2) @@ -275,9 +322,9 @@ def prepack_vproj_weight(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tens def fused_vproj_quant_v_keyaxis( - hidden_states: torch.Tensor, # [Z, S, hidden] - wnv: torch.Tensor, # [HK*D, hidden//2] uint8 (prepacked v_proj weight) - wsc: torch.Tensor, # [HK*D, hidden//16] e4m3 + hidden_states: torch.Tensor, # [Z, S, hidden] + wnv: torch.Tensor, # [HK*D, hidden//2] uint8 (prepacked v_proj weight) + wsc: torch.Tensor, # [HK*D, hidden//16] e4m3 hk: int, d: int, block_n: int = 128, @@ -310,18 +357,31 @@ def fused_vproj_quant_v_keyaxis( grid = (z, hk, triton.cdiv(s, block_s)) _vproj_pack_keyaxis_kernel[grid]( - xnv.view(torch.uint8), xsc, - wnv.view(torch.uint8), wsc.view(torch.uint8), - vnv, vsc, - s, s_pad, k, - xnv.stride(0), xnv.stride(1), - xsc.stride(0), xsc.stride(1), + xnv.view(torch.uint8), + xsc, + wnv.view(torch.uint8), + wsc.view(torch.uint8), + vnv, + vsc, + s, + s_pad, + k, + xnv.stride(0), + xnv.stride(1), + xsc.stride(0), + xsc.stride(1), wnv.stride(0), - vnv.stride(1), vsc.stride(1), - HK=hk, D=d, - BLOCK_S=block_s, BLOCK_K=block_k, - KP2=block_k // 2, KP16=block_k // 16, - SP2=block_s // 2, SP16=block_s // 16, - num_warps=num_warps, num_stages=num_stages, + vnv.stride(1), + vsc.stride(1), + HK=hk, + D=d, + BLOCK_S=block_s, + BLOCK_K=block_k, + KP2=block_k // 2, + KP16=block_k // 16, + SP2=block_s // 2, + SP16=block_s // 16, + num_warps=num_warps, + num_stages=num_stages, ) return vnv, vsc.view(torch.float8_e4m3fn), s_pad diff --git a/src/axolotl/kernels/nvfp4_rmsnorm.py b/src/axolotl/kernels/nvfp4_rmsnorm.py index 26364d6714..c6e4a7b91b 100644 --- a/src/axolotl/kernels/nvfp4_rmsnorm.py +++ b/src/axolotl/kernels/nvfp4_rmsnorm.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="arg-type" """Fused RMSNorm -> NVFP4 quantization for consumer Blackwell (sm_120). MSLK ships a fused rms+quant kernel (`triton_scale_nvfp4_quant_rms`) but it uses a @@ -268,10 +269,11 @@ def from_norm(cls, norm: nn.Module, eps_attr: str = "variance_epsilon"): w = norm.weight with torch.no_grad(): x = torch.randn(8, w.shape[-1], device=w.device, dtype=w.dtype) - normed = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + eps) + normed = x.float() * torch.rsqrt( + x.float().pow(2).mean(-1, keepdim=True) + eps + ) y = norm(x).float() e_plain = (y - normed * w.float()).abs().mean() e_zc = (y - normed * (1.0 + w.float())).abs().mean() zero_centered = bool(e_zc < e_plain) return cls(w, eps, zero_centered) - diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 2bee09ba5e..90c1dc17ae 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -378,7 +378,8 @@ def _apply_nvfp4_training(self, model: PreTrainedModel): # FP4-storage path, so a requested storage/compute base_mode would # silently give no memory saving. Steer the user to backend: native. if adapter in ("lora", "qlora") and ( - nvfp4.quantize_base or getattr(nvfp4, "base_mode", None) in ("storage", "compute") + nvfp4.quantize_base + or getattr(nvfp4, "base_mode", None) in ("storage", "compute") ): LOG.warning( "nvfp4_training.backend: te ignores base_mode/quantize_base " @@ -767,10 +768,10 @@ def _nvfp4_apply_tied_or_lm_head(self, model, recipe, base_mode: str) -> None: # lm_head store; force the torchao storage class for it. Otherwise use # the requested base mode (compute/storage/hp). if bool(getattr(nvfp4, "fused_fp4_cross_entropy", False)): - from axolotl.utils.nvfp4_training import swap_frozen_lm_head_tileable - import torch.nn as _nn + from axolotl.utils.nvfp4_training import swap_frozen_lm_head_tileable + out_emb = model.get_output_embeddings() if isinstance(out_emb, _nn.Linear): name = self._module_name(model, out_emb) diff --git a/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py b/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py index 1cb714409f..dd4a21d78b 100644 --- a/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py +++ b/src/axolotl/monkeypatch/attention/nvfp4_flash_attn.py @@ -57,9 +57,9 @@ def _custom_op_enabled(module: nn.Module) -> bool: - requested = getattr( - module, "_nvfp4_compile_custom_op", False - ) or os.environ.get(_CUSTOM_OP_ENV, "").lower() in {"1", "true", "yes"} + requested = getattr(module, "_nvfp4_compile_custom_op", False) or os.environ.get( + _CUSTOM_OP_ENV, "" + ).lower() in {"1", "true", "yes"} return bool(requested) and torch.compiler.is_compiling() @@ -73,9 +73,7 @@ def _causal_tril(q_len: int, kv_len: int, device: torch.device) -> torch.Tensor: key = (q_len, kv_len, device) t = _TRIL_CACHE.get(key) if t is None: - t = torch.tril( - torch.ones(q_len, kv_len, dtype=torch.bool, device=device) - ) + t = torch.tril(torch.ones(q_len, kv_len, dtype=torch.bool, device=device)) _TRIL_CACHE[key] = t return t @@ -150,10 +148,10 @@ def _can_fuse_vproj(module: nn.Module) -> bool: def _nvfp4_attention( module: nn.Module, - query_states: torch.Tensor, # [Z, H, S, D] post q_norm, PRE-RoPE - key_states: torch.Tensor, # [Z, Hk, S, D] post k_norm, PRE-RoPE - value_states: torch.Tensor, # [Z, Hk, Skv, D] - cos: torch.Tensor, # [Z, S, rotary_dim] + query_states: torch.Tensor, # [Z, H, S, D] post q_norm, PRE-RoPE + key_states: torch.Tensor, # [Z, Hk, S, D] post k_norm, PRE-RoPE + value_states: torch.Tensor, # [Z, Hk, Skv, D] + cos: torch.Tensor, # [Z, S, rotary_dim] sin: torch.Tensor, scaling: float, causal: bool, @@ -183,8 +181,18 @@ def _nvfp4_attention( # The compile custom op keeps the [Z, H, S, D] schema (its registered op / # fake are layout-fixed), so the transpose stays on the compiled path. out = nvfp4_flash_attention_packed_custom_op( - qnv, qsc, knv, ksc, vnv, vsc, - z=z, h=h, hk=hk, s_q=s_q, s_kv=s_kv, d=d, + qnv, + qsc, + knv, + ksc, + vnv, + vsc, + z=z, + h=h, + hk=hk, + s_q=s_q, + s_kv=s_kv, + d=d, scaling=scaling, out_dtype=query_states.dtype, causal=causal, @@ -195,8 +203,18 @@ def _nvfp4_attention( # Eager path: the kernel writes the [Z, S, H, D] HF attn_output layout directly, # so the per-layer transpose(1,2)+contiguous copy at the caller is eliminated. return nvfp4_flash_attention_packed( - qnv, qsc, knv, ksc, vnv, vsc, - z=z, h=h, hk=hk, s_q=s_q, s_kv=s_kv, d=d, + qnv, + qsc, + knv, + ksc, + vnv, + vsc, + z=z, + h=h, + hk=hk, + s_q=s_q, + s_kv=s_kv, + d=d, scaling=scaling, out_dtype=query_states.dtype, causal=causal, @@ -276,6 +294,7 @@ def make_nvfp4_forward(orig_forward): ``train_backward`` was enabled at patch time, grad-enabled dense prefill uses the differentiable native-NVFP4 attention function. """ + def forward( self, hidden_states, @@ -287,8 +306,12 @@ def forward( grad_enabled = torch.is_grad_enabled() if kwargs.get("output_attentions"): return orig_forward( - self, hidden_states, position_embeddings, attention_mask, - past_key_values, **kwargs, + self, + hidden_states, + position_embeddings, + attention_mask, + past_key_values, + **kwargs, ) input_shape = hidden_states.shape[:-1] @@ -321,12 +344,15 @@ def forward( if use_fp4_qk: q_full, k_full = _nvfp4_qk_proj(self, hidden_states) query_states, gate = torch.chunk( - q_full.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1, + q_full.view(*input_shape, -1, self.head_dim * 2), + 2, + dim=-1, ) else: query_states, gate = torch.chunk( self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), - 2, dim=-1, + 2, + dim=-1, ) gate = gate.reshape(*input_shape, -1) @@ -351,8 +377,12 @@ def forward( or past_key_values is not None ): return orig_forward( - self, hidden_states, position_embeddings, attention_mask, - past_key_values, **kwargs, + self, + hidden_states, + position_embeddings, + attention_mask, + past_key_values, + **kwargs, ) # Break the Inductor graph on BOTH sides of the FP4 attention block. # Fused with the FP4 plugin's quantized q/k/v/o_proj autograd, the @@ -432,14 +462,18 @@ def forward( self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) ) if past_key_values is not None: - k_roped, _ = apply_rotary_pos_emb( - key_states, key_states, cos, sin - ) + k_roped, _ = apply_rotary_pos_emb(key_states, key_states, cos, sin) past_key_values.update(k_roped, value_states, self.layer_idx) attn_output = _nvfp4_attention( - self, query_states, key_states, value_states, cos, sin, - self.scaling, causal=(kind == "causal"), + self, + query_states, + key_states, + value_states, + cos, + sin, + self.scaling, + causal=(kind == "causal"), hidden_states=hidden_states if fuse_v else None, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -449,8 +483,12 @@ def forward( return self.o_proj(attn_output), None return orig_forward( - self, hidden_states, position_embeddings, attention_mask, - past_key_values, **kwargs, + self, + hidden_states, + position_embeddings, + attention_mask, + past_key_values, + **kwargs, ) return forward @@ -517,7 +555,12 @@ def patch_qwen3_5_nvfp4_attention( "nvfp4 attention: patched %d Qwen3.5 full-attention layers " "(fuse_vproj=%s, train_backward=%s, backward_rtn_grad_packs=%s, " "save_backward_packs=%s, dkdv_scratch_bf16=%s, compile_custom_op=%s)", - patched, fuse_vproj, train_backward, backward_rtn_grad_packs, - save_backward_packs, dkdv_scratch_bf16, compile_custom_op, + patched, + fuse_vproj, + train_backward, + backward_rtn_grad_packs, + save_backward_packs, + dkdv_scratch_bf16, + compile_custom_op, ) return patched diff --git a/src/axolotl/utils/nvfp4_training.py b/src/axolotl/utils/nvfp4_training.py index 9d43e5610e..eb1ce0b53f 100644 --- a/src/axolotl/utils/nvfp4_training.py +++ b/src/axolotl/utils/nvfp4_training.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="assignment, misc" """NVFP4-GEMM training: real FP4 forward + backward GEMMs on Blackwell. This is a throughput feature, not a memory feature: master weights and @@ -24,8 +25,8 @@ from dataclasses import dataclass import torch -from torch import nn import triton +from torch import nn from triton import language as tl LOG = logging.getLogger(__name__) @@ -482,7 +483,9 @@ def backward(ctx, grad_out): # GEMM (the stored layout is quantized along K for the forward). w_hp = w_q.dequantize(torch.bfloat16) grad_x = _fp4_mm( - g_p, w_hp, QuantPolicy(stochastic=ctx.recipe.stochastic_rounding), + g_p, + w_hp, + QuantPolicy(stochastic=ctx.recipe.stochastic_rounding), QuantPolicy(), )[:m] grad_x = grad_x.reshape(ctx.x_shape) @@ -521,7 +524,11 @@ def fsdp_post_all_gather( ctx, per_tensor_scale = metadata if out is not None: return - inner = {"qdata": qdata, "scale": scale, "per_tensor_scale": per_tensor_scale} + inner = { + "qdata": qdata, + "scale": scale, + "per_tensor_scale": per_tensor_scale, + } rebuilt = type(self).__tensor_unflatten__(inner, ctx, None, None) return rebuilt, (qdata, scale) @@ -776,9 +783,9 @@ def backward(ctx, grad_out): g_p, m = _pad_to_block(g, 0) # dgrad: gx[M,K] = g[M,N] @ W[N,K]; W pre-quantized along N (contraction) g_q = _quantize(g_p, QuantPolicy(stochastic=ctx.recipe.stochastic_rounding)) - grad_x = _addmm_nvfp4_dispatch( - g_q, ctx.w_dgrad, torch.ops.aten.mm.default - )[:m] + grad_x = _addmm_nvfp4_dispatch(g_q, ctx.w_dgrad, torch.ops.aten.mm.default)[ + :m + ] grad_x = grad_x.reshape(ctx.x_shape) return grad_x, None, None, None @@ -803,9 +810,7 @@ def __init__(self, w_fprop, w_dgrad, bias, recipe: NVFP4Recipe): self.out_features = w_fprop.shape[1] def forward(self, x): - out = NVFP4ComputeBaseFunction.apply( - x, self.w_fprop, self.w_dgrad, self.recipe - ) + out = NVFP4ComputeBaseFunction.apply(x, self.w_fprop, self.w_dgrad, self.recipe) return out if self.bias is None else out + self.bias @property @@ -932,13 +937,29 @@ def _recipe_load_lane( else: mask = None other = None - return tl.load(x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other).to(tl.float32) + return tl.load( + x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other + ).to(tl.float32) @triton.jit def _recipe_hadamard16( - x0, x1, x2, x3, x4, x5, x6, x7, - x8, x9, x10, x11, x12, x13, x14, x15, + x0, + x1, + x2, + x3, + x4, + x5, + x6, + x7, + x8, + x9, + x10, + x11, + x12, + x13, + x14, + x15, ): a0 = x0 + x1 a1 = x0 - x1 @@ -1030,8 +1051,22 @@ def _recipe_hadamard16( @triton.jit def _recipe_lane_amax( - y0, y1, y2, y3, y4, y5, y6, y7, - y8, y9, y10, y11, y12, y13, y14, y15, + y0, + y1, + y2, + y3, + y4, + y5, + y6, + y7, + y8, + y9, + y10, + y11, + y12, + y13, + y14, + y15, ): a = tl.maximum(tl.abs(y0), tl.abs(y1)) a = tl.maximum(a, tl.abs(y2)) @@ -1070,26 +1105,60 @@ def _recipe_rht_amax_kernel( if USE_INT64_INDEXING: offs_m = offs_m.to(tl.int64) - x0 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 0, USE_MASK) - x1 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 1, USE_MASK) - x2 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 2, USE_MASK) - x3 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 3, USE_MASK) - x4 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 4, USE_MASK) - x5 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 5, USE_MASK) - x6 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 6, USE_MASK) - x7 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 7, USE_MASK) - x8 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 8, USE_MASK) - x9 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 9, USE_MASK) - x10 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 10, USE_MASK) - x11 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 11, USE_MASK) - x12 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 12, USE_MASK) - x13 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 13, USE_MASK) - x14 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 14, USE_MASK) - x15 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 15, USE_MASK) + x0 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 0, USE_MASK + ) + x1 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 1, USE_MASK + ) + x2 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 2, USE_MASK + ) + x3 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 3, USE_MASK + ) + x4 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 4, USE_MASK + ) + x5 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 5, USE_MASK + ) + x6 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 6, USE_MASK + ) + x7 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 7, USE_MASK + ) + x8 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 8, USE_MASK + ) + x9 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 9, USE_MASK + ) + x10 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 10, USE_MASK + ) + x11 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 11, USE_MASK + ) + x12 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 12, USE_MASK + ) + x13 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 13, USE_MASK + ) + x14 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 14, USE_MASK + ) + x15 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 15, USE_MASK + ) if HADAMARD: - y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15 = _recipe_hadamard16( - x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 + y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15 = ( + _recipe_hadamard16( + x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 + ) ) else: y0, y1, y2, y3, y4, y5, y6, y7 = x0, x1, x2, x3, x4, x5, x6, x7 @@ -1102,7 +1171,9 @@ def _recipe_rht_amax_kernel( @triton.jit -def _recipe_norm_lane(y, scales, global_scale, seed, base_off, lane_off, STOCHASTIC: tl.constexpr): +def _recipe_norm_lane( + y, scales, global_scale, seed, base_off, lane_off, STOCHASTIC: tl.constexpr +): yn = y * (global_scale / scales.to(tl.float32)) if STOCHASTIC: ax = tl.abs(yn) @@ -1151,26 +1222,60 @@ def _recipe_quantize_kernel( if USE_INT64_INDEXING: offs_m = offs_m.to(tl.int64) - x0 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 0, USE_MASK) - x1 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 1, USE_MASK) - x2 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 2, USE_MASK) - x3 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 3, USE_MASK) - x4 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 4, USE_MASK) - x5 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 5, USE_MASK) - x6 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 6, USE_MASK) - x7 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 7, USE_MASK) - x8 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 8, USE_MASK) - x9 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 9, USE_MASK) - x10 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 10, USE_MASK) - x11 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 11, USE_MASK) - x12 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 12, USE_MASK) - x13 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 13, USE_MASK) - x14 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 14, USE_MASK) - x15 = _recipe_load_lane(x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 15, USE_MASK) + x0 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 0, USE_MASK + ) + x1 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 1, USE_MASK + ) + x2 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 2, USE_MASK + ) + x3 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 3, USE_MASK + ) + x4 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 4, USE_MASK + ) + x5 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 5, USE_MASK + ) + x6 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 6, USE_MASK + ) + x7 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 7, USE_MASK + ) + x8 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 8, USE_MASK + ) + x9 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 9, USE_MASK + ) + x10 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 10, USE_MASK + ) + x11 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 11, USE_MASK + ) + x12 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 12, USE_MASK + ) + x13 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 13, USE_MASK + ) + x14 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 14, USE_MASK + ) + x15 = _recipe_load_lane( + x_ptr, offs_m, group, pid_n, stride_xm, stride_xn, M, N, 15, USE_MASK + ) if HADAMARD: - y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15 = _recipe_hadamard16( - x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 + y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15 = ( + _recipe_hadamard16( + x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 + ) ) else: y0, y1, y2, y3, y4, y5, y6, y7 = x0, x1, x2, x3, x4, x5, x6, x7 @@ -1189,28 +1294,62 @@ def _recipe_quantize_kernel( scales = tl.where(scale_mask, scales, 0.0) offs_m_in_layout = (pid_m * M_PER_BLOCK % 128) + tl.arange(0, M_PER_BLOCK)[:, None] - layout_off = ((pid_m * M_PER_BLOCK) // 128) * NUM_N_BLOCKS * NUM_ELEM_PER_LAYOUT + pid_n * NUM_ELEM_PER_LAYOUT + layout_off = ( + (pid_m * M_PER_BLOCK) // 128 + ) * NUM_N_BLOCKS * NUM_ELEM_PER_LAYOUT + pid_n * NUM_ELEM_PER_LAYOUT scale_offs = layout_off + _recipe_scale_swizzle(offs_m_in_layout) tl.store(s_ptr + scale_offs, scales) row_base = offs_m * N col_base = pid_n * 64 + group * 16 - n0 = _recipe_norm_lane(y0, scales, global_scale, seed, row_base + col_base, 0, STOCHASTIC) - n1 = _recipe_norm_lane(y1, scales, global_scale, seed, row_base + col_base, 1, STOCHASTIC) - n2 = _recipe_norm_lane(y2, scales, global_scale, seed, row_base + col_base, 2, STOCHASTIC) - n3 = _recipe_norm_lane(y3, scales, global_scale, seed, row_base + col_base, 3, STOCHASTIC) - n4 = _recipe_norm_lane(y4, scales, global_scale, seed, row_base + col_base, 4, STOCHASTIC) - n5 = _recipe_norm_lane(y5, scales, global_scale, seed, row_base + col_base, 5, STOCHASTIC) - n6 = _recipe_norm_lane(y6, scales, global_scale, seed, row_base + col_base, 6, STOCHASTIC) - n7 = _recipe_norm_lane(y7, scales, global_scale, seed, row_base + col_base, 7, STOCHASTIC) - n8 = _recipe_norm_lane(y8, scales, global_scale, seed, row_base + col_base, 8, STOCHASTIC) - n9 = _recipe_norm_lane(y9, scales, global_scale, seed, row_base + col_base, 9, STOCHASTIC) - n10 = _recipe_norm_lane(y10, scales, global_scale, seed, row_base + col_base, 10, STOCHASTIC) - n11 = _recipe_norm_lane(y11, scales, global_scale, seed, row_base + col_base, 11, STOCHASTIC) - n12 = _recipe_norm_lane(y12, scales, global_scale, seed, row_base + col_base, 12, STOCHASTIC) - n13 = _recipe_norm_lane(y13, scales, global_scale, seed, row_base + col_base, 13, STOCHASTIC) - n14 = _recipe_norm_lane(y14, scales, global_scale, seed, row_base + col_base, 14, STOCHASTIC) - n15 = _recipe_norm_lane(y15, scales, global_scale, seed, row_base + col_base, 15, STOCHASTIC) + n0 = _recipe_norm_lane( + y0, scales, global_scale, seed, row_base + col_base, 0, STOCHASTIC + ) + n1 = _recipe_norm_lane( + y1, scales, global_scale, seed, row_base + col_base, 1, STOCHASTIC + ) + n2 = _recipe_norm_lane( + y2, scales, global_scale, seed, row_base + col_base, 2, STOCHASTIC + ) + n3 = _recipe_norm_lane( + y3, scales, global_scale, seed, row_base + col_base, 3, STOCHASTIC + ) + n4 = _recipe_norm_lane( + y4, scales, global_scale, seed, row_base + col_base, 4, STOCHASTIC + ) + n5 = _recipe_norm_lane( + y5, scales, global_scale, seed, row_base + col_base, 5, STOCHASTIC + ) + n6 = _recipe_norm_lane( + y6, scales, global_scale, seed, row_base + col_base, 6, STOCHASTIC + ) + n7 = _recipe_norm_lane( + y7, scales, global_scale, seed, row_base + col_base, 7, STOCHASTIC + ) + n8 = _recipe_norm_lane( + y8, scales, global_scale, seed, row_base + col_base, 8, STOCHASTIC + ) + n9 = _recipe_norm_lane( + y9, scales, global_scale, seed, row_base + col_base, 9, STOCHASTIC + ) + n10 = _recipe_norm_lane( + y10, scales, global_scale, seed, row_base + col_base, 10, STOCHASTIC + ) + n11 = _recipe_norm_lane( + y11, scales, global_scale, seed, row_base + col_base, 11, STOCHASTIC + ) + n12 = _recipe_norm_lane( + y12, scales, global_scale, seed, row_base + col_base, 12, STOCHASTIC + ) + n13 = _recipe_norm_lane( + y13, scales, global_scale, seed, row_base + col_base, 13, STOCHASTIC + ) + n14 = _recipe_norm_lane( + y14, scales, global_scale, seed, row_base + col_base, 14, STOCHASTIC + ) + n15 = _recipe_norm_lane( + y15, scales, global_scale, seed, row_base + col_base, 15, STOCHASTIC + ) q0 = _recipe_fp32_to_fp4_packed((n0, n1)) q1 = _recipe_fp32_to_fp4_packed((n2, n3)) @@ -1632,9 +1771,9 @@ def backward(ctx, grad_out): gp, QuantPolicy(stochastic=ctx.recipe.stochastic_rounding) ) wdq, wdsc, wd_inv = _mslk_quantize(w_hp.t().contiguous()) # B = W.t() - grad_x = _mslk_scaled_mm( - gq, gsc, g_inv, wdq, wdsc, wd_inv, grad_out.dtype - )[:m] + grad_x = _mslk_scaled_mm(gq, gsc, g_inv, wdq, wdsc, wd_inv, grad_out.dtype)[ + :m + ] grad_x = grad_x.reshape(ctx.x_shape) return grad_x, None, None, None, None, None, None @@ -1753,7 +1892,11 @@ def nvfp4_base_dgrad(g: torch.Tensor, base) -> torch.Tensor: elif isinstance(base, NVFP4FastFrozenBaseLinear): # single FP4 layout: dequantize the stored weight for the dgrad GEMM. w_hp = _mslk_dequant( - base.wq, base.wsc, base.w_inv, (base.out_features, base.in_features), g.dtype + base.wq, + base.wsc, + base.w_inv, + (base.out_features, base.in_features), + g.dtype, ) out = _fp4_mm(gp, w_hp, sr, QuantPolicy()) elif isinstance(base, NVFP4FrozenBaseLinear): @@ -1768,8 +1911,7 @@ def _is_swappable(module: nn.Linear) -> bool: # the _scaled_mm packed-contraction rule (logical %32, not just block %16) — # an out_features of 16 packs to 8 and trips "trailing dim divisible by 16". return ( - module.in_features % _GEMM_ALIGN == 0 - and module.out_features % _GEMM_ALIGN == 0 + module.in_features % _GEMM_ALIGN == 0 and module.out_features % _GEMM_ALIGN == 0 ) @@ -2117,9 +2259,7 @@ def swap_tied_embedding_and_lm_head_to_nvfp4( tied_head = NVFP4TiedLMHead(new_embed, lm_head_bias, recipe) _set_submodule(model, lm_head_name, tied_head) _dynamo_disable_forward(tied_head) - LOG.info( - "NVFP4 training: tied embedding/lm_head quantized once (shared FP4 store)" - ) + LOG.info("NVFP4 training: tied embedding/lm_head quantized once (shared FP4 store)") return True @@ -2220,8 +2360,7 @@ def convert_vision_tower_to_nvfp4( _stream_quantize_swap(vt, name, module, build) swapped += 1 LOG.info( - "nvfp4_training.quantize_vision_tower: swapped %d linears under %s " - "(mode=%s)", + "nvfp4_training.quantize_vision_tower: swapped %d linears under %s (mode=%s)", swapped, vt_name, base_mode, @@ -2323,7 +2462,6 @@ def convert_lora_base_to_te_nvfp4( constraint. Weights are copied in; dims must be divisible by 16. """ import transformer_engine.pytorch as te - from peft.tuners.lora import Linear as LoraLinear recipe = recipe or NVFP4Recipe() @@ -2394,9 +2532,7 @@ def convert_lora_base_to_nvfp4( # plus one transient layer. Weights already on the GPU are quantized in place. # A materialized named_modules() list would pin every base_layer and defeat # the per-layer free, so keep only the lora.Linear references. - target = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) + target = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") lora_modules = [ (n, m) for n, m in model.named_modules() if isinstance(m, LoraLinear) ] @@ -2570,7 +2706,7 @@ def load_nvfp4_packed(model: nn.Module, model_dir) -> int: path = os.path.join(str(model_dir), NVFP4_PACKED_SIDECAR) if not os.path.isfile(path): return 0 - packed = torch.load(path, weights_only=False, map_location="cpu") + packed = torch.load(path, weights_only=False, map_location="cpu") # nosec B614 by_module: dict[str, dict] = {} for key, tensor in packed.items(): mod_name, buf_name = key.rsplit(".", 1) @@ -2604,7 +2740,5 @@ def load_nvfp4_packed(model: nn.Module, model_dir) -> int: else: module.register_buffer(bname, tensor) restored += 1 - LOG.info( - "NVFP4 save_nvfp4: restored %d packed tensor(s) from %s", restored, path - ) + LOG.info("NVFP4 save_nvfp4: restored %d packed tensor(s) from %s", restored, path) return restored diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index fd1b1fcd23..34c71dd808 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -59,9 +59,9 @@ ) from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.nvfp4 import NVFP4TrainingConfig -from axolotl.utils.schemas.sage import SageAttentionConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.quantization import PTQConfig, QATConfig +from axolotl.utils.schemas.sage import SageAttentionConfig from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.validation import ValidationMixin @@ -1714,7 +1714,8 @@ def check_nvfp4_training(self): if ( model_config_type is not None and any(qwen3_5_native_flags) - and model_config_type not in ( + and model_config_type + not in ( "qwen3_5", "qwen3_5_moe", ) @@ -1772,9 +1773,11 @@ def check_nvfp4_training(self): # The standard fused linear cross-entropy kernel reads the lm_head weight # directly, bypassing the NVFP4 lm_head forward. The FP4-aware fused CE # path is the explicit opt-in exception. - if self.nvfp4_training.quantize_lm_head and getattr( - self, "cut_cross_entropy", None - ) and not self.nvfp4_training.fused_fp4_cross_entropy: + if ( + self.nvfp4_training.quantize_lm_head + and getattr(self, "cut_cross_entropy", None) + and not self.nvfp4_training.fused_fp4_cross_entropy + ): raise ValueError( "nvfp4_training.quantize_lm_head is incompatible with " "cut_cross_entropy: the fused linear cross-entropy kernel consumes " diff --git a/src/axolotl/utils/schemas/nvfp4.py b/src/axolotl/utils/schemas/nvfp4.py index 2ef1a9537c..8f100bf7e4 100644 --- a/src/axolotl/utils/schemas/nvfp4.py +++ b/src/axolotl/utils/schemas/nvfp4.py @@ -398,10 +398,19 @@ def _migrate_legacy_attention_flags(cls, data): "qwen3_5_native_attention": ("enabled",), "qwen3_5_fuse_vproj": ("fuse_vproj",), "qwen3_5_native_attention_backward": ("backward", "enabled"), - "qwen3_5_native_attention_backward_rtn_grad_packs": ("backward", "rtn_grad_packs"), + "qwen3_5_native_attention_backward_rtn_grad_packs": ( + "backward", + "rtn_grad_packs", + ), "qwen3_5_native_attention_save_backward_packs": ("backward", "save_packs"), - "qwen3_5_native_attention_dkdv_scratch_bf16": ("backward", "dkdv_scratch_bf16"), - "qwen3_5_native_attention_compile_custom_op": ("backward", "compile_custom_op"), + "qwen3_5_native_attention_dkdv_scratch_bf16": ( + "backward", + "dkdv_scratch_bf16", + ), + "qwen3_5_native_attention_compile_custom_op": ( + "backward", + "compile_custom_op", + ), } top_map = { "qwen3_5_native_linear_attn": "linear_attn", diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index 7e7f66bd28..444f8458a3 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -606,11 +606,34 @@ def run(batched): kA, kB = _lora_pair(ko, in_f, rank) vA, vB = _lora_pair(vo, in_f, rank) q, k, v = LoRA_QKV.apply( - X, None, - weights[0], None, None, qA, qB, 2.0, None, None, - weights[1], None, None, kA, kB, 2.0, None, None, - weights[2], None, None, vA, vB, 2.0, None, None, - inplace, batched, + X, + None, + weights[0], + None, + None, + qA, + qB, + 2.0, + None, + None, + weights[1], + None, + None, + kA, + kB, + 2.0, + None, + None, + weights[2], + None, + None, + vA, + vB, + 2.0, + None, + None, + inplace, + batched, ) loss = (q.float() ** 2).sum() + (k.float() ** 2).sum() + (v.float() ** 2).sum() loss.backward() @@ -625,7 +648,7 @@ def run(batched): assert torch.isfinite(o2).all() and torch.isfinite(x2).all() assert torch.equal(o1, o2) assert torch.equal(x1, x2) - for a, b in zip(g1, g2): + for a, b in zip(g1, g2, strict=False): assert torch.equal(a, b) @@ -646,12 +669,36 @@ def run(batched): uA, uB = _lora_pair(inter, in_f, rank) dA, dB = _lora_pair(in_f, inter, rank) out = LoRA_MLP.apply( - X, None, - gW, None, None, gA, gB, 2.0, None, None, - uW, None, None, uA, uB, 2.0, None, None, - dW, None, None, dA, dB, 2.0, None, None, - swiglu_forward, swiglu_backward, - inplace, batched, + X, + None, + gW, + None, + None, + gA, + gB, + 2.0, + None, + None, + uW, + None, + None, + uA, + uB, + 2.0, + None, + None, + dW, + None, + None, + dA, + dB, + 2.0, + None, + None, + swiglu_forward, + swiglu_backward, + inplace, + batched, ) (out.float() ** 2).sum().backward() return ( @@ -665,5 +712,5 @@ def run(batched): assert torch.isfinite(o2).all() and torch.isfinite(x2).all() assert torch.equal(o1, o2) assert torch.equal(x1, x2) - for a, b in zip(g1, g2): + for a, b in zip(g1, g2, strict=False): assert torch.equal(a, b) diff --git a/tests/e2e/test_nvfp4_integration.py b/tests/e2e/test_nvfp4_integration.py index b32b204fd0..17a6aabfad 100644 --- a/tests/e2e/test_nvfp4_integration.py +++ b/tests/e2e/test_nvfp4_integration.py @@ -1,7 +1,5 @@ """Schema, gate, and (GPU-gated) end-to-end tests for nvfp4_training.""" -import os - import pytest import axolotl.utils.nvfp4_training as nvfp4_mod @@ -22,7 +20,12 @@ # Capability gates live on AxolotlConfigWCapabilities; supply the two capability # blocks so the gate validator runs without touching real hardware. CAPS = { - "capabilities": {"bf16": True, "fp8": True, "n_gpu": 1, "compute_capability": "sm_120"}, + "capabilities": { + "bf16": True, + "fp8": True, + "n_gpu": 1, + "compute_capability": "sm_120", + }, "env_capabilities": {"torch_version": "2.8.0"}, } @@ -51,7 +54,9 @@ def test_schema_accepts_valid_nvfp4_config(monkeypatch): def test_schema_backend_defaults_native_and_accepts_te(monkeypatch): _supported(monkeypatch, True) assert ( - AxolotlInputConfig(**BASE, nvfp4_training={"enabled": True}).nvfp4_training.backend + AxolotlInputConfig( + **BASE, nvfp4_training={"enabled": True} + ).nvfp4_training.backend == "native" ) cfg = AxolotlInputConfig(**BASE, nvfp4_training={"enabled": True, "backend": "te"}) @@ -222,9 +227,7 @@ def test_gate_refuses_qwen3_5_switch_on_other_model(monkeypatch): def test_gate_refuses_unsupported_hardware(monkeypatch): _supported(monkeypatch, False, "no Blackwell here") with pytest.raises(ValueError, match="no Blackwell here"): - AxolotlConfigWCapabilities( - **BASE, **CAPS, nvfp4_training={"enabled": True} - ) + AxolotlConfigWCapabilities(**BASE, **CAPS, nvfp4_training={"enabled": True}) def test_gate_allows_lora(monkeypatch): @@ -292,9 +295,7 @@ def test_gate_refuses_fp16(monkeypatch): def test_disabled_nvfp4_skips_gate(monkeypatch): _supported(monkeypatch, False, "should not be raised") - cfg = AxolotlConfigWCapabilities( - **BASE, **CAPS, nvfp4_training={"enabled": False} - ) + cfg = AxolotlConfigWCapabilities(**BASE, **CAPS, nvfp4_training={"enabled": False}) assert cfg.nvfp4_training.enabled is False @@ -444,7 +445,6 @@ def test_qwen3_5_compile_custom_op_explicit_optout_under_torch_compile(monkeypat def _tiny_lora_model(): """A 2-layer toy model wrapped with a PEFT LoRA adapter (CPU-friendly).""" - import torch from peft import LoraConfig, get_peft_model from torch import nn @@ -486,9 +486,7 @@ def fake_patch(_model, **kwargs): captured.update(kwargs) return 1 - monkeypatch.setattr( - nvfp4_flash_attn, "patch_qwen3_5_nvfp4_attention", fake_patch - ) + monkeypatch.setattr(nvfp4_flash_attn, "patch_qwen3_5_nvfp4_attention", fake_patch) pm = _patch_manager( { "model_config_type": "qwen3_5", @@ -563,26 +561,21 @@ def test_apply_selects_lora_compute_mode(monkeypatch): from axolotl.utils.nvfp4_training import ( NVFP4ComputeBaseLinear, NVFP4FastComputeBaseLinear, - NVFP4FrozenBaseLinear, NVFP4FastFrozenBaseLinear, + NVFP4FrozenBaseLinear, ) model = _tiny_lora_model() - pm = _patch_manager( - {"adapter": "lora", "nvfp4_training": {"enabled": True}} - ) + pm = _patch_manager({"adapter": "lora", "nvfp4_training": {"enabled": True}}) pm._apply_nvfp4_training(model) - bases = [ - m.base_layer for m in model.modules() if isinstance(m, LoraLinear) - ] + bases = [m.base_layer for m in model.modules() if isinstance(m, LoraLinear)] assert bases and all( isinstance(b, (NVFP4ComputeBaseLinear, NVFP4FastComputeBaseLinear)) for b in bases ) assert not any( - isinstance(b, (NVFP4FrozenBaseLinear, NVFP4FastFrozenBaseLinear)) - for b in bases + isinstance(b, (NVFP4FrozenBaseLinear, NVFP4FastFrozenBaseLinear)) for b in bases ) @@ -606,7 +599,6 @@ def test_apply_selects_hp_mode_when_requested(monkeypatch): def test_apply_selects_fft_mode_when_no_adapter(monkeypatch): """No adapter -> raw nn.Linear swapped to NVFP4Linear (full fine-tune).""" _supported(monkeypatch, True) - import torch from torch import nn from axolotl.utils.nvfp4_training import NVFP4Linear @@ -685,9 +677,9 @@ def test_e2e_lora_swap_and_train_step(quantize_base): from transformers import AutoModelForCausalLM from axolotl.utils.nvfp4_training import ( - NVFP4FrozenBaseLinear, - NVFP4FastFrozenBaseLinear, NVFP4FastComputeBaseLinear, + NVFP4FastFrozenBaseLinear, + NVFP4FrozenBaseLinear, NVFP4Linear, NVFP4Recipe, convert_lora_base_to_nvfp4, diff --git a/tests/kernels/test_nvfp4_rope_quant_strided.py b/tests/kernels/test_nvfp4_rope_quant_strided.py index 2e65dc4536..5059cfc110 100644 --- a/tests/kernels/test_nvfp4_rope_quant_strided.py +++ b/tests/kernels/test_nvfp4_rope_quant_strided.py @@ -2,6 +2,7 @@ view and produce BIT-IDENTICAL packs to the contiguous path — the invariant behind dropping the per-layer .contiguous() copy (prefill grab #2). The production caller passes q_norm(...).transpose(1,2), exactly this layout.""" + import pytest import torch @@ -12,7 +13,9 @@ from axolotl.kernels.nvfp4_fused_producers import fused_rope_quant_qk -@pytest.mark.parametrize("Z,H,S,D", [(1, 16, 300, 256), (1, 8, 256, 128), (2, 4, 128, 256)]) +@pytest.mark.parametrize( + "Z,H,S,D", [(1, 16, 300, 256), (1, 8, 256, 128), (2, 4, 128, 256)] +) def test_strided_matches_contiguous(Z, H, S, D): torch.manual_seed(0) rot = D @@ -22,7 +25,7 @@ def test_strided_matches_contiguous(Z, H, S, D): cos = torch.randn(Z, S, rot, device="cuda", dtype=torch.bfloat16) sin = torch.randn(Z, S, rot, device="cuda", dtype=torch.bfloat16) - q_s, sc_s = fused_rope_quant_qk(x_t, cos, sin) # strided (no copy) + q_s, sc_s = fused_rope_quant_qk(x_t, cos, sin) # strided (no copy) q_c, sc_c = fused_rope_quant_qk(x_t.contiguous(), cos, sin) # contiguous reference assert torch.equal(q_s, q_c), "packed FP4 differs between strided and contiguous" @@ -36,10 +39,14 @@ def test_noncontiguous_d_falls_back(): torch.manual_seed(0) Z, H, S, D = 1, 4, 64, 128 # make D non-unit-stride by transposing S<->D, then take a view where D is dim 3 - x = torch.randn(Z, H, D, S, device="cuda", dtype=torch.bfloat16).transpose(2, 3) # [Z,H,S,D], D stride = S + x = torch.randn(Z, H, D, S, device="cuda", dtype=torch.bfloat16).transpose( + 2, 3 + ) # [Z,H,S,D], D stride = S assert x.stride(3) != 1 cos = torch.randn(Z, S, D, device="cuda", dtype=torch.bfloat16) sin = torch.randn(Z, S, D, device="cuda", dtype=torch.bfloat16) q_fb, sc_fb = fused_rope_quant_qk(x, cos, sin) q_ref, sc_ref = fused_rope_quant_qk(x.contiguous(), cos, sin) - assert torch.equal(q_fb, q_ref) and torch.equal(sc_fb.view(torch.uint8), sc_ref.view(torch.uint8)) + assert torch.equal(q_fb, q_ref) and torch.equal( + sc_fb.view(torch.uint8), sc_ref.view(torch.uint8) + ) From 0a5b3a6450ae903e68726a54af1182731eafe088 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 4 Jun 2026 15:40:35 -0700 Subject: [PATCH 17/17] chore: repo-wide pre-commit cleanup (ruff-format, mypy/ruff suppressions) Apply ruff-format across the pre-existing nvfp4 files, file-level mypy disable-error-code for the loosely-typed kernels/scripts, E741 noqa for the flash-kernel O accumulator, and B007/B023/B905 fixes in the e2e tests. Brings pre-commit --all-files fully green (the beta base was not pre-commit-clean). --- scripts/bench_nvfp4_ce_kernel.py | 47 ++-- scripts/nvfp4_cuda_graph_loop.py | 8 +- src/axolotl/kernels/attn_nvfp4.py | 6 +- src/axolotl/kernels/attn_nvfp4_custom_op.py | 57 ++++- src/axolotl/kernels/attn_qat_flash.py | 219 ++++++++++++++---- src/axolotl/kernels/fp8_fused_ce.py | 21 +- src/axolotl/kernels/nvfp4_fused_ce.py | 33 +-- src/axolotl/kernels/nvfp4_linear.py | 59 +++-- .../kernels/nvfp4_quant_fusion_proto.py | 75 ++++-- .../attention/nvfp4_linear_attn.py | 42 ++-- .../monkeypatch/attention/sage_fp4_attn.py | 4 +- .../monkeypatch/models/qwen3_5/modeling.py | 4 +- .../monkeypatch/models/qwen_fused_attn.py | 4 +- src/axolotl/utils/attn_qat.py | 16 +- src/axolotl/utils/nvfp4_cuda_graph_loop.py | 27 ++- tests/e2e/test_nvfp4_training.py | 37 +-- tests/e2e/test_quantization.py | 2 +- 17 files changed, 451 insertions(+), 210 deletions(-) diff --git a/scripts/bench_nvfp4_ce_kernel.py b/scripts/bench_nvfp4_ce_kernel.py index 0b4ff7f542..36c59abe49 100644 --- a/scripts/bench_nvfp4_ce_kernel.py +++ b/scripts/bench_nvfp4_ce_kernel.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="operator" #!/usr/bin/env python """Microbenchmark NVFP4 lm_head + CE variants. @@ -110,7 +111,9 @@ def main() -> None: print(f"liger=unavailable ({type(liger_exc).__name__}: {liger_exc})") def materialized_loss(x): - return F.cross_entropy((x @ dense_weight.t()).float(), labels, ignore_index=-100) + return F.cross_entropy( + (x @ dense_weight.t()).float(), labels, ignore_index=-100 + ) def old_loss(x): return fused_fp4_cross_entropy(x, head, labels, shift=False, fp4_matmul=False) @@ -139,6 +142,7 @@ def new_loss(x): def wrap_loss(loss_fn): if not args.backward: + def run(): with torch.no_grad(): loss_fn(hidden) @@ -155,28 +159,37 @@ def run(): return run timings: list[tuple[str, float | None]] = [] - timings.append(( - "materialized_bf16_ce", - _time_cuda(wrap_loss(materialized_loss), args.warmup, args.iters), - )) + timings.append( + ( + "materialized_bf16_ce", + _time_cuda(wrap_loss(materialized_loss), args.warmup, args.iters), + ) + ) if liger is not None: + def liger_loss(x): return liger(dense_weight, x, labels) - timings.append(( - "liger_fused_linear_ce", - _time_cuda(wrap_loss(liger_loss), args.warmup, args.iters), - )) + timings.append( + ( + "liger_fused_linear_ce", + _time_cuda(wrap_loss(liger_loss), args.warmup, args.iters), + ) + ) old_probe = old_loss(hidden) if old_probe is not None: - timings.append(( - "existing_nvfp4_fused_ce", - _time_cuda(wrap_loss(old_loss), args.warmup, args.iters), - )) - timings.append(( - "fp4_scaled_mm_ce", - _time_cuda(wrap_loss(new_loss), args.warmup, args.iters), - )) + timings.append( + ( + "existing_nvfp4_fused_ce", + _time_cuda(wrap_loss(old_loss), args.warmup, args.iters), + ) + ) + timings.append( + ( + "fp4_scaled_mm_ce", + _time_cuda(wrap_loss(new_loss), args.warmup, args.iters), + ) + ) for name, ms in timings: tok_s = args.tokens / (ms / 1000.0) diff --git a/scripts/nvfp4_cuda_graph_loop.py b/scripts/nvfp4_cuda_graph_loop.py index 9ed8918143..d5ad993b8e 100755 --- a/scripts/nvfp4_cuda_graph_loop.py +++ b/scripts/nvfp4_cuda_graph_loop.py @@ -27,11 +27,15 @@ def parse_args(): ) compile_group = parser.add_mutually_exclusive_group() compile_group.add_argument("--compile", dest="compile_model", action="store_true") - compile_group.add_argument("--no-compile", dest="compile_model", action="store_false") + compile_group.add_argument( + "--no-compile", dest="compile_model", action="store_false" + ) parser.set_defaults(compile_model=None) parser.add_argument("--fullgraph", action="store_true") parser.add_argument("--probe-only", action="store_true") - parser.add_argument("--no-probe-on-fail", dest="probe_on_fail", action="store_false") + parser.add_argument( + "--no-probe-on-fail", dest="probe_on_fail", action="store_false" + ) parser.set_defaults(probe_on_fail=True) return parser.parse_args() diff --git a/src/axolotl/kernels/attn_nvfp4.py b/src/axolotl/kernels/attn_nvfp4.py index 9e07d5754c..aa94de6393 100644 --- a/src/axolotl/kernels/attn_nvfp4.py +++ b/src/axolotl/kernels/attn_nvfp4.py @@ -54,11 +54,7 @@ def _repeat_kv(t: torch.Tensor, n_rep: int) -> torch.Tensor: if n_rep == 1: return t z, hk, s, d = t.shape - return ( - t[:, :, None, :, :] - .expand(z, hk, n_rep, s, d) - .reshape(z, hk * n_rep, s, d) - ) + return t[:, :, None, :, :].expand(z, hk, n_rep, s, d).reshape(z, hk * n_rep, s, d) def nvfp4_attention( diff --git a/src/axolotl/kernels/attn_nvfp4_custom_op.py b/src/axolotl/kernels/attn_nvfp4_custom_op.py index 3b08a8cc5c..681b7841aa 100644 --- a/src/axolotl/kernels/attn_nvfp4_custom_op.py +++ b/src/axolotl/kernels/attn_nvfp4_custom_op.py @@ -30,7 +30,9 @@ def _code_to_dtype(code: int) -> torch.dtype: try: return _CODE_TO_DTYPE[int(code)] except KeyError as exc: - raise TypeError(f"unsupported NVFP4 attention output dtype code: {code}") from exc + raise TypeError( + f"unsupported NVFP4 attention output dtype code: {code}" + ) from exc @torch.library.custom_op("axolotl_nvfp4::flash_attention_packed", mutates_args=()) @@ -315,9 +317,19 @@ def _flash_attention_train_bwd_op( do, out.reshape(z * h, s_q, d).contiguous(), bias, - z, h, hk, s_q, s_kv, d, - scaling, causal, sr, - block_m, block_n, num_warps, num_stages, + z, + h, + hk, + s_q, + s_kv, + d, + scaling, + causal, + sr, + block_m, + block_n, + num_warps, + num_stages, lse=None, sr_p_dv=backward_p_dv_sr, sr_dot_dv=backward_dot_dv_sr, @@ -389,9 +401,22 @@ def _flash_attention_train_backward(ctx, grad_out): ) # one grad slot per forward input (16 inputs) return ( - dq, dk, dv, - None, None, None, None, None, None, None, None, - None, None, None, None, None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, ) @@ -434,9 +459,21 @@ def nvfp4_flash_attn_train_custom_op( else key_pad_bias ) sr = stochastic_rounding - p_dv = sr if backward_p_dv_stochastic_rounding is None else backward_p_dv_stochastic_rounding - dot_dv = sr if backward_dot_dv_stochastic_rounding is None else backward_dot_dv_stochastic_rounding - ds_dq = sr if backward_ds_dq_stochastic_rounding is None else backward_ds_dq_stochastic_rounding + p_dv = ( + sr + if backward_p_dv_stochastic_rounding is None + else backward_p_dv_stochastic_rounding + ) + dot_dv = ( + sr + if backward_dot_dv_stochastic_rounding is None + else backward_dot_dv_stochastic_rounding + ) + ds_dq = ( + sr + if backward_ds_dq_stochastic_rounding is None + else backward_ds_dq_stochastic_rounding + ) return torch.ops.axolotl_nvfp4.flash_attention_train( query, key, diff --git a/src/axolotl/kernels/attn_qat_flash.py b/src/axolotl/kernels/attn_qat_flash.py index 740313e1f7..5e3aca4293 100644 --- a/src/axolotl/kernels/attn_qat_flash.py +++ b/src/axolotl/kernels/attn_qat_flash.py @@ -1,3 +1,4 @@ +# ruff: noqa: E741 """Fused FlashAttention-style NVFP4 fake-quant attention (Attn-QAT, arXiv 2603.00040). Long-context follow-up to the eager v1 in ``axolotl.utils.attn_qat``: instead of @@ -56,7 +57,6 @@ if HAS_TRITON: - _BLK = tl.constexpr(16) _F4MAX = tl.constexpr(6.0) _EPS = tl.constexpr(0.015625) @@ -73,13 +73,27 @@ def _round_e2m1(x): s = tl.where(x >= 0, 1.0, -1.0) a = tl.abs(x) # nearest grid point by midpoint thresholds - r = tl.where(a < 0.25, 0.0, - tl.where(a < 0.75, 0.5, - tl.where(a < 1.25, 1.0, - tl.where(a < 1.75, 1.5, - tl.where(a < 2.5, 2.0, - tl.where(a < 3.5, 3.0, - tl.where(a < 5.0, 4.0, 6.0))))))) + r = tl.where( + a < 0.25, + 0.0, + tl.where( + a < 0.75, + 0.5, + tl.where( + a < 1.25, + 1.0, + tl.where( + a < 1.75, + 1.5, + tl.where( + a < 2.5, + 2.0, + tl.where(a < 3.5, 3.0, tl.where(a < 5.0, 4.0, 6.0)), + ), + ), + ), + ), + ) # ties-to-even at the midpoints r = tl.where(a == 0.25, 0.0, r) r = tl.where(a == 0.75, 1.0, r) @@ -134,14 +148,35 @@ def _fake_quant_p(p, n_valid, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): @triton.jit def _attn_qat_fwd( - Q, K, V, sm_scale, B, - O, Op, M, # noqa: E741 - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_ok, + Q, + K, + V, + sm_scale, + B, + O, + Op, + M, # noqa: E741 + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_ok, stride_bz, - Z, H, N_CTX, N_KV, + Z, + H, + N_CTX, + N_KV, HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -246,16 +281,42 @@ def _attn_qat_fwd( @triton.jit def _attn_qat_bwd( - Q, K, V, sm_scale, B, - DO, Op, M, - DQ, DK, DV, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_ok, - stride_dkz, stride_dkh, stride_dkn, stride_dkk, + Q, + K, + V, + sm_scale, + B, + DO, + Op, + M, + DQ, + DK, + DV, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_ok, + stride_dkz, + stride_dkh, + stride_dkn, + stride_dkk, stride_bz, - Z, H, N_CTX, N_KV, + Z, + H, + N_CTX, + N_KV, HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -308,8 +369,12 @@ def _attn_qat_bwd( m_mask = offs_m < N_CTX q_ptrs = q_base + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - do_ptrs = do_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok - op_ptrs = op_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + do_ptrs = ( + do_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + ) + op_ptrs = ( + op_base + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + ) q = tl.load(q_ptrs, mask=m_mask[:, None], other=0.0) do = tl.load(do_ptrs, mask=m_mask[:, None], other=0.0).to(tl.float32) op = tl.load(op_ptrs, mask=m_mask[:, None], other=0.0).to(tl.float32) @@ -337,7 +402,9 @@ def _attn_qat_bwd( ds = tl.where(n_mask[None, :], ds, 0.0) dk += tl.dot(tl.trans(ds).to(qf.dtype), qf) dq = tl.dot(ds.to(kf.dtype), kf) - dq_ptrs = dq_base + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + dq_ptrs = ( + dq_base + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + ) tl.atomic_add(dq_ptrs, dq, mask=m_mask[:, None]) dk_ptrs = dk_base + offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkk @@ -375,16 +442,43 @@ def forward(ctx, q, k, v, sm_scale, causal, gqa_group, bias): m = torch.empty((Z * H, N_CTX), device=q.device, dtype=torch.float32) grid = (triton.cdiv(N_CTX, BLOCK_M), Z * H) _attn_qat_fwd[grid]( - q, k, v, sm_scale, b, o, op, m, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q, + k, + v, + sm_scale, + b, + o, + op, + m, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), stride_bz, - Z, H, N_CTX, N_KV, - HEAD_DIM=D, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - CAUSAL=causal, HAS_BIAS=has_bias, GQA_GROUP=gqa_group, - num_warps=4, num_stages=1, + Z, + H, + N_CTX, + N_KV, + HEAD_DIM=D, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + CAUSAL=causal, + HAS_BIAS=has_bias, + GQA_GROUP=gqa_group, + num_warps=4, + num_stages=1, ) ctx.save_for_backward(q, k, v, op, m, bias) ctx.sm_scale = sm_scale @@ -410,17 +504,50 @@ def backward(ctx, do): dv = torch.empty_like(q) grid = (triton.cdiv(N_KV, BLOCK_N), Z * H) _attn_qat_bwd[grid]( - q, k, v, ctx.sm_scale, b, do, op, m, dq, dk, dv, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - do.stride(0), do.stride(1), do.stride(2), do.stride(3), - dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), + q, + k, + v, + ctx.sm_scale, + b, + do, + op, + m, + dq, + dk, + dv, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), stride_bz, - Z, H, N_CTX, N_KV, - HEAD_DIM=D, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - CAUSAL=ctx.causal, HAS_BIAS=has_bias, GQA_GROUP=ctx.gqa_group, - num_warps=4, num_stages=1, + Z, + H, + N_CTX, + N_KV, + HEAD_DIM=D, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + CAUSAL=ctx.causal, + HAS_BIAS=has_bias, + GQA_GROUP=ctx.gqa_group, + num_warps=4, + num_stages=1, ) dq = dq.to(q.dtype) if ctx.gqa_group > 1: diff --git a/src/axolotl/kernels/fp8_fused_ce.py b/src/axolotl/kernels/fp8_fused_ce.py index e00019ad32..76dba38773 100644 --- a/src/axolotl/kernels/fp8_fused_ce.py +++ b/src/axolotl/kernels/fp8_fused_ce.py @@ -50,9 +50,7 @@ def prepack_lm_head_weight_fp8_ce( dgrad_scale = _scale_from_amax(weight.abs().max()) else: dgrad_scale = _scale_from_amax(weight.abs().amax(dim=0, keepdim=True)) - dgrad_weight = _to_col_major_for_scaled_mm( - _quantize_e4m3(weight, dgrad_scale) - ) + dgrad_weight = _to_col_major_for_scaled_mm(_quantize_e4m3(weight, dgrad_scale)) return FP8FusedCEWeight( fprop=fprop, dgrad_weight=dgrad_weight, @@ -148,14 +146,9 @@ def backward(ctx, grad_loss): H = packed.fprop.in_features M = hidden_fp8.shape[0] rows = torch.arange(M, device=hidden_fp8.device) - grad_hidden = torch.zeros( - M, H, device=hidden_fp8.device, dtype=torch.float32 - ) + grad_hidden = torch.zeros(M, H, device=hidden_fp8.device, dtype=torch.float32) coef = ( - grad_loss.float() - * ctx.grad_scale - * valid.float() - * ctx.logit_scale + grad_loss.float() * ctx.grad_scale * valid.float() * ctx.logit_scale ).unsqueeze(1) for lo in range(0, V, _VOCAB_BLOCK): @@ -266,9 +259,7 @@ def forward(self, *args, **kwargs): labels, num_items_in_batch=num_items_in_batch, shift=True, - granularity=getattr( - self, "_axolotl_fp8_lm_head_ce_granularity", "rowwise" - ), + granularity=getattr(self, "_axolotl_fp8_lm_head_ce_granularity", "rowwise"), ) if loss is None: kwargs["labels"] = labels @@ -312,9 +303,7 @@ def patch_model_fp8_lm_head_cross_entropy( ) return False if lm_head.bias is not None or lm_head.weight.requires_grad: - LOG.warning( - "fp8_lm_head_cross_entropy: requires a frozen bias-free lm_head" - ) + LOG.warning("fp8_lm_head_cross_entropy: requires a frozen bias-free lm_head") return False if lm_head.weight.shape[0] % 16 or lm_head.weight.shape[1] % 16: LOG.warning("fp8_lm_head_cross_entropy: lm_head dims are not FP8-eligible") diff --git a/src/axolotl/kernels/nvfp4_fused_ce.py b/src/axolotl/kernels/nvfp4_fused_ce.py index 2a742d252a..46ed5b43a5 100644 --- a/src/axolotl/kernels/nvfp4_fused_ce.py +++ b/src/axolotl/kernels/nvfp4_fused_ce.py @@ -83,6 +83,8 @@ def _nvfp4_lm_head_store(module: nn.Module): def _nvfp4_lm_head_fp4_store(module: nn.Module): """Return a ``[V, H]`` FP4 store usable by tiled ``torch._scaled_mm``.""" + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + from axolotl.utils.nvfp4_training import ( NVFP4ComputeBaseLinear, NVFP4FastComputeBaseLinear, @@ -90,7 +92,6 @@ def _nvfp4_lm_head_fp4_store(module: nn.Module): NVFP4FrozenBaseLinear, NVFP4TiedLMHead, ) - from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor if isinstance(module, (NVFP4FrozenBaseLinear, NVFP4TiedLMHead)): return module.w_q @@ -175,7 +176,9 @@ def _fp4_logits_tile( return logits -def _quantize_hidden_sl(hidden: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None: +def _quantize_hidden_sl( + hidden: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor] | None: try: from axolotl.utils.nvfp4_training import _mslk_quantize_sl @@ -208,7 +211,9 @@ def forward(ctx, hidden, store, labels, ignore_index, scale, grad_scale): valid = labels != ignore_index safe_labels = torch.where(valid, labels, labels.new_zeros(())) - running_max = torch.full((M,), float("-inf"), device=device, dtype=torch.float32) + running_max = torch.full( + (M,), float("-inf"), device=device, dtype=torch.float32 + ) running_sum = torch.zeros(M, device=device, dtype=torch.float32) label_logit = torch.zeros(M, device=device, dtype=torch.float32) @@ -252,7 +257,9 @@ def backward(ctx, grad_loss): dtype = hidden.dtype # d(loss)/d(logit_v) = grad_loss * grad_scale * mask * (softmax_v - onehot_v) * scale - coef = (grad_loss * ctx.grad_scale * valid.float() * scale).unsqueeze(1) # [M,1] + coef = (grad_loss * ctx.grad_scale * valid.float() * scale).unsqueeze( + 1 + ) # [M,1] rows = torch.arange(M, device=hidden.device) grad_hidden = torch.zeros(M, H, device=hidden.device, dtype=dtype) @@ -284,13 +291,17 @@ def forward(ctx, hidden, store, labels, ignore_index, scale, grad_scale): hidden_q = _quantize_hidden_sl(hidden) if hidden_q is None: - raise RuntimeError("MSLK single-level NVFP4 activation quant is unavailable") + raise RuntimeError( + "MSLK single-level NVFP4 activation quant is unavailable" + ) hidden_qdata, hidden_scale = hidden_q valid = labels != ignore_index safe_labels = torch.where(valid, labels, labels.new_zeros(())) - running_max = torch.full((M,), float("-inf"), device=device, dtype=torch.float32) + running_max = torch.full( + (M,), float("-inf"), device=device, dtype=torch.float32 + ) running_sum = torch.zeros(M, device=device, dtype=torch.float32) label_logit = torch.zeros(M, device=device, dtype=torch.float32) @@ -386,9 +397,7 @@ def fused_fp4_cross_entropy( valid = labels1d != ignore_index if num_items_in_batch is not None: denom = num_items_in_batch - grad_scale = ( - 1.0 / denom if torch.is_tensor(denom) else 1.0 / float(denom) - ) + grad_scale = 1.0 / denom if torch.is_tensor(denom) else 1.0 / float(denom) else: grad_scale = 1.0 / valid.sum().clamp(min=1).float() @@ -447,11 +456,7 @@ def forward(self, *args, **kwargs): ) # Only intercept the training path with an FP4, tile-able head. Anything # else (generation, non-FP4 head, logits_to_keep slicing) -> original. - if ( - labels is None - or kwargs.get("logits_to_keep") - or not has_fp4_ce - ): + if labels is None or kwargs.get("logits_to_keep") or not has_fp4_ce: return orig_forward(self, *args, **kwargs) # Run the base model to get hidden states (mirror the HF forward prologue). diff --git a/src/axolotl/kernels/nvfp4_linear.py b/src/axolotl/kernels/nvfp4_linear.py index 18dac9952f..4a59222ec8 100644 --- a/src/axolotl/kernels/nvfp4_linear.py +++ b/src/axolotl/kernels/nvfp4_linear.py @@ -29,13 +29,22 @@ @triton.jit def _nvfp4_gemm_kernel( - anv_ptr, asc_ptr, # [M, K//2] uint8, [M, K//16] e4m3 (activation) - wnv_ptr, wsc_ptr, # [N, K//2] uint8, [N, K//16] e4m3 (weight, W[N,K]) - out_ptr, # [M, N] out_dtype - M, N, K, - s_am, s_wn, s_om, # row strides (col stride = 1) - s_asc, s_wsc, # scale row strides - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + anv_ptr, + asc_ptr, # [M, K//2] uint8, [M, K//16] e4m3 (activation) + wnv_ptr, + wsc_ptr, # [N, K//2] uint8, [N, K//16] e4m3 (weight, W[N,K]) + out_ptr, # [M, N] out_dtype + M, + N, + K, + s_am, + s_wn, + s_om, # row strides (col stride = 1) + s_asc, + s_wsc, # scale row strides + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) @@ -52,19 +61,23 @@ def _nvfp4_gemm_kernel( offk16 = k0 // 16 + tl.arange(0, KP16) a = tl.load( anv_ptr + offs_m[:, None] * s_am + offk2[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ) asc = tl.load( asc_ptr + offs_m[:, None] * s_asc + offk16[None, :], - mask=mmask[:, None], other=0, + mask=mmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) w = tl.load( wnv_ptr + offs_n[:, None] * s_wn + offk2[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ) wsc = tl.load( wsc_ptr + offs_n[:, None] * s_wsc + offk16[None, :], - mask=nmask[:, None], other=0, + mask=nmask[:, None], + other=0, ).to(tl.float8e4nv, bitcast=True) acc = tl.dot_scaled(a, asc, "e2m1", w.T, wsc, "e2m1", acc=acc) @@ -118,13 +131,23 @@ def nvfp4_linear( out = torch.empty(m, n_out, device=x.device, dtype=out_dtype) grid = (triton.cdiv(m, block_m), triton.cdiv(n_out, block_n)) _nvfp4_gemm_kernel[grid]( - anv.view(torch.uint8), asc.view(torch.uint8), - wnv.view(torch.uint8), wsc.view(torch.uint8), + anv.view(torch.uint8), + asc.view(torch.uint8), + wnv.view(torch.uint8), + wsc.view(torch.uint8), out, - m, n_out, k, - anv.stride(0), wnv.stride(0), out.stride(0), - asc.stride(0), wsc.stride(0), - BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, - num_warps=num_warps, num_stages=num_stages, + m, + n_out, + k, + anv.stride(0), + wnv.stride(0), + out.stride(0), + asc.stride(0), + wsc.stride(0), + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=num_warps, + num_stages=num_stages, ) return out.reshape(*lead, n_out) diff --git a/src/axolotl/kernels/nvfp4_quant_fusion_proto.py b/src/axolotl/kernels/nvfp4_quant_fusion_proto.py index 650c739bc7..0bb1cc6dd8 100644 --- a/src/axolotl/kernels/nvfp4_quant_fusion_proto.py +++ b/src/axolotl/kernels/nvfp4_quant_fusion_proto.py @@ -28,7 +28,6 @@ import torch import triton import triton.language as tl - from mslk.quantize.triton.fp4_quantize import ( convert_fp32_to_fp4_packed, nvfp4_scale_swizzle, @@ -140,8 +139,9 @@ def _rope_quant_kernel( x_rot = x * cos + rot * sin x_blocks = x_rot.reshape(M_PER_BLOCK, 4, 16) - _quant_emit(x_blocks, q_ptr, s_ptr, gscale, pid_m, pid_n, M, N, - M_PER_BLOCK, NUM_N_BLOCKS) + _quant_emit( + x_blocks, q_ptr, s_ptr, gscale, pid_m, pid_n, M, N, M_PER_BLOCK, NUM_N_BLOCKS + ) def fused_rope_quant(x2d, cos2d, sin2d, two_level: bool = False): @@ -171,11 +171,23 @@ def fused_rope_quant(x2d, cos2d, sin2d, two_level: bool = False): M_PER_BLOCK = min(triton.next_power_of_2(M), 128) grid = (triton.cdiv(N, 64), triton.cdiv(M, M_PER_BLOCK)) _rope_quant_kernel[grid]( - x2d, cos2d, sin2d, q, s, float(gscale), M, N, N // 2, - M_PER_BLOCK=M_PER_BLOCK, NUM_N_BLOCKS=triton.cdiv(N, 64), + x2d, + cos2d, + sin2d, + q, + s, + float(gscale), + M, + N, + N // 2, + M_PER_BLOCK=M_PER_BLOCK, + NUM_N_BLOCKS=triton.cdiv(N, 64), + ) + return ( + q.view(torch.float4_e2m1fn_x2), + s.view(torch.float8_e4m3fn), + (1.0 / gscale).to(x2d.dtype), ) - return (q.view(torch.float4_e2m1fn_x2), s.view(torch.float8_e4m3fn), - (1.0 / gscale).to(x2d.dtype)) # ---------------------------------------------------------------------------- @@ -232,10 +244,9 @@ def _softmax_quant_kernel( sb = tl.arange(0, N_BLOCKS) # group-block index along N pid_n = sb // 4 in4 = sb % 4 - layout_off = ( - (padded_r // 128) * (tl.cdiv(N, 64)) * NUM_ELEM_PER_LAYOUT - + pid_n * NUM_ELEM_PER_LAYOUT - ) + layout_off = (padded_r // 128) * ( + tl.cdiv(N, 64) + ) * NUM_ELEM_PER_LAYOUT + pid_n * NUM_ELEM_PER_LAYOUT offs_m_in_layout = (padded_r % 128).to(tl.int32) # swizzle for a single row m: sub_layout_off + (m//32)*4 + in4 m = offs_m_in_layout @@ -257,8 +268,14 @@ def fused_softmax_quant(scores: torch.Tensor, out_dtype=torch.bfloat16): BLOCK_N = triton.next_power_of_2(N) grid = (M,) _softmax_quant_kernel[grid]( - scores, q, s, gscale, M, N, - BLOCK_N=BLOCK_N, N_BLOCKS=N // 16, + scores, + q, + s, + gscale, + M, + N, + BLOCK_N=BLOCK_N, + N_BLOCKS=N // 16, ) inv_gs = torch.tensor(1.0 / gscale, device=scores.device, dtype=out_dtype) return q.view(torch.float4_e2m1fn_x2), s.view(torch.float8_e4m3fn), inv_gs @@ -269,8 +286,14 @@ def fused_softmax_quant(scores: torch.Tensor, out_dtype=torch.bfloat16): # ---------------------------------------------------------------------------- @triton.jit def _identity_quant_kernel( - x_ptr, q_ptr, s_ptr, gscale, M, N, - M_PER_BLOCK: tl.constexpr, NUM_N_BLOCKS: tl.constexpr, + x_ptr, + q_ptr, + s_ptr, + gscale, + M, + N, + M_PER_BLOCK: tl.constexpr, + NUM_N_BLOCKS: tl.constexpr, ): pid_m = tl.program_id(1) pid_n = tl.program_id(0) @@ -279,8 +302,9 @@ def _identity_quant_kernel( mask = (offs_m < M) & (offs_n < N) x = tl.load(x_ptr + offs_m * N + offs_n, mask=mask, other=0.0).to(tl.float32) x_blocks = x.reshape(M_PER_BLOCK, 4, 16) - _quant_emit(x_blocks, q_ptr, s_ptr, gscale, pid_m, pid_n, M, N, - M_PER_BLOCK, NUM_N_BLOCKS) + _quant_emit( + x_blocks, q_ptr, s_ptr, gscale, pid_m, pid_n, M, N, M_PER_BLOCK, NUM_N_BLOCKS + ) def fused_vproj_quant(v2d: torch.Tensor, two_level: bool = False): @@ -303,8 +327,17 @@ def fused_vproj_quant(v2d: torch.Tensor, two_level: bool = False): M_PER_BLOCK = min(triton.next_power_of_2(M), 128) grid = (triton.cdiv(N, 64), triton.cdiv(M, M_PER_BLOCK)) _identity_quant_kernel[grid]( - v2d, q, s, float(gscale), M, N, - M_PER_BLOCK=M_PER_BLOCK, NUM_N_BLOCKS=triton.cdiv(N, 64), + v2d, + q, + s, + float(gscale), + M, + N, + M_PER_BLOCK=M_PER_BLOCK, + NUM_N_BLOCKS=triton.cdiv(N, 64), + ) + return ( + q.view(torch.float4_e2m1fn_x2), + s.view(torch.float8_e4m3fn), + (1.0 / gscale).to(v2d.dtype), ) - return (q.view(torch.float4_e2m1fn_x2), s.view(torch.float8_e4m3fn), - (1.0 / gscale).to(v2d.dtype)) diff --git a/src/axolotl/monkeypatch/attention/nvfp4_linear_attn.py b/src/axolotl/monkeypatch/attention/nvfp4_linear_attn.py index ce5e819188..d89276fbbf 100644 --- a/src/axolotl/monkeypatch/attention/nvfp4_linear_attn.py +++ b/src/axolotl/monkeypatch/attention/nvfp4_linear_attn.py @@ -35,6 +35,7 @@ import torch import torch.nn.functional as F +import triton from torch import nn from axolotl.kernels.attn_nvfp4_flash import _quant_nvfp4 @@ -45,8 +46,6 @@ ) from axolotl.utils.logging import get_logger -import triton - LOG = get_logger(__name__) @@ -113,14 +112,24 @@ def _gemm_from_packed_act(anv, asc, wnv, wsc, m, n_out, k, out_dtype): block_m, block_n, block_k = 128, 128, 256 grid = (triton.cdiv(m, block_m), triton.cdiv(n_out, block_n)) _nvfp4_gemm_kernel[grid]( - anv.view(torch.uint8), asc.view(torch.uint8), - wnv.view(torch.uint8), wsc.view(torch.uint8), + anv.view(torch.uint8), + asc.view(torch.uint8), + wnv.view(torch.uint8), + wsc.view(torch.uint8), out, - m, n_out, k, - anv.stride(0), wnv.stride(0), out.stride(0), - asc.stride(0), wsc.stride(0), - BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, - num_warps=8, num_stages=3, + m, + n_out, + k, + anv.stride(0), + wnv.stride(0), + out.stride(0), + asc.stride(0), + wsc.stride(0), + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=8, + num_stages=3, ) return out @@ -150,7 +159,9 @@ def forward( torch.is_grad_enabled() or use_cache or kwargs - or not _position_ids_are_dense_unpacked(position_ids, hidden_states.shape[1]) + or not _position_ids_are_dense_unpacked( + position_ids, hidden_states.shape[1] + ) ): return _call_orig_forward( orig_forward, @@ -226,8 +237,11 @@ def forward( key = key.repeat_interleave(rep, dim=2) core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( - query, key, value, - g=g, beta=beta, + query, + key, + value, + g=g, + beta=beta, initial_state=None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, @@ -242,9 +256,7 @@ def forward( # out_proj: separate activation (post delta-rule), its own NVFP4 path. out_wnv, out_wsc = _get_packed_weight(self, "_out_packed", self.out_proj) - output = nvfp4_linear( - core_attn_out, out_wnv, out_wsc, self.hidden_size - ) + output = nvfp4_linear(core_attn_out, out_wnv, out_wsc, self.hidden_size) return output return forward diff --git a/src/axolotl/monkeypatch/attention/sage_fp4_attn.py b/src/axolotl/monkeypatch/attention/sage_fp4_attn.py index 27a6425a1c..4c3105eeeb 100644 --- a/src/axolotl/monkeypatch/attention/sage_fp4_attn.py +++ b/src/axolotl/monkeypatch/attention/sage_fp4_attn.py @@ -64,9 +64,7 @@ def _flash_available() -> bool: return importlib.util.find_spec("flash_attn") is not None -def _fallback_kind( - attention_mask: torch.Tensor | None, q_len: int, kv_len: int -) -> str: +def _fallback_kind(attention_mask: torch.Tensor | None, q_len: int, kv_len: int) -> str: """Classify the fallback for inputs the FP4 kernel can't serve. ``"causal"`` / ``"full"`` are flash-attention-eligible (dense, no per-key diff --git a/src/axolotl/monkeypatch/models/qwen3_5/modeling.py b/src/axolotl/monkeypatch/models/qwen3_5/modeling.py index 11da3e7a4d..9dbc6eed93 100644 --- a/src/axolotl/monkeypatch/models/qwen3_5/modeling.py +++ b/src/axolotl/monkeypatch/models/qwen3_5/modeling.py @@ -295,9 +295,7 @@ def _apply_packing_patches( ) -def patch_qwen3_5_modeling_packing( - *, fla_causal_conv_compile_boundary: bool = False -): +def patch_qwen3_5_modeling_packing(*, fla_causal_conv_compile_boundary: bool = False): _apply_packing_patches( "qwen3_5", "Qwen3_5", diff --git a/src/axolotl/monkeypatch/models/qwen_fused_attn.py b/src/axolotl/monkeypatch/models/qwen_fused_attn.py index 67f5c18a46..ce0a9e3665 100644 --- a/src/axolotl/monkeypatch/models/qwen_fused_attn.py +++ b/src/axolotl/monkeypatch/models/qwen_fused_attn.py @@ -11,7 +11,9 @@ LOG = get_logger(__name__) -def _attention_interface(functions, implementation: str, fallback: Callable) -> Callable: +def _attention_interface( + functions, implementation: str, fallback: Callable +) -> Callable: if hasattr(functions, "get_interface"): return functions.get_interface(implementation, fallback) if implementation == "eager": diff --git a/src/axolotl/utils/attn_qat.py b/src/axolotl/utils/attn_qat.py index e2e5e45ade..87ef178e3f 100644 --- a/src/axolotl/utils/attn_qat.py +++ b/src/axolotl/utils/attn_qat.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="attr-defined" """FP4-attention quantization-aware training (Attn-QAT, arXiv 2603.00040). QAT for FP4 attention: during SFT the attention operands (Q, K, V and the @@ -184,8 +185,13 @@ def nvfp4_qat_attention_forward( ) nvfp4_qat_attention_forward._fused_logged = True out = _fused_nvfp4_qat_attention( - query, key, value, scaling, causal, - module.num_key_value_groups, key_pad_bias, + query, + key, + value, + scaling, + causal, + module.num_key_value_groups, + key_pad_bias, ) return out.transpose(1, 2).contiguous(), None @@ -204,9 +210,9 @@ def nvfp4_qat_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask[..., : k_fq.shape[-2]] - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) attn_weights = nn.functional.dropout( attn_weights, p=dropout, training=module.training ) diff --git a/src/axolotl/utils/nvfp4_cuda_graph_loop.py b/src/axolotl/utils/nvfp4_cuda_graph_loop.py index cf38f1924f..76b5d5fc6e 100644 --- a/src/axolotl/utils/nvfp4_cuda_graph_loop.py +++ b/src/axolotl/utils/nvfp4_cuda_graph_loop.py @@ -96,7 +96,9 @@ def validate_loop_cfg(cfg): def maybe_compile_model(model: torch.nn.Module, cfg, options: GraphLoopOptions): - compile_model = cfg.torch_compile if options.compile_model is None else options.compile_model + compile_model = ( + cfg.torch_compile if options.compile_model is None else options.compile_model + ) if not compile_model: return model, "compile=off" @@ -188,7 +190,9 @@ def build_dataloader(cfg, train_dataset, tokenizer) -> DataLoader: ) -def build_optimizer(model: torch.nn.Module, cfg, capturable: bool) -> torch.optim.Optimizer: +def build_optimizer( + model: torch.nn.Module, cfg, capturable: bool +) -> torch.optim.Optimizer: params = [p for p in model.parameters() if p.requires_grad] if not params: raise ValueError("model has no trainable parameters") @@ -213,10 +217,14 @@ def _model_device(model: torch.nn.Module) -> torch.device: def _tensor_batch(batch: dict[str, Any]) -> dict[str, torch.Tensor]: - return {key: value for key, value in batch.items() if isinstance(value, torch.Tensor)} + return { + key: value for key, value in batch.items() if isinstance(value, torch.Tensor) + } -def _trim_batch_for_model(batch: dict[str, torch.Tensor], cfg) -> dict[str, torch.Tensor]: +def _trim_batch_for_model( + batch: dict[str, torch.Tensor], cfg +) -> dict[str, torch.Tensor]: out = dict(batch) out.pop("length", None) if ( @@ -444,7 +452,9 @@ def run_loop(cfg, options: GraphLoopOptions) -> GraphLoopResult: "No LR scheduler is applied in this prototype.", ] if cfg.sample_packing: - notes.append("DataLoader uses Axolotl MultipackBatchSampler and packed collator.") + notes.append( + "DataLoader uses Axolotl MultipackBatchSampler and packed collator." + ) if options.reuse_static_batch: notes.append("Reuses one static batch; this is capture feasibility only.") @@ -528,7 +538,9 @@ def run_loop(cfg, options: GraphLoopOptions) -> GraphLoopResult: notes=notes, probes=probes, ) - notes.append("Graph capture failed; auto mode fell back to eager static loop.") + notes.append( + "Graph capture failed; auto mode fell back to eager static loop." + ) optimizer = build_optimizer(model, cfg, capturable=False) else: probes = [] @@ -594,8 +606,7 @@ def format_result(result: GraphLoopResult) -> str: ] if result.median_ms is not None: lines.append( - f"cuda_ms_per_step median={result.median_ms:.3f} " - f"mean={result.mean_ms:.3f}" + f"cuda_ms_per_step median={result.median_ms:.3f} mean={result.mean_ms:.3f}" ) lines.append(f"input_tokens_per_second median={result.tokens_per_second:.1f}") if result.loss_first is not None: diff --git a/tests/e2e/test_nvfp4_training.py b/tests/e2e/test_nvfp4_training.py index 182190816e..271dc8ad52 100644 --- a/tests/e2e/test_nvfp4_training.py +++ b/tests/e2e/test_nvfp4_training.py @@ -62,7 +62,7 @@ def test_learns(self): def test_gemm_wiring_matches_dequant_matmul(self): """The FP4 GEMM must equal dequant(a) @ dequant(b).""" - from axolotl.utils.nvfp4_training import _fp4_mm, QuantPolicy, _quantize + from axolotl.utils.nvfp4_training import QuantPolicy, _fp4_mm, _quantize torch.manual_seed(0) a = torch.randn(256, 512, device="cuda", dtype=torch.bfloat16) @@ -247,7 +247,7 @@ def __init__(self): super().__init__() self.lm_head = nn.Linear(2048, 4096, bias=False) - for mode, cls in ( + for mode, _cls in ( ("compute", NVFP4ComputeBaseLinear), ("storage", NVFP4FrozenBaseLinear), ): @@ -336,7 +336,7 @@ def forward(s, x): n = xf * torch.rsqrt( xf.pow(2).mean(-1, keepdim=True) + s.variance_epsilon ) - g = (1.0 + s.weight.float()) if zero_centered else s.weight.float() + g = (1.0 + s.weight.float()) if zero_centered else s.weight.float() # noqa: B023 return (n * g).to(x.dtype) norm = Norm().cuda() @@ -624,21 +624,10 @@ def test_attention_dkdv_bf16_scratch_matches_fp32_scratch(self): torch.manual_seed(123) amp = 0.5 base = { - "q": torch.randn( - 1, 4, 96, 128, device="cuda", dtype=torch.bfloat16 - ) - * amp, - "k": torch.randn( - 1, 2, 96, 128, device="cuda", dtype=torch.bfloat16 - ) - * amp, - "v": torch.randn( - 1, 2, 96, 128, device="cuda", dtype=torch.bfloat16 - ) - * amp, - "upstream": torch.randn( - 1, 4, 96, 128, device="cuda", dtype=torch.bfloat16 - ) + "q": torch.randn(1, 4, 96, 128, device="cuda", dtype=torch.bfloat16) * amp, + "k": torch.randn(1, 2, 96, 128, device="cuda", dtype=torch.bfloat16) * amp, + "v": torch.randn(1, 2, 96, 128, device="cuda", dtype=torch.bfloat16) * amp, + "upstream": torch.randn(1, 4, 96, 128, device="cuda", dtype=torch.bfloat16) * amp, } @@ -668,7 +657,7 @@ def run(dkdv_scratch_bf16: bool): bf16 = run(True) assert torch.equal(fp32[0], bf16[0]) - for ref, got in zip(fp32[1:], bf16[1:]): + for ref, got in zip(fp32[1:], bf16[1:], strict=False): assert torch.isfinite(got).all().item() rel = (ref.float() - got.float()).norm() / (ref.float().norm() + 1e-9) assert rel < 1e-2, rel.item() @@ -830,10 +819,10 @@ def __init__(self): assert isinstance(m.lm_head, nn.Linear) def test_compile_no_graph_breaks(self): - from axolotl.utils.nvfp4_training import NVFP4Linear, NVFP4Recipe - import torch._dynamo as dyn + from axolotl.utils.nvfp4_training import NVFP4Linear, NVFP4Recipe + torch.manual_seed(0) linear = nn.Linear(512, 256).cuda().bfloat16() # recipe on: SR + RHT live at the quant boundary, must not break the graph @@ -1006,8 +995,8 @@ def test_lora_fp4_compute_base_prequant(self): NVFP4ComputeBaseLinear, NVFP4Linear, NVFP4Recipe, - is_nvfp4_base, convert_lora_base_to_nvfp4, + is_nvfp4_base, ) # bit-identical to the per-step path (same quantization, just cached) @@ -1078,9 +1067,7 @@ def test_fsdp_hooks_present_and_reconstruct(self): fq = _to_fsdp_nvfp4(wq) assert hasattr(fq, "fsdp_pre_all_gather") assert hasattr(fq, "fsdp_post_all_gather") - assert torch.equal( - fq.dequantize(torch.bfloat16), wq.dequantize(torch.bfloat16) - ) + assert torch.equal(fq.dequantize(torch.bfloat16), wq.dequantize(torch.bfloat16)) # simulate the gather: split qdata/scale by row, concat, reconstruct (qd, sc), (ctx, pts) = fq.fsdp_pre_all_gather(mesh=None) diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index 2dafcd31bf..866647b0e3 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -374,8 +374,8 @@ def test_nvfp4_qat_trains_finite_loss(self): Uses dims divisible by 16 (the NVFP4 forward GEMM requires it); this is why the full-model swap test above never runs forward through lm_head. """ - from torchao.quantization import quantize_ from torchao.prototype.qat import NVFP4FakeQuantizedLinear + from torchao.quantization import quantize_ from axolotl.utils.quantization import _make_qat_config