Skip to content

Commit 749cc16

Browse files
committed
extra 4x4x4 step
1 parent a3ba9ad commit 749cc16

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

python/tvm/topi/cuda/tensorcore_alter_op.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,17 @@ def _dense_legalize(attrs, inputs, arg_types):
149149
# Pad input and output channels to use tensorcore schedule.
150150
if dtype in ["float16", "int8", "uint8"]:
151151
# The shape of (M, K, N) must be multiple of
152-
# (16, 16, 16) or (32, 16, 8) or (8, 16, 32) or (4, 4, 4)
152+
# (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
153+
# from https://arxiv.org/pdf/1811.09736.pdf
153154
if (
154155
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
155156
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
156157
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
157-
or (M % 4 == 0 and K % 4 == 0 and N % 4 == 0)
158158
):
159159
# no need to pad
160160
return None
161161

162-
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32), (4, 4, 4)]
162+
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
163163
elif dtype in ["int4", "uint4"]:
164164
if M % 8 == 0 and K % 32 == 0 and N % 8 == 0:
165165
# no need to pad
@@ -172,7 +172,19 @@ def _dense_legalize(attrs, inputs, arg_types):
172172

173173
if extra_flops_ratio > 2:
174174
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
175-
return None
175+
176+
# If tensorcore schedule padding fails, pad to nearest upward 4x4x4 as long as
177+
# the additional flops ratio isn't double or more.
178+
# Note that 4x4x4 is invalid for tensorcore scheduling, but padding upwards to 4x4x4
179+
# doesn't hurt if tensorcore padding has already failed.
180+
if M % 4 == 0 and K % 4 == 0 and N % 4 == 0:
181+
# No need to pad
182+
return None
183+
(dm, dk, dn) = _pad_to(M, K, N, (4, 4, 4))
184+
extra_flops_ratio = _extra_flops(M, K, N, dm, dk, dn) / (M * K * N)
185+
186+
if extra_flops_ratio > 2:
187+
return None
176188

177189
logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio)
178190

@@ -200,14 +212,18 @@ def pad_to_tensorcore(M, K, N, candidates):
200212
best_pad = (0, 0, 0)
201213
for padding in candidates:
202214
dm, dk, dn = _pad_to(M, K, N, padding)
203-
e = (M + dm) * (N + dn) * (K + dk) - M * N * K
215+
e = _extra_flops(M, K, N, dm, dk, dn)
204216
# print(dm, dk, dn, e, flops)
205217
if e < extra_flops:
206218
extra_flops = e
207219
best_pad = (dm, dk, dn)
208220
return best_pad, extra_flops / flops
209221

210222

223+
def _extra_flops(M, K, N, dm, dk, dn):
224+
return (M + dm) * (N + dn) * (K + dk) - M * N * K
225+
226+
211227
def _pad_to(M, K, N, PADDING):
212228
dm, dk, dn = 0, 0, 0
213229

tests/python/relay/test_pass_legalize_tensorcore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def expected():
259259
_test_legalize_dense((8, 15), (32, 15), (0, 1, 0), dtype)
260260
_test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype)
261261
_test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype)
262-
_test_legalize_dense((3, 16), (32, 16), (1, 0, 0), dtype)
262+
_test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype)
263263
_test_legalize_dense((1, 16), (32, 16), (0, 0, 0), dtype, False)
264264

265265
# Test if units parameter is correctly updated
@@ -272,7 +272,7 @@ def expected():
272272
_test_legalize_dense((7, 31), (31, 31), (1, 1, 1), "int4")
273273
_test_legalize_dense((3, 32), (32, 32), (5, 0, 0), "int4")
274274
_test_legalize_dense((8, 16), (32, 16), (0, 16, 0), "int4")
275-
_test_legalize_dense((2, 16), (32, 16), (0, 0, 0), "int4", False)
275+
_test_legalize_dense((1, 16), (32, 16), (0, 0, 0), "int4", False)
276276

277277

278278
@tvm.testing.uses_gpu

0 commit comments

Comments
 (0)