Skip to content

Commit

Permalink
Add an absolute tolerance in multinomial_test to unblock an XLA optim…
Browse files Browse the repository at this point in the history
…ization.

PiperOrigin-RevId: 410337145
  • Loading branch information
bloops authored and DistraxDev committed Nov 18, 2021
1 parent ccedc58 commit 53bf52d
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions distrax/_src/distributions/multinomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


RTOL = 1e-3
ATOL = 1e-6


class MultinomialTest(equivalence.EquivalenceTest, parameterized.TestCase):
Expand All @@ -39,7 +40,8 @@ def setUp(self):
[4, 3], dtype=np.float32) # float dtype required for TFP
self.probs = 0.5 * np.asarray([0.1, 0.4, 0.2, 0.3]) # unnormalized
self.logits = np.log(self.probs)
self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL)
self.assertion_fn = lambda x, y: np.testing.assert_allclose( # pylint: disable=g-long-lambda
x, y, rtol=RTOL, atol=ATOL)

@parameterized.named_parameters(
('from probs', False),
Expand Down Expand Up @@ -554,10 +556,12 @@ def test_method(self, function_string, dist_params):
assertion_fn=self.assertion_fn)

def test_jittable(self):
super()._test_jittable(dist_kwargs={
'probs': np.asarray([1.0, 0.0, 0.0]),
'total_count': np.asarray([3, 10]),
})
super()._test_jittable(
dist_kwargs={
'probs': np.asarray([1.0, 0.0, 0.0]),
'total_count': np.asarray([3, 10])
},
assertion_fn=self.assertion_fn)

@parameterized.named_parameters(
('single element', 2),
Expand Down

0 comments on commit 53bf52d

Please sign in to comment.