Skip to content

Commit e311ba3

Browse files
committed
dgrad compiled
1 parent 47f35be commit e311ba3

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

src/runtime/contrib/cudnn/conv_backward.cc

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,70 +32,70 @@ namespace contrib {
3232
using namespace runtime;
3333

3434
void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[],
35-
const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
36-
DLTensor* y, const std::string& conv_dtype) {
35+
const int stride[], const int dilation[], DLTensor* dy, DLTensor* w,
36+
DLTensor* dx, const std::string& conv_dtype) {
3737
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
3838
// Set Mode
3939
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
40-
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
41-
y->shape, x->dtype, conv_dtype);
40+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dy->shape, w->shape,
41+
dx->shape, dy->dtype, conv_dtype);
4242
// Set Device
43-
entry_ptr->conv_entry.device = x->device;
43+
entry_ptr->conv_entry.device = dy->device;
4444
// Set Algo
45-
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
45+
entry_ptr->conv_entry.bwd_data_algo = static_cast<cudnnConvolutionBwdDataAlgo_t>(algo);
4646

4747
// Set workspace
4848
size_t workspace_size = 0;
49-
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
50-
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
51-
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
52-
entry_ptr->conv_entry.fwd_algo, &workspace_size));
49+
CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(
50+
entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc,
51+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
52+
entry_ptr->conv_entry.bwd_data_algo, &workspace_size));
5353
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
54-
CUDNN_CALL(cudnnConvolutionForward(
54+
CUDNN_CALL(cudnnConvolutionBackwardData(
5555
entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
56-
entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
57-
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
56+
entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.output_desc, dy->data,
57+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo,
5858
entry_ptr->conv_entry.workspace, workspace_size,
59-
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
60-
entry_ptr->conv_entry.output_desc, y->data));
59+
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc,
60+
dx->data));
6161
}
6262

6363
void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
64-
const int dilation[], const int x_dim[], const int w_dim[],
65-
const int y_dim[], const std::string& data_dtype,
64+
const int dilation[], const int dy_dim[], const int w_dim[],
65+
const int dx_dim[], const std::string& data_dtype,
6666
const std::string& conv_dtype, TVMRetValue* ret) {
6767
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
6868
const int full_dims = dims + 2;
69-
std::vector<int64_t> x_dim_int64(full_dims);
69+
std::vector<int64_t> dy_dim_int64(full_dims);
7070
std::vector<int64_t> w_dim_int64(full_dims);
71-
std::vector<int64_t> y_dim_int64(full_dims);
71+
std::vector<int64_t> dx_dim_int64(full_dims);
7272
for (int i = 0; i < full_dims; ++i) {
73-
x_dim_int64[i] = x_dim[i];
73+
dy_dim_int64[i] = dy_dim[i];
7474
w_dim_int64[i] = w_dim[i];
75-
y_dim_int64[i] = y_dim[i];
75+
dx_dim_int64[i] = dx_dim[i];
7676
}
77-
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
78-
w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype),
77+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dy_dim_int64.data(),
78+
w_dim_int64.data(), dx_dim_int64.data(), String2DLDataType(data_dtype),
7979
conv_dtype);
8080

8181
int returned_algo_count = 0;
82-
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
83-
CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
84-
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
85-
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
86-
CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results));
87-
88-
const std::vector<std::string> fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
89-
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
90-
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
91-
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
92-
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
93-
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
94-
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
95-
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"};
82+
83+
cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT];
84+
CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(
85+
entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc,
86+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
87+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results));
88+
89+
const std::vector<std::string> fwd_algo_names{
90+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",
91+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
92+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
93+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
94+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
95+
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"};
9696

9797
auto best_algo = perf_results[0].algo;
98-
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing "
98+
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing "
9999
<< fwd_algo_names[best_algo];
100100
for (int i = 0; i < returned_algo_count; ++i) {
101101
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
@@ -117,13 +117,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data")
117117
stride_v[i] = args[5 + i];
118118
dilation_v[i] = args[7 + i];
119119
}
120-
DLTensor* x = args[9];
120+
DLTensor* dy = args[9];
121121
DLTensor* w = args[10];
122-
DLTensor* y = args[11];
122+
DLTensor* dx = args[11];
123123
std::string conv_dtype = args[12];
124124
int groups = args[13];
125125

126-
ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y,
126+
ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx,
127127
conv_dtype);
128128
});
129129

@@ -134,14 +134,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
134134
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
135135
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
136136
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
137-
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
137+
int* dy_dim = static_cast<int*>(static_cast<void*>(args[5]));
138138
int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
139-
int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
139+
int* dx_dim = static_cast<int*>(static_cast<void*>(args[7]));
140140
std::string data_dtype = args[8];
141141
std::string conv_dtype = args[9];
142142
int groups = args[10];
143143

144-
BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim,
144+
BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim,
145145
data_dtype, conv_dtype, ret);
146146
});
147147

src/runtime/contrib/cudnn/cudnn_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) {
6767
struct ConvEntry {
6868
cudnnConvolutionDescriptor_t conv_desc;
6969
cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION};
70-
cudnnFilterDescriptor_t filter_desc;
7170
cudnnDataType_t data_type;
7271
cudnnTensorFormat_t tensor_format;
7372
cudnnTensorDescriptor_t input_desc;
73+
cudnnFilterDescriptor_t filter_desc;
7474
cudnnTensorDescriptor_t output_desc;
7575
cudnnConvolutionFwdAlgo_t fwd_algo;
76+
cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
77+
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
7678
// cudnnMathType_t math_type;
7779
Device device;
7880
runtime::DeviceAPI* cuda_api;

0 commit comments

Comments
 (0)