Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add tests for backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-krishnamurthy committed Jan 17, 2019
1 parent 2d698a3 commit 7011380
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/operator/image/normalize_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ struct normalize_backward {
MSHADOW_XINLINE static void Map(int j, DType* in_grad, const DType* out_grad,
const DType* in_data, const int i, const int length,
const int step, const DType std_dev) {
// d/dx{(x - mean) / std_dev} is (1 / std_dev)
KERNEL_ASSIGN(in_grad[step + i*length + j], req, out_grad[i] * (1.0 / std_dev));
// d/dx{(x - mean) / std_dev} => (1 / std_dev)
KERNEL_ASSIGN(in_grad[step + i*length + j], req,
out_grad[step + i*length + j] * (1.0 / std_dev));
}
};

Expand All @@ -201,9 +202,6 @@ void NormalizeBackwardImpl(const OpContext &ctx,
const TBlob& in_grad = outputs[0];
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();

for (int i = 0; i < channel; ++i) {
DType std_dev = param.std[param.std.ndim() > 1 ? i : 0];
mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
Expand Down
67 changes: 67 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7296,6 +7296,73 @@ def test_invalid_max_pooling_pad_type_same():
name='pooling',
pooling_convention="same")


@with_seed()
def test_image_normalize():
# Part 1 - Test 3D Input
shape_3d = (3, 28, 28)
mean = (0, 1, 2)
std = (3, 2, 1)

data_in_3d = mx.nd.random.uniform(0, 1, shape_3d)
data_expected_3d = data_in_3d.asnumpy()
data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0

data = mx.symbol.Variable('data')
img_norm_sym = mx.sym.image.normalize(data=data, mean=mean, std=std)

# check forward
check_symbolic_forward(img_norm_sym, [data_in_3d], [data_expected_3d],
rtol=1e-5, atol=1e-5)

# Gradient is 1/std_dev
grad_expected_3d = np.ones(shape_3d)
grad_expected_3d[:][:][0] = 1 / 3.0
grad_expected_3d[:][:][1] = 1 / 2.0
grad_expected_3d[:][:][2] = 1 / 1.0

# check backward
check_symbolic_backward(img_norm_sym, location=[data_in_3d], out_grads=[mx.nd.ones(shape_3d)],
expected=[grad_expected_3d], rtol=1e-5, atol=1e-5)

# check backward using finite difference
check_numeric_gradient(img_norm_sym, [data_in_3d], atol=0.001)

# Part 2 - Test 4D Input
shape_4d = (2, 3, 28, 28)

data_in_4d = mx.nd.random.uniform(0, 1, shape_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0

# check forward
check_symbolic_forward(img_norm_sym, [data_in_4d], [data_expected_4d],
rtol=1e-5, atol=1e-5)

# Gradient is 1/std_dev
grad_expected_4d = np.ones(shape_4d)
grad_expected_4d[0][:][:][0] = 1 / 3.0
grad_expected_4d[0][:][:][1] = 1 / 2.0
grad_expected_4d[0][:][:][2] = 1 / 1.0
grad_expected_4d[1][:][:][0] = 1 / 3.0
grad_expected_4d[1][:][:][1] = 1 / 2.0
grad_expected_4d[1][:][:][2] = 1 / 1.0

# check backward
check_symbolic_backward(img_norm_sym, location=[data_in_4d], out_grads=[mx.nd.ones(shape_4d)],
expected=[grad_expected_4d], rtol=1e-5, atol=1e-5)

# check backward using finite difference
check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 7011380

Please sign in to comment.