Skip to content

Commit 9a238f1

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Hijax] Let MemoryRef be parameterized by a type, instead of forcing it to be ShapedArray
This will allow us to generalize to types beyond ShapedArray (new HiTypes specifically) PiperOrigin-RevId: 807400923
1 parent a27ef01 commit 9a238f1

File tree

4 files changed

+47
-23
lines changed

4 files changed

+47
-23
lines changed

jax/_src/pallas/core.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,24 +210,34 @@ def update(
210210
@dataclasses.dataclass(frozen=True)
211211
class MemoryRef:
212212
"""Like jax.ShapeDtypeStruct but with memory spaces."""
213-
shape: tuple[int, ...]
214-
dtype: jnp.dtype | dtypes.ExtendedDType
213+
inner_aval: jax_core.AbstractValue
215214
# TODO(b/368122763): Unify memory space types across backends
216215
memory_space: Any
217216

218217
def get_array_aval(self) -> jax_core.ShapedArray:
219-
dtype = self.dtype
218+
if not isinstance(self.inner_aval, jax_core.ShapedArray):
219+
raise ValueError(
220+
f"MemoryRef type must be a ShapedArray, got {type(self.inner_aval)}"
221+
)
222+
dtype = self.inner_aval.dtype
220223
if not isinstance(dtype, (jnp.dtype, dtypes.ExtendedDType)):
221224
dtype = jnp.dtype(dtype)
222225
return ShapedArrayWithMemorySpace(
223-
self.shape, dtype, memory_space=self.memory_space
226+
self.inner_aval.shape, dtype, memory_space=self.memory_space
224227
)
225228

226229
def get_ref_aval(self) -> TransformedRef | state.AbstractRef:
227230
# TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
228231
# try to apply JAX ops to it.
229-
return state.AbstractRef(
230-
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
232+
return state.AbstractRef(self.inner_aval, self.memory_space)
233+
234+
@property
235+
def dtype(self):
236+
return self.inner_aval.dtype
237+
238+
@property
239+
def shape(self):
240+
return self.inner_aval.shape
231241

232242

233243
class MemorySpace(enum.Enum):
@@ -242,10 +252,13 @@ class MemorySpace(enum.Enum):
242252
KEY = "key" # Memory space for PRNG keys.
243253
HOST = "host" # Host memory space.
244254

245-
def __call__(self, shape, dtype):
246-
if self == MemorySpace.ANY:
247-
return jax.ShapeDtypeStruct(shape, dtype)
248-
return MemoryRef(shape, dtype, self)
255+
def from_type(self, type: jax_core.AbstractValue) -> MemoryRef:
256+
return MemoryRef(type, memory_space=self)
257+
258+
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
259+
# A convenience function for constructing MemoryRef types of ShapedArrays.
260+
return self.from_type(jax_core.ShapedArray(shape, dtype))
261+
249262

250263
def __str__(self) -> str:
251264
return self.value

jax/_src/pallas/mosaic/core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
from collections.abc import Mapping
2424

2525
import jax
26+
import jax.numpy as jnp
2627
from jax.extend import backend as jex_backend
2728
from jax._src import core as jax_core
2829
from jax._src import state
2930
from jax._src import util
3031
from jax._src.frozen_dict import FrozenDict
3132
from jax._src.pallas import core as pallas_core
32-
import jax.numpy as jnp
3333
import numpy as np
3434

3535

@@ -166,9 +166,12 @@ class MemorySpace(enum.Enum):
166166
def __str__(self) -> str:
167167
return self.value
168168

169+
def from_type(self, ty):
170+
return pallas_core.MemoryRef(ty, memory_space=self)
171+
169172
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
170-
# A convenience function for constructing MemoryRef types.
171-
return pallas_core.MemoryRef(shape, dtype, self)
173+
# A convenience function for constructing MemoryRef types of ShapedArrays.
174+
return self.from_type(jax_core.ShapedArray(shape, dtype))
172175

173176
class dma_semaphore(pallas_core.semaphore_dtype): pass
174177

@@ -189,7 +192,8 @@ def __call__(self, shape: tuple[int, ...]):
189192
dtype = pallas_core.BarrierSemaphore()
190193
else:
191194
dtype = pallas_core.Semaphore()
192-
return pallas_core.MemoryRef(shape, dtype, MemorySpace.SEMAPHORE)
195+
return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype),
196+
MemorySpace.SEMAPHORE)
193197

194198
def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace:
195199
return self(()).get_array_aval()

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __call__(
150150
collective: bool | None = None,
151151
layout: TMEMLayout | None = None,
152152
) -> pallas_core.MemoryRef:
153+
# TODO(sharadmv): Add HiType constructor support.
153154
if self == MemorySpace.TMEM:
154155
if transforms:
155156
raise ValueError("transforms are not supported for TMEM")
@@ -174,8 +175,9 @@ def __call__(
174175
if packed is not None or collective is not None or layout is not None:
175176
raise ValueError("packed, collective and layout arguments are only supported for TMEM.")
176177
mgpu_layout = None
177-
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms,
178-
layout=mgpu_layout, collective=collective)
178+
return GPUMemoryRef(jax_core.ShapedArray(shape, dtype), memory_space=self,
179+
transforms=transforms, layout=mgpu_layout,
180+
collective=collective)
179181

180182

181183
class SemaphoreType(enum.Enum):
@@ -188,7 +190,8 @@ def __call__(self, shape: tuple[int, ...]):
188190
dtype = pallas_core.BarrierSemaphore()
189191
else:
190192
dtype = pallas_core.Semaphore()
191-
return pallas_core.MemoryRef(shape, dtype, MemorySpace.GMEM)
193+
return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype),
194+
MemorySpace.GMEM)
192195

193196
def get_array_aval(self) -> jax_core.ShapedArray:
194197
return self(()).get_array_aval()
@@ -484,8 +487,9 @@ def __init__(self, *refs: _GPUMemoryRefTree):
484487
object.__setattr__(self, "refs", refs)
485488
num_bytes = max(map(_ref_group_size, self.refs))
486489
super().__init__(
487-
shape=(num_bytes,),
488-
dtype=jnp.int8,
490+
inner_aval=jax_core.ShapedArray(
491+
(num_bytes,), jnp.int8
492+
),
489493
memory_space=SMEM,
490494
transforms=(),
491495
)
@@ -498,8 +502,10 @@ def __init__(self, *refs: _GPUMemoryRefTree):
498502
"Some aliased TMEM references are collective and some are not."
499503
)
500504
super().__init__(
501-
shape=(128, max_cols,),
502-
dtype=jnp.int32,
505+
inner_aval=jax_core.ShapedArray(
506+
shape=(128, max_cols,),
507+
dtype=jnp.int32,
508+
),
503509
memory_space=TMEM,
504510
transforms=(),
505511
layout=tcgen05.tmem_default_layout(packing=1),

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,11 +748,12 @@ def ref_for_aval(aval: ShapedAbstractValue):
748748
return gpu_core.WGMMAAccumulatorRef(aval.shape, aval.dtype)
749749
elif isinstance(aval, gpu_core.AbstractTMEMRef):
750750
return gpu_core.GPUMemoryRef(
751-
aval.shape, aval.dtype, gpu_core.TMEM,
751+
jax_core.ShapedArray(aval.shape, aval.dtype), gpu_core.TMEM,
752752
transforms=(), layout=aval.layout, collective=aval.collective,
753753
)
754754
elif isinstance(aval, state_types.AbstractRef):
755-
return pallas_core.MemoryRef(aval.shape, aval.dtype, aval.memory_space)
755+
return pallas_core.MemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
756+
aval.memory_space)
756757
else:
757758
return gpu_core.SMEM(aval.shape, aval.dtype)
758759

0 commit comments

Comments
 (0)