Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
290 changes: 206 additions & 84 deletions benchmarks/routines/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
from triton_reference.selective_state_update import (
selective_state_update_triton as selective_state_update_triton_reference,
)
from triton_reference.selective_state_update_varlen import (
selective_state_update_varlen_triton,
)


# ==============================================================================
Expand Down Expand Up @@ -151,6 +154,21 @@ def parse_mamba_args(line, parser):
default=False,
help="Apply softplus to dt before use.",
)
parser.add_argument(
"--varlen",
action="store_true",
default=False,
help="Use varlen mode: 3D flattened inputs with cu_seqlens, 2D state indices, "
"and dst_state_batch_indices. Requires --cache_steps >= 1 (= max_seqlen).",
)
parser.add_argument(
"--algorithm",
type=str,
required=False,
default="auto",
choices=["auto", "simple", "vertical", "horizontal"],
help="FlashInfer kernel algorithm. Default: auto (picks best for GPU arch).",
)
parser.add_argument(
"--backends",
type=str,
Expand Down Expand Up @@ -196,6 +214,16 @@ def parse_mamba_args(line, parser):
f"Supported ratios: {supported_ratios}."
)

if args.varlen:
if args.cache_steps < 1:
raise ValueError(
"--varlen requires --cache_steps >= 1 (specifies max_seqlen)"
)
if args.cache_steps >= 1 and args.algorithm in ("vertical", "horizontal"):
raise ValueError(
f"MTP/varlen mode only supports 'auto' or 'simple' algorithm, got '{args.algorithm}'"
)

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Expand Down Expand Up @@ -245,6 +273,8 @@ def testSelectiveStateUpdate(args):
cache_steps = args.cache_steps
has_z = args.has_z
dt_softplus = args.dt_softplus
is_varlen = args.varlen
algorithm = args.algorithm
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
res = []
Expand All @@ -259,49 +289,65 @@ def testSelectiveStateUpdate(args):
weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype)
## Done parsing input arguments

## Determine STP vs MTP mode
is_mtp = cache_steps >= 1
## Determine mode
is_mtp = cache_steps >= 1 and not is_varlen
T = cache_steps if is_mtp else None

## Prepare input tensors (mirrors tests/mamba/utils.py::create_test_inputs)
if is_varlen:
n_seqs = batch_size
max_seqlen = cache_steps
total_tokens = n_seqs * max_seqlen

Comment on lines +296 to +300
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Grow the cache before materializing varlen src/dst indices.

This path needs 2 * n_seqs * max_seqlen distinct slots, but the benchmark still sizes state_cache independently of max_seqlen. Once cache_steps gets large enough, the second perm[...] slice is too short and the reshape fails before timing starts.

πŸ› οΈ Proposed fix
-    ssm_state_cache_size = max(384, batch_size * 10)
+    ssm_state_cache_size = max(384, batch_size * 10)
+    if is_varlen:
+        ssm_state_cache_size = max(
+            ssm_state_cache_size, 2 * n_seqs * max_seqlen
+        )

Also applies to: 360-370

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/mamba.py` around lines 296 - 300, When is_varlen is true
the code needs 2 * n_seqs * max_seqlen cache slots but state_cache is still
sized independently; before materializing the varlen src/dst indices (the
perm[...] slices and subsequent reshape) ensure state_cache is grown/resized to
at least 2 * n_seqs * max_seqlen (use n_seqs = batch_size and max_seqlen =
cache_steps) so the second perm slice and reshape won't be too short; apply the
same change to the analogous block around the perm/reshape at lines 360-370 so
both varlen paths expand state_cache before slicing.

## Prepare input tensors
ssm_state_cache_size = max(384, batch_size * 10)

# State cache: (total_entries, nheads, dim, dstate) - contiguous
state_cache = torch.randn(
ssm_state_cache_size, nheads, dim, dstate, dtype=state_dtype, device=device
)

# Input x: (batch_size, [T,] nheads, dim)
if T is not None:
if is_varlen:
# Varlen: 3D flattened inputs (total_tokens, ...)
x = torch.randn(total_tokens, nheads, dim, dtype=input_dtype, device=device)
dt_base = torch.randn(total_tokens, nheads, dtype=weight_dtype, device=device)
dt = dt_base.as_strided((total_tokens, nheads, dim), (nheads, 1, 0))
B = torch.randn(total_tokens, ngroups, dstate, dtype=input_dtype, device=device)
C = torch.randn(total_tokens, ngroups, dstate, dtype=input_dtype, device=device)
z = None
if has_z:
z = torch.randn(total_tokens, nheads, dim, dtype=input_dtype, device=device)
elif T is not None:
# MTP: 4D inputs (batch_size, T, ...)
x = torch.randn(batch_size, T, nheads, dim, dtype=input_dtype, device=device)
else:
x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device)

# dt: broadcasting across dim (one value per head)
if T is not None:
dt_base = torch.randn(batch_size, T, nheads, dtype=weight_dtype, device=device)
dt = dt_base.as_strided(
(batch_size, T, nheads, dim), (T * nheads, nheads, 1, 0)
)
else:
dt_base = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device)
dt = dt_base.as_strided((batch_size, nheads, dim), (nheads, 1, 0))

# A: (nheads, dim, dstate) - negative values, broadcasting (one value per head)
A_base = -torch.rand(nheads, dtype=torch.float32, device=device) - 1.0
A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0))

# B, C: (batch_size, [T,] ngroups, dstate)
if T is not None:
B = torch.randn(
batch_size, T, ngroups, dstate, dtype=input_dtype, device=device
)
C = torch.randn(
batch_size, T, ngroups, dstate, dtype=input_dtype, device=device
)
z = None
if has_z:
z = torch.randn(
batch_size, T, nheads, dim, dtype=input_dtype, device=device
)
else:
# STP: 3D inputs (batch_size, ...)
x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device)
dt_base = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device)
dt = dt_base.as_strided((batch_size, nheads, dim), (nheads, 1, 0))
B = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device)
C = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device)
z = None
if has_z:
z = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device)

# A: (nheads, dim, dstate) - negative values, broadcasting (one value per head)
A_base = -torch.rand(nheads, dtype=torch.float32, device=device) - 1.0
A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0))

# D: (nheads, dim) - broadcasting (one value per head)
D_base = torch.randn(nheads, dtype=weight_dtype, device=device)
Expand All @@ -311,23 +357,30 @@ def testSelectiveStateUpdate(args):
dt_bias_base = torch.rand(nheads, dtype=weight_dtype, device=device) - 4.0
dt_bias = dt_bias_base.as_strided((nheads, dim), (1, 0))

# Slot indices for state batching
slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int64, device=device)[
:batch_size
]

# Optional z tensor for gating
z = None
if has_z:
if T is not None:
z = torch.randn(
batch_size, T, nheads, dim, dtype=input_dtype, device=device
)
else:
z = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device)
# Varlen-specific tensors
if is_varlen:
cu_seqlens = torch.arange(
0, total_tokens + 1, max_seqlen, device=device, dtype=torch.int32
)
perm = torch.randperm(ssm_state_cache_size, device=device).to(torch.int32)
src_indices = perm[: n_seqs * max_seqlen].reshape(n_seqs, max_seqlen)
dst_indices = perm[n_seqs * max_seqlen : 2 * n_seqs * max_seqlen].reshape(
n_seqs, max_seqlen
)
num_accepted = torch.ones(n_seqs, device=device, dtype=torch.int32)
slot_idx = None
else:
cu_seqlens = None
src_indices = None
dst_indices = None
num_accepted = None
slot_idx = torch.randperm(
ssm_state_cache_size, dtype=torch.int64, device=device
)[:batch_size]

if args.verbose >= 2:
print(f"[VVERBOSE] Mode: {'MTP' if is_mtp else 'STP'}")
mode_str = "varlen" if is_varlen else ("MTP" if is_mtp else "STP")
print(f"[VVERBOSE] Mode: {mode_str}")
print(f"[VVERBOSE] {state_cache.shape = }, {state_cache.dtype = }")
print(f"[VVERBOSE] {x.shape = }, {x.dtype = }")
print(f"[VVERBOSE] {dt.shape = }, {dt.dtype = }")
Expand All @@ -336,7 +389,11 @@ def testSelectiveStateUpdate(args):
print(f"[VVERBOSE] {C.shape = }, {C.dtype = }")
print(f"[VVERBOSE] {D.shape = }, {D.dtype = }")
print(f"[VVERBOSE] {dt_bias.shape = }, {dt_bias.dtype = }")
print(f"[VVERBOSE] {slot_idx.shape = }")
if slot_idx is not None:
print(f"[VVERBOSE] {slot_idx.shape = }")
if cu_seqlens is not None:
print(f"[VVERBOSE] {cu_seqlens.shape = }")
print(f"[VVERBOSE] {src_indices.shape = }, {dst_indices.shape = }")
print(f"[VVERBOSE] {has_z = }, {dt_softplus = }")
if z is not None:
print(f"[VVERBOSE] {z.shape = }, {z.dtype = }")
Expand All @@ -345,38 +402,79 @@ def testSelectiveStateUpdate(args):
triton_cache_steps = cache_steps if cache_steps > 0 else None

def run_backend(backend, state, x, dt, A, B, C, D):
if backend == "flashinfer":
return flashinfer.mamba.selective_state_update(
state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=slot_idx,
cache_steps=cache_steps,
)
elif backend == "triton":
return selective_state_update_triton_reference(
state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=slot_idx,
cache_steps=triton_cache_steps,
)
if is_varlen:
if backend == "flashinfer":
return flashinfer.mamba.selective_state_update(
state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=src_indices,
algorithm=algorithm,
dst_state_batch_indices=dst_indices,
num_accepted_tokens=num_accepted,
cu_seqlens=cu_seqlens,
cache_steps=cache_steps,
)
elif backend == "triton":
return selective_state_update_varlen_triton(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=src_indices,
dst_state_batch_indices=dst_indices,
num_accepted_tokens=num_accepted,
cu_seqlens=cu_seqlens,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
else:
raise ValueError(f"Unsupported backend: {backend}")
if backend == "flashinfer":
return flashinfer.mamba.selective_state_update(
state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=slot_idx,
cache_steps=cache_steps,
algorithm=algorithm,
)
elif backend == "triton":
return selective_state_update_triton_reference(
state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=slot_idx,
cache_steps=triton_cache_steps,
)
else:
raise ValueError(f"Unsupported backend: {backend}")

# Reference check: use Triton as golden reference
# Save a clean snapshot of state_cache before any benchmarking, because
Expand All @@ -386,24 +484,46 @@ def run_backend(backend, state, x, dt, A, B, C, D):
clean_state_snapshot = state_cache.clone() if run_refcheck else None
if run_refcheck:
ref_state = clean_state_snapshot.clone()
reference_output = (
selective_state_update_triton_reference(
ref_state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=slot_idx,
cache_steps=triton_cache_steps,
if is_varlen:
reference_output = (
selective_state_update_varlen_triton(
ref_state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=src_indices,
dst_state_batch_indices=dst_indices,
num_accepted_tokens=num_accepted,
cu_seqlens=cu_seqlens,
)
.detach()
.clone()
)
else:
reference_output = (
selective_state_update_triton_reference(
ref_state,
x,
dt,
A,
B,
C,
D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=slot_idx,
cache_steps=triton_cache_steps,
)
.detach()
.clone()
)
.detach()
.clone()
)
has_reference_output = True

# Storage for timing results and outputs
Expand Down Expand Up @@ -541,6 +661,8 @@ def run_backend(backend, state, x, dt, A, B, C, D):
cur_res["weight_dtype"] = str(weight_dtype)
cur_res["has_z"] = has_z
cur_res["dt_softplus"] = dt_softplus
cur_res["varlen"] = is_varlen
cur_res["algorithm"] = algorithm
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
Loading
Loading