Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix dtype inference in arange_like operator (#15930)
Browse files Browse the repository at this point in the history
* fix dtype in arange_like operator

* add unit test
  • Loading branch information
TaoLv authored and eric-haibin-lin committed Aug 25, 2019
1 parent 3b7e484 commit c7a8a78
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Examples::
.set_num_outputs(1)
.set_attr_parser(ParamParser<RangeLikeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", RangeLikeShape)
.set_attr<nnvm::FInferType>("FInferType", InitType<RangeLikeParam, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
[](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 0); })
.set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu, RangeLikeParam>)
Expand Down
4 changes: 0 additions & 4 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ struct RangeLikeParam : public dmlc::Parameter<RangeLikeParam> {
double step;
int repeat;
std::string ctx;
int dtype;
dmlc::optional<int> axis;

DMLC_DECLARE_PARAMETER(RangeLikeParam) {
Expand All @@ -197,9 +196,6 @@ struct RangeLikeParam : public dmlc::Parameter<RangeLikeParam> {
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<int>())
.describe("Arange elements according to the size of a certain axis of input array."
Expand Down
16 changes: 16 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,22 @@ def test_math():
for op in ops:
run_math(op, shape, dtype, check_value=check_value)

@with_seed()
def test_arange_like_dtype():
dtypes = [np.float16, np.float32, np.float64]

for t in dtypes:
x = mx.sym.Variable('x', dtype=t)
y = mx.sym.reshape(x, shape=(0, 0, -1))
z = mx.sym.contrib.arange_like(y, axis=-1)

mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null')
mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t)
out = mod.forward(is_train=False)
for v in out:
assert v.dtype == t


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit c7a8a78

Please sign in to comment.