Skip to content

Commit f1aa27e

Browse files
committed
Update sparse MLA examples to support SKV adjustment and correctness checks
- Changed SKV parameter from 32768 to 8192 in sparse MLA backward and forward tests. - Added check_correctness parameter to test functions for validation of outputs. - Updated test cases to reflect new SKV values and correctness checks.
1 parent abc2f8c commit f1aa27e

File tree

4 files changed

+40
-16
lines changed

4 files changed

+40
-16
lines changed

examples/deepseek_v32/sparse_mla_bwd.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,14 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c
333333

334334
def test_sparse_mla_bwd(B=1,
335335
S=4096,
336-
SKV=32768,
336+
SKV=8192,
337337
H=64,
338338
HKV=1,
339339
DQKV=576,
340340
DV=512,
341341
topk=2048,
342-
dtype=torch.bfloat16):
342+
dtype=torch.bfloat16,
343+
check_correctness=True):
343344
# Prepare data
344345
q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
345346
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
@@ -359,7 +360,7 @@ def test_sparse_mla_bwd(B=1,
359360
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
360361
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None)
361362

362-
if SKV <= 4096:
363+
if check_correctness:
363364
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
364365
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
365366
print("assert_tensors_similar passed")
@@ -385,4 +386,13 @@ def fn():
385386

386387
if __name__ == "__main__":
387388
test_sparse_mla_bwd(
388-
B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16)
389+
B=1,
390+
S=4096,
391+
SKV=8192,
392+
H=64,
393+
HKV=1,
394+
DQKV=576,
395+
DV=512,
396+
topk=2048,
397+
dtype=torch.bfloat16,
398+
check_correctness=True)

examples/deepseek_v32/sparse_mla_fwd.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
234234

235235
def test_sparse_mla_fwd(B=1,
236236
S=4096,
237-
SKV=4096,
237+
SKV=8192,
238238
H=128,
239239
HKV=1,
240240
DQK=576,
241241
DV=512,
242242
topk=2048,
243-
dtype=torch.bfloat16):
243+
dtype=torch.bfloat16,
244+
check_correctness=True):
244245
torch.random.manual_seed(0)
245246
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
246247
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
@@ -254,7 +255,7 @@ def test_sparse_mla_fwd(B=1,
254255

255256
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
256257

257-
if SKV <= 4096:
258+
if check_correctness:
258259
# otherwise may cause out of memory
259260
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices)
260261
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
@@ -277,4 +278,13 @@ def fn():
277278

278279
if __name__ == "__main__":
279280
test_sparse_mla_fwd(
280-
B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
281+
B=1,
282+
S=4096,
283+
SKV=4096,
284+
H=128,
285+
HKV=1,
286+
DQK=576,
287+
DV=512,
288+
topk=2048,
289+
dtype=torch.bfloat16,
290+
check_correctness=True)

examples/deepseek_v32/sparse_mla_fwd_pipelined.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,15 @@ def ref_sparse_mla_fwd_interface(q,
399399

400400
def test_sparse_mla_fwd_pipelined(B=1,
401401
S=4096,
402-
SKV=4096,
402+
SKV=8192,
403403
H=128,
404404
HKV=1,
405405
DQK=576,
406406
DV=512,
407407
topk=2048,
408408
dtype=torch.bfloat16,
409-
q_start_s_index=1024):
409+
q_start_s_index=1024,
410+
check_correctness=True):
410411
KV_stride = 1
411412

412413
torch.random.manual_seed(0)
@@ -456,8 +457,8 @@ def fn():
456457
parser.add_argument("--test_correctness", action="store_true")
457458
args = parser.parse_args()
458459
if args.test_correctness:
459-
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
460+
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
460461
else:
461462
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
462-
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
463-
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
463+
test_sparse_mla_fwd_pipelined(
464+
B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,23 @@ def test_example_fp8_lighting_indexer():
2020
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2121
def test_example_sparse_mla_fwd():
2222
# small shapes for testing
23-
test_sparse_mla_fwd(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
23+
test_sparse_mla_fwd(
24+
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
2425

2526

2627
@tilelang.testing.requires_cuda
2728
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2829
def test_example_sparse_mla_fwd_pipelined():
2930
# small shapes for testing
30-
test_sparse_mla_fwd_pipelined(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
31+
test_sparse_mla_fwd_pipelined(
32+
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
3133

3234

3335
@tilelang.testing.requires_cuda
3436
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3537
def test_example_sparse_mla_bwd():
36-
test_sparse_mla_bwd()
38+
test_sparse_mla_bwd(
39+
S=1024, SKV=2048, H=128, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
3740

3841

3942
if __name__ == "__main__":

0 commit comments

Comments
 (0)