Skip to content

Commit

Permalink
add test, local test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Jul 17, 2019
1 parent dfba360 commit d576306
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/operator/numpy/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ NNVM_REGISTER_OP(_backward_np_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<NumpyTakeParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyTakeOpBackward<cpu>);

Expand Down
2 changes: 1 addition & 1 deletion src/operator/numpy/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void NumpyTakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[take_::kOut] == kNullOp) return;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
const NumpyTakeParam& param = nnvm::get<NumpyTakeParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

Expand Down
11 changes: 10 additions & 1 deletion src/operator/numpy/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,16 @@ void NumpyTakeOpBackward(const nnvm::NodeAttrs& attrs,
s, idxshape.Size(), outputs[take_::kIdx].dptr<IType>());
}

if (!param.axis.has_value() || (param.axis.has_value() && param.axis.value() == 0)) {
bool flag = false;
if (!param.axis.has_value()) {
flag = true;
} else if (param.axis.value() == 0) {
flag = true;
} else if (param.axis.value() + arrshape.ndim() == 0) {
flag = true;
}

if (flag) {
int idxndim = idxshape.ndim();
Tensor<xpu, 1, IType> idx = inputs[1].get_with_shape<xpu, 1, IType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Expand Down
124 changes: 114 additions & 10 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TestTensordot(HybridBlock):
def __init__(self, axes):
super(TestTensordot, self).__init__()
self._axes = axes

def hybrid_forward(self, F, a, b):
return F.np.tensordot(a, b, self._axes)

Expand All @@ -58,7 +58,7 @@ def tensordot_backward(a, b, axes=2):
b_axes_summed = b_axes_summed,

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')
raise ValueError('Axes length mismatch')

a_axes_remained = []
for i in range(a.ndim):
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_np_dot():
((3, 4, 5), (5, )), # Case 4
((3, 4, 5), (5, 2)), # Case 5
((5,), (5, 2)),
((3, 5, 4), (5, 4, 3)),
((3, 5, 4), (5, 4, 3)),
((3, 4), (5, 4, 3)),
((4,), (5, 4, 3))
]
Expand Down Expand Up @@ -1013,15 +1013,15 @@ def hybrid_forward(self, F, a):
return F.np.argsort(a, self._axis)

shapes = [
(),
(1,),
(),
(1,),
(5,4),
(5,0,4),
(5,0,0),
(0,0,5),
(0,0,0),
(5,3,4)
]
]
for hybridize in [True, False]:
for shape in shapes:
for ax in list(range(len(shape))) + [-1, None]:
Expand Down Expand Up @@ -1129,15 +1129,15 @@ def test_np_hstack():
class TestHStack(HybridBlock):
def __init__(self):
super(TestHStack, self).__init__()

def hybrid_forward(self, F, a, *args):
return F.np.hstack([a] + list(args))

def get_new_shape(shape):
if len(shape) == 0:
l = random.randint(0,3)
if l == 0:
return shape
return shape
else:
return (l,)
shape_lst = list(shape)
Expand Down Expand Up @@ -1188,6 +1188,110 @@ def get_new_shape(shape):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@npx.use_np_shape
def test_np_take():
configs = [
((4, 4), (4, 0), None),
((4, 4), (4, 0), 0),
((4, 4), (4, 0), 1),
((), (4, 0), None),
((), (5, ), None),
((), (4, 5), None),
((), (), None),
((3, 4), (), None),
((3, 4), (), 0),
((3, 4), (), 1),
]

class TestTake(HybridBlock):
def __init__(self, axis, mode):
super(TestTake, self).__init__()
self._axis = axis
self._mode = mode

def hybrid_forward(self, F, a, indices):
return F.np.take(a, indices, axis=self._axis, mode=self._mode)

def grad_helper(grad_in, axis, idx, mode):
k = grad_in.shape[axis]
if mode == 'clip':
idx = 0 if idx < 0 else idx
idx = k - 1 if idx >= k else idx
else:
idx = idx % k
if axis == None:
grad_in[idx] += 1.0
elif axis == 0:
if axis == len(grad_in.shape) - 1:
grad_in[idx] += 1.0
else:
grad_in[idx, :] += 1.0
elif axis == 1:
if axis == len(grad_in.shape) - 1:
grad_in[:, idx] += 1.0
else:
grad_in[:, idx, :] += 1.0
elif axis == 2:
if axis == len(grad_in.shape) - 1:
grad_in[:, :, idx] += 1.0
else:
grad_in[:, :, idx, :] += 1.0
elif axis == 3:
if axis == len(grad_in.shape) - 1:
grad_in[:, :, :, idx] += 1.0
else:
grad_in[:, :, :, idx, :] += 1.0
elif axis == 4:
grad_in[:, :, :, :, idx] += 1.0
else:
raise ValueError("axis %d is not supported..." % axis)

def check_output_n_grad(data_shape, idx_shape, axis, mode):
data_real = _np.random.normal(size=data_shape).astype('float32')
idx_real = _np.random.randint(low=-100, high=100, size=idx_shape)
same(np.take(np.array(data_real), np.array(idx_real), axis=axis, mode=mode).asnumpy(),
_np.take(data_real, idx_real, axis=axis, mode=mode))

grad_in = _np.zeros(data_shape, dtype='float32')

test_take = TestTake(axis=axis, mode=mode)
if hybridize:
test_take.hybridize()
x = np.array(data_real)
x.attach_grad()
with mx.autograd.record():
mx_out = test_take(x, np.array(idx_real))
same(mx_out.asnumpy(), _np.take(data_real, idx_real, axis=axis, mode=mode))

if axis and axis < 0:
axis += len(data_shape)
try:
for i in _np.nditer(idx_real):
grad_helper(grad_in, axis, i, mode)
except:
pass

mx_out.backward()
same(x.grad.asnumpy(), grad_in)

for hybridize in [True, False]:
for mode in ['clip', 'wrap']:
for data_ndim in range(1, 5):
for idx_ndim in range(1, 4):
for axis in range(-data_ndim, data_ndim):
data_shape = ()
for _ in range(data_ndim):
data_shape += (_np.random.randint(low=1, high=5), )
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (_np.random.randint(low=1, high=5), )
check_output_n_grad(data_shape, idx_shape, axis, mode)

for config in configs:
check_output_n_grad(config[0], config[1], config[2], mode)


@with_seed()
@npx.use_np_shape
def test_np_swapaxes():
Expand Down Expand Up @@ -1499,10 +1603,10 @@ def __init__(self, axis1, axis2, offset):
self._axis1 = axis1
self._axis2 = axis2
self._offset = offset

def hybrid_forward(self, F, data):
return F.np.trace(data, axis1=self._axis1, axis2=self._axis2, offset=self._offset)

def g(data, axis1, axis2, offset):
idx = _np.indices(data.shape)
ret = _np.zeros_like(data)
Expand Down
14 changes: 10 additions & 4 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,7 +4159,13 @@ def test_blockgrad():

@with_seed()
def test_take():
def grad_helper(grad_in, axis, idx):
def grad_helper(grad_in, axis, idx, mode):
k = grad_in.shape[axis]
if mode == 'clip':
idx = 0 if idx < 0 else idx
idx = k - 1 if idx >= k else idx
else:
idx = idx % k
if axis == 0:
if axis == len(grad_in.shape) - 1:
grad_in[idx] += 1.0
Expand Down Expand Up @@ -4193,7 +4199,7 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
exe = result.simple_bind(default_context(), a=data_shape,
indices=idx_shape, axis=axis, mode=mode)
data_real = np.random.normal(size=data_shape).astype('float32')
idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape)
idx_real = np.random.randint(low=data_shape[axis]*-100, high=data_shape[axis]*100, size=idx_shape)
if axis < 0:
axis += len(data_shape)

Expand All @@ -4206,10 +4212,10 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode))

for i in np.nditer(idx_real):
grad_helper(grad_in, axis, i)
grad_helper(grad_in, axis, i, mode)

exe.backward([mx.nd.array(grad_out)])
assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in)
same(exe.grad_dict['a'].asnumpy(), grad_in)

def check_autograd_req():
row_len = 2
Expand Down

0 comments on commit d576306

Please sign in to comment.