Skip to content

Commit

Permalink
Merge branch 'master' into rc/stopping_criteria
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Oct 30, 2024
2 parents 90a0c0a + 831e7bd commit dd40060
Show file tree
Hide file tree
Showing 55 changed files with 1,708 additions and 1,984 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/mpl_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ jobs:
- mpl-version: 3.7.0
- mpl-version: 3.7.1
- mpl-version: 3.7.2
- mpl-version: 3.7.3
- mpl-version: 3.7.4
- mpl-version: 3.7.5
- mpl-version: 3.8.0
- mpl-version: 3.8.1
- mpl-version: 3.8.2
- mpl-version: 3.8.3
- mpl-version: 3.8.4
- mpl-version: 3.9.0
- mpl-version: 3.9.2

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/notebook_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Test notebooks with pytest and nbmake
if: env.has_changes == 'true'
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/regression_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Set Swap Space
if: env.has_changes == 'true'
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
source .venv-${{ matrix.combos.python_version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Set Swap Space
if: env.has_changes == 'true'
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ repos:
rev: v3.2.0
hooks:
- id: pyupgrade
- repo: local
hooks:
- id: check_unmarked_tests
name: check_unmarked_tests
entry: devtools/check_unmarked_tests.sh
language: script
files: ^tests/
types: [python]
pass_filenames: true
7 changes: 5 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ def put(arr, inds, vals):
Returns
-------
arr : array-like
Input array with vals inserted at inds.
Copy of input array with vals inserted at inds.
In some cases JAX may decide a copy is not necessary.
"""
if isinstance(arr, np.ndarray):
arr = arr.copy()
arr[inds] = vals
return arr
return jnp.asarray(arr).at[inds].set(vals)
Expand Down Expand Up @@ -509,9 +511,10 @@ def put(arr, inds, vals):
Returns
-------
arr : array-like
Input array with vals inserted at inds.
Copy of input array with vals inserted at inds.
"""
arr = arr.copy()
arr[inds] = vals
return arr

Expand Down
177 changes: 177 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from functools import partial
from typing import Callable, Optional

from jax._src.api import (
_check_input_dtype_jacfwd,
_check_input_dtype_jacrev,
_check_output_dtype_jacfwd,
_check_output_dtype_jacrev,
_jacfwd_unravel,
_jacrev_unravel,
_jvp,
_std_basis,
_vjp,
)
from jax._src.api_util import _ensure_index, argnums_partial, check_callable
from jax._src.tree_util import tree_map, tree_structure, tree_transpose
from jax._src.util import wraps

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
Expand Down Expand Up @@ -320,3 +336,164 @@ def wrapped(*args, **kwargs):
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped


# The following section of this code is derived from JAX
# https://github.com/jax-ml/jax/blob/ff0a98a2aef958df156ca149809cf532efbbcaf4/
# jax/_src/api.py
#
# The original copyright notice is as follows
# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");


def jacfwd_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)
argnums = _ensure_index(argnums)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if not has_aux:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args)
y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args))
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
else:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap_chunked(pushfwd, chunk_size=chunk_size)(
_std_basis(dyn_args)
)
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
aux = tree_map(lambda x: x[0], aux)
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun


def jacrev_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
allow_int=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
allow_int: Optional, bool.
Whether to allow differentiating with respect to integer valued inputs. The
gradient of an integer input will have a trivial vector-space dtype (float0).
Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if not has_aux:
y, pullback = _vjp(f_partial, *dyn_args)
else:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(
tree_structure(example_args), tree_structure(y), jac_tree
)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun
Loading

0 comments on commit dd40060

Please sign in to comment.