diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 89b6c6e14acd..40caae76bd8f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -59,6 +59,8 @@ def program_id(axis: int) -> jax.Array: grid coordinates `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + The returned value is an array of shape `()` and dtype `int32`. + Args: axis: the axis of the grid along which to count the program. """ diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d8f890c06c32..65cdde30f7d9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -752,8 +752,6 @@ class OpsExtraTest(PallasBaseTest): def setUp(self): super().setUp() - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") @@ -800,7 +798,7 @@ def kernel(x_ref, o_ref): def test_abs_weak_type(self): # see https://github.com/jax-ml/jax/issues/23191 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), floatx), ) def kernel(x_ref, o_ref): o_ref[...] = jnp.abs(x_ref[...]) @@ -1145,20 +1143,20 @@ def f(x_ref, o_ref): def test_num_programs(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((4,), intx), grid=4, ) def kernel(o_ref): o_ref[pl.program_id(0)] = pl.num_programs(0) np.testing.assert_array_equal( - kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) + kernel(), jnp.array([4, 4, 4, 4], dtype=intx) ) def test_where_broadcasting(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), + out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx), grid=1, ) def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): @@ -1225,11 +1223,12 @@ def dot(x_ref, y_ref, o_ref): def test_masked_load_store(self, size, block_size): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((size,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((size,), floatx)), grid=pl.cdiv(size, block_size), ) def kernel(x_ref, o_ref): - idx = pl.program_id(0) * block_size + jnp.arange(block_size) + idx = pl.program_id(0) * block_size + jnp.arange( + block_size, dtype=jnp.int32) mask = idx < x_ref.shape[0] x = pl.load(x_ref, (idx,), mask=mask) pl.store(o_ref, (idx,), x + 1.0, mask=mask) @@ -1243,7 +1242,7 @@ def test_masked_oob_load_store_slice(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((n,), floatx)), grid=1, ) def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): @@ -1276,7 +1275,7 @@ def test_broadcasted_load_store(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((m, n), floatx)), grid=1, ) def load(x_ref, o_ref): @@ -1319,7 +1318,7 @@ def test_swap(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2, grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -1339,7 +1338,7 @@ def test_masked_swap(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2, grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -1360,8 +1359,8 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32), - jax.ShapeDtypeStruct((m,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((n,), floatx), + jax.ShapeDtypeStruct((m,), floatx)), grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -1430,7 +1429,7 @@ def test_array_atomic_add(self, axis): grid = m else: grid = n - out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) + out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), floatx) @functools.partial( self.pallas_call, @@ -1464,8 +1463,8 @@ def reduce(x_ref, _, y_ref): def test_atomic_cas(self, init_value, cmp, new_value): @functools.partial( self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx)), input_output_aliases={0: 0}) def swap(_, lock_ref, out_ref): out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) @@ -1528,14 +1527,31 @@ def reduce(x_ref, y_ref): ("argmin", jnp.argmin), ] for axis in [0, 1, (1,), (0, 1)] - for dtype in ["float16", "float32", "int32", "uint32"] + for dtype in [ + "float16", + "float32", + "float64", + "int32", + "int64", + "uint32", + "uint64", + ] if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): m, n = 32, 8 - out_dtype = dtype - if op in {jnp.argmin, jnp.argmax}: - out_dtype = jnp.int32 + + if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"): + self.skipTest("64-bit types require x64_enabled") + + # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects + # `index_type` to be i32 + if ( + jax.config.x64_enabled + and jtu.test_device_matches(["gpu"]) + and op in {jnp.argmin, jnp.argmax} + ): + self.skipTest("Not supported on GPU in 64-bit mode") def make_x(key): if jnp.issubdtype(dtype, jnp.integer): @@ -1545,9 +1561,10 @@ def make_x(key): else: return random.normal(key, (m, n), dtype=dtype) + # deduct `out_dtype` by executing the op on a single element + out_dtype = op(jnp.arange(1, dtype=dtype)).dtype out_shape = jax.ShapeDtypeStruct( - op(make_x(random.key(0)), axis=axis).shape, out_dtype - ) + op(make_x(random.key(0)), axis=axis).shape, out_dtype) if isinstance(axis, int): grid = tuple(a for i, a in enumerate((m, n)) if i != axis) else: @@ -1555,9 +1572,11 @@ def make_x(key): @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) + x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], + jnp.arange(n, dtype=jnp.int32)[None])) y = op(x, axis=axis) - pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) + pl.store(y_ref, + tuple(jnp.arange(d, dtype=jnp.int32) for d in y.shape), y) for i, key in enumerate(random.split(random.key(0), 20)): x = make_x(key)