Skip to content

Latest commit

 

History

History
1584 lines (1351 loc) · 74.8 KB

CHANGELOG.md

File metadata and controls

1584 lines (1351 loc) · 74.8 KB

Change log

Best viewed here.

jax 0.4.2

  • Breaking changes

    • Deleted jax.experimental.callback
  • Changes

    • {func}jax2tf.call_tf has a new parameter has_side_effects (default True) that can be used to declare whether an instance can be removed or replicated by JAX optimizations such as dead-code elimination ({jax-issue}#13980).

jaxlib 0.4.2

jax 0.4.1 (Dec 13, 2022)

  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
    • We introduce jax.Array which is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. jax.Array has been enabled by default in JAX 0.4 and makes some breaking change to the pjit API. The jax.Array migration guide can help you migrate your codebase to jax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.
    • PartitionSpec and Mesh are now out of experimental. The new API endpoints are jax.sharding.PartitionSpec and jax.sharding.Mesh. jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are deprecated and will be removed in 3 months.
    • with_sharding_constraints new public endpoint is jax.lax.with_sharding_constraint.
    • If using ABSL flags together with jax.config, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading jax.config options, which are used pervasively in JAX.
    • The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
    • A number of jax.numpy functions now have their arguments marked as positional-only, matching NumPy.
    • jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibility policy. It can be replaced with jnp.sort(a, axis=0).

jaxlib 0.4.1 (Dec 13, 2022)

  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
    • The behavior of XLA_PYTHON_CLIENT_MEM_FRACTION=.XX has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details.
    • The deprecated method .block_host_until_ready() has been removed. Use .block_until_ready() instead.

jax 0.4.0 (Dec 12, 2022)

  • The release was yanked.

jaxlib 0.4.0 (Dec 12, 2022)

  • The release was yanked.

jax 0.3.25 (Nov 15, 2022)

  • Changes
    • {func}jax.numpy.linalg.pinv now supports the hermitian option.
    • {func}jax.scipy.linalg.hessenberg is now supported on CPU only. Requires jaxlib > 0.3.24.
    • New functions {func}jax.lax.linalg.hessenberg, {func}jax.lax.linalg.tridiagonal, and {func}jax.lax.linalg.householder_product were added. Householder reduction is currently CPU-only and tridiagonal reductions are supported on CPU and GPU only.
    • The gradients of svd and jax.numpy.linalg.pinv are now computed more economically for non-square matrices.
  • Breaking Changes
    • Deleted the jax_experimental_name_stack config option.
    • Convert a string axis_names arguments to the {class}jax.experimental.maps.Mesh constructor into a singleton tuple instead of unpacking the string into a sequence of character axis names.

jaxlib 0.3.25 (Nov 15, 2022)

  • Changes
    • Added support for tridiagonal reductions on CPU and GPU.
    • Added support for upper Hessenberg reductions on CPU.
  • Bugs
    • Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+

jax 0.3.24 (Nov 4, 2022)

  • Changes
    • JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX's import time.
    • Setting the env var JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N can be used to limit the number of cache entries written to the persistent cache. By default, computations that take 1 second or more to compile will be cached.
      • Added {func}jax.scipy.stats.mode.
    • The default device order used by pmap on TPU if no order is specified now matches jax.devices() for single-process jobs. Previously the two orderings differed, which could lead to unnecessary copies or out-of-memory errors. Requiring the orderings to agree simplifies matters.
  • Breaking Changes
    • {func}jax.numpy.gradient now behaves like most other functions in {mod}jax.numpy, and forbids passing lists or tuples in place of arrays ({jax-issue}#12958)
    • Functions in {mod}jax.numpy.linalg and {mod}jax.numpy.fft now uniformly require inputs to be array-like: i.e. lists and tuples cannot be used in place of arrays. Part of {jax-issue}#7737.
  • Deprecations
    • jax.sharding.MeshPspecSharding has been renamed to jax.sharding.NamedSharding. jax.sharding.MeshPspecSharding name will be removed in 3 months.

jaxlib 0.3.24 (Nov 4, 2022)

  • Changes
    • Buffer donation now works on CPU. This may break code that marked buffers for donation on CPU but relied on donation not being implemented.

jax 0.3.23 (Oct 12, 2022)

  • Changes
    • Update Colab TPU driver version for new jaxlib release.

jax 0.3.22 (Oct 11, 2022)

  • Changes
    • Add JAX_PLATFORMS=tpu,cpu as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. Set JAX_PLATFORMS='' to override this behavior and automatically choose an available backend (the original default), or set JAX_PLATFORMS=cpu to always use CPU regardless of if the TPU is available.
  • Deprecations
    • Several test utilities deprecated in JAX v0.3.8 are now removed from {mod}jax.test_util.

jaxlib 0.3.22 (Oct 11, 2022)

jax 0.3.21 (Sep 30, 2022)

  • GitHub commits.
  • Changes
    • The persistent compilation cache will now warn instead of raising an exception on error ({jax-issue}#12582), so program execution can continue if something goes wrong with the cache. Set JAX_RAISE_PERSISTENT_CACHE_ERRORS=true to revert this behavior.

jax 0.3.20 (Sep 28, 2022)

  • Bug fixes:
    • Adds missing .pyi files that were missing from the previous release ({jax-issue}#12536).
    • Fixes an incompatibility between jax 0.3.19 and the libtpu version it pinned ({jax-issue}#12550). Requires jaxlib 0.3.20.
    • Fix incorrect pip url in setup.py comment ({jax-issue}#12528).

jaxlib 0.3.20 (Sep 28, 2022)

  • GitHub commits.
  • Bug fixes
    • Fixes support for limiting the visible CUDA devices via jax_cuda_visible_devices in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU ({jax-issue}#12533).

jax 0.3.19 (Sep 27, 2022)

jax 0.3.18 (Sep 26, 2022)

  • GitHub commits.
  • Changes
    • Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}#7733) is stable and public. See the overview and the API docs for {mod}jax.stages.
    • Introduced {class}jax.Array, intended to be used for both isinstance checks and type annotations for array types in JAX. Notice that this included some subtle changes to how isinstance works for {class}jax.numpy.ndarray for jax-internal objects, as {class}jax.numpy.ndarray is now a simple alias of {class}jax.Array.
  • Breaking changes
    • jax._src is no longer imported into the from the public jax namespace. This may break users that were using JAX internals.
    • jax.soft_pmap has been deleted. Please use pjit or xmap instead. jax.soft_pmap is undocumented. If it were documented, a deprecation period would have been provided.

jax 0.3.17 (Aug 31, 2022)

  • GitHub commits.
  • Bugs
    • Fix corner case issue in gradient of lax.pow with an exponent of zero ({jax-issue}12041)
  • Breaking changes
    • {func}jax.checkpoint, also known as {func}jax.remat, no longer supports the concrete option, following the previous version's deprecation; see JEP 11830.
  • Changes
    • Added {func}jax.pure_callback that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit or jax.pmap).
  • Deprecations:
    • The deprecated DeviceArray.tile() method has been removed. Use {func}jax.numpy.tile ({jax-issue}#11944).
    • DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead.

jax 0.3.16

  • GitHub commits.
  • Breaking changes
    • Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
  • Changes
    • Added {mod}jax.debug that includes utilities for runtime value debugging such at {func}jax.debug.print and {func}jax.debug.breakpoint.
    • Added new documentation for runtime value debugging
  • Deprecations
    • {func}jax.mask {func}jax.shapecheck APIs have been removed. See {jax-issue}#11557.
    • {mod}jax.experimental.loops has been removed. See {jax-issue}#10278 for an alternative API.
    • {func}jax.tree_util.tree_multimap has been removed. It has been deprecated since JAX release 0.3.5, and {func}jax.tree_util.tree_map is a direct replacement.
    • Removed jax.experimental.stax; it has long been a deprecated alias of {mod}jax.example_libraries.stax.
    • Removed jax.experimental.optimizers; it has long been a deprecated alias of {mod}jax.example_libraries.optimizers.
    • {func}jax.checkpoint, also known as {func}jax.remat, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.

jax 0.3.15 (July 22, 2022)

  • GitHub commits.
  • Changes
    • JaxTestCase and JaxTestLoader have been removed from jax.test_util. These classes have been deprecated since v0.3.1 ({jax-issue}#11248).
    • Added {class}jax.scipy.gaussian_kde ({jax-issue}#11237).
    • Binary operations between JAX arrays and built-in collections (dict, list, set, tuple) now raise a TypeError in all cases. Previously some cases (particularly equality and inequality) would return boolean scalars inconsistent with similar operations in NumPy ({jax-issue}#11234).
    • Several {mod}jax.tree_util routines accessed as top-level JAX package imports are now deprecated, and will be removed in a future JAX release in accordance with the {ref}api-compatibility policy:
      • {func}jax.treedef_is_leaf is deprecated in favor of {func}jax.tree_util.treedef_is_leaf
      • {func}jax.tree_flatten is deprecated in favor of {func}jax.tree_util.tree_flatten
      • {func}jax.tree_leaves is deprecated in favor of {func}jax.tree_util.tree_leaves
      • {func}jax.tree_structure is deprecated in favor of {func}jax.tree_util.tree_structure
      • {func}jax.tree_transpose is deprecated in favor of {func}jax.tree_util.tree_transpose
      • {func}jax.tree_unflatten is deprecated in favor of {func}jax.tree_util.tree_unflatten
    • The sym_pos argument of {func}jax.scipy.linalg.solve is deprecated in favor of assume_a='pos', following a similar deprecation in {func}scipy.linalg.solve.

jaxlib 0.3.15 (July 22, 2022)

jax 0.3.14 (June 27, 2022)

  • GitHub commits.
  • Breaking changes
    • {func}jax.experimental.compilation_cache.initialize_cache does not support max_cache_size_ bytes anymore and will not get that as an input.
    • JAX_PLATFORMS now raises an exception when platform initialization fails.
  • Changes
    • Fixed compatibility problems with NumPy 1.23.
    • {func}jax.numpy.linalg.slogdet now accepts an optional method argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition.
    • {func}jax.numpy.linalg.qr now supports mode="raw".
    • pickle, copy.copy, and copy.deepcopy now have more complete support when used on jax arrays ({jax-issue}#10659). In particular:
      • pickle and deepcopy previously returned np.ndarray objects when used on a DeviceArray; now DeviceArray objects are returned. For deepcopy, the copied array is on the same device as the original. For pickle the deserialized array will be on the default device.
      • Within function transformations (i.e. traced code), deepcopy and copy previously were no-ops. Now they use the same mechanism as DeviceArray.copy().
      • Calling pickle on a traced array now results in an explicit ConcretizationTypeError.
    • The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
    • {func}jax.numpy.ldexp no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller ({jax-issue}#10921).
    • Add a create_perfetto_link option to {func}jax.profiler.start_trace and {func}jax.profiler.start_trace. When used, the profiler will generate a link to the Perfetto UI to view the trace.
    • Changed the semantics of {func}jax.profiler.start_server(...) to store the keepalive globally, rather than requiring the user to keep a reference to it.
    • Added {func}jax.random.generalized_normal.
    • Added {func}jax.random.ball.
    • Added {func}jax.default_device.
    • Added a python -m jax.collect_profile script to manually capture program traces as an alternative to the Tensorboard UI.
    • Added a jax.named_scope context manager that adds profiler metadata to Python programs (similar to jax.named_call).
    • In scatter-update operations (i.e. :attr:jax.numpy.ndarray.at), unsafe implicit dtype casts are deprecated, and now result in a FutureWarning. In a future release, this will become an error. An example of an unsafe implicit cast is jnp.zeros(4, dtype=int).at[0].set(1.5), in which 1.5 previously was silently truncated to 1.
    • {func}jax.experimental.compilation_cache.initialize_cache now supports gcs bucket path as input.
    • Added {func}jax.scipy.stats.gennorm.
    • {func}jax.numpy.roots is now better behaved when strip_zeros=False when coefficients have leading zeros ({jax-issue}#11215).

jaxlib 0.3.14 (June 27, 2022)

  • GitHub commits.
    • x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.
    • The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
    • The Python flatbuffers package is no longer a dependency of jaxlib.

jax 0.3.13 (May 16, 2022)

jax 0.3.12 (May 15, 2022)

jax 0.3.11 (May 15, 2022)

  • GitHub commits.
  • Changes
    • {func}jax.lax.eigh now accepts an optional sort_eigenvalues argument that allows users to opt out of eigenvalue sorting on TPU.
  • Deprecations
    • Non-array arguments to functions in {mod}jax.lax.linalg are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}jax.numpy.linalg instead.
    • {func}jax.scipy.linalg.polar_unitary, which was a JAX extension to the scipy API, is deprecated. Use {func}jax.scipy.linalg.polar instead.

jax 0.3.10 (May 3, 2022)

jaxlib 0.3.10 (May 3, 2022)

  • GitHub commits.
  • Changes
    • TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.

jax 0.3.9 (May 2, 2022)

  • GitHub commits.
  • Changes
    • Added support for fully asynchronous checkpointing for GlobalDeviceArray.

jax 0.3.8 (April 29 2022)

  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}#10452)
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

jax 0.3.7 (April 15, 2022)

  • GitHub commits.
  • Changes:
    • Fixed a performance problem if the indices passed to {func}jax.numpy.take_along_axis were broadcasted ({jax-issue}#10281).
    • {func}jax.scipy.special.expit and {func}jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
    • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use {func}jax.numpy.tile ({jax-issue}#10266).

jaxlib 0.3.7 (April 15, 2022)

  • Changes:
    • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.

jax 0.3.6 (April 12, 2022)

  • GitHub commits.
  • Changes:
    • Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218.
  • Deprecations:
    • {mod}jax.experimental.loops is being deprecated. See {jax-issue}#10278 for an alternative API.

jax 0.3.5 (April 7, 2022)

  • GitHub commits.
  • Changes:
    • added {func}jax.random.loggamma & improved behavior of {func}jax.random.beta and {func}jax.random.dirichlet for small parameter values ({jax-issue}#9906).
    • the private lax_numpy submodule is no longer exposed in the jax.numpy namespace ({jax-issue}#10029).
    • added array creation routines {func}jax.numpy.frombuffer, {func}jax.numpy.fromfunction, and {func}jax.numpy.fromstring ({jax-issue}#10049).
    • DeviceArray.copy() now returns a DeviceArray rather than a np.ndarray ({jax-issue}#10069)
    • added {func}jax.scipy.linalg.rsf2csf
    • jax.experimental.sharded_jit has been deprecated and will be removed soon.
  • Deprecations:
    • {func}jax.nn.normalize is being deprecated. Use {func}jax.nn.standardize instead ({jax-issue}#9899).
    • {func}jax.tree_util.tree_multimap is deprecated. Use {func}jax.tree_util.tree_map instead ({jax-issue}#5746).
    • jax.experimental.sharded_jit is deprecated. Use pjit instead.

jaxlib 0.3.5 (April 7, 2022)

  • Bug fixes
    • Fixed a bug where double-precision complex-to-real IRFFTs would mutate their input buffers on GPU ({jax-issue}#9946).
    • Fixed incorrect constant-folding of complex scatters ({jax-issue}#10159)

jax 0.3.4 (March 18, 2022)

jax 0.3.3 (March 17, 2022)

jax 0.3.2 (March 16, 2022)

  • GitHub commits.
  • Changes:
    • The functions jax.ops.index_update, jax.ops.index_add, which were deprecated in 0.2.22, have been removed. Please use the .at property on JAX arrays instead, e.g., x.at[idx].set(y).
    • Moved jax.experimental.ann.approx_*_k into jax.lax. These functions are optimized alternatives to jax.lax.top_k.
    • {func}jax.numpy.broadcast_arrays and {func}jax.numpy.broadcast_to now require scalar or array-like inputs, and will fail if they are passed lists (part of {jax-issue}#7737).
    • The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
    • pjit now works on CPU (in addition to previous TPU and GPU support).

jaxlib 0.3.2 (March 16, 2022)

  • Changes
    • XlaComputation.as_hlo_text() now supports printing large constants by passing boolean flag print_large_constants=True.
  • Deprecations:
    • The .block_host_until_ready() method on JAX arrays has been deprecated. Use .block_until_ready() instead.

jax 0.3.1 (Feb 18, 2022)

  • GitHub commits.

  • Changes:

    • jax.test_util.JaxTestCase and jax.test_util.JaxTestLoader are now deprecated. The suggested replacement is to use parametrized.TestCase directly. For tests that rely on custom asserts such as JaxTestCase.assertAllClose(), the suggested replacement is to use standard numpy testing utilities such as {func}numpy.testing.assert_allclose(), which work directly with JAX arrays ({jax-issue}#9620).
    • jax.test_util.JaxTestCase now sets jax_numpy_rank_promotion='raise' by default ({jax-issue}#9562). To recover the previous behavior, use the new jax.test_util.with_config decorator:
      @jtu.with_config(jax_numpy_rank_promotion='allow')
      class MyTestCase(jtu.JaxTestCase):
        ...
    • Added {func}jax.scipy.linalg.schur, {func}jax.scipy.linalg.sqrtm, {func}jax.scipy.signal.csd, {func}jax.scipy.signal.stft, {func}jax.scipy.signal.welch.

jax 0.3.0 (Feb 10, 2022)

jaxlib 0.3.0 (Feb 10, 2022)

  • Changes
    • Bazel 5.0.0 is now required to build jaxlib.
    • jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.

jax 0.2.28 (Feb 1, 2022)

  • GitHub commits.
    • jax.jit(f).lower(...).compiler_ir() now defaults to the MHLO dialect if no dialect= is passed.
    • The jax.jit(f).lower(...).compiler_ir(dialect='mhlo') now returns an MLIR ir.Module object instead of its string representation.

jaxlib 0.1.76 (Jan 27, 2022)

  • New features
    • Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.
    • With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
  • Breaking changes
    • Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
  • Bug fixes
    • Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).
    • The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).

jax 0.2.27 (Jan 18 2022)

  • GitHub commits.

  • Breaking changes:

    • Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
    • The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variable, or the --flax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678).
    • Sorting now matches the behavior of NumPy for 0.0 and NaN regardless of the bit representation. In particular, 0.0 and -0.0 are now treated as equivalent, where previously -0.0 was treated as less than 0.0. Additionally all NaN representations are now treated as equivalent and sorted to the end of the array. Previously negative NaN values were sorted to the front of the array, and NaN values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax-issue}#9178).
    • {func}jax.numpy.unique now treats NaN values in the same way as np.unique in NumPy versions 1.21 and newer: at most one NaN value will appear in the uniquified output ({jax-issue}9184).
  • Bug fixes:

    • host_callback now supports ad_checkpoint.checkpoint ({jax-issue}#8907).
  • New features:

    • add jax.block_until_ready ({jax-issue}`#8941)
    • Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.
    • Added jax.ensure_compile_time_eval to the public api ({jax-issue}#7987).
    • jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details ({jax-issue}#9189).

jaxlib 0.1.75 (Dec 8, 2021)

  • New features:
    • Support for python 3.10.

jax 0.2.26 (Dec 8, 2021)

  • GitHub commits.

  • Bug fixes:

    • Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).
    • jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions ({jax-issue}#7839).

jaxlib 0.1.74 (Nov 17, 2021)

  • Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
  • Added experimental MLIR Python bindings for use by JAX.

jax 0.2.25 (Nov 10, 2021)

  • GitHub commits.

  • New features:

    • (Experimental) jax.distributed.initialize exposes multi-host GPU backend.
    • jax.random.permutation supports new independent keyword argument ({jax-issue}#8430)
  • Breaking changes

    • Moved jax.experimental.stax to jax.example_libraries.stax
    • Moved jax.experimental.optimizers to jax.example_libraries.optimizers
  • New features:

    • Added jax.lax.linalg.qdwh.

jax 0.2.24 (Oct 19, 2021)

  • GitHub commits.

  • New features:

    • jax.random.choice and jax.random.permutation now support multidimensional arrays and an optional axis argument ({jax-issue}#8158)
  • Breaking changes:

    • jax.numpy.take and jax.numpy.take_along_axis now require array-like inputs (see {jax-issue}#7737)

jaxlib 0.1.73 (Oct 18, 2021)

  • Multiple cuDNN versions are now supported for jaxlib GPU cuda11 wheels.

    • cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
    • cuDNN 8.0.5 or newer.
  • Breaking changes:

    • The install commands for GPU jaxlib are as follows:

      pip install --upgrade pip
      
      # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
      pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
      pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
      pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

jax 0.2.22 (Oct 12, 2021)

  • GitHub commits.
  • Breaking Changes
    • Static arguments to jax.pmap must now be hashable.

      Unhashable static arguments have long been disallowed on jax.jit, but they were still permitted on jax.pmap; jax.pmap compared unhashable static arguments using object identity.

      This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of jax.pmap wants to compare static arguments by object identity, they can define __hash__ and __eq__ methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to use functools.partial to encapsulate the unhashable static arguments into the function object.

    • jax.util.partial was an accidental export that has now been removed. Use functools.partial from the Python standard library instead.

  • Deprecations
    • The functions jax.ops.index_update, jax.ops.index_add etc. are deprecated and will be removed in a future JAX release. Please use the .at property on JAX arrays instead, e.g., x.at[idx].set(y). For now, these functions produce a DeprecationWarning.
  • New features:
    • An optimized C++ code-path improving the dispatch time for pmap is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the --experimental_cpp_pmap flag (or JAX_CPP_PMAP environment variable).
    • jax.numpy.unique now supports an optional fill_value argument ({jax-issue}#8121)

jaxlib 0.1.72 (Oct 12, 2021)

  • Breaking changes:
    • Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
  • Bug fixes:
    • Fixes jax-ml#7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.

jax 0.2.21 (Sept 23, 2021)

  • GitHub commits.
  • Breaking Changes
    • jax.api has been removed. Functions that were available as jax.api.* were aliases for functions in jax.*; please use the functions in jax.* instead.
    • jax.partial, and jax.lax.partial were accidental exports that have now been removed. Use functools.partial from the Python standard library instead.
    • Boolean scalar indices now raise a TypeError; previously this silently returned wrong results ({jax-issue}#7925).
    • Many more jax.numpy functions now require array-like inputs, and will error if passed a list ({jax-issue}#7747 {jax-issue}#7802 {jax-issue}#7907). See {jax-issue}#7737 for a discussion of the rationale behind this change.
    • When inside a transformation such as jax.jit, jax.numpy.array always stages the array it produces into the traced computation. Previously jax.numpy.array would sometimes produce a on-device array, even under a jax.jit decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.
    • jnp.ndarray is now a true base-class for JAX arrays. In particular, this means that for a standard numpy array x, isinstance(x, jnp.ndarray) will now return False ({jax-issue}7927).
  • New features:
    • Added {func}jax.numpy.insert implementation ({jax-issue}#7936).

jax 0.2.20 (Sept 2, 2021)

  • GitHub commits.
  • Breaking Changes
    • jnp.poly* functions now require array-like inputs ({jax-issue}#7732)
    • jnp.unique and other set-like operations now require array-like inputs ({jax-issue}#7662)

jaxlib 0.1.71 (Sep 1, 2021)

  • Breaking changes:
    • Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.

jax 0.2.19 (Aug 12, 2021)

  • GitHub commits.
  • Breaking changes:
    • Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.

    • The jit decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as +.

      This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g., x + 2**40). The workaround is to cast the constant to an explicit type (e.g., np.float64(2**40)).

  • New features:
    • Improved the support for shape polymorphism in jax2tf for operations that need to use a dimension size in array computation, e.g., jnp.mean. ({jax-issue}#7317)
  • Bug fixes:
    • Some leaked trace errors from the previous release ({jax-issue}#7613)

jaxlib 0.1.70 (Aug 9, 2021)

  • Breaking changes:
    • Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.

    • Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.

    • The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.

jax 0.2.18 (July 21 2021)

  • GitHub commits.

  • Breaking changes:

    • Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
    • The minimum jaxlib version is now 0.1.69.
    • The backend argument to {py:func}jax.dlpack.from_dlpack has been removed.
  • New features:

    • Added a polar decomposition ({py:func}jax.scipy.linalg.polar).
  • Bug fixes:

    • Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with an invalid axis value, or with an empty reduction dimension. ({jax-issue}#7196)

jaxlib 0.1.69 (July 9 2021)

  • Fix bugs in TFRT CPU backend that results in incorrect results.

jax 0.2.17 (July 9 2021)

  • GitHub commits.
  • Bug fixes:
    • Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
  • New features:
    • New SciPy function {py:func}jax.scipy.special.sph_harm.
    • Reverse-mode autodiff functions ({func}jax.grad, {func}jax.value_and_grad, {func}jax.vjp, and {func}jax.linear_transpose) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially only {func}jax.experimental.maps.xmap) ({jax-issue}#6950).

jax 0.2.16 (June 23 2021)

jax 0.2.15 (June 23 2021)

  • GitHub commits.

  • New features:

    • #7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.
    • The {func}jax2tf.convert supports inequalities and min/max for booleans ({jax-issue}#6956).
    • New SciPy function {py:func}jax.scipy.special.lpmn_values.
  • Breaking changes:

  • Bug fixes:

    • Fixed bug that prevented round-tripping from JAX to TF and back: jax2tf.call_tf(jax2tf.convert) ({jax-issue}#6947).

jaxlib 0.1.68 (June 23 2021)

  • Bug fixes:
    • Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.

jax 0.2.14 (June 10 2021)

  • GitHub commits.

  • New features:

    • The {func}jax2tf.convert now has support for pjit and sharded_jit.
    • A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
    • A new traceback filtering mode using __tracebackhide__ is now enabled by default in sufficiently recent versions of IPython.
    • The {func}jax2tf.convert supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g., jnp.reshape(-1) ({jax-issue}#6827).
    • The {func}jax2tf.convert generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA.
    • New SciPy function {py:func}jax.scipy.special.lpmn.
  • Bug fixes:

    • The {func}jax2tf.convert now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX ({jax-issue}#6883).
    • The {func}jax2tf.convert now scopes the enable_xla conversion parameter properly to apply only during the just-in-time conversion ({jax-issue}#6720).
    • The {func}jax2tf.convert now converts lax.dot_general using the XlaDot TensorFlow op, for better fidelity w.r.t. JAX numerical precision ({jax-issue}#6717).
    • The {func}jax2tf.convert now has support for inequality comparisons and min/max for complex numbers ({jax-issue}#6892).

jaxlib 0.1.67 (May 17 2021)

jaxlib 0.1.66 (May 11 2021)

  • New features:
    • CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

      NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.

      There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).

    • Jaxlib now bundles libdevice.10.bc in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file.

    • Added automatic support for static keyword arguments to the {func}jit implementation.

    • Added support for pretransformation exception traces.

    • Initial support for pruning unused arguments from {func}jit -transformed computations. Pruning is still a work in progress.

    • Improved the string representation of {class}PyTreeDef objects.

    • Added support for XLA's variadic ReduceWindow.

  • Bug fixes:
    • Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
    • Fix a bug that meant that JAX garbage collection was not triggered by {func}jit transformed functions.

jax 0.2.13 (May 3 2021)

  • GitHub commits.

  • New features:

    • When combined with jaxlib 0.1.66, {func}jax.jit now supports static keyword arguments. A new static_argnames option has been added to specify keyword arguments as static.
    • {func}jax.nonzero has a new optional size argument that allows it to be used within jit ({jax-issue}#6501)
    • {func}jax.numpy.unique now supports the axis argument ({jax-issue}#6532).
    • {func}jax.experimental.host_callback.call now supports pjit.pjit ({jax-issue}#6569).
    • Added {func}jax.scipy.linalg.eigh_tridiagonal that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present.
    • The order of the filtered and unfiltered stack traces in exceptions has been changed. The traceback attached to an exception thrown from JAX-transformed code is now filtered, with an UnfilteredStackTrace exception containing the original trace as the __cause__ of the filtered exception. Filtered stack traces now also work with Python 3.6.
    • If an exception is thrown by code that has been transformed by reverse-mode automatic differentiation, JAX now attempts to attach as a __cause__ of the exception a JaxStackTraceBeforeTransformation object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
  • Breaking changes:

    • The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
      • host_id --> {func}~jax.process_index
      • host_count --> {func}~jax.process_count
      • host_ids --> range(jax.process_count())
    • Similarly, the argument to {func}~jax.local_devices has been renamed from host_id to process_index.
    • Arguments to {func}jax.jit other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added to jit.
  • Bug fixes:

    • The {func}jax2tf.convert now works in presence of gradients for functions with integer inputs ({jax-issue}#6360).
    • Fixed assertion failure in {func}jax2tf.call_tf when used with captured tf.Variable ({jax-issue}#6572).

jaxlib 0.1.65 (April 7 2021)

jax 0.2.12 (April 1 2021)

  • GitHub commits.
  • New features
    • New profiling APIs: {func}jax.profiler.start_trace, {func}jax.profiler.stop_trace, and {func}jax.profiler.trace
    • {func}jax.lax.reduce is now differentiable.
  • Breaking changes:
    • The minimum jaxlib version is now 0.1.64.
    • Some profiler APIs names have been changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
      • TraceContext --> {func}~jax.profiler.TraceAnnotation
      • StepTraceContext --> {func}~jax.profiler.StepTraceAnnotation
      • trace_function --> {func}~jax.profiler.annotate_function
    • Omnistaging can no longer be disabled. See omnistaging for more information.
    • Python integers larger than the maximum int64 value will now lead to an overflow in all cases, rather than being silently converted to uint64 in some cases ({jax-issue}#6047).
    • Outside X64 mode, Python integers outside the range representable by int32 will now lead to an OverflowError rather than having their value silently truncated.
  • Bug fixes:
    • host_callback now supports empty arrays in arguments and results ({jax-issue}#6262).
    • {func}jax.random.randint clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype ({jax-issue}#5868)

jax 0.2.11 (March 23 2021)

  • GitHub commits.

  • New features:

    • #6112 added context managers: jax.enable_checks, jax.check_tracer_leaks, jax.debug_nans, jax.debug_infs, jax.log_compiles.
    • #6085 added jnp.delete
  • Bug fixes:

    • #6136 generalized jax.flatten_util.ravel_pytree to handle integer dtypes.
    • #6129 fixed a bug with handling some constants like enum.IntEnums
    • #6145 fixed batching issues with incomplete beta functions
    • #6014 fixed H2D transfers during tracing
    • #6165 avoids OverflowErrors when converting some large Python integers to floats
  • Breaking changes:

    • The minimum jaxlib version is now 0.1.62.

jaxlib 0.1.64 (March 18 2021)

jaxlib 0.1.63 (March 17 2021)

jax 0.2.10 (March 5 2021)

  • GitHub commits.
  • New features:
    • {func}jax.scipy.stats.chi2 is now available as a distribution with logpdf and pdf methods.
    • {func}jax.scipy.stats.betabinom is now available as a distribution with logpmf and pmf methods.
    • Added {func}jax.experimental.jax2tf.call_tf to call TensorFlow functions from JAX ({jax-issue}#5627) and README).
    • Extended the batching rule for lax.pad to support batching of the padding values.
  • Bug fixes:
    • {func}jax.numpy.take properly handles negative indices ({jax-issue}#5768)
  • Breaking changes:
    • JAX's promotion rules were adjusted to make promotion more consistent and invariant to JIT. In particular, binary operations can now result in weakly-typed values when appropriate. The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression jnp.bfloat16(1) + 0.1 * jnp.arange(10) previously returned a float64 array, and now returns a bfloat16 array. JAX's type promotion behavior is described at {ref}type-promotion.
    • {func}jax.numpy.linspace now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0.
    • {func}jax.numpy.i0 no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0.
    • Several {mod}jax.numpy functions no longer accept tuples or lists in place of array arguments: {func}jax.numpy.pad, :funcjax.numpy.ravel, {func}jax.numpy.repeat, {func}jax.numpy.reshape. In general, {mod}jax.numpy functions should be used with scalars or array arguments.

jaxlib 0.1.62 (March 9 2021)

  • New features:
    • jaxlib wheels are now built to require AVX instructions on x86-64 machines by default. If you want to use JAX on a machine that doesn't support AVX, you can build a jaxlib from source using the --target_cpu_features flag to build.py. --target_cpu_features also replaces --enable_march_native.

jaxlib 0.1.61 (February 12 2021)

jaxlib 0.1.60 (Febuary 3 2021)

  • Bug fixes:
    • Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
    • bool, int8, and uint8 are now considered safe to cast to bfloat16 NumPy extension type.

jax 0.2.9 (January 26 2021)

  • GitHub commits.
  • New features:
    • Extend the {mod}jax.experimental.loops module with support for pytrees. Improved error checking and error messages.
    • Add {func}jax.experimental.enable_x64 and {func}jax.experimental.disable_x64. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
  • Breaking changes:
    • {func}jax.ops.segment_sum now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.

jaxlib 0.1.59 (January 15 2021)

jax 0.2.8 (January 12 2021)

  • GitHub commits.
  • New features:
    • Add {func}jax.closure_convert for use with higher-order custom derivative functions. ({jax-issue}#5244)
    • Add {func}jax.experimental.host_callback.call to call a custom Python function on the host and return a result to the device computation. ({jax-issue}#5243)
  • Bug fixes:
    • jax.numpy.arccosh now returns the same branch as numpy.arccosh for complex inputs ({jax-issue}#5156)
    • host_callback.id_tap now works for jax.pmap also. There is an optional parameter for id_tap and id_print to request that the device from which the value is tapped be passed as a keyword argument to the tap function ({jax-issue}#5182).
  • Breaking changes:
    • jax.numpy.pad now takes keyword arguments. Positional argument constant_values has been removed. In addition, passing unsupported keyword arguments raises an error.
    • Changes for {func}jax.experimental.host_callback.id_tap ({jax-issue}#5243):
      • Removed support for kwargs for {func}jax.experimental.host_callback.id_tap. (This support has been deprecated for a few months.)
      • Changed the printing of tuples for {func}jax.experimental.host_callback.id_print to use '(' instead of '['.
      • Changed the {func}jax.experimental.host_callback.id_print in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent.
      • host_callback.outfeed_receiver has been removed (it is not necessary, and was deprecated a few months ago).
  • New features:
    • New flag for debugging inf, analagous to that for NaN ({jax-issue}#5224).

jax 0.2.7 (Dec 4 2020)

  • GitHub commits.
  • New features:
    • Add jax.device_put_replicated
    • Add multi-host support to jax.experimental.sharded_jit
    • Add support for differentiating eigenvalues computed by jax.numpy.linalg.eig
    • Add support for building on Windows platforms
    • Add support for general in_axes and out_axes in jax.pmap
    • Add complex support for jax.numpy.linalg.slogdet
  • Bug fixes:
    • Fix higher-than-second order derivatives of jax.numpy.sinc at zero
    • Fix some hard-to-hit bugs around symbolic zeros in transpose rules
  • Breaking changes:
    • jax.experimental.optix has been deleted, in favor of the standalone optax Python package.
    • indexing of JAX arrays with non-tuple sequences now raises a TypeError. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See {jax-issue}#4564.

jax 0.2.6 (Nov 18 2020)

  • GitHub commits.

  • New Features:

    • Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
  • Breaking change cleanup

    • Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.

    • Improve consistency of type promotion behavior ({jax-issue}#4744):

      • Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, jnp.float32(1) + 1j now returns complex64, where previously it returned complex128.
      • Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example: jnp.result_type(jnp.uint64, jnp.int64, jnp.float16) and jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) both return float16, where previously the first returned float64 and the second returned float16.
    • The contents of the (undocumented) jax.lax_linalg linear algebra module are now exposed publicly as jax.lax.linalg.

    • jax.random.PRNGKey now produces the same results in and out of JIT compilation ({jax-issue}#4877). This required changing the result for a given seed in a few particular cases:

      • With jax_enable_x64=False, negative seeds passed as Python integers now return a different result outside JIT mode. For example, jax.random.PRNGKey(-1) previously returned [4294967295, 4294967295], and now returns [0, 4294967295]. This matches the behavior in JIT.
      • Seeds outside the range representable by int64 outside JIT now result in an OverflowError rather than a TypeError. This matches the behavior in JIT.

      To recover the keys returned previously for negative integers with jax_enable_x64=False outside JIT, you can use:

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
      
    • DeviceArray now raises RuntimeError instead of ValueError when trying to access its value while it has been deleted.

jaxlib 0.1.58 (January 12ish 2021)

  • Fixed a bug that meant JAX sometimes return platform-specific types (e.g., np.cint) instead of standard types (e.g., np.int32). (#4903)
  • Fixed a crash when constant-folding certain int16 operations. (#4971)
  • Added an is_leaf predicate to {func}pytree.flatten.

jaxlib 0.1.57 (November 12 2020)

  • Fixed manylinux2010 compliance issues in GPU wheels.
  • Switched the CPU FFT implementation from Eigen to PocketFFT.
  • Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
  • Add support for retaining ownership when passing arrays to DLPack (#4636).
  • Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
  • Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
  • Fixed a bug in profiler where tools are missing (#4427).
  • Dropped support for CUDA 10.0.

jax 0.2.5 (October 27 2020)

jax 0.2.4 (October 19 2020)

  • GitHub commits.

  • Improvements:

    • Add support for remat to jax.experimental.host_callback. See {jax-issue}#4608.
  • Deprecations

    • Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy. In a future release, this will result in a TypeError. See {jax-issue}#4564.

jaxlib 0.1.56 (October 14, 2020)

jax 0.2.3 (October 14 2020)

  • GitHub commits.
  • The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation

jax 0.2.2 (October 13 2020)

jax 0.2.1 (October 6 2020)

  • GitHub commits.
  • Improvements:
    • As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the {py:func}jax.experimental.host_callback.id_print/ {py:func}jax.experimental.host_callback.id_tap is not used in the computation.

jax (0.2.0) (September 23 2020)

jax (0.1.77) (September 15 2020)

  • Breaking changes:
    • New simplified interface for {py:func}jax.experimental.host_callback.id_tap (#4101)

jaxlib 0.1.55 (September 8, 2020)

  • Update XLA:
    • Fix bug in DLPackManagedTensorToBuffer (#4196)

jax 0.1.76 (September 8, 2020)

jax 0.1.75 (July 30, 2020)

  • GitHub commits.
  • Bug Fixes:
    • make jnp.abs() work for unsigned inputs (#3914)
  • Improvements:
    • "Omnistaging" behavior added behind a flag, disabled by default (#3370)

jax 0.1.74 (July 29, 2020)

  • GitHub commits.
  • New Features:
    • BFGS (#3101)
    • TPU support for half-precision arithmetic (#3878)
  • Bug Fixes:
    • Prevent some accidental dtype warnings (#3874)
    • Fix a multi-threading bug in custom derivatives (#3845, #3869)
  • Improvements:
    • Faster searchsorted implementation (#3873)
    • Better test coverage for jax.numpy sorting algorithms (#3836)

jaxlib 0.1.52 (July 22, 2020)

  • Update XLA.

jax 0.1.73 (July 22, 2020)

  • GitHub commits.
  • The minimum jaxlib version is now 0.1.51.
  • New Features:
    • jax.image.resize. (#3703)
    • hfft and ihfft (#3664)
    • jax.numpy.intersect1d (#3726)
    • jax.numpy.lexsort (#3812)
    • lax.scan and the scan primitive support an unroll parameter for loop unrolling when lowering to XLA ({jax-issue}#3738).
  • Bug Fixes:
    • Fix reduction repeated axis error (#3618)
    • Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
    • make psum transpose handle zero cotangents (#3653)
    • Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
    • Support differentiation through jax.lax.all_to_all (#3733)
    • address nan issue in jax.scipy.special.zeta (#3777)
  • Improvements:
    • Many improvements to jax2tf
    • Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
    • Enable XLA SPMD partitioning by default. (#3151)
    • Add support for 0d transpose convolution (#3643)
    • Make LU gradient work for low-rank matrices (#3610)
    • support multiple_results and custom JVPs in jet (#3657)
    • Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
    • Implement complex convolutions on CPU and GPU. (#3735)
    • Make jnp.take work for empty slices of empty arrays. (#3751)
    • Relax dimension ordering rules for dot_general. (#3778)
    • Enable buffer donation for GPU. (#3800)
    • Add support for base dilation and window dilation to reduce window op… (#3803)

jaxlib 0.1.51 (July 2, 2020)

  • Update XLA.
  • Add new runtime support for host_callback.

jax 0.1.72 (June 28, 2020)

  • GitHub commits.
  • Bug fixes:
    • Fix an odeint bug introduced in the previous release, see {jax-issue}#3587.

jax 0.1.71 (June 25, 2020)

  • GitHub commits.
  • The minimum jaxlib version is now 0.1.48.
  • Bug fixes:
    • Allow jax.experimental.ode.odeint dynamics functions to close over values with respect to which we're differentiating {jax-issue}#3562.

jaxlib 0.1.50 (June 25, 2020)

  • Add support for CUDA 11.0.
  • Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
  • Update XLA.

jaxlib 0.1.49 (June 19, 2020)

jaxlib 0.1.48 (June 12, 2020)

  • New features:
    • Adds support for fast traceback collection.
    • Adds preliminary support for on-device heap profiling.
    • Implements np.nextafter for bfloat16 types.
    • Complex128 support for FFTs on CPU and GPU.
  • Bugfixes:
    • Improved float64 tanh accuracy on GPU.
    • float64 scatters on GPU are much faster.
    • Complex matrix multiplication on CPU should be much faster.
    • Stable sorts on CPU should actually be stable now.
    • Concurrency bug fix in CPU backend.

jax 0.1.70 (June 8, 2020)

  • GitHub commits.
  • New features:
    • lax.switch introduces indexed conditionals with multiple branches, together with a generalization of the cond primitive {jax-issue}#3318.

jax 0.1.69 (June 3, 2020)

jax 0.1.68 (May 21, 2020)

  • GitHub commits.
  • New features:
    • {func}lax.cond supports a single-operand form, taken as the argument to both branches {jax-issue}#2993.
  • Notable changes:
    • The format of the transforms keyword for the {func}jax.experimental.host_callback.id_tap primitive has changed {jax-issue}#3132.

jax 0.1.67 (May 12, 2020)

  • GitHub commits.
  • New features:
    • Support for reduction over subsets of a pmapped axis using axis_index_groups {jax-issue}#2382.
    • Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap ({jax-issue}#3006).
  • Notable changes:
    • The visibility of names exported from {mod}jax.numpy has been tightened. This may break code that was making use of names that were previously exported accidentally.

jaxlib 0.1.47 (May 8, 2020)

  • Fixes crash for outfeed.

jax 0.1.66 (May 5, 2020)

  • GitHub commits.
  • New features:
    • Support for in_axes=None on {func}pmap {jax-issue}#2896.

jaxlib 0.1.46 (May 5, 2020)

  • Fixes crash for linear algebra functions on Mac OS X (#432).
  • Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).

jax 0.1.65 (April 30, 2020)

  • GitHub commits.
  • New features:
    • Differentiation of determinants of singular matrices {jax-issue}#2809.
  • Bug fixes:
    • Fix {func}odeint differentiation with respect to time of ODEs with time-dependent dynamics {jax-issue}#2817, also add ODE CI testing.
    • Fix {func}lax_linalg.qr differentiation {jax-issue}#2867.

jaxlib 0.1.45 (April 21, 2020)

  • Fixes segfault: {jax-issue}#2755
  • Plumb is_stable option on Sort HLO through to Python.

jax 0.1.64 (April 21, 2020)

  • GitHub commits.
  • New features:
    • Add syntactic sugar for functional indexed updates {jax-issue}#2684.
    • Add {func}jax.numpy.linalg.multi_dot {jax-issue}#2726.
    • Add {func}jax.numpy.unique {jax-issue}#2760.
    • Add {func}jax.numpy.rint {jax-issue}#2724.
    • Add {func}jax.numpy.rint {jax-issue}#2724.
    • Add more primitive rules for {func}jax.experimental.jet.
  • Bug fixes:
    • Fix {func}logaddexp and {func}logaddexp2 differentiation at zero {jax-issue}#2107.
    • Improve memory usage in reverse-mode autodiff without {func}jit {jax-issue}#2719.
  • Better errors:
    • Improves error message for reverse-mode differentiation of {func}lax.while_loop {jax-issue}#2129.

jaxlib 0.1.44 (April 16, 2020)

  • Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
  • Bugfix for batch_group_count convolutions.
  • Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.

jax 0.1.63 (April 12, 2020)

  • GitHub commits.
  • Added jax.custom_jvp and jax.custom_vjp from {jax-issue}#2026, see the tutorial notebook. Deprecated jax.custom_transforms and removed it from the docs (though it still works).
  • Add scipy.sparse.linalg.cg {jax-issue}#2566.
  • Changed how Tracers are printed to show more useful information for debugging {jax-issue}#2591.
  • Made jax.numpy.isclose handle nan and inf correctly {jax-issue}#2501.
  • Added several new rules for jax.experimental.jet {jax-issue}#2537.
  • Fixed jax.experimental.stax.BatchNorm when scale/center isn't provided.
  • Fix some missing cases of broadcasting in jax.numpy.einsum {jax-issue}#2512.
  • Implement jax.numpy.cumsum and jax.numpy.cumprod in terms of a parallel prefix scan {jax-issue}#2596 and make reduce_prod differentiable to arbitray order {jax-issue}#2597.
  • Add batch_group_count to conv_general_dilated {jax-issue}#2635.
  • Add docstring for test_util.check_grads {jax-issue}#2656.
  • Add callback_transform {jax-issue}#2665.
  • Implement rollaxis, convolve/correlate 1d & 2d, copysign, trunc, roots, and quantile/percentile interpolation options.

jaxlib 0.1.43 (March 31, 2020)

  • Fixed a performance regression for Resnet-50 on GPU.

jax 0.1.62 (March 21, 2020)

  • GitHub commits.
  • JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
  • Removed the internal function lax._safe_mul, which implemented the convention 0. * nan == 0.. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.
  • Added an all_gather parallel convenience function.
  • More type annotations in core code.

jaxlib 0.1.42 (March 19, 2020)

  • jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
  • JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.

jax 0.1.61 (March 17, 2020)

  • GitHub commits.
  • Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.

jax 0.1.60 (March 17, 2020)

  • GitHub commits.
  • New features:
    • {py:func}jax.pmap has static_broadcast_argnums argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously to static_argnums in {py:func}jax.jit.
    • Improved error messages for when tracers are mistakenly saved in global state.
    • Added {py:func}jax.nn.one_hot utility function.
    • Added {mod}jax.experimental.jet for exponentially faster higher-order automatic differentiation.
    • Added more correctness checking to arguments of {py:func}jax.lax.broadcast_in_dim.
  • The minimum jaxlib version is now 0.1.41.

jaxlib 0.1.40 (March 4, 2020)

  • Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
  • Includes prototype support for multihost GPU computations that communicate via NCCL.
  • Improves performance of NCCL collectives on GPU.
  • Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
  • Supports device assignments known at XLA compilation time.

jax 0.1.59 (February 11, 2020)

  • GitHub commits.

  • Breaking changes

    • The minimum jaxlib version is now 0.1.38.
    • Simplified {py:class}Jaxpr by removing the Jaxpr.freevars and Jaxpr.bound_subjaxprs. The call primitives (xla_call, xla_pmap, sharded_call, and remat_call) get a new parameter call_jaxpr with a fully-closed (no constvars) jaxpr. Also, added a new field call_primitive to primitives.
  • New features:

    • Reverse-mode automatic differentiation (e.g. grad) of lax.cond, making it now differentiable in both modes ({jax-issue}#2091)
    • JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
    • JAX GPU DeviceArrays now support __cuda_array_interface__, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.
    • JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
    • Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.

jaxlib 0.1.39 (February 11, 2020)

  • Updates XLA.

jaxlib 0.1.38 (January 29, 2020)

  • CUDA 9.0 is no longer supported.
  • CUDA 10.2 wheels are now built by default.

jax 0.1.58 (January 28, 2020)

  • GitHub commits.

  • Breaking changes

    • JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
  • New features

    • Forward-mode automatic differentiation (jvp) of while loop ({jax-issue}#1980)
    • New NumPy and SciPy functions:

      • {py:func}jax.numpy.fft.fft2
      • {py:func}jax.numpy.fft.ifft2
      • {py:func}jax.numpy.fft.rfft
      • {py:func}jax.numpy.fft.irfft
      • {py:func}jax.numpy.fft.rfft2
      • {py:func}jax.numpy.fft.irfft2
      • {py:func}jax.numpy.fft.rfftn
      • {py:func}jax.numpy.fft.irfftn
      • {py:func}jax.numpy.fft.fftfreq
      • {py:func}jax.numpy.fft.rfftfreq
      • {py:func}jax.numpy.linalg.matrix_rank
      • {py:func}jax.numpy.linalg.matrix_power
      • {py:func}jax.scipy.special.betainc
    • Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.

Notable bug fixes

  • With the Python 3 upgrade, JAX no longer depends on fastcache, which should help with installation.