Skip to content

Commit 74116dc

Browse files
committed
added cli mamba option and clean up
Signed-off-by: Rishi Astra <[email protected]>
1 parent 1400b95 commit 74116dc

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def __init__(
466466
# The tuple is (conv_state, ssm_state)
467467
self.kv_cache = (torch.tensor([]), torch.tensor([]))
468468

469+
overrides = getattr(model_config, "hf_overrides", {}) or {}
470+
self.use_fast_kernel = bool(overrides.get("mamba2_fast_kernel", False))
471+
469472
self.model_config = model_config
470473
self.cache_config = cache_config
471474
self.prefix = prefix
@@ -712,6 +715,7 @@ def forward_cuda(
712715
dt_limit=(0.0, float("inf")),
713716
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
714717
state_dtype=ssm_state.dtype,
718+
use_fused_kernel=self.use_fast_kernel,
715719
)
716720

717721
if prefix_caching_enabled:

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _mamba_chunk_scan_combined_fwd(
4545
dt_softplus=False,
4646
dt_limit=(0.0, float("inf")),
4747
state_dtype=None,
48-
fused=False,
48+
use_fused_kernel=False,
4949
):
5050
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
5151
seqlen, nheads, headdim = x.shape
@@ -80,8 +80,8 @@ def _mamba_chunk_scan_combined_fwd(
8080
if initial_states is not None:
8181
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
8282

83-
if fused: # all 5 kernels fused
84-
_, states, _, dA_cumsum, dt = _fused5_ssd(
83+
if use_fused_kernel: # all 5 kernels fused
84+
_, states, dA_cumsum, dt = _fused5_ssd(
8585
x,
8686
dt,
8787
A,
@@ -201,6 +201,7 @@ def mamba_chunk_scan_combined_varlen(
201201
dt_limit=(0.0, float("inf")),
202202
return_intermediate_states=False,
203203
state_dtype=None,
204+
use_fused_kernel=False,
204205
):
205206
"""
206207
Argument:
@@ -249,6 +250,7 @@ def mamba_chunk_scan_combined_varlen(
249250
dt_softplus=dt_softplus,
250251
dt_limit=dt_limit,
251252
state_dtype=state_dtype,
253+
use_fused_kernel=use_fused_kernel,
252254
)
253255

254256
return varlen_states

vllm/model_executor/layers/mamba/ops/ssd_fused5.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,9 +1051,4 @@ def _fused5_ssd(
10511051
CB_COMP_FP32=cb_comp_fp32,
10521052
)
10531053

1054-
# states_G holds both states and final states
1055-
# TODO: can skip this copy if copied outside of function
1056-
final_states = states_G[nchunks].to(
1057-
states_G.dtype, copy=True
1058-
) # copy and convert to expected dtype
1059-
return out_x, states_G[1:], final_states, dA_cumsum, dt_out
1054+
return out_x, states_G[1:], dA_cumsum, dt_out

0 commit comments

Comments
 (0)