Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: replace old tutorials with new content #20819

Closed
wants to merge 1 commit into from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Apr 18, 2024

This is the first part of the restructuring in #18585

The main change here is that the older jax-101 tutorials are replaced by newer, updated discussions, with redirects from the old URLs.

Preview here: https://jax--20819.org.readthedocs.build/en/20819/

Find the main new content listed here: https://jax--20819.org.readthedocs.build/en/20819/tutorials.html, and linked in the side-bar.

I've added the new sphinxext-rediraffe extension to add automatic redirects for the old URLs to the replacement content. Some examples of the new redirects in action:

@jakevdp jakevdp self-assigned this Apr 18, 2024
@jakevdp jakevdp force-pushed the doc-new-tutorials branch 2 times, most recently from fe5ce96 to 7897fda Compare April 18, 2024 20:03
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to attach this to the right place, but: there's a line at the top of the quickstart that says "a la NumPy", can we change that to "à la NumPy"? I don't want to offend the French.

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Apologies again, I'm not sure how to attach this comment to the right place due to how the file move is represented in GitHub.)

In quickstart.md, there's a line that says "the jax.jacobian() transformation can be used to compute gradients of vector-valued functions". It refers to the gradients of vector-valued functions.

IMO we should reserve "gradient" only to apply to scalar-valued functions. That's consistent with Wikipedia, the jax.grad restriction, and my understanding of usual conventions.

How about we rephrase to something like:

"Beyond scalar-valued functions, we can use jax.jacobian to compute the full Jacobian matrix of vector-valued functions."

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In quickstart.md there's a sentence that says

For more advanced autodiff operations, you can use jax..jacrev() for reverse-mode vector-Jacobian products, and jax.jacfwd() for forward-mode Jacobian-vector products

But actually 'vector-Jacobian products' and 'jacobian-vector products' are describing jax.vjp and jax.jvp, not jax.jacrev and jax.jacfwd. Let's change it to:

For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() or jax.linearize for forward-mode Jacobian-vector products

But then the text that just follows needs to be revised too, since it refers to jacfwd and jacrev... we could define those inline with vmap as ~1 liners each, or add a sentence introducing jacfwd and jacrev separately. I kind of lean towards defining them inline. WDYT?

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In key-concepts.md, there's a line "JAX arrays are never constructed directly" that wasn't immediately clear to me what it meant. How about we change it to something like "We don't call the jax.Array object constructor directly, like jax.Array(...), but rather ..."?

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In key-concepts.md I noticed some "jax" instead of "JAX", but we can fix those later.

I also noticed that "The magic behind transformations is the notion of a Tracers." should end with the word "Tracer" rather than "Tracers" I think (or the "a" should be deleted).

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

Maybe I'll try attaching comments to the deleted side of the moved files... EDIT nope couldn't do it, e.g. nowhere to attach comments to key-concepts.md line-by-line afaict.

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

Docs are such a great opportunity to check over the fundamentals.

In key-concepts.md, the Tracing section reminded me that I had a recent experiment trying to improve how Tracers are printed. The Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)> is pretty opaque to users.

This branch makes some small tweaks, so that when we run this test file:

import jax
import jax.numpy as jnp
jax.config.update('jax_platforms', 'cpu')

cool = []
def f(x):
  cool.append(repr(x))
  return x.sum()

def do(expr):
  eval(expr)
  print(f'{expr}\n\n{cool.pop()}')
  print('\n===\n')
  jax.clear_caches()

x = jnp.ones((3, 3))

do('jax.jit(f)(x)')
do('jax.grad(f)(x)')
do('jax.grad(lambda x: jax.grad(f)(x).sum())(x)')
do('jax.grad(jax.jit(f))(x)')
do('jax.jit(jax.grad(f))(x)')

instead of getting this like we do at HEAD:

jax.jit(f)(x)

Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=1/0)>

===

jax.grad(f)(x)

Traced<ConcreteArray([[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()

===

jax.grad(lambda x: jax.grad(f)(x).sum())(x)

Traced<ConcreteArray([[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], dtype=float32)>with<JVPTrace(level=4/0)> with
  primal = Traced<ConcreteArray([[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], dtype=float32)>with<JVPTrace(level=2/0)> with
    primal = Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
    tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
      pval = (ShapedArray(float32[3,3]), None)
      recipe = LambdaBinding()
  tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=3/0)> with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()

===

jax.grad(jax.jit(f))(x)

Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=3/0)>

===

jax.jit(jax.grad(f))(x)

Traced<ShapedArray(float32[3,3])>with<JVPTrace(level=3/0)> with
  primal = Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=1/0)>
  tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=2/0)> with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()

===

We get this:

jax.jit(f)(x)

JitTracer(float32[3,3])

===

jax.grad(f)(x)

GradTracer(Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32))

===

jax.grad(lambda x: jax.grad(f)(x).sum())(x)

GradTracer(GradTracer(Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)))

===

jax.grad(jax.jit(f))(x)

JitTracer(float32[3,3])

===

jax.jit(jax.grad(f))(x)

GradTracer(JitTracer(float32[3,3]))

===

I think for this line in the docs we'd get JitTracer(int32[5]) instead of the current Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>.

Landing this is outside the scope of this PR review, but I wanted to jot this down for posterity so we don't forget. LMK if you have any thoughts.

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In key-concepts.md there's a line "short for JAX eXPRession" that should be "short for JAX exPRession" to make the capitals work out.

How about we change that sentence from "A jaxpr (short for JAX eXPRession) represents a list of core units of computation called primitives that represent the effect of a computation." to something like "A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations."

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

In jit-compilation.md, I think this paragraph might make it sound like global_list.append has more of a chance at predictable behavior than it really does:

Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour under transformations. However, as a rule of thumb, you can expect (but shouldn’t rely on) the side-effects of a JIT-compiled function to run once (during the first call), and never again, due to JAX’s traced execution model.

The trouble is that in the example, we'll leak a tracer, and all heck will break loose!

Maybe we can change it to something like:

Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. (If you want debug printing, use jax.debug.print (link to debugging section of tutorial!). To express general side-effects at the cost of performance, see jax.experimental.io_callback (link to docs). To check for tracer leaks at the cost of performance, use with jax.check_tracer_leaks (link to docs).)

I'm not sure about all those forward pointers but they may be useful. I've liked the other forward pointers in here so far!

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

In jit-compilation.md, another possible forward-pointer: when we talk about print, we can say it's a trace-time print, and if you want a run-time debug print that does appear in the jaxpr, check out jax.debug.print.

Nevermind, that's exactly what is already said in debugging.md. So a forward pointer covers that.

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

This list item indentation in jit-compilation.md is a bit off:
image

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

In jit-compilation.md, I think this section is misleading:

image

We forgot the block_until-ready on the second example!

I think the whole section is misleading, though, because jit likely always worth it. It should have lower overhead on each execution, and we have to compile in both cases, just the jit-decorated case compiles one larger program rather than multiple smaller ones.

Maybe we can just excise the whole section? The message at the end, to jit the largest chunks of your program that you can, is a good one, so maybe we can just keep that last paragraph?

@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

In automatic-differentiation.md, we write "automatic differentiation (autodiff)" twice (once in each of the first two sections). The second time could just say "autodiff".

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 19, 2024

In quickstart.md there's a sentence that says "For more advanced autodiff operations, you can use jax..jacrev() for reverse-mode vector-Jacobian products, and jax.jacfwd() for forward-mode Jacobian-vector products"

It previously said something like this, but I found it confusing because the sentence mentions jax.jvp and jax.vjp, then gives an example using jax.jacfwd and jax.jacrev. I fixed this by changing the functions mentioned in the paragraph.

I'm fine with switching back, but that re-introduces the strangeness of the example not using the functions mentioned in the test. What do you suggest?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 19, 2024

I'm going to close this PR, becuase it's about moving camera-ready content to the right place. The volume of content-related comments is telling me it's not actually camera-ready! It will be cleanest to address those separately in-place, and then move the content once it's ready.

@jakevdp jakevdp closed this Apr 19, 2024
@jakevdp jakevdp deleted the doc-new-tutorials branch April 19, 2024 17:56
@mattjj
Copy link
Collaborator

mattjj commented Apr 19, 2024

I'm fine with switching back, but that re-introduces the strangeness of the example not using the functions mentioned in the test. What do you suggest?

Yeah tricky call. Do you think it's okay to use vmap here (with a forward pointer), even though we haven't covered it in the tutorial order yet? If so then I'll try inlining simplified (no pytree handling) definitions of jacfwd and jacrev, so the example looks like vmap + jvp + vmap + vjp = hessian, and we can take a look at that version. Or if it's a priori obvious that we shouldn't mention vmap here, I'm not sure what to do... maybe move it to the vmap section and have a forward pointer here?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 19, 2024

I don't think we should confuse things by adding vmap. If you want to mention jvp and vjp, then the example should use those. Or, if you want to keep the example as-is, maybe one more sentence explaining why it's relevant would be enough.

mattjj added a commit to mattjj/jax that referenced this pull request Apr 19, 2024
clrpackages pushed a commit to clearlinux-pkgs/pypi-jax that referenced this pull request May 13, 2024
… 0.4.27

Adam Paszke (20):
      Fix Pallas' while_loop lowering to properly account for the loop length
      Initial commit for Mosaic GPU
      Increase sharding to avoid timeouts
      [Mosaic:GPU] Update lowering to match upstream changes in the LLVM dialect
      Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
      [Mosaic] Always use 32-bit selects while retiling
      [Mosaic] Add support for concatenating arrays of packed types (<32 bits)
      Update GPU and NVGPU MLIR bindings to match upstream MLIR changes
      [Mosaic GPU] Implement a simple profilng tool using CUDA events
      [Mosaic GPU] Add the first example: pipelined matmul
      [Mosaic GPU] Disable matmul tests in internal CI
      [Mosaic GPU] Use a custom TMA descriptor initialization method
      Fix imports in Mosaic GPU examples
      [Mosaic GPU] Use the profiler to compute approximate matmul TFLOPs
      [Mosaic GPU] Stop using the MLIR CUDA runtime
      [Mosaic GPU] Add the flash attention example
      [Mosaic GPU] Only call kernel initializer from inside a custom call
      [Mosaic GPU] Fix the diagnostic dump infrastructure
      [Mosaic] Add support for remote DMAs and semaphores in megacore mode
      [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args

Benjamin Bastian (2):
      Remove outdated section in cloud_tpu_colabs README
      Update outdated comment Pmap_Cookbook

Blake Hechtman (1):
      [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.

Chase Roberts (1):
      Add more imports to jax extend

Chi Zeng (1):
      Let caller switch implementation of reduction after import

Chris Jones (1):
      Allow replacing jaxpr `debug_info` with `None`.

Christian Sigg (1):
      Switch `llo` and `tpu` dialects to MLIR properties.

Clemens Giuliani (3):
      Add a common flag for the collectives implementations on cpu.
      try remove MpiCollectives from type annotation
      ignore type

Dan Zheng (1):
      [jaxpr.rst] Remove extraneous 'let' in Jaxpr grammar.

Daniel Ng (1):
      Move dtype settings out of  metadata field into the root of Tensorstore spec

David Dunleavy (2):
      Move `tsl/BUILD`, `tsl.bzl`, and `tsl.default.bzl` to XLA
      Update references to TSL `config_settings` to their new home in XLA

David Hall (3):
      set tensorstore settings in cloud_tpu that make serialization more robust
      revert
      the reverters have been reverted

Dinghua Li (2):
      Introduce an "inline_seq_dim" mode to paged attention kernel. The mode will fuse kernel instances along the sequence dimension into one kernel, hence reducing the number of kernels.
      Fix incorrect sequence length in batch megacore mode and enable megacore tests which were incorrectly disabled before.

Dougal (2):
      Add a zeros rule for mutable arrays and test it using a custom vjp.
      Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.

Enrique Piqueras (2):
      Add builtin cc dataclass pytree node for performance.
      Fix jax.tree_util.register_dataclass in older JAX versions.

Frederic Bastien (1):
      Current status + build script fixes

George Necula (6):
      [callback] Add a flag to implement host_callback in terms of io_callback.
      [jax2tf] Adjust tolerance for asinh test.
      Remove old ducc_fft custom call.
      Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
      Fix mlir.merge_mlir_modules to properly remember the inlined symbols
      [export] Add backwards compatibility test for Pallas call on GPUs.

Henning Becker (1):
      Move CUDA specific functions from asm_compiler to cuda_asm_compiler target

Jackson Stokes (1):
      [XLA:TPU] Support scan loops for parameter input and output streaming in host offloading.

Jake VanderPlas (66):
      README: improve TPU accuracy discussion
      test: work around issue with kstest in scipy>1.12
      Remove support for complex jnp.floor_divide
      Restore upstream-nightly github action
      ufunc: fix implements wrapper for at
      [array api] allow Python scalar arguments to functions
      BUG: fix sign of beta()
      CI: pin actions dependencies to most recent release
      CI: avoid deprecated ruff configurations
      Fix typos in comments
      Finalize the deprecation of the arr.device() method
      test: fix testClipStaticBounds complex warning
      Make complex_arr.astype(bool) follow NumPy's semantics
      remove test of deprecated jax.random.shuffle API
      softmax: deprecate initial argument & always set to -inf internally
      DOC: update Colab mention in distributed arrays doc
      CI: run the metal CI nightly & on PRs touching the config file
      Avoid 'from jax import config' imports
      [key reuse] handle reuse of closed-over constants
      [key reuse] refactor & unify key reuse rule registration
      [key reuse] call key reuse logic directly in dispatch
      Finalize deprecation of invalid JIT argument names & numbers
      DOC: document strict dtype promotion mode
      [array api] update to latest test repo commit
      DOC: add introduction to sharded computation
      DOC: update installation guide
      DOC: fix Mesh construction in sharding doc
      Pin sphinx version to avoid error in 7.3.0
      DOC: add stateful computation doc
      DOC: pin sphinx to >=7.3.2
      DOC: remove copy of thinking-in-jax from new tutorial flow
      DOC: one last readthrough of the new 101 tutorials
      Add jax.scipy.special.gammasgn
      Finalize deprecation of zero-dimensional inputs to jnp.nonzero
      jnp.select: lower to lax.select_n
      DOC: respond to mattjj comments
      pallas flash attention: explicitly use dtype
      Finalize deprecation of `arr.device_buffer` and `arr.device_buffers`
      API deprecation doc: explicitly list some dates
      Remove extraneous pure_callback_api wrapper
      sparse_nm_test: skip on incompatible GPUs
      DOC: improve documentation of where, nonzero, and related functions
      DOC: link docstrings to github source
      DOC: improve docs of transpose & matrix_transpose
      Finalize deprecation of jnp.where keyword arguments
      jax.scipy.fft: manually document functions to avoid scipy import
      Cache (most) calls to dtypes.issubdtype
      DOC: add manual documentation to jax.scipy.special functions.
      DOC: improve docs for reshape() and ravel()
      test: fix reshape signature test for NumPy 2.1
      Remove dependency on sphinx_autodoc_typehints
      Finalize deprecation of lax.linalg positional args
      jax.scipy.stats: manually document functions to avoid scipy import
      DOC: Improve docstrings for jax.scipy.signal
      DOC: Improve docstrings for jax.scipy.linalg
      DOC: Improve remaining jax.scipy docstrings
      Remove remaining top-level scipy imports
      Remove last scipy imports
      Refactor jax.numpy docstring tests
      DOC: replace old tutorials with new content
      jnp.delete: better docs
      Require arraylike input for several jax.numpy functions
      random_lax_test: fix kstest for newer NumPy
      jnp.linalg: improve API documentation
      jnp.linalg tensorinv & tensorsolve: improve implementation & docs
      CI: print numpy/scipy version in upstream job

James Lottes (1):
      Refactor QDWH to be more efficient when run batched under vmap.

Jamie Townsend (1):
      Correct a name in ScatterDimensionNumbers docstring

Jesse Rosenstock (1):
      DOC: Fix :ref: format

Jevin Jiang (4):
      [Pallas] Fix typo in semaphore_wait error messages.
      [Pallas][Mosaic] Expose semaphore read.
      [XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
      [XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.

Jieying Luo (9):
      [PJRT C API] Plumb plugin attributes from plugin to JAX python.
      Add a fallback when GetDefaultLayout is unimplemented for that backend.
      Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
      [PJRT C API] Add stream extension to support DLPack and implement this extension in CUDA plugin.
      Enable some more C API tests.
      Rename arg in build script to be more clear.
      Add get_device_ordinal to cuda plugin so that CUDA dependency can be removed from py_array (jaxlib).
      Add instructions to report issues to GitHub for unsupported features in PJRT C API.
      Fix cuda array interface with old jaxlib.

John QiangZhang (1):
      [jax] Fix jax_export issue with static args.

Junwhan Ahn (3):
      Optimize `_create_copy_plan` in array.py
      Optimize `_create_copy_plan` by iterating over only the shards that are needed for materialization
      Optimize `jax.device_put()` dispatch for 1:1 device-to-device transfers

Justin Fu (6):
      Fix minor typo in Pallas docs.
      Add proper handling of OOB array accesses in pallas interpret mode.
      Redefines pltpu.trace as an alias of jax.named_scope.
      [Mosaic] Add guard for absl-py in tpu_custom_call.py.
      Remove explicit pallas trace_start/trace_stop primitives. These are now automatically inserted with the usage of jax.named_scope.
      Reverts 7844bac5d220b41253495cacf719f61905f46925

Kevin Kiningham (1):
      Add type information to Pallas primatives.

Lianmin Zheng (1):
      Fix a typo in jax.jit docstring

Marvin Kim (2):
      [JAX] Fix typo in comment.
      [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.

Matteo Hessel (1):
      Add sparse_sigmoid to jax.nn

Matthew Johnson (18):
      small simplification to asymptotic complexity of make_jaxpr_effects
      [mutable-arrays] enable refs without cps, and not just at top level
      add a temporary config option to disable custom_vjp shape checking
      start adding EArray, a jax.Array analog that can contain extended dtypes
      alias jax.sharding.NamedSharding -> jax.NamedSharding
      [callbacks] io_callback batching rule accidentally called pure_callback
      [callbacks] allow unordered effects in batched while_loop if predicate is not batched
      simple fix to make_jaxpr docstring
      fix cache key typo np.ndarray -> np.arange(...).reshape
      fix
      relax a side-effects test that was erroneously checking against a canonical order
      quickstart tweaks (from jax-ml/jax#20819)
      tweak jit docstring
      skip pallas/gmm_test.py if we don't have hypothesis
      [omnitracing] pop frames, see if anything downstream breaks
      fix vmap-of-grad-of-shmap axis name reuse bug
      [omnitracing] partially un-regress dispatch time
      [shard_map] better fix for spmd_axis_name issues with shmap residuals

Meekail Zain (16):
      Update
      Update env name
      Update `from_dlpack` to match array API 2023
      Update `jnp.clip` to Array API 2023 standard
      Updated FAQ to include section on CUDA library initialization failures
      Clean up sparse test run conditions
      Add __array_namespace_info__ and corresponding utilities
      Add support for max_version, dl_device, copy kwargs in __dlpack__
      Expose existing functions in array API namespace
      Fixed hypot bug on nan/inf pairings, began deprecation of non-real values
      Add new unstack function to numpy/array_api namespaces
      Add new cumulative_sum function to numpy and array_api
      Add support for copy kwarg in astype to match Array API
      Expose tile function in array_api namespace
      Refactor array_api namespace, relying more directly on jax.numpy
      Refactored common upcast for integral-type accumulators

Mohammed Anany (1):
      Integrate Triton up to [8e0c7b42](https://github.com/openai/triton/commits/8e0c7b425ac149c43183de966ffa423fd46e4762)

Olli Lupton (3):
      cuInit before querying compute capability
      _version_from_git_tree: avoid git describe
      jet_test: use float32 matmul precision

Parker Schuh (1):
      Support auto in shard_map.

Paul Wohlhart (1):
      Use xla_client.Device in jax.numpy.

Pavel T (1):
      better unsupported indexing handling in lax_numpy.py

Pearu Peterson (4):
      Workaround numpy 1.x assert_allclose false-positive result in comparing complex infinities.
      Workaround mpmath 1.3 bugs in tan and tanh evaluation at infinities
      Workaround mpmath 1.3 issues in asin and asinh evaluation at infinities and on branch cuts.
      Workaround mpmath 1.3 issues in asinh evaluation at infinities

Peter Hawkins (8):
      Fix test failure in complex clip test.
      Install fork() warning during backend initialization, rather than jax import.
      Import etils.epath lazily.
      [jax2tf] Bump asinh test tolerance in graph and eager modes.
      Fix pytest failures from compilation cache test.
      Remove jax_triton as a BUILD dependency of pallas_test.py.
      Fix warnings in CI from compilation_cache_test.
      Compute source maps when pretty-printing jaxprs.

Rebecca Chen (1):
      Silence some pytype errors.

Roy Frostig (7):
      changelog: batching rule change for `rng_bit_generator`
      changelog: note doc change to use `jax.random.key` over `PRNGKey`
      add and populate `jax.extend.core.primitives`
      fix up `extend:core` build rule
      remove Threefry GPU kernel
      bump shard count for `random_lax_test`
      reintroduce the Threefry GPU kernel lowering, under a flag

Ruturaj4 (1):
      [ROCm]: fix tsl path

Sai-Suraj-27 (4):
      Updated all the pre-commit hooks versions.
      Updated ubuntu os version in readthedocs.yml file to the latest suggested version.
      Prefer raising of TypeError for invalid types instead of ValueError.
      Made the error messages when raising TypeError better.

Samuel (1):
      Minor typo fix in docstring `jax.lax.psum`

Selam Waktola (2):
      redundant phrase 'ever time' removed
      improve documentation for ix_

Sergei Lebedev (22):
      Fixed a few typos in the matmul example in "Pallas Design"
      Pallas now exclusively uses XLA for compiling kernels on GPU
      jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
      Do not run Pallas GPU tests on Windows
      The compiler_params= argument of pl.pallas_call on GPU now uses "triton" to refer to Triton-specific parameters, instead of the repetitive "triton_params"
      pallas_call now has only one way to pass compiler_params=
      _cast() now takes JAX dtypes
      lax.Precision now uses _missing_ to handle aliases
      Pallas TPU now accepts compiler parameters only via mosaic=...
      Bumped the minimum jaxlib to 0.4.23
      Import web_pdb lazily
      Import rich lazily
      Added int4 and uint4 to dtype-specific tests
      jax.debug.callback now requires a Callable[..., None]
      jax.debug.callback now passes arguments as jax.Arrays
      Removed unnecessary jax.tree.map calls from *_callback_impl functions
      Added a GPU-specific approximate tanh to Pallas
      Added elementwise_inline_asm to Pallas GPU
      Do not require a capsule to have a specific name in the CUDA plugin
      Bundle MLIR .pyi files with jaxlib
      Ported LuPivotsToPermutation to the typed XLA FFI
      Made has_side_effect= parameter of mlir.emit_python_callback keyword-only

Sergey Kozub (3):
      Add JAX API that provides sparse matmul support (2:4 structured sparsity)
      Use correct kWidth in sparse dots with int8 input (on Ampere)
      Disable sparse_nm_test_gpu_h100 because of flakiness

Sharad Vikram (1):
      Add dynamic grid support to emit_pipeline

Shuhan Ding (4):
      update along with lax_numpy_test
      enable conv1d test
      metal_plugin ci with jaxlib nightly
      fix jaxlib config name

Stephan Hoyer (2):
      Make linkcode_resolve() a bit more robust
      Recursively pull out __wrapped__ in linkcode_resolve()

Tomás Longeri (3):
      [Mosaic] Fix Python pipeline not working correctly when HLO passes are enabled.
      [Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n).
      [Mosaic] Expand support of vector.extract and vector.extract_strided_slice

Yash Katariya (32):
      Finish jax and jaxlib 0.4.26 release
      Expose `Layout(device_local_layout, sharding)` class allowing users to specify layouts of Arrays.
      Rename `layout.AUTO` to `DeviceLocalLayout.AUTO`
      Make device_local_layout and sharding optional in `Layout`. Also only accept `Layout` class to `_in_layouts` and `_out_layouts`.
      Remove the unused return from prepare_axis_resources
      Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted
      `device_local_layout` can be None on a jax.Array for backends that don't implement certain required methods for a jax.Array to populate the `device_local_layout`.
      Add `Layout` support to `jax.jit`.
      Share lowering code between jit and aot jit path
      Remove deprecated code from JAX lowering and compilation
      Delete deprecated AOT layouts API.
      Remove `mesh_shape` from ShardingSpec since it's not in use anymore.
      Add kwargs support to in_shardings argument of jax.jit.
      Reverts a1c8207caea8bbc323bbcfb7735768822a59f5ce
      Don't do layout checks during compiled safe call on DCE'd args.
      Remove the sharding and layout checks for non-DCE'd arguments during AOT safe call.
      Use `_internal_device_list` in `_get_device` so that all places accessing `_get_device` get a speedup.
      Make sure we don't return GSPMDSharding in `compiled.input_shardings`
      Remove the dead code now that jax.Array is the only array we have
      Add support for loading checkpoints with a given layout to the array serialization library
      If callback returns a fully replicated global array, return it as is.
      Accept layout on `ShapeDtypeStruct` on the `sharding` argument. `DeviceLocalLayout.AUTO` is not allowed on SDS.
      Add layout support to `make_array_from_callback`.
      Remove the spmd_mode check from OSS JAX since enhanced barrier is switched on for OSS JAX
      Skip `test_spmd_preserves_input_sharding_vmap_grad` unless `xla_extension_version >= 258`
      Cache the `_check_sharding` check in device_put. If aval and sharding are the same, no need to check multiple times
      Merge some loops in device_put since it's trivial to do so
      Fix donation with kwargs. The problem is that pytrees sort dictionaries by default. So if we create the donation vector with original `kwargs` order, it won't match the aval order (which is created by sorting kwargs i.e. dict) and we end up donating the wrong input.
      Replace donation_vector's logic with `donation_vector_with_in_tree` which is now deleted
      Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation
      Add `specialize` on `jax.jit` so that we can delete the duplicate code in `jax.make_jaxpr`.
      Start jax and jaxlib 0.4.27 release

Yue Sheng (6):
      Fix token management for ordered side-effects.
      Async dispatch expensive computations on the JAX CPU backend.
      Update JAX official doc: point out that the device numbers are not in numerical order because of the underlying torus hardware topology.
      Temporarily disable async dispatch on JAX CPU by setting 'jax_cpu_enable_async_dispatch' to be `False` by default, as we observed abnormal memory usage increases.
      Make `core.Token` a non-trivial class which wraps a `jax.Array`. Currently, we use a singleton and empty `core.token` object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
      Unify token lowering in JAX to always use `stablehlo.token`.

Yunlong Liu (1):
      Adds meaningful function names for better debugging.

carlosgmartin (4):
      Add jax.nn.mish.
      Add where argument to logsumexp.
      Let initial=-jnp.inf by default for nn.softmax and nn.log_softmax.
      Let xs=None by default in lax.scan.

dependabot[bot] (3):
      Bump actions/download-artifact from 4.1.4 to 4.1.6
      Bump actions/upload-artifact from 4.3.1 to 4.3.3
      Bump actions/download-artifact from 4.1.6 to 4.1.7

jax authors (69):
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Make layout on Array a property instead of a cached_property.
      Update XLA dependency to use revision
      [Pallas TPU] Pallas while loop -> fori test.
      Update XLA dependency to use revision
      Jax persistent compilation cache user guide.
      [Pallas TPU] Convert pattern_match_while_to_fori_loop to return (Jaxpr, str) rather than throw exceptions.
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      [Pallas TPU] Add missing Mosaic lowering rules for float comparisons, tests.
      Update XLA dependency to use revision
      [Pallas] Global Barrier bug fix.
      [Pallas TPU] Raise clearer NotImplementedError on vector -> scalar reductions.
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      [XLA] Clear derived instruction's sharding only if shapes are incompatible.
      Update XLA dependency to use revision
      fix matrix dimension and block shape.
      Update XLA dependency to use revision
      [Mosaic] Support scf.while and scf.condition.
      Update XLA dependency to use revision
      Reverts f18739900c615a85e8d182bcf3217f704cf7aa0d
      Update XLA dependency to use revision
      Document the fact that jax.clear_caches() doesn't affect the persistent cache.
      Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero.
      [Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Add some docstrings for remote DMAs and semaphore barriers.
      Return early from eigh for small matrices. This was accidentally removed cl/577222219.
      Update XLA dependency to use revision
      Only perform checks on slice sizes if they're static.
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Reverts 6bfbb4593a42fced91ba50de47271af425c74c20
      Update XLA dependency to use revision
      Adds rewrite patterns to LinalgVectorizationPass to eliminate transfer_read and transfer_write ops.
      Improve performance of SVD when batched by avoiding several cond() constructs.
      Adds rewrite patterns for `arith.{cmpi,select}` and `tensor.splat` as sources to a vector.transfer_read op.
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Use the `windows-2022` image for running the `windows_ci` workflow.
      Point the `windows_ci` workflow to the correct VC directory.
      Fix example in mosaic tpu dialect layout.h
      Try fixing the MSVC path for `windows_ci` workflow once more.
      Return to the `win-2019` image for the `windows_ci` workflow.
      [XLA:CPU] Enable constant host offloading
      Update XLA dependency to use revision
      [Pallas TPU] Print the exception when a lowering exception occurs.
      Allow multiple indexers when doing discharge or swap in pallas
      Update XLA dependency to use revision
      Change determination of cloud TPU to check for TPU chips.
      [Pallas TPU] Increase clarity of dot 2D shape enforcement error.
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Update XLA dependency to use revision
      Don't create temp directory when module is getting imported.
      Pass `bazel_options` directly to the Bazel command, instead of into .bazelrc.
      Add a config for using Clang on Windows.
      Add an experimental, Clang version of the Windows CI job.
      Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model.
      Update XLA dependency to use revision

kaixih (1):
      Support BNTH input formats

rajasekharporeddy (9):
      Update jax.scipy.special.ndtri to return NaN for the values outside of the range [0, 1]
      Fix doc typos
      Fix Typos in docs
      DOC: Fix docstring typos in scipy special functions
      Fix doc Typos
      Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k
      Fix jax.scipy.stats.beta.logpdf to emulate scipy.stats.beta.logpdf
      Fix Typos and math rendering in jax.random docs
      Fix typos in docs and an error message
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants