@@ -57,18 +57,15 @@ def traverse(op):
5757 return s
5858
5959
60- def _get_threads (ib , nthread_tx , nthread_bx , nthread_by , nthread_bz ):
60+ def _get_threads (ib , nthread_tx , nthread_bx , nthread_by ):
6161 tx = te .thread_axis ("threadIdx.x" )
6262 bx = te .thread_axis ("blockIdx.x" )
6363 ib .scope_attr (tx , "thread_extent" , nthread_tx )
6464 ib .scope_attr (bx , "thread_extent" , nthread_bx )
6565
6666 by = te .thread_axis ("blockIdx.y" )
67- bz = te .thread_axis ("blockIdx.z" )
6867 ib .scope_attr (by , "thread_extent" , nthread_by )
69- ib .scope_attr (bz , "thread_extent" , nthread_bz )
70-
71- return tx , bx , by , bz
68+ return tx , bx , by
7269
7370
7471def _sort_init (ib , shape , axis , keys_in , keys_out , values_out = None , value_init_func = None ):
@@ -87,14 +84,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
8784 max_threads = int (tvm .target .Target .current (allow_none = False ).max_num_threads )
8885 nthread_tx = max_threads
8986 nthread_bx = ceil_div (shape [axis ], max_threads )
90- nthread_by = axis_mul_before
91- nthread_bz = axis_mul_after
87+ nthread_by = axis_mul_before * axis_mul_after
9288
9389 # Copy the keys_in to initial output
9490 with ib .new_scope ():
95- tx , bx , by , bz = _get_threads (ib , nthread_tx , nthread_bx , nthread_by , nthread_bz )
91+ tx , bx , by = _get_threads (ib , nthread_tx , nthread_bx , nthread_by )
9692 tid = bx * nthread_tx + tx
97- idx = (by * shape [axis ] + tid ) * axis_mul_after + bz
93+ idx = (by // axis_mul_after * shape [axis ] + tid ) * axis_mul_after + ( by % axis_mul_after )
9894 with ib .if_scope (tid < shape [axis ]):
9995 keys_out [idx ] = keys_in [idx ]
10096 if values_out is not None :
@@ -122,11 +118,10 @@ def _odd_even_sort(
122118):
123119 nthread_tx = block_size // 2
124120 nthread_bx = ceil_div (size , block_size )
125- nthread_by = axis_mul_before
126- nthread_bz = axis_mul_after
121+ nthread_by = axis_mul_before * axis_mul_after
127122 with ib .new_scope ():
128123 ib .scope_attr (tvm .tir .const (0 ), "hand_threaded" , 0 )
129- tx , bx , by , bz = _get_threads (ib , nthread_tx , nthread_bx , nthread_by , nthread_bz )
124+ tx , bx , by = _get_threads (ib , nthread_tx , nthread_bx , nthread_by )
130125 tid = 2 * tx
131126 start = bx * block_size
132127
@@ -153,7 +148,7 @@ def _odd_even_sort(
153148 temp_cond1 = ib .allocate (keys_swap .dtype , (1 ,), name = "temp_cond1" , scope = "local" )
154149 temp_cond2 = ib .allocate (keys_swap .dtype , (1 ,), name = "temp_cond2" , scope = "local" )
155150 # Copy data to scratch space
156- base_idx = by * size * axis_mul_after + bz
151+ base_idx = ( by // axis_mul_after ) * size * axis_mul_after + ( by % axis_mul_after )
157152 with ib .for_range (0 , 2 ) as n :
158153 with ib .if_scope ((tid + n + start ) < size ):
159154 tmp_keys_swap [tid + n ] = keys [base_idx + (tid + n + start ) * axis_mul_after ]
@@ -222,7 +217,6 @@ def _sort_common(
222217
223218 max_threads = int (tvm .target .Target .current (allow_none = False ).max_num_threads )
224219 nthread_by = axis_mul_before * axis_mul_after
225- nthread_bz = 1
226220 nthread_tx = max_threads
227221 nthread_bx = ceil_div (size , nthread_tx )
228222
@@ -334,12 +328,12 @@ def assign_j():
334328 ntx = max_threads
335329 nbx = tvm .tir .generic .cast (ceil_div (width , max_threads * thread_work ), "int32" )
336330 nbz = tvm .tir .generic .cast (ceil_div (size , width ), "int32" )
337- tx , bx , by , bz = _get_threads (ib , ntx , nbx , nthread_by , nbz )
331+ tx , bx , by = _get_threads (ib , ntx , nbx , nthread_by * nbz )
338332 else :
339333 ntx = tvm .tir .generic .cast (tvm .te .min (max_threads , width ), "int32" )
340334 nbx = tvm .tir .generic .cast (ceil_div (width , max_threads * thread_work ), "int32" )
341335 nbz = tvm .tir .generic .cast (ceil_div (size , width ), "int32" )
342- tx , bx , by , bz = _get_threads (ib , ntx , nbx , nthread_by , nbz )
336+ tx , bx , by = _get_threads (ib , ntx , nbx , nthread_by * nbz )
343337
344338 def mergepath (
345339 source ,
@@ -392,7 +386,7 @@ def merge(source, dest, source_idx, dest_idx):
392386
393387 def mergesort (source , dest , source_idx , dest_idx , size , width , even ):
394388 # calculate the start, mid, and end points of this section
395- start = width * bz
389+ start = width * ( by % nbz )
396390 middle = cast (tvm .te .min (start + tvm .tir .indexdiv (width , 2 ), size ), "int64" )
397391 end = cast (tvm .te .min (start + width , size ), "int64" )
398392 with ib .if_scope (start < size ):
@@ -471,18 +465,17 @@ def do_merge(first, last):
471465 width ,
472466 tvm .tir .indexmod (l2_width , 2 ) == 0 ,
473467 )
474- nthread_by = axis_mul_before
475- nthread_bz = axis_mul_after
468+ nthread_by = axis_mul_before * axis_mul_after
476469 nthread_tx = max_threads
477470 nthread_bx = ceil_div (size , nthread_tx )
478471 ## if the final sorted data ended up in the swap, copy it to the real output
479472 with ib .if_scope (
480473 tvm .tir .all (upper_lim > lower_lim , tvm .tir .indexmod (upper_lim - lower_lim , 2 ) == 1 )
481474 ):
482475 with ib .new_scope ():
483- tx , bx , by , bz = _get_threads (ib , nthread_tx , nthread_bx , nthread_by , nthread_bz )
476+ tx , bx , by = _get_threads (ib , nthread_tx , nthread_bx , nthread_by )
484477 tid = bx * nthread_tx + tx
485- idx = ( by * axis_mul_after + bz ) * size + tid
478+ idx = by * size + tid
486479 with ib .if_scope (tid < size ):
487480 keys [idx ] = keys_swap [idx ]
488481 if values is not None :
0 commit comments