diff --git a/python/sglang/srt/configs/mamba_utils.py b/python/sglang/srt/configs/mamba_utils.py index c0a632c0b93c..d2ff3762b140 100644 --- a/python/sglang/srt/configs/mamba_utils.py +++ b/python/sglang/srt/configs/mamba_utils.py @@ -13,6 +13,7 @@ """Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc.""" import os +from abc import ABC from dataclasses import dataclass, field from typing import List, Optional @@ -34,9 +35,45 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): return tp_size - ngroups +@dataclass(kw_only=True, frozen=True) +class Mamba2StateDType: + conv: torch.dtype + temporal: torch.dtype + + +CONV_DTYPE = torch.bfloat16 + + +def mamba2_state_dtype() -> Mamba2StateDType: + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] + return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype) + + +@dataclass(kw_only=True, frozen=True) +class BaseLinearStateParams(ABC): + dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype) + layers: list[int] + + @property + def mamba_cache_per_req(self) -> int: + conv_numel = int( + np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]) + ) + + ssm_numel = int(np.prod(self.shape.temporal)) + return ( + conv_numel * self.dtype.conv.itemsize + + ssm_numel * self.dtype.temporal.itemsize + ) * len(self.layers) + + @dataclass(kw_only=True, frozen=True) class Mamba2StateShape: - conv: tuple[int, int] + conv: list[tuple[int, int]] temporal: tuple[int, int, int] intermediate_size: int @@ -74,7 +111,7 @@ def create( # e.g., QWen3-Next: (32, 128, 128) temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) return Mamba2StateShape( - conv=conv_state_shape, + conv=[conv_state_shape], temporal=temporal_state_shape, intermediate_size=intermediate_size, conv_dim=conv_dim, @@ -87,35 +124,8 @@ def create( @dataclass(kw_only=True, frozen=True) -class Mamba2StateDType: - conv: torch.dtype - temporal: torch.dtype - - -CONV_DTYPE = torch.bfloat16 - - -def mamba2_state_dtype() -> Mamba2StateDType: - dtype_map = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - } - ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] - return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype) - - -@dataclass(kw_only=True, frozen=True) -class Mamba2CacheParams: +class Mamba2CacheParams(BaseLinearStateParams): shape: Mamba2StateShape - dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype) - layers: list[int] - - @property - def mamba_cache_per_req(self) -> int: - return ( - int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize - + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize - ) * len(self.layers) @dataclass(kw_only=True, frozen=True) @@ -169,15 +179,5 @@ def create( @dataclass(kw_only=True, frozen=True) -class KimiLinearCacheParams: +class KimiLinearCacheParams(BaseLinearStateParams): shape: KimiLinearStateShape - dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype) - layers: list[int] - - @property - def mamba_cache_per_req(self) -> int: - return ( - int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv])) - * self.dtype.conv.itemsize - + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize - ) * len(self.layers) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 99785403346b..5cfaac9e4418 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -545,7 +545,7 @@ def forward_decode( layer_id = kwargs["layer_id"] layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id) - conv_states = layer_cache.conv + conv_states = layer_cache.conv[0] ssm_states = layer_cache.temporal query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices @@ -628,12 +628,14 @@ def forward_extend( retrieve_parent_token = self.forward_metadata.retrieve_parent_token mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id) - conv_states = mamba_cache_params.conv + conv_states = mamba_cache_params.conv[0] ssm_states = mamba_cache_params.temporal if is_target_verify: assert isinstance(mamba_cache_params, MambaPool.SpeculativeState) intermediate_state_cache = mamba_cache_params.intermediate_ssm - intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window + intermediate_conv_window_cache = ( + mamba_cache_params.intermediate_conv_window[0] + ) has_initial_states = torch.ones( seq_len // forward_batch.spec_info.draft_token_num, dtype=torch.bool, @@ -973,10 +975,10 @@ def update_mamba_state_after_mtp_verify(self, accepted_indices, model): self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers() ) - conv_states = mamba_caches.conv + conv_states = mamba_caches.conv[0] ssm_states = mamba_caches.temporal intermediate_state_cache = mamba_caches.intermediate_ssm - intermediate_conv_window_cache = mamba_caches.intermediate_conv_window + intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0] # SSM state updates (chunked to reduce peak memory) valid_mask = accepted_indices >= 0 diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 2520284899ee..fcaaf2900c9f 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -364,7 +364,7 @@ def forward( # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration state_indices_tensor = metadata.mamba_cache_indices - conv_state = layer_cache.conv + conv_state = layer_cache.conv[0] ssm_state = layer_cache.temporal query_start_loc = metadata.query_start_loc diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 739e289439e4..b5acbcd8f748 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -17,7 +17,7 @@ from dataclasses import dataclass -from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams +from sglang.srt.configs.mamba_utils import BaseLinearStateParams from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -124,30 +124,31 @@ def clear(self): class MambaPool: @dataclass(frozen=True, kw_only=True) class State: - conv: Union[torch.Tensor, List[torch.Tensor]] + conv: List[torch.Tensor] temporal: torch.Tensor def at_layer_idx(self, layer: int): - if isinstance(self.conv, list): - return type(self)( - conv=[v[layer] for v in self.conv], - temporal=self.temporal[layer], - ) - return type(self)(**{k: v[layer] for k, v in vars(self).items()}) + kwargs = {} + for k, v in vars(self).items(): + if k == "conv" or k == "intermediate_conv_window": + kwargs[k] = [conv[layer] for conv in v] + else: + kwargs[k] = v[layer] + return type(self)(**kwargs) def mem_usage_bytes(self): return sum(get_tensor_size_bytes(t) for t in vars(self).values()) @dataclass(frozen=True, kw_only=True) class SpeculativeState(State): - intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]] - intermediate_conv_window: torch.Tensor + intermediate_ssm: torch.Tensor + intermediate_conv_window: List[torch.Tensor] def __init__( self, *, size: int, - cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"], + cache_params: BaseLinearStateParams, device: str, speculative_num_draft_tokens: Optional[int] = None, ): @@ -165,29 +166,19 @@ def __init__( maybe_init_custom_mem_pool(device=self.device) ) - self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams) with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.enable_custom_mem_pool else nullcontext() ): - if self.is_kda_cache: - conv_state = [ - torch.zeros( - size=(num_mamba_layers, size + 1) + conv_shape, - dtype=conv_dtype, - device=device, - ) - for conv_shape in conv_state_shape - ] - else: - # assume conv_state = (dim, state_len) - assert conv_state_shape[0] > conv_state_shape[1] - conv_state = torch.zeros( - size=(num_mamba_layers, size + 1) + conv_state_shape, + conv_state = [ + torch.zeros( + size=(num_mamba_layers, size + 1) + conv_shape, dtype=conv_dtype, device=device, ) + for conv_shape in conv_state_shape + ] temporal_state = torch.zeros( size=(num_mamba_layers, size + 1) + temporal_state_shape, dtype=ssm_dtype, @@ -210,34 +201,20 @@ def __init__( ) # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] - - if self.is_kda_cache: - intermediate_conv_window_cache = [ - torch.zeros( - size=( - num_mamba_layers, - size + 1, - speculative_num_draft_tokens, - conv_shape[0], - conv_shape[1], - ), - dtype=conv_dtype, - device="cuda", - ) - for conv_shape in conv_state_shape - ] - else: - intermediate_conv_window_cache = torch.zeros( + intermediate_conv_window_cache = [ + torch.zeros( size=( num_mamba_layers, size + 1, speculative_num_draft_tokens, - conv_state_shape[0], - conv_state_shape[1], + conv_shape[0], + conv_shape[1], ), dtype=conv_dtype, device="cuda", ) + for conv_shape in conv_state_shape + ] self.mamba_cache = self.SpeculativeState( conv=conv_state, temporal=temporal_state, @@ -289,25 +266,18 @@ def free(self, free_index: torch.Tensor): if free_index.numel() == 0: return self.free_slots = torch.cat((self.free_slots, free_index)) - if self.is_kda_cache: - for i in range(len(self.mamba_cache.conv)): - self.mamba_cache.conv[i][:, free_index] = 0 - else: - self.mamba_cache.conv[:, free_index] = 0 + for i in range(len(self.mamba_cache.conv)): + self.mamba_cache.conv[i][:, free_index] = 0 self.mamba_cache.temporal[:, free_index] = 0 def clear(self): self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): - if self.is_kda_cache: - for i in range(len(self.mamba_cache.conv)): - self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][ - :, src_index - ] - else: - self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index] - + for i in range(len(self.mamba_cache.conv)): + self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][ + :, src_index + ] self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[ :, src_index ] @@ -321,9 +291,13 @@ def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]: return dst_index def get_contiguous_buf_infos(self): - state_tensors = [ - getattr(self.mamba_cache, field) for field in vars(self.mamba_cache) - ] + state_tensors = [] + for field in vars(self.mamba_cache): + value = getattr(self.mamba_cache, field) + if isinstance(value, list): + state_tensors.extend(value) + else: + state_tensors.append(value) data_ptrs, data_lens, item_lens = [], [], [] for _, state_tensor in enumerate(state_tensors): @@ -348,7 +322,7 @@ def __init__( max_context_len: int, device: str, enable_memory_saver: bool, - cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"], + cache_params: BaseLinearStateParams, speculative_num_draft_tokens: int = None, ): super().__init__( @@ -367,7 +341,7 @@ def __init__( def _init_mamba_pool( self, size: int, - cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"], + cache_params: BaseLinearStateParams, device: str, speculative_num_draft_tokens: int = None, ): diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py index 7bbca75e1aae..2fea9ecdff7d 100644 --- a/test/srt/test_mamba_unittest.py +++ b/test/srt/test_mamba_unittest.py @@ -321,8 +321,8 @@ def make_dummy_req(): kv_indices, last_node = result.device_indices, result.last_device_node assert req9.mamba_pool_idx is not None assert torch.all( - mamba_pool.mamba_cache.conv[:, req9.mamba_pool_idx] - == mamba_pool.mamba_cache.conv[:, last_node.mamba_value] + mamba_pool.mamba_cache.conv[0][:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache.conv[0][:, last_node.mamba_value] ) assert torch.all( mamba_pool.mamba_cache.temporal[:, req9.mamba_pool_idx]