Skip to content

Commit

Permalink
[Pallas] Simplify sign and erf_inv tests
Browse files Browse the repository at this point in the history
Removed the method to locally enabling x64 using:

```python
with contextlib.ExitStack() as stack:
  if jnp.dtype(dtype).itemsize == 8:
    stack.enter_context(config.enable_x64(True))
```

This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.

PiperOrigin-RevId: 677019633
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 21, 2024
1 parent d63afd8 commit b0176f3
Showing 1 changed file with 26 additions and 34 deletions.
60 changes: 26 additions & 34 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,21 +669,25 @@ def run(interpret=False):
actual = run(False)
self.assertAllClose(actual, expected)

SIGN_PARAMS = [
(jnp.int32, (-3, 0, 5)),
(jnp.uint32, (0, 5)),
(jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
(jnp.float64, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
]

@parameterized.named_parameters(
(f"{dtype.__name__}_{value}", dtype, value)
for dtype, values in SIGN_PARAMS
for dtypes, values in [
((jnp.uint16, jnp.uint32, jnp.uint64), (0, 5)),
((jnp.int16, jnp.int32, jnp.int64), (-3, 0, 5)),
(
(jnp.float16, jnp.float32, jnp.float64),
(-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf),
),
]
for dtype in dtypes
for value in values
)
def test_sign(self, dtype, value):
if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64:
self.skipTest("float64 is not supported on TPU")
if (
not jax.config.x64_enabled
and dtype in (jnp.uint64, jnp.int64, jnp.float64)
):
self.skipTest("64-bit types require x64_enabled")

@functools.partial(
self.pallas_call,
Expand All @@ -692,38 +696,26 @@ def test_sign(self, dtype, value):
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sign(x_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))

x = jnp.full((8, 128,), value, dtype=dtype)
out = kernel(x)
expected = jnp.sign(x)
np.testing.assert_array_equal(out, expected)
x = jnp.full((8, 128,), value, dtype=dtype)
out = kernel(x)
expected = jnp.sign(x)
np.testing.assert_array_equal(out, expected)

@parameterized.product(
dtype=[jnp.float32, jnp.float64],
value=[-3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4],
@parameterized.parameters(
-3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4
)
def test_erf_inv(self, dtype, value):
if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64:
self.skipTest("float64 is not supported on TPU")

def test_erf_inv(self, value):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
out_shape=jax.ShapeDtypeStruct((8, 128), floatx),
)
def kernel(x_ref, o_ref):
o_ref[...] = lax.erf_inv(x_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))

x = jnp.full((8, 128), value, dtype=dtype)
out = kernel(x)
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)
x = jnp.full((8, 128), value, dtype=floatx)
out = kernel(x)
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)


class OpsInterpretTest(OpsTest):
Expand Down

0 comments on commit b0176f3

Please sign in to comment.