Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ def traverse(op):
return s


def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)

return tx, bx, by, bz
return tx, bx, by


def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None):
Expand All @@ -87,13 +85,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_by = axis_mul_before * axis_mul_after

# Copy the keys_in to initial output
with ib.new_scope():
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
tid = bx * nthread_tx + tx
by, bz = by % axis_mul_before, by // axis_mul_before
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
keys_out[idx] = keys_in[idx]
Expand Down Expand Up @@ -122,11 +120,11 @@ def _odd_even_sort(
):
nthread_tx = block_size // 2
nthread_bx = ceil_div(size, block_size)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_by = axis_mul_before * axis_mul_after
with ib.new_scope():
ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
by, bz = by % axis_mul_before, by // axis_mul_before
tid = 2 * tx
start = bx * block_size

Expand Down Expand Up @@ -222,7 +220,6 @@ def _sort_common(

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_by = axis_mul_before * axis_mul_after
nthread_bz = 1
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)

Expand Down Expand Up @@ -334,12 +331,13 @@ def assign_j():
ntx = max_threads
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
else:
ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
by, bz = by % nthread_by, by // nthread_by

def mergepath(
source,
Expand Down Expand Up @@ -471,18 +469,17 @@ def do_merge(first, last):
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_by = axis_mul_before * axis_mul_after
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)
## if the final sorted data ended up in the swap, copy it to the real output
with ib.if_scope(
tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)
):
with ib.new_scope():
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
tid = bx * nthread_tx + tx
idx = (by * axis_mul_after + bz) * size + tid
idx = by * size + tid
with ib.if_scope(tid < size):
keys[idx] = keys_swap[idx]
if values is not None:
Expand Down