Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add cudnn_off parameter to SpatialTransformer Op and fix the inconsis…
Browse files Browse the repository at this point in the history
…tency between CPU & GPU code (#12557)
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Sep 14, 2018
1 parent e213286 commit f8ed533
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
23 changes: 13 additions & 10 deletions src/operator/spatial_transformer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct SpatialTransformerParam : public dmlc::Parameter<SpatialTransformerParam>
TShape target_shape;
int transform_type;
int sampler_type;
dmlc::optional<bool> cudnn_off;
DMLC_DECLARE_PARAMETER(SpatialTransformerParam) {
int shape[] = {0, 0};
DMLC_DECLARE_FIELD(target_shape).set_default(TShape(shape, shape + 2))
Expand All @@ -62,6 +63,8 @@ struct SpatialTransformerParam : public dmlc::Parameter<SpatialTransformerParam>
.describe("transformation type");
DMLC_DECLARE_FIELD(sampler_type).add_enum("bilinear", st::kBilinear)
.describe("sampling type");
DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>())
.describe("whether to turn cudnn off");
}
};

Expand Down Expand Up @@ -101,11 +104,11 @@ class SpatialTransformerOp : public Operator {
}
Copy(grid_dst, workspace, grid_dst.stream_);
for (index_t batch = 0; batch < data.size(0); batch++) {
if (param_.transform_type == st::kAffine) {
// Legacy approach shown here for comparison:
// grid_src[batch] = dot(loc[batch], grid_dst);
linalg_gemm(loc[batch], grid_dst, grid_src[batch], false, false, s);
}
if (param_.transform_type == st::kAffine) {
// Legacy approach shown here for comparison:
// grid_src[batch] = dot(loc[batch], grid_dst);
linalg_gemm(loc[batch], grid_dst, grid_src[batch], false, false, s);
}
}
if (param_.sampler_type == st::kBilinear) {
BilinearSamplingForward(out, data, grid_src);
Expand Down Expand Up @@ -136,11 +139,11 @@ class SpatialTransformerOp : public Operator {
BilinearSamplingBackward(gdata, grid_src, grad, data);
}
for (index_t batch = 0; batch < data.size(0); batch++) {
if (param_.transform_type == st::kAffine) {
// Legacy approach shown here for comparison:
// gloc[batch] = dot(grid_src[batch], grid_dst.T());
linalg_gemm(grid_src[batch], grid_dst, gloc[batch], false, true, s);
}
if (param_.transform_type == st::kAffine) {
// Legacy approach shown here for comparison:
// gloc[batch] = dot(grid_src[batch], grid_dst.T());
linalg_gemm(grid_src[batch], grid_dst, gloc[batch], false, true, s);
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/operator/spatial_transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd((g_input + data_index + i_w),
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd((g_input + data_index + i_w + 1),
Expand Down Expand Up @@ -194,7 +195,11 @@ Operator* CreateOp<gpu>(SpatialTransformerParam param, int dtype) {
Operator *op = NULL;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new CuDNNSpatialTransformerOp<DType>(param);
if (param.cudnn_off.has_value() && param.cudnn_off.value()) {
op = new SpatialTransformerOp<gpu, DType>(param);
} else {
op = new CuDNNSpatialTransformerOp<DType>(param);
}
})
#else
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
Expand Down
7 changes: 5 additions & 2 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,6 @@ def test_grid_generator_with_type():
check_consistency(sym, ctx_list, grad_req="add")


@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. https://github.com/apache/incubator-mxnet/issues/11839")
@with_seed()
def test_spatial_transformer_with_type():
data = mx.sym.Variable('data')
Expand All @@ -758,11 +757,15 @@ def test_spatial_transformer_with_type():
loc = mx.sym.Activation(data=loc, act_type='relu')
loc = mx.sym.FullyConnected(data=loc, num_hidden=6)
sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10),
transform_type="affine", sampler_type="bilinear")
transform_type="affine", sampler_type="bilinear", cudnn_off=True)
ctx_list = [{'ctx': mx.gpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}},
{'ctx': mx.cpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}}]
check_consistency(sym, ctx_list)
check_consistency(sym, ctx_list, grad_req="add")
sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10),
transform_type="affine", sampler_type="bilinear", cudnn_off=False)
check_consistency(sym, ctx_list)
check_consistency(sym, ctx_list, grad_req="add")


@with_seed()
Expand Down

0 comments on commit f8ed533

Please sign in to comment.