Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix ConcatType backward type inference (#15829)
Browse files Browse the repository at this point in the history
* Fix ConcatType and add test

* Remove return false

* Change error message

* Run RNN test only when CUDNN enabled

* set default context for test_contrib_amp
  • Loading branch information
anirudh2290 committed Aug 15, 2019
1 parent 9dbfc2d commit 40593c6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
32 changes: 22 additions & 10 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
int dtype = -1;

// checks uniformity of input
for (int i : *in_type) {
if (dtype == -1) {
dtype = i;
Expand All @@ -154,18 +155,29 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
}
}

if (dtype == -1) {
LOG(FATAL) << "Not enough information to infer type in Concat.";
return false;
}

size_t nin = param_.num_args;
in_type->clear();
for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);

out_type->clear();
out_type->push_back(dtype);

// if in types are known out types are unknown
if (dtype != -1 && (*out_type)[0] == -1) {
(*out_type)[0] = dtype;
in_type->clear();
for (size_t i = 0; i < nin; ++i) {
in_type->push_back(dtype);
}
// if out types are known in types are unknown
} else if ((*out_type)[0] != -1 && dtype == -1) {
in_type->clear();
for (size_t i = 0; i < nin; ++i) {
in_type->push_back((*out_type)[0]);
}
// if both out_types and in_types are known, and different
} else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) {
std::ostringstream os;
os << "Type inconsistent, Provided output type = "
<< mxnet::op::type_string((*out_type)[0]) << ','
<< " inferred type = " << mxnet::op::type_string(dtype);
throw mxnet::op::InferTypeError(os.str(), 0);
}
return true;
}

Expand Down
19 changes: 17 additions & 2 deletions tests/python/gpu/test_contrib_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
from nose.tools import assert_raises
from mxnet.test_utils import set_default_context, download_model, same_symbol_structure
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon import SymbolBlock
from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed, teardown
from common import with_seed, teardown, assert_raises_cudnn_not_satisfied
set_default_context(mx.gpu(0))

def test_amp_coverage():
conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
Expand Down Expand Up @@ -305,6 +306,20 @@ def check_amp_convert_hybrid_block():
check_amp_convert_model()
check_amp_convert_hybrid_block()

@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_amp_conversion_rnn():
with mx.Context(mx.gpu(0)):
model = nn.HybridSequential()
model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True))
model.add(nn.Dense(2))
model.initialize()
model.hybridize()
out = model(mx.nd.ones((2, 3, 4)))
new_model = amp.convert_hybrid_block(model)
out2 = new_model(mx.nd.ones((2, 3, 4)))
mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)


@with_seed()
def test_module_backward_compatibility():
Expand Down

0 comments on commit 40593c6

Please sign in to comment.