diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index d7513f69a85..0516c844811 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -166,10 +166,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s ## Kernel backend -| Arguments | Description | Defaults | -|----------|-------------|---------| -| `attention_backend` | This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `flashmla`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. | None | -| `sampling_backend` | Specifies the backend used for sampling. | None | +| Arguments | Description | Defaults | +|------------------------|-------------|---------| +| `attention_backend` | This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. | None | +| `sampling_backend` | Specifies the backend used for sampling. | None | +| `mm_attention_backend` | Set multimodal attention backend. ## Constrained Decoding diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index d65104beb58..429787ec86b 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -1,6 +1,7 @@ from __future__ import annotations -from functools import lru_cache +import math +from functools import lru_cache, wraps from typing import Optional, Tuple import torch @@ -8,6 +9,13 @@ import torch.nn.functional as F from einops import rearrange +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel.flash_attn import flash_attn_varlen_func + from sglang.srt.distributed import parallel_state from sglang.srt.distributed import utils as dist_utils from sglang.srt.layers.attention.triton_ops.prefill_attention import ( @@ -19,166 +27,31 @@ RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half -from sglang.srt.utils import add_prefix +from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.utils import add_prefix, logger - -class VisionAttention(nn.Module): - r""" - Multi-headed attention without any cache, mostly used for ViT. +ROTARY_EMBED_CLASSES = { + "normal": apply_rotary_pos_emb, +} - Args: - use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. - use_context_forward (bool, default to True): - if ``True``, a flash_attn style attention will be applied - Otherwise, a full-sequence attention will be applied. - softmax_in_single_precision (bool, default to False): - if ``True``, the softmax will be performed in single-precision - Otherwise, it will be performed in half-precision +def execute_once(func): + has_run = None - """ + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal has_run + if not has_run: + func(*args, **kwargs) + has_run = True - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - use_qkv_parallel: bool, - quant_config: Optional[QuantizationConfig] = None, - dropout: float = 0.0, - use_context_forward: bool = True, - softmax_in_single_precision: bool = False, - flatten_batch: bool = False, - prefix: str = "", - ): - super().__init__() - self.use_context_forward = use_context_forward - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.dropout = dropout - self.head_size = embed_dim // num_heads - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads - ) - self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size - ) + return wrapper - if self.use_context_forward: - self.qkv_backend = VisionTritonAttention() - else: - self.qkv_backend = VisionSdpaAttention( - head_size=self.head_size, - dropout=dropout, - flatten_batch=flatten_batch, - softmax_in_single_precision=softmax_in_single_precision, - ) - self.use_qkv_parallel = use_qkv_parallel - if use_qkv_parallel: - self.qkv_proj = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.head_size, - total_num_heads=num_heads, - quant_config=quant_config, - prefix=add_prefix("qkv_proj", prefix), - ) - else: - self.qkv_proj = ColumnParallelLinear( - input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=add_prefix("qkv_proj", prefix), - ) - self.proj = RowParallelLinear( - input_size=embed_dim, - output_size=embed_dim, - quant_config=quant_config, - prefix=add_prefix("proj", prefix), - ) - - def forward( - self, - x: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - x: [b, s, embed_dim] - cu_seqlens: [b] - Returns: - [s, b, head * head_size] - """ - bsz, s, _ = x.shape - head = self.num_attention_heads_per_partition - if self.use_qkv_parallel: - # [b, s, embed_dim] --> [b, s, embed_dim] - qkv, _ = self.qkv_proj(x) - q, k, v = qkv.chunk(3, dim=-1) - - # [b, s, embed_dim] --> [b * s, head, head_size] - q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)] - else: - # [b, s, embed_dim] --> [s, b, embed_dim] - x = rearrange(x, "b s ... -> s b ...") - # [s, b, embed_dim] --> [s, b, head * 3 * head_size] - qkv, _ = self.qkv_proj(x) - # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] - new_x_shape = qkv.size()[:-1] + ( - head, - 3 * self.hidden_size_per_attention_head, - ) - qkv = qkv.view(*new_x_shape) - - # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] - q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) - - # [s, b, head, head_size] --> [b, s, head, head_size] - q, k, v = [ - rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) - ] - - if position_embeddings is not None: - cos, sin = position_embeddings - original_shape = q.shape - # [total_tokens, head, head_size] - q = q.view(-1, head, self.head_size) - k = k.view(-1, head, self.head_size) - - q, k = apply_rotary_pos_emb(q, k, cos, sin) - - q = q.view(original_shape) - k = k.view(original_shape) - - if self.use_qkv_parallel: - pass - else: - # [b, s, head, head_size] --> [b * s, head, head_size] - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - - output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask) - - if self.use_qkv_parallel: - # [b * s, h, head_size] --> [b, s, h * head_size] - output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) - - # [b, s, h * head_size] --> [b, s, h * head_size] - output, _ = self.proj(output) - else: - # [b * s, h, head_size] --> [s, b, h * head_size] - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=bsz, s=s - ).contiguous() - - # [s, b, h * head_size] --> [s, b, h * head_size] - output, _ = self.proj(context_layer) - - # [s, b, h * head_size] --> [b, s, h * head_size] - output = output.view(bsz, s, -1) - - return output +@execute_once +def info_once(message: str): + logger.info(message) class VisionSdpaAttention(nn.Module): @@ -189,16 +62,22 @@ class VisionSdpaAttention(nn.Module): def __init__( self, - head_size: int, + head_dim: int, + num_heads: int, + num_kv_heads: int, dropout: float = 0.0, flatten_batch: bool = False, softmax_in_single_precision: bool = False, + **kwargs, ): super().__init__() - self.head_size = head_size + self.head_size = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads self.flatten_batch = flatten_batch self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout + self.scale = 1.0 / math.sqrt(self.head_size) @staticmethod @lru_cache(maxsize=128) @@ -212,7 +91,7 @@ def _generate_mask_cache( flatten_batch: whether to flatten batch dimension cu_seqlens: tuple of cumulative sequence lengths Returns: - attention mask tensor + attention mask tensor of shape [b, 1, s, s] or [1, s, s] """ if flatten_batch: mask = torch.zeros([1, s, s], dtype=torch.bool) @@ -241,7 +120,7 @@ def generate_patch_attention_mask( flatten_batch: bool = False, ) -> Optional[torch.Tensor]: r""" - Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`. Args: s: sequence length cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask @@ -264,6 +143,7 @@ def forward( bsz: int, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: r""" Args: @@ -274,6 +154,8 @@ def forward( if self.flatten_batch: assert bsz == 1, "flatten_batch is True, bsz must be 1" + assert q.dim() == 3, q.shape + s = q.shape[0] // bsz # [b, 1, s, s] @@ -291,10 +173,10 @@ def forward( q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] if self.softmax_in_single_precision: - scale = self.head_size**-0.5 - k_transposed = rearrange(k, "b h s d -> b h d s") - attn_weights = torch.matmul(q, k_transposed) * scale - del k, k_transposed + k = rearrange(k, "b h s d -> b h d s") + attn_weights = torch.matmul(q, k) * self.scale + del k + # masking attention_mask = (~attention_mask) * torch.finfo(q.dtype).min attn_weights = attn_weights + attention_mask del attention_mask @@ -332,6 +214,7 @@ class VisionTritonAttention(nn.Module): def __init__( self, + **kwargs, ): super().__init__() @@ -340,8 +223,8 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - _bsz: int, cu_seqlens: Optional[torch.Tensor], + **kwargs, ) -> torch.Tensor: r""" Args: @@ -366,3 +249,247 @@ def forward( ) return output + + +class VisionFlash3Attention(nn.Module): + def __init__( + self, + **kwargs, + ): + if not _is_cuda: + raise Exception("VisionFlash3Attention is only available for cuda") + super().__init__() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda() + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + ) + + return output + + +QKV_BACKEND_IMPL = { + "triton_attn": VisionTritonAttention, + "sdpa": VisionSdpaAttention, + "fa3": VisionFlash3Attention, +} + + +class VisionAttention(nn.Module): + r""" + Multi-headed attention without any cache, mostly used for multimodal transformers. + + + Args: + use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. + softmax_in_single_precision (bool, default to False): + if ``True``, the softmax will be performed in single-precision + Otherwise, it will be performed in half-precision + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + use_qkv_parallel: bool, + qkv_backend: Optional[str] = None, + quant_config: Optional[QuantizationConfig] = None, + dropout: float = 0.0, + softmax_in_single_precision: bool = False, + flatten_batch: bool = False, + prefix: str = "", + proj_bias: bool = True, + **kwargs, + ): + super().__init__() + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.dropout = dropout + self.head_size = embed_dim // num_heads + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size + ) + self.num_attention_kv_heads_per_partition = dist_utils.divide( + num_heads, world_size + ) + + self.q_size = self.num_attention_heads_per_partition * self.head_size + self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size + + if global_server_args_dict["mm_attention_backend"] is None: + if qkv_backend is None: + qkv_backend = "sdpa" + info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") + else: + qkv_backend = global_server_args_dict["mm_attention_backend"] + + info_once(f"Using {qkv_backend} as multimodal attention backend.") + + self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( + head_dim=self.head_size, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_kv_heads_per_partition, + dropout=dropout, + flatten_batch=flatten_batch, + softmax_in_single_precision=softmax_in_single_precision, + ) + + self.use_qkv_parallel = use_qkv_parallel + if use_qkv_parallel: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_size, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + else: + self.qkv_proj = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.proj = RowParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=proj_bias, + quant_config=quant_config, + prefix=add_prefix("proj", prefix), + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + x: [b, s, embed_dim] + cu_seqlens: [b] + Returns: + [s, b, head * head_size] + """ + if x.dim() == 2: + x = x.unsqueeze(0) + assert x.dim() == 3, x.shape + bsz, s, _ = x.shape + head = self.num_attention_heads_per_partition + kv_head = self.num_attention_kv_heads_per_partition + if self.use_qkv_parallel: + # [b, s, embed_dim] --> [b, s, embed_dim] + qkv, _ = self.qkv_proj(x) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # [b, s, embed_dim] --> [b * s, head, head_size] + q = q.reshape(bsz * s, head, -1).contiguous() + k = k.reshape(bsz * s, kv_head, -1).contiguous() + v = v.reshape(bsz * s, kv_head, -1).contiguous() + else: + # [b, s, embed_dim] --> [s, b, embed_dim] + x = rearrange(x, "b s ... -> s b ...") + # [s, b, embed_dim] --> [s, b, head * 3 * head_size] + qkv, _ = self.qkv_proj(x) + + # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] + new_x_shape = qkv.size()[:-1] + ( + head, + 3 * self.hidden_size_per_attention_head, + ) + qkv = qkv.view(*new_x_shape) + + # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] + q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) + # [s, b, head, head_size] --> [b, s, head, head_size] + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + + if position_embeddings is not None: + cos, sin = position_embeddings + original_shape = q.shape + # [total_tokens, head, head_size] + q = q.view(-1, head, self.head_size) + k = k.view(-1, head, self.head_size) + + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + q = q.view(original_shape) + k = k.view(original_shape) + + if q.dim() == 4: + # [b, s, head, head_size] --> [b * s, head, head_size] + q = rearrange(q, "b s ... -> (b s) ...") + if k.dim() == 4: + # [b, s, head, head_size] --> [b * s, head, head_size] + k = rearrange(k, "b s ... -> (b s) ...") + if v.dim() == 4: + # [b, s, head, head_size] --> [b * s, head, head_size] + v = rearrange(v, "b s ... -> (b s) ...") + + assert q.dim() == 3, q.dim() + assert k.dim() == 3, k.dim() + assert v.dim() == 3, v.dim() + + output = self.qkv_backend.forward( + q=q, + k=k, + v=v, + bsz=bsz, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + ) + + assert output.dim() == 3, output.shape + + if self.use_qkv_parallel: + # [b * s, h, head_size] --> [b, s, h * head_size] + output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) + + # [b, s, h * head_size] --> [b, s, h * head_size] + output, _ = self.proj(output) + else: + # [b * s, h, head_size] --> [s, b, h * head_size] + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=bsz, s=s + ).contiguous() + + # [s, b, h * head_size] --> [s, b, h * head_size] + output, _ = self.proj(context_layer) + + # [s, b, h * head_size] --> [b, s, h * head_size] + output = output.view(bsz, s, -1) + + return output diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ff08f182add..d2291087165 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -173,6 +173,7 @@ def __init__( "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "use_mla_backend": self.use_mla_backend, + "mm_attention_backend": server_args.mm_attention_backend, } ) diff --git a/python/sglang/srt/models/clip.py b/python/sglang/srt/models/clip.py index eecb9208def..999f77ead82 100644 --- a/python/sglang/srt/models/clip.py +++ b/python/sglang/srt/models/clip.py @@ -151,20 +151,20 @@ def __init__( self.layer_norm1 = norm_layer(config.hidden_size) self.layer_norm2 = norm_layer(config.hidden_size) if attn_implementation == "sdpa": - use_context_forward = False + qkv_backend = "sdpa" softmax_in_single_precision = False elif attn_implementation == "flash_attention_2": + qkv_backend = "triton_attn" softmax_in_single_precision = False - use_context_forward = True elif attn_implementation == "eager": + qkv_backend = "sdpa" softmax_in_single_precision = True - use_context_forward = False self.self_attn = VisionAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, projection_size=config.hidden_size, use_qkv_parallel=True, - use_context_forward=use_context_forward, + qkv_backend=qkv_backend, softmax_in_single_precision=softmax_in_single_precision, flatten_batch=True, quant_config=quant_config, diff --git a/python/sglang/srt/models/deepseek_janus_pro.py b/python/sglang/srt/models/deepseek_janus_pro.py index 04dcf787e2f..bef9c07ed32 100644 --- a/python/sglang/srt/models/deepseek_janus_pro.py +++ b/python/sglang/srt/models/deepseek_janus_pro.py @@ -532,7 +532,7 @@ def __init__( num_heads=num_heads, projection_size=dim, use_qkv_parallel=True, - use_context_forward=False, + qkv_backend="sdpa", softmax_in_single_precision=False, dropout=attn_drop, ) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 92a690b2db0..824c58916b4 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -281,7 +281,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]): pixel_values = torch.stack( flatten_nested_list([item.pixel_values for item in items]), dim=0 ) - pixel_values = pixel_values.to("cuda") + pixel_values = pixel_values.to(device=self.vision_tower.device) pixel_values = pixel_values.to(dtype=self.language_model.dtype()) vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 0323efb784e..5793fa819f5 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -197,7 +197,7 @@ def __init__( use_qkv_parallel=True, quant_config=quant_config, dropout=config.attention_dropout, - use_context_forward=False, + qkv_backend="sdpa", softmax_in_single_precision=True, flatten_batch=False, prefix=add_prefix("self_attn", prefix), diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index aea1fdf7162..8d63d9cfca4 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -203,7 +203,7 @@ def __init__( use_qkv_parallel=True, quant_config=quant_config, dropout=0.0, - use_context_forward=False, + qkv_backend="sdpa", softmax_in_single_precision=False, flatten_batch=False, prefix=add_prefix("self_attn", prefix), diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 17697a83554..1d52c92cd59 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -125,16 +125,20 @@ def __init__( self.norm1 = Qwen2RMSNorm(dim, eps=1e-6) self.norm2 = Qwen2RMSNorm(dim, eps=1e-6) if attn_implementation == "sdpa": - use_context_forward = False softmax_in_single_precision = False + qkv_backend = "sdpa" flatten_batch = True elif attn_implementation == "flash_attention_2": softmax_in_single_precision = False - use_context_forward = True + qkv_backend = "triton_attn" flatten_batch = True elif attn_implementation == "eager": softmax_in_single_precision = True - use_context_forward = False + qkv_backend = "sdpa" + flatten_batch = True + elif attn_implementation == "flash_attention_3": + softmax_in_single_precision = False + qkv_backend = "fa3" flatten_batch = True self.attn = VisionAttention( @@ -142,7 +146,7 @@ def __init__( num_heads=num_heads, projection_size=dim, use_qkv_parallel=True, - use_context_forward=use_context_forward, + qkv_backend=qkv_backend, softmax_in_single_precision=softmax_in_single_precision, flatten_batch=flatten_batch, quant_config=quant_config, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 7401111d996..f653401d81a 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -139,21 +139,21 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if attn_implementation == "sdpa": - use_context_forward = False + qkv_backend = "sdpa" softmax_in_single_precision = False elif attn_implementation == "flash_attention_2": + qkv_backend = "triton_attn" softmax_in_single_precision = False - use_context_forward = True elif attn_implementation == "eager": + qkv_backend = "sdpa" softmax_in_single_precision = True - use_context_forward = False self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, use_qkv_parallel=True, - use_context_forward=use_context_forward, + qkv_backend=qkv_backend, softmax_in_single_precision=softmax_in_single_precision, flatten_batch=True, quant_config=quant_config, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 718208959d8..0aa71e34478 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -187,6 +187,7 @@ class ServerArgs: n_share_experts_fusion: int = 0 disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False + mm_attention_backend: Optional[str] = None # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -1265,6 +1266,14 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", ) + parser.add_argument( + "--mm-attention-backend", + type=str, + choices=["sdpa", "fa3", "triton_attn"], + default=ServerArgs.mm_attention_backend, + help="Set multimodal attention backend.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size