diff --git a/src/operator/spatial_transformer-inl.h b/src/operator/spatial_transformer-inl.h index 301c55c93719..3e863d877b08 100644 --- a/src/operator/spatial_transformer-inl.h +++ b/src/operator/spatial_transformer-inl.h @@ -54,6 +54,7 @@ struct SpatialTransformerParam : public dmlc::Parameter TShape target_shape; int transform_type; int sampler_type; + dmlc::optional cudnn_off; DMLC_DECLARE_PARAMETER(SpatialTransformerParam) { int shape[] = {0, 0}; DMLC_DECLARE_FIELD(target_shape).set_default(TShape(shape, shape + 2)) @@ -62,6 +63,8 @@ struct SpatialTransformerParam : public dmlc::Parameter .describe("transformation type"); DMLC_DECLARE_FIELD(sampler_type).add_enum("bilinear", st::kBilinear) .describe("sampling type"); + DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional()) + .describe("whether to turn cudnn off"); } }; @@ -101,11 +104,11 @@ class SpatialTransformerOp : public Operator { } Copy(grid_dst, workspace, grid_dst.stream_); for (index_t batch = 0; batch < data.size(0); batch++) { - if (param_.transform_type == st::kAffine) { - // Legacy approach shown here for comparison: - // grid_src[batch] = dot(loc[batch], grid_dst); - linalg_gemm(loc[batch], grid_dst, grid_src[batch], false, false, s); - } + if (param_.transform_type == st::kAffine) { + // Legacy approach shown here for comparison: + // grid_src[batch] = dot(loc[batch], grid_dst); + linalg_gemm(loc[batch], grid_dst, grid_src[batch], false, false, s); + } } if (param_.sampler_type == st::kBilinear) { BilinearSamplingForward(out, data, grid_src); @@ -136,11 +139,11 @@ class SpatialTransformerOp : public Operator { BilinearSamplingBackward(gdata, grid_src, grad, data); } for (index_t batch = 0; batch < data.size(0); batch++) { - if (param_.transform_type == st::kAffine) { - // Legacy approach shown here for comparison: - // gloc[batch] = dot(grid_src[batch], grid_dst.T()); - linalg_gemm(grid_src[batch], grid_dst, gloc[batch], false, true, s); - } + if (param_.transform_type == st::kAffine) { + // Legacy approach shown here for comparison: + // gloc[batch] = dot(grid_src[batch], grid_dst.T()); + linalg_gemm(grid_src[batch], grid_dst, gloc[batch], false, true, s); + } } } diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index f1d69f7618e6..33dbe3e7c069 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -121,6 +121,7 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h, if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { atomicAdd((g_input + data_index + i_w), *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w); + bottom_left_v = *(data + data_index + i_w); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { atomicAdd((g_input + data_index + i_w + 1), @@ -194,7 +195,11 @@ Operator* CreateOp(SpatialTransformerParam param, int dtype) { Operator *op = NULL; #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new CuDNNSpatialTransformerOp(param); + if (param.cudnn_off.has_value() && param.cudnn_off.value()) { + op = new SpatialTransformerOp(param); + } else { + op = new CuDNNSpatialTransformerOp(param); + } }) #else MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index d201a2e09c6d..dd7ec985c7c8 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -749,7 +749,6 @@ def test_grid_generator_with_type(): check_consistency(sym, ctx_list, grad_req="add") -@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. https://github.com/apache/incubator-mxnet/issues/11839") @with_seed() def test_spatial_transformer_with_type(): data = mx.sym.Variable('data') @@ -758,11 +757,15 @@ def test_spatial_transformer_with_type(): loc = mx.sym.Activation(data=loc, act_type='relu') loc = mx.sym.FullyConnected(data=loc, num_hidden=6) sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10), - transform_type="affine", sampler_type="bilinear") + transform_type="affine", sampler_type="bilinear", cudnn_off=True) ctx_list = [{'ctx': mx.gpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}}, {'ctx': mx.cpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}}] check_consistency(sym, ctx_list) check_consistency(sym, ctx_list, grad_req="add") + sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10), + transform_type="affine", sampler_type="bilinear", cudnn_off=False) + check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, grad_req="add") @with_seed()