@@ -37,8 +37,8 @@ void ConvolutionBackwardData(int mode, int format, int algo, int dims, int group
3737 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
3838 // Set Mode
3939 entry_ptr->conv_entry .mode = static_cast <cudnnConvolutionMode_t>(mode);
40- SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, dy ->shape , w->shape ,
41- dx ->shape , dy->dtype , conv_dtype);
40+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, dx ->shape , w->shape ,
41+ dy ->shape , dy->dtype , conv_dtype);
4242 // Set Device
4343 entry_ptr->conv_entry .device = dy->device ;
4444 // Set Algo
@@ -74,8 +74,8 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con
7474 w_dim_int64[i] = w_dim[i];
7575 dx_dim_int64[i] = dx_dim[i];
7676 }
77- SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, dy_dim_int64 .data (),
78- w_dim_int64.data (), dx_dim_int64 .data (), String2DLDataType (data_dtype),
77+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64 .data (),
78+ w_dim_int64.data (), dy_dim_int64 .data (), String2DLDataType (data_dtype),
7979 conv_dtype);
8080
8181 int returned_algo_count = 0 ;
@@ -86,8 +86,8 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con
8686 entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .input_desc ,
8787 CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results));
8888
89- const std::vector<std::string> fwd_algo_names {
90- " CUDNN_CONVOLUTION_BWD_DATA_ALGO_0" ,
89+ const std::vector<std::string> bwd_data_algo_names {
90+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_0" , // non-deterministic
9191 " CUDNN_CONVOLUTION_BWD_DATA_ALGO_1" ,
9292 " CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT" ,
9393 " CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING" ,
@@ -96,9 +96,86 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con
9696
9797 auto best_algo = perf_results[0 ].algo ;
9898 LOG (INFO) << " \t CUDNN Found " << returned_algo_count << " bwd data algorithms, choosing "
99- << fwd_algo_names [best_algo];
99+ << bwd_data_algo_names [best_algo];
100100 for (int i = 0 ; i < returned_algo_count; ++i) {
101- LOG (INFO) << " \t\t " << i << " ) " << fwd_algo_names[perf_results[i].algo ]
101+ LOG (INFO) << " \t\t " << i << " ) " << bwd_data_algo_names[perf_results[i].algo ]
102+ << " - time: " << perf_results[i].time << " ms"
103+ << " , Memory: " << perf_results[i].memory ;
104+ }
105+
106+ ret[0 ] = best_algo;
107+ }
108+
109+ void ConvolutionBackwardFilter (int mode, int format, int algo, int dims, int groups,
110+ const int pad[], const int stride[], const int dilation[],
111+ DLTensor* x, DLTensor* dy, DLTensor* dw,
112+ const std::string& conv_dtype) {
113+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
114+ // Set Mode
115+ entry_ptr->conv_entry .mode = static_cast <cudnnConvolutionMode_t>(mode);
116+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, x->shape , dw->shape ,
117+ dy->shape , x->dtype , conv_dtype);
118+ // Set Device
119+ entry_ptr->conv_entry .device = x->device ;
120+ // Set Algo
121+ entry_ptr->conv_entry .bwd_filter_algo = static_cast <cudnnConvolutionBwdFilterAlgo_t>(algo);
122+
123+ // Set workspace
124+ size_t workspace_size = 0 ;
125+ CUDNN_CALL (cudnnGetConvolutionBackwardFilterWorkspaceSize (
126+ entry_ptr->handle , entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .output_desc ,
127+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .filter_desc ,
128+ entry_ptr->conv_entry .bwd_filter_algo , &workspace_size));
129+ entry_ptr->conv_entry .UpdateWorkspace (workspace_size);
130+ CUDNN_CALL (cudnnConvolutionBackwardFilter (
131+ entry_ptr->handle , CuDNNDataType::GetConst<1 >(entry_ptr->conv_entry .data_type ),
132+ entry_ptr->conv_entry .input_desc , x->data , entry_ptr->conv_entry .output_desc , dy->data ,
133+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .bwd_filter_algo ,
134+ entry_ptr->conv_entry .workspace , workspace_size,
135+ CuDNNDataType::GetConst<0 >(entry_ptr->conv_entry .data_type ),
136+ entry_ptr->conv_entry .filter_desc , dw->data ));
137+ }
138+
139+ void BackwardFilterFindAlgo (int format, int dims, int groups, const int pad[], const int stride[],
140+ const int dilation[], const int x_dim[], const int dy_dim[],
141+ const int dw_dim[], const std::string& data_dtype,
142+ const std::string& conv_dtype, TVMRetValue* ret) {
143+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
144+ const int full_dims = dims + 2 ;
145+ std::vector<int64_t > x_dim_int64 (full_dims);
146+ std::vector<int64_t > dy_dim_int64 (full_dims);
147+ std::vector<int64_t > dw_dim_int64 (full_dims);
148+ for (int i = 0 ; i < full_dims; ++i) {
149+ x_dim_int64[i] = x_dim[i];
150+ dy_dim_int64[i] = dy_dim[i];
151+ dw_dim_int64[i] = dw_dim[i];
152+ }
153+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data (),
154+ dw_dim_int64.data (), dy_dim_int64.data (), String2DLDataType (data_dtype),
155+ conv_dtype);
156+
157+ int returned_algo_count = 0 ;
158+
159+ cudnnConvolutionBwdFilterAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT];
160+ CUDNN_CALL (cudnnFindConvolutionBackwardFilterAlgorithm (
161+ entry_ptr->handle , entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .output_desc ,
162+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .filter_desc ,
163+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, perf_results));
164+
165+ const std::vector<std::string> bwd_filter_algo_names{
166+ " CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0" , // non-deterministic
167+ " CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1" ,
168+ " CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT" ,
169+ " CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3" ,
170+ " CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED" ,
171+ " CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING" ,
172+ };
173+
174+ auto best_algo = perf_results[0 ].algo ;
175+ LOG (INFO) << " \t CUDNN Found " << returned_algo_count << " bwd filter algorithms, choosing "
176+ << bwd_filter_algo_names[best_algo];
177+ for (int i = 0 ; i < returned_algo_count; ++i) {
178+ LOG (INFO) << " \t\t " << i << " ) " << bwd_filter_algo_names[perf_results[i].algo ]
102179 << " - time: " << perf_results[i].time << " ms"
103180 << " , Memory: " << perf_results[i].memory ;
104181 }
@@ -145,5 +222,44 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
145222 data_dtype, conv_dtype, ret);
146223 });
147224
225+ TVM_REGISTER_GLOBAL (" tvm.contrib.cudnn.conv2d.backward_filter" )
226+ .set_body([](TVMArgs args, TVMRetValue* ret) {
227+ int mode = args[0 ];
228+ int format = args[1 ];
229+ int algo = args[2 ];
230+ int pad_v[2 ], stride_v[2 ], dilation_v[2 ];
231+ for (int i = 0 ; i < 2 ; i++) {
232+ pad_v[i] = args[3 + i];
233+ stride_v[i] = args[5 + i];
234+ dilation_v[i] = args[7 + i];
235+ }
236+ DLTensor* x = args[9 ];
237+ DLTensor* dy = args[10 ];
238+ DLTensor* dw = args[11 ];
239+ std::string conv_dtype = args[12 ];
240+ int groups = args[13 ];
241+
242+ ConvolutionBackwardFilter (mode, format, algo, 2 , groups, pad_v, stride_v, dilation_v, x, dy,
243+ dw, conv_dtype);
244+ });
245+
246+ TVM_REGISTER_GLOBAL (" tvm.contrib.cudnn.conv.backward_filter_find_algo" )
247+ .set_body([](TVMArgs args, TVMRetValue* ret) {
248+ int format = args[0 ];
249+ int dims = args[1 ];
250+ int * pad = static_cast <int *>(static_cast <void *>(args[2 ]));
251+ int * stride = static_cast <int *>(static_cast <void *>(args[3 ]));
252+ int * dilation = static_cast <int *>(static_cast <void *>(args[4 ]));
253+ int * x_dim = static_cast <int *>(static_cast <void *>(args[5 ]));
254+ int * dy_dim = static_cast <int *>(static_cast <void *>(args[6 ]));
255+ int * dw_dim = static_cast <int *>(static_cast <void *>(args[7 ]));
256+ std::string data_dtype = args[8 ];
257+ std::string conv_dtype = args[9 ];
258+ int groups = args[10 ];
259+
260+ BackwardFilterFindAlgo (format, dims, groups, pad, stride, dilation, x_dim, dy_dim, dw_dim,
261+ data_dtype, conv_dtype, ret);
262+ });
263+
148264} // namespace contrib
149265} // namespace tvm
0 commit comments