Skip to content
79 changes: 41 additions & 38 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,39 @@ def use_triton_gemm() -> bool:
from aiter import gemm_a4w4, per_1x32_f4_quant_hip

def gemm_a4w4_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
otype: torch.dtype,
weight_scale: torch.Tensor,
x: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
otype: torch.dtype,
weight_scale: torch.Tensor,
params_dtype: torch.dtype,
input_scale: torch.Tensor,
output_size: int,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=otype, device=x.device)

input_scale: torch.Tensor,
output_size: int) -> torch.Tensor:
return torch.empty(
(*x.shape[:-1], weight.shape[0]), dtype=otype, device=x.device
)

# It's important to use mutates_args=[] to avoid functionized_v2 op generation
@torch_compile_guard(gen_fake=gemm_a4w4_quant_fake, mutates_args=[])
def gemm_a4w4_quant(
x: torch.Tensor,
weight: torch.Tensor,
otype: torch.dtype,
weight_scale: torch.Tensor,
x: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
otype: torch.dtype,
weight_scale: torch.Tensor,
params_dtype: torch.dtype,
input_scale: torch.Tensor,
output_size: int,
) -> torch.Tensor:

input_scale: torch.Tensor,
output_size: int) -> torch.Tensor:

if gemm_afp4wfp4_preshuffle is None:
quant_func = get_hip_quant(QuantType.per_1x32)
x, x_scale = quant_func(
x,
quant_dtype=params_dtype,
scale=input_scale,
shuffle=True,
)
if x_scale is None:
quant_func = get_hip_quant(QuantType.per_1x32)
x, x_scale = quant_func(
x,
quant_dtype=params_dtype,
scale=input_scale,
shuffle=True,
)

m = x.view(-1, x.size(-1)).shape[0]
y = torch.empty(
Expand All @@ -104,13 +106,13 @@ def gemm_a4w4_quant(
dtype=otype,
device=x.device,
)

quant_func = get_hip_quant(QuantType.per_1x32)
x, x_scale = quant_func(
x,
quant_dtype=params_dtype,
shuffle=(m >= 32),
)
if x_scale is None:
quant_func = get_hip_quant(QuantType.per_1x32)
x, x_scale = quant_func(
x,
quant_dtype=params_dtype,
shuffle=(m >= 32),
)

if m >= 32:
x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // 32, -1)
Expand Down Expand Up @@ -360,13 +362,14 @@ def forward(
y += self.bias
elif self.quant_type.value == QuantType.per_1x32.value:
y = gemm_a4w4_quant(
x,
self.weight,
otype,
self.weight_scale.data,
self.params_dtype,
getattr(self, "input_scale", None),
self.output_size,
x,
x_scale,
self.weight,
otype,
self.weight_scale.data,
self.params_dtype,
getattr(self, "input_scale", None),
self.output_size
)
if self.bias is not None:
y += self.bias
Expand Down
196 changes: 167 additions & 29 deletions atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,97 @@
from atom.utils.custom_register import direct_register_custom_op
from atom.utils.decorators import support_torch_compile
from atom.utils import envs
from aiter.jit.utils.torch_guard import torch_compile_guard
# from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8

ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION
ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION
ENABLE_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_RMSNORM_QUANT_FUSION


def _fuse_rmsnorm_fp4_quant_fake(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
shuffle: bool = True,
scale_shuffle_padding: bool = True,
output_unquantized_inp1: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
m, n1 = x1.shape
n2 = x2.shape[1] if x2 is not None else 0
MXFP4_QUANT_BLOCK_SIZE = 32

out1_quantized = torch.empty((m, n1 // 2), dtype=torch.uint8, device=x1.device)

scale_n_valid = (n1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE

scale_m = ((m + 255) // 256) * 256
scale_n = ((scale_n_valid + 7) // 8) * 8

out1_bs = torch.empty((scale_m, scale_n), dtype=torch.uint8, device=x1.device)

out2 = None
if x2 is not None:
out2 = torch.empty((m, n2), dtype=x1.dtype, device=x1.device)

out_res1 = None
if res1 is not None:
out_res1 = torch.empty((m, n1), dtype=x1.dtype, device=x1.device)

out1_unquantized = None

return out1_quantized, out1_bs, out1_unquantized, out2, out_res1


@torch_compile_guard(gen_fake=_fuse_rmsnorm_fp4_quant_fake)
def _fuse_rmsnorm_fp4_quant(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x1_epsilon: float,
x2: Optional[torch.Tensor] = None,
x2_weight: Optional[torch.Tensor] = None,
x2_epsilon: Optional[float] = None,
res1: Optional[torch.Tensor] = None,
shuffle: bool = True,
scale_shuffle_padding: bool = True,
output_unquantized_inp1: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
m = x1.shape[0]

shuffle_bool = True and (m >= 32)

(out1_quantized, out1_bs), _out1_unquantized, out2, out_res1 = fused_rms_mxfp4_quant(
x1=x1,
x1_weight=x1_weight,
x1_epsilon=x1_epsilon,
x2=x2,
x2_weight=x2_weight,
x2_epsilon=0.0 if x2_epsilon is None else x2_epsilon,
res1=res1,
shuffle=shuffle_bool,
scale_shuffle_padding=True,
output_unquantized_inp1=output_unquantized_inp1,
)

out1_unquantized = None
return out1_quantized, out1_bs, out1_unquantized, out2, out_res1

# only for DS MLA attention
def _fuse_rmsnorm_quant(
x1: torch.Tensor,
x1_weight: torch.Tensor,
Expand All @@ -106,7 +191,20 @@ def _fuse_rmsnorm_quant(
output_unquantized_inp1=False,
transpose_scale=False,
):
if dtype_quant == dtypes.fp8:
if dtype_quant == dtypes.fp4x2:
out1_quantized, out1_bs, out1_unquantized, out2, out_res1 = _fuse_rmsnorm_fp4_quant(
x1,
x1_weight,
x1_epsilon,
x2,
x2_weight,
x2_epsilon,
res1,
shuffle,
scale_shuffle_padding,
output_unquantized_inp1,
)
elif dtype_quant == dtypes.fp8:
(out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_fp8_group_quant(
x1,
x1_weight,
Expand All @@ -120,20 +218,6 @@ def _fuse_rmsnorm_quant(
output_unquantized_inp1,
transpose_scale,
)
elif dtype_quant == dtypes.fp4x2:
(out1_quantized, out1_bs), out1_unquantized, out2, out_res1 = fused_rms_mxfp4_quant(
x1,
x1_weight,
x1_epsilon,
x2,
x2_weight,
x2_epsilon,
res1,
shuffle,
scale_shuffle_padding,
output_unquantized_inp1,
)
# out1_unquantized = None
else:
raise ValueError(f"No fused rmsnorm quant kernel availble for quant dtype: {dtype_quant}.")
return (out1_quantized, out1_bs), out1_unquantized, out2, out_res1
Expand Down Expand Up @@ -679,16 +763,28 @@ def __init__(
)

self.prefix = prefix
self.quant_dtype = non_proj_quant_config["quant_dtype"] if non_proj_quant_config else None
self.fuse_qknorm_quant = ENABLE_DS_QKNORM_QUANT_FUSION and self.quant_dtype is not None
if quant_config["quant_dtype"] == torch.float4_e2m1fn_x2:
if use_triton_gemm():
self.quant_dtype = quant_config["quant_dtype"]
self.fuse_qknorm_quant = True
else:
self.quant_dtype = None
self.fuse_qknorm_quant = False

# self.quant_dtype = non_proj_quant_config["quant_dtype"] if non_proj_quant_config else None
# self.fuse_qknorm_quant = ENABLE_DS_QKNORM_QUANT_FUSION and self.quant_dtype is not None

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states_scale = None
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_scale = hidden_states

if self.q_lora_rank is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states)
qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale)
# ckq = self.q_a_proj(hidden_states)
q_c, kv_c, k_pe = torch.split(
qkv_lora,
Expand All @@ -697,8 +793,8 @@ def forward(
)
# fuse q_c norm + kv_c norm + quant of hidden_states_or_q_c
if self.fuse_qknorm_quant:
(hidden_states_or_q_c,
hidden_states_or_q_c_scale), _, kv_c_normed, _ = _fuse_rmsnorm_quant(
(hidden_states_or_q_c,
hidden_states_or_q_c_scale), _, kv_c_normed, _ = _fuse_rmsnorm_quant(
q_c,
self.q_a_layernorm.weight,
self.q_a_layernorm.eps,
Expand All @@ -707,8 +803,8 @@ def forward(
self.kv_a_layernorm.eps,
None,
dtype_quant=self.quant_dtype,
shuffle=False,
scale_shuffle_padding=False,
shuffle=True,
scale_shuffle_padding=True,
group_size=128,
output_unquantized_inp1=False,
transpose_scale=False,
Expand All @@ -717,7 +813,7 @@ def forward(
hidden_states_or_q_c = self.q_a_layernorm(q_c)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = torch.split(self.kv_a_proj_with_mqa(hidden_states),
kv_c, k_pe = torch.split(self.kv_a_proj_with_mqa(hidden_states, hidden_states_scale),
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not self.fuse_qknorm_quant:
kv_c_normed = self.kv_a_layernorm(kv_c)
Expand Down Expand Up @@ -798,6 +894,7 @@ def __init__(
eps=config.rms_norm_eps,
fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION)
self.routed_scaling_factor = config.routed_scaling_factor
self.quant_dtype = quant_config["quant_dtype"] if quant_config else None

def forward(
self,
Expand All @@ -806,12 +903,53 @@ def forward(
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if ENABLE_RMSNORM_QUANT_FUSION:
assert self.quant_dtype is not None
weight = self.input_layernorm.weight
eps = self.input_layernorm.eps
if residual is None:
residual = hidden_states
(hidden_states_quant, hidden_states_quant_scale), _, _, _ = _fuse_rmsnorm_quant(
hidden_states,
weight,
eps,
None,
None,
None,
None,
dtype_quant=self.quant_dtype,
shuffle=False,
scale_shuffle_padding=False,
group_size=128,
output_unquantized_inp1=False,
transpose_scale=False,
)
else:
(hidden_states_quant, hidden_states_quant_scale), _, _, residual = _fuse_rmsnorm_quant(
hidden_states,
weight,
eps,
None,
None,
None,
residual,
dtype_quant=self.quant_dtype,
shuffle=False,
scale_shuffle_padding=False,
group_size=128,
output_unquantized_inp1=False,
transpose_scale=False,
)

hidden_states = (hidden_states_quant, hidden_states_quant_scale)

else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand Down
1 change: 1 addition & 0 deletions atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv("ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1") == "1",
"ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1",
"ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1",
"ATOM_ENABLE_RMSNORM_QUANT_FUSION": lambda: os.getenv("ATOM_ENABLE_RMSNORM_QUANT_FUSION", "1") == "1",
}

def __getattr__(name: str):
Expand Down