Skip to content

Commit

Permalink
[MXNET-798] Fix the dtype cast from non float32 in Gradient computati…
Browse files Browse the repository at this point in the history
…on (apache#12290)

* Fix the dtype mismatch in derived _zeros node

* Add unittest for infer dtype

* Add one more unit test

* Add nose runmodule

* Add a zero operator with no default dtype

* Rename variables

* fix a bug: rename operator for gpu

* Increase atol and rtol to avoid flakiness
  • Loading branch information
apeforest authored and anirudh2290 committed Sep 19, 2018
1 parent b5a0852 commit c6bafde
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {

if (v.empty()) {
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = zeros_op;
ng->attrs.name = "zeros";
ng->attrs.op = Op::Get("_zeros_without_dtype");
ng->attrs.name = "zeros_without_dtype";
ng->attrs.op->attr_parser(&(ng->attrs));
return nnvm::NodeEntry{ng, 0, 0};
}
Expand Down
13 changes: 13 additions & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,22 @@ namespace op {

DMLC_REGISTER_PARAMETER(InitOpParam);
DMLC_REGISTER_PARAMETER(InitOpWithScalarParam);
DMLC_REGISTER_PARAMETER(InitOpWithoutDTypeParam);
DMLC_REGISTER_PARAMETER(RangeParam);
DMLC_REGISTER_PARAMETER(EyeParam);

NNVM_REGISTER_OP(_zeros_without_dtype)
.describe("fill target with zeros without default dtype")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InitOpWithoutDTypeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", InitShape<InitOpWithoutDTypeParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpWithoutDTypeParam>)
.set_attr<FInferStorageType>("FInferStorageType",
InitStorageType<InitOpWithoutDTypeParam, true, true>)
.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
.set_attr<FComputeEx>("FComputeEx<cpu>", FillComputeZerosEx<cpu>)
.add_arguments(InitOpWithoutDTypeParam::__FIELDS__());

NNVM_REGISTER_OP(_zeros)
.describe("fill target with zeros")
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst) {
});
}

NNVM_REGISTER_OP(_zeros_without_dtype)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
.set_attr<FComputeEx>("FComputeEx<gpu>", FillComputeZerosEx<gpu>);

NNVM_REGISTER_OP(_zeros)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
Expand Down
18 changes: 18 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
}
};

struct InitOpWithoutDTypeParam : public dmlc::Parameter<InitOpWithoutDTypeParam> {
TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(InitOpWithoutDTypeParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(TShape())
.describe("The shape of the output");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.set_default(-1)
.describe("Target data type.");
}
};

struct EyeParam : public dmlc::Parameter<EyeParam> {
nnvm::dim_t N;
nnvm::dim_t M;
Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_infer_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
import mxnet as mx
import numpy as np
from common import models, with_seed
from mxnet import autograd
from nose.tools import *
from mxnet.test_utils import assert_almost_equal

@with_seed()
def test_infer_multiout_op():
data = mx.nd.arange(16, dtype=np.float64).reshape((4, 4))
data.attach_grad()

with autograd.record():
y = mx.nd.split(data, axis=0, num_outputs=2)
y[0].backward()
assert data.grad.dtype == np.float64

@with_seed()
def test_infer_multiout_op2():
def test_func(a):
q, l = mx.nd.linalg.gelqf(a)
return mx.nd.sum(l)

data32 = mx.nd.random.normal(shape=(2, 3), ctx=mx.cpu(), dtype=np.float32)
data32.attach_grad()
with autograd.record():
test32 = test_func(data32)
test32.backward()

data64 = mx.nd.Cast(data32, dtype=np.float64)
data64.attach_grad()
with autograd.record():
test64 = test_func(data64)
test64.backward()
assert_almost_equal(data64.grad.asnumpy(), data32.grad.asnumpy(), atol=1e-5, rtol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit c6bafde

Please sign in to comment.