Skip to content
Merged
55 changes: 30 additions & 25 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import math
from functools import lru_cache, partial
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -89,6 +89,27 @@ def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Ten
return cu_seqlens


def resolve_seqlens(
cu_seqlens: torch.Tensor | SingletonCache | None,
bsz: int,
seq_len: int,
*,
device: torch.device,
) -> torch.Tensor:
if cu_seqlens is None:
resolved_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(_get_cu_seqlens_for_shape(bsz, seq_len, device=device))
resolved_seqlens = cu_seqlens.get_data()
else:
resolved_seqlens = cu_seqlens
assert isinstance(
resolved_seqlens, torch.Tensor
), "cu_seqlens must be a torch.Tensor"
return resolved_seqlens


class VisionSdpaAttention(nn.Module):
r"""
Scaled Dot Product Attention inner product
Expand Down Expand Up @@ -258,7 +279,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[torch.Tensor],
cu_seqlens: torch.Tensor | SingletonCache | None,
bsz: int,
seq_len: int,
**kwargs,
Expand All @@ -269,8 +290,7 @@ def forward(
Returns:
[b * s, h, head_size]
"""
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)

# [b * s, head, head_size]
output = torch.empty_like(q)
Expand Down Expand Up @@ -304,7 +324,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
cu_seqlens: torch.Tensor | SingletonCache | None,
bsz: int,
seq_len: int,
**kwargs,
Expand All @@ -315,14 +335,7 @@ def forward(
Returns:
[b * s, h, head_size]
"""
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
)
cu_seqlens = cu_seqlens.get_data()
cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)

cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
Expand Down Expand Up @@ -363,19 +376,12 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
cu_seqlens: torch.Tensor | SingletonCache | None,
bsz: int,
seq_len: int,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
)
cu_seqlens = cu_seqlens.get_data()
cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)

cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
Expand Down Expand Up @@ -407,7 +413,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
cu_seqlens: torch.Tensor | SingletonCache | None,
bsz: int,
seq_len: int,
**kwargs,
Expand All @@ -418,8 +424,7 @@ def forward(
Returns:
[b * s, h, head_size]
"""
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device)

seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
if seq_lens.is_npu:
Expand Down
Loading