Skip to content

Commit

Permalink
Lookup shape and dtype directly on state.AbstractRef instead of…
Browse files Browse the repository at this point in the history
… going through `inner_aval`

This is just a cleanup. No behavior changes are expected.

PiperOrigin-RevId: 675964703
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 18, 2024
1 parent e903369 commit b7c91e9
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
7 changes: 4 additions & 3 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,10 @@ def slice_scratch_ops(self):
@property
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *inputs."""
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):])
index_shapes = (
jax.ShapeDtypeStruct(ia.shape, ia.dtype)
for ia in self.index_map_avals[len(self.grid) :]
)
inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[:self.num_inputs])
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, tiling: tuple[int, ...]):
def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
block_shape = block_aval.inner_aval.shape # pytype: disable=attribute-error
block_shape = block_aval.shape
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,7 @@ def lower_jaxpr_to_module(
for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_structs_smem = [
jax.ShapeDtypeStruct(
[num_stages, *bm.ref_aval.inner_aval.shape],
bm.ref_aval.inner_aval.dtype,
)
jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype)
if in_smem
else None
for bm, in_smem in zip(
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,7 @@ def eval_jaxpr(*refs):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
for a in res_ref_avals]
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]

def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars"
Expand Down
18 changes: 12 additions & 6 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,21 @@ def join(self, other):

@property
def shape(self):
if not isinstance(self.inner_aval, core.ShapedArray):
raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
return self.inner_aval.shape
try:
return self.inner_aval.shape # pytype: disable=attribute-error
except AttributeError:
raise AttributeError(
f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`."
) from None

@property
def dtype(self):
if not isinstance(self.inner_aval, core.UnshapedArray):
raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
return self.inner_aval.dtype
try:
return self.inner_aval.dtype # pytype: disable=attribute-error
except AttributeError:
raise AttributeError(
f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`."
) from None

@core.aval_property
def at(self):
Expand Down

0 comments on commit b7c91e9

Please sign in to comment.