Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,33 @@ def _sum_bitmatrix_rows(
)

# tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)


@triton.jit
def _sum_bitmatrix_rows_fused(
B,
shape_bm,
stride_bm,
stride_bn,
Ret,
N_BLKS_BITMATRIX: tl.constexpr,
BLOCK_M: tl.constexpr,
EVEN_M: tl.constexpr,
):
if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
shape_bm = tl.load(shape_bm)
for i in tl.static_range(N_BLKS_BITMATRIX):
offs_m = tl.arange(0, BLOCK_M)
offs_n = i * 32 + tl.arange(0, 32)
n_rows = shape_bm
if EVEN_M:
bits = tl.load(B + i * stride_bn + offs_m * stride_bm)
else:
bits = tl.load(
B + i * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0
)
bits = tl.reshape(bits, [1, BLOCK_M])
ret = vpopc(bits) # [1, 32]
ret = tl.reshape(ret, [32])

tl.store(Ret + offs_n, ret)
9 changes: 9 additions & 0 deletions aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,12 @@ def _expt_data_compute_stage2(
data = (block_offs << 16) + expt_id
tl.store(TileInfo + block_offs, data, mask=block_offs < n_blocks)
block_offs += BLOCK


@triton.jit
def _expt_data_compute_stage2_fused(expt_id, Hist, TileStart, TileInfo):
n_tokens = tl.load(Hist + expt_id)
if n_tokens == 0:
return
TileInfo += tl.load(TileStart + expt_id)
tl.store(TileInfo, expt_id)
141 changes: 141 additions & 0 deletions aiter/ops/triton/_triton_kernels/moe_routing/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from aiter.ops.triton._triton_kernels.moe_routing.expt_data import (
_expt_data_compute_stage1,
_expt_data_compute_stage2,
_expt_data_compute_stage2_fused,
)
from aiter.ops.triton._triton_kernels.moe_routing.bitmatrix import (
_sum_bitmatrix_rows_fused,
)


Expand Down Expand Up @@ -86,6 +90,66 @@ def _routing_compute_indx(
tl.store(GateScal + gates, gate_scal, mask=mask)


@triton.jit
def _routing_compute_indx_fused(
GatherIndx,
ScatterIndx,
GateScal,
ExptScal,
ExptIndx,
TokensStart,
n_gates,
BLOCK_M: tl.constexpr,
EVEN_M: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
):

tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)

local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
offs = local_offs
if EVEN_M:
expert = tl.load(ExptIndx + offs).to(tl.uint32)
else:
expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)

# stable-sort by expert ID:
kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
kv_pairs = tl.sort(kv_pairs, 0)
expert = kv_pairs >> 16
offs = kv_pairs & 0xFFFF

if EVEN_M:
gate_scal = tl.load(ExptScal + offs)

# compute run lengths in expert-sorted order:
x = kv_pairs & 0xFFFF0000 | 0x00000001
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF

gates = tl.load(TokensStart + expert)
gates += exclusive_run_lengths

tl.store(ScatterIndx + offs, gates)
tl.store(GatherIndx + gates, offs)
tl.store(GateScal + gates, gate_scal)
else:
mask = expert != 0xFFFF
gate_scal = tl.load(ExptScal + offs, mask=mask)

# compute run lengths in expert-sorted order:
x = kv_pairs & 0xFFFF0000 | 0x00000001
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF

gates = tl.load(TokensStart + expert, mask=mask)
gates += exclusive_run_lengths

tl.store(ScatterIndx + offs, gates, mask=mask)
tl.store(GatherIndx + gates, offs, mask=mask)
tl.store(GateScal + gates, gate_scal, mask=mask)


@triton.jit
def _combined_routing(
GatherIndx,
Expand Down Expand Up @@ -148,3 +212,80 @@ def _combined_routing(
EVEN_M,
N_EXPTS_ACT,
)


@triton.jit
def _combined_routing_fused(
GatherIndx,
ScatterIndx,
GateScal,
ExptScal,
ExptIndx,
Bitmatrix,
shape_bm,
stride_bm,
stride_bn,
N_BLKS_BITMATRIX: tl.constexpr,
n_gates,
BLOCK_M: tl.constexpr,
EVEN_M: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
N_EXPTS_TOT: tl.constexpr,
ExpertHist,
TokenStart,
TileStart,
blocks1a,
MDTileInfo,
max_num_tiles,
tile_dim_log2: tl.constexpr,
BLOCK_A: tl.constexpr,
EQUAL_A: tl.constexpr,
):

pid = tl.program_id(0)

_sum_bitmatrix_rows_fused(
Bitmatrix,
shape_bm,
stride_bm,
stride_bn,
ExpertHist,
N_BLKS_BITMATRIX,
BLOCK_M,
EVEN_M,
)

if pid != 0 and pid < blocks1a:
n_tokens = tl.load(ExpertHist + pid)
if n_tokens == 0:
return

_expt_data_compute_stage1(
pid,
ExpertHist,
N_EXPTS_TOT,
TokenStart,
TileStart,
MDTileInfo,
max_num_tiles,
n_gates,
tile_dim_log2,
BLOCK_A,
EQUAL_A,
)

if pid < blocks1a:
_expt_data_compute_stage2_fused(pid, ExpertHist, TileStart, MDTileInfo)
else:
_routing_compute_indx_fused(
GatherIndx,
ScatterIndx,
GateScal,
ExptScal,
ExptIndx,
TokenStart,
n_gates,
BLOCK_M,
EVEN_M,
N_EXPTS_ACT,
)
18 changes: 14 additions & 4 deletions aiter/ops/triton/moe_op_gemm_a8w4.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def get_kernel_config(m, n, k, routing_data):
xcd_swizzle = num_xcds
w_cache_modifier = ".cg" if block_m <= 32 else None
num_stages = 2

split_k = 1
block_k = 256

if block_m == 16:
block_n = 128
block_k = 256
num_warps = 4

grid_m = routing_data.n_blocks(m, block_m)
Expand All @@ -90,10 +90,20 @@ def get_kernel_config(m, n, k, routing_data):
grid_m = routing_data.n_blocks(m, block_m)
grid_n = triton.cdiv(n, block_n)
grid = grid_m * grid_n * split_k

elif block_m == 32:
if n <= 1024:
block_n = 128
num_warps = 4
elif n <= 4096:
block_n = 256
num_warps = 8
else:
block_n = 512
num_warps = 8

else:
# for scale preshuffling
block_n = 512
block_k = 256
num_warps = 8

ret = {
Expand Down
105 changes: 95 additions & 10 deletions aiter/ops/triton/moe_routing/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import torch
import triton
from dataclasses import dataclass, field
from aiter.ops.triton._triton_kernels.moe_routing.routing import _combined_routing
from aiter.ops.triton._triton_kernels.moe_routing.routing import (
_combined_routing,
_combined_routing_fused,
)


@dataclass
Expand Down Expand Up @@ -125,6 +128,72 @@ def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOC
)


def sort_tokens_fused(
expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M
):
cdiv = triton.cdiv

device = expt_scal.device
dtype = expt_scal.dtype
n_tokens, n_expts_act = expt_scal.shape
n_gates = n_tokens * n_expts_act

hist = bitmatrix.scratchpad
hist = hist[:n_expts_tot]
assert hist.dtype == torch.int32
num_blocks_bitmatrix = cdiv(bitmatrix.shape[1], 32)
# scratchpad
combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device)
# output
topk_indx = combined_indx[:n_gates]
gate_indx = combined_indx[n_gates:]
gate_scal = torch.empty(n_gates, dtype=dtype, device=device)

token_offs_raw, token_offs_pad, block_pid_map, blocks1a, BLOCK_A, block_m_log2 = (
_compute_expt_data_internal(n_expts_tot, n_gates, block_m, device)
)

blocks1b = cdiv(n_tokens, HIST_BLOCK_M)

_combined_routing_fused[(blocks1a + blocks1b,)](
topk_indx,
gate_indx,
gate_scal, # outputs
expt_scal,
expt_indx,
bitmatrix.data,
bitmatrix.shape[0],
bitmatrix.data.stride(0),
bitmatrix.data.stride(1),
num_blocks_bitmatrix,
n_gates, # input shape
HIST_BLOCK_M,
n_tokens % HIST_BLOCK_M == 0,
n_expts_act, # constants
n_expts_tot,
hist,
token_offs_raw,
token_offs_pad, #
blocks1a,
block_pid_map,
block_pid_map.shape[0], #
block_m_log2,
BLOCK_A=BLOCK_A,
EQUAL_A=(hist.shape[0] == BLOCK_A), # optimization parameters
num_warps=1,
)

return (
hist,
topk_indx,
gate_indx,
gate_scal,
token_offs_raw,
token_offs_pad,
block_pid_map,
)


# --------------------------
# expt_data
# --------------------------
Expand Down Expand Up @@ -186,15 +255,31 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None):
m = num_tokens * n_expts_act
tokens_per_expt = max(1, m // n_expts_tot)
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
(
hist,
topk_indx,
gate_indx,
gate_scal,
token_offs_raw,
token_offs_pad,
block_pid_map,
) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M)
if num_tokens <= 16:
HIST_BLOCK_M = triton.next_power_of_2(num_tokens)
(
hist,
topk_indx,
gate_indx,
gate_scal,
token_offs_raw,
token_offs_pad,
block_pid_map,
) = sort_tokens_fused(
expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M
)
else:
(
hist,
topk_indx,
gate_indx,
gate_scal,
token_offs_raw,
token_offs_pad,
block_pid_map,
) = sort_tokens(
expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M
)
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)

# pack the matmul data structure
Expand Down