diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 472c08793..37be7173e 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -4,6 +4,7 @@ import logging from dataclasses import dataclass from typing import Optional, Tuple +from functools import partial as functools_partial import torch from aiter import ( @@ -28,15 +29,16 @@ ) from aiter import ( QuantType, - dtypes, get_hip_quant, ) if use_triton_gemm(): try: from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat + from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat except: fused_gemm_afp4wfp4_preshuffle_split_cat = None + fused_gemm_a8w8_blockscale_preshuffle_split_cat = None # from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla # # from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla @@ -279,38 +281,41 @@ def _forward_prefill_mha( self.v_head_dim, output_dtype ) - else: # FP8 GEMM + split + cat + elif fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None and weight.dtype == dtypes.fp8: # FP8 GEMM + split + cat + weight_shuffled = weight.reshape( + weight.shape[0] // 16, + weight.shape[1] * 16 + ) + + output_dtype = kv_c_normed.dtype + + quant_func = functools_partial( + get_hip_quant(QuantType.per_1x128), + transpose_scale=True + ) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp8, + scale=getattr(self.kv_b_proj, "input_scale", None) + ) + + k, v = fused_gemm_a8w8_blockscale_preshuffle_split_cat( + q_input, + weight_shuffled, + k_rope.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale, + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype + ) + else: kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat - # import aiter as rocm_aiter - # from aiter import get_hip_quant - # aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) - # from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 - - # input = kv_c_normed - # weight = self.kv_b_proj.weight - # block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size - # weight_scale = self.kv_b_proj.weight_scale - - # input_2d = input.view(-1, input.shape[-1]) - # output_dtype = input.dtype - - # if current_platform.is_fp8_fnuz(): - # q_input, x_scale = aiter_per1x128_quant( - # input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - # else: - # q_input, x_scale = per_token_group_quant_fp8( - # input_2d, block_size[1], column_major_scales=False) - - # k, v = fused_gemm_a8w8_blockscale_split_cat( - # q_input, weight, k_rope.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype - # ) else: kv_nope = self.kv_b_proj(kv_c_normed).view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 9c8deb7fa..f4bc1f778 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -29,25 +29,29 @@ from aiter.ops.shuffle import shuffle_weight from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from aiter import gemm_a4w4, per_1x32_f4_quant_hip from atom.utils import envs - -def divide(numerator, denominator): - assert ( - numerator % denominator == 0 - ), f"numerator {numerator} denominator {denominator}" - return numerator // denominator - def use_triton_gemm() -> bool: return envs.ATOM_USE_TRITON_GEMM if use_triton_gemm(): - from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle + try: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle + # For Triton FP8 Blockscale GEMM is mostly slower then AITER GEMM, we turn off Triton FP8 GEMM + # from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle + except: + gemm_afp4wfp4_preshuffle = None + gemm_a8w8_blockscale_preshuffle = None else: gemm_afp4wfp4_preshuffle = None + gemm_a8w8_blockscale_preshuffle = None -from aiter.jit.utils.torch_guard import torch_compile_guard -from aiter import gemm_a4w4, per_1x32_f4_quant_hip +def divide(numerator, denominator): + assert ( + numerator % denominator == 0 + ), f"numerator {numerator} denominator {denominator}" + return numerator // denominator def gemm_a4w4_quant_fake( x: torch.Tensor, @@ -132,6 +136,29 @@ def gemm_a4w4_quant( return y[:m, ...] +def gemm_a8w8_blockscale_preshuffle_fake(x: torch.Tensor, weight: torch.Tensor, + x_scale: torch.Tensor, w_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device + ) + + +@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[]) +def gemm_a8w8_blockscale_preshuffle_impl(x: torch.Tensor, weight: torch.Tensor, + x_scale: torch.Tensor, w_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + if gemm_a8w8_blockscale_preshuffle is not None: + weight_shuffled = weight.reshape( + weight.shape[0] // 16, + weight.shape[1] * 16 + ) + y = gemm_a8w8_blockscale_preshuffle(x, weight_shuffled, x_scale, w_scale, dtype) + else: + y = gemm_a8w8_blockscale_bpreshuffle(x, weight, x_scale, w_scale, dtype) + return y + + class LinearBase(nn.Module): def __init__( @@ -360,7 +387,7 @@ def forward( if self.bias is not None: y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: - y = gemm_a8w8_blockscale_bpreshuffle( + y = gemm_a8w8_blockscale_preshuffle_impl( x, self.weight, x_scale, self.weight_scale, dtype=otype ) if self.bias is not None: diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 89839c3c3..3d3105607 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -48,13 +48,14 @@ deepgemm_fp8_paged_mqa_logits_stage1, ) from aiter.rotary_embedding import get_rope -from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant +from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_group_quant, + fused_reduce_rms_fp8_group_quant +) from aiter.ops.triton.fused_mxfp4_quant import ( fused_rms_mxfp4_quant, fused_reduce_rms_mxfp4_quant, ) -from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle -from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from torch import nn from transformers import PretrainedConfig @@ -73,6 +74,17 @@ MergedReplicatedLinear, use_triton_gemm, ) + +from aiter import gemm_a8w8_blockscale_bpreshuffle +if use_triton_gemm(): + try: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle + from aiter.ops.triton.gemm_a16wfp4 import gemm_a16wfp4_preshuffle + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle + except: + gemm_afp4wfp4_preshuffle = None + gemm_a16wfp4_preshuffle = None + gemm_a8w8_blockscale_preshuffle = None from atom.model_ops.attention_mla import is_rocm_aiter_fp4bmm_enabled from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import ( @@ -99,7 +111,7 @@ ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION -ENABLE_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_RMSNORM_QUANT_FUSION +ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION def _fuse_rmsnorm_fp4_quant_fake( @@ -142,7 +154,43 @@ def _fuse_rmsnorm_fp4_quant_fake( out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + +def _fused_rms_fp8_group_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: torch.dtype = dtypes.fp8, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + m, n1 = x1.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=x1.device) + out1_bs = torch.empty((m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=x1.device) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out1_unquantized = None + if output_unquantized_inp1: + out1_unquantized = torch.empty_like(x1) + out2 = None + if x2 is not None: + _, n2 = x2.shape + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 @@ -186,20 +234,56 @@ def _fuse_rmsnorm_fp4_quant( return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 +@torch_compile_guard(gen_fake=_fused_rms_fp8_group_quant_fake) +def _fused_rms_fp8_group_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + dtype_quant: torch.dtype = dtypes.fp8, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_fp8_group_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + group_size, + dtype_quant, + res1, + output_unquantized_inp1, + transpose_scale, + ) + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + def _fuse_rmsnorm_quant( x1: torch.Tensor, x1_weight: torch.Tensor, x1_epsilon: float, x2: Optional[torch.Tensor] = None, x2_weight: Optional[torch.Tensor] = None, - x2_epsilon: float = 0.0, + x2_epsilon: Optional[float] = None, res1: Optional[torch.Tensor] = None, - dtype_quant=dtypes.fp8, - shuffle: Optional[bool] = True, - scale_shuffle_padding: Optional[bool] = True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=False, + dtype_quant: torch.dtype = dtypes.fp8, + shuffle: bool = True, + scale_shuffle_padding: bool = False, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = False, ): if dtype_quant == dtypes.fp4x2: out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = _fuse_rmsnorm_fp4_quant( @@ -215,16 +299,16 @@ def _fuse_rmsnorm_quant( output_unquantized_inp1, ) elif dtype_quant == dtypes.fp8: - (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_fp8_group_quant( + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = _fused_rms_fp8_group_quant( x1, x1_weight, x1_epsilon, x2, x2_weight, x2_epsilon, - group_size, - dtype_quant, res1, + dtype_quant, + group_size, output_unquantized_inp1, transpose_scale, ) @@ -267,6 +351,37 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake( return q_c, q_c_scale, kv_c_normed, k_pe +def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8_fake( + hidden_states_quant: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + hidden_states_quant_scale: Optional[torch.Tensor] = None, + output_unquantized_inp1: Optional[bool] = False, + transpose_scale: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor +]: + M = hidden_states_quant.shape[0] + FP8_QUANT_BLOCK_SIZE = 128 + device = hidden_states_quant.device + q_c = torch.empty((M, q_lora_rank), dtype=dtypes.fp8, device=device) + scale_n = (q_lora_rank + FP8_QUANT_BLOCK_SIZE - 1) // FP8_QUANT_BLOCK_SIZE + q_c_scale = torch.empty((M, scale_n), dtype=dtypes.fp8, device=device) + kv_c_normed = torch.empty((M, kv_lora_rank), dtype=torch.bfloat16, device=device) + k_pe = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + return q_c, q_c_scale, kv_c_normed, k_pe + + @torch_compile_guard(gen_fake=_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4_fake, mutates_args=[]) def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4( hidden_states_quant: torch.Tensor, @@ -367,6 +482,105 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4( k_pe = k_pe_reduced_out return q_c, q_c_scale, kv_c_normed, k_pe + +@torch_compile_guard(gen_fake=_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8_fake, mutates_args=[]) +def _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( + hidden_states_quant: torch.Tensor, + weight_qkv_a_proj: torch.Tensor, + weight_scale_qkv_a_proj: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + q_lora_rank: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + hidden_states_quant_scale: Optional[torch.Tensor] = None, + output_unquantized_inp1: Optional[bool] = False, + transpose_scale: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor +]: + M = hidden_states_quant.shape[0] + + if hidden_states_quant_scale is None: + quant_func = get_hip_quant(QuantType.per_1x128) + x, x_scale = quant_func( + hidden_states_quant, + quant_dtype=dtypes.fp8, + transpose_scale=transpose_scale + ) + if M <= 256: + qkv_lora = gemm_a8w8_blockscale_preshuffle( + x, + weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), + x_scale, + weight_scale_qkv_a_proj, + skip_reduce=True + ) + else: + qkv_lora = gemm_a8w8_blockscale_bpreshuffle( + x, + weight_qkv_a_proj, + x_scale, + weight_scale_qkv_a_proj, + torch.bfloat16 + ) + else: + if M <= 256: + qkv_lora = gemm_a8w8_blockscale_preshuffle( + hidden_states_quant, + weight_qkv_a_proj.view(weight_qkv_a_proj.shape[0] // 16, -1), + hidden_states_quant_scale, + weight_scale_qkv_a_proj, + skip_reduce=True + ) + else: + qkv_lora = gemm_a8w8_blockscale_bpreshuffle( + x, + weight_qkv_a_proj, + hidden_states_quant_scale, + weight_scale_qkv_a_proj, + torch.bfloat16 + ) + + + q_c, kv_c, k_pe = torch.split( + qkv_lora, + [q_lora_rank, kv_lora_rank, qk_rope_head_dim], + dim=-1, + ) + + k_pe_reduced = None + k_pe_reduced_out = None + if k_pe.dim() == 3: + device = hidden_states_quant.device + k_pe_reduced = k_pe + k_pe_reduced_out = torch.empty((M, q_lora_rank + kv_lora_rank + qk_rope_head_dim), dtype=torch.bfloat16, device=device)[..., :qk_rope_head_dim] + (q_c, q_c_scale), _, kv_c_normed, _, k_pe_reduced_out = fused_reduce_rms_fp8_group_quant( + q_c, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_c, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + k_pe_reduced, + res1=None, + output_unquantized_inp1=output_unquantized_inp1, + dtype=torch.bfloat16, + out3=k_pe_reduced_out, + transpose_scale=transpose_scale, + ) + + if k_pe_reduced_out is not None: + k_pe = k_pe_reduced_out + + return q_c, q_c_scale, kv_c_normed, k_pe + + def _fuse_qkv_a_proj_reduce_rmsnorm_quant( hidden_states_quant: torch.Tensor, weight_qkv_a_proj: torch.Tensor, @@ -404,8 +618,21 @@ def _fuse_qkv_a_proj_reduce_rmsnorm_quant( output_unquantized_inp1, ) elif dtype_quant == dtypes.fp8: - # TODO add - raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") + q_c, q_c_scale, kv_c_normed, k_pe = _fuse_qkv_a_proj_reduce_rmsnorm_quant_fp8( + hidden_states_quant, + weight_qkv_a_proj, + weight_scale_qkv_a_proj, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + hidden_states_quant_scale, + output_unquantized_inp1, + transpose_scale, + ) else: raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") @@ -832,9 +1059,8 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.layer_num = layer_num - # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains BF16 GEMM, - # Otherwise, if use_triton_gemm() is off, all projections are BF16 GEMMs - # For FP8, use_triton_gemm() is ignored and AITER FP8 GEMM is used + # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs, + # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs if quant_config["quant_dtype"] == dtypes.fp4x2: if not use_triton_gemm(): # TODO use ignore layer for mxfp4 attention @@ -973,13 +1199,12 @@ def __init__( prefix=prefix, ) - # self.fuse_qknorm_quant is turned on only if FP8 or (FP4 and use_triton_gemm()), - # because for (FP4 and not use_triton_gemm()) case, BF16 GEMMs are used + # When ATOM_ENABLE_DS_QKNORM_QUANT_FUSION is turned on, self.fuse_qknorm_quant is turned on only if use_triton_gemm() and (FP8 or FP4), self.prefix = prefix self.quant_dtype = None self.fuse_qknorm_quant = False if quant_config is not None and ENABLE_DS_QKNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or (quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm()): + if (quant_config["quant_dtype"] == dtypes.fp8 or quant_config["quant_dtype"] == dtypes.fp4x2) and use_triton_gemm(): self.quant_dtype = quant_config["quant_dtype"] self.fuse_qknorm_quant = True @@ -993,7 +1218,7 @@ def forward( hidden_states, hidden_states_scale = hidden_states if self.q_lora_rank is not None: - if self.fuse_qknorm_quant and self.quant_dtype != dtypes.fp8: + if self.fuse_qknorm_quant: q_c, q_c_scale, kv_c_normed, k_pe = _fuse_qkv_a_proj_reduce_rmsnorm_quant( hidden_states, self.fused_qkv_a_proj.weight, @@ -1011,7 +1236,7 @@ def forward( scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, - transpose_scale=False, + transpose_scale=True, ) hidden_states_or_q_c = q_c hidden_states_or_q_c_scale = q_c_scale @@ -1024,25 +1249,7 @@ def forward( dim=-1, ) # fuse q_c norm + kv_c norm + quant of hidden_states_or_q_c - if self.fuse_qknorm_quant: - (hidden_states_or_q_c, - hidden_states_or_q_c_scale), _, kv_c_normed, _ = _fuse_rmsnorm_quant( - q_c, - self.q_a_layernorm.weight, - self.q_a_layernorm.eps, - kv_c, - self.kv_a_layernorm.weight, - self.kv_a_layernorm.eps, - None, - dtype_quant=self.quant_dtype, - shuffle=True, - scale_shuffle_padding=True, - group_size=128, - output_unquantized_inp1=False, - transpose_scale=False, - ) - else: - hidden_states_or_q_c = self.q_a_layernorm(q_c) + hidden_states_or_q_c = self.q_a_layernorm(q_c) else: hidden_states_or_q_c = hidden_states kv_c, k_pe = torch.split(self.kv_a_proj_with_mqa(hidden_states, hidden_states_scale), @@ -1109,21 +1316,23 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - # self.fuse_input_norm_quant is turned on only if FP8 or (FP4 and use_triton_gemm()), - # because for (FP4 and not use_triton_gemm()) case, BF16 GEMMs are used - # note that ENABLE_ALLREDUCE_RMSNORM_FUSION will be turned off if it was on + # When ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on self.fuse_input_norm_quant is turned on only if use_triton_gemm and (FP8 or FP4), + # Because AR_RMS and RMS_Quant cannot co-exist for input_layernorm, this block of codes ensures 3 things when ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on: + # 1. RMS_Quant fusion is only used for input_layernorm + # 2. The reduce_results variable is re-enabled for feed forward layers (MOE and MLP), because AR_RMS is now disabled in the beginning of the next layer + # 3. AR_RMS is turned off for input_layernorm but still enabled for post_attention_layernorm if ENABLE_ALLREDUCE_RMSNORM_FUSION is turned on self.quant_dtype = None self.fuse_input_norm_quant = False self.fuse_ar_input_norm = ENABLE_ALLREDUCE_RMSNORM_FUSION - - # this part enforce fuse_rms_quant and use standalone AR - if quant_config is not None and ENABLE_RMSNORM_QUANT_FUSION: - if quant_config["quant_dtype"] == dtypes.fp8 or (quant_config["quant_dtype"] == dtypes.fp4x2 and use_triton_gemm()): + if quant_config is not None and ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: + if (quant_config["quant_dtype"] == dtypes.fp8 or quant_config["quant_dtype"] == dtypes.fp4x2) and use_triton_gemm(): self.quant_dtype = quant_config["quant_dtype"] self.fuse_input_norm_quant = True if self.fuse_ar_input_norm: self.fuse_ar_input_norm = False - logger.info("Warning: Because ENABLE_RMSNORM_QUANT_FUSION is turned on, AR + RMS fusion is turned off for input_layernorm and reduce_results is re-enabled for first k dense layer down_proj") + logger.info("Warning: Because ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on, AR + RMS fusion is turned off for input_layernorm and reduce_results is re-enabled for first k dense layer down_proj") + else: + logger.info("Info: Because ATOM_USE_TRITON_GEMM is not turned on in DeepSeek-R1, ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned off automatically") if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace @@ -1135,11 +1344,6 @@ def __init__( prefix=f"{prefix}.mlp", ) else: - # next_layer_dense = True - # if (config.n_routed_experts is not None - # and (layer_idx + 1) >= config.first_k_dense_replace - # and (layer_idx + 1) % config.moe_layer_freq == 0): - # next_layer_dense = False self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -1155,6 +1359,8 @@ def __init__( eps=config.rms_norm_eps, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION) self.routed_scaling_factor = config.routed_scaling_factor + self.quant_dtype = quant_config["quant_dtype"] if quant_config else None + self.fuse_rmsnorm_quant = ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None def forward( @@ -1183,7 +1389,7 @@ def forward( scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, - transpose_scale=False, + transpose_scale=True, ) else: (hidden_states_quant, hidden_states_quant_scale), _, _, residual = _fuse_rmsnorm_quant( @@ -1199,7 +1405,7 @@ def forward( scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, - transpose_scale=False, + transpose_scale=True, ) hidden_states = (hidden_states_quant, hidden_states_quant_scale) @@ -1242,7 +1448,7 @@ def forward( return hidden_states, residual -@support_torch_compile +# @support_torch_compile class DeepseekV2Model(nn.Module): def __init__( self, @@ -1292,11 +1498,11 @@ def __init__( layer_num_offset=0 ) + # fused_allreduce will have to be turned off here if the fuse_ar_input_norm variable is False in the last layer if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, - fused_allreduce=self.layers[self.end_layer - 1].fuse_ar_input_norm,) - # fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION,) + fused_allreduce=self.layers[self.end_layer - 1].fuse_ar_input_norm) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = ( diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 641bf9a40..cb6bb50f7 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -13,10 +13,10 @@ "ATOM_DP_MASTER_PORT": lambda: int(os.getenv("ATOM_DP_MASTER_PORT", "29500")), "ATOM_ENFORCE_EAGER": lambda: os.getenv("ATOM_ENFORCE_EAGER", "0") == "1", "ATOM_ENABLE_DS_QKNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", "1") == "1", + "ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", "1") == "1", "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv("ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1") == "1", "ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1", "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", - "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1") == "1", "ATOM_ENABLE_QK_NORM_ROPE_FUSION": lambda: os.getenv("ATOM_ENABLE_QK_NORM_ROPE_FUSION", "1") == "1", # add qk-norm-rope-cache-quant fusion for Qwen3-Moe model, default disabled, # Qwen3-Moe model should enable this for better performance.