diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 17ba7bcfbc23..9f556a885862 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -4,7 +4,7 @@ import functools import math from functools import lru_cache, partial -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple import torch import torch.nn as nn @@ -89,6 +89,27 @@ def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Ten return cu_seqlens +def resolve_seqlens( + cu_seqlens: torch.Tensor | SingletonCache | None, + bsz: int, + seq_len: int, + *, + device: torch.device, +) -> torch.Tensor: + if cu_seqlens is None: + resolved_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=device) + elif isinstance(cu_seqlens, SingletonCache): + if cu_seqlens.empty(): + cu_seqlens.set_data(_get_cu_seqlens_for_shape(bsz, seq_len, device=device)) + resolved_seqlens = cu_seqlens.get_data() + else: + resolved_seqlens = cu_seqlens + assert isinstance( + resolved_seqlens, torch.Tensor + ), "cu_seqlens must be a torch.Tensor" + return resolved_seqlens + + class VisionSdpaAttention(nn.Module): r""" Scaled Dot Product Attention inner product @@ -258,7 +279,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], + cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, **kwargs, @@ -269,8 +290,7 @@ def forward( Returns: [b * s, h, head_size] """ - if cu_seqlens is None: - cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) + cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) # [b * s, head, head_size] output = torch.empty_like(q) @@ -304,7 +324,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]], + cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, **kwargs, @@ -315,14 +335,7 @@ def forward( Returns: [b * s, h, head_size] """ - if cu_seqlens is None: - cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) - elif isinstance(cu_seqlens, SingletonCache): - if cu_seqlens.empty(): - cu_seqlens.set_data( - _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) - ) - cu_seqlens = cu_seqlens.get_data() + cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] @@ -363,19 +376,12 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]], + cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, **kwargs, ) -> torch.Tensor: - if cu_seqlens is None: - cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) - elif isinstance(cu_seqlens, SingletonCache): - if cu_seqlens.empty(): - cu_seqlens.set_data( - _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) - ) - cu_seqlens = cu_seqlens.get_data() + cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] @@ -407,7 +413,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]], + cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, **kwargs, @@ -418,8 +424,7 @@ def forward( Returns: [b * s, h, head_size] """ - if cu_seqlens is None: - cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) + cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] if seq_lens.is_npu: