Skip to content

Commit

Permalink
[Numpy] Port nd.random.multinomial to npx.sample_categorical (apache#…
Browse files Browse the repository at this point in the history
…18272)

* port nd.multinomial to npx.sample_categorical

* move to npx.random
  • Loading branch information
xidulu committed May 11, 2020
1 parent 1d14bf3 commit 9d44086
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/operator/random/sample_multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ DMLC_REGISTER_PARAMETER(SampleMultinomialParam);

NNVM_REGISTER_OP(_sample_multinomial)
.add_alias("sample_multinomial")
.add_alias("_npx__random_categorical")
.describe(R"code(Concurrent sampling from multiple multinomial distributions.
*data* is an *n* dimensional array whose last dimension has length *k*, where
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4543,6 +4543,33 @@ def hybrid_forward(self, F, mean, cov):
assert list(desired_shape) == list(actual_shape)


@with_seed()
@use_np
def test_npx_categorical():
class TestNumpyCategorical(HybridBlock):
def __init__(self, size=None):
super(TestNumpyCategorical, self).__init__()
self.size = size

def hybrid_forward(self, F, prob):
if self.size is None:
return F.npx.random.categorical(prob)
return F.npx.random.categorical(prob, shape=self.size)

batch_sizes = [(2,), (2, 3)]
event_shapes = [None, (10,), (10, 12)]
num_event = [2, 4, 10]
for batch_size, num_event, event_shape in itertools.product(batch_sizes, num_event, event_shapes):
for hybridize in [True, False]:
prob = np.ones(batch_size + (num_event,)) / num_event
net = TestNumpyCategorical(event_shape)
if hybridize:
net.hybridize()
mx_out = net(prob)
desired_shape = batch_size + event_shape if event_shape is not None else batch_size
assert mx_out.shape == desired_shape


@with_seed()
@use_np
def test_random_seed():
Expand Down

0 comments on commit 9d44086

Please sign in to comment.