Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8741164
tmp
k50112113 Dec 11, 2025
28b9bc9
fix
k50112113 Dec 11, 2025
1831a5d
clean
k50112113 Dec 11, 2025
fcf4daf
Making the BMM use fp4 weights
omuhamma Dec 12, 2025
0f08fae
add ATOM_USE_TRITON_GEMM and a16wfp4 gemm for o_proj
k50112113 Dec 12, 2025
4058924
Cleaning up the code and ensuring other weights wont crash
omuhamma Dec 12, 2025
3c79ba8
add import check for gemm_a16wfp4_preshuffle
k50112113 Dec 12, 2025
dadb3f4
Merge branch 'shaoclee/ds_fp4_gemm' into omuhamma/bmm
omuhamma Dec 12, 2025
7c4e8ae
Merge pull request #44 from ROCm/omuhamma/bmm
k50112113 Dec 12, 2025
b6ee629
clean
k50112113 Dec 15, 2025
2fd1842
clean
k50112113 Dec 15, 2025
53a2ca1
disable FP4 triton GEMM on o_proj on DS FP4
k50112113 Dec 15, 2025
05c39f5
Fused rms for fp4
omuhamma Dec 16, 2025
8b575e8
Adding the x_scale change in linear.py to choose when to quantize or not
omuhamma Dec 16, 2025
352d916
Enabling the second fused rms before attention
omuhamma Dec 16, 2025
49b8d14
Merge remote-tracking branch 'origin/main' into shaoclee/ds_fp4_gemm
k50112113 Dec 17, 2025
1f623a3
Fixing issue where there was a shape mismatch when running the second…
omuhamma Dec 17, 2025
c9e5280
Merge branch 'shaoclee/ds_fp4_gemm' into omuhamma/dsfp4-rms
omuhamma Dec 17, 2025
bec4b20
Marking shuffle and shuffle padding as true temporarily always
omuhamma Dec 17, 2025
8848851
Working implemenation of fused_rms for fp4
omuhamma Dec 19, 2025
f3b8681
Formatting fixes
omuhamma Dec 19, 2025
35e8338
Fix syntax error
omuhamma Dec 19, 2025
c14a5d8
Remove some commented code from the fp4 section
omuhamma Dec 19, 2025
98fdbb4
Merge pull request #61 from ROCm/omuhamma/dsfp4-rms
k50112113 Dec 26, 2025
dcff526
disable only AR + input layernorm with ATOM_ENABLE_RMSNORM_QUANT_FUSI…
k50112113 Dec 26, 2025
04b1301
add _fuse_qkv_a_proj_reduce_rmsnorm_quant for DS FP4
k50112113 Dec 29, 2025
ea52f43
add gemm split + cat for DS FP4
k50112113 Dec 29, 2025
612bd7e
Integreated fused rmsnorm + quant in decoder layer
farlukas Dec 12, 2025
e52e722
No need to fuse post attention
farlukas Dec 12, 2025
ec27b85
Refactored fusion condition
farlukas Dec 16, 2025
9640efc
Transpose scales for input layernorm
farlukas Dec 17, 2025
c8584cb
Added torch compile guards on fusion to enable torch compiler
farlukas Dec 18, 2025
35e79c4
Refactored fp8 fused rms quant function
farlukas Dec 19, 2025
a1f16ea
Added fp8 triton preshuffled gemm
farlukas Dec 31, 2025
3d01e02
Fixed triton gemm condition
farlukas Jan 2, 2026
d7e3e80
Added fused rmsnorm quant fp8 back in
farlukas Jan 2, 2026
9dc331c
Added transpose_scale back to fp8 fake function
farlukas Jan 6, 2026
20ac850
Remove duplicate env
farlukas Jan 6, 2026
279d7dd
Implemented fp8 gemm preshuffled + split + cat
farlukas Jan 6, 2026
e93a053
Merge remote-tracking branch 'origin/main' into shaoclee/ds_fp4_gemm
k50112113 Jan 6, 2026
866a01f
add back triton fusk_rope_kv_cache
k50112113 Jan 6, 2026
e18974e
consider both AR_RMS + Quant and AR + RMS_Quant condition via ATOM_EN…
k50112113 Jan 7, 2026
53f00e3
Merge branch 'shaoclee/ds_fp4_gemm' into farlukas/dsfp8-fusedrmsnorm
farlukas Jan 7, 2026
678ca28
Implemented fp8 fused reduce rms quant
farlukas Jan 7, 2026
3bf537c
change boundary
k50112113 Jan 7, 2026
b26f81f
Removed unreachable branch
farlukas Jan 7, 2026
b46db44
Added transpose_scale to fused reduce rms quant
farlukas Jan 8, 2026
d9fd150
fix
k50112113 Jan 8, 2026
bbd4198
clean
k50112113 Jan 8, 2026
55fafbe
Merge pull request #116 from ROCm/farlukas/dsfp8-fusedrmsnorm
k50112113 Jan 8, 2026
e09758f
add a16w8 preshuffle gemm
k50112113 Jan 8, 2026
19e74a7
clean
k50112113 Jan 8, 2026
26994c2
change fp8 gemm boundary
k50112113 Jan 8, 2026
357f3ba
triton fp8 gemm rename
k50112113 Jan 9, 2026
b6e1b89
Merge remote-tracking branch 'origin/main' into shaoclee/ds_fp4_gemm
k50112113 Jan 9, 2026
5817f32
remove loader change
k50112113 Jan 9, 2026
788dc30
remove comments
k50112113 Jan 9, 2026
3446254
address comments
k50112113 Jan 12, 2026
0b7e98b
Merge branch 'main' into shaoclee/ds_fp4_gemm
k50112113 Jan 13, 2026
d9b116e
Merge remote-tracking branch 'origin/main' into shaoclee/ds_fp4_gemm
k50112113 Jan 13, 2026
57aecda
Merge branch 'main' into shaoclee/ds_fp4_gemm
k50112113 Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 107 additions & 10 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from dataclasses import dataclass
from typing import Optional, Tuple
from functools import partial as functools_partial

import torch
from aiter import (
Expand All @@ -22,12 +23,17 @@
ForwardContext,
get_forward_context,
)

from atom.model_ops.linear import use_triton_gemm
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm,
)
from aiter import (
QuantType,
get_hip_quant,
)

# from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla
# # from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla
from aiter import fused_qk_rope_concat_and_cache_mla
from aiter.dist.parallel_state import get_dp_group

Expand All @@ -37,14 +43,20 @@

logger = logging.getLogger("atom")

if use_triton_gemm():
try:
from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat
except ImportError as e:
logger.warning(f"Triton fused GEMM split_cat not available: {e}")
fused_gemm_afp4wfp4_preshuffle_split_cat = None
fused_gemm_a8w8_blockscale_preshuffle_split_cat = None

def is_rocm_aiter_fp4bmm_enabled() -> bool:
return envs.ATOM_USE_TRITON_MXFP4_BMM

return envs.ATOM_USE_TRITON_MXFP4_BMM

if is_rocm_aiter_fp4bmm_enabled():
from atom.model_ops.utils import quark_post_load_weights

# from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4

Expand Down Expand Up @@ -234,14 +246,84 @@ def _forward_prefill_mha(
) -> torch.Tensor:
assert attn_metadata is not None

kv_nope = self.kv_b_proj(kv_c_normed).view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

if k_rope.dim() == 2:
k_rope = k_rope.unsqueeze(1)
k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)

if use_triton_gemm():
weight = self.kv_b_proj.weight
weight_scale = self.kv_b_proj.weight_scale
if fused_gemm_afp4wfp4_preshuffle_split_cat is not None and weight.dtype == dtypes.fp4x2: # FP4 GEMM + split + cat
m = kv_c_normed.shape[0]
# from aiter.ops.triton.quant import dynamic_mxfp4_quant
# input = kv_c_normed
# input_2d = input.view(-1, input.shape[-1])
output_dtype = kv_c_normed.dtype

# q_input, x_scale = dynamic_mxfp4_quant(input_2d)
quant_func = get_hip_quant(QuantType.per_1x32)
q_input, x_scale = quant_func(
kv_c_normed,
quant_dtype=dtypes.fp4x2,
shuffle=(m >= 32),
)

if m >= 32:
x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // 32, -1)
else:
x_scale = x_scale[:m, ...].view(torch.uint8)

k, v = fused_gemm_afp4wfp4_preshuffle_split_cat(
q_input.view(torch.uint8),
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
k_rope.expand((-1, self.num_heads, -1)),
x_scale,
weight_scale.view(torch.uint8).view(weight_scale.shape[0] // 32, -1),
self.qk_nope_head_dim,
self.v_head_dim,
output_dtype
)
elif fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None and weight.dtype == dtypes.fp8: # FP8 GEMM + split + cat
weight_shuffled = weight.reshape(
weight.shape[0] // 16,
weight.shape[1] * 16
)

output_dtype = kv_c_normed.dtype

quant_func = functools_partial(
get_hip_quant(QuantType.per_1x128),
transpose_scale=True
)
q_input, x_scale = quant_func(
kv_c_normed,
quant_dtype=dtypes.fp8,
scale=getattr(self.kv_b_proj, "input_scale", None)
)

k, v = fused_gemm_a8w8_blockscale_preshuffle_split_cat(
q_input,
weight_shuffled,
k_rope.expand((-1, self.num_heads, -1)),
x_scale,
weight_scale,
self.qk_nope_head_dim,
self.v_head_dim,
output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed).view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)
else:
kv_nope = self.kv_b_proj(kv_c_normed).view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = flash_attn_varlen_func(
q=q,
Expand Down Expand Up @@ -485,6 +567,21 @@ def forward(
is_neox=self.rotary_emb.is_neox_style,
is_nope_first=True,
)
# from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla
# decode_q, _, _, _ = fused_qk_rope_cat_and_cache_mla(
# q_nope,
# q_rope,
# k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank),
# k_rope.view(-1, self.num_kv_heads, self.qk_rope_head_dim),
# kv_cache,
# attn_metadata.slot_mapping,
# positions,
# self.rotary_emb.cos_cache,
# self.rotary_emb.sin_cache,
# k_scale=self._k_scale,
# is_neox=self.rotary_emb.is_neox_style,
# q_out_dtype=kv_cache.dtype,
# )

if context.is_prefill:
output = self._forward_prefill_mla(q_out, kv_cache, attn_metadata)
Expand Down
Loading