Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 41 additions & 41 deletions python/sglang/srt/configs/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""

import os
from abc import ABC
from dataclasses import dataclass, field
from typing import List, Optional

Expand All @@ -34,9 +35,45 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
return tp_size - ngroups


@dataclass(kw_only=True, frozen=True)
class Mamba2StateDType:
conv: torch.dtype
temporal: torch.dtype


CONV_DTYPE = torch.bfloat16


def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)


@dataclass(kw_only=True, frozen=True)
class BaseLinearStateParams(ABC):
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]

@property
def mamba_cache_per_req(self) -> int:
conv_numel = int(
np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv])
)

ssm_numel = int(np.prod(self.shape.temporal))
return (
conv_numel * self.dtype.conv.itemsize
+ ssm_numel * self.dtype.temporal.itemsize
) * len(self.layers)


@dataclass(kw_only=True, frozen=True)
class Mamba2StateShape:
conv: tuple[int, int]
conv: list[tuple[int, int]]
temporal: tuple[int, int, int]

intermediate_size: int
Expand Down Expand Up @@ -74,7 +111,7 @@ def create(
# e.g., QWen3-Next: (32, 128, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return Mamba2StateShape(
conv=conv_state_shape,
conv=[conv_state_shape],
temporal=temporal_state_shape,
intermediate_size=intermediate_size,
conv_dim=conv_dim,
Expand All @@ -87,35 +124,8 @@ def create(


@dataclass(kw_only=True, frozen=True)
class Mamba2StateDType:
conv: torch.dtype
temporal: torch.dtype


CONV_DTYPE = torch.bfloat16


def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)


@dataclass(kw_only=True, frozen=True)
class Mamba2CacheParams:
class Mamba2CacheParams(BaseLinearStateParams):
shape: Mamba2StateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]

@property
def mamba_cache_per_req(self) -> int:
return (
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)


@dataclass(kw_only=True, frozen=True)
Expand Down Expand Up @@ -169,15 +179,5 @@ def create(


@dataclass(kw_only=True, frozen=True)
class KimiLinearCacheParams:
class KimiLinearCacheParams(BaseLinearStateParams):
shape: KimiLinearStateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]

@property
def mamba_cache_per_req(self) -> int:
return (
int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]))
* self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def forward_decode(
layer_id = kwargs["layer_id"]

layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = layer_cache.conv
conv_states = layer_cache.conv[0]
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
Expand Down Expand Up @@ -628,12 +628,14 @@ def forward_extend(
retrieve_parent_token = self.forward_metadata.retrieve_parent_token

mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = mamba_cache_params.conv
conv_states = mamba_cache_params.conv[0]
ssm_states = mamba_cache_params.temporal
if is_target_verify:
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
intermediate_state_cache = mamba_cache_params.intermediate_ssm
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
intermediate_conv_window_cache = (
mamba_cache_params.intermediate_conv_window[0]
)
has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool,
Expand Down Expand Up @@ -973,10 +975,10 @@ def update_mamba_state_after_mtp_verify(self, accepted_indices, model):
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
)

conv_states = mamba_caches.conv
conv_states = mamba_caches.conv[0]
ssm_states = mamba_caches.temporal
intermediate_state_cache = mamba_caches.intermediate_ssm
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0]

# SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_indices >= 0
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/attention/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def forward(
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
state_indices_tensor = metadata.mamba_cache_indices
conv_state = layer_cache.conv
conv_state = layer_cache.conv[0]
ssm_state = layer_cache.temporal

query_start_loc = metadata.query_start_loc
Expand Down
102 changes: 38 additions & 64 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass

from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
from sglang.srt.configs.mamba_utils import BaseLinearStateParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
Expand Down Expand Up @@ -124,30 +124,31 @@ def clear(self):
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: Union[torch.Tensor, List[torch.Tensor]]
conv: List[torch.Tensor]
temporal: torch.Tensor

def at_layer_idx(self, layer: int):
if isinstance(self.conv, list):
return type(self)(
conv=[v[layer] for v in self.conv],
temporal=self.temporal[layer],
)
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
kwargs = {}
for k, v in vars(self).items():
if k == "conv" or k == "intermediate_conv_window":
kwargs[k] = [conv[layer] for conv in v]
else:
kwargs[k] = v[layer]
return type(self)(**kwargs)

def mem_usage_bytes(self):
return sum(get_tensor_size_bytes(t) for t in vars(self).values())

@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
intermediate_conv_window: torch.Tensor
intermediate_ssm: torch.Tensor
intermediate_conv_window: List[torch.Tensor]

def __init__(
self,
*,
size: int,
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
cache_params: BaseLinearStateParams,
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
Expand All @@ -165,29 +166,19 @@ def __init__(
maybe_init_custom_mem_pool(device=self.device)
)

self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
if self.is_kda_cache:
conv_state = [
torch.zeros(
size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype,
device=device,
)
for conv_shape in conv_state_shape
]
else:
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
conv_state = [
torch.zeros(
size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype,
device=device,
)
for conv_shape in conv_state_shape
]
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
Expand All @@ -210,34 +201,20 @@ def __init__(
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]

if self.is_kda_cache:
intermediate_conv_window_cache = [
torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
else:
intermediate_conv_window_cache = torch.zeros(
intermediate_conv_window_cache = [
torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
Expand Down Expand Up @@ -289,25 +266,18 @@ def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
self.free_slots = torch.cat((self.free_slots, free_index))
if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, free_index] = 0
else:
self.mamba_cache.conv[:, free_index] = 0
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, free_index] = 0
self.mamba_cache.temporal[:, free_index] = 0

def clear(self):
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)

def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
if self.is_kda_cache:
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
else:
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]

for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index
]
Expand All @@ -321,9 +291,13 @@ def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
return dst_index

def get_contiguous_buf_infos(self):
state_tensors = [
getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
]
state_tensors = []
for field in vars(self.mamba_cache):
value = getattr(self.mamba_cache, field)
if isinstance(value, list):
state_tensors.extend(value)
else:
state_tensors.append(value)
data_ptrs, data_lens, item_lens = [], [], []

for _, state_tensor in enumerate(state_tensors):
Expand All @@ -348,7 +322,7 @@ def __init__(
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
cache_params: BaseLinearStateParams,
speculative_num_draft_tokens: int = None,
):
super().__init__(
Expand All @@ -367,7 +341,7 @@ def __init__(
def _init_mamba_pool(
self,
size: int,
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
cache_params: BaseLinearStateParams,
device: str,
speculative_num_draft_tokens: int = None,
):
Expand Down
4 changes: 2 additions & 2 deletions test/srt/test_mamba_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def make_dummy_req():
kv_indices, last_node = result.device_indices, result.last_device_node
assert req9.mamba_pool_idx is not None
assert torch.all(
mamba_pool.mamba_cache.conv[:, req9.mamba_pool_idx]
== mamba_pool.mamba_cache.conv[:, last_node.mamba_value]
mamba_pool.mamba_cache.conv[0][:, req9.mamba_pool_idx]
== mamba_pool.mamba_cache.conv[0][:, last_node.mamba_value]
)
assert torch.all(
mamba_pool.mamba_cache.temporal[:, req9.mamba_pool_idx]
Expand Down
Loading