Skip to content

Commit

Permalink
Replace use of deprecated jax test utility
Browse files Browse the repository at this point in the history
Why? It is unused in JAX, and we plan to remove it in jax-ml/jax#7476

PiperOrigin-RevId: 388538856
  • Loading branch information
Jake VanderPlas authored and DistraxDev committed Aug 4, 2021
1 parent 3e21f38 commit b05a161
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions distrax/_src/distributions/mvn_diag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from distrax._src.distributions import mvn_diag
from distrax._src.distributions import normal
from distrax._src.utils import equivalence
import jax
import jax.numpy as jnp
import jax.test_util as jtu
import numpy as np


Expand Down Expand Up @@ -147,7 +147,7 @@ def test_sample_shape(self, distr_params, sample_shape):
sample_shape=sample_shape)

@chex.all_variants
@jtu.disable_implicit_rank_promotion
@jax.numpy_rank_promotion('raise')
@parameterized.named_parameters(
('1d std normal, no shape',
{'scale_diag': np.ones((1,))},
Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/distributions/normal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import chex
from distrax._src.distributions import normal
from distrax._src.utils import equivalence
import jax.test_util as jtu
import jax
import numpy as np


Expand Down Expand Up @@ -65,7 +65,7 @@ def test_sample_shape(self, distr_params, sample_shape):
super()._test_sample_shape(distr_params, dict(), sample_shape)

@chex.all_variants
@jtu.disable_implicit_rank_promotion
@jax.numpy_rank_promotion('raise')
@parameterized.named_parameters(
('1d std normal, no shape', (0, 1), ()),
('1d std normal, int shape', (0, 1), 1),
Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/distributions/uniform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import chex
from distrax._src.distributions import uniform
from distrax._src.utils import equivalence
import jax.test_util as jtu
import jax
import numpy as np


Expand Down Expand Up @@ -62,7 +62,7 @@ def test_sample_shape(self, distr_params, sample_shape):
super()._test_sample_shape(distr_params, dict(), sample_shape)

@chex.all_variants
@jtu.disable_implicit_rank_promotion
@jax.numpy_rank_promotion('raise')
@parameterized.named_parameters(
('1d, no shape', (0., 1.), ()),
('1d, int shape', (0., 1.), 1),
Expand Down

0 comments on commit b05a161

Please sign in to comment.