From 1adc2a482259152f313499ea0602ee9a2b024212 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 20 Feb 2026 01:15:49 +0000 Subject: [PATCH 1/2] fix triton kernels Signed-off-by: Varun Sundar Rabindranath --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index e03ecd01ae79..dc1cb3933caa 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -71,6 +71,8 @@ @triton.jit def pack_bitmatrix( bitmatrix, + bm_row_stride, + bm_col_stride, topk_ids, n_rows, # n_rows in bitmatrix / topk_ids bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix @@ -92,6 +94,7 @@ def pack_bitmatrix( div = indices // 32 rem = indices % 32 one = tl.cast(1, tl.uint32) + zero = tl.cast(0, tl.uint32) # Iterate through all the relevant bitmatrix columns. for i in range(bm_cols): @@ -99,12 +102,11 @@ def pack_bitmatrix( offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) # All topks that need to go into this column has the correct bit set. # Other bits are 0. x is a 2D tensor. - x = tl.where( - div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0 - ) + is_valid_expert = (div[:, :, None] == offs[None, None, :]) & (indices[:, :, None] >= 0) + x = tl.where(is_valid_expert, (one << rem)[:, :, None], zero) # Reduce x to get a single int32_t bitpack. y = tl.reduce_or(x, axis=1) - bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * bm_cols + offs[None, :] + bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * bm_row_stride + offs[None, :] * bm_col_stride tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) @@ -506,14 +508,15 @@ def make_routing_data( BLOCK_SIZE_M = 512 BLOCK_SIZE_K = 32 - bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks - bitmatrix = torch.zeros( - (n_rows, bm_cols), dtype=torch.uint32, device=topk_ids.device - ) + bm_cols = triton.cdiv(num_local_experts, 32) # n_bitpacks + bitmatrix = torch.zeros((bm_cols, triton.cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=topk_ids.device) + bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows] grid = (triton.cdiv(n_rows, BLOCK_SIZE_M),) pack_bitmatrix[grid]( bitmatrix, + bitmatrix.stride(0), + bitmatrix.stride(1), topk_ids, n_rows, bm_cols, @@ -525,9 +528,7 @@ def make_routing_data( bitmatrix_shape = [n_rows, bm_cols * 32] bitmatrix_shape_max = [n_rows, None] bitmatrix = ( - Bitmatrix( - bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max - ) + Bitmatrix(bitmatrix, dtype=BIT, shape=bitmatrix_shape) if not use_legacy_triton_kernels else Bitmatrix( bitmatrix, From 6ee270e48d4b0d84250c2ea41e93384bda45af64 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 2 Mar 2026 20:11:21 +0000 Subject: [PATCH 2/2] fix lint Signed-off-by: Varun Sundar Rabindranath --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index dc1cb3933caa..9ee6a6f42fb6 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -94,7 +94,7 @@ def pack_bitmatrix( div = indices // 32 rem = indices % 32 one = tl.cast(1, tl.uint32) - zero = tl.cast(0, tl.uint32) + zero = tl.cast(0, tl.uint32) # Iterate through all the relevant bitmatrix columns. for i in range(bm_cols): @@ -102,11 +102,17 @@ def pack_bitmatrix( offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) # All topks that need to go into this column has the correct bit set. # Other bits are 0. x is a 2D tensor. - is_valid_expert = (div[:, :, None] == offs[None, None, :]) & (indices[:, :, None] >= 0) + is_valid_expert = (div[:, :, None] == offs[None, None, :]) & ( + indices[:, :, None] >= 0 + ) x = tl.where(is_valid_expert, (one << rem)[:, :, None], zero) # Reduce x to get a single int32_t bitpack. y = tl.reduce_or(x, axis=1) - bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * bm_row_stride + offs[None, :] * bm_col_stride + bitmatrix_ptrs = ( + bitmatrix + + offsets_m[:, None] * bm_row_stride + + offs[None, :] * bm_col_stride + ) tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) @@ -509,7 +515,11 @@ def make_routing_data( BLOCK_SIZE_K = 32 bm_cols = triton.cdiv(num_local_experts, 32) # n_bitpacks - bitmatrix = torch.zeros((bm_cols, triton.cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=topk_ids.device) + bitmatrix = torch.zeros( + (bm_cols, triton.cdiv(n_rows, 32) * 32), + dtype=torch.uint32, + device=topk_ids.device, + ) bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows] grid = (triton.cdiv(n_rows, BLOCK_SIZE_M),)