@@ -169,59 +169,39 @@ def profile(
169169 If profile_all is False, return immediately after the first applicable kernel is found.
170170 If use_multiprocessing is True, compile all profiler executables in parallel.
171171 """
172- if True :
173- ops = GENERATOR_FUNC_TABLE [self .sm ](out_dtype , op_creator = create_conv2d_operator )
174- N , H , W , IC = d_shape
175- OC , R , S , _ = w_shape
176- ops = list (filter (lambda op : self .check_align (op ["name" ], IC , OC ), ops ))
177-
178- for op in ops :
179- op ["runtime" ] = - 1
180-
181- if profile_all :
182- self .engine .compile_all (ops , use_multiprocessing )
183-
184- args = [
185- "--n=%d" % N ,
186- "--h=%d" % H ,
187- "--w=%d" % W ,
188- "--k=%d" % OC ,
189- "--c=%d" % IC ,
190- "--r=%d" % R ,
191- "--s=%d" % S ,
192- "--pad_h=%d" % padding [0 ],
193- "--pad_w=%d" % padding [1 ],
194- "--stride_h=%d" % stride [0 ],
195- "--stride_w=%d" % stride [1 ],
196- "--dilation_h=%d" % dilation [0 ],
197- "--dilation_w=%d" % dilation [1 ],
198- ]
199- for op in ops :
200- out = self .engine .evaluate (op , args )
201- op ["runtime" ] = out
202- if out > 0 and profile_all is False :
203- break
204-
205- valid_ops = filter (lambda op : op ["runtime" ] > 0 , ops )
206- output = sorted (valid_ops , key = lambda i : i ["runtime" ])
207- # self.cache[(M, N, K)] = output[0]
208- return output [0 ]
209-
210- else :
211- B , _ , _ , IC = d_shape
212- OC , R , S , _ = w_shape
213- _ , P , Q , _ = out_shape
214-
215- M = B * P * Q
216- N = OC
217- K = R * S * IC
218-
219- gemm_profile_result = self .gemm_profiler .profile (
220- M , N , K , out_dtype , profile_all = profile_all , use_multiprocessing = use_multiprocessing
221- )
222-
223- tile_description = gemm_profile_result ["tile_description" ]
224- alignment = gemm_profile_result ["alignment" ]
225- data_type = gemm_profile_result ["data_type" ]
226-
227- return create_conv2d_operator ([tile_description ], data_type , [alignment ])[0 ]
172+ ops = GENERATOR_FUNC_TABLE [self .sm ](out_dtype , op_creator = create_conv2d_operator )
173+ N , H , W , IC = d_shape
174+ OC , R , S , _ = w_shape
175+ ops = list (filter (lambda op : self .check_align (op ["name" ], IC , OC ), ops ))
176+
177+ for op in ops :
178+ op ["runtime" ] = - 1
179+
180+ if profile_all :
181+ self .engine .compile_all (ops , use_multiprocessing )
182+
183+ args = [
184+ "--n=%d" % N ,
185+ "--h=%d" % H ,
186+ "--w=%d" % W ,
187+ "--k=%d" % OC ,
188+ "--c=%d" % IC ,
189+ "--r=%d" % R ,
190+ "--s=%d" % S ,
191+ "--pad_h=%d" % padding [0 ],
192+ "--pad_w=%d" % padding [1 ],
193+ "--stride_h=%d" % stride [0 ],
194+ "--stride_w=%d" % stride [1 ],
195+ "--dilation_h=%d" % dilation [0 ],
196+ "--dilation_w=%d" % dilation [1 ],
197+ ]
198+ for op in ops :
199+ out = self .engine .evaluate (op , args )
200+ op ["runtime" ] = out
201+ if out > 0 and profile_all is False :
202+ break
203+
204+ valid_ops = filter (lambda op : op ["runtime" ] > 0 , ops )
205+ output = sorted (valid_ops , key = lambda i : i ["runtime" ])
206+ # self.cache[(M, N, K)] = output[0]
207+ return output [0 ]
0 commit comments