@@ -32,70 +32,70 @@ namespace contrib {
3232using namespace runtime ;
3333
3434void ConvolutionBackwardData (int mode, int format, int algo, int dims, int groups, const int pad[],
35- const int stride[], const int dilation[], DLTensor* x , DLTensor* w,
36- DLTensor* y , const std::string& conv_dtype) {
35+ const int stride[], const int dilation[], DLTensor* dy , DLTensor* w,
36+ DLTensor* dx , const std::string& conv_dtype) {
3737 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
3838 // Set Mode
3939 entry_ptr->conv_entry .mode = static_cast <cudnnConvolutionMode_t>(mode);
40- SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, x ->shape , w->shape ,
41- y ->shape , x ->dtype , conv_dtype);
40+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, dy ->shape , w->shape ,
41+ dx ->shape , dy ->dtype , conv_dtype);
4242 // Set Device
43- entry_ptr->conv_entry .device = x ->device ;
43+ entry_ptr->conv_entry .device = dy ->device ;
4444 // Set Algo
45- entry_ptr->conv_entry .fwd_algo = static_cast <cudnnConvolutionFwdAlgo_t >(algo);
45+ entry_ptr->conv_entry .bwd_data_algo = static_cast <cudnnConvolutionBwdDataAlgo_t >(algo);
4646
4747 // Set workspace
4848 size_t workspace_size = 0 ;
49- CUDNN_CALL (cudnnGetConvolutionForwardWorkspaceSize (
50- entry_ptr->handle , entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .filter_desc ,
51- entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .output_desc ,
52- entry_ptr->conv_entry .fwd_algo , &workspace_size));
49+ CUDNN_CALL (cudnnGetConvolutionBackwardDataWorkspaceSize (
50+ entry_ptr->handle , entry_ptr->conv_entry .filter_desc , entry_ptr->conv_entry .output_desc ,
51+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .input_desc ,
52+ entry_ptr->conv_entry .bwd_data_algo , &workspace_size));
5353 entry_ptr->conv_entry .UpdateWorkspace (workspace_size);
54- CUDNN_CALL (cudnnConvolutionForward (
54+ CUDNN_CALL (cudnnConvolutionBackwardData (
5555 entry_ptr->handle , CuDNNDataType::GetConst<1 >(entry_ptr->conv_entry .data_type ),
56- entry_ptr->conv_entry .input_desc , x ->data , entry_ptr->conv_entry .filter_desc , w ->data ,
57- entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .fwd_algo ,
56+ entry_ptr->conv_entry .filter_desc , w ->data , entry_ptr->conv_entry .output_desc , dy ->data ,
57+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .bwd_data_algo ,
5858 entry_ptr->conv_entry .workspace , workspace_size,
59- CuDNNDataType::GetConst<0 >(entry_ptr->conv_entry .data_type ),
60- entry_ptr-> conv_entry . output_desc , y ->data ));
59+ CuDNNDataType::GetConst<0 >(entry_ptr->conv_entry .data_type ), entry_ptr-> conv_entry . input_desc ,
60+ dx ->data ));
6161}
6262
6363void BackwardDataFindAlgo (int format, int dims, int groups, const int pad[], const int stride[],
64- const int dilation[], const int x_dim [], const int w_dim[],
65- const int y_dim [], const std::string& data_dtype,
64+ const int dilation[], const int dy_dim [], const int w_dim[],
65+ const int dx_dim [], const std::string& data_dtype,
6666 const std::string& conv_dtype, TVMRetValue* ret) {
6767 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
6868 const int full_dims = dims + 2 ;
69- std::vector<int64_t > x_dim_int64 (full_dims);
69+ std::vector<int64_t > dy_dim_int64 (full_dims);
7070 std::vector<int64_t > w_dim_int64 (full_dims);
71- std::vector<int64_t > y_dim_int64 (full_dims);
71+ std::vector<int64_t > dx_dim_int64 (full_dims);
7272 for (int i = 0 ; i < full_dims; ++i) {
73- x_dim_int64 [i] = x_dim [i];
73+ dy_dim_int64 [i] = dy_dim [i];
7474 w_dim_int64[i] = w_dim[i];
75- y_dim_int64 [i] = y_dim [i];
75+ dx_dim_int64 [i] = dx_dim [i];
7676 }
77- SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64 .data (),
78- w_dim_int64.data (), y_dim_int64 .data (), String2DLDataType (data_dtype),
77+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, dy_dim_int64 .data (),
78+ w_dim_int64.data (), dx_dim_int64 .data (), String2DLDataType (data_dtype),
7979 conv_dtype);
8080
8181 int returned_algo_count = 0 ;
82- cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
83- CUDNN_CALL ( cudnnFindConvolutionForwardAlgorithm (
84- entry_ptr-> handle , entry_ptr-> conv_entry . input_desc , entry_ptr-> conv_entry . filter_desc ,
85- entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .output_desc ,
86- CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results));
87-
88- const std::vector<std::string> fwd_algo_names{ " CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM " ,
89- " CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM " ,
90- " CUDNN_CONVOLUTION_FWD_ALGO_GEMM " ,
91- " CUDNN_CONVOLUTION_FWD_ALGO_DIRECT " ,
92- " CUDNN_CONVOLUTION_FWD_ALGO_FFT " ,
93- " CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING " ,
94- " CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD " ,
95- " CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED " };
82+
83+ cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT];
84+ CUDNN_CALL ( cudnnFindConvolutionBackwardDataAlgorithm (
85+ entry_ptr->handle , entry_ptr-> conv_entry .filter_desc , entry_ptr->conv_entry .output_desc ,
86+ entry_ptr-> conv_entry . conv_desc , entry_ptr-> conv_entry . input_desc ,
87+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results));
88+
89+ const std::vector<std::string> fwd_algo_names{
90+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 " ,
91+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 " ,
92+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT " ,
93+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING " ,
94+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD " ,
95+ " CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED " };
9696
9797 auto best_algo = perf_results[0 ].algo ;
98- LOG (INFO) << " \t CUDNN Found " << returned_algo_count << " fwd algorithms, choosing "
98+ LOG (INFO) << " \t CUDNN Found " << returned_algo_count << " bwd data algorithms, choosing "
9999 << fwd_algo_names[best_algo];
100100 for (int i = 0 ; i < returned_algo_count; ++i) {
101101 LOG (INFO) << " \t\t " << i << " ) " << fwd_algo_names[perf_results[i].algo ]
@@ -117,13 +117,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data")
117117 stride_v[i] = args[5 + i];
118118 dilation_v[i] = args[7 + i];
119119 }
120- DLTensor* x = args[9 ];
120+ DLTensor* dy = args[9 ];
121121 DLTensor* w = args[10 ];
122- DLTensor* y = args[11 ];
122+ DLTensor* dx = args[11 ];
123123 std::string conv_dtype = args[12 ];
124124 int groups = args[13 ];
125125
126- ConvolutionBackwardData (mode, format, algo, 2 , groups, pad_v, stride_v, dilation_v, x , w, y ,
126+ ConvolutionBackwardData (mode, format, algo, 2 , groups, pad_v, stride_v, dilation_v, dy , w, dx ,
127127 conv_dtype);
128128 });
129129
@@ -134,14 +134,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
134134 int * pad = static_cast <int *>(static_cast <void *>(args[2 ]));
135135 int * stride = static_cast <int *>(static_cast <void *>(args[3 ]));
136136 int * dilation = static_cast <int *>(static_cast <void *>(args[4 ]));
137- int * x_dim = static_cast <int *>(static_cast <void *>(args[5 ]));
137+ int * dy_dim = static_cast <int *>(static_cast <void *>(args[5 ]));
138138 int * w_dim = static_cast <int *>(static_cast <void *>(args[6 ]));
139- int * y_dim = static_cast <int *>(static_cast <void *>(args[7 ]));
139+ int * dx_dim = static_cast <int *>(static_cast <void *>(args[7 ]));
140140 std::string data_dtype = args[8 ];
141141 std::string conv_dtype = args[9 ];
142142 int groups = args[10 ];
143143
144- BackwardDataFindAlgo (format, dims, groups, pad, stride, dilation, x_dim , w_dim, y_dim ,
144+ BackwardDataFindAlgo (format, dims, groups, pad, stride, dilation, dy_dim , w_dim, dx_dim ,
145145 data_dtype, conv_dtype, ret);
146146 });
147147
0 commit comments