From ecfad707ba0c8c56c79747d13b46f5a056dd71cc Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Mon, 9 Feb 2026 17:25:10 +0000 Subject: [PATCH 1/3] Fall back to the old triton_kernels API in case of ROCm before we can support triton 3.6 Signed-off-by: Gregory Shtrasberg --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 3801814d95f3..1706c71d3c47 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -38,10 +39,11 @@ from triton_kernels.tensor import ( BIT, Bitmatrix, - SparseMatrix, - make_ragged_tensor_metadata, ) from triton_kernels.topk import topk + + if not current_platform.is_rocm(): + from triton_kernels.tensor import SparseMatrix, make_ragged_tensor_metadata except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -101,6 +103,12 @@ def legacy_routing_from_bitmatrix( Replacement for the removed triton_kernels.routing.routing_from_bitmatrix. Creates routing data from a bitmatrix representation. """ + if current_platform.is_rocm(): + from triton_kernels.routing import routing_from_bitmatrix + + return routing_from_bitmatrix( + bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act + ) sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix) dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx combine_indx = sparse_logits.mask_metadata.col_sorted_indx @@ -130,6 +138,10 @@ def legacy_routing( Replacement for the removed triton_kernels.routing.routing function. Computes routing data from gating logits. """ + if current_platform.is_rocm(): + from triton_kernels.routing import routing + + return routing(logits, n_expts_act, sm_first=sm_first) if sm_first: logits = torch.softmax(logits, dim=-1) sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first) @@ -231,11 +243,22 @@ def triton_kernel_fused_experts( ) output_tensor = _resize_cache(output_tensor, (batch_dim, M, K)) - act = FusedActivation( - FnSpecs( - "swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2 - ), - (swiglu_alpha, swiglu_limit), + act = ( + FusedActivation( + FnSpecs( + "swiglu", + triton_kernels.swiglu.swiglu_fn, + ("alpha", "limit"), + reduction_n=2, + ), + (swiglu_alpha, swiglu_limit), + ) + if not current_platform.is_rocm() + else FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), + 2, + ) ) gammas = routing_data.gate_scal if routing_data else None @@ -296,8 +319,17 @@ def make_routing_data( bitmatrix_shape = [n_rows, bm_cols * 32] bitmatrix_shape_max = [n_rows, None] - bitmatrix = Bitmatrix( - bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max + bitmatrix = ( + Bitmatrix( + bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max + ) + if not current_platform.is_rocm() + else Bitmatrix( + bitmatrix, + shape=bitmatrix_shape, + shape_max=bitmatrix_shape_max, + scratchpad=None, + ) ) # matmul_ogs expects invalid topk_weights to be -1s From 9b243f939cfa0dcc495794c26795129c13885b1d Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Mon, 9 Feb 2026 17:33:21 +0000 Subject: [PATCH 2/3] Allow use of new API on ROCm when available Signed-off-by: Gregory Shtrasberg --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 1706c71d3c47..4e04e81ac118 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -25,6 +25,8 @@ logger = init_logger(__name__) +use_legacy_triton_kernels = False + if has_triton_kernels(): try: import triton_kernels.swiglu @@ -42,8 +44,15 @@ ) from triton_kernels.topk import topk - if not current_platform.is_rocm(): - from triton_kernels.tensor import SparseMatrix, make_ragged_tensor_metadata + if current_platform.is_rocm(): + try: + from triton_kernels.tensor import ( + SparseMatrix, + make_ragged_tensor_metadata, + ) + except ImportError: + logger.warning_once("Using legacy triton_kernels on ROCm") + use_legacy_triton_kernels = True except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -103,7 +112,7 @@ def legacy_routing_from_bitmatrix( Replacement for the removed triton_kernels.routing.routing_from_bitmatrix. Creates routing data from a bitmatrix representation. """ - if current_platform.is_rocm(): + if use_legacy_triton_kernels: from triton_kernels.routing import routing_from_bitmatrix return routing_from_bitmatrix( @@ -138,7 +147,7 @@ def legacy_routing( Replacement for the removed triton_kernels.routing.routing function. Computes routing data from gating logits. """ - if current_platform.is_rocm(): + if use_legacy_triton_kernels: from triton_kernels.routing import routing return routing(logits, n_expts_act, sm_first=sm_first) @@ -253,7 +262,7 @@ def triton_kernel_fused_experts( ), (swiglu_alpha, swiglu_limit), ) - if not current_platform.is_rocm() + if not use_legacy_triton_kernels else FusedActivation( FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit), @@ -323,7 +332,7 @@ def make_routing_data( Bitmatrix( bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max ) - if not current_platform.is_rocm() + if not use_legacy_triton_kernels else Bitmatrix( bitmatrix, shape=bitmatrix_shape, From db430879eb8f6190f7ee3fa10261670de13ee2d1 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Mon, 9 Feb 2026 17:47:18 +0000 Subject: [PATCH 3/3] Fix import to happen on all platforms Signed-off-by: Gregory Shtrasberg --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 4e04e81ac118..eafdf97a9575 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -44,15 +44,17 @@ ) from triton_kernels.topk import topk - if current_platform.is_rocm(): - try: - from triton_kernels.tensor import ( - SparseMatrix, - make_ragged_tensor_metadata, - ) - except ImportError: + try: + from triton_kernels.tensor import ( + SparseMatrix, + make_ragged_tensor_metadata, + ) + except ImportError: + if current_platform.is_rocm(): logger.warning_once("Using legacy triton_kernels on ROCm") use_legacy_triton_kernels = True + else: + raise except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton "