diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 18e17a5110c1..398a1811a2ad 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -152,10 +152,13 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: ) use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" is_gather_supported = hasattr(triton.language, "gather") -is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + +is_tma_supported = ( + is_nvidia and 9 <= torch.cuda.get_device_capability(0)[0] < 12 +) and ( hasattr(triton.language, "_experimental_make_tensor_descriptor") or hasattr(triton.language, "make_tensor_descriptor") -) +) # Upper bound < 12 disables TMA on Blackwell (sm_12x): Triton autotuner OOM def get_all_max_shared_mem():