@@ -874,7 +874,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
874874 x, [None, config.vthread_x, config.block_size_x, config.micro_size_x]
875875 )
876876 ko, ki = sch.split(k, factors=[None, config.micro_size_k])
877- sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
877+ reordered_loops = [by, bx, vy, vx, ty, tx, ko, ki] + (
878+ [yi, xi] if config.inner_x else [xi, yi]
879+ )
880+ sch.reorder(*reordered_loops)
878881 by = sch.fuse(batch, by)
879882 sch.bind(bx, "blockIdx.x")
880883 sch.bind(by, "blockIdx.y")
@@ -884,7 +887,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
884887 sch.bind(tx, "threadIdx.x")
885888 inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y
886889 if inner_loop % config.vector_size == 0:
887- _, v = sch.split(xi , [None, config.vector_size])
890+ _, v = sch.split(reordered_loops[-1] , [None, config.vector_size])
888891 sch.vectorize(v)
889892
890893 if config.unroll > 0:
0 commit comments