Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
indentation and mem alloc
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya authored and Ubuntu committed Aug 21, 2019
1 parent 7b540ec commit 8b01c06
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 46 deletions.
10 changes: 4 additions & 6 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,6 @@ void TopKImpl(const RunContext &ctx,
mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
// Additional temp space for gpu full sorts for batch ids.
temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
// Temp space for cpu sorts.
Expand Down Expand Up @@ -563,13 +561,13 @@ void TopK(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
})
});
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
});
}
Expand Down Expand Up @@ -695,13 +693,13 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnBoth) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
});
});
} else if (param.ret_typ == topk_enum::kReturnValue) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
});
} else {
Expand Down
79 changes: 39 additions & 40 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4346,46 +4346,45 @@ def get_large_matrix():
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="indices", k=1,
is_ascend=True)])
is_ascend=False)])
for dtype in [np.float16, np.float32, np.float64]:
dshape = (5, 5, 5, 5)
a_npy = np.arange(np.prod(dshape)).astype(dtype)
np.random.shuffle(a_npy)
a_npy = a_npy.reshape(dshape)
a = mx.sym.Variable('a')
for axis in [1, 3, None]:
K = [1, 3, 5, 7] if axis is None else [1, 3, 5]
for k in K:
for is_ascend in [True, False]:
b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k)
out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend)
check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx)
check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy])

b = mx.sym.topk(a, axis=1, is_ascend=is_ascend, ret_typ="indices", k=5)
check_symbolic_backward(sym=b, location={'a': large_matrix_npy},
out_grads=[np.random.normal(size=(100, 5))],
expected=[np.zeros((100, 300096))])
check_symbolic_forward(b, location={'a': large_matrix_npy},
expected=[gt_topk(dat=large_matrix_npy, axis=1,
ret_typ="indices", k=5,
is_ascend=is_ascend)])

b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 3))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3,
is_ascend=False)])

b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3,
is_ascend=True)])

for dtype in [np.float16, np.float32, np.float64]:
dshape = (5, 5, 5, 5)
a_npy = np.arange(np.prod(dshape)).astype(dtype)
np.random.shuffle(a_npy)
a_npy = a_npy.reshape(dshape)
a = mx.sym.Variable('a')
for axis in [1, 3, None]:
K = [1, 3, 5, 7] if axis is None else [1, 3, 5]
for k in K:
for is_ascend in [True, False]:
b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k)
out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend)
check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx)
check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy])

b = mx.sym.topk(a, axis=1, is_ascend=is_ascend, ret_typ="indices", k=5)
check_symbolic_backward(sym=b, location={'a': large_matrix_npy},
out_grads=[np.random.normal(size=(100, 5))],
expected=[np.zeros((100, 300096))])
check_symbolic_forward(b, location={'a': large_matrix_npy},
expected=[gt_topk(dat=large_matrix_npy, axis=1,
ret_typ="indices", k=5, is_ascend=is_ascend)])

b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 3))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3,
is_ascend=False)])

b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3,
is_ascend=True)])


@with_seed()
Expand Down

0 comments on commit 8b01c06

Please sign in to comment.