Skip to content

Commit

Permalink
Fix dirichlet flaky tests (apache#18817)
Browse files Browse the repository at this point in the history
* make parameter smoother

* minor changes
  • Loading branch information
xidulu authored Jul 30, 2020
1 parent 6bbd531 commit 608afef
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
12 changes: 6 additions & 6 deletions tests/python/unittest/test_gluon_probability_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def hybrid_forward(self, F, loc, scale, *args):
for shape, hybridize in itertools.product(shapes, [True, False]):
loc = np.random.uniform(-1, 1, shape)
scale = np.random.uniform(0.5, 1.5, shape)
samples = np.random.uniform(size=shape, high=1.0-1e-4)
samples = np.random.uniform(size=shape, low=1e-4, high=1.0-1e-4)
net = TestCauchy("icdf")
if hybridize:
net.hybridize()
Expand Down Expand Up @@ -837,15 +837,15 @@ def hybrid_forward(self, F, alpha, *args):
dirichlet = mgp.Dirichlet(alpha, F, validate_args=True)
return _distribution_method_invoker(dirichlet, self._func, *args)

event_shapes = [2, 5, 10]
event_shapes = [2, 4, 6]
batch_shapes = [None, (2, 3)]

# Test sampling
for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes):
for hybridize in [True, False]:
desired_shape = (
batch_shape if batch_shape is not None else ()) + (event_shape,)
alpha = np.random.uniform(size=desired_shape)
alpha = np.random.uniform(1.0, 5.0, size=desired_shape)
net = TestDirichlet("sample")
if hybridize:
net.hybridize()
Expand All @@ -862,9 +862,9 @@ def hybrid_forward(self, F, alpha, *args):
for hybridize in [True, False]:
desired_shape = (
batch_shape if batch_shape is not None else ()) + (event_shape,)
alpha = np.random.uniform(size=desired_shape)
alpha = np.random.uniform(1.0, 5.0, desired_shape)
np_samples = _np.random.dirichlet(
[1 / event_shape] * event_shape, size=batch_shape)
[10.0 / event_shape] * event_shape, size=batch_shape)
net = TestDirichlet("log_prob")
if hybridize:
net.hybridize()
Expand All @@ -879,7 +879,7 @@ def hybrid_forward(self, F, alpha, *args):
for func in ['mean', 'variance', 'entropy']:
desired_shape = (
batch_shape if batch_shape is not None else ()) + (event_shape,)
alpha = np.random.uniform(size=desired_shape)
alpha = np.random.uniform(1.0, 5.0, desired_shape)
net = TestDirichlet(func)
if hybridize:
net.hybridize()
Expand Down
10 changes: 5 additions & 5 deletions tests/python/unittest/test_gluon_probability_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,15 +837,15 @@ def forward(self, alpha, *args):
dirichlet = mgp.Dirichlet(alpha, validate_args=True)
return _distribution_method_invoker(dirichlet, self._func, *args)

event_shapes = [2, 5, 10]
event_shapes = [2, 4, 6]
batch_shapes = [None, (2, 3)]

# Test sampling
for event_shape, batch_shape in itertools.product(event_shapes, batch_shapes):
for hybridize in [True, False]:
desired_shape = (
batch_shape if batch_shape is not None else ()) + (event_shape,)
alpha = np.random.uniform(size=desired_shape)
alpha = np.random.uniform(1.0, 5.0, size=desired_shape)
net = TestDirichlet("sample")
if hybridize:
net.hybridize()
Expand All @@ -862,9 +862,9 @@ def forward(self, alpha, *args):
for hybridize in [True, False]:
desired_shape = (
batch_shape if batch_shape is not None else ()) + (event_shape,)
alpha = np.random.uniform(size=desired_shape)
alpha = np.random.uniform(1.0, 5.0, size=desired_shape)
np_samples = _np.random.dirichlet(
[1 / event_shape] * event_shape, size=batch_shape)
[10.0 / event_shape] * event_shape, size=batch_shape)
net = TestDirichlet("log_prob")
if hybridize:
net.hybridize()
Expand All @@ -879,7 +879,7 @@ def forward(self, alpha, *args):
for func in ['mean', 'variance', 'entropy']:
desired_shape = (
batch_shape if batch_shape is not None else ()) + (event_shape,)
alpha = np.random.uniform(size=desired_shape)
alpha = np.random.uniform(1.0, 5.0, desired_shape)
net = TestDirichlet(func)
if hybridize:
net.hybridize()
Expand Down

0 comments on commit 608afef

Please sign in to comment.