Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
@triton.jit
def pack_bitmatrix(
bitmatrix,
bm_row_stride,
bm_col_stride,
topk_ids,
n_rows, # n_rows in bitmatrix / topk_ids
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
Expand All @@ -92,19 +94,25 @@ def pack_bitmatrix(
div = indices // 32
rem = indices % 32
one = tl.cast(1, tl.uint32)
zero = tl.cast(0, tl.uint32)

# Iterate through all the relevant bitmatrix columns.
for i in range(bm_cols):
# When BLOCK_SIZE_K=32, offs is just the column index.
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
x = tl.where(
div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0
is_valid_expert = (div[:, :, None] == offs[None, None, :]) & (
indices[:, :, None] >= 0
)
x = tl.where(is_valid_expert, (one << rem)[:, :, None], zero)
# Reduce x to get a single int32_t bitpack.
y = tl.reduce_or(x, axis=1)
bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * bm_cols + offs[None, :]
bitmatrix_ptrs = (
bitmatrix
+ offsets_m[:, None] * bm_row_stride
+ offs[None, :] * bm_col_stride
)
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)


Expand Down Expand Up @@ -506,14 +514,19 @@ def make_routing_data(
BLOCK_SIZE_M = 512
BLOCK_SIZE_K = 32

bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
bm_cols = triton.cdiv(num_local_experts, 32) # n_bitpacks
bitmatrix = torch.zeros(
(n_rows, bm_cols), dtype=torch.uint32, device=topk_ids.device
(bm_cols, triton.cdiv(n_rows, 32) * 32),
dtype=torch.uint32,
device=topk_ids.device,
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defensively use 32 directly. aliasing with BLOCK_SIZE_K will lead to a wrong definition of bitmatrix if BLOCK_SIZE_K changes.

bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows]

grid = (triton.cdiv(n_rows, BLOCK_SIZE_M),)
pack_bitmatrix[grid](
bitmatrix,
bitmatrix.stride(0),
bitmatrix.stride(1),
topk_ids,
n_rows,
bm_cols,
Expand All @@ -525,9 +538,7 @@ def make_routing_data(
bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = (
Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
)
Bitmatrix(bitmatrix, dtype=BIT, shape=bitmatrix_shape)
if not use_legacy_triton_kernels
else Bitmatrix(
bitmatrix,
Expand Down
Loading