Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions tests/kernels/mamba/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def end_boundary(n: int):
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):

Expand Down Expand Up @@ -253,15 +253,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths

# odd chunk_size
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
(21, 15)]), # irregular sizes

# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences

# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
Expand All @@ -271,10 +271,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,

seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases

# TODO: the irregular chunk size cases have some issues and require higher
# tolerance. This is to be invesigated
if chunk_size not in {8, 256}:
atol, rtol = 5e-1, 5e-1
# This test can have larger error for longer sequences
if seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3

Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,8 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr +
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
other=0.0).to(tl.float32)

if HAS_SEQ_IDX:
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')


def is_int_pow_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0


def _mamba_chunk_scan_combined_fwd(x,
dt,
A,
Expand All @@ -38,6 +42,7 @@ def _mamba_chunk_scan_combined_fwd(x,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
out=None):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
Expand Down