diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index 5b3d690d5e96..0e5a6a16dcfc 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -38,10 +38,7 @@ JAX 201 :maxdepth: 1 parallelism - advanced-autodiff - gradient-checkpointing advanced-debugging - external-callbacks profiling-and-performance JAX 301 @@ -50,6 +47,4 @@ JAX 301 .. toctree:: :maxdepth: 1 - jax-primitives - jaxpr advanced-compilation diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/advanced-autodiff.md similarity index 100% rename from docs/_tutorials/advanced-autodiff.md rename to docs/advanced-autodiff.md diff --git a/docs/conf.py b/docs/conf.py index e77916e265ff..d57420dec881 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -360,4 +360,6 @@ def linkcode_resolve(domain, info): 'jax-101/07-state.md': 'stateful-computations.md', 'jax-101/08-pjit.rst': 'sharded-computation.md', 'jax-101/index.rst': 'tutorials.rst', + 'notebooks/external_callbacks.md': 'external-callbacks.md', + 'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md', } diff --git a/docs/extensions.rst b/docs/extensions.rst index 92963b71f20f..856153cd8723 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -10,8 +10,6 @@ that use or interface with JAX. :caption: Extensible JAX internals :maxdepth: 1 - notebooks/How_JAX_primitives_work - jaxpr notebooks/Writing_custom_interpreters_in_Jax Custom_Operation_for_GPUs jax.extend diff --git a/docs/_tutorials/external-callbacks.md b/docs/external-callbacks.md similarity index 100% rename from docs/_tutorials/external-callbacks.md rename to docs/external-callbacks.md diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 04ae80cbf5b1..a8cd5219d4b5 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -364,7 +364,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" + "We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" ] }, { diff --git a/docs/ffi.md b/docs/ffi.md index 03acf876be08..cc3863ed99b2 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -311,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) ``` -We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: +We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: ```{code-cell} ipython3 jax.make_jaxpr(jax.vmap(rms_norm))(x) diff --git a/docs/glossary.rst b/docs/glossary.rst index 4bb9fa15667e..286b07e21a66 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -30,7 +30,7 @@ Glossary of terms jaxpr Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution. - See :ref:`understanding-jaxprs` for more discussion and examples. + See :ref:`jax-internals-jaxpr` for more discussion and examples. JIT Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/gradient-checkpointing.md similarity index 100% rename from docs/_tutorials/gradient-checkpointing.md rename to docs/gradient-checkpointing.md diff --git a/docs/index.rst b/docs/index.rst index 92422edc069f..2dd856ab88ef 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -121,8 +121,6 @@ maintains an up-to-date list. installation quickstart - notebooks/Common_Gotchas_in_JAX - faq .. toctree:: :hidden: @@ -130,11 +128,14 @@ maintains an up-to-date list. tutorials + notebooks/Common_Gotchas_in_JAX + + faq .. toctree:: :hidden: :maxdepth: 2 - :caption: Resources + :caption: More guides/resources user_guides advanced_guide diff --git a/docs/_tutorials/jax-primitives.md b/docs/jax-primitives.md similarity index 100% rename from docs/_tutorials/jax-primitives.md rename to docs/jax-primitives.md diff --git a/docs/_tutorials/jaxpr.md b/docs/jaxpr.md similarity index 100% rename from docs/_tutorials/jaxpr.md rename to docs/jaxpr.md diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst deleted file mode 100644 index d7b50dcb301e..000000000000 --- a/docs/jaxpr.rst +++ /dev/null @@ -1,472 +0,0 @@ -.. _understanding-jaxprs: - -Understanding Jaxprs -==================== - -Updated: May 3, 2020 (for commit f1a46fe). - -Conceptually, one can think of JAX transformations as first trace-specializing -the Python function to be transformed into a small and well-behaved -intermediate form that is then interpreted with transformation-specific -interpretation rules. One of the reasons JAX can pack so much power into such a -small software package is that it starts with a familiar and flexible -programming interface (Python with NumPy) and it uses the actual Python -interpreter to do most of the heavy lifting to distill the essence of the -computation into a simple statically-typed expression language with limited -higher-order features. That language is the jaxpr language. - -Not all Python programs can be processed this way, but it turns out that many -scientific computing and machine learning programs can. - -Before we proceed, it is important to point out that not all JAX -transformations literally materialize a jaxpr as described above; some, e.g., -differentiation or batching, will apply transformations incrementally during -tracing. Nevertheless, if one wants to understand how JAX works internally, or -to make use of the result of JAX tracing, it is useful to understand jaxprs. - -A jaxpr instance represents a function with one or more typed parameters (input -variables) and one or more typed results. The results depend only on the input -variables; there are no free variables captured from enclosing scopes. The -inputs and outputs have types, which in JAX are represented as abstract values. -There are two related representations in the code for jaxprs, -:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A -:py:class:`jax.core.ClosedJaxpr` represents a partially-applied -:py:class:`jax.core.Jaxpr`, and is what you obtain when you use -:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields: - - * ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual - computation content of the function (described below). - * ``consts`` is a list of constants. - -The most interesting part of the ClosedJaxpr is the actual execution content, -represented as a :py:class:`jax.core.Jaxpr` as printed using the following -grammar:: - - Jaxpr ::= { lambda Var* ; Var+. let - Eqn* - in [Expr+] } - -where: - * The parameters of the jaxpr are shown as two lists of variables separated by - ``;``. The first set of variables are the ones that have been introduced - to stand for constants that have been hoisted out. These are called the - ``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts`` - field holds corresponding values. The second list of variables, called - ``invars``, correspond to the inputs of the traced Python function. - * ``Eqn*`` is a list of equations, defining intermediate variables referring to - intermediate expressions. Each equation defines one or more variables as the - result of applying a primitive on some atomic expressions. Each equation uses only - input variables and intermediate variables defined by previous equations. - * ``Expr+``: is a list of output atomic expressions (literals or variables) - for the jaxpr. - -Equations are printed as follows:: - - Eqn ::= Var+ = Primitive [ Param* ] Expr+ - -where: - * ``Var+`` are one or more intermediate variables to be defined as the output - of a primitive invocation (some primitives can return multiple values). - * ``Expr+`` are one or more atomic expressions, each either a variable or a - literal constant. A special variable ``unitvar`` or literal ``unit``, - printed as ``*``, represents a value that is not needed - in the rest of the computation and has been elided. That is, units are just - placeholders. - * ``Param*`` are zero or more named parameters to the primitive, printed in - square brackets. Each parameter is shown as ``Name = Value``. - - -Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments):: - - Primitive := add | sub | sin | mul | ... - - -The jaxpr primitives are documented in the :py:mod:`jax.lax` module. - -For example, here is the jaxpr produced for the function ``func1`` below - ->>> from jax import make_jaxpr ->>> import jax.numpy as jnp ->>> def func1(first, second): -... temp = first + jnp.sin(second) * 3. -... return jnp.sum(temp) -... ->>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - -Here there are no constvars, ``a`` and ``b`` are the input variables -and they correspond respectively to -``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept -inline. -The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the -operand ``e``. - -Note that even though execution of a program that calls into JAX builds a jaxpr, -Python-level control-flow and Python-level functions execute normally. -This means that just because a Python program contains functions and control-flow, -the resulting jaxpr does not have to contain control-flow or higher-order features. - -For example, when tracing the function ``func3`` JAX will inline the call to -``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same -jaxpr as before - ->>> def func2(inner, first, second): -... temp = first + inner(second) * 3. -... return jnp.sum(temp) -... ->>> def inner(second): -... if second.shape[0] > 4: -... return jnp.sin(second) -... else: -... assert False -... ->>> def func3(first, second): -... return func2(inner, first, second) -... ->>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - - -Handling PyTrees ----------------- - -In jaxpr there are no tuple types; instead primitives take multiple inputs -and produce multiple outputs. When processing a function that has structured -inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists -of inputs and outputs. For more details, please see the documentation for -PyTrees (:ref:`pytrees`). - -For example, the following code produces an identical jaxpr to what we saw -before (with two input vars, one for each element of the input tuple) - - ->>> def func4(arg): # Arg is a pair -... temp = arg[0] + jnp.sin(arg[1]) * 3. -... return jnp.sum(temp) -... ->>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - - - -Constant vars -------------- - -Some values in jaxprs are constants, in that their value does not depend on the -jaxpr's arguments. When these values are scalars they are represented directly -in the jaxpr equations; non-scalar array constants are instead hoisted out to -the top-level jaxpr, where they correspond to constant variables ("constvars"). -These constvars differ from the other jaxpr parameters ("invars") only as a -bookkeeping convention. - - -Higher-order primitives ------------------------ - -jaxpr includes several higher-order primitives. They are more complicated because -they include sub-jaxprs. - -Conditionals -^^^^^^^^^^^^ - -JAX traces through normal Python conditionals. To capture a -conditional expression for dynamic execution, one must use the -:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors, -which have the signatures:: - - lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B - - lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B - -Both of these will bind a primitive called ``cond`` internally. The -``cond`` primitive in jaxprs reflects the more general signature of -:py:func:`lax.switch`: it takes an integer denoting the index of the branch -to execute (clamped into valid indexing range). - -For example: - ->>> from jax import lax ->>> ->>> def one_of_three(index, arg): -... return lax.switch(index, [lambda x: x + 1., -... lambda x: x - 2., -... lambda x: x + 3.], -... arg) -... ->>> print(make_jaxpr(one_of_three)(1, 5.)) -{ lambda ; a:i32[] b:f32[]. let - c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a - d:i32[] = clamp 0 c 2 - e:f32[] = cond[ - branches=( - { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) } - { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) } - { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) } - ) - ] d b - in (e,) } - -The `branches` parameter to the cond primitive corresponds to the branch -functionals. In this example, those functionals each take one input variable, -corresponding to ``x``. - -The above instance of the cond primitive takes two operands. The first -one (``d``) is the branch index, then ``b`` is the operand (``arg``) to -be passed to whichever jaxpr in ``branches`` is selected by the branch -index. - -Another example, using :py:func:`lax.cond`: - ->>> from jax import lax ->>> ->>> def func7(arg): -... return lax.cond(arg >= 0., -... lambda xtrue: xtrue + 3., -... lambda xfalse: xfalse - 3., -... arg) -... ->>> print(make_jaxpr(func7)(5.)) -{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0 - c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b - d:f32[] = cond[ - branches=( - { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) } - { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) } - ) - ] c a - in (d,) } - -In this case, the boolean predicate is converted to an integer index -(0 or 1), and ``branches`` are jaxprs that correspond to the false and -true branch functionals, in that order. Again, each functional takes -one input variable, corresponding to ``xfalse`` and ``xtrue`` -respectively. - -The following example shows a more complicated situation when the input -to the branch functionals is a tuple, and the `false` branch functional -contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` - ->>> def func8(arg1, arg2): # arg2 is a pair -... return lax.cond(arg1 >= 0., -... lambda xtrue: xtrue[0], -... lambda xfalse: jnp.array([1]) + xfalse[1], -... arg2) -... ->>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) -{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let - e:bool[] = ge b 0.0 - f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e - g:f32[1] = cond[ - branches=( - { lambda ; h:i32[1] i:f32[1] j:f32[]. let - k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h - l:f32[1] = add k j - in (l,) } - { lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) } - ) - ] f a c d - in (g,) } - - - -While -^^^^^ - -Just like for conditionals, Python loops are inlined during tracing. -If you want to capture a loop for dynamic execution, you must use one of several -special operations, :py:func:`jax.lax.while_loop` (a primitive) -and :py:func:`jax.lax.fori_loop` -(a helper that generates a while_loop primitive):: - - lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C - lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C - - -In the above signature, “C” stands for the type of the loop “carry” value. -For example, here is an example fori loop - ->>> import numpy as np ->>> ->>> def func10(arg, n): -... ones = jnp.ones(arg.shape) # A constant -... return lax.fori_loop(0, n, -... lambda i, carry: carry + ones * 3. + arg, -... arg + ones) -... ->>> print(make_jaxpr(func10)(np.ones(16), 5)) -{ lambda ; a:f32[16] b:i32[]. let - c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 - d:f32[16] = add a c - _:i32[] _:i32[] e:f32[16] = while[ - body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let - k:i32[] = add h 1 - l:f32[16] = mul f 3.0 - m:f32[16] = add j l - n:f32[16] = add m g - in (k, i, n) } - body_nconsts=2 - cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let - r:bool[] = lt o p - in (r,) } - cond_nconsts=0 - ] c a 0 b d - in (e,) } - -The while primitive takes 5 arguments: ``c a 0 b d``, as follows: - - * 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0) - * 2 constants for ``body_jaxpr`` (``c``, and ``a``) - * 3 parameters for the initial value of carry - -Scan -^^^^ - -JAX supports a special form of loop over the elements of an array (with -statically known shape). The fact that there are a fixed number of iterations -makes this form of looping easily reverse-differentiable. Such loops are -constructed with the :py:func:`jax.lax.scan` function:: - - lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B]) - -This is written in terms of a `Haskell Type Signature`_: -``C`` is the type of the scan carry, ``A`` is the element type of the -input array(s), and ``B`` is the element type of the output array(s). - -For the example consider the function ``func11`` below - ->>> def func11(arr, extra): -... ones = jnp.ones(arr.shape) # A constant -... def body(carry, aelems): -... # carry: running dot-product of the two arrays -... # aelems: a pair with corresponding elements from the two arrays -... ae1, ae2 = aelems -... return (carry + ae1 * ae2 + extra, carry) -... return lax.scan(body, 0., (arr, ones)) -... ->>> print(make_jaxpr(func11)(np.ones(16), 5.)) -{ lambda ; a:f32[16] b:f32[]. let - c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 - d:f32[] e:f32[16] = scan[ - _split_transpose=False - jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let - j:f32[] = mul h i - k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g - l:f32[] = add k j - m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f - n:f32[] = add l m - in (n, g) } - length=16 - linear=(False, False, False, False) - num_carry=1 - num_consts=1 - reverse=False - unroll=1 - ] b 0.0 a c - in (d, e) } - -The ``linear`` parameter describes for each of the input variables whether they -are guaranteed to be used linearly in the body. Once the scan goes through -linearization, more arguments will be linear. - -The scan primitive takes 4 arguments: ``b 0.0 a c``, of which: - - * one is the free variable for the body - * one is the initial value of the carry - * The next 2 are the arrays over which the scan operates. - -XLA_call -^^^^^^^^ - -The call primitive arises from JIT compilation, and it encapsulates -a sub-jaxpr along with parameters that specify the backend and the device on -which the computation should run. For example - ->>> from jax import jit ->>> ->>> def func12(arg): -... @jit -... def inner(x): -... return x + arg * jnp.ones(1) # Include a constant in the inner function -... return arg + inner(arg - 2.) -... ->>> print(make_jaxpr(func12)(1.)) # doctest:+ELLIPSIS -{ lambda ; a:f32[]. let - b:f32[] = sub a 2.0 - c:f32[1] = pjit[ - name=inner - jaxpr={ lambda ; d:f32[] e:f32[]. let - f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 - g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d - h:f32[1] = mul g f - i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e - j:f32[1] = add i h - in (j,) } - ] a b - k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a - l:f32[1] = add k c - in (l,) } - - -XLA_pmap -^^^^^^^^ - -If you use the :py:func:`jax.pmap` transformation, the function to be mapped is -captured using the ``xla_pmap`` primitive. Consider this example - ->>> from jax import pmap ->>> ->>> def func13(arr, extra): -... def inner(x): -... # use a free variable "extra" and a constant jnp.ones(1) -... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows') -... return pmap(inner, axis_name='rows')(arr) -... ->>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.)) -{ lambda ; a:f32[1,3] b:f32[]. let - c:f32[1,3] = xla_pmap[ - axis_name=rows - axis_size=1 - backend=None - call_jaxpr={ lambda ; d:f32[] e:f32[3]. let - f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d - g:f32[3] = add e f - h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 - i:f32[3] = add g h - j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e - k:f32[3] = div i j - in (k,) } - devices=None - donated_invars=(False, False) - global_axis_size=1 - in_axes=(None, 0) - is_explicit_global_axis_size=False - name=inner - out_axes=(0,) - ] b a - in (c,) } - -The ``xla_pmap`` primitive specifies the name of the axis (parameter -``axis_name``) and the body of the function to be mapped as the ``call_jaxpr`` -parameter. The value of this parameter is a Jaxpr with 2 input variables. - -The parameter ``in_axes`` specifies which of the input variables should be -mapped and which should be broadcast. In our example, the value of ``extra`` -is broadcast and the value of ``arr`` is mapped. - -.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index bc6cb3c04cf8..59c7bbd8fb90 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -51,7 +51,7 @@ def log2(x): print(jax.make_jaxpr(log2)(3.0)) ``` -The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output. +The {ref}`jax-internals-jaxpr` section of the documentation provides more information on the meaning of the above output. Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb deleted file mode 100644 index e9924e18d023..000000000000 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ /dev/null @@ -1,1532 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "vfxqky4PCUnh" - }, - "source": [ - "# How JAX primitives work\n", - "\n", - "\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", - "\n", - "*necula@google.com*, October 2019.\n", - "\n", - "JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n", - "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,\n", - "which means that as the Python function executes\n", - "the only operations it applies to the data are either inspections of data\n", - "attributes such as shape or type, or special operations called JAX primitives.\n", - "In particular, a JAX-traceable function is sometimes invoked by JAX with\n", - "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,\n", - "which captures the type and the shape of values, but not the concrete data values.\n", - "JAX primitives know how to operate on both concrete data\n", - "values and on the JAX abstract values.\n", - "\n", - "\n", - "The JAX-transformed functions must themselves be JAX-traceable functions,\n", - "to ensure that these transformations\n", - "can be composed, e.g., `jit(jacfwd(grad(f)))`.\n", - "\n", - "There are pre-defined JAX primitives corresponding to most XLA operations,\n", - "e.g., add, matmul, sin, cos, indexing.\n", - "JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n", - "using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n", - "Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.\n", - "\n", - "The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,\n", - "one can define a new primitive that encapsulates the behavior of the function.\n", - "\n", - "**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n", - "\n", - "Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n", - "as \"multiply_add(x, y, z) = x * y + z\".\n", - "This function operates on 3 identically-shaped tensors of floating point\n", - "values and performs the operations pointwise." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HIJYIHNTD1yI" - }, - "source": [ - "## Using existing primitives\n", - "\n", - "The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n", - "functions that are themselves written using JAX primitives, e.g., those\n", - "defined in the `jax.lax` module:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "tbOF0LB0EMne", - "outputId": "3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "square_add_lax = 14.0\n", - "grad(square_add_lax) = 4.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.\n", - " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" - ] - } - ], - "source": [ - "from jax import lax\n", - "from jax._src import api\n", - "\n", - "def multiply_add_lax(x, y, z):\n", - " \"\"\"Implementation of multiply-add using the jax.lax primitives.\"\"\"\n", - " return lax.add(lax.mul(x, y), z)\n", - "\n", - "\n", - "def square_add_lax(a, b):\n", - " \"\"\"A square-add function using the newly defined multiply-add.\"\"\"\n", - " return multiply_add_lax(a, a, b)\n", - "\n", - "print(\"square_add_lax = \", square_add_lax(2., 10.))\n", - "# Differentiate w.r.t. the first argument\n", - "print(\"grad(square_add_lax) = \", api.grad(square_add_lax, argnums=0)(2.0, 10.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Cgv60Wm3E_D5" - }, - "source": [ - "In order to understand how JAX is internally using the primitives,\n", - "we add some helpers for tracing function calls." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "mQRQGEGiE53K" - }, - "outputs": [], - "source": [ - "#@title Helper functions (execute this cell)\n", - "import functools\n", - "import traceback\n", - "\n", - "_indentation = 0\n", - "def _trace(msg=None):\n", - " \"\"\"Print a message at current indentation.\"\"\"\n", - " if msg is not None:\n", - " print(\" \" * _indentation + msg)\n", - "\n", - "def _trace_indent(msg=None):\n", - " \"\"\"Print a message and then indent the rest.\"\"\"\n", - " global _indentation\n", - " _trace(msg)\n", - " _indentation = 1 + _indentation\n", - "\n", - "def _trace_unindent(msg=None):\n", - " \"\"\"Unindent then print a message.\"\"\"\n", - " global _indentation\n", - " _indentation = _indentation - 1\n", - " _trace(msg)\n", - "\n", - "def trace(name):\n", - " \"\"\"A decorator for functions to trace arguments and results.\"\"\"\n", - "\n", - " def trace_func(func): # pylint: disable=missing-docstring\n", - " def pp(v):\n", - " \"\"\"Print certain values more succinctly\"\"\"\n", - " vtype = str(type(v))\n", - " if \"jax._src.xla_bridge._JaxComputationBuilder\" in vtype:\n", - " return \"\"\n", - " elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n", - " return \"\".format(id(v))\n", - " elif (\"partial_eval.JaxprTracer\" in vtype or\n", - " \"batching.BatchTracer\" in vtype or\n", - " \"ad.JVPTracer\" in vtype):\n", - " return \"Traced<{}>\".format(v.aval)\n", - " elif isinstance(v, tuple):\n", - " return \"({})\".format(pp_values(v))\n", - " else:\n", - " return str(v)\n", - " def pp_values(args):\n", - " return \", \".join([pp(arg) for arg in args])\n", - "\n", - " @functools.wraps(func)\n", - " def func_wrapper(*args):\n", - " _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n", - " res = func(*args)\n", - " _trace_unindent(\"|<- {} = {}\".format(name, pp(res)))\n", - " return res\n", - "\n", - " return func_wrapper\n", - "\n", - " return trace_func\n", - "\n", - "class expectNotImplementedError(object):\n", - " \"\"\"Context manager to check for NotImplementedError.\"\"\"\n", - " def __enter__(self): pass\n", - " def __exit__(self, type, value, tb):\n", - " global _indentation\n", - " _indentation = 0\n", - " if type is NotImplementedError:\n", - " print(\"\\nFound expected exception:\")\n", - " traceback.print_exc(limit=3)\n", - " return True\n", - " elif type is None: # No exception\n", - " assert False, \"Expected NotImplementedError\"\n", - " else:\n", - " return False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qf4eLrLCFYDl" - }, - "source": [ - "Instead of using `jax.lax` primitives directly, we can use other functions\n", - "that are already written in terms of those primitives, such as those in `jax.numpy`:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "QhKorz6cFRJb", - "outputId": "aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Normal evaluation:\n", - "call square_add_numpy(2.0, 10.0)\n", - " call multiply_add_numpy(2.0, 2.0, 10.0)\n", - " |<- multiply_add_numpy = 14.0\n", - "|<- square_add_numpy = 14.0\n", - "square_add_numpy = 14.0\n", - "\n", - "Gradient evaluation:\n", - "call square_add_numpy(Traced, 10.0)\n", - " call multiply_add_numpy(Traced, Traced, 10.0)\n", - " |<- multiply_add_numpy = Traced\n", - "|<- square_add_numpy = Traced\n", - "grad(square_add_numpy) = 4.0\n" - ] - } - ], - "source": [ - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "@trace(\"multiply_add_numpy\")\n", - "def multiply_add_numpy(x, y, z):\n", - " return jnp.add(jnp.multiply(x, y), z)\n", - "\n", - "@trace(\"square_add_numpy\")\n", - "def square_add_numpy(a, b):\n", - " return multiply_add_numpy(a, a, b)\n", - "\n", - "print(\"\\nNormal evaluation:\")\n", - "print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n", - "print(\"\\nGradient evaluation:\")\n", - "print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sg-D8EdeFn4a" - }, - "source": [ - "Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n", - "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further\n", - "below in this colab).\n", - "It is important to remember that a JAX-traceable function must be able to\n", - "operate not only on concrete arguments but also on special abstract arguments\n", - "that JAX may use to abstract the function execution.\n", - "\n", - "The JAX traceability property is satisfied as long as the function is written\n", - "in terms of JAX primitives." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WxrQO7-XGLcg" - }, - "source": [ - "## Defining new JAX primitives\n", - "\n", - "The right way to add support for multiply-add is in terms of existing\n", - "JAX primitives, as shown above. However, in order to demonstrate how JAX\n", - "primitives work let us pretend that we want to add a new primitive to\n", - "JAX for the multiply-add functionality." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cPqAH1XOGTN4" - }, - "outputs": [], - "source": [ - "from jax import core\n", - "multiply_add_p = core.Primitive(\"multiply_add\") # Create the primitive\n", - "\n", - "@trace(\"multiply_add_prim\")\n", - "def multiply_add_prim(x, y, z):\n", - " \"\"\"The JAX-traceable way to use the JAX primitive.\n", - "\n", - " Note that the traced arguments must be passed as positional arguments\n", - " to `bind`.\n", - " \"\"\"\n", - " return multiply_add_p.bind(x, y, z)\n", - "\n", - "@trace(\"square_add_prim\")\n", - "def square_add_prim(a, b):\n", - " \"\"\"A square-add function implemented using the new JAX-primitive.\"\"\"\n", - " return multiply_add_prim(a, a, b)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LMzs5PAKGr-4" - }, - "source": [ - "If we try to call the newly defined functions we get an error, because\n", - "we have not yet told JAX anything about the semantics of the new primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "_X3PAYxhGpWd", - "outputId": "90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(2.0, 10.0)\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " square_add_prim(2., 10.)\n", - " File \"\", line 47, in func_wrapper\n", - " res = func(*args)\n", - " File \"\", line 16, in square_add_prim\n", - " return multiply_add_prim(a, a, b)\n", - "NotImplementedError: Evaluation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " square_add_prim(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "elha0FdgHSEF" - }, - "source": [ - "### Primal evaluation rules" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "FT34FFAGHARU", - "outputId": "4c54f1c2-8a50-4788-90e1-06aee412c43b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "@trace(\"multiply_add_impl\")\n", - "def multiply_add_impl(x, y, z):\n", - " \"\"\"Concrete implementation of the primitive.\n", - "\n", - " This function does not need to be JAX traceable.\n", - " Args:\n", - " x, y, z: the concrete arguments of the primitive. Will only be called with\n", - " concrete values.\n", - " Returns:\n", - " the concrete result of the primitive.\n", - " \"\"\"\n", - " # Note that we can use the original numpy, which is not JAX traceable\n", - " return np.add(np.multiply(x, y), z)\n", - "\n", - "# Now we register the primal implementation with JAX\n", - "multiply_add_p.def_impl(multiply_add_impl)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "G5bstKaeNAVV", - "outputId": "deb94d5b-dfea-4e6f-9ec2-70b416c996c5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(2.0, 10.0)\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - "|<- square_add_prim = 14.0\n" - ] - } - ], - "source": [ - "assert square_add_prim(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "upBf-uAuHhPJ" - }, - "source": [ - "### JIT\n", - "\n", - "If we now try to use `jit` we get a `NotImplementedError`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "QG-LULjiHk4b", - "outputId": "d4ef4406-8dae-4c96-97ca-b662340474ee" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jit(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n", - " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n", - " outs = primitive.impl(f, *args, **params)\n", - "NotImplementedError: Abstract evaluation for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " api.jit(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rHS1bAGHH44E" - }, - "source": [ - "#### Abstract evaluation rules\n", - "In order to JIT the function, and for other transformations as well,\n", - "JAX first evaluates it abstractly using only the\n", - "shape and type of the arguments. This abstract evaluation serves multiple\n", - "purposes:\n", - "\n", - " * Gets the sequence of JAX primitives that are used in the computation. This\n", - " sequence will be compiled.\n", - " * Computes the shape and type of all vectors and operations used in the computation.\n", - "\n", - "\n", - "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.\n", - "In the latter case, JAX uses the actual concrete value wrapped as an abstract value." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "ctQmEeckIbdo", - "outputId": "e751d0cc-460e-4ffd-df2e-fdabf9cffdc2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import core\n", - "@trace(\"multiply_add_abstract_eval\")\n", - "def multiply_add_abstract_eval(xs, ys, zs):\n", - " \"\"\"Abstract evaluation of the primitive.\n", - "\n", - " This function does not need to be JAX traceable. It will be invoked with\n", - " abstractions of the actual arguments.\n", - " Args:\n", - " xs, ys, zs: abstractions of the arguments.\n", - " Result:\n", - " a ShapedArray for the result of the primitive.\n", - " \"\"\"\n", - " assert xs.shape == ys.shape\n", - " assert xs.shape == zs.shape\n", - " return core.ShapedArray(xs.shape, xs.dtype)\n", - "\n", - "# Now we register the abstract evaluation with JAX\n", - "multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RPN88X6YI43A" - }, - "source": [ - "If we re-attempt to JIT, we see how the abstract evaluation proceeds, but\n", - "we get another error, about missing the actual XLA compilation rule:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "eOcNR92SI2h-", - "outputId": "356ef229-3703-4696-cc3d-7c05de405fb0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jit(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n", - " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n", - " outs = primitive.impl(f, *args, **params)\n", - "NotImplementedError: XLA translation rule for primitive 'multiply_add' not found\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " api.jit(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9IOV1R-fJMHp" - }, - "source": [ - "#### XLA Compilation rules\n", - "\n", - "JAX compilation works by compiling each primitive into a graph of XLA operations.\n", - "\n", - "This is the biggest hurdle to adding new functionality to JAX, because the\n", - "set of XLA operations is limited, and JAX already has pre-defined primitives\n", - "for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FYQWSSjKJaWP" - }, - "outputs": [], - "source": [ - "from jax._src.lib.mlir.dialects import hlo\n", - "@trace(\"multiply_add_lowering\")\n", - "def multiply_add_lowering(ctx, xc, yc, zc):\n", - " \"\"\"The compilation to XLA of the primitive.\n", - "\n", - " Given an mlir.ir.Value for each argument, return the mlir.ir.Values for\n", - " the results of the function.\n", - "\n", - " Does not need to be a JAX-traceable function.\n", - " \"\"\"\n", - " return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]\n", - "\n", - "# Now we register the lowering rule with JAX\n", - "# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)\n", - "# TODO: TPU?\n", - "from jax.interpreters import mlir\n", - "mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K98LX-VaJkFu" - }, - "source": [ - "Now we succeed to JIT. Notice below that JAX first evaluates the function\n", - "abstractly, which triggers the `multiply_add_abstract_eval` function, and\n", - "then compiles the set of primitives it has encountered, including `multiply_add`.\n", - "At this point JAX invokes `multiply_add_xla_translation`." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "rj3TLsolJgEc", - "outputId": "e384bee4-1e9c-4344-f49c-d3b5ec08eb32" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Omrez-2_KFfo" - }, - "source": [ - "Below is another use of `jit` where we compile only\n", - "with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n", - "in the third argument to `multiply_add_abstract_eval` being\n", - "`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n", - "both `ShapedArray` and `ConcreteArray`." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "mPfTwIBoKOEK", - "outputId": "b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(10.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y),\n", - " static_argnums=1)(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_Ya3B5l4J1VA" - }, - "source": [ - "### Forward differentiation\n", - "\n", - "JAX implements forward differentiation in the form of\n", - "a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).\n", - "\n", - "If we attempt now to compute the `jvp` function we get an\n", - "error because we have not yet told JAX how to differentiate\n", - "the `multiply_add` primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "OxDx6NQnKwMI", - "outputId": "ce659ef3-c03c-4856-f252-49ec4b6eb964" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 217, in process_primitive\n", - " jvp = primitive_jvps[primitive]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jvp(square_add_prim, (2., 10.), (1., 1.))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 978, in jvp\n", - " out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\", line 165, in call_wrapped\n", - " ans = self.f(*args, **dict(self.params, **kwargs))\n", - "NotImplementedError: Forward-mode differentiation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# The second argument `(2., 10.)` are the argument values\n", - "# where we evaluate the Jacobian, and the third `(1., 1.)`\n", - "# are the values of the tangents for the arguments.\n", - "with expectNotImplementedError():\n", - " api.jvp(square_add_prim, (2., 10.), (1., 1.))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zxG24C1JMIMM" - }, - "outputs": [], - "source": [ - "from jax.interpreters import ad\n", - "\n", - "\n", - "@trace(\"multiply_add_value_and_jvp\")\n", - "def multiply_add_value_and_jvp(arg_values, arg_tangents):\n", - " \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n", - "\n", - " Given values of the arguments and perturbation of the arguments (tangents),\n", - " compute the output of the primitive and the perturbation of the output.\n", - "\n", - " This method must be JAX-traceable. JAX may invoke it with abstract values\n", - " for the arguments and tangents.\n", - "\n", - " Args:\n", - " arg_values: a tuple of arguments\n", - " arg_tangents: a tuple with the tangents of the arguments. The tuple has\n", - " the same length as the arg_values. Some of the tangents may also be the\n", - " special value ad.Zero to specify a zero tangent.\n", - " Returns:\n", - " a pair of the primal output and the tangent.\n", - " \"\"\"\n", - " x, y, z = arg_values\n", - " xt, yt, zt = arg_tangents\n", - " _trace(\"Primal evaluation:\")\n", - " # Now we have a JAX-traceable computation of the output.\n", - " # Normally, we can use the ma primitive itself to compute the primal output.\n", - " primal_out = multiply_add_prim(x, y, z)\n", - "\n", - " _trace(\"Tangent evaluation:\")\n", - " # We must use a JAX-traceable way to compute the tangent. It turns out that\n", - " # the output tangent can be computed as (xt * y + x * yt + zt),\n", - " # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n", - "\n", - " # We do need to deal specially with Zero. Here we just turn it into a\n", - " # proper tensor of 0s (of the same shape as 'x').\n", - " # An alternative would be to check for Zero and perform algebraic\n", - " # simplification of the output tangent computation.\n", - " def make_zero(tan):\n", - " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan\n", - "\n", - " output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n", - " return (primal_out, output_tangent)\n", - "\n", - "# Register the forward differentiation rule with JAX\n", - "ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "ma3KBkiAMfW1", - "outputId": "f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))\n", - " Primal evaluation:\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(2.0, 1.0, 1.0)\n", - " call multiply_add_impl(2.0, 1.0, 1.0)\n", - " |<- multiply_add_impl = 3.0\n", - " |<- multiply_add_prim = 3.0\n", - " call multiply_add_prim(1.0, 2.0, 3.0)\n", - " call multiply_add_impl(1.0, 2.0, 3.0)\n", - " |<- multiply_add_impl = 5.0\n", - " |<- multiply_add_prim = 5.0\n", - " |<- multiply_add_value_and_jvp = (14.0, 5.0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n" - ] - } - ], - "source": [ - "# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.\n", - "assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69QsEcu-lP4u" - }, - "source": [ - "TO EXPLAIN:\n", - "\n", - " * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n", - " * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n", - " we do not call the multiply_add_abstract_eval.\n", - " * I think it would be useful to show the jaxpr here" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sb6e3ZAHOPHv" - }, - "source": [ - "#### JIT of forward differentiation\n", - "\n", - "We can apply JIT to the forward differentiation function:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "hg-hzVu-N-hv", - "outputId": "38d32067-e152-4046-ad80-7f95a31ba628" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Traced))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (Traced, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda arg_values, arg_tangents:\n", - " api.jvp(square_add_prim, arg_values, arg_tangents))(\n", - " (2., 10.), (1., 1.)) == (14., 5.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jlZt1_v2mU88" - }, - "source": [ - "Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n", - "evaluates abstractly both the primal and the tangent evaluation (a total of\n", - "3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n", - "of the primitive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "555yt6ZIOePB" - }, - "source": [ - "### Reverse differentiation\n", - "\n", - "If we attempt now to use reverse differentiation we\n", - "see that JAX starts by using the `multiply_add_value_and_jvp` to\n", - "compute the forward differentiation for abstract values, but then runs\n", - "into a `NotImplementedError`.\n", - "\n", - "When computing the reverse differentiation JAX first does abstract evaluation\n", - "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a\n", - "trace of primitives that compute the output tangent.\n", - "Observe that JAX performs this abstract evaluation with concrete values\n", - "for the differentiation point, and abstract values for the tangents.\n", - "Observe also that JAX uses the special abstract tangent value `Zero` for\n", - "the tangent corresponding to the 3rd argument of `ma`. This reflects the\n", - "fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n", - "which flows to the 3rd argument to `multiply_add_prim`.\n", - "\n", - "Observe also that during the abstract evaluation of the tangent we pass the\n", - "value 0.0 as the tangent for the 3rd argument. This is due to the use\n", - "of the `make_zero` function in the definition of `multiply_add_value_and_jvp`." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "8eAVnexaOjBn", - "outputId": "e4ee89cf-ab4a-4505-9817-fa978a2865ab" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, 0.0)\n", - " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (14.0, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 198, in get_primitive_transpose\n", - " return primitive_transposes[p]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.grad(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n", - " _, g = value_and_grad_f(*args, **kwargs)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n", - " g = vjp_py(np.ones((), dtype=dtype))\n", - "NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# This is reverse differentiation w.r.t. the first argument of square_add_prim\n", - "with expectNotImplementedError():\n", - " api.grad(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fSHLUMDN26AY" - }, - "source": [ - "The above error is because there is a missing piece for JAX to be able\n", - "to use the forward differentiation code to compute reverse differentiation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3ibDbGF-PjK9" - }, - "source": [ - "#### Transposition\n", - "\n", - "\n", - "As explained above, when computing reverse differentiation JAX obtains\n", - "a trace of primitives that compute the tangent using forward differentiation.\n", - "Then, **JAX interprets this trace abstractly backwards** and for each\n", - "primitive it applies a **transposition** rule.\n", - "\n", - "To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n", - "```\n", - " a = xt * 4.\n", - " b = 2. * yt\n", - " c = a + b\n", - " ft = c + yt\n", - "```\n", - "\n", - "By construction, the tangent calculation is always linear in the input tangents.\n", - "The only non-linear operator that may arise in the tangent calculation is multiplication,\n", - "but then one of the operands is constant.\n", - "\n", - "JAX will produce the reverse differentiation computation by processing the\n", - "JVP computation backwards. For each operation in the tangent computation,\n", - "it accumulates the cotangents\n", - "of the variables used by the operation, using the cotangent of the result\n", - "of the operation:\n", - "```\n", - " # Initialize cotangents of inputs and intermediate vars\n", - " xct = yct = act = bct = cct = 0.\n", - " # Initialize cotangent of the output\n", - " fct = 1.\n", - " # Process \"ft = c + yt\"\n", - " cct += fct\n", - " yct += fct\n", - " # Process \"c = a + b\"\n", - " act += cct\n", - " bct += cct\n", - " # Process \"b = 2. * yt\"\n", - " yct += 2. * bct\n", - " # Process \"a = xt * 4.\"\n", - " xct += act * 4.\n", - "```\n", - "\n", - "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which\n", - "are the partial derivatives of the function `f`.\n", - "\n", - "JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n", - "```\n", - "p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)\n", - "```\n", - "\n", - "Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n", - "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned\n", - "for the constant arguments.\n", - "\n", - "In particular,\n", - "```\n", - " add_transpose(out_ct, _, _) = (out_ct, out_ct)\n", - " mult_transpose(out_ct, x, _) = (None, x * out_ct)\n", - " mult_transpose(out_ct, _, y) = (out_ct * y, None)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JaHxFdkRO42r" - }, - "outputs": [], - "source": [ - "@trace(\"multiply_add_transpose\")\n", - "def multiply_add_transpose(ct, x, y, z):\n", - " \"\"\"Evaluates the transpose of a linear primitive.\n", - "\n", - " This method is only used when computing the backward gradient following\n", - " value_and_jvp, and is only needed for primitives that are used in the JVP\n", - " calculation for some other primitive. We need transposition for multiply_add_prim,\n", - " because we have used multiply_add_prim in the computation of the output_tangent in\n", - " multiply_add_value_and_jvp.\n", - "\n", - " In our case, multiply_add is not a linear primitive. However, it is used linearly\n", - " w.r.t. tangents in multiply_add_value_and_jvp:\n", - " output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n", - "\n", - " Always one of the first two multiplicative arguments is a constant.\n", - "\n", - " Args:\n", - " ct: the cotangent of the output of the primitive.\n", - " x, y, z: values of the arguments. The arguments that are used linearly\n", - " get an ad.UndefinedPrimal value. The other arguments get a constant\n", - " value.\n", - " Returns:\n", - " a tuple with the cotangent of the inputs, with the value None\n", - " corresponding to the constant arguments.\n", - " \"\"\"\n", - " if not ad.is_undefined_primal(x):\n", - " # This use of multiply_add is with a constant \"x\"\n", - " assert ad.is_undefined_primal(y)\n", - " ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))\n", - " res = None, ct_y, ct\n", - " else:\n", - " # This use of multiply_add is with a constant \"y\"\n", - " assert ad.is_undefined_primal(x)\n", - " ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))\n", - " res = ct_x, None, ct\n", - " return res\n", - "\n", - "\n", - "ad.primitive_transposes[multiply_add_p] = multiply_add_transpose" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpChox-Jp7wb" - }, - "source": [ - "Now we can complete the run of the `grad`:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "PogPKS4MPevd", - "outputId": "d33328d4-3e87-45b5-9b31-21ad624b67af" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, 0.0)\n", - " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (14.0, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_transpose(1.0, _, 2.0, _)\n", - " call multiply_add_prim(1.0, 2.0, 0.0)\n", - " call multiply_add_impl(1.0, 2.0, 0.0)\n", - " |<- multiply_add_impl = 2.0\n", - " |<- multiply_add_prim = 2.0\n", - "|<- multiply_add_transpose = (2.0, None, 1.0)\n", - "call multiply_add_transpose(1.0, 2.0, _, 0.0)\n", - " call multiply_add_prim(2.0, 1.0, 0.0)\n", - " call multiply_add_impl(2.0, 1.0, 0.0)\n", - " |<- multiply_add_impl = 2.0\n", - " |<- multiply_add_prim = 2.0\n", - "|<- multiply_add_transpose = (None, 2.0, 1.0)\n" - ] - } - ], - "source": [ - "assert api.grad(square_add_prim)(2., 10.) == 4." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8M1xLCXW4fK7" - }, - "source": [ - "Notice the two calls to `multiply_add_transpose`. They correspond to the two\n", - "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the\n", - "last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EIJs6FYmPg6c" - }, - "source": [ - "#### JIT of reverse differentiation\n", - "\n", - "Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n", - "abstract values, while in the absence of JIT we used `ConcreteArray`." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "FZ-JGbWZPq2-", - "outputId": "e42b5222-9c3e-4853-e13a-874f6605d178" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (Traced, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_transpose(1.0, _, Traced, _)\n", - " call multiply_add_prim(1.0, Traced, Traced)\n", - " call multiply_add_abstract_eval(ConcreteArray(1.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- multiply_add_transpose = (Traced, None, 1.0)\n", - "call multiply_add_transpose(1.0, Traced, _, Traced)\n", - " call multiply_add_prim(Traced, 1.0, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(1.0), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- multiply_add_transpose = (None, Traced, 1.0)\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(api.grad(square_add_prim))(2., 10.) == 4." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-3lqPkdQPvl5" - }, - "source": [ - "### Batching\n", - "\n", - "The batching transformation takes a point-wise computation and turns it\n", - "into a computation on vectors. If we try it right now, we get a `NotImplementedError`:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "hFvBR3I9Pzh3", - "outputId": "434608bc-281f-4d3b-83bd-eaaf3b51b1cd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 163, in get_primitive_batcher\n", - " return primitive_batchers[p]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 3, in \n", - " np.array([10., 20.]))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n", - " lambda: _flatten_axes(out_tree(), out_axes))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n", - " out_vals, out_dims = batch2(fun, in_vals, in_dims)\n", - "NotImplementedError: Batching rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# The arguments are two vectors instead of two scalars\n", - "with expectNotImplementedError():\n", - " api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n", - " np.array([10., 20.]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gILasMiP6elR" - }, - "source": [ - "We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KQfeqRIrP7zg" - }, - "outputs": [], - "source": [ - "from jax.interpreters import batching\n", - "\n", - "\n", - "@trace(\"multiply_add_batch\")\n", - "def multiply_add_batch(vector_arg_values, batch_axes):\n", - " \"\"\"Computes the batched version of the primitive.\n", - "\n", - " This must be a JAX-traceable function.\n", - "\n", - " Since the multiply_add primitive already operates pointwise on arbitrary\n", - " dimension tensors, to batch it we can use the primitive itself. This works as\n", - " long as both the inputs have the same dimensions and are batched along the\n", - " same axes. The result is batched along the axis that the inputs are batched.\n", - "\n", - " Args:\n", - " vector_arg_values: a tuple of two arguments, each being a tensor of matching\n", - " shape.\n", - " batch_axes: the axes that are being batched. See vmap documentation.\n", - " Returns:\n", - " a tuple of the result, and the result axis that was batched.\n", - " \"\"\"\n", - " assert batch_axes[0] == batch_axes[1]\n", - " assert batch_axes[0] == batch_axes[2]\n", - " _trace(\"Using multiply_add to compute the batch:\")\n", - " res = multiply_add_prim(*vector_arg_values)\n", - " return res, batch_axes[0]\n", - "\n", - "\n", - "batching.primitive_batchers[multiply_add_p] = multiply_add_batch" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "VwxNk869P_YG", - "outputId": "9d22c921-5803-4d33-9e88-b6e439ba9738" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))\n", - " Using multiply_add to compute the batch:\n", - " call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])\n", - " call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])\n", - " |<- multiply_add_impl = [14. 29.]\n", - " |<- multiply_add_prim = [14. 29.]\n", - " |<- multiply_add_batch = ([14. 29.], 0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n" - ] - } - ], - "source": [ - "assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n", - " np.array([2., 3.]),\n", - " np.array([10., 20.])),\n", - " [14., 29.])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NmqLlV1TQDCC" - }, - "source": [ - "#### JIT of batching" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "xqEdXVUgQCTt", - "outputId": "9c22fd9c-919c-491d-bbeb-32c241b808fa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_batch((Traced, Traced, Traced), (0, 0, 0))\n", - " Using multiply_add to compute the batch:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[2])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_batch = (Traced, 0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n", - " (np.array([2., 3.]),\n", - " np.array([10., 20.])),\n", - " [14., 29.])" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "How JAX primitives work.ipynb", - "provenance": [], - "toc_visible": true - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md deleted file mode 100644 index 7c24ac11a6ce..000000000000 --- a/docs/notebooks/How_JAX_primitives_work.md +++ /dev/null @@ -1,771 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "vfxqky4PCUnh"} - -# How JAX primitives work - - - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) - -*necula@google.com*, October 2019. - -JAX implements certain transformations of Python functions, e.g., `jit`, `grad`, -`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, -which means that as the Python function executes -the only operations it applies to the data are either inspections of data -attributes such as shape or type, or special operations called JAX primitives. -In particular, a JAX-traceable function is sometimes invoked by JAX with -abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, -which captures the type and the shape of values, but not the concrete data values. -JAX primitives know how to operate on both concrete data -values and on the JAX abstract values. - - -The JAX-transformed functions must themselves be JAX-traceable functions, -to ensure that these transformations -can be composed, e.g., `jit(jacfwd(grad(f)))`. - -There are pre-defined JAX primitives corresponding to most XLA operations, -e.g., add, matmul, sin, cos, indexing. -JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs -using JAX’s implementation of numpy are JAX-traceable and therefore transformable. -Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives. - -The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives, -one can define a new primitive that encapsulates the behavior of the function. - -**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.** - -Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically -as "multiply_add(x, y, z) = x * y + z". -This function operates on 3 identically-shaped tensors of floating point -values and performs the operations pointwise. - -+++ {"id": "HIJYIHNTD1yI"} - -## Using existing primitives - -The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other -functions that are themselves written using JAX primitives, e.g., those -defined in the `jax.lax` module: - -```{code-cell} ipython3 -:id: tbOF0LB0EMne -:outputId: 3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528 - -from jax import lax -from jax._src import api - -def multiply_add_lax(x, y, z): - """Implementation of multiply-add using the jax.lax primitives.""" - return lax.add(lax.mul(x, y), z) - - -def square_add_lax(a, b): - """A square-add function using the newly defined multiply-add.""" - return multiply_add_lax(a, a, b) - -print("square_add_lax = ", square_add_lax(2., 10.)) -# Differentiate w.r.t. the first argument -print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.)) -``` - -+++ {"id": "Cgv60Wm3E_D5"} - -In order to understand how JAX is internally using the primitives, -we add some helpers for tracing function calls. - -```{code-cell} ipython3 -:cellView: form -:id: mQRQGEGiE53K - -#@title Helper functions (execute this cell) -import functools -import traceback - -_indentation = 0 -def _trace(msg=None): - """Print a message at current indentation.""" - if msg is not None: - print(" " * _indentation + msg) - -def _trace_indent(msg=None): - """Print a message and then indent the rest.""" - global _indentation - _trace(msg) - _indentation = 1 + _indentation - -def _trace_unindent(msg=None): - """Unindent then print a message.""" - global _indentation - _indentation = _indentation - 1 - _trace(msg) - -def trace(name): - """A decorator for functions to trace arguments and results.""" - - def trace_func(func): # pylint: disable=missing-docstring - def pp(v): - """Print certain values more succinctly""" - vtype = str(type(v)) - if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: - return "" - elif "jaxlib.xla_extension.XlaOp" in vtype: - return "".format(id(v)) - elif ("partial_eval.JaxprTracer" in vtype or - "batching.BatchTracer" in vtype or - "ad.JVPTracer" in vtype): - return "Traced<{}>".format(v.aval) - elif isinstance(v, tuple): - return "({})".format(pp_values(v)) - else: - return str(v) - def pp_values(args): - return ", ".join([pp(arg) for arg in args]) - - @functools.wraps(func) - def func_wrapper(*args): - _trace_indent("call {}({})".format(name, pp_values(args))) - res = func(*args) - _trace_unindent("|<- {} = {}".format(name, pp(res))) - return res - - return func_wrapper - - return trace_func - -class expectNotImplementedError(object): - """Context manager to check for NotImplementedError.""" - def __enter__(self): pass - def __exit__(self, type, value, tb): - global _indentation - _indentation = 0 - if type is NotImplementedError: - print("\nFound expected exception:") - traceback.print_exc(limit=3) - return True - elif type is None: # No exception - assert False, "Expected NotImplementedError" - else: - return False -``` - -+++ {"id": "Qf4eLrLCFYDl"} - -Instead of using `jax.lax` primitives directly, we can use other functions -that are already written in terms of those primitives, such as those in `jax.numpy`: - -```{code-cell} ipython3 -:id: QhKorz6cFRJb -:outputId: aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a - -import jax.numpy as jnp -import numpy as np - -@trace("multiply_add_numpy") -def multiply_add_numpy(x, y, z): - return jnp.add(jnp.multiply(x, y), z) - -@trace("square_add_numpy") -def square_add_numpy(a, b): - return multiply_add_numpy(a, a, b) - -print("\nNormal evaluation:") -print("square_add_numpy = ", square_add_numpy(2., 10.)) -print("\nGradient evaluation:") -print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) -``` - -+++ {"id": "Sg-D8EdeFn4a"} - -Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and -`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further -below in this colab). -It is important to remember that a JAX-traceable function must be able to -operate not only on concrete arguments but also on special abstract arguments -that JAX may use to abstract the function execution. - -The JAX traceability property is satisfied as long as the function is written -in terms of JAX primitives. - -+++ {"id": "WxrQO7-XGLcg"} - -## Defining new JAX primitives - -The right way to add support for multiply-add is in terms of existing -JAX primitives, as shown above. However, in order to demonstrate how JAX -primitives work let us pretend that we want to add a new primitive to -JAX for the multiply-add functionality. - -```{code-cell} ipython3 -:id: cPqAH1XOGTN4 - -from jax import core -multiply_add_p = core.Primitive("multiply_add") # Create the primitive - -@trace("multiply_add_prim") -def multiply_add_prim(x, y, z): - """The JAX-traceable way to use the JAX primitive. - - Note that the traced arguments must be passed as positional arguments - to `bind`. - """ - return multiply_add_p.bind(x, y, z) - -@trace("square_add_prim") -def square_add_prim(a, b): - """A square-add function implemented using the new JAX-primitive.""" - return multiply_add_prim(a, a, b) -``` - -+++ {"id": "LMzs5PAKGr-4"} - -If we try to call the newly defined functions we get an error, because -we have not yet told JAX anything about the semantics of the new primitive. - -```{code-cell} ipython3 -:id: _X3PAYxhGpWd -:outputId: 90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8 - -with expectNotImplementedError(): - square_add_prim(2., 10.) -``` - -+++ {"id": "elha0FdgHSEF"} - -### Primal evaluation rules - -```{code-cell} ipython3 -:id: FT34FFAGHARU -:outputId: 4c54f1c2-8a50-4788-90e1-06aee412c43b - -@trace("multiply_add_impl") -def multiply_add_impl(x, y, z): - """Concrete implementation of the primitive. - - This function does not need to be JAX traceable. - Args: - x, y, z: the concrete arguments of the primitive. Will only be called with - concrete values. - Returns: - the concrete result of the primitive. - """ - # Note that we can use the original numpy, which is not JAX traceable - return np.add(np.multiply(x, y), z) - -# Now we register the primal implementation with JAX -multiply_add_p.def_impl(multiply_add_impl) -``` - -```{code-cell} ipython3 -:id: G5bstKaeNAVV -:outputId: deb94d5b-dfea-4e6f-9ec2-70b416c996c5 - -assert square_add_prim(2., 10.) == 14. -``` - -+++ {"id": "upBf-uAuHhPJ"} - -### JIT - -If we now try to use `jit` we get a `NotImplementedError`: - -```{code-cell} ipython3 -:id: QG-LULjiHk4b -:outputId: d4ef4406-8dae-4c96-97ca-b662340474ee - -with expectNotImplementedError(): - api.jit(square_add_prim)(2., 10.) -``` - -+++ {"id": "rHS1bAGHH44E"} - -#### Abstract evaluation rules -In order to JIT the function, and for other transformations as well, -JAX first evaluates it abstractly using only the -shape and type of the arguments. This abstract evaluation serves multiple -purposes: - - * Gets the sequence of JAX primitives that are used in the computation. This - sequence will be compiled. - * Computes the shape and type of all vectors and operations used in the computation. - - -For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. -In the latter case, JAX uses the actual concrete value wrapped as an abstract value. - -```{code-cell} ipython3 -:id: ctQmEeckIbdo -:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2 - -from jax import core -@trace("multiply_add_abstract_eval") -def multiply_add_abstract_eval(xs, ys, zs): - """Abstract evaluation of the primitive. - - This function does not need to be JAX traceable. It will be invoked with - abstractions of the actual arguments. - Args: - xs, ys, zs: abstractions of the arguments. - Result: - a ShapedArray for the result of the primitive. - """ - assert xs.shape == ys.shape - assert xs.shape == zs.shape - return core.ShapedArray(xs.shape, xs.dtype) - -# Now we register the abstract evaluation with JAX -multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) -``` - -+++ {"id": "RPN88X6YI43A"} - -If we re-attempt to JIT, we see how the abstract evaluation proceeds, but -we get another error, about missing the actual XLA compilation rule: - -```{code-cell} ipython3 -:id: eOcNR92SI2h- -:outputId: 356ef229-3703-4696-cc3d-7c05de405fb0 - -with expectNotImplementedError(): - api.jit(square_add_prim)(2., 10.) -``` - -+++ {"id": "9IOV1R-fJMHp"} - -#### XLA Compilation rules - -JAX compilation works by compiling each primitive into a graph of XLA operations. - -This is the biggest hurdle to adding new functionality to JAX, because the -set of XLA operations is limited, and JAX already has pre-defined primitives -for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++. - -```{code-cell} ipython3 -:id: FYQWSSjKJaWP - -from jax._src.lib.mlir.dialects import hlo -@trace("multiply_add_lowering") -def multiply_add_lowering(ctx, xc, yc, zc): - """The compilation to XLA of the primitive. - - Given an mlir.ir.Value for each argument, return the mlir.ir.Values for - the results of the function. - - Does not need to be a JAX-traceable function. - """ - return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] - -# Now we register the lowering rule with JAX -# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) -# TODO: TPU? -from jax.interpreters import mlir -mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') -``` - -+++ {"id": "K98LX-VaJkFu"} - -Now we succeed to JIT. Notice below that JAX first evaluates the function -abstractly, which triggers the `multiply_add_abstract_eval` function, and -then compiles the set of primitives it has encountered, including `multiply_add`. -At this point JAX invokes `multiply_add_xla_translation`. - -```{code-cell} ipython3 -:id: rj3TLsolJgEc -:outputId: e384bee4-1e9c-4344-f49c-d3b5ec08eb32 - -assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. -``` - -+++ {"id": "Omrez-2_KFfo"} - -Below is another use of `jit` where we compile only -with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads -in the third argument to `multiply_add_abstract_eval` being -`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with -both `ShapedArray` and `ConcreteArray`. - -```{code-cell} ipython3 -:id: mPfTwIBoKOEK -:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b - -assert api.jit(lambda x, y: square_add_prim(x, y), - static_argnums=1)(2., 10.) == 14. -``` - -+++ {"id": "_Ya3B5l4J1VA"} - -### Forward differentiation - -JAX implements forward differentiation in the form of -a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)). - -If we attempt now to compute the `jvp` function we get an -error because we have not yet told JAX how to differentiate -the `multiply_add` primitive. - -```{code-cell} ipython3 -:id: OxDx6NQnKwMI -:outputId: ce659ef3-c03c-4856-f252-49ec4b6eb964 - -# The second argument `(2., 10.)` are the argument values -# where we evaluate the Jacobian, and the third `(1., 1.)` -# are the values of the tangents for the arguments. -with expectNotImplementedError(): - api.jvp(square_add_prim, (2., 10.), (1., 1.)) -``` - -```{code-cell} ipython3 -:id: zxG24C1JMIMM - -from jax.interpreters import ad - - -@trace("multiply_add_value_and_jvp") -def multiply_add_value_and_jvp(arg_values, arg_tangents): - """Evaluates the primal output and the tangents (Jacobian-vector product). - - Given values of the arguments and perturbation of the arguments (tangents), - compute the output of the primitive and the perturbation of the output. - - This method must be JAX-traceable. JAX may invoke it with abstract values - for the arguments and tangents. - - Args: - arg_values: a tuple of arguments - arg_tangents: a tuple with the tangents of the arguments. The tuple has - the same length as the arg_values. Some of the tangents may also be the - special value ad.Zero to specify a zero tangent. - Returns: - a pair of the primal output and the tangent. - """ - x, y, z = arg_values - xt, yt, zt = arg_tangents - _trace("Primal evaluation:") - # Now we have a JAX-traceable computation of the output. - # Normally, we can use the ma primitive itself to compute the primal output. - primal_out = multiply_add_prim(x, y, z) - - _trace("Tangent evaluation:") - # We must use a JAX-traceable way to compute the tangent. It turns out that - # the output tangent can be computed as (xt * y + x * yt + zt), - # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive. - - # We do need to deal specially with Zero. Here we just turn it into a - # proper tensor of 0s (of the same shape as 'x'). - # An alternative would be to check for Zero and perform algebraic - # simplification of the output tangent computation. - def make_zero(tan): - return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan - - output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt))) - return (primal_out, output_tangent) - -# Register the forward differentiation rule with JAX -ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp -``` - -```{code-cell} ipython3 -:id: ma3KBkiAMfW1 -:outputId: f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd - -# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5. -assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) -``` - -+++ {"id": "69QsEcu-lP4u"} - -TO EXPLAIN: - - * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here. - * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet - we do not call the multiply_add_abstract_eval. - * I think it would be useful to show the jaxpr here - -+++ {"id": "Sb6e3ZAHOPHv"} - -#### JIT of forward differentiation - -We can apply JIT to the forward differentiation function: - -```{code-cell} ipython3 -:id: hg-hzVu-N-hv -:outputId: 38d32067-e152-4046-ad80-7f95a31ba628 - -assert api.jit(lambda arg_values, arg_tangents: - api.jvp(square_add_prim, arg_values, arg_tangents))( - (2., 10.), (1., 1.)) == (14., 5.) -``` - -+++ {"id": "jlZt1_v2mU88"} - -Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn -evaluates abstractly both the primal and the tangent evaluation (a total of -3 invocations of the `ma` primitive). Then we compile the 3 occurrences -of the primitive. - -+++ {"id": "555yt6ZIOePB"} - -### Reverse differentiation - -If we attempt now to use reverse differentiation we -see that JAX starts by using the `multiply_add_value_and_jvp` to -compute the forward differentiation for abstract values, but then runs -into a `NotImplementedError`. - -When computing the reverse differentiation JAX first does abstract evaluation -of the forward differentiation code `multiply_add_value_and_jvp` to obtain a -trace of primitives that compute the output tangent. -Observe that JAX performs this abstract evaluation with concrete values -for the differentiation point, and abstract values for the tangents. -Observe also that JAX uses the special abstract tangent value `Zero` for -the tangent corresponding to the 3rd argument of `ma`. This reflects the -fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`, -which flows to the 3rd argument to `multiply_add_prim`. - -Observe also that during the abstract evaluation of the tangent we pass the -value 0.0 as the tangent for the 3rd argument. This is due to the use -of the `make_zero` function in the definition of `multiply_add_value_and_jvp`. - -```{code-cell} ipython3 -:id: 8eAVnexaOjBn -:outputId: e4ee89cf-ab4a-4505-9817-fa978a2865ab - -# This is reverse differentiation w.r.t. the first argument of square_add_prim -with expectNotImplementedError(): - api.grad(square_add_prim)(2., 10.) -``` - -+++ {"id": "fSHLUMDN26AY"} - -The above error is because there is a missing piece for JAX to be able -to use the forward differentiation code to compute reverse differentiation. - -+++ {"id": "3ibDbGF-PjK9"} - -#### Transposition - - -As explained above, when computing reverse differentiation JAX obtains -a trace of primitives that compute the tangent using forward differentiation. -Then, **JAX interprets this trace abstractly backwards** and for each -primitive it applies a **transposition** rule. - -To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`: -``` - a = xt * 4. - b = 2. * yt - c = a + b - ft = c + yt -``` - -By construction, the tangent calculation is always linear in the input tangents. -The only non-linear operator that may arise in the tangent calculation is multiplication, -but then one of the operands is constant. - -JAX will produce the reverse differentiation computation by processing the -JVP computation backwards. For each operation in the tangent computation, -it accumulates the cotangents -of the variables used by the operation, using the cotangent of the result -of the operation: -``` - # Initialize cotangents of inputs and intermediate vars - xct = yct = act = bct = cct = 0. - # Initialize cotangent of the output - fct = 1. - # Process "ft = c + yt" - cct += fct - yct += fct - # Process "c = a + b" - act += cct - bct += cct - # Process "b = 2. * yt" - yct += 2. * bct - # Process "a = xt * 4." - xct += act * 4. -``` - -One can verify that this computation produces `xct = 4.` and `yct = 3.`, which -are the partial derivatives of the function `f`. - -JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is: -``` -p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) -``` - -Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other -arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned -for the constant arguments. - -In particular, -``` - add_transpose(out_ct, _, _) = (out_ct, out_ct) - mult_transpose(out_ct, x, _) = (None, x * out_ct) - mult_transpose(out_ct, _, y) = (out_ct * y, None) -``` - -```{code-cell} ipython3 -:id: JaHxFdkRO42r - -@trace("multiply_add_transpose") -def multiply_add_transpose(ct, x, y, z): - """Evaluates the transpose of a linear primitive. - - This method is only used when computing the backward gradient following - value_and_jvp, and is only needed for primitives that are used in the JVP - calculation for some other primitive. We need transposition for multiply_add_prim, - because we have used multiply_add_prim in the computation of the output_tangent in - multiply_add_value_and_jvp. - - In our case, multiply_add is not a linear primitive. However, it is used linearly - w.r.t. tangents in multiply_add_value_and_jvp: - output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt)) - - Always one of the first two multiplicative arguments is a constant. - - Args: - ct: the cotangent of the output of the primitive. - x, y, z: values of the arguments. The arguments that are used linearly - get an ad.UndefinedPrimal value. The other arguments get a constant - value. - Returns: - a tuple with the cotangent of the inputs, with the value None - corresponding to the constant arguments. - """ - if not ad.is_undefined_primal(x): - # This use of multiply_add is with a constant "x" - assert ad.is_undefined_primal(y) - ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x)) - res = None, ct_y, ct - else: - # This use of multiply_add is with a constant "y" - assert ad.is_undefined_primal(x) - ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y)) - res = ct_x, None, ct - return res - - -ad.primitive_transposes[multiply_add_p] = multiply_add_transpose -``` - -+++ {"id": "PpChox-Jp7wb"} - -Now we can complete the run of the `grad`: - -```{code-cell} ipython3 -:id: PogPKS4MPevd -:outputId: d33328d4-3e87-45b5-9b31-21ad624b67af - -assert api.grad(square_add_prim)(2., 10.) == 4. -``` - -+++ {"id": "8M1xLCXW4fK7"} - -Notice the two calls to `multiply_add_transpose`. They correspond to the two -uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the -last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0. - -+++ {"id": "EIJs6FYmPg6c"} - -#### JIT of reverse differentiation - -Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only -abstract values, while in the absence of JIT we used `ConcreteArray`. - -```{code-cell} ipython3 -:id: FZ-JGbWZPq2- -:outputId: e42b5222-9c3e-4853-e13a-874f6605d178 - -assert api.jit(api.grad(square_add_prim))(2., 10.) == 4. -``` - -+++ {"id": "-3lqPkdQPvl5"} - -### Batching - -The batching transformation takes a point-wise computation and turns it -into a computation on vectors. If we try it right now, we get a `NotImplementedError`: - -```{code-cell} ipython3 -:id: hFvBR3I9Pzh3 -:outputId: 434608bc-281f-4d3b-83bd-eaaf3b51b1cd - -# The arguments are two vectors instead of two scalars -with expectNotImplementedError(): - api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]), - np.array([10., 20.])) -``` - -+++ {"id": "gILasMiP6elR"} - -We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation. - -```{code-cell} ipython3 -:id: KQfeqRIrP7zg - -from jax.interpreters import batching - - -@trace("multiply_add_batch") -def multiply_add_batch(vector_arg_values, batch_axes): - """Computes the batched version of the primitive. - - This must be a JAX-traceable function. - - Since the multiply_add primitive already operates pointwise on arbitrary - dimension tensors, to batch it we can use the primitive itself. This works as - long as both the inputs have the same dimensions and are batched along the - same axes. The result is batched along the axis that the inputs are batched. - - Args: - vector_arg_values: a tuple of two arguments, each being a tensor of matching - shape. - batch_axes: the axes that are being batched. See vmap documentation. - Returns: - a tuple of the result, and the result axis that was batched. - """ - assert batch_axes[0] == batch_axes[1] - assert batch_axes[0] == batch_axes[2] - _trace("Using multiply_add to compute the batch:") - res = multiply_add_prim(*vector_arg_values) - return res, batch_axes[0] - - -batching.primitive_batchers[multiply_add_p] = multiply_add_batch -``` - -```{code-cell} ipython3 -:id: VwxNk869P_YG -:outputId: 9d22c921-5803-4d33-9e88-b6e439ba9738 - -assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)( - np.array([2., 3.]), - np.array([10., 20.])), - [14., 29.]) -``` - -+++ {"id": "NmqLlV1TQDCC"} - -#### JIT of batching - -```{code-cell} ipython3 -:id: xqEdXVUgQCTt -:outputId: 9c22fd9c-919c-491d-bbeb-32c241b808fa - -assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0)) - (np.array([2., 3.]), - np.array([10., 20.])), - [14., 29.]) -``` diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb deleted file mode 100644 index 3c022124e3cc..000000000000 --- a/docs/notebooks/external_callbacks.ipynb +++ /dev/null @@ -1,1121 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "7XNMxdTwURqI" - }, - "source": [ - "# External callbacks\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h6lXo6bSUYGq" - }, - "source": [ - "This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xi_nhfpnlmbm" - }, - "source": [ - "## Why callbacks?\n", - "\n", - "A callback routine is a way to perform **host-side** execution of code at runtime.\n", - "As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.\n", - "Using a simple Python `print` statement, it looks like this:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "lz8rEL1Amb4r", - "outputId": "bbd37102-19f2-46d2-b794-3d4952c6fe97" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "intermediate value: Tracedwith\n" - ] - } - ], - "source": [ - "import jax\n", - "\n", - "@jax.jit\n", - "def f(x):\n", - " y = x + 1\n", - " print(\"intermediate value: {}\".format(y))\n", - " return y * 2\n", - "\n", - "result = f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yEy41sFAmxOp" - }, - "source": [ - "What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).\n", - "\n", - "To print the value at runtime we need a callback, for example `jax.debug.print`:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "wFfHmoQxnKDF", - "outputId": "6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "intermediate value: 3\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def f(x):\n", - " y = x + 1\n", - " jax.debug.print(\"intermediate value: {}\", y)\n", - " return y * 2\n", - "\n", - "result = f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CvWv3pudn9X5" - }, - "source": [ - "This works by passing the runtime value represented by `y` back to the host process, where the host can print the value." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X0vR078znuT-" - }, - "source": [ - "## Flavors of Callback\n", - "\n", - "In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:\n", - "\n", - "- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.\n", - "- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.\n", - "- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.\n", - "\n", - "(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).\n", - "\n", - "From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.\n", - "\n", - "|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |\n", - "|-------------------------------------|----|----|----|----|----|----|\n", - "|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |\n", - "|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |\n", - "|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |\n", - "\n", - "¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff\n", - "\n", - "² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.\n", - "\n", - "³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hE_M8DaPvoym" - }, - "source": [ - "### Exploring `jax.pure_callback`\n", - "\n", - "`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).\n", - "\n", - "The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "4lQDzXy6t_-k", - "outputId": "279e4daf-0540-4eab-f535-d3bcbac74c44" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "def f_host(x):\n", - " # call a numpy (not jax.numpy) operation:\n", - " return np.sin(x).astype(x.dtype)\n", - "\n", - "def f(x):\n", - " result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " return jax.pure_callback(f_host, result_shape, x)\n", - "\n", - "x = jnp.arange(5.0)\n", - "f(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q7YCIr8qMrDs" - }, - "source": [ - "Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:\"" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "bgoZ0fxsuoWV", - "outputId": "901443bd-5cb4-4923-ce53-6f832ac22ca9" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.jit(f)(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "ajBRGWGfupu2", - "outputId": "b28e31ee-7457-4b92-872b-52d819f53ddf" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(f)(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "xe7AOGexvC13", - "outputId": "8fa77977-1f2b-41c5-cc5e-11993ee5aa3e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def body_fun(_, x):\n", - " return _, f(x)\n", - "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tMzAVs2VNj5G" - }, - "source": [ - "However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "4QAF4VhUu5bb", - "outputId": "f8a06d02-47e9-4240-8077-d7be81e5a480" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exception reporting mode: Minimal\n" - ] - } - ], - "source": [ - "%xmode minimal" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "qUpKPxlOurfY", - "outputId": "11a665e8-40eb-4b0e-dc2e-a544a25fc57e", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n" - ] - } - ], - "source": [ - "jax.grad(f)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y9DAibV4Nwpo" - }, - "source": [ - "For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LrvdAloMZbIe" - }, - "source": [ - "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "mmFc_zawZrBq", - "outputId": "a4df7568-3f64-4b2f-9a2c-7adb2e0815e0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "printing something\n" - ] - } - ], - "source": [ - "def print_something():\n", - " print('printing something')\n", - " return np.int32(0)\n", - "\n", - "@jax.jit\n", - "def f1():\n", - " return jax.pure_callback(print_something, np.int32(0))\n", - "f1();" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "tTwE4kpmaNei" - }, - "outputs": [], - "source": [ - "@jax.jit\n", - "def f2():\n", - " jax.pure_callback(print_something, np.int32(0))\n", - " return 1.0\n", - "f2();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qfyGYbw4Z5U3" - }, - "source": [ - "In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.\n", - "In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JHcJybr7OEBM" - }, - "source": [ - "### Exploring `jax.experimental.io_callback`\n", - "\n", - "In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n", - "\n", - "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "eAg5xIhrOiWV", - "outputId": "e3cfec21-d843-4852-a49d-69a69fba9fc1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[5]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax.experimental import io_callback\n", - "from functools import partial\n", - "\n", - "global_rng = np.random.default_rng(0)\n", - "\n", - "def host_side_random_like(x):\n", - " \"\"\"Generate a random array like x using the global_rng state\"\"\"\n", - " # We have two side-effects here:\n", - " # - printing the shape and dtype\n", - " # - calling global_rng, thus updating its state\n", - " print(f'generating {x.dtype}{list(x.shape)}')\n", - " return global_rng.uniform(size=x.shape).astype(x.dtype)\n", - "\n", - "@jax.jit\n", - "def numpy_random_like(x):\n", - " return io_callback(host_side_random_like, x, x)\n", - "\n", - "x = jnp.zeros(5)\n", - "numpy_random_like(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mAIF31MlXj33" - }, - "source": [ - "The `io_callback` is compatible with `vmap` by default:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "NY3o5dG6Vg6u", - "outputId": "a67a8a98-214e-40ca-ad98-a930cd3db85e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(numpy_random_like)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XXvSeeOXXquZ" - }, - "source": [ - "Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.\n", - "\n", - "If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "3aNmRsDrX3-2", - "outputId": "a8ff4b77-f4cb-442f-8cfb-ea7251c66274", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: Cannot `vmap` ordered IO callback.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot `vmap` ordered IO callback.\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def numpy_random_like_ordered(x):\n", - " return io_callback(host_side_random_like, x, x, ordered=True)\n", - "\n", - "jax.vmap(numpy_random_like_ordered)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fD2FTHlUYAZH" - }, - "source": [ - "On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "lMVzZlIEWL7F", - "outputId": "f9741c18-a30d-4d46-b706-8102849286b5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def body_fun(_, x):\n", - " return _, numpy_random_like_ordered(x)\n", - "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_sf8mCbbo8K" - }, - "source": [ - "Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "Cn6_RG4JcKZm", - "outputId": "336ae5d2-e35b-4fe5-cbfb-14a7aef28c07", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: IO callbacks do not support JVP.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m IO callbacks do not support JVP.\n" - ] - } - ], - "source": [ - "jax.grad(numpy_random_like)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "plvfn9lWcKu4" - }, - "source": [ - "However, if the callback is not dependent on a differentiated variable, it will execute:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "wxgfDmDfb5bx", - "outputId": "d8c0285c-cd04-4b4d-d15a-1b07f778882d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hello\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def f(x):\n", - " io_callback(lambda: print('hello'), None)\n", - " return x\n", - "\n", - "jax.grad(f)(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "STLI40EZcVIY" - }, - "source": [ - "Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pkkM1ZmqclV-" - }, - "source": [ - "### Exploring `debug.callback`\n", - "\n", - "Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "74TdWyu9eqBa", - "outputId": "d8551dab-2e61-492e-9ac3-dc3db51b2c18" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 1.0\n" - ] - } - ], - "source": [ - "from jax import debug\n", - "\n", - "def log_value(x):\n", - " # This could be an actual logging call; we'll use\n", - " # print() for demonstration\n", - " print(\"log:\", x)\n", - "\n", - "@jax.jit\n", - "def f(x):\n", - " debug.callback(log_value, x)\n", - " return x\n", - "\n", - "f(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P848STlsfzmW" - }, - "source": [ - "The debug callback is compatible with `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "2sSNsPB-fGVI", - "outputId": "fff58575-d94c-48fb-b88a-c1c395595fd0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 0.0\n", - "log: 1.0\n", - "log: 2.0\n", - "log: 3.0\n", - "log: 4.0\n" - ] - } - ], - "source": [ - "x = jnp.arange(5.0)\n", - "jax.vmap(f)(x);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VDMacqpXf3La" - }, - "source": [ - "And is also compatible with `grad` and other autodiff transformations" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "wkFRle-tfTDe", - "outputId": "4e8a81d0-5012-4c51-d843-3fbdc498df31" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 1.0\n" - ] - } - ], - "source": [ - "jax.grad(f)(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w8t-SDZ3gRzE" - }, - "source": [ - "This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dF7hoWGQUneJ" - }, - "source": [ - "## Example: `pure_callback` with `custom_jvp`\n", - "\n", - "One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).\n", - "Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.\n", - "\n", - "Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.\n", - "We can start by defining a straightforward `pure_callback`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ge4fNPZdVSJY" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import scipy.special\n", - "\n", - "def jv(v, z):\n", - " v, z = jnp.asarray(v), jnp.asarray(z)\n", - "\n", - " # Require the order v to be integer type: this simplifies\n", - " # the JVP rule below.\n", - " assert jnp.issubdtype(v.dtype, jnp.integer)\n", - "\n", - " # Promote the input to inexact (float/complex).\n", - " # Note that jnp.result_type() accounts for the enable_x64 flag.\n", - " z = z.astype(jnp.result_type(float, z.dtype))\n", - "\n", - " # Wrap scipy function to return the expected dtype.\n", - " _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)\n", - "\n", - " # Define the expected shape & dtype of output.\n", - " result_shape_dtype = jax.ShapeDtypeStruct(\n", - " shape=jnp.broadcast_shapes(v.shape, z.shape),\n", - " dtype=z.dtype)\n", - "\n", - " # We use vectorize=True because scipy.special.jv handles broadcasted inputs.\n", - " return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vyjQj-0QVuoN" - }, - "source": [ - "This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f4e46670f4e4" - }, - "outputs": [], - "source": [ - "j1 = partial(jv, 1)\n", - "z = jnp.arange(5.0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6svImqFHWBwj", - "outputId": "bc8c778a-6c10-443b-9be2-c0f28e2ac1a9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(j1(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d48eb4f2d48e" - }, - "source": [ - "Here is the same result with `jit`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "txvRqR9DWGdC", - "outputId": "d25f3476-23b1-48e4-dda1-3c06d32c3b87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(jax.jit(j1)(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d861a472d861" - }, - "source": [ - "And here is the same result again with `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BS-Ve5u_WU0C", - "outputId": "08cecd1f-6953-4853-e9db-25a03eb5b000" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(jax.vmap(j1)(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SCH2ii_dWXP6" - }, - "source": [ - "However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q3qh_4DrWxdQ", - "outputId": "c46b0bfa-96f3-4629-b9af-a4d4f3ccb870", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1090\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1091\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1092\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_axes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m 2655\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2656\u001b[0;31m out_primal, out_vjp = ad.vjp(\n\u001b[0m\u001b[1;32m 2657\u001b[0m flat_fun, primals_flat, reduce_axes=reduce_axes)\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3425\u001b[0m \"\"\"\n\u001b[0;32m-> 3426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtypes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 328\u001b[0m all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfind_top_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m \u001b[0mprimal_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangent_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangents_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 311\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 129\u001b[0m lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m result_avals=tuple(flat_result_avals), vectorized=vectorized)\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpure_callback_jvp_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \"Please use `jax.custom_jvp` to use callbacks while taking gradients.\")\n", - "\u001b[0;31mValueError\u001b[0m: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients." - ] - } - ], - "source": [ - "jax.grad(j1)(z)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PtYeJ_xUW09v" - }, - "source": [ - "Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:\n", - "\n", - "$$\n", - "d J_\\nu(z) = \\left\\{\n", - "\\begin{eqnarray}\n", - "-J_1(z),\\ &\\nu=0\\\\\n", - "[J_{\\nu - 1}(z) - J_{\\nu + 1}(z)]/2,\\ &\\nu\\ne 0\n", - "\\end{eqnarray}\\right.\n", - "$$\n", - "\n", - "The gradient with respect to $\\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.\n", - "\n", - "We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BOVQnt05XvLs" - }, - "outputs": [], - "source": [ - "jv = jax.custom_jvp(jv)\n", - "\n", - "@jv.defjvp\n", - "def _jv_jvp(primals, tangents):\n", - " v, z = primals\n", - " _, z_dot = tangents # Note: v_dot is always 0 because v is integer.\n", - " jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)\n", - " djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))\n", - " return jv(v, z), z_dot * djv_dz" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W1SxcvQSX44c" - }, - "source": [ - "Now computing the gradient of our function will work correctly:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sCGceBs-X8nL", - "outputId": "71c5589f-f996-44a0-f09a-ca8bb40c167a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.06447162\n" - ] - } - ], - "source": [ - "j1 = partial(jv, 1)\n", - "print(jax.grad(j1)(2.0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gWQ4phN5YB26" - }, - "source": [ - "Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QTe5mRAvYQBh", - "outputId": "d58ecff3-9419-422a-fd0e-14a7d9cf2cc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(-0.4003078, dtype=float32, weak_type=True)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.hessian(j1)(2.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QEXGxU4uYZii" - }, - "source": [ - "Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.\n", - "When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.\n", - "However, if you are running JAX on a single CPU (where the \"host\" and \"device\" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md deleted file mode 100644 index 910d47bd72ae..000000000000 --- a/docs/notebooks/external_callbacks.md +++ /dev/null @@ -1,515 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "7XNMxdTwURqI"} - -# External callbacks - - - -+++ {"id": "h6lXo6bSUYGq"} - -This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation. - -+++ {"id": "Xi_nhfpnlmbm"} - -## Why callbacks? - -A callback routine is a way to perform **host-side** execution of code at runtime. -As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation. -Using a simple Python `print` statement, it looks like this: - -```{code-cell} -:id: lz8rEL1Amb4r -:outputId: bbd37102-19f2-46d2-b794-3d4952c6fe97 - -import jax - -@jax.jit -def f(x): - y = x + 1 - print("intermediate value: {}".format(y)) - return y * 2 - -result = f(2) -``` - -+++ {"id": "yEy41sFAmxOp"} - -What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)). - -To print the value at runtime we need a callback, for example `jax.debug.print`: - -```{code-cell} -:id: wFfHmoQxnKDF -:outputId: 6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a - -@jax.jit -def f(x): - y = x + 1 - jax.debug.print("intermediate value: {}", y) - return y * 2 - -result = f(2) -``` - -+++ {"id": "CvWv3pudn9X5"} - -This works by passing the runtime value represented by `y` back to the host process, where the host can print the value. - -+++ {"id": "X0vR078znuT-"} - -## Flavors of Callback - -In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations: - -- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect. -- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk. -- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler. - -(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`). - -From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. - -|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution | -|-------------------------------------|----|----|----|----|----|----| -|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ | -|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ | -|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | - -¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff - -² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`. - -³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases. - -+++ {"id": "hE_M8DaPvoym"} - -### Exploring `jax.pure_callback` - -`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.). - -The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times. - -```{code-cell} -:id: 4lQDzXy6t_-k -:outputId: 279e4daf-0540-4eab-f535-d3bcbac74c44 - -import jax -import jax.numpy as jnp -import numpy as np - -def f_host(x): - # call a numpy (not jax.numpy) operation: - return np.sin(x).astype(x.dtype) - -def f(x): - result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) - return jax.pure_callback(f_host, result_shape, x) - -x = jnp.arange(5.0) -f(x) -``` - -+++ {"id": "q7YCIr8qMrDs"} - -Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:" - -```{code-cell} -:id: bgoZ0fxsuoWV -:outputId: 901443bd-5cb4-4923-ce53-6f832ac22ca9 - -jax.jit(f)(x) -``` - -```{code-cell} -:id: ajBRGWGfupu2 -:outputId: b28e31ee-7457-4b92-872b-52d819f53ddf - -jax.vmap(f)(x) -``` - -```{code-cell} -:id: xe7AOGexvC13 -:outputId: 8fa77977-1f2b-41c5-cc5e-11993ee5aa3e - -def body_fun(_, x): - return _, f(x) -jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] -``` - -+++ {"id": "tMzAVs2VNj5G"} - -However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics: - -```{code-cell} -:id: 4QAF4VhUu5bb -:outputId: f8a06d02-47e9-4240-8077-d7be81e5a480 - -%xmode minimal -``` - -```{code-cell} -:id: qUpKPxlOurfY -:outputId: 11a665e8-40eb-4b0e-dc2e-a544a25fc57e -:tags: [raises-exception] - -jax.grad(f)(x) -``` - -+++ {"id": "y9DAibV4Nwpo"} - -For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below. - -+++ {"id": "LrvdAloMZbIe"} - -By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely: - -```{code-cell} -:id: mmFc_zawZrBq -:outputId: a4df7568-3f64-4b2f-9a2c-7adb2e0815e0 - -def print_something(): - print('printing something') - return np.int32(0) - -@jax.jit -def f1(): - return jax.pure_callback(print_something, np.int32(0)) -f1(); -``` - -```{code-cell} -:id: tTwE4kpmaNei - -@jax.jit -def f2(): - jax.pure_callback(print_something, np.int32(0)) - return 1.0 -f2(); -``` - -+++ {"id": "qfyGYbw4Z5U3"} - -In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output. -In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects. - -+++ {"id": "JHcJybr7OEBM"} - -### Exploring `jax.experimental.io_callback` - -In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects. - -As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!). - -```{code-cell} -:id: eAg5xIhrOiWV -:outputId: e3cfec21-d843-4852-a49d-69a69fba9fc1 - -from jax.experimental import io_callback -from functools import partial - -global_rng = np.random.default_rng(0) - -def host_side_random_like(x): - """Generate a random array like x using the global_rng state""" - # We have two side-effects here: - # - printing the shape and dtype - # - calling global_rng, thus updating its state - print(f'generating {x.dtype}{list(x.shape)}') - return global_rng.uniform(size=x.shape).astype(x.dtype) - -@jax.jit -def numpy_random_like(x): - return io_callback(host_side_random_like, x, x) - -x = jnp.zeros(5) -numpy_random_like(x) -``` - -+++ {"id": "mAIF31MlXj33"} - -The `io_callback` is compatible with `vmap` by default: - -```{code-cell} -:id: NY3o5dG6Vg6u -:outputId: a67a8a98-214e-40ca-ad98-a930cd3db85e - -jax.vmap(numpy_random_like)(x) -``` - -+++ {"id": "XXvSeeOXXquZ"} - -Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run. - -If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error: - -```{code-cell} -:id: 3aNmRsDrX3-2 -:outputId: a8ff4b77-f4cb-442f-8cfb-ea7251c66274 -:tags: [raises-exception] - -@jax.jit -def numpy_random_like_ordered(x): - return io_callback(host_side_random_like, x, x, ordered=True) - -jax.vmap(numpy_random_like_ordered)(x) -``` - -+++ {"id": "fD2FTHlUYAZH"} - -On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced: - -```{code-cell} -:id: lMVzZlIEWL7F -:outputId: f9741c18-a30d-4d46-b706-8102849286b5 - -def body_fun(_, x): - return _, numpy_random_like_ordered(x) -jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] -``` - -+++ {"id": "w_sf8mCbbo8K"} - -Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable: - -```{code-cell} -:id: Cn6_RG4JcKZm -:outputId: 336ae5d2-e35b-4fe5-cbfb-14a7aef28c07 -:tags: [raises-exception] - -jax.grad(numpy_random_like)(x) -``` - -+++ {"id": "plvfn9lWcKu4"} - -However, if the callback is not dependent on a differentiated variable, it will execute: - -```{code-cell} -:id: wxgfDmDfb5bx -:outputId: d8c0285c-cd04-4b4d-d15a-1b07f778882d - -@jax.jit -def f(x): - io_callback(lambda: print('hello'), None) - return x - -jax.grad(f)(1.0); -``` - -+++ {"id": "STLI40EZcVIY"} - -Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation. - -+++ {"id": "pkkM1ZmqclV-"} - -### Exploring `debug.callback` - -Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program. - -```{code-cell} -:id: 74TdWyu9eqBa -:outputId: d8551dab-2e61-492e-9ac3-dc3db51b2c18 - -from jax import debug - -def log_value(x): - # This could be an actual logging call; we'll use - # print() for demonstration - print("log:", x) - -@jax.jit -def f(x): - debug.callback(log_value, x) - return x - -f(1.0); -``` - -+++ {"id": "P848STlsfzmW"} - -The debug callback is compatible with `vmap`: - -```{code-cell} -:id: 2sSNsPB-fGVI -:outputId: fff58575-d94c-48fb-b88a-c1c395595fd0 - -x = jnp.arange(5.0) -jax.vmap(f)(x); -``` - -+++ {"id": "VDMacqpXf3La"} - -And is also compatible with `grad` and other autodiff transformations - -```{code-cell} -:id: wkFRle-tfTDe -:outputId: 4e8a81d0-5012-4c51-d843-3fbdc498df31 - -jax.grad(f)(1.0); -``` - -+++ {"id": "w8t-SDZ3gRzE"} - -This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`. - -+++ {"id": "dF7hoWGQUneJ"} - -## Example: `pure_callback` with `custom_jvp` - -One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`). -Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers. - -Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`. -We can start by defining a straightforward `pure_callback`: - -```{code-cell} -:id: Ge4fNPZdVSJY - -import jax -import jax.numpy as jnp -import scipy.special - -def jv(v, z): - v, z = jnp.asarray(v), jnp.asarray(z) - - # Require the order v to be integer type: this simplifies - # the JVP rule below. - assert jnp.issubdtype(v.dtype, jnp.integer) - - # Promote the input to inexact (float/complex). - # Note that jnp.result_type() accounts for the enable_x64 flag. - z = z.astype(jnp.result_type(float, z.dtype)) - - # Wrap scipy function to return the expected dtype. - _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype) - - # Define the expected shape & dtype of output. - result_shape_dtype = jax.ShapeDtypeStruct( - shape=jnp.broadcast_shapes(v.shape, z.shape), - dtype=z.dtype) - - # We use vectorize=True because scipy.special.jv handles broadcasted inputs. - return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True) -``` - -+++ {"id": "vyjQj-0QVuoN"} - -This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`: - -```{code-cell} -:id: f4e46670f4e4 - -j1 = partial(jv, 1) -z = jnp.arange(5.0) -``` - -```{code-cell} -:id: 6svImqFHWBwj -:outputId: bc8c778a-6c10-443b-9be2-c0f28e2ac1a9 - -print(j1(z)) -``` - -+++ {"id": "d48eb4f2d48e"} - -Here is the same result with `jit`: - -```{code-cell} -:id: txvRqR9DWGdC -:outputId: d25f3476-23b1-48e4-dda1-3c06d32c3b87 - -print(jax.jit(j1)(z)) -``` - -+++ {"id": "d861a472d861"} - -And here is the same result again with `vmap`: - -```{code-cell} -:id: BS-Ve5u_WU0C -:outputId: 08cecd1f-6953-4853-e9db-25a03eb5b000 - -print(jax.vmap(j1)(z)) -``` - -+++ {"id": "SCH2ii_dWXP6"} - -However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function: - -```{code-cell} -:id: q3qh_4DrWxdQ -:outputId: c46b0bfa-96f3-4629-b9af-a4d4f3ccb870 -:tags: [raises-exception] - -jax.grad(j1)(z) -``` - -+++ {"id": "PtYeJ_xUW09v"} - -Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`: - -$$ -d J_\nu(z) = \left\{ -\begin{eqnarray} --J_1(z),\ &\nu=0\\ -[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 -\end{eqnarray}\right. -$$ - -The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example. - -We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function: - -```{code-cell} -:id: BOVQnt05XvLs - -jv = jax.custom_jvp(jv) - -@jv.defjvp -def _jv_jvp(primals, tangents): - v, z = primals - _, z_dot = tangents # Note: v_dot is always 0 because v is integer. - jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z) - djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1)) - return jv(v, z), z_dot * djv_dz -``` - -+++ {"id": "W1SxcvQSX44c"} - -Now computing the gradient of our function will work correctly: - -```{code-cell} -:id: sCGceBs-X8nL -:outputId: 71c5589f-f996-44a0-f09a-ca8bb40c167a - -j1 = partial(jv, 1) -print(jax.grad(j1)(2.0)) -``` - -+++ {"id": "gWQ4phN5YB26"} - -Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free: - -```{code-cell} -:id: QTe5mRAvYQBh -:outputId: d58ecff3-9419-422a-fd0e-14a7d9cf2cc3 - -jax.hessian(j1)(2.0) -``` - -+++ {"id": "QEXGxU4uYZii"} - -Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device. -When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called. -However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities. diff --git a/docs/tutorials.rst b/docs/tutorials.rst index be70c6d41654..a31517155e1a 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -16,3 +16,13 @@ Tutorials working-with-pytrees sharded-computation stateful-computations + +.. toctree:: + :maxdepth: 1 + :caption: Advanced tutorials + + advanced-autodiff + external-callbacks + gradient-checkpointing + jax-primitives + jaxpr diff --git a/docs/user_guides.rst b/docs/user_guides.rst index e917cf2fee38..6481da7a31dd 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -33,7 +33,6 @@ or deployed codebases. :maxdepth: 1 :caption: Custom operations - notebooks/external_callbacks pallas/index ffi