Skip to content

Commit

Permalink
[Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode
Browse files Browse the repository at this point in the history
This is a follow-up of #23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR #23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 676838304
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 20, 2024
1 parent 6b93b35 commit e8216c4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 25 deletions.
2 changes: 2 additions & 0 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def program_id(axis: int) -> jax.Array:
grid coordinates `(1, 2)`,
`program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
The returned value is an array of shape `()` and dtype `int32`.
Args:
axis: the axis of the grid along which to count the program.
"""
Expand Down
69 changes: 44 additions & 25 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,6 @@ class OpsExtraTest(PallasBaseTest):

def setUp(self):
super().setUp()
if jax.config.x64_enabled:
self.skipTest("Only works in 32-bit")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
# TODO: most tests fail on TPU in non-interpret mode
self.skipTest("On TPU the test works only in interpret mode")
Expand Down Expand Up @@ -800,7 +798,7 @@ def kernel(x_ref, o_ref):
def test_abs_weak_type(self):
# see https://github.com/jax-ml/jax/issues/23191
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32),
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), floatx),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.abs(x_ref[...])
Expand Down Expand Up @@ -1145,20 +1143,20 @@ def f(x_ref, o_ref):
def test_num_programs(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.int32),
out_shape=jax.ShapeDtypeStruct((4,), intx),
grid=4,
)
def kernel(o_ref):
o_ref[pl.program_id(0)] = pl.num_programs(0)

np.testing.assert_array_equal(
kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32)
kernel(), jnp.array([4, 4, 4, 4], dtype=intx)
)

def test_where_broadcasting(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32),
out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx),
grid=1,
)
def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref):
Expand Down Expand Up @@ -1225,11 +1223,12 @@ def dot(x_ref, y_ref, o_ref):
def test_masked_load_store(self, size, block_size):
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((size,), jnp.float32)),
out_shape=(jax.ShapeDtypeStruct((size,), floatx)),
grid=pl.cdiv(size, block_size),
)
def kernel(x_ref, o_ref):
idx = pl.program_id(0) * block_size + jnp.arange(block_size)
idx = pl.program_id(0) * block_size + jnp.arange(
block_size, dtype=jnp.int32)
mask = idx < x_ref.shape[0]
x = pl.load(x_ref, (idx,), mask=mask)
pl.store(o_ref, (idx,), x + 1.0, mask=mask)
Expand All @@ -1243,7 +1242,7 @@ def test_masked_oob_load_store_slice(self):

@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)),
out_shape=(jax.ShapeDtypeStruct((n,), floatx)),
grid=1,
)
def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref):
Expand Down Expand Up @@ -1276,7 +1275,7 @@ def test_broadcasted_load_store(self):

@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32)),
out_shape=(jax.ShapeDtypeStruct((m, n), floatx)),
grid=1,
)
def load(x_ref, o_ref):
Expand Down Expand Up @@ -1319,7 +1318,7 @@ def test_swap(self):

@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2,
out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2,
grid=1,
input_output_aliases={0: 0, 1: 1},
)
Expand All @@ -1339,7 +1338,7 @@ def test_masked_swap(self):

@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2,
out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2,
grid=1,
input_output_aliases={0: 0, 1: 1},
)
Expand All @@ -1360,8 +1359,8 @@ def test_masked_oob_swap_slice(self):

@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32),
jax.ShapeDtypeStruct((m,), jnp.float32)),
out_shape=(jax.ShapeDtypeStruct((n,), floatx),
jax.ShapeDtypeStruct((m,), floatx)),
grid=1,
input_output_aliases={0: 0, 1: 1},
)
Expand Down Expand Up @@ -1430,7 +1429,7 @@ def test_array_atomic_add(self, axis):
grid = m
else:
grid = n
out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32)
out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), floatx)

@functools.partial(
self.pallas_call,
Expand Down Expand Up @@ -1464,8 +1463,8 @@ def reduce(x_ref, _, y_ref):
def test_atomic_cas(self, init_value, cmp, new_value):
@functools.partial(
self.pallas_call, out_shape=(
jax.ShapeDtypeStruct((), jnp.int32),
jax.ShapeDtypeStruct((), jnp.int32)),
jax.ShapeDtypeStruct((), intx),
jax.ShapeDtypeStruct((), intx)),
input_output_aliases={0: 0})
def swap(_, lock_ref, out_ref):
out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value)
Expand Down Expand Up @@ -1528,14 +1527,31 @@ def reduce(x_ref, y_ref):
("argmin", jnp.argmin),
]
for axis in [0, 1, (1,), (0, 1)]
for dtype in ["float16", "float32", "int32", "uint32"]
for dtype in [
"float16",
"float32",
"float64",
"int32",
"int64",
"uint32",
"uint64",
]
if isinstance(axis, int) or "arg" not in op_name
])
def test_array_reduce(self, op, dtype, axis):
m, n = 32, 8
out_dtype = dtype
if op in {jnp.argmin, jnp.argmax}:
out_dtype = jnp.int32

if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"):
self.skipTest("64-bit types require x64_enabled")

# Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
# `index_type` to be i32
if (
jax.config.x64_enabled
and jtu.test_device_matches(["gpu"])
and op in {jnp.argmin, jnp.argmax}
):
self.skipTest("Not supported on GPU in 64-bit mode")

def make_x(key):
if jnp.issubdtype(dtype, jnp.integer):
Expand All @@ -1545,19 +1561,22 @@ def make_x(key):
else:
return random.normal(key, (m, n), dtype=dtype)

# deduct `out_dtype` by executing the op on a single element
out_dtype = op(jnp.arange(1, dtype=dtype)).dtype
out_shape = jax.ShapeDtypeStruct(
op(make_x(random.key(0)), axis=axis).shape, out_dtype
)
op(make_x(random.key(0)), axis=axis).shape, out_dtype)
if isinstance(axis, int):
grid = tuple(a for i, a in enumerate((m, n)) if i != axis)
else:
grid = tuple(a for i, a in enumerate((m, n)) if i not in axis)

@functools.partial(self.pallas_call, out_shape=out_shape, grid=grid)
def reduce(x_ref, y_ref):
x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None]))
x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None],
jnp.arange(n, dtype=jnp.int32)[None]))
y = op(x, axis=axis)
pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y)
pl.store(y_ref,
tuple(jnp.arange(d, dtype=jnp.int32) for d in y.shape), y)

for i, key in enumerate(random.split(random.key(0), 20)):
x = make_x(key)
Expand Down

0 comments on commit e8216c4

Please sign in to comment.