diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index d812ec433720..32cd1dade96a 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -16,10 +16,14 @@ from .index import prepare_chunk_indices from .op import exp -from .utils import FLA_CHUNK_SIZE, check_shared_mem, is_nvidia_hopper - -BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] -NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] +from .utils import FLA_CHUNK_SIZE, check_shared_mem + +# BKV and NUM_WARPS are tile dimensions / warp counts for the Triton +# autotuner. The autotuner compiles each combination and skips any that +# exceed available shared memory, so offering a generous range is safe +# and lets the autotuner find the best fit for each GPU. +BKV_LIST = [32, 64, 128] +NUM_WARPS = [2, 4, 8] @triton.heuristics( diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 83b75e6853d1..3f6dd0e8a38d 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -148,16 +148,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: is_intel = device_platform == "intel" is_nvidia = device_platform == "nvidia" is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) -is_nvidia_hopper = is_nvidia and ( - "NVIDIA H" in torch.cuda.get_device_name(0) - or torch.cuda.get_device_capability()[0] >= 9 -) use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" is_gather_supported = hasattr(triton.language, "gather") -is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( - hasattr(triton.language, "_experimental_make_tensor_descriptor") - or hasattr(triton.language, "make_tensor_descriptor") -) def get_all_max_shared_mem(): @@ -172,6 +164,33 @@ def get_all_max_shared_mem(): return [-1] +# TMA code paths in FLA require significant shared memory for the Triton +# autotuner to compile tile configurations. SM12x GPUs (RTX 5090/5080, +# DGX Spark GB10) have TMA hardware but only ~101KB SMEM per SM, which is +# insufficient and causes OOM in fla/solve_tril. Gate on SMEM capacity so +# future SM12x variants with datacenter-class SMEM get TMA automatically. +MIN_SMEM_FOR_TMA = 131072 # 128KB + + +def check_tma_supported(max_shared_mem: int) -> bool: + """Check if TMA code paths should be enabled. + + Uses a shared-memory threshold rather than architecture checks so that + any GPU with sufficient SMEM gets TMA support regardless of arch family. + """ + if max_shared_mem < MIN_SMEM_FOR_TMA: + return False + return ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") + ) + + +is_tma_supported = is_nvidia and check_tma_supported( + get_all_max_shared_mem()[0] +) + + class Backend(Enum): ADA = 101376 # RTX 4090 AMPERE = 166912 # A100