diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 3801814d95f3..eafdf97a9575 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -19,11 +19,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels logger = init_logger(__name__) +use_legacy_triton_kernels = False + if has_triton_kernels(): try: import triton_kernels.swiglu @@ -38,10 +41,20 @@ from triton_kernels.tensor import ( BIT, Bitmatrix, - SparseMatrix, - make_ragged_tensor_metadata, ) from triton_kernels.topk import topk + + try: + from triton_kernels.tensor import ( + SparseMatrix, + make_ragged_tensor_metadata, + ) + except ImportError: + if current_platform.is_rocm(): + logger.warning_once("Using legacy triton_kernels on ROCm") + use_legacy_triton_kernels = True + else: + raise except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix( Replacement for the removed triton_kernels.routing.routing_from_bitmatrix. Creates routing data from a bitmatrix representation. """ + if use_legacy_triton_kernels: + from triton_kernels.routing import routing_from_bitmatrix + + return routing_from_bitmatrix( + bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act + ) sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix) dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx combine_indx = sparse_logits.mask_metadata.col_sorted_indx @@ -130,6 +149,10 @@ def legacy_routing( Replacement for the removed triton_kernels.routing.routing function. Computes routing data from gating logits. """ + if use_legacy_triton_kernels: + from triton_kernels.routing import routing + + return routing(logits, n_expts_act, sm_first=sm_first) if sm_first: logits = torch.softmax(logits, dim=-1) sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first) @@ -231,11 +254,22 @@ def triton_kernel_fused_experts( ) output_tensor = _resize_cache(output_tensor, (batch_dim, M, K)) - act = FusedActivation( - FnSpecs( - "swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2 - ), - (swiglu_alpha, swiglu_limit), + act = ( + FusedActivation( + FnSpecs( + "swiglu", + triton_kernels.swiglu.swiglu_fn, + ("alpha", "limit"), + reduction_n=2, + ), + (swiglu_alpha, swiglu_limit), + ) + if not use_legacy_triton_kernels + else FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), + 2, + ) ) gammas = routing_data.gate_scal if routing_data else None @@ -296,8 +330,17 @@ def make_routing_data( bitmatrix_shape = [n_rows, bm_cols * 32] bitmatrix_shape_max = [n_rows, None] - bitmatrix = Bitmatrix( - bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max + bitmatrix = ( + Bitmatrix( + bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max + ) + if not use_legacy_triton_kernels + else Bitmatrix( + bitmatrix, + shape=bitmatrix_shape, + shape_max=bitmatrix_shape_max, + scratchpad=None, + ) ) # matmul_ogs expects invalid topk_weights to be -1s