diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index 1d9b664bfd3c..ba19cbc13fa3 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -116,9 +116,6 @@ srt_musa = [ "torch_musa", "torchada>=0.1.45", "mthreads-ml-py", - "mate", - "mate-deep_gemm", - "mate-flash-attention", "numpy<2.0", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 966e7d620ecc..a7f66c8443f5 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -563,12 +563,6 @@ def _derive_model_shapes(self): self.num_key_value_heads = getattr( self.hf_text_config, "num_key_value_heads", None ) - self.first_k_dense_replace = getattr( - self.hf_text_config, "first_k_dense_replace", None - ) - self.full_attention_interval = getattr( - self.hf_text_config, "full_attention_interval", None - ) # for Dbrx and MPT models if self.hf_config.model_type in ["dbrx", "mpt"]: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 7d375459ab15..4d0964efd44f 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -331,9 +331,6 @@ class Envs: SGLANG_USE_AG_AFTER_QLORA = EnvBool(False) SGLANG_NPU_FUSED_MOE_MODE = EnvInt(1) - # MTHREADS & MUSA - SGLANG_MUSA_FA3_FORCE_UPDATE_METADATA = EnvBool(False) - # Quantization SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_CPU_QUANTIZATION = EnvBool(False) diff --git a/python/sglang/srt/hardware_backend/musa/__init__.py b/python/sglang/srt/hardware_backend/musa/__init__.py deleted file mode 100644 index be2fb35b43db..000000000000 --- a/python/sglang/srt/hardware_backend/musa/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# MUSA (Moore Threads GPU) hardware backend diff --git a/python/sglang/srt/hardware_backend/musa/attention/__init__.py b/python/sglang/srt/hardware_backend/musa/attention/__init__.py deleted file mode 100644 index 5b522a6408c6..000000000000 --- a/python/sglang/srt/hardware_backend/musa/attention/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# MUSA attention backend -from sglang.srt.hardware_backend.musa.attention.flash_attention import ( - FlashAttentionContext, - FlashAttentionContextManager, - flash_attn_with_kvcache, - update_flash_attention_context, -) - -__all__ = [ - "FlashAttentionContext", - "FlashAttentionContextManager", - "update_flash_attention_context", - "flash_attn_with_kvcache", -] diff --git a/python/sglang/srt/hardware_backend/musa/attention/flash_attention.py b/python/sglang/srt/hardware_backend/musa/attention/flash_attention.py deleted file mode 100644 index 7a2d3c63522a..000000000000 --- a/python/sglang/srt/hardware_backend/musa/attention/flash_attention.py +++ /dev/null @@ -1,254 +0,0 @@ -"""MUSA Flash Attention wrapper with automatic scheduler_metadata injection. - -This module provides a wrapper for mate's flash_attn_with_kvcache that automatically -computes and injects scheduler_metadata based on the current FlashAttentionContext. -""" - -from __future__ import annotations - -import threading -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Tuple, Union - -import torch -from flash_attn import flash_attn_with_kvcache as _mate_flash_attn_with_kvcache -from flash_attn import get_scheduler_metadata - -from sglang.srt.distributed import get_pp_group, get_pp_indices -from sglang.srt.environ import envs - -if TYPE_CHECKING: - from sglang.srt.layers.radix_attention import RadixAttention - -# Global workspace buffer for MLA -_MATE_MLA_WORKSPACE_BUFFER: torch.Tensor | None = None - -# Cache for non-MLA scheduler metadata by prefix -_MATE_NO_MLA_SCHEDULER_METADATA_DICT: dict = {} - -# Thread-local storage for flash attention context -_flash_attention_context = threading.local() - - -@dataclass -class FlashAttentionContext: - """Context for MUSA flash attention calls. - - This context stores the information needed to compute scheduler_metadata - for mate's flash_attn_with_kvcache. - """ - - # Static config (set once per backend) - device: torch.device - use_mla: bool - num_hidden_layers: int - first_k_dense_replace: int - full_attention_interval: Optional[int] - - # Dynamic state (set per forward call) - layer: "RadixAttention" - prefix: str - max_seqlen_k: int - can_run_tbo: bool - - -class FlashAttentionContextManager: - """Context manager for MUSA flash attention. - - Automatically sets and clears the flash attention context on entry/exit. - This ensures cleanup happens even on early returns or exceptions. - - Usage: - with FlashAttentionContextManager(ctx): - # flash_attn_with_kvcache calls will auto-inject scheduler_metadata - ... - """ - - def __init__(self, ctx: FlashAttentionContext): - self.ctx = ctx - - def __enter__(self) -> "FlashAttentionContextManager": - _flash_attention_context.current = self.ctx - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - _flash_attention_context.current = None - return None # Don't suppress exceptions - - -def get_flash_attention_context() -> Optional[FlashAttentionContext]: - """Get the current flash attention context.""" - return getattr(_flash_attention_context, "current", None) - - -def update_flash_attention_context( - prefix: Optional[str] = None, - max_seqlen_k: Optional[int] = None, -) -> None: - """Update specific fields of the current flash attention context. - - This is useful for cascade attention where prefix and max_seqlen_k change. - """ - ctx = get_flash_attention_context() - if ctx is not None: - if prefix is not None: - ctx.prefix = prefix - if max_seqlen_k is not None: - ctx.max_seqlen_k = max_seqlen_k - - -def _compute_scheduler_metadata( - ctx: FlashAttentionContext, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k_new: Optional[torch.Tensor], - cache_seqlens: torch.Tensor, - max_seqlen_q: int, - page_size: int, - causal: bool, - window_size: Tuple[int, int], - num_splits: int, -) -> Tuple[torch.Tensor, bool] | torch.Tensor: - """Compute scheduler metadata based on context.""" - global _MATE_MLA_WORKSPACE_BUFFER, _MATE_NO_MLA_SCHEDULER_METADATA_DICT - - layer = ctx.layer - current_layer_id = layer.layer_id - batch_size = cu_seqlens_q.shape[-1] - 1 - - # Determine if scheduler metadata should be updated - should_update = True - pp_group = get_pp_group() - pp_rank = pp_group.rank_in_group - start_layer_id, _ = get_pp_indices( - ctx.num_hidden_layers, pp_group.rank_in_group, pp_group.world_size - ) - if ctx.can_run_tbo and pp_rank == 0: - start_layer_id += ( - ctx.first_k_dense_replace if ctx.first_k_dense_replace is not None else 0 - ) - - if ctx.full_attention_interval is not None: - start_layer_id += ctx.full_attention_interval - 1 - - if current_layer_id > start_layer_id: - should_update = False - - if envs.SGLANG_MUSA_FA3_FORCE_UPDATE_METADATA.get(): - should_update = True - - if ctx.use_mla: - if _MATE_MLA_WORKSPACE_BUFFER is None: - _MATE_MLA_WORKSPACE_BUFFER = torch.empty( - 128 * 1024 * 1024, device=ctx.device, dtype=torch.uint8 - ) - return (_MATE_MLA_WORKSPACE_BUFFER, not should_update) - else: - if should_update or ctx.prefix not in _MATE_NO_MLA_SCHEDULER_METADATA_DICT: - _MATE_NO_MLA_SCHEDULER_METADATA_DICT[ctx.prefix] = get_scheduler_metadata( - batch_size=batch_size, - num_heads_q=layer.tp_q_head_num, - num_heads_kv=layer.tp_k_head_num, - headdim=layer.qk_head_dim, - headdim_v=layer.v_head_dim, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - # XXX (MUSA): cu_seqlens_k_new is not supported on MATE - # cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - page_size=page_size, - causal=causal, - window_size=window_size, - num_splits=num_splits, - ) - return _MATE_NO_MLA_SCHEDULER_METADATA_DICT[ctx.prefix] - - -def flash_attn_with_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k: Optional[torch.Tensor] = None, - v: Optional[torch.Tensor] = None, - qv: Optional[torch.Tensor] = None, - rotary_cos: Optional[torch.Tensor] = None, - rotary_sin: Optional[torch.Tensor] = None, - cache_seqlens: Optional[Union[int, torch.Tensor]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - rotary_seqlens: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - attention_chunk: int = 0, - softcap: float = 0.0, - rotary_interleaved: bool = True, - scheduler_metadata: Optional[torch.Tensor] = None, - num_splits: int = 0, - pack_gqa=None, - sm_margin: int = 0, - return_softmax_lse: bool = False, - **kwargs, -): - """MUSA flash_attn_with_kvcache wrapper that auto-injects scheduler_metadata. - - This wrapper retrieves the current FlashAttentionContext and computes - scheduler_metadata automatically, so call sites don't need to be modified. - """ - # Get context and compute scheduler_metadata if not provided - if scheduler_metadata is None: - ctx = get_flash_attention_context() - if ctx is not None: - page_size = k_cache.shape[1] if k_cache is not None else 1 - scheduler_metadata = _compute_scheduler_metadata( - ctx=ctx, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - cache_seqlens=cache_seqlens, - max_seqlen_q=max_seqlen_q, - page_size=page_size, - causal=causal, - window_size=window_size, - num_splits=num_splits, - ) - - return _mate_flash_attn_with_kvcache( - q=q, - k_cache=k_cache, - v_cache=v_cache, - k=k, - v=v, - qv=qv, - rotary_cos=rotary_cos, - rotary_sin=rotary_sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - # XXX (MUSA): cu_seqlens_k_new is not supported on MATE - # cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - rotary_seqlens=rotary_seqlens, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - attention_chunk=attention_chunk, - softcap=softcap, - rotary_interleaved=rotary_interleaved, - scheduler_metadata=scheduler_metadata, - num_splits=num_splits, - pack_gqa=pack_gqa, - sm_margin=sm_margin, - return_softmax_lse=return_softmax_lse, - ) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 185d474a7ad7..2353c15993fd 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,10 +1,6 @@ import logging from typing import TYPE_CHECKING -from sglang.srt.utils import get_device_capability, is_musa - -_is_musa = is_musa() - logger = logging.getLogger(__name__) @@ -129,19 +125,14 @@ def create_flashmla_backend(runner): @register_attention_backend("fa3") def create_flashattention_v3_backend(runner): + import torch - major, minor = get_device_capability() - if not _is_musa: - assert (major == 8 and not runner.use_mla_backend) or major == 9, ( - "FlashAttention v3 Backend requires SM>=80 and SM<=90. " - "Please use `--attention-backend flashinfer`." - ) - else: - assert major >= 3 and minor >= 1, ( - "FlashAttention v3 Backend requires MP>=31. " - "Please use `--attention-backend triton`." - ) - + assert ( + torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend + ) or torch.cuda.get_device_capability()[0] == 9, ( + "FlashAttention v3 Backend requires SM>=80 and SM<=90. " + "Please use `--attention-backend flashinfer`." + ) from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend return FlashAttentionBackend(runner) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 4143c98395b4..7262e503458f 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -20,35 +19,16 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import get_compiler_backend, is_musa +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 +from sgl_kernel.flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_fa3 +from sgl_kernel.flash_attn import flash_attn_with_kvcache as flash_attn_with_kvcache_fa3 -_is_musa = is_musa() -if not _is_musa: - from sgl_kernel.flash_attn import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa3, - ) - from sgl_kernel.flash_attn import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_fa3, - ) -else: - from flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_fa3 - - from sglang.srt.hardware_backend.musa.attention import ( - FlashAttentionContext, - FlashAttentionContextManager, - ) - from sglang.srt.hardware_backend.musa.attention import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_fa3, - ) - from sglang.srt.hardware_backend.musa.attention import ( - update_flash_attention_context, - ) flash_attn_varlen_func = flash_attn_varlen_func_fa3 flash_attn_with_kvcache = flash_attn_with_kvcache_fa3 @@ -84,8 +64,6 @@ class FlashAttentionMetadata: page_table: torch.Tensor = None # Page table for Sliding Window Attention swa_page_table: torch.Tensor = None - # Extend from cached prefix tokens - extend_with_prefix: bool = False # Encoder metadata # Cumulative sequence lengths for encoder key @@ -424,33 +402,6 @@ def __init__( else 0 ) - if _is_musa: - self.num_hidden_layers = model_runner.model_config.num_hidden_layers - self.first_k_dense_replace = model_runner.model_config.first_k_dense_replace - self.full_attention_interval = ( - model_runner.model_config.full_attention_interval - ) - - # TODO: This function is currently designed to create a context for MUSA devices - # to automatically inject scheduler metadata. Refactoring may be required - # if other devices use it in the future. - def get_flash_attention_context(self, layer, prefix, max_seqlen_k, can_run_tbo): - if not _is_musa: - return nullcontext() - return FlashAttentionContextManager( - FlashAttentionContext( - device=self.device, - use_mla=self.use_mla, - num_hidden_layers=self.num_hidden_layers, - first_k_dense_replace=self.first_k_dense_replace, - full_attention_interval=self.full_attention_interval, - layer=layer, - prefix=prefix, - max_seqlen_k=max_seqlen_k, - can_run_tbo=can_run_tbo, - ) - ) - def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() @@ -684,11 +635,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - metadata.extend_with_prefix = any(forward_batch.extend_prefix_lens_cpu) - if ( - metadata.extend_with_prefix - or forward_batch.forward_mode.is_draft_extend(include_v2=True) - ): + if any( + forward_batch.extend_prefix_lens_cpu + ) or forward_batch.forward_mode.is_draft_extend(include_v2=True): extend_seq_lens = forward_batch.extend_seq_lens metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.cu_seqlens_q = torch.nn.functional.pad( @@ -906,7 +855,6 @@ def forward_extend( cu_seqlens_q = local_metadata.local_query_start_loc cache_seqlens = local_metadata.local_seqused_k max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len elif is_swa_layer and metadata.swa_spec_metadata is not None: swa_spec_metadata = metadata.swa_spec_metadata page_table = swa_spec_metadata.page_table @@ -914,7 +862,6 @@ def forward_extend( cache_seqlens = swa_spec_metadata.cache_seqlens_int32 max_seqlen_q = swa_spec_metadata.max_seq_len_q cu_seqlens_k = swa_spec_metadata.cu_seqlens_k - max_seqlen_k = swa_spec_metadata.max_seq_len_k else: page_table = metadata.page_table if is_swa_layer and self.use_sliding_window_kv_pool: @@ -928,60 +875,7 @@ def forward_extend( cache_seqlens = metadata.cache_seqlens_int32 max_seqlen_q = metadata.max_seq_len_q cu_seqlens_k = metadata.cu_seqlens_k - max_seqlen_k = metadata.max_seq_len_k - - with self.get_flash_attention_context( - layer, "forward_extend", max_seqlen_k, forward_batch.can_run_tbo - ): - return self._forward_extend_impl( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - metadata=metadata, - page_table=page_table, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - window_size=window_size, - k_descale=k_descale, - v_descale=v_descale, - causal=causal, - use_cascade_attn=use_cascade_attn, - use_local_attn=use_local_attn, - q_rope=q_rope, - k_rope=k_rope, - **kwargs, - ) - def _forward_extend_impl( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: "RadixAttention", - forward_batch: ForwardBatch, - metadata: FlashAttentionMetadata, - page_table, - cache_seqlens, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - window_size, - k_descale, - v_descale, - causal, - use_cascade_attn, - use_local_attn, - q_rope, - k_rope, - **kwargs, - ): - """Internal implementation of forward_extend, wrapped by context manager.""" # Use Flash Attention for prefill if not self.use_mla: # Do multi-head attention @@ -1036,12 +930,7 @@ def _fa_cp_attn( self.device, _fa_cp_attn, ) - elif ( - not _is_musa - or metadata.extend_with_prefix - or forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() - ): + else: result = flash_attn_with_kvcache( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache=key_cache, @@ -1062,71 +951,40 @@ def _fa_cp_attn( **kwargs, ) - if use_cascade_attn: - if _is_musa: - # Need to re-create scheduler metadata for different flash_attn_with_kvcache parameters - update_flash_attention_context( - prefix="forward_extend_use_cascade_attn", - max_seqlen_k=self.forward_metadata_spec_decode_expand.max_seq_len_k, - ) - o, softmax_lse, *rest = result - o_expand, softmax_lse_expand, *rest_expand = ( - flash_attn_with_kvcache( - q=q.contiguous().view( - -1, layer.tp_q_head_num, layer.head_dim - ), - # Here metadata_expand.page_table is not divided with page_size. - # This is because we loose the fine control of what token to attend, - # but has to attend to some block completely. - k_cache=key_cache.view( - -1, 1, layer.tp_k_head_num, layer.head_dim - ), - v_cache=value_cache.view( - -1, 1, layer.tp_v_head_num, layer.head_dim - ), - page_table=self.forward_metadata_spec_decode_expand.page_table, - cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, - cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, - cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, - max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, - softmax_scale=layer.scaling, - causal=False, - window_size=window_size, - softcap=layer.logit_cap, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=True, - num_splits=self.num_splits, - **kwargs, - ) - ) - o, _ = merge_state_v2_wrapper( - o, - softmax_lse.T.contiguous(), - o_expand, - softmax_lse_expand.T.contiguous(), - ) - else: - o = result - else: - # MATE's MHA for extend part of sequence without attending prefix kv cache - output = flash_attn_varlen_func( - q=q.view(-1, layer.tp_q_head_num, layer.head_dim), - k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), - v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), - cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k=metadata.cu_seqlens_q, - max_seqlen_q=metadata.max_seq_len_q, - max_seqlen_k=metadata.max_seq_len_q, + if use_cascade_attn: + o, softmax_lse, *rest = result + o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + # Here metadata_expand.page_table is not divided with page_size. + # This is because we loose the fine control of what token to attend, + # but has to attend to some block completely. + k_cache=key_cache.view(-1, 1, layer.tp_k_head_num, layer.head_dim), + v_cache=value_cache.view( + -1, 1, layer.tp_v_head_num, layer.head_dim + ), + page_table=self.forward_metadata_spec_decode_expand.page_table, + cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, + cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, + cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, softmax_scale=layer.scaling, - causal=True, - return_softmax_lse=forward_batch.mha_return_lse, + causal=False, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=True, + num_splits=self.num_splits, + **kwargs, ) - if forward_batch.mha_return_lse: - output, lse, *rest = output - lse = torch.transpose(lse, 0, 1).contiguous() - return output.view(-1, layer.tp_q_head_num * layer.v_head_dim), lse - return output.view(-1, layer.tp_q_head_num * layer.v_head_dim) + o, _ = merge_state_v2_wrapper( + o, + softmax_lse.T.contiguous(), + o_expand, + softmax_lse_expand.T.contiguous(), + ) + else: + o = result else: if ( forward_batch.attn_attend_prefix_cache is not None @@ -1331,43 +1189,6 @@ def forward_decode( if sinks is not None: kwargs["sinks"] = sinks - with self.get_flash_attention_context( - layer, "forward_decode", metadata.max_seq_len_k, forward_batch.can_run_tbo - ): - return self._forward_decode_impl( - q=q, - layer=layer, - forward_batch=forward_batch, - metadata=metadata, - is_swa_layer=is_swa_layer, - window_size=window_size, - causal=causal, - use_cascade_attn=use_cascade_attn, - use_local_attn=use_local_attn, - local_attn_metadata=local_attn_metadata, - q_rope=q_rope, - k_rope=k_rope, - **kwargs, - ) - - def _forward_decode_impl( - self, - q: torch.Tensor, - layer: "RadixAttention", - forward_batch: ForwardBatch, - metadata: FlashAttentionMetadata, - is_swa_layer, - window_size, - causal, - use_cascade_attn, - use_local_attn, - local_attn_metadata, - q_rope, - k_rope, - **kwargs, - ): - """Internal implementation of forward_decode, wrapped by context manager.""" - k_descale, v_descale = None, None flash_attn_with_kvcache_base = flash_attn_with_kvcache_fa3 flash_attn_with_kvcache = ( @@ -1479,12 +1300,6 @@ def _forward_decode_impl( **kwargs, ) if use_cascade_attn: - if _is_musa: - # Need to re-create scheduler metadata for different flash_attn_with_kvcache parameters - update_flash_attention_context( - prefix="forward_decode_use_cascade_attn", - max_seqlen_k=self.forward_metadata_spec_decode_expand.max_seq_len_k, - ) o, softmax_lse, *rest = result o_expand, softmax_lse_expand, *rest_expand = ( flash_attn_with_kvcache( @@ -1562,12 +1377,6 @@ def _forward_decode_impl( num_splits=self.num_splits, ) if use_cascade_attn: - if _is_musa: - # Need to re-create scheduler metadata for different flash_attn_with_kvcache parameters - update_flash_attention_context( - prefix="forward_decode_use_cascade_attn", - max_seqlen_k=self.forward_metadata_spec_decode_expand.max_seq_len_k, - ) o, softmax_lse, *rest = result o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( q=q_rope, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c60250678cd1..4229f40932f7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -52,7 +52,6 @@ is_hip, is_hopper_with_cuda_12_3, is_mps, - is_musa, is_no_spec_infer_or_topk_one, is_npu, is_remote_url, @@ -2413,13 +2412,6 @@ def _handle_attention_backend_compatibility(self): if model_config.context_len > 8192: self.mem_fraction_static *= 0.85 - # MUSA platforms compatible backends - if is_musa() and self.attention_backend == "fa3": - logger.warning( - "FA3 attention backend on MUSA ignores any user-provided page_size and enforces a fixed value of 64." - ) - self.page_size = 64 - # Other platforms backends if ( self.attention_backend == "intel_amx"