diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 65cdde30f7d9..9688a574ab16 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -669,21 +669,25 @@ def run(interpret=False): actual = run(False) self.assertAllClose(actual, expected) - SIGN_PARAMS = [ - (jnp.int32, (-3, 0, 5)), - (jnp.uint32, (0, 5)), - (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - (jnp.float64, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - ] - @parameterized.named_parameters( (f"{dtype.__name__}_{value}", dtype, value) - for dtype, values in SIGN_PARAMS + for dtypes, values in ( + ((jnp.uint16, jnp.uint32, jnp.uint64), (0, 5)), + ((jnp.int16, jnp.int32, jnp.int64), (-3, 0, 5)), + ( + (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64), + (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf), + ), + ) + for dtype in dtypes for value in values ) def test_sign(self, dtype, value): - if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: - self.skipTest("float64 is not supported on TPU") + if ( + not jax.config.x64_enabled + and dtype in (jnp.uint64, jnp.int64, jnp.float64) + ): + self.skipTest("64-bit types require x64_enabled") @functools.partial( self.pallas_call, @@ -692,38 +696,26 @@ def test_sign(self, dtype, value): def kernel(x_ref, o_ref): o_ref[...] = jnp.sign(x_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - - x = jnp.full((8, 128,), value, dtype=dtype) - out = kernel(x) - expected = jnp.sign(x) - np.testing.assert_array_equal(out, expected) + x = jnp.full((8, 128,), value, dtype=dtype) + out = kernel(x) + expected = jnp.sign(x) + np.testing.assert_array_equal(out, expected) - @parameterized.product( - dtype=[jnp.float32, jnp.float64], - value=[-3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4], + @parameterized.parameters( + -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, ) - def test_erf_inv(self, dtype, value): - if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: - self.skipTest("float64 is not supported on TPU") - + def test_erf_inv(self, value): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + out_shape=jax.ShapeDtypeStruct((8, 128), floatx), ) def kernel(x_ref, o_ref): o_ref[...] = lax.erf_inv(x_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - - x = jnp.full((8, 128), value, dtype=dtype) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) + x = jnp.full((8, 128), value, dtype=floatx) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) class OpsInterpretTest(OpsTest):