diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 172665fbbf12..4e8900be24ca 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -129,7 +129,7 @@ Examples:: .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", RangeLikeShape) -.set_attr("FInferType", InitType) +.set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 0); }) .set_attr("FCompute", RangeCompute) diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 51c84363489a..f3c405d7103c 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -179,7 +179,6 @@ struct RangeLikeParam : public dmlc::Parameter { double step; int repeat; std::string ctx; - int dtype; dmlc::optional axis; DMLC_DECLARE_PARAMETER(RangeLikeParam) { @@ -197,9 +196,6 @@ struct RangeLikeParam : public dmlc::Parameter { .set_default("") .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." "Only used for imperative calls."); - DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) - MXNET_ADD_ALL_TYPES - .describe("Target data type."); DMLC_DECLARE_FIELD(axis) .set_default(dmlc::optional()) .describe("Arange elements according to the size of a certain axis of input array." diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f8d8b4496afc..db550e4254b7 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2331,6 +2331,22 @@ def test_math(): for op in ops: run_math(op, shape, dtype, check_value=check_value) +@with_seed() +def test_arange_like_dtype(): + dtypes = [np.float16, np.float32, np.float64] + + for t in dtypes: + x = mx.sym.Variable('x', dtype=t) + y = mx.sym.reshape(x, shape=(0, 0, -1)) + z = mx.sym.contrib.arange_like(y, axis=-1) + + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') + mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) + out = mod.forward(is_train=False) + for v in out: + assert v.dtype == t + + if __name__ == '__main__': import nose nose.runmodule()