Skip to content

Commit da56c89

Browse files
authored
[Dlight] Enhance vectorization for gpu matmul (#16894)
* [Dlight] Enhance vectorization for gpu matmul * fix
1 parent b3ffd97 commit da56c89

File tree

3 files changed

+54
-52
lines changed

3 files changed

+54
-52
lines changed

python/tvm/dlight/gpu/matmul.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
874874
x, [None, config.vthread_x, config.block_size_x, config.micro_size_x]
875875
)
876876
ko, ki = sch.split(k, factors=[None, config.micro_size_k])
877-
sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
877+
reordered_loops = [by, bx, vy, vx, ty, tx, ko, ki] + (
878+
[yi, xi] if config.inner_x else [xi, yi]
879+
)
880+
sch.reorder(*reordered_loops)
878881
by = sch.fuse(batch, by)
879882
sch.bind(bx, "blockIdx.x")
880883
sch.bind(by, "blockIdx.y")
@@ -884,7 +887,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
884887
sch.bind(tx, "threadIdx.x")
885888
inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y
886889
if inner_loop % config.vector_size == 0:
887-
_, v = sch.split(xi, [None, config.vector_size])
890+
_, v = sch.split(reordered_loops[-1], [None, config.vector_size])
888891
sch.vectorize(v)
889892

890893
if config.unroll > 0:

0 commit comments

Comments
 (0)