Skip to content

Commit

Permalink
Merge pull request #23812 from mattjj:custom-primal-tangent-dtype-helper
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677269012
  • Loading branch information
Google-ML-Automation committed Sep 21, 2024
2 parents a2b3919 + 43cc70b commit bceceab
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 28 deletions.
21 changes: 11 additions & 10 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,16 +1596,17 @@ def physical_aval(aval: DShapedArray) -> DShapedArray: ...
def physical_aval(aval: AbstractValue) -> AbstractValue: ...

def physical_aval(aval):
aval_dtype = getattr(aval, 'dtype', None)
if aval_dtype and isinstance(aval_dtype, dtypes.ExtendedDType):
ctor = type(aval)
aval_shape = getattr(aval, 'shape', None)
assert aval_shape is not None, (ctor, aval)
elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype)
assert type(elt_aval) is ShapedArray
return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count
else:
return aval
if (isinstance(aval, (ShapedArray, DShapedArray)) and
isinstance(aval.dtype, dtypes.ExtendedDType)):
elt_aval = physical_element_aval(aval.dtype)
if isinstance(aval, ShapedArray):
return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype)
return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype)
return aval

def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray:
duck = edtype._rules.physical_element_aval(edtype) # type: ignore
return ShapedArray(duck.shape, dtypes.dtype(duck.dtype))

def _short_dtype_name(dtype) -> str:
if isinstance(dtype, dtypes.ExtendedDType):
Expand Down
22 changes: 22 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

import abc
import builtins
import dataclasses
import functools
import types
from typing import cast, overload, Any, Literal, Union
import warnings

Expand Down Expand Up @@ -834,3 +836,23 @@ def safe_to_cast(input_dtype_or_value: Any,
# We deliberately use output_dtype rather than output_dtype_or_value here:
# this effectively treats the output dtype as always strongly-typed.
return result_type(input_dtype_or_value, output_dtype) == output_dtype

def primal_tangent_dtype(primal_dtype, tangent_dtype,
name: str | None = None) -> ExtendedDType:
name_ = name or f'PTDtype{{{primal_dtype}:{tangent_dtype}}}'
rules = types.SimpleNamespace(
physical_element_aval=
lambda dtype: types.SimpleNamespace(shape=(), dtype=primal_dtype),
tangent_dtype=lambda dtype: tangent_dtype,
convert_from=lambda _, other: other == primal_dtype,
convert_to=lambda other, _: other == primal_dtype)

class primal_tangent_dtype_scalar(extended): ...

@dataclasses.dataclass(frozen=True)
class PrimalTangentDType(ExtendedDType):
name = name_
_rules = rules
type = primal_tangent_dtype_scalar

return PrimalTangentDType()
12 changes: 4 additions & 8 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,8 +2158,7 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
# op is broadcast.
# Lower a possibly-dynamic broadcast_in_dim
if dtypes.issubdtype(aval_out.dtype, dtypes.extended): # type: ignore
elt_shape = aval_out.dtype._rules.physical_element_aval( # type: ignore
aval_out.dtype).shape # type: ignore
elt_shape = core.physical_element_aval(aval_out.dtype).shape # type: ignore
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] # type: ignore
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
physical_aval_out = core.physical_aval(aval_out)
Expand Down Expand Up @@ -2213,8 +2212,7 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices, limit_indices, strides) -> ir.Value:
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
elt_shape = core.physical_element_aval(aval_out.dtype).shape
trailing_zeros = [0] * len(elt_shape)
trailing_ones = [1] * len(elt_shape)
start_indices = (*start_indices, *trailing_zeros)
Expand All @@ -2241,8 +2239,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
x_aval = ctx.avals_in[0]
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
elt_shape = core.physical_element_aval(aval_out.dtype).shape
index_avals = ctx.avals_in[1:]
dtype = dtypes.canonicalize_dtype(
index_avals[0].dtype if index_avals else 'int64') # type: ignore
Expand Down Expand Up @@ -2275,8 +2272,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
elt_shape = core.physical_element_aval(aval_out.dtype).shape
index_avals = ctx.avals_in[2:]
dtype = dtypes.canonicalize_dtype(
index_avals[0].dtype if index_avals else 'int64') # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ def _hlo_shard(aval, axis_env, x, in_axis):
return x
elif isinstance(aval, core.ShapedArray):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
aval = aval.dtype._rules.physical_element_aval(aval.dtype)
aval = core.physical_element_aval(aval.dtype)
dims = list(aval.shape)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [zero] * len(dims)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3866,8 +3866,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
elt_shape = aval_out.dtype._rules.physical_element_aval(
aval_out.dtype).shape
elt_shape = core.physical_element_aval(aval_out.dtype).shape
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
permutation = [*permutation, *trailing_dims]
return [hlo.transpose(x, mlir.dense_int_array(permutation))]
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,7 +1783,7 @@ def _gather_lower_opaque(ctx, operand, indices, *,
indices_are_sorted, mode, fill_value) -> ir.Value:
aval_x, aval_indices = ctx.avals_in
aval_y, = ctx.avals_out
elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape
elt_shape = core.physical_element_aval(aval_x.dtype).shape
trailing_offset_dims = [aval_y.ndim + i for i in range(len(elt_shape))]
dimension_numbers = dimension_numbers._replace(
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
Expand Down Expand Up @@ -2436,7 +2436,7 @@ def _scatter_lower_opaque(ctx, operand, indices, updates, *,
unique_indices, indices_are_sorted, mode):
aval_x, aval_indices, aval_updates = ctx.avals_in
aval_y, = ctx.avals_out
elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape
elt_shape = core.physical_element_aval(aval_x.dtype).shape
trailing_window_dims = [aval_updates.ndim + i for i in range(len(elt_shape))]
dimension_numbers = dimension_numbers._replace(
update_window_dims=(*dimension_numbers.update_window_dims,
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def num_addressable_indices(


def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
elt_aval = core.physical_element_aval(aval.dtype)
new_op_sharding = hlo_sharding.to_proto().clone()
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
Expand All @@ -1526,15 +1526,15 @@ def make_key_array_phys_sharding(aval, sharding):
if is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
elt_aval = core.physical_element_aval(aval.dtype)
trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, NamedSharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
elt_aval = core.physical_element_aval(aval.dtype)
trailing_spec = [None] * elt_aval.ndim
return NamedSharding(
sharding.mesh,
Expand All @@ -1551,7 +1551,7 @@ def physical_sharding(


def get_logical_gspmd_sharding(aval, phys_sharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
elt_aval = core.physical_element_aval(aval.dtype)
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + elt_aval.ndim)
partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding)
Expand Down Expand Up @@ -1583,7 +1583,7 @@ def logical_sharding(aval, phys_sharding) -> sharding.Sharding:
if is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
elt_aval = core.physical_element_aval(aval.dtype)
logical_sharding_spec = sharding_specs.ShardingSpec(
sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim],
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from jax._src.callback import (
io_callback as io_callback
)
from jax._src.dtypes import (
primal_tangent_dtype as primal_tangent_dtype,
)
from jax._src.earray import (
EArray as EArray
)
28 changes: 28 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,34 @@ def test_check_dtype_array(self):
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(dtypes.check_user_dtype_supported)(x)

@parameterized.parameters([True]) # TODO(mattjj): make jit=False work
def test_primal_tangent_dtype(self, jit):
dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16)

x = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10
).astype(jnp.int8)
g = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10
).astype(jnp.bfloat16)

@jax.custom_gradient
def f(x):
def bwd(g):
return 2 * g,
return jnp.int8(x).astype(g.dtype) * 2 + 1, bwd

def h():
result, bwd = jax.vjp(f, x.astype(dt))
bwd_result, = bwd(g)
return result, bwd_result

if jit:
h = jax.jit(h)

result, bwd_result = h()
self.assertEqual(result.dtype, jnp.bfloat16)
self.assertEqual(bwd_result.dtype, jnp.bfloat16)
self.assertAllClose(bwd_result, 2 * g)


class EArrayTest(jtu.JaxTestCase):

Expand Down

0 comments on commit bceceab

Please sign in to comment.