From 0b79840bc2f6b76535e40fa3fb7f69bf33cdef32 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 8 May 2024 21:05:53 +0800 Subject: [PATCH] [TOPI] Remove `blockIdx.z` in topi sort As `blockIdx.z` is not allowed in WebGPU, this PR split `blockIdx.z` into `blockIdx.y` to support WebGPU --- python/tvm/topi/cuda/sort.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index dc72aa8cc13b..9151744b6961 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -57,18 +57,16 @@ def traverse(op): return s -def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz): +def _get_threads(ib, nthread_tx, nthread_bx, nthread_by): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) - return tx, bx, by, bz + return tx, bx, by def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): @@ -87,13 +85,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = ceil_div(shape[axis], max_threads) - nthread_by = axis_mul_before - nthread_bz = axis_mul_after + nthread_by = axis_mul_before * axis_mul_after # Copy the keys_in to initial output with ib.new_scope(): - tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz) + tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by) tid = bx * nthread_tx + tx + by, bz = by % axis_mul_before, by // axis_mul_before idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): keys_out[idx] = keys_in[idx] @@ -122,11 +120,11 @@ def _odd_even_sort( ): nthread_tx = block_size // 2 nthread_bx = ceil_div(size, block_size) - nthread_by = axis_mul_before - nthread_bz = axis_mul_after + nthread_by = axis_mul_before * axis_mul_after with ib.new_scope(): ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0) - tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz) + tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by) + by, bz = by % axis_mul_before, by // axis_mul_before tid = 2 * tx start = bx * block_size @@ -222,7 +220,6 @@ def _sort_common( max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_by = axis_mul_before * axis_mul_after - nthread_bz = 1 nthread_tx = max_threads nthread_bx = ceil_div(size, nthread_tx) @@ -334,12 +331,13 @@ def assign_j(): ntx = max_threads nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32") nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32") - tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz) + tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz) else: ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32") nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32") nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32") - tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz) + tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz) + by, bz = by % nthread_by, by // nthread_by def mergepath( source, @@ -471,8 +469,7 @@ def do_merge(first, last): width, tvm.tir.indexmod(l2_width, 2) == 0, ) - nthread_by = axis_mul_before - nthread_bz = axis_mul_after + nthread_by = axis_mul_before * axis_mul_after nthread_tx = max_threads nthread_bx = ceil_div(size, nthread_tx) ## if the final sorted data ended up in the swap, copy it to the real output @@ -480,9 +477,9 @@ def do_merge(first, last): tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1) ): with ib.new_scope(): - tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz) + tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by) tid = bx * nthread_tx + tx - idx = (by * axis_mul_after + bz) * size + tid + idx = by * size + tid with ib.if_scope(tid < size): keys[idx] = keys_swap[idx] if values is not None: