diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py index fcc8c6316d7a..00158319fb46 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -38,6 +38,10 @@ class AscendGDNAttnBackend(GDNAttnBackend): def __init__(self, model_runner: ModelRunner): super().__init__(model_runner) + # transpose last two dim for _init_npu_conv_state + self.conv_states_shape = torch.Size( + (*self.conv_states_shape[:-2], self.conv_states_shape[-1], self.conv_states_shape[-2]) + ) decode_backend = get_linear_attn_decode_backend() prefill_backend = get_linear_attn_prefill_backend() self.kernel_dispatcher = AscendGDNKernelDispatcher( @@ -60,10 +64,10 @@ def prepare_gdn_inputs( if forward_mode.is_target_verify(): seq_len = spec_info.draft_token_num self.actual_seq_lengths = self.actual_seq_lengths * seq_len - start_indices = cache_indices * seq_len - offset = torch.arange(seq_len, device=start_indices.device) - ranges = start_indices.unsqueeze(1) + offset - self.ssm_state_indices = ranges.flatten().to(torch.int32) + # indices + self.ssm_state_indices = torch.arange( + cache_indices.shape[0] * seq_len, dtype=torch.int32, device=cache_indices.device + ) else: self.ssm_state_indices = cache_indices @@ -262,7 +266,7 @@ def forward_extend( :, forward_metadata.track_conv_indices ].transpose(0, 1) mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0] - conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track + conv_states.transpose(1, 2)[conv_dst[mask_indices]] = mixed_qkv_to_track kernel_size = layer.conv_weights.shape[-1] conv_states_for_prefill = conv_states[:, -(kernel_size - 1) :, :] conv_states_tmp = conv_states_for_prefill.contiguous() diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 75d0fac2b6d4..4ca421a8694e 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -487,12 +487,9 @@ def _capture_metadata( self.query_start_loc_list[bs - 1].copy_( self.cached_cuda_graph_verify_query_start_loc[: bs + 1] ) - start_indices = mamba_indices * spec_info.draft_token_num - offset = torch.arange( - spec_info.draft_token_num, device=start_indices.device + ssm_state_indices = torch.arange( + mamba_indices.shape[0] * spec_info.draft_token_num, dtype=torch.int32, device=mamba_indices.device ) - ranges = start_indices.unsqueeze(1) + offset - ssm_state_indices = ranges.flatten().to(torch.int32) self.state_indices_list_gdn[bs - 1][ : len(mamba_indices) * spec_info.draft_token_num ].copy_(ssm_state_indices) @@ -547,14 +544,10 @@ def _replay_metadata( bs - num_padding ) elif forward_mode.is_target_verify(): - start_indices = ( - mamba_indices[: bs - num_padding] * spec_info.draft_token_num + ssm_state_indices = torch.arange( + len(mamba_indices[:bs - num_padding]) * spec_info.draft_token_num, + dtype=torch.int32, device=mamba_indices.device ) - offset = torch.arange( - spec_info.draft_token_num, device=start_indices.device - ) - ranges = start_indices.unsqueeze(1) + offset - ssm_state_indices = ranges.flatten().to(torch.int32) self.state_indices_list_gdn[bs - 1][ : len(mamba_indices[: bs - num_padding]) * spec_info.draft_token_num ].copy_(ssm_state_indices) @@ -1001,18 +994,21 @@ def update_mamba_state_after_mtp_verify( intermediate_state_cache = mamba_caches.intermediate_ssm intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0] if is_npu(): - valid_state_indices = state_indices_tensor.to(torch.int64) # [N] + dst_indices_tensor = state_indices_tensor.to(torch.int64) # [N] + src_indices_tensor = torch.arange(dst_indices_tensor.shape[0], + device=dst_indices_tensor.device, + dtype=torch.int64) last_steps = accepted_steps.to(torch.int64) # [N] move_intermediate_cache( - ssm_states, intermediate_state_cache, valid_state_indices, last_steps + ssm_states, intermediate_state_cache, dst_indices_tensor, src_indices_tensor, last_steps ) draft_token_num = intermediate_state_cache.shape[2] - if valid_state_indices.numel() > 0: + if dst_indices_tensor.numel() > 0: conv_state_rollback( conv_states, - valid_state_indices, + dst_indices_tensor, last_steps, draft_token_num, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e7db9bf9f579..8fb64fc080e5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2140,7 +2140,7 @@ def _handle_mamba_radix_cache( ) assert ( - is_cuda() + is_cuda(), is_npu() ), "Mamba extra_buffer is only supported on CUDA devices with FLA backend" if self.speculative_num_draft_tokens is not None: assert (