diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index 4348f673dc8c..8b0b40519b21 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -116,6 +116,9 @@ srt_musa = [ "torch_musa", "torchada>=0.1.25", "mthreads-ml-py", + "mate", + "mate-deep_gemm", + "mate-flash-attention", "numpy<2.0", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 7142f56afc06..acc0c3eb1720 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -552,6 +552,12 @@ def _derive_model_shapes(self): self.num_key_value_heads = getattr( self.hf_text_config, "num_key_value_heads", None ) + self.first_k_dense_replace = getattr( + self.hf_text_config, "first_k_dense_replace", None + ) + self.full_attention_interval = getattr( + self.hf_text_config, "full_attention_interval", None + ) # for Dbrx and MPT models if self.hf_config.model_type in ["dbrx", "mpt"]: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index adc21301ba15..f8e3b0944206 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -327,6 +327,9 @@ class Envs: SGLANG_USE_AG_AFTER_QLORA = EnvBool(False) SGLANG_NPU_FUSED_MOE_MODE = EnvInt(1) + # MTHREADS & MUSA + SGLANG_MUSA_FA3_FORCE_UPDATE_METADATA = EnvBool(False) + # Quantization SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_CPU_QUANTIZATION = EnvBool(False) diff --git a/python/sglang/srt/hardware_backend/musa/__init__.py b/python/sglang/srt/hardware_backend/musa/__init__.py new file mode 100644 index 000000000000..be2fb35b43db --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/__init__.py @@ -0,0 +1 @@ +# MUSA (Moore Threads GPU) hardware backend diff --git a/python/sglang/srt/hardware_backend/musa/attention/__init__.py b/python/sglang/srt/hardware_backend/musa/attention/__init__.py new file mode 100644 index 000000000000..5b522a6408c6 --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/attention/__init__.py @@ -0,0 +1,14 @@ +# MUSA attention backend +from sglang.srt.hardware_backend.musa.attention.flash_attention import ( + FlashAttentionContext, + FlashAttentionContextManager, + flash_attn_with_kvcache, + update_flash_attention_context, +) + +__all__ = [ + "FlashAttentionContext", + "FlashAttentionContextManager", + "update_flash_attention_context", + "flash_attn_with_kvcache", +] diff --git a/python/sglang/srt/hardware_backend/musa/attention/flash_attention.py b/python/sglang/srt/hardware_backend/musa/attention/flash_attention.py new file mode 100644 index 000000000000..7a2d3c63522a --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/attention/flash_attention.py @@ -0,0 +1,254 @@ +"""MUSA Flash Attention wrapper with automatic scheduler_metadata injection. + +This module provides a wrapper for mate's flash_attn_with_kvcache that automatically +computes and injects scheduler_metadata based on the current FlashAttentionContext. +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Tuple, Union + +import torch +from flash_attn import flash_attn_with_kvcache as _mate_flash_attn_with_kvcache +from flash_attn import get_scheduler_metadata + +from sglang.srt.distributed import get_pp_group, get_pp_indices +from sglang.srt.environ import envs + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + +# Global workspace buffer for MLA +_MATE_MLA_WORKSPACE_BUFFER: torch.Tensor | None = None + +# Cache for non-MLA scheduler metadata by prefix +_MATE_NO_MLA_SCHEDULER_METADATA_DICT: dict = {} + +# Thread-local storage for flash attention context +_flash_attention_context = threading.local() + + +@dataclass +class FlashAttentionContext: + """Context for MUSA flash attention calls. + + This context stores the information needed to compute scheduler_metadata + for mate's flash_attn_with_kvcache. + """ + + # Static config (set once per backend) + device: torch.device + use_mla: bool + num_hidden_layers: int + first_k_dense_replace: int + full_attention_interval: Optional[int] + + # Dynamic state (set per forward call) + layer: "RadixAttention" + prefix: str + max_seqlen_k: int + can_run_tbo: bool + + +class FlashAttentionContextManager: + """Context manager for MUSA flash attention. + + Automatically sets and clears the flash attention context on entry/exit. + This ensures cleanup happens even on early returns or exceptions. + + Usage: + with FlashAttentionContextManager(ctx): + # flash_attn_with_kvcache calls will auto-inject scheduler_metadata + ... + """ + + def __init__(self, ctx: FlashAttentionContext): + self.ctx = ctx + + def __enter__(self) -> "FlashAttentionContextManager": + _flash_attention_context.current = self.ctx + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + _flash_attention_context.current = None + return None # Don't suppress exceptions + + +def get_flash_attention_context() -> Optional[FlashAttentionContext]: + """Get the current flash attention context.""" + return getattr(_flash_attention_context, "current", None) + + +def update_flash_attention_context( + prefix: Optional[str] = None, + max_seqlen_k: Optional[int] = None, +) -> None: + """Update specific fields of the current flash attention context. + + This is useful for cascade attention where prefix and max_seqlen_k change. + """ + ctx = get_flash_attention_context() + if ctx is not None: + if prefix is not None: + ctx.prefix = prefix + if max_seqlen_k is not None: + ctx.max_seqlen_k = max_seqlen_k + + +def _compute_scheduler_metadata( + ctx: FlashAttentionContext, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k_new: Optional[torch.Tensor], + cache_seqlens: torch.Tensor, + max_seqlen_q: int, + page_size: int, + causal: bool, + window_size: Tuple[int, int], + num_splits: int, +) -> Tuple[torch.Tensor, bool] | torch.Tensor: + """Compute scheduler metadata based on context.""" + global _MATE_MLA_WORKSPACE_BUFFER, _MATE_NO_MLA_SCHEDULER_METADATA_DICT + + layer = ctx.layer + current_layer_id = layer.layer_id + batch_size = cu_seqlens_q.shape[-1] - 1 + + # Determine if scheduler metadata should be updated + should_update = True + pp_group = get_pp_group() + pp_rank = pp_group.rank_in_group + start_layer_id, _ = get_pp_indices( + ctx.num_hidden_layers, pp_group.rank_in_group, pp_group.world_size + ) + if ctx.can_run_tbo and pp_rank == 0: + start_layer_id += ( + ctx.first_k_dense_replace if ctx.first_k_dense_replace is not None else 0 + ) + + if ctx.full_attention_interval is not None: + start_layer_id += ctx.full_attention_interval - 1 + + if current_layer_id > start_layer_id: + should_update = False + + if envs.SGLANG_MUSA_FA3_FORCE_UPDATE_METADATA.get(): + should_update = True + + if ctx.use_mla: + if _MATE_MLA_WORKSPACE_BUFFER is None: + _MATE_MLA_WORKSPACE_BUFFER = torch.empty( + 128 * 1024 * 1024, device=ctx.device, dtype=torch.uint8 + ) + return (_MATE_MLA_WORKSPACE_BUFFER, not should_update) + else: + if should_update or ctx.prefix not in _MATE_NO_MLA_SCHEDULER_METADATA_DICT: + _MATE_NO_MLA_SCHEDULER_METADATA_DICT[ctx.prefix] = get_scheduler_metadata( + batch_size=batch_size, + num_heads_q=layer.tp_q_head_num, + num_heads_kv=layer.tp_k_head_num, + headdim=layer.qk_head_dim, + headdim_v=layer.v_head_dim, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + # XXX (MUSA): cu_seqlens_k_new is not supported on MATE + # cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + page_size=page_size, + causal=causal, + window_size=window_size, + num_splits=num_splits, + ) + return _MATE_NO_MLA_SCHEDULER_METADATA_DICT[ctx.prefix] + + +def flash_attn_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 0, + pack_gqa=None, + sm_margin: int = 0, + return_softmax_lse: bool = False, + **kwargs, +): + """MUSA flash_attn_with_kvcache wrapper that auto-injects scheduler_metadata. + + This wrapper retrieves the current FlashAttentionContext and computes + scheduler_metadata automatically, so call sites don't need to be modified. + """ + # Get context and compute scheduler_metadata if not provided + if scheduler_metadata is None: + ctx = get_flash_attention_context() + if ctx is not None: + page_size = k_cache.shape[1] if k_cache is not None else 1 + scheduler_metadata = _compute_scheduler_metadata( + ctx=ctx, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + cache_seqlens=cache_seqlens, + max_seqlen_q=max_seqlen_q, + page_size=page_size, + causal=causal, + window_size=window_size, + num_splits=num_splits, + ) + + return _mate_flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + # XXX (MUSA): cu_seqlens_k_new is not supported on MATE + # cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + ) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 2353c15993fd..185d474a7ad7 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,6 +1,10 @@ import logging from typing import TYPE_CHECKING +from sglang.srt.utils import get_device_capability, is_musa + +_is_musa = is_musa() + logger = logging.getLogger(__name__) @@ -125,14 +129,19 @@ def create_flashmla_backend(runner): @register_attention_backend("fa3") def create_flashattention_v3_backend(runner): - import torch - assert ( - torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend - ) or torch.cuda.get_device_capability()[0] == 9, ( - "FlashAttention v3 Backend requires SM>=80 and SM<=90. " - "Please use `--attention-backend flashinfer`." - ) + major, minor = get_device_capability() + if not _is_musa: + assert (major == 8 and not runner.use_mla_backend) or major == 9, ( + "FlashAttention v3 Backend requires SM>=80 and SM<=90. " + "Please use `--attention-backend flashinfer`." + ) + else: + assert major >= 3 and minor >= 1, ( + "FlashAttention v3 Backend requires MP>=31. " + "Please use `--attention-backend triton`." + ) + from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend return FlashAttentionBackend(runner) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 7262e503458f..4143c98395b4 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -19,16 +20,35 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import get_compiler_backend +from sglang.srt.utils import get_compiler_backend, is_musa if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 -from sgl_kernel.flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_fa3 -from sgl_kernel.flash_attn import flash_attn_with_kvcache as flash_attn_with_kvcache_fa3 +_is_musa = is_musa() +if not _is_musa: + from sgl_kernel.flash_attn import ( + flash_attn_varlen_func as flash_attn_varlen_func_fa3, + ) + from sgl_kernel.flash_attn import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_fa3, + ) +else: + from flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_fa3 + + from sglang.srt.hardware_backend.musa.attention import ( + FlashAttentionContext, + FlashAttentionContextManager, + ) + from sglang.srt.hardware_backend.musa.attention import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_fa3, + ) + from sglang.srt.hardware_backend.musa.attention import ( + update_flash_attention_context, + ) flash_attn_varlen_func = flash_attn_varlen_func_fa3 flash_attn_with_kvcache = flash_attn_with_kvcache_fa3 @@ -64,6 +84,8 @@ class FlashAttentionMetadata: page_table: torch.Tensor = None # Page table for Sliding Window Attention swa_page_table: torch.Tensor = None + # Extend from cached prefix tokens + extend_with_prefix: bool = False # Encoder metadata # Cumulative sequence lengths for encoder key @@ -402,6 +424,33 @@ def __init__( else 0 ) + if _is_musa: + self.num_hidden_layers = model_runner.model_config.num_hidden_layers + self.first_k_dense_replace = model_runner.model_config.first_k_dense_replace + self.full_attention_interval = ( + model_runner.model_config.full_attention_interval + ) + + # TODO: This function is currently designed to create a context for MUSA devices + # to automatically inject scheduler metadata. Refactoring may be required + # if other devices use it in the future. + def get_flash_attention_context(self, layer, prefix, max_seqlen_k, can_run_tbo): + if not _is_musa: + return nullcontext() + return FlashAttentionContextManager( + FlashAttentionContext( + device=self.device, + use_mla=self.use_mla, + num_hidden_layers=self.num_hidden_layers, + first_k_dense_replace=self.first_k_dense_replace, + full_attention_interval=self.full_attention_interval, + layer=layer, + prefix=prefix, + max_seqlen_k=max_seqlen_k, + can_run_tbo=can_run_tbo, + ) + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() @@ -635,9 +684,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - if any( - forward_batch.extend_prefix_lens_cpu - ) or forward_batch.forward_mode.is_draft_extend(include_v2=True): + metadata.extend_with_prefix = any(forward_batch.extend_prefix_lens_cpu) + if ( + metadata.extend_with_prefix + or forward_batch.forward_mode.is_draft_extend(include_v2=True) + ): extend_seq_lens = forward_batch.extend_seq_lens metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.cu_seqlens_q = torch.nn.functional.pad( @@ -855,6 +906,7 @@ def forward_extend( cu_seqlens_q = local_metadata.local_query_start_loc cache_seqlens = local_metadata.local_seqused_k max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len elif is_swa_layer and metadata.swa_spec_metadata is not None: swa_spec_metadata = metadata.swa_spec_metadata page_table = swa_spec_metadata.page_table @@ -862,6 +914,7 @@ def forward_extend( cache_seqlens = swa_spec_metadata.cache_seqlens_int32 max_seqlen_q = swa_spec_metadata.max_seq_len_q cu_seqlens_k = swa_spec_metadata.cu_seqlens_k + max_seqlen_k = swa_spec_metadata.max_seq_len_k else: page_table = metadata.page_table if is_swa_layer and self.use_sliding_window_kv_pool: @@ -875,7 +928,60 @@ def forward_extend( cache_seqlens = metadata.cache_seqlens_int32 max_seqlen_q = metadata.max_seq_len_q cu_seqlens_k = metadata.cu_seqlens_k + max_seqlen_k = metadata.max_seq_len_k + + with self.get_flash_attention_context( + layer, "forward_extend", max_seqlen_k, forward_batch.can_run_tbo + ): + return self._forward_extend_impl( + q=q, + k=k, + v=v, + layer=layer, + forward_batch=forward_batch, + metadata=metadata, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + window_size=window_size, + k_descale=k_descale, + v_descale=v_descale, + causal=causal, + use_cascade_attn=use_cascade_attn, + use_local_attn=use_local_attn, + q_rope=q_rope, + k_rope=k_rope, + **kwargs, + ) + def _forward_extend_impl( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: "RadixAttention", + forward_batch: ForwardBatch, + metadata: FlashAttentionMetadata, + page_table, + cache_seqlens, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + window_size, + k_descale, + v_descale, + causal, + use_cascade_attn, + use_local_attn, + q_rope, + k_rope, + **kwargs, + ): + """Internal implementation of forward_extend, wrapped by context manager.""" # Use Flash Attention for prefill if not self.use_mla: # Do multi-head attention @@ -930,7 +1036,12 @@ def _fa_cp_attn( self.device, _fa_cp_attn, ) - else: + elif ( + not _is_musa + or metadata.extend_with_prefix + or forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): result = flash_attn_with_kvcache( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache=key_cache, @@ -951,40 +1062,71 @@ def _fa_cp_attn( **kwargs, ) - if use_cascade_attn: - o, softmax_lse, *rest = result - o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( - q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - # Here metadata_expand.page_table is not divided with page_size. - # This is because we loose the fine control of what token to attend, - # but has to attend to some block completely. - k_cache=key_cache.view(-1, 1, layer.tp_k_head_num, layer.head_dim), - v_cache=value_cache.view( - -1, 1, layer.tp_v_head_num, layer.head_dim - ), - page_table=self.forward_metadata_spec_decode_expand.page_table, - cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, - cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, - cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, - max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + if use_cascade_attn: + if _is_musa: + # Need to re-create scheduler metadata for different flash_attn_with_kvcache parameters + update_flash_attention_context( + prefix="forward_extend_use_cascade_attn", + max_seqlen_k=self.forward_metadata_spec_decode_expand.max_seq_len_k, + ) + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = ( + flash_attn_with_kvcache( + q=q.contiguous().view( + -1, layer.tp_q_head_num, layer.head_dim + ), + # Here metadata_expand.page_table is not divided with page_size. + # This is because we loose the fine control of what token to attend, + # but has to attend to some block completely. + k_cache=key_cache.view( + -1, 1, layer.tp_k_head_num, layer.head_dim + ), + v_cache=value_cache.view( + -1, 1, layer.tp_v_head_num, layer.head_dim + ), + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, + softmax_scale=layer.scaling, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + num_splits=self.num_splits, + **kwargs, + ) + ) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result + else: + # MATE's MHA for extend part of sequence without attending prefix kv cache + output = flash_attn_varlen_func( + q=q.view(-1, layer.tp_q_head_num, layer.head_dim), + k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k=metadata.cu_seqlens_q, + max_seqlen_q=metadata.max_seq_len_q, + max_seqlen_k=metadata.max_seq_len_q, softmax_scale=layer.scaling, - causal=False, - window_size=window_size, - softcap=layer.logit_cap, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=True, - num_splits=self.num_splits, - **kwargs, - ) - o, _ = merge_state_v2_wrapper( - o, - softmax_lse.T.contiguous(), - o_expand, - softmax_lse_expand.T.contiguous(), + causal=True, + return_softmax_lse=forward_batch.mha_return_lse, ) - else: - o = result + if forward_batch.mha_return_lse: + output, lse, *rest = output + lse = torch.transpose(lse, 0, 1).contiguous() + return output.view(-1, layer.tp_q_head_num * layer.v_head_dim), lse + return output.view(-1, layer.tp_q_head_num * layer.v_head_dim) else: if ( forward_batch.attn_attend_prefix_cache is not None @@ -1189,6 +1331,43 @@ def forward_decode( if sinks is not None: kwargs["sinks"] = sinks + with self.get_flash_attention_context( + layer, "forward_decode", metadata.max_seq_len_k, forward_batch.can_run_tbo + ): + return self._forward_decode_impl( + q=q, + layer=layer, + forward_batch=forward_batch, + metadata=metadata, + is_swa_layer=is_swa_layer, + window_size=window_size, + causal=causal, + use_cascade_attn=use_cascade_attn, + use_local_attn=use_local_attn, + local_attn_metadata=local_attn_metadata, + q_rope=q_rope, + k_rope=k_rope, + **kwargs, + ) + + def _forward_decode_impl( + self, + q: torch.Tensor, + layer: "RadixAttention", + forward_batch: ForwardBatch, + metadata: FlashAttentionMetadata, + is_swa_layer, + window_size, + causal, + use_cascade_attn, + use_local_attn, + local_attn_metadata, + q_rope, + k_rope, + **kwargs, + ): + """Internal implementation of forward_decode, wrapped by context manager.""" + k_descale, v_descale = None, None flash_attn_with_kvcache_base = flash_attn_with_kvcache_fa3 flash_attn_with_kvcache = ( @@ -1300,6 +1479,12 @@ def forward_decode( **kwargs, ) if use_cascade_attn: + if _is_musa: + # Need to re-create scheduler metadata for different flash_attn_with_kvcache parameters + update_flash_attention_context( + prefix="forward_decode_use_cascade_attn", + max_seqlen_k=self.forward_metadata_spec_decode_expand.max_seq_len_k, + ) o, softmax_lse, *rest = result o_expand, softmax_lse_expand, *rest_expand = ( flash_attn_with_kvcache( @@ -1377,6 +1562,12 @@ def forward_decode( num_splits=self.num_splits, ) if use_cascade_attn: + if _is_musa: + # Need to re-create scheduler metadata for different flash_attn_with_kvcache parameters + update_flash_attention_context( + prefix="forward_decode_use_cascade_attn", + max_seqlen_k=self.forward_metadata_spec_decode_expand.max_seq_len_k, + ) o, softmax_lse, *rest = result o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( q=q_rope, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f2dfa1c456c6..c9f846c693fa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -52,6 +52,7 @@ is_hip, is_hopper_with_cuda_12_3, is_mps, + is_musa, is_no_spec_infer_or_topk_one, is_npu, is_remote_url, @@ -2405,6 +2406,13 @@ def _handle_attention_backend_compatibility(self): if model_config.context_len > 8192: self.mem_fraction_static *= 0.85 + # MUSA platforms compatible backends + if is_musa() and self.attention_backend == "fa3": + logger.warning( + "FA3 attention backend on MUSA ignores any user-provided page_size and enforces a fixed value of 64." + ) + self.page_size = 64 + # Other platforms backends if ( self.attention_backend == "intel_amx"