@@ -60,6 +60,44 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
6060 entry_ptr->conv_entry .output_desc , y->data ));
6161}
6262
63+ void ConvolutionBiasActivationForward (int mode, int format, int algo, int dims, int groups, int act,
64+ double coef, const int pad[], const int stride[],
65+ const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y,
66+ DLTensor* bias, const std::string& conv_dtype) {
67+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
68+ // Set Mode
69+ entry_ptr->conv_entry .mode = static_cast <cudnnConvolutionMode_t>(mode);
70+ CUDNN_CALL (cudnnSetActivationDescriptor (entry_ptr->conv_entry .activation_desc ,
71+ static_cast <cudnnActivationMode_t>(act),
72+ cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef));
73+ CUDNN_CALL (cudnnSetTensor4dDescriptor (
74+ entry_ptr->conv_entry .bias_desc , entry_ptr->conv_entry .tensor_format ,
75+ CuDNNDataType::DLTypeToCuDNNType (bias->dtype ), 1 , static_cast <int >(w->shape [0 ]), 1 , 1 ));
76+
77+ SetConvDescriptors (entry_ptr, format, dims, groups, pad, stride, dilation, x->shape , w->shape ,
78+ y->shape , x->dtype , conv_dtype);
79+ // Set Device
80+ entry_ptr->conv_entry .device = x->device ;
81+ // Set Algo
82+ entry_ptr->conv_entry .fwd_algo = static_cast <cudnnConvolutionFwdAlgo_t>(algo);
83+
84+ // Set workspace
85+ size_t workspace_size = 0 ;
86+ CUDNN_CALL (cudnnGetConvolutionForwardWorkspaceSize (
87+ entry_ptr->handle , entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .filter_desc ,
88+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .output_desc ,
89+ entry_ptr->conv_entry .fwd_algo , &workspace_size));
90+ entry_ptr->conv_entry .UpdateWorkspace (workspace_size);
91+ CUDNN_CALL (cudnnConvolutionBiasActivationForward (
92+ entry_ptr->handle , CuDNNDataType::GetConst<1 >(entry_ptr->conv_entry .data_type ),
93+ entry_ptr->conv_entry .input_desc , x->data , entry_ptr->conv_entry .filter_desc , w->data ,
94+ entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .fwd_algo ,
95+ entry_ptr->conv_entry .workspace , workspace_size,
96+ CuDNNDataType::GetConst<0 >(entry_ptr->conv_entry .data_type ),
97+ entry_ptr->conv_entry .output_desc , y->data , entry_ptr->conv_entry .bias_desc , bias->data ,
98+ entry_ptr->conv_entry .activation_desc , entry_ptr->conv_entry .output_desc , y->data ));
99+ }
100+
63101void FindAlgo (int format, int dims, int groups, const int pad[], const int stride[],
64102 const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
65103 const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
@@ -126,6 +164,30 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
126164 conv_dtype);
127165 });
128166
167+ TVM_REGISTER_GLOBAL (" tvm.contrib.cudnn.conv2d+bias+act.forward" )
168+ .set_body([](TVMArgs args, TVMRetValue* ret) {
169+ int mode = args[0 ];
170+ int format = args[1 ];
171+ int algo = args[2 ];
172+ int pad_v[2 ], stride_v[2 ], dilation_v[2 ];
173+ for (int i = 0 ; i < 2 ; i++) {
174+ pad_v[i] = args[3 + i];
175+ stride_v[i] = args[5 + i];
176+ dilation_v[i] = args[7 + i];
177+ }
178+ int act = args[9 ];
179+ double coef = args[10 ];
180+ DLTensor* x = args[11 ];
181+ DLTensor* w = args[12 ];
182+ DLTensor* bias = args[13 ];
183+ DLTensor* y = args[14 ];
184+ std::string conv_dtype = args[15 ];
185+ int groups = args[16 ];
186+
187+ ConvolutionBiasActivationForward (mode, format, algo, 2 , groups, act, coef, pad_v, stride_v,
188+ dilation_v, x, w, y, bias, conv_dtype);
189+ });
190+
129191TVM_REGISTER_GLOBAL (" tvm.contrib.cudnn.conv3d.forward" )
130192 .set_body([](TVMArgs args, TVMRetValue* ret) {
131193 int mode = args[0 ];
0 commit comments