Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add unit test for backward
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Apr 2, 2019
1 parent 40b2546 commit 3f6583d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/operator/image/crop-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ inline void CropBackwardImpl(int x,
begin[1] = y;
begin[2] = x;
}
MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = output_grad.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
input_grad.dptr<DType>(), output_grad.dptr<DType>(),
output_grad.shape_.get<ndim>(), input_grad.shape_.get<ndim>(), begin, step);
input_grad.shape_.get<ndim>(), output_grad.shape_.get<ndim>(), begin, step);
})
})
})
Expand Down
40 changes: 32 additions & 8 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
from collections import namedtuple

import mxnet as mx
import mxnet.ndarray as nd
from mxnet.base import MXNetError
from mxnet import gluon
from mxnet.base import MXNetError
from mxnet.gluon.data.vision import transforms
from mxnet import image
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import almost_equal
from mxnet.test_utils import *
from common import assertRaises, setup_module, with_seed, teardown

import numpy as np
Expand Down Expand Up @@ -158,12 +159,35 @@ def _test_crop_resize_with_diff_type(dtype):

for dtype in ['uint8', 'float32', 'float64']:
_test_crop_resize_with_diff_type(dtype)
# test for gradient
data = mx.sym.Variable('data')
slice_sym = mx.sym.slice(data, begin=begin, end=end, step=step)
expected_in_grad = np.zeros_like(a_np)
expected_in_grad[index] = b_np
check_symbolic_backward(slice_sym, [a_np], [b_np], [expected_in_grad])

# test nd.image.crop backward
def test_crop_backward(test_nd_arr, TestCase):
a_np = test_nd_arr.asnumpy()
b_np = a_np[(slice(TestCase.y, TestCase.y + TestCase.height), slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))]

data = mx.sym.Variable('data')
crop_sym = mx.sym.image.crop(data, TestCase.x, TestCase.y, TestCase.width, TestCase.height)

expected_in_grad = np.zeros_like(a_np)
expected_in_grad[(slice(TestCase.y, TestCase.y + TestCase.height), slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))] = b_np
check_symbolic_backward(crop_sym, [a_np], [b_np], [expected_in_grad])

TestCase = namedtuple('TestCase', ['x', 'y', 'width', 'height'])
test_list = [TestCase(0, 0, 3, 3), TestCase(2, 1, 1, 2), TestCase(0, 1, 3, 2)]

for dtype in ['uint8', 'float32', 'float64']:
data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype)
for test_case in test_list:
test_crop_backward(data_in, test_case)



# check numeric gradient of nd.image.crop
# in_data = np.arange(36).reshape(3, 4, 3)
# data = mx.sym.Variable('data')
# image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2)
# check_numeric_gradient(image_crop_sym, [in_data])


@with_seed()
def test_flip_left_right():
Expand Down

0 comments on commit 3f6583d

Please sign in to comment.