Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] Allow user to pass 64-bit indices to pl.{load,store,...}. #24782

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax import tree_util
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
Expand All @@ -39,8 +40,8 @@
from jax._src.pallas import core as pallas_core
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import types as state_types
from jax._src.state import primitives as sp
from jax._src.state import types as state_types
from jax.interpreters import mlir
import jax.numpy as jnp

Expand All @@ -51,6 +52,16 @@
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip


def _enable_x64(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
with config.enable_x64():
return f(*args, **kwargs)

return wrapper


program_id_p = jax_core.Primitive("program_id")
batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule

Expand Down Expand Up @@ -182,6 +193,7 @@ def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
return _swap_abstract_eval(*avals_flat, args_tree=args_tree)


@_enable_x64
def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
atomic_type: AtomicOpType):
x_ref, transforms = sp.get_ref_and_transforms(
Expand Down Expand Up @@ -629,6 +641,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
return (x_new,) + (None,) * (len(in_avals) - 1), out


@_enable_x64
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
eviction_policy=None, volatile=False) -> jax.Array:
"""Returns an array loaded from the given index.
Expand Down Expand Up @@ -659,6 +672,8 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
is_volatile=volatile,
)


@_enable_x64
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
_function_name="swap") -> jax.Array:
"""Swaps the value at the given index and returns the old value.
Expand Down
Loading