Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
20f3726
Use correct init method for fp8
ex-rzr Dec 12, 2025
aaa6777
Use smaller tolerance for LSE, check if before OUT
ex-rzr Dec 12, 2025
c7f724b
Create a base for MX FMHA: FP8 with 16x16x128 & 32x32x64 tiles withou…
ex-rzr Dec 12, 2025
5f1eb84
Extend enums, structs and args with mx-related values
ex-rzr Dec 15, 2025
026d1a2
Implement host side: generation of scales, validation with mx gemm
ex-rzr Dec 17, 2025
a0037ae
Pass dram windows for scales from kernel to pipeline
ex-rzr Dec 18, 2025
a58354c
Clone BlockGemmARegBSmemCRegV2 as BlockGemmMxARegBSmemCRegV1
ex-rzr Dec 19, 2025
cad0d3d
Support scales in WarpGemmAttributeMfma...
ex-rzr Dec 19, 2025
7a86e3d
Add scales for Q to the first GEMM
ex-rzr Dec 19, 2025
330b3c9
Add scales for K to the first GEMM
ex-rzr Dec 19, 2025
e26f744
Add scales for V to the second GEMM
ex-rzr Jan 6, 2026
ea3e97e
Shuffle K during loading from DRAM instead of LDS
ex-rzr Jan 8, 2026
fea9090
Implement simple calculation of scales for P
ex-rzr Jan 12, 2026
3a6d38a
Use cvt_scalef32_pk_fp8_f32 for P
ex-rzr Jan 12, 2026
c66c8f5
Use full fp8 range for P (448.0 instead of 1.0)
ex-rzr Jan 12, 2026
61db443
Fix cases when N0 != K1 (k1_loops > 1)
ex-rzr Jan 14, 2026
ab7f751
Support hdim=128 and 32x32x64 MFMA
ex-rzr Jan 14, 2026
c97cb10
Support bf8
ex-rzr Jan 15, 2026
dfc82f4
Fix get_y_sliced_thread_data (and hence get_slice_tile) for pk_fp4_t
ex-rzr Jan 15, 2026
a8ace9b
Support fp4
ex-rzr Jan 15, 2026
28ba17f
Fix bias initialization range with init_method=3
ex-rzr Jan 20, 2026
6c82bdc
Move K and K scales shuffling into BlockGemmMx
ex-rzr Jan 20, 2026
c0c3c05
Extract P mx casting into a separate function
ex-rzr Jan 21, 2026
d526919
Add fp4 traits, instances and tests
ex-rzr Jan 21, 2026
e9437d9
Fix errors after rebasing onto recent blockscale changes
ex-rzr Jan 22, 2026
4625670
Fix alignment of Q, K, V for fp4
ex-rzr Jan 22, 2026
0ac81de
Replace NaN e8m0_t with 0 for invalid (padded) scale values
ex-rzr Jan 23, 2026
1b78f66
Implement group mode
ex-rzr Jan 23, 2026
496c1a5
Use PackedSize for pointers modified with head/batch offsets
ex-rzr Jan 23, 2026
4aecdc4
Ensure that hdim_q and seqlen_k are even for fp4
ex-rzr Jan 23, 2026
c361f15
Fix compilation of pipelines and types without MX support
ex-rzr Jan 26, 2026
0571980
Ignore seqlen_kpads as it is not supported yet
ex-rzr Jan 27, 2026
96ca0e8
Fix V scale loading with windowed masks (like t:64,64)
ex-rzr Jan 27, 2026
2238668
Enable padded pipelines
ex-rzr Jan 27, 2026
957801a
Use PackedSize for fp4 alignment and LDS size calculations
ex-rzr Jan 27, 2026
12b8955
Add kScaleGranularity to WarpGemmAttributeMfmaImpl
ex-rzr Feb 3, 2026
e6710df
Update changelog
ex-rzr Feb 3, 2026
cb0d423
Fix MakePScaleRegTileDistribution
ex-rzr Feb 3, 2026
a04d839
Use faster tile size for hdim=256
ex-rzr Feb 3, 2026
e64cb22
Ensure that kv_eff_lens_per_batch values are even for fp4
ex-rzr Feb 3, 2026
a2fdd9c
Do not build mx tests on old archs without mx support
ex-rzr Feb 3, 2026
1fb0d96
Disable PaddingCases tests for mx types (like for fp8)
ex-rzr Feb 3, 2026
709862d
Refactor common constants in BlockGemmMxARegBSmemCRegV1
ex-rzr Feb 4, 2026
b3f3190
Move TargetCMPerLane out of block gemm
ex-rzr Feb 4, 2026
16b60ec
Fix reference_batched_mx.hpp file name
ex-rzr Feb 4, 2026
0e82f1a
Add new mx tests to REGRESSION_TESTS
ex-rzr Feb 6, 2026
b4d3985
Update help message for -qscale
ex-rzr Feb 6, 2026
3f88131
Extend static checks for Q and Q scale DRAM windows in K dim
ex-rzr Feb 9, 2026
c655e47
Improve comments with cast_tile_mx's implementation details
ex-rzr Feb 9, 2026
01cc47c
Format with clang-format-18
ex-rzr Feb 9, 2026
8c6095e
Fix compilation for c++17
ex-rzr Feb 9, 2026
3093308
Merge branch 'develop'
ex-rzr Feb 11, 2026
123422d
Update after merging develop
ex-rzr Feb 11, 2026
654a5ef
Build FMHA tests per type based on available instances
ex-rzr Feb 11, 2026
147c500
Merge branch 'develop'
ex-rzr Feb 12, 2026
5698038
Make cast_tile_mx more generic
ex-rzr Feb 12, 2026
f175504
Replace exp2 and log2 in cast_tile_mx with faster arithmetic
ex-rzr Feb 12, 2026
821d892
Merge branch 'develop'
ex-rzr Feb 19, 2026
f560eae
Merge branch 'develop'
ex-rzr Feb 20, 2026
e7cc973
Replace muls of P and O acc with adjusting P scales
ex-rzr Feb 20, 2026
d796f65
Fix rounding seqlen to even for mxfp4 when seqlen_k = -1
ex-rzr Feb 20, 2026
4a6456c
Merge branch 'develop'
ex-rzr Feb 21, 2026
06b1368
Merge branch 'develop'
ex-rzr Feb 23, 2026
0054db8
Use more precise algorithm of MX scale calculation
ex-rzr Feb 23, 2026
ebbadc4
Merge branch 'develop'
ex-rzr Feb 24, 2026
63aebd6
Merge branch 'develop'
ex-rzr Feb 27, 2026
2a378cd
Fix ambiguity in oob loading with customized_value for e8m0_t
ex-rzr Feb 27, 2026
591df34
Merge branch 'develop'
ex-rzr Mar 5, 2026
6be97d6
Merge branch 'develop'
ex-rzr Mar 10, 2026
911dd5d
Remove duplicated code
ex-rzr Mar 10, 2026
477ebee
Merge branch 'develop'
ex-rzr Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/composablekernel/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
* Added FP8 block scale quantization for FMHA forward kernel.
* Added gfx11 support for FMHA.
* Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only).

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32",
"mxfp8": "FmhaFwdMxFp8",
"mxfp4": "FmhaFwdMxFp4",
}

BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
Expand Down Expand Up @@ -79,13 +81,15 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
}

QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
"mx": "quant_scale_enum::mx",
}

BIAS_MAP = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"fp8bf16": 8,
"fp8fp32": 8,
"bf8": 8,
"mxfp8": 8,
"mxfp4": 4,
}

K0_MAX_SUBMAX_MAP = {
Expand Down Expand Up @@ -836,7 +838,8 @@ def get_rules(cls) -> List[CompatibilityRule]:
def check_hdim_tile(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
if problem_ctx.dtype != "fp32":
# FIX: too confusing that it has to know about mx types
if problem_ctx.dtype not in ("fp32", "mxfp8", "mxfp4"):
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and (
(
Expand Down Expand Up @@ -966,8 +969,6 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
return {
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
else:
raise ValueError(f"unsupported dtype={dtype}")

# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
Expand Down Expand Up @@ -1035,9 +1036,6 @@ def get_pipelines(
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
pass
return pipelines


Expand All @@ -1046,6 +1044,17 @@ class KernelComponentFactoryGfx950(
):
arch = ArchTrait("gfx950")

_DT_MXFP8 = ("mxfp8",)
_DT_MXFP4 = ("mxfp4",)

@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return (
KernelComponentFactoryGfx9.supported_dtypes()
+ cls._DT_MXFP8
+ cls._DT_MXFP4
)

@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
Expand All @@ -1054,6 +1063,18 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
elif dtype in cls._DT_MXFP8:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
elif dtype in cls._DT_MXFP4:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
return result

@classmethod
Expand Down Expand Up @@ -1091,6 +1112,19 @@ def get_pipelines(
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip

elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
# no need dropout kernels
lse = "t"
dropout = "f"
for logits, qscale, mask, bias, sink in itertools.product(
["f"],
["mx"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
return pipelines


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ auto create_args(int argc, char* argv[])
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("qscale",
"n",
"n or 0, no scale\n"
"pt or 1, per-tensor scale\n")
"quant scale:\n"
" n or 0, no scale\n"
" pt or 1, per-tensor scale\n"
" bs or 2, block scale\n"
" kvbs or 3, Q per-tensor, K/V per-page block scale\n"
" mx or 4, microscaling (exclusively for data types like mxfp8 and mxfp4)")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("iperm",
"1",
Expand All @@ -61,7 +65,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type: fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
Expand Down Expand Up @@ -231,6 +235,10 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
Expand All @@ -239,6 +247,14 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "mxfp8")
{
return run<FmhaFwdMxFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "mxfp4")
{
return run<FmhaFwdMxFp4>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
Expand Down
67 changes: 67 additions & 0 deletions projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ struct FmhaFwdFp8Fp32
{
};

struct FmhaFwdMxFp8
{
};

struct FmhaFwdMxFp4
{
};

template <typename DataType>
struct FmhaFwdTypeConfig;

Expand Down Expand Up @@ -165,6 +173,54 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8Fp32>
using ODataType = float;
};

template <>
struct FmhaFwdTypeConfig<FmhaFwdMxFp8>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;

using QScaleDataType = ck_tile::e8m0_t;
using KScaleDataType = ck_tile::e8m0_t;
using VScaleDataType = ck_tile::e8m0_t;
using PScaleDataType = ck_tile::e8m0_t;

static constexpr ck_tile::index_t kQKScaleGranularity = 32;
static constexpr ck_tile::index_t kVScaleGranularity = 32;
};

template <>
struct FmhaFwdTypeConfig<FmhaFwdMxFp4>
{
using QDataType = ck_tile::pk_fp4_t;
using KDataType = ck_tile::pk_fp4_t;
using VDataType = ck_tile::pk_fp4_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::pk_fp4_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;

using QScaleDataType = ck_tile::e8m0_t;
using KScaleDataType = ck_tile::e8m0_t;
using VScaleDataType = ck_tile::e8m0_t;
using PScaleDataType = ck_tile::e8m0_t;

static constexpr ck_tile::index_t kQKScaleGranularity = 32;
static constexpr ck_tile::index_t kVScaleGranularity = 32;
};

struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
Expand Down Expand Up @@ -232,6 +288,7 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* block_scale_seqstart_q_ptr;
const void* block_scale_seqstart_k_ptr;
const void* seqstart_v_scale_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
Expand All @@ -252,6 +309,9 @@ struct fmha_fwd_args
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t stride_q_descale;
ck_tile::index_t stride_k_descale;
ck_tile::index_t stride_v_descale;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
Expand Down Expand Up @@ -635,6 +695,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqlen_k_ptr,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.seqstart_v_scale_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
Expand All @@ -647,6 +708,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_o,
args.stride_q_descale,
args.stride_k_descale,
args.stride_v_descale,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
Expand Down Expand Up @@ -697,6 +761,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_o,
args.stride_q_descale,
args.stride_k_descale,
args.stride_v_descale,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
Expand Down
Loading
Loading