diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index ffadb03b5..d12bcaa0e 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -50,37 +50,39 @@ def use_triton_gemm() -> bool: from aiter import gemm_a4w4, per_1x32_f4_quant_hip def gemm_a4w4_quant_fake( - x: torch.Tensor, - weight: torch.Tensor, - otype: torch.dtype, - weight_scale: torch.Tensor, + x: torch.Tensor, + x_scale: torch.Tensor, + weight: torch.Tensor, + otype: torch.dtype, + weight_scale: torch.Tensor, params_dtype: torch.dtype, - input_scale: torch.Tensor, - output_size: int, -) -> torch.Tensor: - return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=otype, device=x.device) - + input_scale: torch.Tensor, + output_size: int) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], weight.shape[0]), dtype=otype, device=x.device + ) # It's important to use mutates_args=[] to avoid functionized_v2 op generation @torch_compile_guard(gen_fake=gemm_a4w4_quant_fake, mutates_args=[]) def gemm_a4w4_quant( - x: torch.Tensor, - weight: torch.Tensor, - otype: torch.dtype, - weight_scale: torch.Tensor, + x: torch.Tensor, + x_scale: torch.Tensor, + weight: torch.Tensor, + otype: torch.dtype, + weight_scale: torch.Tensor, params_dtype: torch.dtype, - input_scale: torch.Tensor, - output_size: int, -) -> torch.Tensor: - + input_scale: torch.Tensor, + output_size: int) -> torch.Tensor: + if gemm_afp4wfp4_preshuffle is None: - quant_func = get_hip_quant(QuantType.per_1x32) - x, x_scale = quant_func( - x, - quant_dtype=params_dtype, - scale=input_scale, - shuffle=True, - ) + if x_scale is None: + quant_func = get_hip_quant(QuantType.per_1x32) + x, x_scale = quant_func( + x, + quant_dtype=params_dtype, + scale=input_scale, + shuffle=True, + ) m = x.view(-1, x.size(-1)).shape[0] y = torch.empty( @@ -104,13 +106,13 @@ def gemm_a4w4_quant( dtype=otype, device=x.device, ) - - quant_func = get_hip_quant(QuantType.per_1x32) - x, x_scale = quant_func( - x, - quant_dtype=params_dtype, - shuffle=(m >= 32), - ) + if x_scale is None: + quant_func = get_hip_quant(QuantType.per_1x32) + x, x_scale = quant_func( + x, + quant_dtype=params_dtype, + shuffle=(m >= 32), + ) if m >= 32: x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // 32, -1) @@ -360,13 +362,14 @@ def forward( y += self.bias elif self.quant_type.value == QuantType.per_1x32.value: y = gemm_a4w4_quant( - x, - self.weight, - otype, - self.weight_scale.data, - self.params_dtype, - getattr(self, "input_scale", None), - self.output_size, + x, + x_scale, + self.weight, + otype, + self.weight_scale.data, + self.params_dtype, + getattr(self, "input_scale", None), + self.output_size ) if self.bias is not None: y += self.bias diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index ead7e680b..6e7d5a063 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -85,12 +85,97 @@ from atom.utils.custom_register import direct_register_custom_op from atom.utils.decorators import support_torch_compile from atom.utils import envs +from aiter.jit.utils.torch_guard import torch_compile_guard # from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION +ENABLE_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_RMSNORM_QUANT_FUSION + + +def _fuse_rmsnorm_fp4_quant_fake( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + shuffle: bool = True, + scale_shuffle_padding: bool = True, + output_unquantized_inp1: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + m, n1 = x1.shape + n2 = x2.shape[1] if x2 is not None else 0 + MXFP4_QUANT_BLOCK_SIZE = 32 + + out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device) + + scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + + scale_m = ((m + 255) // 256) * 256 + scale_n = ((scale_n_valid + 7) // 8) * 8 + + out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device) + + out2 = None + if x2 is not None: + out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device) + + out_res1 = None + if res1 is not None: + out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device) + + out1_unquantized = None + + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 + + +@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake) +def _fuse_rmsnorm_fp4_quant( + x1: torch.Tensor, + x1_weight: torch.Tensor, + x1_epsilon: float, + x2: Optional[torch.Tensor] = None, + x2_weight: Optional[torch.Tensor] = None, + x2_epsilon: Optional[float] = None, + res1: Optional[torch.Tensor] = None, + shuffle: bool = True, + scale_shuffle_padding: bool = True, + output_unquantized_inp1: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + m = x1.shape[0] + + shuffle_bool = True and (m >= 32) + + (out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = fused_rms_mxfp4_quant( + x1=x1, + x1_weight=x1_weight, + x1_epsilon=x1_epsilon, + x2=x2, + x2_weight=x2_weight, + x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon, + res1=res1, + shuffle=shuffle_bool, + scale_shuffle_padding=True, + output_unquantized_inp1=output_unquantized_inp1, + ) + + out1_unquantized = None + return out1_quantized, out1_bs, out1_unquantized, out2, out_res1 -# only for DS MLA attention def _fuse_rmsnorm_quant( x1: torch.Tensor, x1_weight: torch.Tensor, @@ -106,7 +191,20 @@ def _fuse_rmsnorm_quant( output_unquantized_inp1=False, transpose_scale=False, ): - if dtype_quant == dtypes.fp8: + if dtype_quant == dtypes.fp4x2: + out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = _fuse_rmsnorm_fp4_quant( + x1, + x1_weight, + x1_epsilon, + x2, + x2_weight, + x2_epsilon, + res1, + shuffle, + scale_shuffle_padding, + output_unquantized_inp1, + ) + elif dtype_quant == dtypes.fp8: (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_fp8_group_quant( x1, x1_weight, @@ -120,20 +218,6 @@ def _fuse_rmsnorm_quant( output_unquantized_inp1, transpose_scale, ) - elif dtype_quant == dtypes.fp4x2: - (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_mxfp4_quant( - x1, - x1_weight, - x1_epsilon, - x2, - x2_weight, - x2_epsilon, - res1, - shuffle, - scale_shuffle_padding, - output_unquantized_inp1, - ) - # out1_unquantized = None else: raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.") return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1 @@ -679,16 +763,28 @@ def __init__( ) self.prefix = prefix - self.quant_dtype = non_proj_quant_config["quant_dtype"] if non_proj_quant_config else None - self.fuse_qknorm_quant = ENABLE_DS_QKNORM_QUANT_FUSION and self.quant_dtype is not None + if quant_config["quant_dtype"] == torch.float4_e2m1fn_x2: + if use_triton_gemm(): + self.quant_dtype = quant_config["quant_dtype"] + self.fuse_qknorm_quant = True + else: + self.quant_dtype = None + self.fuse_qknorm_quant = False + + # self.quant_dtype = non_proj_quant_config["quant_dtype"] if non_proj_quant_config else None + # self.fuse_qknorm_quant = ENABLE_DS_QKNORM_QUANT_FUSION and self.quant_dtype is not None def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + if self.q_lora_rank is not None: - qkv_lora = self.fused_qkv_a_proj(hidden_states) + qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) # ckq = self.q_a_proj(hidden_states) q_c, kv_c, k_pe = torch.split( qkv_lora, @@ -697,8 +793,8 @@ def forward( ) # fuse q_c norm + kv_c norm + quant of hidden_states_or_q_c if self.fuse_qknorm_quant: - (hidden_states_or_q_c, - hidden_states_or_q_c_scale), _, kv_c_normed, _ = _fuse_rmsnorm_quant( + (hidden_states_or_q_c, + hidden_states_or_q_c_scale), _, kv_c_normed, _ = _fuse_rmsnorm_quant( q_c, self.q_a_layernorm.weight, self.q_a_layernorm.eps, @@ -707,8 +803,8 @@ def forward( self.kv_a_layernorm.eps, None, dtype_quant=self.quant_dtype, - shuffle=False, - scale_shuffle_padding=False, + shuffle=True, + scale_shuffle_padding=True, group_size=128, output_unquantized_inp1=False, transpose_scale=False, @@ -717,7 +813,7 @@ def forward( hidden_states_or_q_c = self.q_a_layernorm(q_c) else: hidden_states_or_q_c = hidden_states - kv_c, k_pe = torch.split(self.kv_a_proj_with_mqa(hidden_states), + kv_c, k_pe = torch.split(self.kv_a_proj_with_mqa(hidden_states, hidden_states_scale), [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) if not self.fuse_qknorm_quant: kv_c_normed = self.kv_a_layernorm(kv_c) @@ -798,6 +894,7 @@ def __init__( eps=config.rms_norm_eps, fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION) self.routed_scaling_factor = config.routed_scaling_factor + self.quant_dtype = quant_config["quant_dtype"] if quant_config else None def forward( self, @@ -806,12 +903,53 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if ENABLE_RMSNORM_QUANT_FUSION: + assert self.quant_dtype is not None + weight = self.input_layernorm.weight + eps = self.input_layernorm.eps + if residual is None: + residual = hidden_states + (hidden_states_quant, hidden_states_quant_scale), _, _, _ = _fuse_rmsnorm_quant( + hidden_states, + weight, + eps, + None, + None, + None, + None, + dtype_quant=self.quant_dtype, + shuffle=False, + scale_shuffle_padding=False, + group_size=128, + output_unquantized_inp1=False, + transpose_scale=False, + ) + else: + (hidden_states_quant, hidden_states_quant_scale), _, _, residual = _fuse_rmsnorm_quant( + hidden_states, + weight, + eps, + None, + None, + None, + residual, + dtype_quant=self.quant_dtype, + shuffle=False, + scale_shuffle_padding=False, + group_size=128, + output_unquantized_inp1=False, + transpose_scale=False, + ) + + hidden_states = (hidden_states_quant, hidden_states_quant_scale) + else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 609559077..cbd276ed0 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -16,6 +16,7 @@ "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv("ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1") == "1", "ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1", "ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1", + "ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1") == "1", } def __getattr__(name: str):