diff --git a/docker/Dockerfile b/docker/Dockerfile index d4ecf96b1485..227f4a3355c8 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -134,8 +134,8 @@ ENV UV_LINK_MODE=copy # Verify GCC version RUN gcc --version -# Ensure CUDA compatibility library is loaded -RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/00-cuda-compat.conf && ldconfig +# Workaround for triton/pytorch issues +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ # ============================================================ # SLOW-CHANGING DEPENDENCIES BELOW @@ -423,6 +423,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ libsm6 \ libxext6 \ libgl1 \ + git \ && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \ if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \ mkdir -p -m 0755 /etc/apt/keyrings ; \ @@ -473,8 +474,8 @@ ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" ENV UV_LINK_MODE=copy -# Ensure CUDA compatibility library is loaded -RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/00-cuda-compat.conf && ldconfig +# Workaround for triton/pytorch issues +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ # ============================================================ # SLOW-CHANGING DEPENDENCIES BELOW diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 1417fb99120b..380bbc30e3d1 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -11,3 +11,5 @@ torchaudio==2.9.1 torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.5.3 +# FA4 +flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@2580b5a4882562640f3cfbffd2bb8d2de9268f9f#subdirectory=flash_attn/cute \ No newline at end of file diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 293045787a1c..8f7f20ef133a 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -67,7 +67,17 @@ def compute_hash(self) -> str: def validate_backend_before(cls, value: Any) -> Any: """Enable parsing of the `backend` enum type from string.""" if isinstance(value, str): - return AttentionBackendEnum[value.upper()] + value = AttentionBackendEnum[value.upper()] + + # Disallow ViT-only attention tags in the KV-cache attention config. + if value == AttentionBackendEnum.FLASH_ATTN_CUTE: + raise ValueError( + "AttentionConfig.backend does not support FLASH_ATTN_CUTE " + "(FA4 / flash_attn.cute). This is a ViT/MM-encoder-only attention " + "tag. Use --mm-encoder-attn-backend / " + "MultiModalConfig.mm_encoder_attn_backend instead." + ) + return value def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None: diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 44e990d29c16..33e120e7660e 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -11,6 +11,7 @@ from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.ops.vit_attn_wrappers import ( + vit_fa4_flash_attn_wrapper, vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, ) @@ -79,6 +80,8 @@ def __init__( AttentionBackendEnum.ROCM_AITER_FA, } + self.is_fa4_backend = self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE + self._fa_version = ( get_flash_attn_version() if self.is_flash_attn_backend else None ) @@ -182,6 +185,40 @@ def _forward_fa( output = output.reshape(bsz, q_len, -1) return output + def _forward_fa4( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + """FA4 (flash_attn.cute) attention for multimodal encoder (no KV cache).""" + assert (cu_seqlens is not None and max_seqlen is not None) or ( + cu_seqlens is None and max_seqlen is None + ), "cu_seqlens and max_seqlen should be both set or both None." + + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_reshaped = query.dim() != 4 + + query, key, value = self.maybe_reshape_qkv_to_4d( + query, key, value, bsz, q_len, kv_len + ) + + output = vit_fa4_flash_attn_wrapper( + q=query, + k=key, + v=value, + batch_size=bsz, + scale=self.scale, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + if is_reshaped: + output = output.reshape(bsz, q_len, -1) + return output + def forward_native( self, query: torch.Tensor, @@ -200,7 +237,9 @@ def forward_cuda( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: - if self.is_flash_attn_backend: + if self.is_fa4_backend: + return self._forward_fa4(query, key, value, cu_seqlens, max_seqlen) + elif self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 34de1da561f5..6836ba8d7566 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -135,7 +135,10 @@ def __init__( self.apply_rotary_emb_flash_attn = None if find_spec("flash_attn") is not None: - from flash_attn.ops.triton.rotary import apply_rotary + try: + from flash_attn.ops.triton.rotary import apply_rotary + except ImportError: + apply_rotary = None self.apply_rotary_emb_flash_attn = apply_rotary diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b8da164ee8e3..9cfd12a31903 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -632,6 +632,7 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.ROCM_AITER_FA, }: @@ -785,6 +786,7 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 56c3db49ed77..44b45a08dc1e 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -396,6 +396,7 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.ROCM_AITER_FA, }: @@ -538,6 +539,7 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() diff --git a/vllm/model_executor/warmup/fa4_warmup.py b/vllm/model_executor/warmup/fa4_warmup.py new file mode 100644 index 000000000000..a9b0bcd61459 --- /dev/null +++ b/vllm/model_executor/warmup/fa4_warmup.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Warmup FA4 (flash_attn.cute) kernels for ViT/MM encoder attention. + +We specifically warm up the FlashAttention Cute-DSL (FA4) compile cache by +running a few representative varlen attention calls that differ only in +sequence length. This helps avoid JIT compilation in the hot path. + +This warmup is: +- Blackwell-only (compute capability 10.x) +- Opt-in (only when mm_encoder_attn_backend == FLASH_ATTN_CUTE) +- Scoped to Qwen3-VL / Qwen3-VL-MoE vision transformer workloads +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.attention.backends.registry import AttentionBackendEnum + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_worker import Worker + +logger = init_logger(__name__) + + +def _get_default_qwen3_vit_warmup_seqlens( + max_positions: int | None = None, +) -> list[int]: + candidates = [ + 16**2, # 256 + 24**2, # 576 + 32**2, # 1024 + 48**2, # 2304 + 64**2, # 4096 + 96**2, # 9216 + 128**2, # 16384 + 192**2, # 36864 + 256**2, # 65536 + ] + if max_positions is None: + return candidates + return [s for s in candidates if s <= max_positions] + + +def should_fa4_vit_warmup(worker: Worker) -> bool: + """Fast predicate used by `kernel_warmup` to gate FA4 warmup.""" + if not current_platform.is_cuda(): + return False + cc = current_platform.get_device_capability() + if cc is None or cc.major != 10: + return False + + mm_cfg = getattr(worker.model_config, "multimodal_config", None) + return ( + mm_cfg is not None + and mm_cfg.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE + ) + + +def fa4_vit_warmup(worker: Worker) -> None: + """Warm up FA4 kernels for Qwen3-VL(-MoE) ViT attention.""" + + # Config gating: only warm up when explicitly selected for mm encoder. + if not should_fa4_vit_warmup(worker): + return + + # Dependency gating. + from vllm.v1.attention.backends.fa4_utils import ( + is_flash_attn_cute_available, + supports_dtype, + warn_if_unoptimized_head_size, + ) + + if not is_flash_attn_cute_available(): + logger.warning( + "Skipping FA4 warmup: `flash_attn.cute.interface` is not available." + ) + return + + model = worker.get_model() + visual = getattr(model, "visual", None) + if visual is None: + # Not a Qwen3-VL(-MoE) style model, or vision tower disabled. + logger.warning("Skipping FA4 warmup: vision tower disabled or not found.") + return + + # Derive head shape and dtype from the actual vision attention module. + try: + first_attn = visual.blocks[0].attn # Qwen2_5_VisionAttention + head_size = int(first_attn.hidden_size_per_attention_head) + num_heads = int(first_attn.num_attention_heads_per_partition) + scale = float(first_attn.hidden_size_per_attention_head**-0.5) + dtype = visual.dtype + except Exception: + # If the model structure is unexpected, skip warmup. + return + + if not supports_dtype(dtype): + # If dtype is not supported, the FA4 backend should not have been selected. + logger.warning_once( + "Skipping FA4 warmup: dtype %s is not supported by flash_attn.cute.", + dtype, + ) + return + + warn_if_unoptimized_head_size(head_size) + + seqlens = tuple(_get_default_qwen3_vit_warmup_seqlens()) + + logger.info_once( + "Warming up FA4 (flash_attn.cute) ViT kernels for seqlens=%s " + "(head_size=%d, num_heads=%d, dtype=%s).", + seqlens, + head_size, + num_heads, + dtype, + ) + + # Run a small number of representative calls that only vary seqlen. + # Compilation key can be found under `flash_attn/cute/interface.py`. + from vllm.v1.attention.backends.fa4_utils import flash_attn_varlen_func + + device = torch.device("cuda") + with torch.inference_mode(): + for seqlen in seqlens: + q = torch.empty((seqlen, num_heads, head_size), device=device, dtype=dtype) + k = torch.empty_like(q) + v = torch.empty_like(q) + cu = torch.tensor([0, seqlen], device=device, dtype=torch.int32) + + # This call will populate FA4's internal compile cache (Cute-DSL). + _ = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu, + cu_seqlens_k=cu, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + softmax_scale=scale, + causal=False, + ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 98b28d3e5292..515b3b95e55c 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.model_executor.warmup.fa4_warmup import fa4_vit_warmup, should_fa4_vit_warmup from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer @@ -72,6 +73,10 @@ def _is_flashinfer_backend(backend): create_mixed_batch=True, ) + # FA4 (flash_attn.cute) warmup for ViT/MM encoder attention. + if should_fa4_vit_warmup(worker): + fa4_vit_warmup(worker) + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 47d634416ae5..8f315881df45 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -361,6 +361,7 @@ def get_attn_backend_cls( def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: return [ AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.FLASH_ATTN, ] @@ -371,11 +372,45 @@ def get_vit_attn_backend( dtype: torch.dtype, backend: Optional["AttentionBackendEnum"] = None, ) -> "AttentionBackendEnum": + cc = cls.get_device_capability() + if backend is not None: assert backend in cls.get_supported_vit_attn_backends(), ( f"Backend {backend} is not supported for vit attention. " f"Supported backends are: {cls.get_supported_vit_attn_backends()}" ) + + # FA4 is Blackwell-only and opt-in (via --mm-encoder-attn-backend). + if backend == AttentionBackendEnum.FLASH_ATTN_CUTE: + if cc is None or cc.major != 10: + raise ValueError( + "FLASH_ATTN_CUTE (FA4 / flash_attn.cute) is only supported on " + "Blackwell GPUs (compute capability 10.x)." + ) + + from vllm.v1.attention.backends.fa4_utils import ( + is_flash_attn_cute_available, + warn_if_unoptimized_head_size, + ) + from vllm.v1.attention.backends.fa4_utils import ( + supports_dtype as fa4_supports_dtype, + ) + + if not fa4_supports_dtype(dtype): + raise ValueError( + "FLASH_ATTN_CUTE (FA4 / flash_attn.cute) only supports " + "float16/bfloat16 for ViT attention." + ) + + if not is_flash_attn_cute_available(): + raise ImportError( + "FLASH_ATTN_CUTE (FA4 / flash_attn.cute) selected, but " + "`flash_attn.cute.interface` is not available in this " + "environment." + ) + + warn_if_unoptimized_head_size(head_size) + logger.info_once(f"Using backend {backend} for vit attention") return backend diff --git a/vllm/v1/attention/backends/fa4_utils.py b/vllm/v1/attention/backends/fa4_utils.py new file mode 100644 index 000000000000..b1ce45e6bd4d --- /dev/null +++ b/vllm/v1/attention/backends/fa4_utils.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from importlib.util import find_spec + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +# flash_attn.cute.interface (Cute-DSL / FA4). +# +# NOTE: vLLM currently only enables this path for **Blackwell** GPUs +# (compute capability 10.x) and only for ViT/MM encoder attention. +# It is NOT a KV-cache attention backend. +_OPTIMIZED_HEAD_SIZES: tuple[int, ...] = (64, 96, 128, 192) + + +def warn_if_unoptimized_head_size(head_size: int) -> None: + """Warn if `head_size` is outside the known-optimized set. + + We intentionally don't hard-block on head_size here, since upstream support + may evolve and some shapes may still work (albeit slower). + """ + if head_size not in _OPTIMIZED_HEAD_SIZES: + logger.warning_once( + "FA4 (flash_attn.cute) selected for head_size=%d, which is not in the " + "known-optimized set %s. The kernel may be slower or unsupported.", + head_size, + _OPTIMIZED_HEAD_SIZES, + ) + + +def supports_dtype(dtype: torch.dtype) -> bool: + return dtype in (torch.float16, torch.bfloat16) + + +def supports_device() -> bool: + if not current_platform.is_cuda(): + return False + cc = current_platform.get_device_capability() + return cc is not None and cc.major == 10 + + +def is_flash_attn_cute_available() -> bool: + """Best-effort availability check for FA4 (flash_attn.cute). + + This intentionally avoids importing `flash_attn.cute.interface` because + that may pull in heavy deps (cutlass-dsl / cuda-python). The actual import + happens in `flash_attn_varlen_func`. + """ + if not supports_device(): + return False + return find_spec("flash_attn.cute.interface") is not None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k: torch.Tensor | None = None, + max_seqlen_q: int | None = None, + max_seqlen_k: int | None = None, + seqused_q: torch.Tensor | None = None, + seqused_k: torch.Tensor | None = None, + softmax_scale: float | None = None, + causal: bool = False, + window_size: tuple[int | None, int | None] = (None, None), + deterministic: bool = False, +) -> torch.Tensor: + """FA4 (Cute-DSL) FlashAttention varlen forward. + + Wraps `flash_attn.cute.interface.flash_attn_varlen_func`, which returns + `(out, lse)`. vLLM only needs `out` for inference. + """ + if not current_platform.is_cuda(): + raise RuntimeError("FA4 (flash_attn.cute) is only supported on CUDA.") + + try: + from flash_attn.cute.interface import flash_attn_varlen_func as _fa4_varlen + except Exception as e: + raise ImportError( + "FA4 (flash_attn.cute) is not available. " + "Please ensure the Cute-DSL FlashAttention build is installed " + "(e.g. nvidia-cutlass-dsl) and cuda-python bindings are present." + ) from e + + out, _lse = _fa4_varlen( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + ) + return out diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index bd45702fa587..6bdf9691b402 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -42,6 +42,10 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + # FA4 (Cute-DSL) - this tag is only used for ViT (MM encoder) attention. + # NOTE: This backend does not implement the KV-cache attention path and + # should not be used with `--attention-config.backend`. + FLASH_ATTN_CUTE = "flash_attn.cute" FLASH_ATTN_DIFFKV = ( "vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend" ) diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index f077a61c984f..8fa0a442c3a5 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -110,6 +110,85 @@ def vit_flash_attn_wrapper( ) +def fa4_flash_attn_maxseqlen_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + """FA4 (flash_attn.cute) wrapper for ViT attention. + + flash_attn.cute returns (out, lse); we only return out. + """ + from vllm.v1.attention.backends.fa4_utils import ( + flash_attn_varlen_func as fa4_flash_attn_varlen_func, + ) + + q_len = q.size(1) + if cu_seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device + ) + max_seqlen_int = q_len if max_seqlen is None else max_seqlen.item() + + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = fa4_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_int, + max_seqlen_k=max_seqlen_int, + softmax_scale=scale, + causal=False, + ) + context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) + return context_layer + + +def fa4_flash_attn_maxseqlen_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(q) + + +direct_register_custom_op( + op_name="fa4_flash_attn_maxseqlen_wrapper", + op_func=fa4_flash_attn_maxseqlen_wrapper, + fake_impl=fa4_flash_attn_maxseqlen_wrapper_fake, +) + + +def vit_fa4_flash_attn_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_size: int, + scale: float | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.ops.vllm.fa4_flash_attn_maxseqlen_wrapper( + q, + k, + v, + batch_size, + scale, + cu_seqlens, + max_seqlen, + ) + + def apply_sdpa( q: torch.Tensor, k: torch.Tensor,