diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py b/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py index de4ca4a6c4..90459af4f5 100644 --- a/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py +++ b/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py @@ -102,3 +102,33 @@ def _sum_bitmatrix_rows( ) # tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret) + + +@triton.jit +def _sum_bitmatrix_rows_fused( + B, + shape_bm, + stride_bm, + stride_bn, + Ret, + N_BLKS_BITMATRIX: tl.constexpr, + BLOCK_M: tl.constexpr, + EVEN_M: tl.constexpr, +): + if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr(): + shape_bm = tl.load(shape_bm) + for i in tl.static_range(N_BLKS_BITMATRIX): + offs_m = tl.arange(0, BLOCK_M) + offs_n = i * 32 + tl.arange(0, 32) + n_rows = shape_bm + if EVEN_M: + bits = tl.load(B + i * stride_bn + offs_m * stride_bm) + else: + bits = tl.load( + B + i * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0 + ) + bits = tl.reshape(bits, [1, BLOCK_M]) + ret = vpopc(bits) # [1, 32] + ret = tl.reshape(ret, [32]) + + tl.store(Ret + offs_n, ret) diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py b/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py index 7c20843789..ff0b32ac70 100644 --- a/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py +++ b/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py @@ -81,3 +81,12 @@ def _expt_data_compute_stage2( data = (block_offs << 16) + expt_id tl.store(TileInfo + block_offs, data, mask=block_offs < n_blocks) block_offs += BLOCK + + +@triton.jit +def _expt_data_compute_stage2_fused(expt_id, Hist, TileStart, TileInfo): + n_tokens = tl.load(Hist + expt_id) + if n_tokens == 0: + return + TileInfo += tl.load(TileStart + expt_id) + tl.store(TileInfo, expt_id) diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/routing.py b/aiter/ops/triton/_triton_kernels/moe_routing/routing.py index 902279b7a4..f7c2c856e1 100644 --- a/aiter/ops/triton/_triton_kernels/moe_routing/routing.py +++ b/aiter/ops/triton/_triton_kernels/moe_routing/routing.py @@ -4,6 +4,10 @@ from aiter.ops.triton._triton_kernels.moe_routing.expt_data import ( _expt_data_compute_stage1, _expt_data_compute_stage2, + _expt_data_compute_stage2_fused, +) +from aiter.ops.triton._triton_kernels.moe_routing.bitmatrix import ( + _sum_bitmatrix_rows_fused, ) @@ -86,6 +90,66 @@ def _routing_compute_indx( tl.store(GateScal + gates, gate_scal, mask=mask) +@triton.jit +def _routing_compute_indx_fused( + GatherIndx, + ScatterIndx, + GateScal, + ExptScal, + ExptIndx, + TokensStart, + n_gates, + BLOCK_M: tl.constexpr, + EVEN_M: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, +): + + tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768) + + local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M) + offs = local_offs + if EVEN_M: + expert = tl.load(ExptIndx + offs).to(tl.uint32) + else: + expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32) + + # stable-sort by expert ID: + kv_pairs = ((expert << 16) | local_offs).to(tl.uint32) + kv_pairs = tl.sort(kv_pairs, 0) + expert = kv_pairs >> 16 + offs = kv_pairs & 0xFFFF + + if EVEN_M: + gate_scal = tl.load(ExptScal + offs) + + # compute run lengths in expert-sorted order: + x = kv_pairs & 0xFFFF0000 | 0x00000001 + expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF + + gates = tl.load(TokensStart + expert) + gates += exclusive_run_lengths + + tl.store(ScatterIndx + offs, gates) + tl.store(GatherIndx + gates, offs) + tl.store(GateScal + gates, gate_scal) + else: + mask = expert != 0xFFFF + gate_scal = tl.load(ExptScal + offs, mask=mask) + + # compute run lengths in expert-sorted order: + x = kv_pairs & 0xFFFF0000 | 0x00000001 + expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF + + gates = tl.load(TokensStart + expert, mask=mask) + gates += exclusive_run_lengths + + tl.store(ScatterIndx + offs, gates, mask=mask) + tl.store(GatherIndx + gates, offs, mask=mask) + tl.store(GateScal + gates, gate_scal, mask=mask) + + @triton.jit def _combined_routing( GatherIndx, @@ -148,3 +212,80 @@ def _combined_routing( EVEN_M, N_EXPTS_ACT, ) + + +@triton.jit +def _combined_routing_fused( + GatherIndx, + ScatterIndx, + GateScal, + ExptScal, + ExptIndx, + Bitmatrix, + shape_bm, + stride_bm, + stride_bn, + N_BLKS_BITMATRIX: tl.constexpr, + n_gates, + BLOCK_M: tl.constexpr, + EVEN_M: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + N_EXPTS_TOT: tl.constexpr, + ExpertHist, + TokenStart, + TileStart, + blocks1a, + MDTileInfo, + max_num_tiles, + tile_dim_log2: tl.constexpr, + BLOCK_A: tl.constexpr, + EQUAL_A: tl.constexpr, +): + + pid = tl.program_id(0) + + _sum_bitmatrix_rows_fused( + Bitmatrix, + shape_bm, + stride_bm, + stride_bn, + ExpertHist, + N_BLKS_BITMATRIX, + BLOCK_M, + EVEN_M, + ) + + if pid != 0 and pid < blocks1a: + n_tokens = tl.load(ExpertHist + pid) + if n_tokens == 0: + return + + _expt_data_compute_stage1( + pid, + ExpertHist, + N_EXPTS_TOT, + TokenStart, + TileStart, + MDTileInfo, + max_num_tiles, + n_gates, + tile_dim_log2, + BLOCK_A, + EQUAL_A, + ) + + if pid < blocks1a: + _expt_data_compute_stage2_fused(pid, ExpertHist, TileStart, MDTileInfo) + else: + _routing_compute_indx_fused( + GatherIndx, + ScatterIndx, + GateScal, + ExptScal, + ExptIndx, + TokenStart, + n_gates, + BLOCK_M, + EVEN_M, + N_EXPTS_ACT, + ) diff --git a/aiter/ops/triton/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe_op_gemm_a8w4.py index 7cc5848851..8f433cc6f8 100644 --- a/aiter/ops/triton/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/moe_op_gemm_a8w4.py @@ -75,11 +75,11 @@ def get_kernel_config(m, n, k, routing_data): xcd_swizzle = num_xcds w_cache_modifier = ".cg" if block_m <= 32 else None num_stages = 2 - split_k = 1 + block_k = 256 + if block_m == 16: block_n = 128 - block_k = 256 num_warps = 4 grid_m = routing_data.n_blocks(m, block_m) @@ -90,10 +90,20 @@ def get_kernel_config(m, n, k, routing_data): grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) grid = grid_m * grid_n * split_k + + elif block_m == 32: + if n <= 1024: + block_n = 128 + num_warps = 4 + elif n <= 4096: + block_n = 256 + num_warps = 8 + else: + block_n = 512 + num_warps = 8 + else: - # for scale preshuffling block_n = 512 - block_k = 256 num_warps = 8 ret = { diff --git a/aiter/ops/triton/moe_routing/routing.py b/aiter/ops/triton/moe_routing/routing.py index c42d22874b..f2dd5337f3 100644 --- a/aiter/ops/triton/moe_routing/routing.py +++ b/aiter/ops/triton/moe_routing/routing.py @@ -2,7 +2,10 @@ import torch import triton from dataclasses import dataclass, field -from aiter.ops.triton._triton_kernels.moe_routing.routing import _combined_routing +from aiter.ops.triton._triton_kernels.moe_routing.routing import ( + _combined_routing, + _combined_routing_fused, +) @dataclass @@ -125,6 +128,72 @@ def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOC ) +def sort_tokens_fused( + expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M +): + cdiv = triton.cdiv + + device = expt_scal.device + dtype = expt_scal.dtype + n_tokens, n_expts_act = expt_scal.shape + n_gates = n_tokens * n_expts_act + + hist = bitmatrix.scratchpad + hist = hist[:n_expts_tot] + assert hist.dtype == torch.int32 + num_blocks_bitmatrix = cdiv(bitmatrix.shape[1], 32) + # scratchpad + combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device) + # output + topk_indx = combined_indx[:n_gates] + gate_indx = combined_indx[n_gates:] + gate_scal = torch.empty(n_gates, dtype=dtype, device=device) + + token_offs_raw, token_offs_pad, block_pid_map, blocks1a, BLOCK_A, block_m_log2 = ( + _compute_expt_data_internal(n_expts_tot, n_gates, block_m, device) + ) + + blocks1b = cdiv(n_tokens, HIST_BLOCK_M) + + _combined_routing_fused[(blocks1a + blocks1b,)]( + topk_indx, + gate_indx, + gate_scal, # outputs + expt_scal, + expt_indx, + bitmatrix.data, + bitmatrix.shape[0], + bitmatrix.data.stride(0), + bitmatrix.data.stride(1), + num_blocks_bitmatrix, + n_gates, # input shape + HIST_BLOCK_M, + n_tokens % HIST_BLOCK_M == 0, + n_expts_act, # constants + n_expts_tot, + hist, + token_offs_raw, + token_offs_pad, # + blocks1a, + block_pid_map, + block_pid_map.shape[0], # + block_m_log2, + BLOCK_A=BLOCK_A, + EQUAL_A=(hist.shape[0] == BLOCK_A), # optimization parameters + num_warps=1, + ) + + return ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) + + # -------------------------- # expt_data # -------------------------- @@ -186,15 +255,31 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None): m = num_tokens * n_expts_act tokens_per_expt = max(1, m // n_expts_tot) block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) - ( - hist, - topk_indx, - gate_indx, - gate_scal, - token_offs_raw, - token_offs_pad, - block_pid_map, - ) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M) + if num_tokens <= 16: + HIST_BLOCK_M = triton.next_power_of_2(num_tokens) + ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) = sort_tokens_fused( + expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M + ) + else: + ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) = sort_tokens( + expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M + ) expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) # pack the matmul data structure