Skip to content

Commit e14fe32

Browse files
committed
format
1 parent a5070fc commit e14fe32

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

examples/flash_attention/test_example_flash_attention.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,35 @@ def test_example_gqa_bwd_wgmma_pipelined():
3434
@tilelang.testing.requires_cuda
3535
def test_example_mha_bwd():
3636
example_mha_bwd.main(
37-
BATCH = 1,
38-
H = 16,
39-
N_CTX = 512,
40-
D_HEAD = 64,
41-
causal = False,)
37+
BATCH=1,
38+
H=16,
39+
N_CTX=512,
40+
D_HEAD=64,
41+
causal=False,
42+
)
4243

4344

4445
@tilelang.testing.requires_cuda
4546
def test_example_mha_bwd_bhsd():
4647
example_mha_bwd_bhsd.main(
47-
BATCH = 1,
48-
H = 16,
49-
N_CTX = 512,
50-
D_HEAD = 64,
51-
causal = False,)
48+
BATCH=1,
49+
H=16,
50+
N_CTX=512,
51+
D_HEAD=64,
52+
causal=False,
53+
)
5254

5355

5456
@tilelang.testing.requires_cuda
5557
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
5658
def test_example_mha_bwd_wgmma_pipelined():
5759
example_mha_bwd_wgmma_pipelined.main(
58-
BATCH = 1,
59-
H = 16,
60-
N_CTX = 512,
61-
D_HEAD = 64,
62-
causal = False,)
60+
BATCH=1,
61+
H=16,
62+
N_CTX=512,
63+
D_HEAD=64,
64+
causal=False,
65+
)
6366

6467

6568
@tilelang.testing.requires_cuda
@@ -99,7 +102,7 @@ def test_example_mha_fwd_bshd():
99102

100103
@tilelang.testing.requires_cuda
101104
def test_example_mha_fwd_varlen():
102-
example_mha_fwd_varlen.main(batch = 4, heads = 16, seq_len = 512, dim = 64)
105+
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
103106

104107

105108
if __name__ == "__main__":

examples/flash_decoding/example_mha_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def flash_split_ref(Q, K, V, causal):
302302
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
303303

304304

305-
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
305+
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
306306
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
307307
total_flops = 2 * flops_per_matmul
308308
if causal:

examples/flash_decoding/test_example_flash_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_example_example_gqa_decode():
1212

1313

1414
def test_example_example_mha_inference():
15-
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
15+
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
1616

1717

1818
if __name__ == "__main__":

0 commit comments

Comments
 (0)