Skip to content

Commit

Permalink
add jax.experimental.primal_tangent_dtype helper
Browse files Browse the repository at this point in the history
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
  • Loading branch information
mattjj committed Sep 20, 2024
1 parent 0d96f39 commit 5fd1607
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
22 changes: 22 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@

import abc
import builtins
import dataclasses
import functools
import types
from typing import cast, overload, Any, Literal, Union
import warnings

import ml_dtypes
import numpy as np

from jax._src import config
from jax._src import core
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC

Expand Down Expand Up @@ -834,3 +837,22 @@ def safe_to_cast(input_dtype_or_value: Any,
# We deliberately use output_dtype rather than output_dtype_or_value here:
# this effectively treats the output dtype as always strongly-typed.
return result_type(input_dtype_or_value, output_dtype) == output_dtype

def primal_tangent_dtype(primal_dtype, tangent_dtype,
name: str | None = None) -> ExtendedDType:
name_ = name or f'PTDtype{{{primal_dtype}:{tangent_dtype}}}'
rules = types.SimpleNamespace(
physical_element_aval=lambda dtype: core.ShapedArray((), primal_dtype),
tangent_dtype=lambda dtype: tangent_dtype,
convert_from=lambda _, other: other == primal_dtype,
convert_to=lambda other, _: other == primal_dtype)

class primal_tangent_dtype_scalar(extended): ...

@dataclasses.dataclass(frozen=True)
class PrimalTangentDType(ExtendedDType):
name = name_
_rules = rules
type = primal_tangent_dtype_scalar

return PrimalTangentDType()
3 changes: 3 additions & 0 deletions jax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from jax._src.callback import (
io_callback as io_callback
)
from jax._src.dtypes import (
primal_tangent_dtype as primal_tangent_dtype,
)
from jax._src.earray import (
EArray as EArray
)
28 changes: 28 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,34 @@ def test_check_dtype_array(self):
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(dtypes.check_user_dtype_supported)(x)

@parameterized.parameters([True]) # TODO(mattjj): make jit=False work
def test_primal_tangent_dtype(self, jit):
dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16)

x = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10
).astype(jnp.int8)
g = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10
).astype(jnp.bfloat16)

@jax.custom_gradient
def f(x):
def bwd(g):
return 2 * g,
return jnp.int8(x).astype(g.dtype) * 2 + 1, bwd

def h():
result, bwd = jax.vjp(f, x.astype(dt))
bwd_result, = bwd(g)
return result, bwd_result

if jit:
h = jax.jit(h)

result, bwd_result = h()
self.assertEqual(result.dtype, jnp.bfloat16)
self.assertEqual(bwd_result.dtype, jnp.bfloat16)
self.assertAllClose(bwd_result, 2 * g)


class EArrayTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 5fd1607

Please sign in to comment.