diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index 1cabc32c131b..b7f4411cffc1 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -114,8 +114,11 @@ srt_musa = [ "sglang[runtime_common]", "torch", "torch_musa", - "torchada>=0.1.45", + "torchada>=0.1.48", "mthreads-ml-py", + "mate", + "mate-deep_gemm", + "mate-flash-attention", "numpy<2.0", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 450ea6b12c2b..e90ee3881e87 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -575,6 +575,12 @@ def _derive_model_shapes(self): self.num_key_value_heads = getattr( self.hf_text_config, "num_key_value_heads", None ) + self.first_k_dense_replace = getattr( + self.hf_text_config, "first_k_dense_replace", None + ) + self.full_attention_interval = getattr( + self.hf_text_config, "full_attention_interval", None + ) # for Dbrx and MPT models if self.hf_config.model_type in ["dbrx", "mpt"]: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index c13d58575f6a..64be12fdcacb 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -331,6 +331,9 @@ class Envs: SGLANG_USE_AG_AFTER_QLORA = EnvBool(False) SGLANG_NPU_FUSED_MOE_MODE = EnvInt(1) + # MTHREADS & MUSA + SGLANG_MUSA_FA3_FORCE_UPDATE_METADATA = EnvBool(False) + # Quantization SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_CPU_QUANTIZATION = EnvBool(False) diff --git a/python/sglang/srt/hardware_backend/musa/__init__.py b/python/sglang/srt/hardware_backend/musa/__init__.py new file mode 100644 index 000000000000..be2fb35b43db --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/__init__.py @@ -0,0 +1 @@ +# MUSA (Moore Threads GPU) hardware backend diff --git a/python/sglang/srt/hardware_backend/musa/attention/__init__.py b/python/sglang/srt/hardware_backend/musa/attention/__init__.py new file mode 100644 index 000000000000..305888efea79 --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/attention/__init__.py @@ -0,0 +1,3 @@ +from .flashattention_backend import MusaFlashAttentionBackend + +__all__ = ["MusaFlashAttentionBackend"] diff --git a/python/sglang/srt/hardware_backend/musa/attention/flashattention_backend.py b/python/sglang/srt/hardware_backend/musa/attention/flashattention_backend.py new file mode 100644 index 000000000000..4583c1d82b61 --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/attention/flashattention_backend.py @@ -0,0 +1,913 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Optional, Tuple, Union + +import torch +from flash_attn import flash_attn_varlen_func +from flash_attn import flash_attn_with_kvcache as mate_flash_attn_with_kvcache +from flash_attn import get_scheduler_metadata + +from sglang.srt.distributed import get_pp_group, get_pp_indices +from sglang.srt.environ import envs +from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + merge_state_v2_wrapper, +) +from sglang.srt.layers.radix_attention import AttentionType +from sglang.srt.layers.utils.cp_utils import ( + cp_allgather_and_save_kv_cache, + cp_attn_forward_extend, +) +from sglang.srt.server_args import get_global_server_args + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.model_executor.model_runner import ModelRunner + +# Global workspace buffer for MLA +_MATE_MLA_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024 +_MATE_MLA_WORKSPACE_BUFFER: torch.Tensor | None = None + +# Cache for non-MLA scheduler metadata by prefix +_MATE_NO_MLA_SCHEDULER_METADATA_DICT: dict = {} +_MATE_NO_MLA_SCHEDULER_METADATA_LOCK = threading.Lock() + +# Global reference to the current backend instance (set during __init__) +_CURRENT_BACKEND: Optional["MusaFlashAttentionBackend"] = None + + +def _compute_scheduler_metadata( + backend: "MusaFlashAttentionBackend", + cu_seqlens_q: torch.Tensor, + cu_seqlens_k_new: Optional[torch.Tensor], + cache_seqlens: torch.Tensor, + max_seqlen_q: int, + page_size: int, + causal: bool, + window_size: Tuple[int, int], + num_splits: int, +) -> Tuple[torch.Tensor, bool] | torch.Tensor: + """Compute scheduler metadata based on backend's current state.""" + global _MATE_MLA_WORKSPACE_BUFFER, _MATE_NO_MLA_SCHEDULER_METADATA_DICT + + layer = backend._current_layer + current_layer_id = layer.layer_id + batch_size = cu_seqlens_q.shape[-1] - 1 + + # Determine if scheduler metadata should be updated + should_update = True + pp_group = get_pp_group() + pp_rank = pp_group.rank_in_group + start_layer_id, _ = get_pp_indices( + backend.num_hidden_layers, pp_group.rank_in_group, pp_group.world_size + ) + if backend._current_can_run_tbo and pp_rank == 0: + start_layer_id += ( + backend.first_k_dense_replace + if backend.first_k_dense_replace is not None + else 0 + ) + + if backend.full_attention_interval is not None: + start_layer_id += backend.full_attention_interval - 1 + + if current_layer_id > start_layer_id: + should_update = False + + if envs.SGLANG_MUSA_FA3_FORCE_UPDATE_METADATA.get(): + should_update = True + + if backend.use_mla: + if _MATE_MLA_WORKSPACE_BUFFER is None: + _MATE_MLA_WORKSPACE_BUFFER = torch.empty( + _MATE_MLA_WORKSPACE_SIZE_BYTES, device=backend.device, dtype=torch.uint8 + ) + return (_MATE_MLA_WORKSPACE_BUFFER, not should_update) + else: + with _MATE_NO_MLA_SCHEDULER_METADATA_LOCK: + if ( + should_update + or backend._current_prefix not in _MATE_NO_MLA_SCHEDULER_METADATA_DICT + ): + _MATE_NO_MLA_SCHEDULER_METADATA_DICT[backend._current_prefix] = ( + get_scheduler_metadata( + batch_size=batch_size, + num_heads_q=layer.tp_q_head_num, + num_heads_kv=layer.tp_k_head_num, + headdim=layer.qk_head_dim, + headdim_v=layer.v_head_dim, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=backend._current_max_seqlen_k, + page_size=page_size, + causal=causal, + window_size=window_size, + num_splits=num_splits, + ) + ) + return _MATE_NO_MLA_SCHEDULER_METADATA_DICT[backend._current_prefix] + + +def flash_attn_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 0, + pack_gqa=None, + sm_margin: int = 0, + return_softmax_lse: bool = False, + ver: int = 3, + **kwargs, +): + """MUSA flash_attn_with_kvcache wrapper that auto-injects scheduler_metadata.""" + if scheduler_metadata is None and _CURRENT_BACKEND is not None: + backend = _CURRENT_BACKEND + # Ensure backend has been properly set up for this call + if backend._current_layer is not None: + page_size = k_cache.shape[1] if k_cache is not None else 1 + scheduler_metadata = _compute_scheduler_metadata( + backend=backend, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + cache_seqlens=cache_seqlens, + max_seqlen_q=max_seqlen_q, + page_size=page_size, + causal=causal, + window_size=window_size, + num_splits=num_splits, + ) + + return mate_flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + ) + + +class MusaFlashAttentionBackend(FlashAttentionBackend): + def __init__(self, model_runner: ModelRunner, **kwargs): + super().__init__(model_runner, **kwargs) + self.num_hidden_layers = model_runner.model_config.num_hidden_layers + self.first_k_dense_replace = model_runner.model_config.first_k_dense_replace + self.full_attention_interval = model_runner.model_config.full_attention_interval + + # State for current attention call (simplified from thread‑local context) + self._current_layer: Optional[RadixAttention] = None + self._current_prefix: str = "" + self._current_max_seqlen_k: int = 0 + self._current_can_run_tbo: bool = False + + # Register this backend as the global current instance for the wrapper + global _CURRENT_BACKEND + _CURRENT_BACKEND = self + + def _set_current_state( + self, layer: RadixAttention, prefix: str, max_seqlen_k: int, can_run_tbo: bool + ): + """Set the dynamic state for the upcoming flash attention call.""" + self._current_layer = layer + self._current_prefix = prefix + self._current_max_seqlen_k = max_seqlen_k + self._current_can_run_tbo = can_run_tbo + + def init_forward_metadata(self, forward_batch: ForwardBatch): + super().init_forward_metadata(forward_batch) + metadata = self.forward_metadata + if not hasattr(metadata, "extend_with_prefix"): + metadata.extend_with_prefix = False + + if forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed( + include_draft_extend_v2=True + ): + metadata.extend_with_prefix = any(forward_batch.extend_prefix_lens_cpu) + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + ): + if k is not None: + assert v is not None + + is_cp_mode = ( + forward_batch.forward_mode.is_context_parallel_extend() + and forward_batch.attn_cp_metadata is not None + and self.attn_cp_size > 1 + ) + + if save_kv_cache and not is_cp_mode: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + if not self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + else: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + if is_cp_mode: + cp_allgather_and_save_kv_cache( + forward_batch, layer, k, v, self.attn_cp_size + ) + + metadata = self.forward_metadata + + is_swa_layer = ( + layer.sliding_window_size is not None and layer.sliding_window_size > -1 + ) + window_size = (layer.sliding_window_size, 0) if is_swa_layer else (-1, -1) + k_descale, v_descale = None, None + if ( + self.kv_cache_dtype_str != "auto" + and layer.head_dim <= 256 + and self.fa_impl_ver != 4 + ): + if layer.k_scale is not None: + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) + q = q.to(self.kv_cache_dtype) + q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None + k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None + causal = True + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + causal = False + + use_local_attn = ( + self.has_local_attention + and self.attention_chunk_size is not None + and metadata.local_attn_metadata is not None + and (hasattr(layer, "use_irope") and layer.use_irope) + ) + + use_cascade_attn = ( + forward_batch.forward_mode.is_target_verify() + and self.topk > 1 + and not is_swa_layer + ) + + kwargs = {} + if sinks is not None: + kwargs["sinks"] = sinks + + if use_local_attn: + local_metadata = metadata.local_attn_metadata + page_table = local_metadata.local_block_table + cu_seqlens_q = local_metadata.local_query_start_loc + cache_seqlens = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + elif is_swa_layer and metadata.swa_spec_metadata is not None: + swa_spec_metadata = metadata.swa_spec_metadata + page_table = swa_spec_metadata.page_table + cu_seqlens_q = swa_spec_metadata.cu_seqlens_q + cache_seqlens = swa_spec_metadata.cache_seqlens_int32 + max_seqlen_q = swa_spec_metadata.max_seq_len_q + cu_seqlens_k = swa_spec_metadata.cu_seqlens_k + max_seqlen_k = swa_spec_metadata.max_seq_len_k + else: + page_table = metadata.page_table + if is_swa_layer and self.use_sliding_window_kv_pool: + if metadata.swa_page_table is not None: + page_table = metadata.swa_page_table + else: + page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa( + metadata.page_table + ) + cu_seqlens_q = metadata.cu_seqlens_q + cache_seqlens = metadata.cache_seqlens_int32 + max_seqlen_q = metadata.max_seq_len_q + cu_seqlens_k = metadata.cu_seqlens_k + max_seqlen_k = metadata.max_seq_len_k + + # Set current state for the flash attention call + self._set_current_state( + layer=layer, + prefix="forward_extend", + max_seqlen_k=max_seqlen_k, + can_run_tbo=forward_batch.can_run_tbo, + ) + if not self.use_mla: + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + if layer.is_cross_attention: + page_table = metadata.encoder_page_table + cache_seqlens = metadata.encoder_lens_int32 + cu_seqlens_k = metadata.encoder_cu_seqlens_k + window_size = (-1, -1) + + if ( + forward_batch.forward_mode.is_context_parallel_extend() + and forward_batch.attn_cp_metadata is not None + and self.attn_cp_size > 1 + ): + + def _fa_cp_attn( + q_chunk, cu_seqlens_q_cp, cache_seqlens_cp, max_seqlen_q_cp + ): + return flash_attn_with_kvcache( + q=q_chunk, + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=cache_seqlens_cp, + cu_seqlens_q=cu_seqlens_q_cp, + cu_seqlens_k_new=(cu_seqlens_k if not use_local_attn else None), + max_seqlen_q=max_seqlen_q_cp, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, + **kwargs, + ) + + result = cp_attn_forward_extend( + forward_batch, + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + self.device, + _fa_cp_attn, + ) + elif ( + metadata.extend_with_prefix + or forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): + result = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, + **kwargs, + ) + + if use_cascade_attn: + # Update state for the second call + self._current_prefix = "forward_extend_use_cascade_attn" + self._current_max_seqlen_k = ( + self.forward_metadata_spec_decode_expand.max_seq_len_k + ) + + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q.contiguous().view( + -1, layer.tp_q_head_num, layer.head_dim + ), + k_cache=key_cache.view( + -1, 1, layer.tp_k_head_num, layer.head_dim + ), + v_cache=value_cache.view( + -1, 1, layer.tp_v_head_num, layer.head_dim + ), + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + num_splits=self.num_splits, + **kwargs, + ) + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + else: + output = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=metadata.cu_seqlens_q, + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=metadata.max_seq_len_q, + softmax_scale=layer.scaling, + causal=True, + return_softmax_lse=forward_batch.mha_return_lse, + ) + if forward_batch.mha_return_lse: + output, lse, *rest = output + lse = torch.transpose(lse, 0, 1).contiguous() + return ( + output.view(-1, layer.tp_q_head_num * layer.v_head_dim), + lse, + ) + return output.view(-1, layer.tp_q_head_num * layer.v_head_dim) + else: + if ( + forward_batch.attn_attend_prefix_cache is not None + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend(include_v2=True) + ): + if forward_batch.attn_attend_prefix_cache: + assert not get_global_server_args().disable_chunked_prefix_cache + assert forward_batch.prefix_chunk_idx is not None + assert forward_batch.prefix_chunk_cu_seq_lens is not None + assert forward_batch.prefix_chunk_max_seq_lens is not None + + chunk_idx = forward_batch.prefix_chunk_idx + assert chunk_idx >= 0 + + assert forward_batch.mha_return_lse + output = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], + softmax_scale=layer.scaling, + causal=False, + return_softmax_lse=True, + **kwargs, + ) + else: + cu_seqlens_k = ( + metadata.cu_seqlens_q + if not forward_batch.mha_one_shot + else metadata.cu_seqlens_k + ) + max_seqlen_k = ( + metadata.max_seq_len_q + if not forward_batch.mha_one_shot + else metadata.max_seq_len_k + ) + output = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=layer.scaling, + causal=True, + return_softmax_lse=forward_batch.mha_return_lse, + **kwargs, + ) + if forward_batch.mha_return_lse: + output, lse, *rest = output + lse = torch.transpose(lse, 0, 1).contiguous() + return output, lse + return output + else: + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ).to(q.dtype) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + + result = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, + ) + if use_cascade_attn: + self._current_prefix = "forward_extend_use_cascade_attn" + self._current_max_seqlen_k = ( + self.forward_metadata_spec_decode_expand.max_seq_len_k + ) + + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + num_splits=self.num_splits, + ) + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if k is not None: + assert v is not None + if save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + if not self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + else: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + + metadata = self.forward_metadata + local_attn_metadata = getattr(metadata, "local_attn_metadata", None) + use_local_attn = ( + self.has_local_attention + and self.attention_chunk_size is not None + and local_attn_metadata is not None + and (hasattr(layer, "use_irope") and layer.use_irope) + ) + + use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1 + + is_swa_layer = ( + layer.sliding_window_size is not None and layer.sliding_window_size > -1 + ) + window_size = (layer.sliding_window_size, 0) if is_swa_layer else (-1, -1) + + causal = True + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + causal = False + + kwargs = {} + if sinks is not None: + kwargs["sinks"] = sinks + + k_descale, v_descale = None, None + if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: + if layer.k_scale is not None: + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) + q = q.to(self.kv_cache_dtype) + q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None + k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None + + # Set current state for the flash attention call + self._set_current_state( + layer=layer, + prefix="forward_decode", + max_seqlen_k=metadata.max_seq_len_k, + can_run_tbo=forward_batch.can_run_tbo, + ) + if not self.use_mla: + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + + if layer.is_cross_attention: + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=metadata.encoder_page_table, + cache_seqlens=metadata.encoder_lens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.encoder_cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=False, + window_size=(-1, -1), + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + num_splits=self.num_splits, + **kwargs, + ) + elif use_local_attn: + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=local_attn_metadata.local_block_table, + cache_seqlens=local_attn_metadata.local_seqused_k, + cu_seqlens_q=local_attn_metadata.local_query_start_loc, + cu_seqlens_k_new=None, + max_seqlen_q=local_attn_metadata.local_max_query_len, + softmax_scale=layer.scaling, + causal=True, + window_size=(-1, -1), + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + num_splits=self.num_splits, + **kwargs, + ) + else: + page_table = metadata.page_table + if is_swa_layer and self.use_sliding_window_kv_pool: + if metadata.swa_page_table is not None: + page_table = metadata.swa_page_table + else: + page_table = ( + self.token_to_kv_pool.translate_loc_from_full_to_swa( + metadata.page_table + ) + ) + cache_seqlens = metadata.cache_seqlens_int32 + cu_seqlens_k = metadata.cu_seqlens_k + max_seqlen_q = metadata.max_seq_len_q + q_reshaped = q.contiguous().view( + -1, layer.tp_q_head_num, layer.head_dim + ) + + result = flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=metadata.cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, + **kwargs, + ) + if use_cascade_attn: + self._current_prefix = "forward_decode_use_cascade_attn" + self._current_max_seqlen_k = ( + self.forward_metadata_spec_decode_expand.max_seq_len_k + ) + + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache, + v_cache=value_cache, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + num_splits=self.num_splits, + **kwargs, + ) + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + else: + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( + q.dtype + ) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + max_seqlen_q = metadata.max_seq_len_q + + result = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=metadata.page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + softmax_scale=layer.scaling, + causal=False if use_cascade_attn else causal, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=use_cascade_attn, + num_splits=self.num_splits, + ) + if use_cascade_attn: + self._current_prefix = "forward_decode_use_cascade_attn" + self._current_max_seqlen_k = ( + self.forward_metadata_spec_decode_expand.max_seq_len_k + ) + + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + num_splits=self.num_splits, + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 0a5920575c0f..cb25eb2e896a 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -5,6 +5,9 @@ get_linear_attn_config, import_backend_class, ) +from sglang.srt.utils import get_device_capability, is_musa + +_is_musa = is_musa() logger = logging.getLogger(__name__) @@ -130,17 +133,28 @@ def create_flashmla_backend(runner): @register_attention_backend("fa3") def create_flashattention_v3_backend(runner): - import torch - assert ( - torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend - ) or torch.cuda.get_device_capability()[0] == 9, ( - "FlashAttention v3 Backend requires SM>=80 and SM<=90. " - "Please use `--attention-backend flashinfer`." - ) - from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend + major, minor = get_device_capability() + if not _is_musa: + assert (major == 8 and not runner.use_mla_backend) or major == 9, ( + "FlashAttention v3 Backend requires SM>=80 and SM<=90. " + "Please use `--attention-backend flashinfer`." + ) + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + return FlashAttentionBackend(runner) + else: + assert major == 3 and minor >= 1, ( + "FlashAttention v3 Backend requires MP>=31. " + "Please use `--attention-backend triton`." + ) + from sglang.srt.hardware_backend.musa.attention import ( + MusaFlashAttentionBackend, + ) - return FlashAttentionBackend(runner) + return MusaFlashAttentionBackend(runner) @register_attention_backend("fa4") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9ec274d2a75c..b56c19045bfe 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -53,6 +53,7 @@ is_hip, is_hopper_with_cuda_12_3, is_mps, + is_musa, is_no_spec_infer_or_topk_one, is_npu, is_remote_url, @@ -2591,7 +2592,10 @@ def _handle_kv4_compatibility(self): def _handle_page_size(self): if self.page_size is None: - self.page_size = 1 + if not is_musa(): + self.page_size = 1 + else: + self.page_size = 64 def _handle_amd_specifics(self): if is_hip():