From d41078fb9578bef16f39fe437dbe40de732c7da1 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Wed, 1 Nov 2023 17:38:38 -0700 Subject: [PATCH] Properly pack and unpack int4 arrays on CPU in PJRT. Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays. PiperOrigin-RevId: 578692796 --- jax/_src/public_test_util.py | 24 ++++++++++++++++++++---- tests/lax_numpy_test.py | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index bb4b414dffb8..c40d0e1c963e 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -92,10 +92,26 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): if a.dtype == b.dtype == _dtypes.float0: np.testing.assert_array_equal(a, b, err_msg=err_msg) return - custom_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, - _dtypes.float8_e5m2, _dtypes.bfloat16] - a = a.astype(np.float32) if a.dtype in custom_dtypes else a - b = b.astype(np.float32) if b.dtype in custom_dtypes else b + + custom_float_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, + _dtypes.float8_e5m2, _dtypes.bfloat16] + def maybe_upcast(x): + if x.dtype in custom_float_dtypes: + return x.astype(np.float32) + # TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once + # ml_dtypes has a stable release with commit + # https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10. + # Remove these checks once JAX depends on a version on ml_dtypes with that + # commit. + if x.dtype == _dtypes.int4: + return x.astype(np.int8) + if x.dtype == _dtypes.uint4: + return x.astype(np.uint8) + return x + + a = maybe_upcast(a) + b = maybe_upcast(b) + kw = {} if atol: kw["atol"] = atol if rtol: kw["rtol"] = rtol diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 5b0b5dab53b0..4a1829311b2c 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -49,6 +49,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_extension_version from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps from jax._src.util import safe_zip, NumpyComplexWarning @@ -3526,6 +3527,24 @@ def testAstypeNone(self): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @unittest.skipIf(xla_extension_version < 210, 'jaxlib version too old') + def testAstypeInt4(self): + # Test converting from int4 to int8 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int8) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + # Test converting from int8 to int4 + x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8) + args_maker = lambda: [x] + np_op = lambda x: np.asarray(x).astype(jnp.int4) + jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( shape=array_shapes, dtype=all_dtypes,