@@ -71,8 +71,7 @@ def _compile_and_check(
7171
7272 print (kernel .get_kernel_source ())
7373
74- profiler = kernel .get_profiler (
75- tensor_supply_type = tilelang .TensorSupplyType .Normal )
74+ profiler = kernel .get_profiler (tensor_supply_type = tilelang .TensorSupplyType .Normal )
7675
7776 def ref_program (A , B ):
7877 import torch
@@ -82,8 +81,8 @@ def ref_program(A, B):
8281 if trans_B :
8382 B = B .T
8483 if in_dtype == "float32" :
85- A = (( A .view (torch .int32 ) - 0x1000 ) ).view (torch .float32 )
86- B = (( B .view (torch .int32 ) - 0x1000 ) ).view (torch .float32 )
84+ A = (A .view (torch .int32 ) - 0x1000 ).view (torch .float32 )
85+ B = (B .view (torch .int32 ) - 0x1000 ).view (torch .float32 )
8786 C = torch .matmul (A .to (torch .float ), B .to (torch .float ))
8887 C = C .to (torch .__getattribute__ (out_dtype ))
8988 return C
@@ -383,51 +382,42 @@ def run_gemm_rr(
383382
384383 _compile_and_check (program , trans_A , trans_B , in_dtype , out_dtype )
385384
385+
386386M_VALUES = [64 , 128 , 256 ]
387387N_VALUES = [16 , 32 , 64 , 128 ]
388388K_VALUES = [16 , 32 , 64 , 128 ]
389389K_VALUES_8Bit = [32 , 64 , 128 ]
390- FALSE_TRUE_CASES = (
391- [
392- pytest .param (
393- k ,
394- "float16" ,
395- "float16" ,
396- "float16" ,
397- id = f"K{ k } -float16-float16-float16" ,
398- )
399- for k in K_VALUES
400- ]
401- + [
402- pytest .param (
403- k ,
404- "int8" ,
405- "int32" ,
406- "int32" ,
407- id = "K32-int8-int32-int32" ,
408- ) for k in K_VALUES_8Bit
409- ]
410- + [
411- pytest .param (
412- k ,
413- "float8_e5m2" ,
414- "float32" ,
415- "float32" ,
416- id = "K32-float8_e5m2-float32-float32" ,
417- )
418- for k in K_VALUES_8Bit
419- ]
420- +
421- [pytest .param (
422- k ,
423- "float8_e4m3" ,
424- "float32" ,
425- "float32" ,
426- id = "K32-float8_e4m3-float32-float32" ,
427- )
428- for k in K_VALUES_8Bit
429- ]
430- )
390+ FALSE_TRUE_CASES = ([
391+ pytest .param (
392+ k ,
393+ "float16" ,
394+ "float16" ,
395+ "float16" ,
396+ id = f"K{ k } -float16-float16-float16" ,
397+ ) for k in K_VALUES
398+ ] + [pytest .param (
399+ k ,
400+ "int8" ,
401+ "int32" ,
402+ "int32" ,
403+ id = "K32-int8-int32-int32" ,
404+ ) for k in K_VALUES_8Bit ] + [
405+ pytest .param (
406+ k ,
407+ "float8_e5m2" ,
408+ "float32" ,
409+ "float32" ,
410+ id = "K32-float8_e5m2-float32-float32" ,
411+ ) for k in K_VALUES_8Bit
412+ ] + [
413+ pytest .param (
414+ k ,
415+ "float8_e4m3" ,
416+ "float32" ,
417+ "float32" ,
418+ id = "K32-float8_e4m3-float32-float32" ,
419+ ) for k in K_VALUES_8Bit
420+ ])
431421
432422
433423def _ensure_torch_dtypes (* dtype_names ):
@@ -485,6 +475,7 @@ def run_gemm_rr_true_false(m, n, k):
485475def run_gemm_rr_true_true (m , n , k ):
486476 run_gemm_rr (m , n , k * 3 , True , True , "float16" , "float16" , "float16" , m , n , k , 2 , 128 )
487477
478+
488479TRANS_CASES = [
489480 pytest .param (False , False , id = "nn" ),
490481 pytest .param (False , True , id = "nt" ),
@@ -699,7 +690,7 @@ def test_gemm_rr_true_true(m, n, k):
699690 # print(f"======================= Test {m} {n} {k} False True =============================")
700691 # run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
701692 # print(f"Test {m} {n} {k} Pass")
702-
693+
703694 # # Test Pass
704695 # for m in [64, 128, 256]:
705696 # for n in [16, 32, 64, 128]:
@@ -717,7 +708,6 @@ def test_gemm_rr_true_true(m, n, k):
717708 # print(f"Test {m}, {n} {k} Pass")
718709 # print(f"Test {n} Pass")
719710
720-
721711 # # Test Pass
722712 # for m in [64, 128, 256]:
723713 # for n in [16, 32, 64, 128]:
@@ -727,7 +717,6 @@ def test_gemm_rr_true_true(m, n, k):
727717 # print(f"Test {m}, {n} {k} Pass")
728718 # print(f"Test {n} Pass")
729719
730-
731720 # Test Pass
732721 # for m in [64, 128, 256]:
733722 # for n in [16, 32, 64, 128]:
0 commit comments