Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions vllm/model_executor/layers/fla/ops/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@

from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_CHUNK_SIZE, check_shared_mem, is_nvidia_hopper

BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
from .utils import FLA_CHUNK_SIZE, check_shared_mem

# BKV and NUM_WARPS are tile dimensions / warp counts for the Triton
# autotuner. The autotuner compiles each combination and skips any that
# exceed available shared memory, so offering a generous range is safe
# and lets the autotuner find the best fit for each GPU.
BKV_LIST = [32, 64, 128]
NUM_WARPS = [2, 4, 8]


@triton.heuristics(
Expand Down
35 changes: 27 additions & 8 deletions vllm/model_executor/layers/fla/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
is_intel = device_platform == "intel"
is_nvidia = device_platform == "nvidia"
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0)
or torch.cuda.get_device_capability()[0] >= 9
)
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
is_gather_supported = hasattr(triton.language, "gather")
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)


def get_all_max_shared_mem():
Expand All @@ -172,6 +164,33 @@ def get_all_max_shared_mem():
return [-1]


# TMA code paths in FLA require significant shared memory for the Triton
# autotuner to compile tile configurations. SM12x GPUs (RTX 5090/5080,
# DGX Spark GB10) have TMA hardware but only ~101KB SMEM per SM, which is
# insufficient and causes OOM in fla/solve_tril. Gate on SMEM capacity so
# future SM12x variants with datacenter-class SMEM get TMA automatically.
MIN_SMEM_FOR_TMA = 131072 # 128KB


def check_tma_supported(max_shared_mem: int) -> bool:
"""Check if TMA code paths should be enabled.

Uses a shared-memory threshold rather than architecture checks so that
any GPU with sufficient SMEM gets TMA support regardless of arch family.
"""
if max_shared_mem < MIN_SMEM_FOR_TMA:
return False
return (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)


is_tma_supported = is_nvidia and check_tma_supported(
get_all_max_shared_mem()[0]
)


class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
Expand Down
Loading