Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ diffusion = [
"PyYAML==6.0.1",
"cloudpickle",
"diffusers==0.36.0",
"flash-attn",
"imageio==2.36.0",
"imageio-ffmpeg==0.5.1",
"moviepy>=2.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,15 @@

import torch

from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum

try:
from sgl_kernel.flash_attn import flash_attn_varlen_func

# flash_attn 3 no longer have a different API, see following commit:
# https://github.com/Dao-AILab/flash-attention/commit/ed209409acedbb2379f870bbd03abce31a7a51b7
flash_attn_func = flash_attn_varlen_func
from flash_attn_interface import flash_attn_varlen_func
except ImportError as e:
raise e


try:
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_upstream,
)
except Exception:
flash_attn_varlen_func_upstream = None
raise ImportError(
"flash-attention library is required. Please install it with: "
"pip install flash-attn --no-build-isolation"
) from e

from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
Expand All @@ -37,8 +27,6 @@

logger = init_logger(__name__)

fa_ver = 3


@lru_cache(maxsize=128)
def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor:
Expand All @@ -51,45 +39,6 @@ def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor:
)


@lru_cache(maxsize=256)
def _should_use_upstream_flash_attention(
upstream_available: bool,
upstream_heads_ok: bool,
q_shape: tuple[int, ...],
k_shape: tuple[int, ...],
v_shape: tuple[int, ...],
) -> bool:
if not upstream_available or not upstream_heads_ok:
return False

if len(q_shape) != 4 or len(k_shape) != 4 or len(v_shape) != 4:
return False

bsz, seqlen, nheads_q, d = q_shape
bsz_k, seqlen_k, nheads_k, d_k = k_shape
bsz_v, seqlen_v, nheads_v, d_v = v_shape

if (
bsz != bsz_k
or bsz != bsz_v
or seqlen != seqlen_k
or seqlen != seqlen_v
or d != d_k
or d != d_v
):
return False
if nheads_k != nheads_v:
return False
if nheads_k == 0 or (nheads_q % nheads_k) != 0:
return False
return True


def set_fa_ver(ver: int):
global fa_ver
fa_ver = ver


@dataclass
class FlashAttentionMetadata:
# Sequence lengths for the forward batch
Expand Down Expand Up @@ -162,13 +111,6 @@ def __init__(
self.causal = causal
self.softmax_scale = softmax_scale
self.attention_metadata = FlashAttentionMetadata()
if self.num_kv_heads is None:
self._upstream_heads_ok = True
else:
# For gqa, the num_heads must be a multiple of num_kv_heads
self._upstream_heads_ok = (
self.num_kv_heads > 0 and (self.num_heads % self.num_kv_heads) == 0
)

def forward(
self,
Expand All @@ -179,58 +121,30 @@ def forward(
*,
return_softmax_lse: bool = False,
):
attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata
if attn_metadata is not None and attn_metadata.max_seqlen_q is None:
attn_metadata.max_seqlen_q = query.shape[1]
attn_metadata.max_seqlen_k = key.shape[1]
max_seqlen_q = attn_metadata.max_seqlen_q
max_seqlen_k = attn_metadata.max_seqlen_k
else:
max_seqlen_q = query.shape[1]
max_seqlen_k = key.shape[1]
q_shape = tuple(query.shape)
k_shape = tuple(key.shape)
v_shape = tuple(value.shape)
use_upstream = _should_use_upstream_flash_attention(
flash_attn_varlen_func_upstream is not None,
self._upstream_heads_ok,
q_shape,
k_shape,
v_shape,
)

if use_upstream:
bsz, seqlen, nheads_q, d = q_shape
bsz_k, seqlen_k, nheads_k, d_k = k_shape
bsz_v, seqlen_v, nheads_v, d_v = v_shape
q_ = query.contiguous().reshape(bsz * seqlen, nheads_q, d)
k_ = key.contiguous().reshape(bsz * seqlen, nheads_k, d)
v_ = value.contiguous().reshape(bsz * seqlen, nheads_v, d)
cu = _get_cu_seqlens(q_.device.index, bsz, seqlen)
out = flash_attn_varlen_func_upstream(
q_,
k_,
v_,
cu,
cu,
seqlen,
seqlen,
softmax_scale=self.softmax_scale,
causal=self.causal,
)
return out.reshape(bsz, seqlen, nheads_q, d)

output = flash_attn_func(
q=query, # type: ignore[no-untyped-call]
k=key,
v=value,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
bsz, seqlen, nheads_q, d = query.shape
bsz_k, seqlen_k, nheads_k, d_k = key.shape

q_ = query.contiguous().reshape(bsz * seqlen, nheads_q, d)
k_ = key.contiguous().reshape(bsz * seqlen_k, nheads_k, d_k)
v_ = value.contiguous().reshape(bsz * seqlen_k, nheads_k, value.shape[-1])

cu_seqlens_q = _get_cu_seqlens(q_.device.index, bsz, seqlen)
cu_seqlens_k = _get_cu_seqlens(k_.device.index, bsz, seqlen_k)

out = flash_attn_varlen_func(
q_,
k_,
v_,
cu_seqlens_q,
cu_seqlens_k,
seqlen,
seqlen_k,
softmax_scale=self.softmax_scale,
causal=self.causal,
return_softmax_lse=return_softmax_lse,
ver=fa_ver,
return_attn_probs=return_softmax_lse,
)
return output

if return_softmax_lse:
out, softmax_lse = out
return out.reshape(bsz, seqlen, nheads_q, -1), softmax_lse
return out.reshape(bsz, seqlen, nheads_q, -1)
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@

import torch

try:
from flash_attn_interface import flash_attn_varlen_func
except ImportError as e:
raise ImportError(
"flash-attention library is required. Please install it with: "
"pip install flash-attn --no-build-isolation"
) from e

from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (
flash_attn_func,
_get_cu_seqlens,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down Expand Up @@ -64,16 +72,33 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
*,
return_softmax_lse: bool = False,
):
output = flash_attn_func(
q=query, # type: ignore[no-untyped-call]
k=key,
v=value,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
bsz, seqlen, nheads_q, d = query.shape
bsz_k, seqlen_k, nheads_k, d_k = key.shape

q_ = query.contiguous().reshape(bsz * seqlen, nheads_q, d)
k_ = key.contiguous().reshape(bsz * seqlen_k, nheads_k, d_k)
v_ = value.contiguous().reshape(bsz * seqlen_k, nheads_k, value.shape[-1])

cu_seqlens_q = _get_cu_seqlens(q_.device.index, bsz, seqlen)
cu_seqlens_k = _get_cu_seqlens(k_.device.index, bsz, seqlen_k)

out = flash_attn_varlen_func(
q_,
k_,
v_,
cu_seqlens_q,
cu_seqlens_k,
seqlen,
seqlen_k,
softmax_scale=self.softmax_scale,
causal=self.causal,
return_attn_probs=return_softmax_lse,
)
return output

if return_softmax_lse:
out, softmax_lse = out
return out.reshape(bsz, seqlen, nheads_q, -1), softmax_lse
return out.reshape(bsz, seqlen, nheads_q, -1)
Loading