Skip to content

Commit 4b1c1e1

Browse files
committed
fixes, tests, and cleanup
Signed-off-by: Rishi Astra <[email protected]>
1 parent 74116dc commit 4b1c1e1

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ def end_boundary(n: int):
192192
@pytest.mark.parametrize("n_heads", [4, 16, 32])
193193
@pytest.mark.parametrize("d_head", [5, 8, 32, 128])
194194
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
195-
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
195+
@pytest.mark.parametrize("mamba2_fast_kernel", [False, True])
196+
def test_mamba_chunk_scan_single_example(
197+
d_head, n_heads, seq_len_chunk_size, itype, mamba2_fast_kernel
198+
):
196199
# this tests the kernels on a single example (bs=1)
197200

198201
# TODO: the bfloat16 case requires higher thresholds. To be investigated
@@ -239,6 +242,7 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
239242
seq_idx=seq_idx_chunks,
240243
out=Y,
241244
D=None,
245+
use_fused_kernel=mamba2_fast_kernel,
242246
)
243247

244248
# just test the last in sequence
@@ -257,6 +261,7 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
257261
@pytest.mark.parametrize("itype", [torch.float32])
258262
@pytest.mark.parametrize("n_heads", [4, 8])
259263
@pytest.mark.parametrize("d_head", [5, 16, 32])
264+
@pytest.mark.parametrize("mamba2_fast_kernel", [False, True])
260265
@pytest.mark.parametrize(
261266
"seq_len_chunk_size_cases",
262267
[
@@ -283,7 +288,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
283288
(768, 128, 2, [(138, 225), (138, 225)]),
284289
],
285290
)
286-
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
291+
def test_mamba_chunk_scan_cont_batch(
292+
d_head, n_heads, seq_len_chunk_size_cases, itype, mamba2_fast_kernel
293+
):
287294
# this test with multiple examples in a continuous batch
288295
# (i.e. chunked prefill)
289296

@@ -329,6 +336,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
329336
out=Y,
330337
D=None,
331338
initial_states=states,
339+
use_fused_kernel=mamba2_fast_kernel,
332340
)
333341

334342
# just test the last in sequence
@@ -347,11 +355,14 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
347355

348356

349357
@pytest.mark.parametrize("chunk_size", [8, 256])
358+
@pytest.mark.parametrize("mamba2_fast_kernel", [False, True])
350359
@pytest.mark.parametrize(
351360
"seqlens",
352361
[(16, 20), (270, 88, 212, 203)],
353362
)
354-
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
363+
def test_mamba_chunk_scan_cont_batch_prefill_chunking(
364+
chunk_size, seqlens, mamba2_fast_kernel
365+
):
355366
# This test verifies the correctness of the chunked prefill implementation
356367
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
357368
# dimension) of chunked results with the full sequence result.
@@ -414,6 +425,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
414425
out=Y_ref,
415426
D=None,
416427
initial_states=None,
428+
use_fused_kernel=mamba2_fast_kernel,
417429
)
418430

419431
## chunked seqlen computation

vllm/model_executor/layers/mamba/ops/ssd_fused5.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _fused5_ssd_kernel(
398398

399399
# advance ptrs up front to simplify and slightly reduce register pressure
400400
# does actually provide a small benefit vs the original separate ptrs per step
401-
states_G_ptr += pid_h * stride_states_G_head + pid_c * stride_states_G_chunk
401+
states_G_ptr += pid_h * stride_states_G_head + (pid_c - 1) * stride_states_G_chunk
402402
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
403403
b_ptr += (
404404
chunk_seqlen_start * stride_b_seqlen
@@ -531,11 +531,6 @@ def _fused5_ssd_kernel(
531531
)
532532
seq_idx_new = tl.load(seq_idx_ptr + pid_c * stride_seq_idx_chunk)
533533

534-
if not HAS_INITSTATES: # don't let previous chunk affect if new sequence
535-
scale = tl.where(
536-
seq_idx_new == seq_idx_prev, scale, 0.0
537-
) # TODO: can avoid load instead
538-
539534
# sync
540535
# the atomic represents which pid_c is ready
541536
# therefore, wait for it to reach our pid_c
@@ -555,10 +550,15 @@ def _fused5_ssd_kernel(
555550
+ offs_ds[None, :] * stride_initial_states_dstate
556551
)
557552

558-
if (not NEED_MASK_HD) and (not NEED_MASK_1_DS):
559-
states_prev = tl.load(states_prev_ptrs)
553+
if seq_idx_new != seq_idx_prev and not HAS_INITSTATES:
554+
states_prev = tl.zeros(
555+
states.shape, dtype=states_prev_ptrs.dtype.element_ty
556+
)
560557
else:
561-
states_prev = tl.load(states_prev_ptrs, mask=main_mask, other=0.0)
558+
if (not NEED_MASK_HD) and (not NEED_MASK_1_DS):
559+
states_prev = tl.load(states_prev_ptrs)
560+
else:
561+
states_prev = tl.load(states_prev_ptrs, mask=main_mask, other=0.0)
562562

563563
states_mod = (
564564
scale * states_prev + states
@@ -589,7 +589,9 @@ def _fused5_ssd_kernel(
589589
prev_state_stride_hdim = stride_states_G_hdim
590590
prev_state_stride_dstate = stride_states_G_dstate
591591

592-
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=0)
592+
seq_idx_prev = tl.load(
593+
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
594+
)
593595
seq_idx_m = tl.load(seq_idx_ptr) # current seq idx
594596
if HAS_INITSTATES: # if new sequence, switch to initial states
595597
if seq_idx_prev != seq_idx_m:
@@ -639,7 +641,8 @@ def _fused5_ssd_kernel(
639641
+ offs_k_dstate[:, None] * prev_state_stride_dstate
640642
)
641643

642-
if seq_idx_prev == seq_idx_m: # if new sequence, add previous chunk affect
644+
# add previous chunk affect if needed
645+
if seq_idx_prev == seq_idx_m or HAS_INITSTATES:
643646
scale_m = tl.exp(dA_cs_m)
644647
if BLOCK_SIZE_DSTATE <= CS_WHOLEBLOCK_DS:
645648
if (not NEED_MASK_HD) and (not NEED_MASK_CS_DS):
@@ -855,8 +858,8 @@ def _fused5_ssd(
855858
:param dt_limit: clamp for dt
856859
"""
857860
# precision settings
858-
cb_store_fp32 = False
859-
cb_scale_fp32 = False
861+
cb_store_fp32 = True
862+
cb_scale_fp32 = True
860863
cs_acc_fp32 = False
861864
cb_comp_fp32 = True
862865

@@ -894,9 +897,8 @@ def _fused5_ssd(
894897
assert dA_chunk_cumsum.shape == (nheads, nchunks)
895898
assert seq_idx is not None and seq_idx.shape == (nchunks,)
896899

897-
# +1 for the final states
898900
states_G = torch.empty(
899-
(nchunks + 1, nheads, hdim, dstate), device=x.device, dtype=state_dtype
901+
(nchunks, nheads, hdim, dstate), device=x.device, dtype=state_dtype
900902
)
901903
# setup from chunk scan
902904
assert C.shape == (seqlen, ngroups, dstate)
@@ -910,14 +912,7 @@ def _fused5_ssd(
910912
if initial_states is not None:
911913
num_varlen_seqs = initial_states.shape[0]
912914
assert initial_states.shape == (num_varlen_seqs, nheads, hdim, dstate)
913-
# with initial states, we need to take care of how
914-
# seq_idx crosses the boundaries
915-
916-
# TODO: try copying all init states here if it's cheaper
917-
states_G[0, :, :, :] = initial_states[0, :, :, :] # initialize to zero
918-
919-
else:
920-
states_G[0, :, :, :] = 0 # initialize to zero
915+
assert initial_states.dtype == states_G.dtype
921916

922917
initial_states_strides = (
923918
(
@@ -1051,4 +1046,4 @@ def _fused5_ssd(
10511046
CB_COMP_FP32=cb_comp_fp32,
10521047
)
10531048

1054-
return out_x, states_G[1:], dA_cumsum, dt_out
1049+
return out_x, states_G, dA_cumsum, dt_out

0 commit comments

Comments
 (0)