Skip to content

Commit 3166952

Browse files
committed
bwd filter compiled
1 parent e311ba3 commit 3166952

File tree

1 file changed

+124
-8
lines changed

1 file changed

+124
-8
lines changed

src/runtime/contrib/cudnn/conv_backward.cc

Lines changed: 124 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ void ConvolutionBackwardData(int mode, int format, int algo, int dims, int group
3737
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
3838
// Set Mode
3939
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
40-
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dy->shape, w->shape,
41-
dx->shape, dy->dtype, conv_dtype);
40+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx->shape, w->shape,
41+
dy->shape, dy->dtype, conv_dtype);
4242
// Set Device
4343
entry_ptr->conv_entry.device = dy->device;
4444
// Set Algo
@@ -74,8 +74,8 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con
7474
w_dim_int64[i] = w_dim[i];
7575
dx_dim_int64[i] = dx_dim[i];
7676
}
77-
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dy_dim_int64.data(),
78-
w_dim_int64.data(), dx_dim_int64.data(), String2DLDataType(data_dtype),
77+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64.data(),
78+
w_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype),
7979
conv_dtype);
8080

8181
int returned_algo_count = 0;
@@ -86,8 +86,8 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con
8686
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
8787
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results));
8888

89-
const std::vector<std::string> fwd_algo_names{
90-
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",
89+
const std::vector<std::string> bwd_data_algo_names{
90+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", // non-deterministic
9191
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
9292
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
9393
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
@@ -96,9 +96,86 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con
9696

9797
auto best_algo = perf_results[0].algo;
9898
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing "
99-
<< fwd_algo_names[best_algo];
99+
<< bwd_data_algo_names[best_algo];
100100
for (int i = 0; i < returned_algo_count; ++i) {
101-
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
101+
LOG(INFO) << "\t\t" << i << ") " << bwd_data_algo_names[perf_results[i].algo]
102+
<< " - time: " << perf_results[i].time << " ms"
103+
<< ", Memory: " << perf_results[i].memory;
104+
}
105+
106+
ret[0] = best_algo;
107+
}
108+
109+
void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int groups,
110+
const int pad[], const int stride[], const int dilation[],
111+
DLTensor* x, DLTensor* dy, DLTensor* dw,
112+
const std::string& conv_dtype) {
113+
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
114+
// Set Mode
115+
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
116+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, dw->shape,
117+
dy->shape, x->dtype, conv_dtype);
118+
// Set Device
119+
entry_ptr->conv_entry.device = x->device;
120+
// Set Algo
121+
entry_ptr->conv_entry.bwd_filter_algo = static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo);
122+
123+
// Set workspace
124+
size_t workspace_size = 0;
125+
CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(
126+
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc,
127+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc,
128+
entry_ptr->conv_entry.bwd_filter_algo, &workspace_size));
129+
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
130+
CUDNN_CALL(cudnnConvolutionBackwardFilter(
131+
entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
132+
entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.output_desc, dy->data,
133+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_filter_algo,
134+
entry_ptr->conv_entry.workspace, workspace_size,
135+
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
136+
entry_ptr->conv_entry.filter_desc, dw->data));
137+
}
138+
139+
void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
140+
const int dilation[], const int x_dim[], const int dy_dim[],
141+
const int dw_dim[], const std::string& data_dtype,
142+
const std::string& conv_dtype, TVMRetValue* ret) {
143+
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
144+
const int full_dims = dims + 2;
145+
std::vector<int64_t> x_dim_int64(full_dims);
146+
std::vector<int64_t> dy_dim_int64(full_dims);
147+
std::vector<int64_t> dw_dim_int64(full_dims);
148+
for (int i = 0; i < full_dims; ++i) {
149+
x_dim_int64[i] = x_dim[i];
150+
dy_dim_int64[i] = dy_dim[i];
151+
dw_dim_int64[i] = dw_dim[i];
152+
}
153+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
154+
dw_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype),
155+
conv_dtype);
156+
157+
int returned_algo_count = 0;
158+
159+
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT];
160+
CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(
161+
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc,
162+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc,
163+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, perf_results));
164+
165+
const std::vector<std::string> bwd_filter_algo_names{
166+
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0", // non-deterministic
167+
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1",
168+
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT",
169+
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3",
170+
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED",
171+
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING",
172+
};
173+
174+
auto best_algo = perf_results[0].algo;
175+
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd filter algorithms, choosing "
176+
<< bwd_filter_algo_names[best_algo];
177+
for (int i = 0; i < returned_algo_count; ++i) {
178+
LOG(INFO) << "\t\t" << i << ") " << bwd_filter_algo_names[perf_results[i].algo]
102179
<< " - time: " << perf_results[i].time << " ms"
103180
<< ", Memory: " << perf_results[i].memory;
104181
}
@@ -145,5 +222,44 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
145222
data_dtype, conv_dtype, ret);
146223
});
147224

225+
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter")
226+
.set_body([](TVMArgs args, TVMRetValue* ret) {
227+
int mode = args[0];
228+
int format = args[1];
229+
int algo = args[2];
230+
int pad_v[2], stride_v[2], dilation_v[2];
231+
for (int i = 0; i < 2; i++) {
232+
pad_v[i] = args[3 + i];
233+
stride_v[i] = args[5 + i];
234+
dilation_v[i] = args[7 + i];
235+
}
236+
DLTensor* x = args[9];
237+
DLTensor* dy = args[10];
238+
DLTensor* dw = args[11];
239+
std::string conv_dtype = args[12];
240+
int groups = args[13];
241+
242+
ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, dy,
243+
dw, conv_dtype);
244+
});
245+
246+
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo")
247+
.set_body([](TVMArgs args, TVMRetValue* ret) {
248+
int format = args[0];
249+
int dims = args[1];
250+
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
251+
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
252+
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
253+
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
254+
int* dy_dim = static_cast<int*>(static_cast<void*>(args[6]));
255+
int* dw_dim = static_cast<int*>(static_cast<void*>(args[7]));
256+
std::string data_dtype = args[8];
257+
std::string conv_dtype = args[9];
258+
int groups = args[10];
259+
260+
BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, x_dim, dy_dim, dw_dim,
261+
data_dtype, conv_dtype, ret);
262+
});
263+
148264
} // namespace contrib
149265
} // namespace tvm

0 commit comments

Comments
 (0)