Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ def name(self) -> str:
# 3: kernelInstanceGEMM1( 256, 256, 128, 128, 2, 2, 3,),
}

# bns gemm1 out:bf16/fp16 A:mxfp4 B:mxfp4
a4w4_bns_gemm1_kernels_list= {
0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 3,),
1: kernelInstanceGEMM1( 256, 64, 64, 128, 2, 2, 3,),
2: kernelInstanceGEMM1( 256, 128, 64, 128, 2, 2, 3,),
}

gemm1_kernels_dict = {
"a16w16_gfx950": a16w16_gemm1_kernels_list_gfx950,
"a16w16": a16w16_gemm1_kernels_list,
Expand All @@ -205,6 +212,7 @@ def name(self) -> str:
"a8w8blkscale": a8w8_gemm1_blockscale_kernels_list,
"a8w4": a8w4_gemm1_kernels_list,
"a4w4": a4w4_gemm1_kernels_list,
"a4w4_bns": a4w4_bns_gemm1_kernels_list,
}


Expand Down Expand Up @@ -284,6 +292,15 @@ def name(self) -> str:
# 6: kernelInstanceGEMM2( 256, 128, 64, 128, 2, 2, 3,),
# 7: kernelInstanceGEMM2( 256, 256, 64, 128, 2, 2, 3,),
}
# gemm2 out:bf16/fp16 A:fp8 B:in4
a4w4_bns_gemm2_kernels_list= {
0: kernelInstanceGEMM2( 64, 32, 32, 128, 1, 1, 1,),
1: kernelInstanceGEMM2( 64, 64, 64, 128, 1, 1, 1,),
2: kernelInstanceGEMM2( 64, 128, 128, 128, 1, 1, 1,),
4: kernelInstanceGEMM2( 256, 32, 128, 128, 1, 4, 3,),
5: kernelInstanceGEMM2( 256, 64, 64, 128, 2, 2, 3,),
6: kernelInstanceGEMM2( 256, 128, 64, 128, 2, 2, 3,),
}

# fmt: on
gemm2_kernels_dict = {
Expand All @@ -294,6 +311,7 @@ def name(self) -> str:
"a8w8blkscale": a8w8_gemm2_blockscale_kernels_list,
"a8w4": a8w4_gemm2_kernels_list,
"a4w4": a4w4_gemm2_kernels_list,
"a4w4_bns": a4w4_bns_gemm2_kernels_list,
}


Expand All @@ -302,6 +320,7 @@ def name(self) -> str:
bit4_list = ["I4", "i4", "FP4X2", "fp4x2"]
QuantType_list = [3, 4]

bns_or_preslf = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please set the flag the parameter user passed by.


def get_gemm1_kernels_list(
Adtype: str,
Expand All @@ -312,6 +331,7 @@ def get_gemm1_kernels_list(
ActOP: str,
MulRoutedWeight: bool,
) -> list:
global bns_or_preslf
arch = get_gfx()
if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype:
if arch == "gfx950":
Expand All @@ -337,7 +357,10 @@ def get_gemm1_kernels_list(
):
tag = "a8w4"
elif Adtype in bit4_list and Bdtype in bit4_list:
tag = "a4w4"
if bns_or_preslf:
tag = "a4w4_bns"
else:
tag = "a4w4"
else:
raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}")
kernels_list = gemm1_kernels_dict[tag]
Expand Down Expand Up @@ -372,6 +395,7 @@ def get_gemm2_kernels_list(
QuantType: str,
MulRoutedWeight: bool,
) -> list:
global bns_or_preslf
arch = get_gfx()

if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype:
Expand All @@ -398,7 +422,10 @@ def get_gemm2_kernels_list(
):
tag = "a8w4"
elif Adtype in bit4_list and Bdtype in bit4_list:
tag = "a4w4"
if bns_or_preslf:
tag = "a4w4_bns"
else:
tag = "a4w4"
else:
raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}")
kernels_list = gemm2_kernels_dict[tag]
Expand Down
Loading
Loading