@@ -149,17 +149,17 @@ def _dense_legalize(attrs, inputs, arg_types):
149149 # Pad input and output channels to use tensorcore schedule.
150150 if dtype in ["float16" , "int8" , "uint8" ]:
151151 # The shape of (M, K, N) must be multiple of
152- # (16, 16, 16) or (32, 16, 8) or (8, 16, 32) or (4, 4, 4)
152+ # (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
153+ # from https://arxiv.org/pdf/1811.09736.pdf
153154 if (
154155 (M % 8 == 0 and K % 16 == 0 and N % 32 == 0 )
155156 or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0 )
156157 or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0 )
157- or (M % 4 == 0 and K % 4 == 0 and N % 4 == 0 )
158158 ):
159159 # no need to pad
160160 return None
161161
162- candidates = [(16 , 16 , 16 ), (32 , 16 , 8 ), (8 , 16 , 32 ), ( 4 , 4 , 4 ) ]
162+ candidates = [(16 , 16 , 16 ), (32 , 16 , 8 ), (8 , 16 , 32 )]
163163 elif dtype in ["int4" , "uint4" ]:
164164 if M % 8 == 0 and K % 32 == 0 and N % 8 == 0 :
165165 # no need to pad
@@ -172,7 +172,19 @@ def _dense_legalize(attrs, inputs, arg_types):
172172
173173 if extra_flops_ratio > 2 :
174174 logger .info ("dense pad_to_tensorcore skipped, extra_flops_ratio %s" , extra_flops_ratio )
175- return None
175+
176+ # If tensorcore schedule padding fails, pad to nearest upward 4x4x4 as long as
177+ # the additional flops ratio isn't double or more.
178+ # Note that 4x4x4 is invalid for tensorcore scheduling, but padding upwards to 4x4x4
179+ # doesn't hurt if tensorcore padding has already failed.
180+ if M % 4 == 0 and K % 4 == 0 and N % 4 == 0 :
181+ # No need to pad
182+ return None
183+ (dm , dk , dn ) = _pad_to (M , K , N , (4 , 4 , 4 ))
184+ extra_flops_ratio = _extra_flops (M , K , N , dm , dk , dn ) / (M * K * N )
185+
186+ if extra_flops_ratio > 2 :
187+ return None
176188
177189 logger .info ("dense pad_to_tensorcore, extra_flops_ratio %s" , extra_flops_ratio )
178190
@@ -200,14 +212,18 @@ def pad_to_tensorcore(M, K, N, candidates):
200212 best_pad = (0 , 0 , 0 )
201213 for padding in candidates :
202214 dm , dk , dn = _pad_to (M , K , N , padding )
203- e = ( M + dm ) * ( N + dn ) * ( K + dk ) - M * N * K
215+ e = _extra_flops ( M , K , N , dm , dk , dn )
204216 # print(dm, dk, dn, e, flops)
205217 if e < extra_flops :
206218 extra_flops = e
207219 best_pad = (dm , dk , dn )
208220 return best_pad , extra_flops / flops
209221
210222
223+ def _extra_flops (M , K , N , dm , dk , dn ):
224+ return (M + dm ) * (N + dn ) * (K + dk ) - M * N * K
225+
226+
211227def _pad_to (M , K , N , PADDING ):
212228 dm , dk , dn = 0 , 0 , 0
213229
0 commit comments