From 612bd7e65255f75b05af903319e4590bc5875d61 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Fri, 12 Dec 2025 00:46:03 +0000 Subject: [PATCH 01/17] Integreated fused rmsnorm + quant in decoder layer --- atom/models/deepseek_v2.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 0ab1aa352..12b1322f7 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1174,9 +1174,32 @@ def forward( residual *= 1. / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) + if ENABLE_FP8_RMSNORM_QUANT_FUSION: + weight = self.post_attention_layernorm.weight + eps = self.post_attention_layernorm.eps + (hidden_states_quant, hidden_states_quant_scale), hidden_states_unquant, _, residual = _fuse_rmsnorm_quant( + hidden_states, + weight, + eps, + None, + None, + None, + residual, + quant_dtype, + False, + False, + 128, + False, + False + ) + if isinstance(self.mlp, + DeepseekV2MoE): + hidden_states = ((hidden_states_quant, hidden_states_quant_scale), hidden_states_unquant) + else: + hidden_states = (hidden_states_quant, hidden_states_quant_scale) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: From e52e7221f86c5f3dcd26feb29c53bf5a9ca41f08 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Fri, 12 Dec 2025 18:17:37 +0000 Subject: [PATCH 02/17] No need to fuse post attention --- atom/models/deepseek_v2.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 12b1322f7..49d60b024 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1102,6 +1102,7 @@ def __init__( eps=config.rms_norm_eps, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION) self.routed_scaling_factor = config.routed_scaling_factor + self.quant_dtype = quant_config["quant_dtype"] if quant_config else None def forward( @@ -1174,32 +1175,9 @@ def forward( residual *= 1. / self.routed_scaling_factor # Fully Connected - if ENABLE_FP8_RMSNORM_QUANT_FUSION: - weight = self.post_attention_layernorm.weight - eps = self.post_attention_layernorm.eps - (hidden_states_quant, hidden_states_quant_scale), hidden_states_unquant, _, residual = _fuse_rmsnorm_quant( - hidden_states, - weight, - eps, - None, - None, - None, - residual, - quant_dtype, - False, - False, - 128, - False, - False - ) - if isinstance(self.mlp, - DeepseekV2MoE): - hidden_states = ((hidden_states_quant, hidden_states_quant_scale), hidden_states_unquant) - else: - hidden_states = (hidden_states_quant, hidden_states_quant_scale) - else: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: From ec27b856511c56d9c62af0751f6bb9275fc95954 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Tue, 16 Dec 2025 20:38:45 +0000 Subject: [PATCH 03/17] Refactored fusion condition --- atom/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 49d60b024..0016477eb 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1103,6 +1103,7 @@ def __init__( fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION) self.routed_scaling_factor = config.routed_scaling_factor self.quant_dtype = quant_config["quant_dtype"] if quant_config else None + self.fuse_rmsnorm_quant = ENABLE_RMSNORM_QUANT_FUSION and self.quant_dtype is not None def forward( From 9640efc65251d5824721e9cc262f03722c835f87 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Wed, 17 Dec 2025 20:27:17 +0000 Subject: [PATCH 04/17] Transpose scales for input layernorm --- atom/models/deepseek_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 0016477eb..b420ffb92 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1132,7 +1132,7 @@ def forward( scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, - transpose_scale=False, + transpose_scale=True, ) else: (hidden_states_quant, hidden_states_quant_scale), _, _, residual = _fuse_rmsnorm_quant( @@ -1148,7 +1148,7 @@ def forward( scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, - transpose_scale=False, + transpose_scale=True, ) hidden_states = (hidden_states_quant, hidden_states_quant_scale) From c8584cb00ce5269410f551cac1e3e092d766dd4b Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Thu, 18 Dec 2025 21:07:49 +0000 Subject: [PATCH 05/17] Added torch compile guards on fusion to enable torch compiler --- atom/models/deepseek_v2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index b420ffb92..7f44998a6 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -54,6 +54,7 @@ fused_reduce_rms_mxfp4_quant, ) from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle +from aiter.dist.utils import cdiv from torch import nn from transformers import PretrainedConfig @@ -191,7 +192,7 @@ def _fuse_rmsnorm_quant( x1_epsilon: float, x2: Optional[torch.Tensor] = None, x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: float = 0.0, + x2_epsilon: Optional[float] = None, res1: Optional[torch.Tensor] = None, dtype_quant=dtypes.fp8, shuffle: Optional[bool] = True, @@ -229,7 +230,7 @@ def _fuse_rmsnorm_quant( ) else: raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") - return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake( From 35e79c413152d112824cef9b903440ac4bf3227f Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Fri, 19 Dec 2025 17:54:17 +0000 Subject: [PATCH 06/17] Refactored fp8 fused rms quant function --- atom/models/deepseek_v2.py | 89 ++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 9 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 7f44998a6..ac951fd48 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -142,7 +142,42 @@ def _fuse_rmsnorm_fp4_quant_fake( out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + +def _fused_rms_fp8_group_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: torch.dtype = dtypes.fp8, + group_size: int = 128, + output_unquantized_inp1: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + m, n1 = x1.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) + out1_bs = torch.empty((m, cdiv(n1, group_size)), dtype=torch.float32, device=x1.device) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out1_unquantized = None + if output_unquantized_inp1: + out1_unquantized = torch.empty_like(x1) + out2 = None + if x2 is not None: + _, n2 = x2.shape + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 @@ -186,6 +221,42 @@ def _fuse_rmsnorm_fp4_quant( return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 +@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake) +def _fused_rms_fp8_group_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: torch.dtype = dtypes.fp8, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_fp8_group_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + group_size, + dtype_quant, + res1, + output_unquantized_inp1, + transpose_scale, + ) + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + def _fuse_rmsnorm_quant( x1: torch.Tensor, x1_weight: torch.Tensor, @@ -194,12 +265,12 @@ def _fuse_rmsnorm_quant( x2_weight: Optional[torch.Tensor] = None, x2_epsilon: Optional[float] = None, res1: Optional[torch.Tensor] = None, - dtype_quant=dtypes.fp8, - shuffle: Optional[bool] = True, - scale_shuffle_padding: Optional[bool] = True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=False, + dtype_quant: torch.dtype = dtypes.fp8, + shuffle: bool = True, + scale_shuffle_padding: bool = False, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, ): if dtype_quant == dtypes.fp4x2: out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = _fuse_rmsnorm_fp4_quant( @@ -222,15 +293,15 @@ def _fuse_rmsnorm_quant( x2, x2_weight, x2_epsilon, - group_size, - dtype_quant, res1, + dtype_quant, + group_size, output_unquantized_inp1, transpose_scale, ) else: raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") - return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake( From a1f16ea0bf066e03c2dbd33f724c6ade88aae0d8 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Wed, 31 Dec 2025 21:46:31 +0000 Subject: [PATCH 07/17] Added fp8 triton preshuffled gemm --- atom/model_ops/linear.py | 33 ++++++++++++++++++++++++++++++++- atom/utils/envs.py | 1 + 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index d12bcaa0e..f52a8dbda 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -32,6 +32,14 @@ from atom.utils import envs +def use_triton_gemm() -> bool: + return envs.ATOM_USE_TRITON_GEMM + +if use_triton_gemm(): + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle +else: + gemm_a8w8_blockscale_preshuffle = None + def divide(numerator, denominator): assert ( numerator % denominator == 0 @@ -130,6 +138,29 @@ def gemm_a4w4_quant( return y[:m, ...] +def gemm_a8w8_blockscale_preshuffle_fake(x: torch.Tensor, weight: torch.Tensor, + x_scale: torch.Tensor, w_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device + ) + + +@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[]) +def gemm_a8w8_blockscale_preshuffle_impl(x: torch.Tensor, weight: torch.Tensor, + x_scale: torch.Tensor, w_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + if gemm_a8w8_blockscale_preshuffle is None: + weight_shuffled = weight.reshape( + weight.shape[0] // 16, + weight.shape[1] * 16 + ) + y = gemm_a8w8_blockscale_preshuffle(x, weight_shuffled, x_scale, w_scale, dtype) + else: + y = gemm_a8w8_blockscale_bpreshuffle(x, weight, x_scale, w_scale, dtype) + return y + + class LinearBase(nn.Module): def __init__( @@ -355,7 +386,7 @@ def forward( if self.bias is not None: y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: - y = gemm_a8w8_blockscale_bpreshuffle( + y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, x_scale, self.weight_scale, dtype=otype ) if self.bias is not None: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index cbd276ed0..85b7966d0 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -17,6 +17,7 @@ "ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1", "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1") == "1", + "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", } def __getattr__(name: str): From 3d01e0210654a005013f9bfd497cbceaebe3b198 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Fri, 2 Jan 2026 16:03:50 +0000 Subject: [PATCH 08/17] Fixed triton gemm condition --- atom/model_ops/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index f52a8dbda..2160d647a 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -150,7 +150,7 @@ def gemm_a8w8_blockscale_preshuffle_fake(x: torch.Tensor, weight: torch.Tensor, def gemm_a8w8_blockscale_preshuffle_impl(x: torch.Tensor, weight: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: - if gemm_a8w8_blockscale_preshuffle is None: + if gemm_a8w8_blockscale_preshuffle is not None: weight_shuffled = weight.reshape( weight.shape[0] // 16, weight.shape[1] * 16 From d7e3e803d90b98bad5ce8870c5bdb074085c6dd2 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Fri, 2 Jan 2026 18:19:41 +0000 Subject: [PATCH 09/17] Added fused rmsnorm quant fp8 back in --- atom/model_ops/linear.py | 15 +++------------ atom/models/deepseek_v2.py | 2 +- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 2160d647a..2409fb115 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -29,15 +29,17 @@ from aiter.ops.shuffle import shuffle_weight from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from aiter import gemm_a4w4, per_1x32_f4_quant_hip from atom.utils import envs - def use_triton_gemm() -> bool: return envs.ATOM_USE_TRITON_GEMM if use_triton_gemm(): + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle else: + gemm_afp4wfp4_preshuffle = None gemm_a8w8_blockscale_preshuffle = None def divide(numerator, denominator): @@ -46,17 +48,6 @@ def divide(numerator, denominator): ), f"numerator {numerator} denominator {denominator}" return numerator // denominator -def use_triton_gemm() -> bool: - return envs.ATOM_USE_TRITON_GEMM - -if use_triton_gemm(): - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle -else: - gemm_afp4wfp4_preshuffle = None - -from aiter.jit.utils.torch_guard import torch_compile_guard -from aiter import gemm_a4w4, per_1x32_f4_quant_hip - def gemm_a4w4_quant_fake( x: torch.Tensor, x_scale: torch.Tensor, diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index ac951fd48..c0fc8c214 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -286,7 +286,7 @@ def _fuse_rmsnorm_quant( output_unquantized_inp1, ) elif dtype_quant == dtypes.fp8: - (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_fp8_group_quant( + (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = _fused_rms_fp8_group_quant( x1, x1_weight, x1_epsilon, From 9dc331cb733a63c78cc7e965a0bca8fb207c01bb Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Tue, 6 Jan 2026 17:35:26 +0000 Subject: [PATCH 10/17] Added transpose_scale back to fp8 fake function --- atom/models/deepseek_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index c0fc8c214..82bdd7e77 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -156,6 +156,7 @@ def _fused_rms_fp8_group_quant_fake( dtype_quant: torch.dtype = dtypes.fp8, group_size: int = 128, output_unquantized_inp1: bool = False, + transpose_scale: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -286,7 +287,7 @@ def _fuse_rmsnorm_quant( output_unquantized_inp1, ) elif dtype_quant == dtypes.fp8: - (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = _fused_rms_fp8_group_quant( + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = _fused_rms_fp8_group_quant( x1, x1_weight, x1_epsilon, From 20ac85063d987778da934f2f95ae829d797b73db Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Tue, 6 Jan 2026 17:42:42 +0000 Subject: [PATCH 11/17] Remove duplicate env --- atom/utils/envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 85b7966d0..cbd276ed0 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -17,7 +17,6 @@ "ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1", "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1") == "1", - "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", } def __getattr__(name: str): From 279d7ddf0b7ec23f6bfc435f1771af68ed3a747e Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Tue, 6 Jan 2026 18:45:02 +0000 Subject: [PATCH 12/17] Implemented fp8 gemm preshuffled + split + cat --- atom/model_ops/attention_mla.py | 59 ++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index f7b72fc9d..f60cac2a3 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -4,6 +4,7 @@ import logging from dataclasses import dataclass from typing import Optional, Tuple +from functools import partial as functools_partial import torch from aiter import ( @@ -28,15 +29,16 @@ ) from aiter import ( QuantType, - dtypes, get_hip_quant, ) if use_triton_gemm(): try: from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat except: fused_gemm_afp4wfp4_preshuffle_split_cat = None + fused_gemm_a8w8_blockscale_preshuffle_split_cat = None # from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla # from aiter import fused_qk_rope_concat_and_cache_mla @@ -313,38 +315,41 @@ def _forward_prefill( self.v_head_dim, output_dtype ) - else: # FP8 GEMM + split + cat + elif fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None and weight.dtype == dtypes.fp8: # FP8 GEMM + split + cat + weight_shuffled = weight.reshape( + weight.shape[0] // 16, + weight.shape[1] * 16 + ) + + output_dtype = kv_c_normed.dtype + + quant_func = functools_partial( + get_hip_quant(QuantType.per_1x128), + transpose_scale=True + ) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp8, + scale=getattr(self.kv_b_proj, "input_scale", None) + ) + + k, v = fused_gemm_a8w8_blockscale_preshuffle_split_cat( + q_input, + weight_shuffled, + k_rope.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale, + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype + ) + else: kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat - # import aiter as rocm_aiter - # from aiter import get_hip_quant - # aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) - # from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 - - # input = kv_c_normed - # weight = self.kv_b_proj.weight - # block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size - # weight_scale = self.kv_b_proj.weight_scale - - # input_2d = input.view(-1, input.shape[-1]) - # output_dtype = input.dtype - - # if current_platform.is_fp8_fnuz(): - # q_input, x_scale = aiter_per1x128_quant( - # input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - # else: - # q_input, x_scale = per_token_group_quant_fp8( - # input_2d, block_size[1], column_major_scales=False) - - # k, v = fused_gemm_a8w8_blockscale_split_cat( - # q_input, weight, k_rope.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype - # ) else: kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim From 678ca28db865e297dd4824bbee9c07703faa0fb4 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Wed, 7 Jan 2026 18:05:29 +0000 Subject: [PATCH 13/17] Implemented fp8 fused reduce rms quant --- atom/models/deepseek_v2.py | 138 +++++++++++++++++++++++++++++++++++-- 1 file changed, 132 insertions(+), 6 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 93985474d..041088518 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -48,13 +48,16 @@ deepgemm_fp8_paged_mqa_logits_stage1, ) from aiter.rotary_embedding import get_rope -from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant +from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_group_quant, + fused_reduce_rms_fp8_group_quant +) from aiter.ops.triton.fused_mxfp4_quant import ( fused_rms_mxfp4_quant, fused_reduce_rms_mxfp4_quant, ) from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle -from aiter.dist.utils import cdiv +from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from torch import nn @@ -167,7 +170,7 @@ def _fused_rms_fp8_group_quant_fake( ]: m, n1 = x1.shape out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) - out1_bs = torch.empty((m, cdiv(n1, group_size)), dtype=torch.float32, device=x1.device) + out1_bs = torch.empty((m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device) if transpose_scale: out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) out1_unquantized = None @@ -340,6 +343,37 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake( return q_c, q_c_scale, kv_c_normed, k_pe +def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8_fake( + hidden_states_quant: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + hidden_states_quant_scale: Optional[torch.Tensor] = None, + output_unquantized_inp1: Optional[bool] = False, + transpose_scale: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor +]: + M = hidden_states_quant.shape[0] + FP8_QUANT_BLOCK_SIZE = 128 + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank), dtype=dtypes.fp8, device=device) + scale_n = (q_lora_rank + FP8_QUANT_BLOCK_SIZE - 1) // FP8_QUANT_BLOCK_SIZE + q_c_scale = torch.empty((M, scale_n), dtype=dtypes.fp8, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + + @torch_compile_guard(gen_fake=_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake, mutates_args=[]) def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4( hidden_states_quant: torch.Tensor, @@ -440,6 +474,85 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4( k_pe = k_pe_reduced_out return q_c, q_c_scale, kv_c_normed, k_pe + +@torch_compile_guard(gen_fake=_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8_fake, mutates_args=[]) +def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( + hidden_states_quant: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + hidden_states_quant_scale: Optional[torch.Tensor] = None, + output_unquantized_inp1: Optional[bool] = False, + transpose_scale: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor +]: + M = hidden_states_quant.shape[0] + + if hidden_states_quant_scale is None: + quant_func = get_hip_quant(QuantType.per_1x128), + x, x_scale = quant_func( + hidden_states_quant, + quant_dtype=dtypes.fp8, + transpose_scale=transpose_scale + ) + qkv_lora = gemm_a8w8_blockscale_preshuffle( + x, + weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), + x_scale, + weight_scale_qkv_a_proj, + skip_reduce=True + ) + else: + qkv_lora = gemm_a8w8_blockscale_preshuffle( + hidden_states_quant, + weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), + hidden_states_quant_scale, + weight_scale_qkv_a_proj, + skip_reduce=True + ) + + q_c, kv_c, k_pe = torch.split( + qkv_lora, + [q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant( + q_c, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_c, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + k_pe_reduced, + res1=None, + output_unquantized_inp1=output_unquantized_inp1, + dtype=torch.bfloat16, + out3=k_pe_reduced_out, + ) + + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + + return q_c, q_c_scale, kv_c_normed, k_pe + + def _fuse_qkv_a_proj_reduce_rmsnorm_quant( hidden_states_quant: torch.Tensor, weight_qkv_a_proj: torch.Tensor, @@ -477,8 +590,21 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant( output_unquantized_inp1, ) elif dtype_quant == dtypes.fp8: - # TODO add - raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") + q_c, q_c_scale, kv_c_normed, k_pe = _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( + hidden_states_quant, + weight_qkv_a_proj, + weight_scale_qkv_a_proj, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + hidden_states_quant_scale, + output_unquantized_inp1, + transpose_scale, + ) else: raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") @@ -1066,7 +1192,7 @@ def forward( hidden_states, hidden_states_scale = hidden_states if self.q_lora_rank is not None: - if self.fuse_qknorm_quant and self.quant_dtype != dtypes.fp8: + if self.fuse_qknorm_quant: q_c, q_c_scale, kv_c_normed, k_pe = _fuse_qkv_a_proj_reduce_rmsnorm_quant( hidden_states, self.fused_qkv_a_proj.weight, From b26f81f8590d9fd578508fc18641486526971806 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Wed, 7 Jan 2026 19:36:26 +0000 Subject: [PATCH 14/17] Removed unreachable branch --- atom/models/deepseek_v2.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 041088518..87e965129 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1223,25 +1223,7 @@ def forward( dim=-1, ) # fuse q_c norm + kv_c norm + quant of hidden_states_or_q_c - if self.fuse_qknorm_quant: - (hidden_states_or_q_c, - hidden_states_or_q_c_scale), _, kv_c_normed, _ = _fuse_rmsnorm_quant( - q_c, - self.q_a_layernorm.weight, - self.q_a_layernorm.eps, - kv_c, - self.kv_a_layernorm.weight, - self.kv_a_layernorm.eps, - None, - dtype_quant=self.quant_dtype, - shuffle=True, - scale_shuffle_padding=True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=False, - ) - else: - hidden_states_or_q_c = self.q_a_layernorm(q_c) + hidden_states_or_q_c = self.q_a_layernorm(q_c) else: hidden_states_or_q_c = hidden_states kv_c, k_pe = torch.split(self.kv_a_proj_with_mqa(hidden_states, hidden_states_scale), From b46db4485186980fe9e3ded4cff544e0f79b33a3 Mon Sep 17 00:00:00 2001 From: Farel Lukas Date: Thu, 8 Jan 2026 00:17:21 +0000 Subject: [PATCH 15/17] Added transpose_scale to fused reduce rms quant --- atom/models/deepseek_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 87e965129..17549488c 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -545,6 +545,7 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( output_unquantized_inp1=output_unquantized_inp1, dtype=torch.bfloat16, out3=k_pe_reduced_out, + transpose_scale=transpose_scale, ) if k_pe_reduced_out is not None: @@ -1210,7 +1211,7 @@ def forward( scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, - transpose_scale=False, + transpose_scale=True, ) hidden_states_or_q_c = q_c hidden_states_or_q_c_scale = q_c_scale From d9fd150f834c1a97827a62af477ca0afa72d0932 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 8 Jan 2026 04:20:01 +0000 Subject: [PATCH 16/17] fix --- atom/models/deepseek_v2.py | 61 ++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 17549488c..0b88acf5b 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -56,9 +56,6 @@ fused_rms_mxfp4_quant, fused_reduce_rms_mxfp4_quant, ) -from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle -from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle -from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from torch import nn from transformers import PretrainedConfig @@ -77,6 +74,13 @@ MergedReplicatedLinear, use_triton_gemm, ) +from aiter import gemm_a8w8_blockscale_bpreshuffle +from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle +from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle +if use_triton_gemm(): + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle +else: + gemm_a8w8_blockscale_preshuffle = None from atom.model_ops.attention_mla import is_rocm_aiter_fp4bmm_enabled from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import ( @@ -499,27 +503,46 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( M = hidden_states_quant.shape[0] if hidden_states_quant_scale is None: - quant_func = get_hip_quant(QuantType.per_1x128), + quant_func = get_hip_quant(QuantType.per_1x128) x, x_scale = quant_func( hidden_states_quant, quant_dtype=dtypes.fp8, transpose_scale=transpose_scale ) - qkv_lora = gemm_a8w8_blockscale_preshuffle( - x, - weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), - x_scale, - weight_scale_qkv_a_proj, - skip_reduce=True - ) + if gemm_a8w8_blockscale_preshuffle is not None: + qkv_lora = gemm_a8w8_blockscale_preshuffle( + x, + weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), + x_scale, + weight_scale_qkv_a_proj, + skip_reduce=True + ) + else: + qkv_lora = gemm_a8w8_blockscale_bpreshuffle( + x, + weight_qkv_a_proj, + x_scale, + weight_scale_qkv_a_proj, + torch.bfloat16 + ) else: - qkv_lora = gemm_a8w8_blockscale_preshuffle( - hidden_states_quant, - weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), - hidden_states_quant_scale, - weight_scale_qkv_a_proj, - skip_reduce=True - ) + if gemm_a8w8_blockscale_preshuffle is not None: + qkv_lora = gemm_a8w8_blockscale_preshuffle( + hidden_states_quant, + weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), + hidden_states_quant_scale, + weight_scale_qkv_a_proj, + skip_reduce=True + ) + else: + qkv_lora = gemm_a8w8_blockscale_bpreshuffle( + x, + weight_qkv_a_proj, + hidden_states_quant_scale, + weight_scale_qkv_a_proj, + torch.bfloat16 + ) + q_c, kv_c, k_pe = torch.split( qkv_lora, @@ -1426,7 +1449,7 @@ def forward( return hidden_states, residual -@support_torch_compile +# @support_torch_compile class DeepseekV2Model(nn.Module): def __init__( self, From bbd41987a015fd2c383de1fa5ede5073fff9d604 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 8 Jan 2026 19:16:56 +0000 Subject: [PATCH 17/17] clean --- atom/model_ops/linear.py | 9 ++++-- atom/models/deepseek_v2.py | 59 +++++++++++++++++++------------------- atom/utils/envs.py | 2 +- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 714ee54df..f4bc1f778 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -36,8 +36,13 @@ def use_triton_gemm() -> bool: return envs.ATOM_USE_TRITON_GEMM if use_triton_gemm(): - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle + try: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle + # For Triton FP8 Blockscale GEMM is mostly slower then AITER GEMM, we turn off Triton FP8 GEMM + # from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle + except: + gemm_afp4wfp4_preshuffle = None + gemm_a8w8_blockscale_preshuffle = None else: gemm_afp4wfp4_preshuffle = None gemm_a8w8_blockscale_preshuffle = None diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 0b88acf5b..3d3105607 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -74,13 +74,17 @@ MergedReplicatedLinear, use_triton_gemm, ) + from aiter import gemm_a8w8_blockscale_bpreshuffle -from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle -from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle if use_triton_gemm(): - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle -else: - gemm_a8w8_blockscale_preshuffle = None + try: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle + from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle + except: + gemm_afp4wfp4_preshuffle = None + gemm_a16wfp4_preshuffle = None + gemm_a8w8_blockscale_preshuffle = None from atom.model_ops.attention_mla import is_rocm_aiter_fp4bmm_enabled from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import ( @@ -107,7 +111,7 @@ ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION -ENABLE_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_RMSNORM_QUANT_FUSION +ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION def _fuse_rmsnorm_fp4_quant_fake( @@ -509,7 +513,7 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( quant_dtype=dtypes.fp8, transpose_scale=transpose_scale ) - if gemm_a8w8_blockscale_preshuffle is not None: + if M <= 256: qkv_lora = gemm_a8w8_blockscale_preshuffle( x, weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), @@ -526,7 +530,7 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( torch.bfloat16 ) else: - if gemm_a8w8_blockscale_preshuffle is not None: + if M <= 256: qkv_lora = gemm_a8w8_blockscale_preshuffle( hidden_states_quant, weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), @@ -1055,9 +1059,8 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.layer_num = layer_num - # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains BF16 GEMM, - # Otherwise, if use_triton_gemm() is off, all projections are BF16 GEMMs - # For FP8, use_triton_gemm() is ignored and AITER FP8 GEMM is used + # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, + # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs if quant_config["quant_dtype"] == dtypes.fp4x2: if not use_triton_gemm(): # TODO use ignore layer for mxfp4 attention @@ -1196,13 +1199,12 @@ def __init__( prefix=prefix, ) - # self.fuse_qknorm_quant is turned on only if FP8 or (FP4 and use_triton_gemm()), - # because for (FP4 and not use_triton_gemm()) case, BF16 GEMMs are used + # When ATOM_ENABLE_DS_QKNORM_QUANT_FUSION is turned on, self.fuse_qknorm_quant is turned on only if use_triton_gemm() and (FP8 or FP4), self.prefix = prefix self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or (quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm()): + if (quant_config["quant_dtype"] == dtypes.fp8 or quant_config["quant_dtype"] == dtypes.fp4x2) and use_triton_gemm(): self.quant_dtype = quant_config["quant_dtype"] self.fuse_qknorm_quant = True @@ -1314,21 +1316,23 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - # self.fuse_input_norm_quant is turned on only if FP8 or (FP4 and use_triton_gemm()), - # because for (FP4 and not use_triton_gemm()) case, BF16 GEMMs are used - # note that ENABLE_ALLREDUCE_RMSNORM_FUSION will be turned off if it was on + # When ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on self.fuse_input_norm_quant is turned on only if use_triton_gemm and (FP8 or FP4), + # Because AR_RMS and RMS_Quant cannot co-exist for input_layernorm, this block of codes ensures 3 things when ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on: + # 1. RMS_Quant fusion is only used for input_layernorm + # 2. The reduce_results variable is re-enabled for feed forward layers (MOE and MLP), because AR_RMS is now disabled in the beginning of the next layer + # 3. AR_RMS is turned off for input_layernorm but still enabled for post_attention_layernorm if ENABLE_ALLREDUCE_RMSNORM_FUSION is turned on self.quant_dtype = None self.fuse_input_norm_quant = False self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION - - # this part enforce fuse_rms_quant and use standalone AR - if quant_config is not None and ENABLE_RMSNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or (quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm()): + if quant_config is not None and ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: + if (quant_config["quant_dtype"] == dtypes.fp8 or quant_config["quant_dtype"] == dtypes.fp4x2) and use_triton_gemm(): self.quant_dtype = quant_config["quant_dtype"] self.fuse_input_norm_quant = True if self.fuse_ar_input_norm: self.fuse_ar_input_norm = False - logger.info("Warning: Because ENABLE_RMSNORM_QUANT_FUSION is turned on, AR + RMS fusion is turned off for input_layernorm and reduce_results is re-enabled for first k dense layer down_proj") + logger.info("Warning: Because ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on, AR + RMS fusion is turned off for input_layernorm and reduce_results is re-enabled for first k dense layer down_proj") + else: + logger.info("Info: Because ATOM_USE_TRITON_GEMM is not turned on in DeepSeek-R1, ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned off automatically") if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace @@ -1340,11 +1344,6 @@ def __init__( prefix=f"{prefix}.mlp", ) else: - # next_layer_dense = True - # if (config.n_routed_experts is not None - # and (layer_idx + 1) >= config.first_k_dense_replace - # and (layer_idx + 1) % config.moe_layer_freq == 0): - # next_layer_dense = False self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -1361,7 +1360,7 @@ def __init__( fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION) self.routed_scaling_factor = config.routed_scaling_factor self.quant_dtype = quant_config["quant_dtype"] if quant_config else None - self.fuse_rmsnorm_quant = ENABLE_RMSNORM_QUANT_FUSION and self.quant_dtype is not None + self.fuse_rmsnorm_quant = ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None def forward( @@ -1499,11 +1498,11 @@ def __init__( layer_num_offset=0 ) + # fused_allreduce will have to be turned off here if the fuse_ar_input_norm variable is False in the last layer if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, - fused_allreduce=self.layers[self.end_layer - 1].fuse_ar_input_norm,) - # fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION,) + fused_allreduce=self.layers[self.end_layer - 1].fuse_ar_input_norm) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = ( diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 641bf9a40..cb6bb50f7 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -13,10 +13,10 @@ "ATOM_DP_MASTER_PORT": lambda: int(os.getenv("ATOM_DP_MASTER_PORT", "29500")), "ATOM_ENFORCE_EAGER": lambda: os.getenv("ATOM_ENFORCE_EAGER", "0") == "1", "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", "1") == "1", + "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", "1") == "1", "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv("ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1") == "1", "ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1", "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", - "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1") == "1", "ATOM_ENABLE_QK_NORM_ROPE_FUSION": lambda: os.getenv("ATOM_ENABLE_QK_NORM_ROPE_FUSION", "1") == "1", # add qk-norm-rope-cache-quant fusion for Qwen3-Moe model, default disabled, # Qwen3-Moe model should enable this for better performance.