From 5ba285bec12a6a9aed1e0f27e5c81f6e7f3b3540 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 3 May 2019 23:28:04 -0700 Subject: [PATCH] Fix sample_multinomial number of outputs bug (#14873) * Fix sample_multinomial number of outputs bug * Fix lint --- src/operator/random/sample_multinomial_op.h | 7 +++++-- tests/python/unittest/test_random.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index b38aefbc1634..377df4f313da 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -68,7 +68,7 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U); const mxnet::TShape& ishape = (*in_attrs)[0]; - if (!shape_is_known(ishape)) return false; + if (!ndim_is_known(ishape)) return false; MSHADOW_TYPE_SWITCH(param.dtype, DType, { CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue()) @@ -95,7 +95,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, } SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape); - return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1)); + for (const auto& out_shape : *out_attrs) { + if (!shape_is_known(out_shape)) return false; + } + return true; } diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 8fbd97d8a162..5e809d383cdf 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -916,6 +916,18 @@ def test_randint_without_dtype(): a = mx.nd.random.randint(low=50000000, high=50000010, ctx=mx.context.current_context()) assert a.dtype == np.int32 + +@with_seed() +def test_sample_multinomial_num_outputs(): + ctx = mx.context.current_context() + probs = [[0.125, 0.25, 0.25], [0.0625, 0.125, 0.1875]] + out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=False) + assert isinstance(out, mx.nd.NDArray) + out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=True) + assert isinstance(out, list) + assert len(out) == 2 + + if __name__ == '__main__': import nose nose.runmodule()