Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions csrc/flashinfer_mamba_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,29 @@ using tvm::ffi::Optional;

namespace flashinfer::mamba {

void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView output,
TensorView A, TensorView B, TensorView C, TensorView D,
Optional<TensorView> z, Optional<TensorView> dt_bias, bool dt_softplus,
Optional<TensorView> state_batch_indices, int64_t pad_slot_id);
void selective_state_update(
TensorView state, // (batch, dim, dstate) or (batch, nheads, dim, dstate)
TensorView x, // (batch, dim) or (batch, nheads, dim) for single-token
// or (batch, T, nheads, dim) for multi-token
TensorView dt, // (batch, dim) or (batch, nheads, dim) for single-token
// or (batch, T, nheads, dim) for multi-token
TensorView A, // (dim, dstate) or (nheads, dim, dstate)
TensorView B, // (batch, dstate) or (batch, ngroups, dstate) for single-token
// or (batch, T, ngroups, dstate) for multi-token
TensorView C, // (batch, dstate) or (batch, ngroups, dstate) for single-token
// or (batch, T, ngroups, dstate) for multi-token
TensorView D, // (dim,) or (nheads, dim)
Optional<TensorView> z, // (batch, dim) or (batch, nheads, dim) for single-token
// or (batch, T, nheads, dim) for multi-token
Optional<TensorView> dt_bias, // (dim,) or (nheads, dim)
bool dt_softplus,
Optional<TensorView> state_batch_indices, // (batch,)
int64_t pad_slot_id,
TensorView output, // same as x
bool disable_state_update,
Optional<TensorView> intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate)
Optional<TensorView> intermediate_state_indices, // (batch,)
int64_t cache_steps);

} // namespace flashinfer::mamba

Expand Down
708 changes: 572 additions & 136 deletions csrc/selective_state_update.cu

Large diffs are not rendered by default.

106 changes: 83 additions & 23 deletions flashinfer/mamba/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def selective_state_update(
dt_softplus: bool = False,
state_batch_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = -1,
out: torch.Tensor | None = None,
out: Optional[torch.Tensor] = None,
disable_state_update: bool = False,
intermediate_states_buffer: Optional[torch.Tensor] = None,
intermediate_state_indices: Optional[torch.Tensor] = None,
cache_steps: int = 0,
) -> torch.Tensor:
r"""Selective state update operation for Mamba layers (the generation phase).

Expand All @@ -82,55 +86,94 @@ def selective_state_update(
state : torch.Tensor
State tensor with shape (state_cache_size, dim, dstate) or (state_cache_size, nheads, dim, dstate)
x : torch.Tensor
Input tensor with shape (batch, dim) or (batch, nheads, dim)
Input tensor with shape (batch, dim) or (batch, nheads, dim) for single-token
or (batch, T, nheads, dim) for multi-token
dt : torch.Tensor
Delta time tensor with shape (batch, dim) or (batch, nheads, dim)
Delta time tensor with shape (batch, dim) or (batch, nheads, dim) for single-token
or (batch, T, nheads, dim) for multi-token
A : torch.Tensor
A matrix with shape (dim, dstate) or (nheads, dim, dstate)
B : torch.Tensor
B matrix with shape (batch, dstate) or (batch, ngroups, dstate)
B matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token
or (batch, T, ngroups, dstate) for multi-token
C : torch.Tensor
C matrix with shape (batch, dstate) or (batch, ngroups, dstate)
C matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token
or (batch, T, ngroups, dstate) for multi-token
D : torch.Tensor
D vector with shape (dim,) or (nheads, dim)
z : Optional[torch.Tensor]
Optional z tensor with shape (batch, dim) or (batch, nheads, dim)
Optional z tensor with shape (batch, dim) or (batch, nheads, dim) for single-token
or (batch, T, nheads, dim) for multi-token
dt_bias : Optional[torch.Tensor]
Optional dt bias with shape (dim,) or (nheads, dim)
dt_softplus : bool
Whether to apply softplus to dt
state_batch_indices : Optional[torch.Tensor]
Optional batch indices for cache processing
Optional batch indices for cache processing with shape (batch,)
pad_slot_id : int
If state_batch_indices is passed, lets the kernel identify padded entries
that will not be processed. For example: state_batch_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
out : torch.Tensor | None
Optional output tensor
out : Optional[torch.Tensor]
Optional output tensor (same shape as x)
disable_state_update : bool
If True, skip updating the state tensor (useful for speculative decoding verification)
intermediate_states_buffer : Optional[torch.Tensor]
Optional buffer for caching intermediate states during speculative decoding
with shape (batch, cache_steps, nheads, dim, dstate)
intermediate_state_indices : Optional[torch.Tensor]
Optional indices mapping batch elements to intermediate state buffer positions
with shape (batch,)
cache_steps : int
Number of steps/tokens to cache for speculative decoding

Returns
-------
output : torch.Tensor
Output tensor with shape (batch, dim) or (batch, nheads, dim)
Output tensor with shape (batch, dim) or (batch, nheads, dim) for single-token
or (batch, T, nheads, dim) for multi-token
"""
# Determine if we're in multi-token mode (more than 1 token)
is_mtp = cache_steps >= 1

if state.dim() == 3:
state = state.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if D.dim() == 1:
D = D.unsqueeze(0)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)

# Handle x, dt, B, C, z dimensions based on mode
# For single-token: 2D -> 3D (batch, nheads, dim)
# For multi-token: 3D -> 4D (batch, T, nheads, dim)
if x.dim() == 2:
x = x.unsqueeze(1)
if is_mtp and x.dim() == 3:
# Add T dimension for MTP mode: (batch, nheads, dim) -> (batch, T, nheads, dim)
x = x.unsqueeze(1)

if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if is_mtp and dt.dim() == 3:
dt = dt.unsqueeze(1)

if B.dim() == 2:
B = B.unsqueeze(1)
if is_mtp and B.dim() == 3:
B = B.unsqueeze(1)

if C.dim() == 2:
C = C.unsqueeze(1)
if D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if is_mtp and C.dim() == 3:
C = C.unsqueeze(1)

if z is not None:
if z.dim() == 2:
z = z.unsqueeze(1)
if is_mtp and z.dim() == 3:
z = z.unsqueeze(1)
if out is None:
output = torch.empty_like(x)
else:
Expand All @@ -139,7 +182,6 @@ def selective_state_update(
state,
x,
dt,
output,
A,
B,
C,
Expand All @@ -149,18 +191,23 @@ def selective_state_update(
dt_softplus,
state_batch_indices,
pad_slot_id,
output,
disable_state_update,
intermediate_states_buffer,
intermediate_state_indices,
cache_steps,
)
return output


@register_custom_op(
"flashinfer::selective_state_update", mutates_args=("state", "output")
"flashinfer::selective_state_update",
mutates_args=("state", "output", "intermediate_states_buffer"),
)
def _selective_state_update(
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
output: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
Expand All @@ -170,13 +217,17 @@ def _selective_state_update(
dt_softplus: bool,
state_batch_indices: Optional[torch.Tensor],
pad_slot_id: int,
output: torch.Tensor,
disable_state_update: bool,
intermediate_states_buffer: Optional[torch.Tensor],
intermediate_state_indices: Optional[torch.Tensor],
cache_steps: int,
) -> None:
"""Internal function registered with torch.library for torch.compile() support."""
get_selective_state_update_module(state.device).selective_state_update(
state,
x,
dt,
output,
A,
B,
C,
Expand All @@ -186,6 +237,11 @@ def _selective_state_update(
dt_softplus,
state_batch_indices,
pad_slot_id,
output,
disable_state_update,
intermediate_states_buffer,
intermediate_state_indices,
cache_steps,
)


Expand All @@ -194,7 +250,6 @@ def _selective_state_update_fake(
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
output: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
Expand All @@ -204,6 +259,11 @@ def _selective_state_update_fake(
dt_softplus: bool,
state_batch_indices: Optional[torch.Tensor],
pad_slot_id: int,
output: torch.Tensor,
disable_state_update: bool,
intermediate_states_buffer: Optional[torch.Tensor],
intermediate_state_indices: Optional[torch.Tensor],
cache_steps: int,
) -> None:
"""Fake implementation for torch.compile() meta tensor propagation."""
pass
Loading
Loading