Skip to content

Commit

Permalink
Fix sample_multinomial number of outputs bug (apache#14873)
Browse files Browse the repository at this point in the history
* Fix sample_multinomial number of outputs bug

* Fix lint
  • Loading branch information
reminisce authored and haohuw committed Jun 23, 2019
1 parent e0bb223 commit 18a70a8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!shape_is_known(ishape)) return false;
if (!ndim_is_known(ishape)) return false;

MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>())
Expand All @@ -95,7 +95,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1));
for (const auto& out_shape : *out_attrs) {
if (!shape_is_known(out_shape)) return false;
}
return true;
}


Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,18 @@ def test_randint_without_dtype():
a = mx.nd.random.randint(low=50000000, high=50000010, ctx=mx.context.current_context())
assert a.dtype == np.int32


@with_seed()
def test_sample_multinomial_num_outputs():
ctx = mx.context.current_context()
probs = [[0.125, 0.25, 0.25], [0.0625, 0.125, 0.1875]]
out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=False)
assert isinstance(out, mx.nd.NDArray)
out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=True)
assert isinstance(out, list)
assert len(out) == 2


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 18a70a8

Please sign in to comment.