Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ARG CUDA_VERSION=13.0.2
ARG PYTHON_VERSION=3.12
ARG UBUNTU_VERSION=22.04
ARG FLASHINFER_VERSION=0.6.8.post1

# By parameterizing the base images, we allow third-party to use their own
# base images. One use case is hermetic builds with base images stored in
Expand Down Expand Up @@ -101,6 +102,7 @@ FROM ${BUILD_BASE_IMAGE} AS base

ARG CUDA_VERSION
ARG PYTHON_VERSION
ARG FLASHINFER_VERSION
ARG BUILD_OS

ENV DEBIAN_FRONTEND=noninteractive
Expand Down Expand Up @@ -212,6 +214,14 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
fi

# `flashinfer-python` is already installed via requirements/cuda.txt above;
# this only activates its `[cu13]` extra (cu13 deps for the SM100 GDN kernel).
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
uv pip install --python /opt/venv/bin/python3 \
"flashinfer-python[cu13]==${FLASHINFER_VERSION}"; \
fi

# Track PyTorch lib versions used during build and match in downstream instances.
# We do this for both nightly and release so we can strip dependencies/*.txt as needed.
# Otherwise library dependencies can upgrade/downgrade torch incorrectly.
Expand Down Expand Up @@ -522,6 +532,7 @@ ARG PYTHON_VERSION
ARG DEADSNAKES_MIRROR_URL
ARG DEADSNAKES_GPGKEY_URL
ARG GET_PIP_URL
ARG FLASHINFER_VERSION

ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /vllm-workspace
Expand Down Expand Up @@ -620,10 +631,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \
rm /tmp/requirements-cuda.txt /tmp/common.txt

# `flashinfer-python` is already installed via requirements/cuda.txt above;
# this only activates its `[cu13]` extra (cu13 deps for the SM100 GDN kernel).
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
uv pip install --system \
"flashinfer-python[cu13]==${FLASHINFER_VERSION}"; \
fi

# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
ARG FLASHINFER_VERSION=0.6.8.post1
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
Expand Down
6 changes: 3 additions & 3 deletions docker/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"UBUNTU_VERSION": {
"default": "22.04"
},
"FLASHINFER_VERSION": {
"default": "0.6.8.post1"
},
"BUILD_BASE_IMAGE": {
"default": "nvidia/cuda:13.0.2-devel-ubuntu22.04"
},
Expand Down Expand Up @@ -67,9 +70,6 @@
"RUN_WHEEL_CHECK": {
"default": "true"
},
"FLASHINFER_VERSION": {
"default": "0.6.8.post1"
},
"GDRCOPY_CUDA_VERSION": {
"default": "12.8"
},
Expand Down
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,14 @@ def _read_requirements(filename: str) -> list[str]:
# vllm-flash-attn is built only for CUDA 12.x.
# Skip for other versions.
continue
if req.startswith("flashinfer-python") and cuda_major == "13":
Comment thread
arpera marked this conversation as resolved.
# Activate FI's `[cu13]` extra on cu13 builds (cu13 deps for
# the SM100 GDN kernel). Mirrors the Dockerfile cu13 path.
req = req.replace(
"flashinfer-python",
"flashinfer-python[cu13]",
1,
)
modified_requirements.append(req)
requirements = modified_requirements
elif _is_hip():
Expand Down
91 changes: 64 additions & 27 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,61 @@
logger = init_logger(__name__)


def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> bool:
"""Whether to use FlashInfer's GDN prefill kernel instead of the
Triton/FLA fallback.

Requirements:
* ``requested in ["flashinfer", "auto"]``;
* ``platform == cuda``;
* one of the following:
- Hopper (SM90) — no further constraints;
- Blackwell (SM10.x) with ``head_k_dim == 128`` and ``cuda_runtime >= 13``.
"""
if backend not in ["flashinfer", "auto"]:
return False
if not current_platform.is_cuda():
return False
if current_platform.is_device_capability(90):
return True # Hopper — no further constraints.
if not current_platform.is_device_capability_family(100):
return False # Neither Hopper nor Blackwell.
if head_k_dim != 128:
return False
return current_platform.get_cuda_runtime_major() >= 13


def _log_gdn_backend_decision(
backend: str, head_k_dim: int | None, use_flashinfer: bool
) -> None:
"""Dump the inputs to the backend decision and the final choice."""
is_cuda = current_platform.is_cuda()
platform = "cuda" if is_cuda else current_platform.device_name
cuda_runtime = torch.version.cuda or "n/a"
device_cap = str(current_platform.get_device_capability()) if is_cuda else "n/a"
logger.info_once(
"GDN prefill backend inputs:\n"
" requested=%s\n"
" platform=%s, cuda_runtime=%s, device_capability=%s\n"
" head_k_dim=%s",
backend,
platform,
cuda_runtime,
device_cap,
head_k_dim,
scope="local",
)
if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once(
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time.",
Comment thread
arpera marked this conversation as resolved.
)
else:
logger.info_once("Using Triton/FLA GDN prefill kernel")


def fi_chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -118,39 +173,21 @@ def fi_chunk_gated_delta_rule(

@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
def __init__(self, head_k_dim: int | None = None) -> None:
super().__init__()
additional_config = get_current_vllm_config().additional_config
assert isinstance(additional_config, dict)
backend_cfg = additional_config.get("gdn_prefill_backend", "auto")
backend = str(backend_cfg).strip().lower()

supports_flashinfer = (
current_platform.is_cuda() and current_platform.is_device_capability(90)
)

if backend == "flashinfer":
use_flashinfer = supports_flashinfer
if not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
elif backend == "triton":
use_flashinfer = False
else:
use_flashinfer = supports_flashinfer

if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once(
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time.",
use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim)
if backend == "flashinfer" and not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
else:
logger.info_once("Using Triton/FLA GDN prefill kernel")
_log_gdn_backend_decision(backend, head_k_dim, use_flashinfer)
Comment thread
arpera marked this conversation as resolved.

self._forward_method = (
self.forward_cuda if use_flashinfer else self.forward_native
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my flavor, I'd place if use_flashinfer inside forward_cuda.

if use_flashinfer:
    return fi_chunk_gated_delta_rule()
else:
    return forward_native()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? Currently we set self._forward_method only once in __init__. You propose to move this if statement inside forward_cuda that means we will process this if statement each time in runtime.

Expand Down Expand Up @@ -380,7 +417,7 @@ def __init__(
prefix=f"{prefix}.out_proj",
)

self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.chunk_gated_delta_rule = ChunkGatedDeltaRule(head_k_dim=self.head_k_dim)
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ def is_device_capability_family(
return False
return (current_capability.to_int() // 10) == (capability // 10)

@classmethod
def get_cuda_runtime_major(cls) -> int:
"""Major ``torch.version.cuda`` version, or ``0`` if undetermined."""
major = (torch.version.cuda or "0").split(".", 1)[0]
return int(major) if major.isdigit() else 0

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
Expand Down
Loading