Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] fix #18936, #18937 (#19878)
Browse files Browse the repository at this point in the history
* fix #18938

* fix #18939, #18940

* fix #18936 and #18937

Co-authored-by: r3stl355 <[email protected]>
  • Loading branch information
r3stl355 and ulmasov authored Apr 30, 2021
1 parent 059a055 commit 6f4ac54
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/operator/random/pdf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ void PdfOpForward(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[0], kAddTo);
CHECK_EQ(inputs.size(), pnum + 1);
CHECK_EQ(outputs.size(), 1);

// Skip kernel launch for zero-size tensors
if (inputs[1].shape_.Size() == 0U || outputs[0].Size() == 0U) {
return;
}

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Expand Down
51 changes: 51 additions & 0 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import unittest
import pytest
from mxnet.test_utils import *
from mxnet.base import MXNetError
from common import assertRaises

def same(a, b):
return np.sum(a != b) == 0
Expand Down Expand Up @@ -1029,3 +1031,52 @@ def test_sample_multinomial_num_outputs():
assert isinstance(out, list)
assert len(out) == 2


@use_np
def test_dirichlet_zero_size_dim():
""" Tests for no error when dealing with zero-size array in calculating PDF of Poisson distribution
Issue: https://github.com/apache/incubator-mxnet/issues/18936
"""

def test_valid_zero_dim():
alpha = mx.nd.array(np.random.rand(0))
sample = mx.nd.array(np.random.rand(4, 0))
res = mx.nd.op.random_pdf_dirichlet(sample=sample, alpha=alpha)
assert res.shape == sample.shape[:-1]

def test_valid_zero_multi_dim():
alpha = mx.nd.array(np.random.rand(4, 0))
sample = mx.nd.array(np.random.rand(4, 3, 0))
res = mx.nd.op.random_pdf_dirichlet(sample=sample, alpha=alpha)
assert res.shape == sample.shape[:-1]

def test_invalid_zero_dim():
"""The shape of *alpha* must match the left-most part of the *sample* shape"""
alpha = mx.nd.array(np.random.rand(1))
sample = mx.nd.array(np.random.rand(4, 0))
assertRaises(MXNetError, mx.nd.op.random_pdf_dirichlet, sample, alpha)

test_valid_zero_dim()
test_valid_zero_multi_dim()
test_invalid_zero_dim()

@use_np
def test_poisson_zero_size_dim():
""" Tests for no error when dealing with zero-size array in calculating PDF of Poisson distribution
Issue: https://github.com/apache/incubator-mxnet/issues/18937
"""

def test_valid_zero_dim():
lam = mx.nd.array(np.random.rand(0))
sample = mx.nd.array(np.random.rand(0, 2))
res = mx.nd.op.random_pdf_poisson(sample=sample, lam=lam)
assert res.shape == sample.shape

def test_invalid_zero_dim():
"""The shape of *lam* must match the leftmost part of the *sample* shape"""
lam = mx.nd.array(np.random.rand(0))
sample = mx.nd.array(np.random.rand(1, 2))
assertRaises(MXNetError, mx.nd.op.random_pdf_poisson, sample, lam)

test_valid_zero_dim()
test_invalid_zero_dim()

0 comments on commit 6f4ac54

Please sign in to comment.