diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 80469b5385eb..9e016bf884f2 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -144,6 +144,7 @@ bool ConcatType(const nnvm::NodeAttrs& attrs, const ConcatParam& param_ = nnvm::get(attrs.parsed); int dtype = -1; + // checks uniformity of input for (int i : *in_type) { if (dtype == -1) { dtype = i; @@ -154,18 +155,29 @@ bool ConcatType(const nnvm::NodeAttrs& attrs, } } - if (dtype == -1) { - LOG(FATAL) << "Not enough information to infer type in Concat."; - return false; - } - size_t nin = param_.num_args; - in_type->clear(); - for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype); - - out_type->clear(); - out_type->push_back(dtype); + // if in types are known out types are unknown + if (dtype != -1 && (*out_type)[0] == -1) { + (*out_type)[0] = dtype; + in_type->clear(); + for (size_t i = 0; i < nin; ++i) { + in_type->push_back(dtype); + } + // if out types are known in types are unknown + } else if ((*out_type)[0] != -1 && dtype == -1) { + in_type->clear(); + for (size_t i = 0; i < nin; ++i) { + in_type->push_back((*out_type)[0]); + } + // if both out_types and in_types are known, and different + } else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) { + std::ostringstream os; + os << "Type inconsistent, Provided output type = " + << mxnet::op::type_string((*out_type)[0]) << ',' + << " inferred type = " << mxnet::op::type_string(dtype); + throw mxnet::op::InferTypeError(os.str(), 0); + } return true; } diff --git a/tests/python/gpu/test_contrib_amp.py b/tests/python/gpu/test_contrib_amp.py index 7927cc99160b..3daab0f7bb6a 100644 --- a/tests/python/gpu/test_contrib_amp.py +++ b/tests/python/gpu/test_contrib_amp.py @@ -26,11 +26,12 @@ from nose.tools import assert_raises from mxnet.test_utils import set_default_context, download_model, same_symbol_structure from mxnet.gluon.model_zoo.vision import get_model -from mxnet.gluon import SymbolBlock +from mxnet.gluon import SymbolBlock, nn, rnn from mxnet.contrib.amp import amp curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) -from common import with_seed, teardown +from common import with_seed, teardown, assert_raises_cudnn_not_satisfied +set_default_context(mx.gpu(0)) def test_amp_coverage(): conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS] @@ -305,6 +306,20 @@ def check_amp_convert_hybrid_block(): check_amp_convert_model() check_amp_convert_hybrid_block() +@with_seed() +@assert_raises_cudnn_not_satisfied(min_version='5.1.10') +def test_amp_conversion_rnn(): + with mx.Context(mx.gpu(0)): + model = nn.HybridSequential() + model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True)) + model.add(nn.Dense(2)) + model.initialize() + model.hybridize() + out = model(mx.nd.ones((2, 3, 4))) + new_model = amp.convert_hybrid_block(model) + out2 = new_model(mx.nd.ones((2, 3, 4))) + mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2) + @with_seed() def test_module_backward_compatibility():