Skip to content

Commit f09ef4d

Browse files
committed
lint fix
1 parent 248a150 commit f09ef4d

File tree

2 files changed

+38
-48
lines changed

2 files changed

+38
-48
lines changed

maint/gemm_v2/correctness_evaluation.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def _compile_and_check(
7171

7272
print(kernel.get_kernel_source())
7373

74-
profiler = kernel.get_profiler(
75-
tensor_supply_type=tilelang.TensorSupplyType.Normal)
74+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
7675

7776
def ref_program(A, B):
7877
import torch
@@ -82,8 +81,8 @@ def ref_program(A, B):
8281
if trans_B:
8382
B = B.T
8483
if in_dtype == "float32":
85-
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
86-
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
84+
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
85+
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
8786
C = torch.matmul(A.to(torch.float), B.to(torch.float))
8887
C = C.to(torch.__getattribute__(out_dtype))
8988
return C
@@ -383,51 +382,42 @@ def run_gemm_rr(
383382

384383
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
385384

385+
386386
M_VALUES = [64, 128, 256]
387387
N_VALUES = [16, 32, 64, 128]
388388
K_VALUES = [16, 32, 64, 128]
389389
K_VALUES_8Bit = [32, 64, 128]
390-
FALSE_TRUE_CASES = (
391-
[
392-
pytest.param(
393-
k,
394-
"float16",
395-
"float16",
396-
"float16",
397-
id=f"K{k}-float16-float16-float16",
398-
)
399-
for k in K_VALUES
400-
]
401-
+ [
402-
pytest.param(
403-
k,
404-
"int8",
405-
"int32",
406-
"int32",
407-
id="K32-int8-int32-int32",
408-
) for k in K_VALUES_8Bit
409-
]
410-
+ [
411-
pytest.param(
412-
k,
413-
"float8_e5m2",
414-
"float32",
415-
"float32",
416-
id="K32-float8_e5m2-float32-float32",
417-
)
418-
for k in K_VALUES_8Bit
419-
]
420-
+
421-
[pytest.param(
422-
k,
423-
"float8_e4m3",
424-
"float32",
425-
"float32",
426-
id="K32-float8_e4m3-float32-float32",
427-
)
428-
for k in K_VALUES_8Bit
429-
]
430-
)
390+
FALSE_TRUE_CASES = ([
391+
pytest.param(
392+
k,
393+
"float16",
394+
"float16",
395+
"float16",
396+
id=f"K{k}-float16-float16-float16",
397+
) for k in K_VALUES
398+
] + [pytest.param(
399+
k,
400+
"int8",
401+
"int32",
402+
"int32",
403+
id="K32-int8-int32-int32",
404+
) for k in K_VALUES_8Bit] + [
405+
pytest.param(
406+
k,
407+
"float8_e5m2",
408+
"float32",
409+
"float32",
410+
id="K32-float8_e5m2-float32-float32",
411+
) for k in K_VALUES_8Bit
412+
] + [
413+
pytest.param(
414+
k,
415+
"float8_e4m3",
416+
"float32",
417+
"float32",
418+
id="K32-float8_e4m3-float32-float32",
419+
) for k in K_VALUES_8Bit
420+
])
431421

432422

433423
def _ensure_torch_dtypes(*dtype_names):
@@ -485,6 +475,7 @@ def run_gemm_rr_true_false(m, n, k):
485475
def run_gemm_rr_true_true(m, n, k):
486476
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
487477

478+
488479
TRANS_CASES = [
489480
pytest.param(False, False, id="nn"),
490481
pytest.param(False, True, id="nt"),
@@ -699,7 +690,7 @@ def test_gemm_rr_true_true(m, n, k):
699690
# print(f"======================= Test {m} {n} {k} False True =============================")
700691
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
701692
# print(f"Test {m} {n} {k} Pass")
702-
693+
703694
# # Test Pass
704695
# for m in [64, 128, 256]:
705696
# for n in [16, 32, 64, 128]:
@@ -717,7 +708,6 @@ def test_gemm_rr_true_true(m, n, k):
717708
# print(f"Test {m}, {n} {k} Pass")
718709
# print(f"Test {n} Pass")
719710

720-
721711
# # Test Pass
722712
# for m in [64, 128, 256]:
723713
# for n in [16, 32, 64, 128]:
@@ -727,7 +717,6 @@ def test_gemm_rr_true_true(m, n, k):
727717
# print(f"Test {m}, {n} {k} Pass")
728718
# print(f"Test {n} Pass")
729719

730-
731720
# Test Pass
732721
# for m in [64, 128, 256]:
733722
# for n in [16, 32, 64, 128]:

maint/gemm_v2/latency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use_v2 = args.use_v2
1010

11+
1112
# @tilelang.jit(target="cuda")
1213
# target currently can be "cuda" or "hip" or "cpu".
1314
# if not specified, it will be inferred from the input tensors during compile time

0 commit comments

Comments
 (0)