Skip to content

Commit 69e5d4e

Browse files
author
Siyuan Feng
committed
[DLight] Fix Matmul rule for Conv3D
Currently, the matmul rule for Conv3D is incorrect, due to the incorrect reindexing of the input tensor. This commit fixes the issue by correctly The `index map` of `transform_layout` should be calculated after the `reindex` process
1 parent f52143e commit 69e5d4e

File tree

2 files changed

+172
-48
lines changed

2 files changed

+172
-48
lines changed

python/tvm/dlight/gpu/matmul.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
364364
if reduction_blocks is None:
365365
return None
366366

367-
main_block = reduction_blocks[0]
368-
block_stmt = sch.get(main_block)
369-
index_maps = get_index_map(block_stmt)
370-
if index_maps is None:
371-
return None
372-
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
373-
374367
# Step 0. Configs
375368
block_size_x: int = 16
376369
block_size_y: int = 16
@@ -382,12 +375,20 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
382375
vector_size: int = 4
383376

384377
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
385-
block = sch.reindex(main_block, ("read", 0))
386-
sch.transform_layout(block, ("write", 0), a_index_map)
387-
block = sch.reindex(main_block, ("read", 1))
388-
sch.transform_layout(block, ("write", 0), b_index_map)
389-
block = sch.reindex(main_block, ("write", 0))
390-
sch.transform_layout(block, ("read", 0), c_index_map)
378+
# Reindex first and than analyze the index map
379+
main_block = reduction_blocks[0]
380+
reindex_a = sch.reindex(main_block, ("read", 0))
381+
reindex_b = sch.reindex(main_block, ("read", 1))
382+
reindex_c = sch.reindex(main_block, ("write", 0))
383+
384+
index_maps = get_index_map(sch.get(main_block))
385+
if index_maps is None:
386+
return None
387+
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
388+
389+
sch.transform_layout(reindex_a, ("write", 0), a_index_map)
390+
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
391+
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
391392
sch.transform_block_layout(main_block, matmul_index_map)
392393

393394
# Step 2. Padding for dynamic shape kernels
@@ -508,13 +509,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
508509
if reduction_blocks is None:
509510
return None
510511

511-
main_block = reduction_blocks[0]
512-
block_stmt = sch.get(main_block)
513-
index_maps = get_index_map(block_stmt)
514-
if index_maps is None:
515-
return None
516-
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
517-
518512
# Start Schedule
519513
# Step 0. Get schedule config.
520514
# NOTE: we can analyze the config by the hardware spec in the future
@@ -539,12 +533,20 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
539533
k_pad_factor = k_factors[1]
540534

541535
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
542-
block = sch.reindex(main_block, ("read", 0))
543-
sch.transform_layout(block, ("write", 0), a_index_map)
544-
block = sch.reindex(main_block, ("read", 1))
545-
sch.transform_layout(block, ("write", 0), b_index_map)
546-
block = sch.reindex(main_block, ("write", 0))
547-
sch.transform_layout(block, ("read", 0), c_index_map)
536+
# Reindex first and than analyze the index map
537+
main_block = reduction_blocks[0]
538+
reindex_a = sch.reindex(main_block, ("read", 0))
539+
reindex_b = sch.reindex(main_block, ("read", 1))
540+
reindex_c = sch.reindex(main_block, ("write", 0))
541+
542+
index_maps = get_index_map(sch.get(main_block))
543+
if index_maps is None:
544+
return None
545+
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
546+
547+
sch.transform_layout(reindex_a, ("write", 0), a_index_map)
548+
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
549+
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
548550
sch.transform_block_layout(main_block, matmul_index_map)
549551

550552
# Step 2. Padding for dynamic shape kernels
@@ -729,13 +731,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
729731
if reduction_blocks is None:
730732
return None
731733

732-
main_block = reduction_blocks[0]
733-
block_stmt = sch.get(main_block)
734-
index_maps = get_index_map(block_stmt)
735-
if index_maps is None:
736-
return None
737-
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
738-
739734
# Start Schedule
740735
# Step 0. Get schedule config.
741736
# NOTE: we can analyze the config by the hardware spec in the future
@@ -760,12 +755,20 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
760755
k_pad_factor = k_factors[1]
761756

762757
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
763-
block = sch.reindex(main_block, ("read", 0))
764-
sch.transform_layout(block, ("write", 0), a_index_map)
765-
block = sch.reindex(main_block, ("read", 1))
766-
sch.transform_layout(block, ("write", 0), b_index_map)
767-
block = sch.reindex(main_block, ("write", 0))
768-
sch.transform_layout(block, ("read", 0), c_index_map)
758+
# Reindex first and than analyze the index map
759+
main_block = reduction_blocks[0]
760+
reindex_a = sch.reindex(main_block, ("read", 0))
761+
reindex_b = sch.reindex(main_block, ("read", 1))
762+
reindex_c = sch.reindex(main_block, ("write", 0))
763+
764+
index_maps = get_index_map(sch.get(main_block))
765+
if index_maps is None:
766+
return None
767+
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
768+
769+
sch.transform_layout(reindex_a, ("write", 0), a_index_map)
770+
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
771+
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
769772
sch.transform_block_layout(main_block, matmul_index_map)
770773

771774
# Step 2. Padding for dynamic shape kernels
@@ -979,9 +982,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
979982

980983
main_block = reduction_blocks[0]
981984
block_stmt = sch.get(main_block)
982-
index_maps = get_index_map(block_stmt)
983-
if index_maps is None:
984-
return None
985985

986986
main_block_info = get_block_info(sch, main_block)
987987
iter_infos = main_block_info.iters
@@ -1000,13 +1000,19 @@ def is_inner_reduction(block_stmt, iter_infos):
10001000
return ret
10011001

10021002
# Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
1003+
# Reindex first and than analyze the index map
1004+
reindex_a = sch.reindex(main_block, ("read", 0))
1005+
reindex_b = sch.reindex(main_block, ("read", 1))
1006+
reindex_c = sch.reindex(main_block, ("write", 0))
1007+
1008+
index_maps = get_index_map(sch.get(main_block))
1009+
if index_maps is None:
1010+
return None
10031011
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
1004-
block = sch.reindex(main_block, ("read", 0))
1005-
sch.transform_layout(block, ("write", 0), a_index_map)
1006-
block = sch.reindex(main_block, ("read", 1))
1007-
sch.transform_layout(block, ("write", 0), b_index_map)
1008-
block = sch.reindex(main_block, ("write", 0))
1009-
sch.transform_layout(block, ("read", 0), c_index_map)
1012+
1013+
sch.transform_layout(reindex_a, ("write", 0), a_index_map)
1014+
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
1015+
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
10101016
sch.transform_block_layout(main_block, matmul_index_map)
10111017

10121018
# Step 1. Check Tensor Core support
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
import pytest
19+
20+
import tvm.testing
21+
from tvm import dlight as dl
22+
from tvm.script import tir as T
23+
from tvm.target import Target
24+
25+
26+
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
27+
@pytest.fixture
28+
def transform(self):
29+
def transform(mod):
30+
with Target("nvidia/geforce-gtx-1080-ti"):
31+
# Use Matmul rule for Conv for now
32+
return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
33+
34+
return transform
35+
36+
37+
class TestConv3d(BaseBeforeAfter):
38+
# fmt: off
39+
@T.prim_func
40+
def before(
41+
A: T.Buffer((14308, 3, 2, 14, 14), "float16"),
42+
W: T.Buffer((1280, 3, 2, 14, 14), "float16"),
43+
C: T.Buffer((14308, 1280, 1, 1, 1), "float16"),
44+
):
45+
pad_A = T.alloc_buffer((14308, 3, 2, 14, 14), "float16")
46+
for i0, i1, i2, i3, i4 in T.grid(14308, 3, 2, 14, 14):
47+
with T.block("pad_A"):
48+
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
49+
pad_A[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4]
50+
for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(14308, 1280, 1, 1, 1, 3, 2, 14, 14):
51+
with T.block("C"):
52+
v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz])
53+
with T.init():
54+
C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0)
55+
C[v_nn, v_ff, v_yy, v_xx, v_zz] += pad_A[v_nn, v_rc, v_yy * 2 + v_ry, v_xx * 14 + v_rx, v_zz * 14 + v_rz]* W[v_ff, v_rc, v_ry, v_rx, v_rz]
56+
57+
@T.prim_func
58+
def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")):
59+
T.func_attr({"tir.is_scheduled": 1})
60+
# with T.block("root"):
61+
C_reindex_pad_local = T.alloc_buffer((1, 14336, 1280), "float16", scope="local")
62+
pad_A_reindex_pad_shared = T.alloc_buffer((1, 14336, 1184), "float16", scope="shared")
63+
W_reindex_pad_shared = T.alloc_buffer((1, 1280, 1184), "float16", scope="shared")
64+
for ax0_ax2_0_fused in T.thread_binding(20, thread="blockIdx.y"):
65+
for ax1_0 in T.thread_binding(448, thread="blockIdx.x"):
66+
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
67+
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
68+
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
69+
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
70+
for ax1_3_init, ax2_3_0_init in T.grid(4, 2):
71+
for ax2_3_1_init in T.vectorized(2):
72+
with T.block("C_init"):
73+
v0 = T.axis.spatial(1, 0)
74+
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
75+
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init)
76+
C_reindex_pad_local[0, v1, v2] = T.float16(0.0)
77+
for ax3_0 in range(74):
78+
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
79+
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
80+
for ax0_ax1_ax2_fused_2 in range(2):
81+
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
82+
with T.block("pad_A_reindex_pad_shared"):
83+
v0 = T.axis.spatial(1, 0)
84+
v1 = T.axis.spatial(14336, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
85+
v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
86+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
87+
pad_A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 14308 and v2 < 1176, A[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0))
88+
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
89+
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
90+
for ax0_ax1_ax2_fused_2 in range(4):
91+
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
92+
with T.block("W_reindex_pad_shared"):
93+
v0 = T.axis.spatial(1, 0)
94+
v1 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
95+
v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
96+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
97+
W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 1176, W[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0))
98+
for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2):
99+
for ax2_3_1 in T.vectorized(2):
100+
with T.block("C_update"):
101+
v0 = T.axis.spatial(1, 0)
102+
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
103+
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1)
104+
v3 = T.axis.reduce(1184, ax3_0 * 16 + ax3_1)
105+
C_reindex_pad_local[0, v1, v2] = C_reindex_pad_local[0, v1, v2] + pad_A_reindex_pad_shared[0, v1, v3] * W_reindex_pad_shared[0, v2, v3]
106+
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
107+
for ax2_1_1 in T.vectorized(2):
108+
with T.block("C_reindex_pad_local"):
109+
v0 = T.axis.spatial(1, ax0)
110+
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_2 * 4 + ax1)
111+
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
112+
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < 14308)
113+
C[v1, v2, 0, 0, 0] = C_reindex_pad_local[v0, v1, v2]
114+
# fmt: on
115+
116+
117+
if __name__ == "__main__":
118+
tvm.testing.main()

0 commit comments

Comments
 (0)