diff --git a/python/pyproject.toml b/python/pyproject.toml index bfe0dd6495ee..b8b63aa5aa6e 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -86,6 +86,7 @@ diffusion = [ "PyYAML==6.0.1", "cloudpickle", "diffusers==0.36.0", + "flash-attn", "imageio==2.36.0", "imageio-ffmpeg==0.5.1", "moviepy>=2.0.0", diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index 17a9cd576de8..6cf4862cf75b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -7,25 +7,15 @@ import torch -from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - # flash_attn 3 no longer have a different API, see following commit: - # https://github.com/Dao-AILab/flash-attention/commit/ed209409acedbb2379f870bbd03abce31a7a51b7 - flash_attn_func = flash_attn_varlen_func + from flash_attn_interface import flash_attn_varlen_func except ImportError as e: - raise e - - -try: - from flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_upstream, - ) -except Exception: - flash_attn_varlen_func_upstream = None + raise ImportError( + "flash-attention library is required. Please install it with: " + "pip install flash-attn --no-build-isolation" + ) from e from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( AttentionBackend, @@ -37,8 +27,6 @@ logger = init_logger(__name__) -fa_ver = 3 - @lru_cache(maxsize=128) def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor: @@ -51,45 +39,6 @@ def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor: ) -@lru_cache(maxsize=256) -def _should_use_upstream_flash_attention( - upstream_available: bool, - upstream_heads_ok: bool, - q_shape: tuple[int, ...], - k_shape: tuple[int, ...], - v_shape: tuple[int, ...], -) -> bool: - if not upstream_available or not upstream_heads_ok: - return False - - if len(q_shape) != 4 or len(k_shape) != 4 or len(v_shape) != 4: - return False - - bsz, seqlen, nheads_q, d = q_shape - bsz_k, seqlen_k, nheads_k, d_k = k_shape - bsz_v, seqlen_v, nheads_v, d_v = v_shape - - if ( - bsz != bsz_k - or bsz != bsz_v - or seqlen != seqlen_k - or seqlen != seqlen_v - or d != d_k - or d != d_v - ): - return False - if nheads_k != nheads_v: - return False - if nheads_k == 0 or (nheads_q % nheads_k) != 0: - return False - return True - - -def set_fa_ver(ver: int): - global fa_ver - fa_ver = ver - - @dataclass class FlashAttentionMetadata: # Sequence lengths for the forward batch @@ -162,13 +111,6 @@ def __init__( self.causal = causal self.softmax_scale = softmax_scale self.attention_metadata = FlashAttentionMetadata() - if self.num_kv_heads is None: - self._upstream_heads_ok = True - else: - # For gqa, the num_heads must be a multiple of num_kv_heads - self._upstream_heads_ok = ( - self.num_kv_heads > 0 and (self.num_heads % self.num_kv_heads) == 0 - ) def forward( self, @@ -179,58 +121,30 @@ def forward( *, return_softmax_lse: bool = False, ): - attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata - if attn_metadata is not None and attn_metadata.max_seqlen_q is None: - attn_metadata.max_seqlen_q = query.shape[1] - attn_metadata.max_seqlen_k = key.shape[1] - max_seqlen_q = attn_metadata.max_seqlen_q - max_seqlen_k = attn_metadata.max_seqlen_k - else: - max_seqlen_q = query.shape[1] - max_seqlen_k = key.shape[1] - q_shape = tuple(query.shape) - k_shape = tuple(key.shape) - v_shape = tuple(value.shape) - use_upstream = _should_use_upstream_flash_attention( - flash_attn_varlen_func_upstream is not None, - self._upstream_heads_ok, - q_shape, - k_shape, - v_shape, - ) - - if use_upstream: - bsz, seqlen, nheads_q, d = q_shape - bsz_k, seqlen_k, nheads_k, d_k = k_shape - bsz_v, seqlen_v, nheads_v, d_v = v_shape - q_ = query.contiguous().reshape(bsz * seqlen, nheads_q, d) - k_ = key.contiguous().reshape(bsz * seqlen, nheads_k, d) - v_ = value.contiguous().reshape(bsz * seqlen, nheads_v, d) - cu = _get_cu_seqlens(q_.device.index, bsz, seqlen) - out = flash_attn_varlen_func_upstream( - q_, - k_, - v_, - cu, - cu, - seqlen, - seqlen, - softmax_scale=self.softmax_scale, - causal=self.causal, - ) - return out.reshape(bsz, seqlen, nheads_q, d) - - output = flash_attn_func( - q=query, # type: ignore[no-untyped-call] - k=key, - v=value, - cu_seqlens_q=None, - cu_seqlens_k=None, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + bsz, seqlen, nheads_q, d = query.shape + bsz_k, seqlen_k, nheads_k, d_k = key.shape + + q_ = query.contiguous().reshape(bsz * seqlen, nheads_q, d) + k_ = key.contiguous().reshape(bsz * seqlen_k, nheads_k, d_k) + v_ = value.contiguous().reshape(bsz * seqlen_k, nheads_k, value.shape[-1]) + + cu_seqlens_q = _get_cu_seqlens(q_.device.index, bsz, seqlen) + cu_seqlens_k = _get_cu_seqlens(k_.device.index, bsz, seqlen_k) + + out = flash_attn_varlen_func( + q_, + k_, + v_, + cu_seqlens_q, + cu_seqlens_k, + seqlen, + seqlen_k, softmax_scale=self.softmax_scale, causal=self.causal, - return_softmax_lse=return_softmax_lse, - ver=fa_ver, + return_attn_probs=return_softmax_lse, ) - return output + + if return_softmax_lse: + out, softmax_lse = out + return out.reshape(bsz, seqlen, nheads_q, -1), softmax_lse + return out.reshape(bsz, seqlen, nheads_q, -1) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py index 62a1974adc4e..c53d022de0c6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py @@ -4,6 +4,14 @@ import torch +try: + from flash_attn_interface import flash_attn_varlen_func +except ImportError as e: + raise ImportError( + "flash-attention library is required. Please install it with: " + "pip install flash-attn --no-build-isolation" + ) from e + from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( AttentionBackend, AttentionImpl, @@ -11,7 +19,7 @@ AttentionMetadataBuilder, ) from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( - flash_attn_func, + _get_cu_seqlens, ) from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -64,16 +72,33 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, + *, + return_softmax_lse: bool = False, ): - output = flash_attn_func( - q=query, # type: ignore[no-untyped-call] - k=key, - v=value, - cu_seqlens_q=None, - cu_seqlens_k=None, - max_seqlen_q=None, - max_seqlen_k=None, + bsz, seqlen, nheads_q, d = query.shape + bsz_k, seqlen_k, nheads_k, d_k = key.shape + + q_ = query.contiguous().reshape(bsz * seqlen, nheads_q, d) + k_ = key.contiguous().reshape(bsz * seqlen_k, nheads_k, d_k) + v_ = value.contiguous().reshape(bsz * seqlen_k, nheads_k, value.shape[-1]) + + cu_seqlens_q = _get_cu_seqlens(q_.device.index, bsz, seqlen) + cu_seqlens_k = _get_cu_seqlens(k_.device.index, bsz, seqlen_k) + + out = flash_attn_varlen_func( + q_, + k_, + v_, + cu_seqlens_q, + cu_seqlens_k, + seqlen, + seqlen_k, softmax_scale=self.softmax_scale, causal=self.causal, + return_attn_probs=return_softmax_lse, ) - return output + + if return_softmax_lse: + out, softmax_lse = out + return out.reshape(bsz, seqlen, nheads_q, -1), softmax_lse + return out.reshape(bsz, seqlen, nheads_q, -1)