66import torch
77from dequantize_utils import torch_convert_bit_twiddling , assert_similar
88from tilelang .autotuner import set_autotune_inputs
9+ import argparse
910
1011
1112def get_configs ():
@@ -433,13 +434,18 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
433434 expert_ids = torch .tensor (expert_ids , dtype = torch .int32 , device = 'cuda' ) # (padding_M,)
434435 padding_M = sorted_token_ids .shape [0 ] # padding_M: token number after padding
435436
436- print (f'{ sorted_token_ids = } ' )
437- print (f'{ expert_ids = } ' )
438-
439437 return A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids , padding_M
440438
441439
442- def main (m = 256 , n = 256 , k = 256 , scale_size = 32 , fast_dequant = True , with_bias = False , topk = 4 , E = 32 ):
440+ def main (m = 256 ,
441+ n = 256 ,
442+ k = 256 ,
443+ scale_size = 32 ,
444+ topk = 4 ,
445+ E = 32 ,
446+ fast_dequant = True ,
447+ with_bias = False ,
448+ tune = False ):
443449 # Tunable parameters
444450 block_M , block_N , block_K = 128 , 256 , 128 # noqa: F841
445451 num_stages = 1 # noqa: F841
@@ -453,8 +459,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
453459 A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids , padding_M = get_data (
454460 m , n , k , qk , scale_size , topk , E , block_M )
455461
456- with set_autotune_inputs ([A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids ]):
457- # Autotune with inputs manually composed
462+ if tune :
463+ with set_autotune_inputs ([A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids ]):
464+ # Autotune with inputs manually composed
465+ kernel = matmul (
466+ m ,
467+ n ,
468+ k ,
469+ topk ,
470+ E ,
471+ padding_M ,
472+ "bfloat16" ,
473+ "bfloat16" ,
474+ "float32" ,
475+ num_bits = num_bits ,
476+ scale_size = scale_size ,
477+ fast_dequant = fast_dequant ,
478+ with_bias = with_bias ,
479+ )
480+ else :
458481 kernel = matmul (
459482 m ,
460483 n ,
@@ -469,8 +492,13 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
469492 scale_size = scale_size ,
470493 fast_dequant = fast_dequant ,
471494 with_bias = with_bias ,
495+ block_M = block_M ,
496+ block_N = block_N ,
497+ block_K = block_K ,
498+ num_stages = num_stages ,
499+ threads = threads ,
500+ split = split ,
472501 )
473- print (f'Best config: { kernel .config } ' )
474502
475503 output = kernel (
476504 A ,
@@ -504,8 +532,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
504532
505533
506534if __name__ == "__main__" :
507- M , N , K = 16384 , 5760 , 2944 # From gpt-oss-20b MoE's first gemm
508- scale_size = 32
509- topk = 4 # experts activated for each token
510- E = 32 # number of experts
511- main (M , N , K , scale_size , fast_dequant = True , with_bias = True , topk = topk , E = E )
535+ parser = argparse .ArgumentParser ()
536+ parser .add_argument (
537+ "--M" , type = int , default = 16384 , help = "M" ) # From gpt-oss-20b MoE's first gemm
538+ parser .add_argument ("--N" , type = int , default = 5760 , help = "N" )
539+ parser .add_argument ("--K" , type = int , default = 2944 , help = "K" )
540+ parser .add_argument ("--scale_size" , type = int , default = 32 , help = "scale size" )
541+ parser .add_argument (
542+ "--topk" , type = int , default = 4 , help = "topk" ) # experts activated for each token
543+ parser .add_argument ("--E" , type = int , default = 32 , help = "E" ) # number of experts
544+ parser .add_argument ("--tune" , action = "store_true" , help = "tune configs" )
545+ args = parser .parse_args ()
546+
547+ main (
548+ args .M ,
549+ args .N ,
550+ args .K ,
551+ args .scale_size ,
552+ topk = args .topk ,
553+ E = args .E ,
554+ fast_dequant = True ,
555+ with_bias = True ,
556+ tune = args .tune )
0 commit comments