diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index bb4b414dffb8..c40d0e1c963e 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -92,10 +92,26 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): if a.dtype == b.dtype == _dtypes.float0: np.testing.assert_array_equal(a, b, err_msg=err_msg) return - custom_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, - _dtypes.float8_e5m2, _dtypes.bfloat16] - a = a.astype(np.float32) if a.dtype in custom_dtypes else a - b = b.astype(np.float32) if b.dtype in custom_dtypes else b + + custom_float_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, + _dtypes.float8_e5m2, _dtypes.bfloat16] + def maybe_upcast(x): + if x.dtype in custom_float_dtypes: + return x.astype(np.float32) + # TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once + # ml_dtypes has a stable release with commit + # https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10. + # Remove these checks once JAX depends on a version on ml_dtypes with that + # commit. + if x.dtype == _dtypes.int4: + return x.astype(np.int8) + if x.dtype == _dtypes.uint4: + return x.astype(np.uint8) + return x + + a = maybe_upcast(a) + b = maybe_upcast(b) + kw = {} if atol: kw["atol"] = atol if rtol: kw["rtol"] = rtol diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 5b0b5dab53b0..4a1829311b2c 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -49,6 +49,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_extension_version from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps from jax._src.util import safe_zip, NumpyComplexWarning @@ -3526,6 +3527,24 @@ def testAstypeNone(self): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @unittest.skipIf(xla_extension_version < 210, 'jaxlib version too old') + def testAstypeInt4(self): + # Test converting from int4 to int8 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int8) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # Test converting from int8 to int4 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int4) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( shape=array_shapes, dtype=all_dtypes,