diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 265554ab3918..54a8d224ff42 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -141,8 +141,8 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { if (v.empty()) { nnvm::NodePtr ng = nnvm::Node::Create(); - ng->attrs.op = zeros_op; - ng->attrs.name = "zeros"; + ng->attrs.op = Op::Get("_zeros_without_dtype"); + ng->attrs.name = "zeros_without_dtype"; ng->attrs.op->attr_parser(&(ng->attrs)); return nnvm::NodeEntry{ng, 0, 0}; } diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index bb23f5d44f64..8554ba854178 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -30,9 +30,22 @@ namespace op { DMLC_REGISTER_PARAMETER(InitOpParam); DMLC_REGISTER_PARAMETER(InitOpWithScalarParam); +DMLC_REGISTER_PARAMETER(InitOpWithoutDTypeParam); DMLC_REGISTER_PARAMETER(RangeParam); DMLC_REGISTER_PARAMETER(EyeParam); +NNVM_REGISTER_OP(_zeros_without_dtype) +.describe("fill target with zeros without default dtype") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", InitShape) +.set_attr("FInferType", InitType) +.set_attr("FInferStorageType", + InitStorageType) +.set_attr("FCompute", FillCompute) +.set_attr("FComputeEx", FillComputeZerosEx) +.add_arguments(InitOpWithoutDTypeParam::__FIELDS__()); NNVM_REGISTER_OP(_zeros) .describe("fill target with zeros") diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu index 81d835ee3bd2..902b567516bd 100644 --- a/src/operator/tensor/init_op.cu +++ b/src/operator/tensor/init_op.cu @@ -44,6 +44,9 @@ void FillZerosCsrImpl(mshadow::Stream *s, const NDArray& dst) { }); } +NNVM_REGISTER_OP(_zeros_without_dtype) +.set_attr("FCompute", FillCompute) +.set_attr("FComputeEx", FillComputeZerosEx); NNVM_REGISTER_OP(_zeros) .set_attr("FCompute", FillCompute) diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 304911a02a78..1a4790acdb26 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -61,6 +61,24 @@ struct InitOpParam : public dmlc::Parameter { } }; +struct InitOpWithoutDTypeParam : public dmlc::Parameter { + TShape shape; + std::string ctx; + int dtype; + DMLC_DECLARE_PARAMETER(InitOpWithoutDTypeParam) { + DMLC_DECLARE_FIELD(shape) + .set_default(TShape()) + .describe("The shape of the output"); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .set_default(-1) + .describe("Target data type."); + } +}; + struct EyeParam : public dmlc::Parameter { nnvm::dim_t N; nnvm::dim_t M; diff --git a/tests/python/unittest/test_infer_type.py b/tests/python/unittest/test_infer_type.py new file mode 100644 index 000000000000..bad83f3ef01b --- /dev/null +++ b/tests/python/unittest/test_infer_type.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +import mxnet as mx +import numpy as np +from common import models, with_seed +from mxnet import autograd +from nose.tools import * +from mxnet.test_utils import assert_almost_equal + +@with_seed() +def test_infer_multiout_op(): + data = mx.nd.arange(16, dtype=np.float64).reshape((4, 4)) + data.attach_grad() + + with autograd.record(): + y = mx.nd.split(data, axis=0, num_outputs=2) + y[0].backward() + assert data.grad.dtype == np.float64 + +@with_seed() +def test_infer_multiout_op2(): + def test_func(a): + q, l = mx.nd.linalg.gelqf(a) + return mx.nd.sum(l) + + data32 = mx.nd.random.normal(shape=(2, 3), ctx=mx.cpu(), dtype=np.float32) + data32.attach_grad() + with autograd.record(): + test32 = test_func(data32) + test32.backward() + + data64 = mx.nd.Cast(data32, dtype=np.float64) + data64.attach_grad() + with autograd.record(): + test64 = test_func(data64) + test64.backward() + assert_almost_equal(data64.grad.asnumpy(), data32.grad.asnumpy(), atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + import nose + nose.runmodule() \ No newline at end of file