Skip to content

Commit 4b5cd42

Browse files
author
Siyuan Feng
committed
[TOPI] Remove blockIdx.z in topi sort
As `blockIdx.z` is not allowed in WebGPU, this PR split `blockIdx.z` into `blockIdx.y` to support WebGPU
1 parent c0a47ed commit 4b5cd42

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

python/tvm/topi/cuda/sort.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,15 @@ def traverse(op):
5757
return s
5858

5959

60-
def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
60+
def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
6161
tx = te.thread_axis("threadIdx.x")
6262
bx = te.thread_axis("blockIdx.x")
6363
ib.scope_attr(tx, "thread_extent", nthread_tx)
6464
ib.scope_attr(bx, "thread_extent", nthread_bx)
6565

6666
by = te.thread_axis("blockIdx.y")
67-
bz = te.thread_axis("blockIdx.z")
6867
ib.scope_attr(by, "thread_extent", nthread_by)
69-
ib.scope_attr(bz, "thread_extent", nthread_bz)
70-
71-
return tx, bx, by, bz
68+
return tx, bx, by
7269

7370

7471
def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None):
@@ -87,14 +84,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
8784
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
8885
nthread_tx = max_threads
8986
nthread_bx = ceil_div(shape[axis], max_threads)
90-
nthread_by = axis_mul_before
91-
nthread_bz = axis_mul_after
87+
nthread_by = axis_mul_before * axis_mul_after
9288

9389
# Copy the keys_in to initial output
9490
with ib.new_scope():
95-
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
91+
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
9692
tid = bx * nthread_tx + tx
97-
idx = (by * shape[axis] + tid) * axis_mul_after + bz
93+
idx = (by // axis_mul_after * shape[axis] + tid) * axis_mul_after + (by % axis_mul_after)
9894
with ib.if_scope(tid < shape[axis]):
9995
keys_out[idx] = keys_in[idx]
10096
if values_out is not None:
@@ -122,11 +118,10 @@ def _odd_even_sort(
122118
):
123119
nthread_tx = block_size // 2
124120
nthread_bx = ceil_div(size, block_size)
125-
nthread_by = axis_mul_before
126-
nthread_bz = axis_mul_after
121+
nthread_by = axis_mul_before * axis_mul_after
127122
with ib.new_scope():
128123
ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
129-
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
124+
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
130125
tid = 2 * tx
131126
start = bx * block_size
132127

@@ -153,7 +148,7 @@ def _odd_even_sort(
153148
temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local")
154149
temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local")
155150
# Copy data to scratch space
156-
base_idx = by * size * axis_mul_after + bz
151+
base_idx = (by // axis_mul_after) * size * axis_mul_after + (by % axis_mul_after)
157152
with ib.for_range(0, 2) as n:
158153
with ib.if_scope((tid + n + start) < size):
159154
tmp_keys_swap[tid + n] = keys[base_idx + (tid + n + start) * axis_mul_after]
@@ -222,7 +217,6 @@ def _sort_common(
222217

223218
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
224219
nthread_by = axis_mul_before * axis_mul_after
225-
nthread_bz = 1
226220
nthread_tx = max_threads
227221
nthread_bx = ceil_div(size, nthread_tx)
228222

@@ -334,12 +328,12 @@ def assign_j():
334328
ntx = max_threads
335329
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
336330
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
337-
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
331+
tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
338332
else:
339333
ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
340334
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
341335
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
342-
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
336+
tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
343337

344338
def mergepath(
345339
source,
@@ -392,7 +386,7 @@ def merge(source, dest, source_idx, dest_idx):
392386

393387
def mergesort(source, dest, source_idx, dest_idx, size, width, even):
394388
# calculate the start, mid, and end points of this section
395-
start = width * bz
389+
start = width * (by % nbz)
396390
middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
397391
end = cast(tvm.te.min(start + width, size), "int64")
398392
with ib.if_scope(start < size):
@@ -471,18 +465,17 @@ def do_merge(first, last):
471465
width,
472466
tvm.tir.indexmod(l2_width, 2) == 0,
473467
)
474-
nthread_by = axis_mul_before
475-
nthread_bz = axis_mul_after
468+
nthread_by = axis_mul_before * axis_mul_after
476469
nthread_tx = max_threads
477470
nthread_bx = ceil_div(size, nthread_tx)
478471
## if the final sorted data ended up in the swap, copy it to the real output
479472
with ib.if_scope(
480473
tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)
481474
):
482475
with ib.new_scope():
483-
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
476+
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
484477
tid = bx * nthread_tx + tx
485-
idx = (by * axis_mul_after + bz) * size + tid
478+
idx = by * size + tid
486479
with ib.if_scope(tid < size):
487480
keys[idx] = keys_swap[idx]
488481
if values is not None:

0 commit comments

Comments
 (0)