-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DOC: replace old tutorials with new content #20819
Conversation
fe5ce96
to
7897fda
Compare
7897fda
to
3f6ec60
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how to attach this to the right place, but: there's a line at the top of the quickstart that says "a la NumPy", can we change that to "à la NumPy"? I don't want to offend the French.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Apologies again, I'm not sure how to attach this comment to the right place due to how the file move is represented in GitHub.)
In quickstart.md, there's a line that says "the jax.jacobian() transformation can be used to compute gradients of vector-valued functions". It refers to the gradients of vector-valued functions.
IMO we should reserve "gradient" only to apply to scalar-valued functions. That's consistent with Wikipedia, the jax.grad
restriction, and my understanding of usual conventions.
How about we rephrase to something like:
"Beyond scalar-valued functions, we can use jax.jacobian
to compute the full Jacobian matrix of vector-valued functions."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In quickstart.md there's a sentence that says
For more advanced autodiff operations, you can use jax..jacrev() for reverse-mode vector-Jacobian products, and jax.jacfwd() for forward-mode Jacobian-vector products
But actually 'vector-Jacobian products' and 'jacobian-vector products' are describing jax.vjp
and jax.jvp
, not jax.jacrev
and jax.jacfwd
. Let's change it to:
For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() or
jax.linearize
for forward-mode Jacobian-vector products
But then the text that just follows needs to be revised too, since it refers to jacfwd
and jacrev
... we could define those inline with vmap
as ~1 liners each, or add a sentence introducing jacfwd
and jacrev
separately. I kind of lean towards defining them inline. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In key-concepts.md, there's a line "JAX arrays are never constructed directly" that wasn't immediately clear to me what it meant. How about we change it to something like "We don't call the jax.Array
object constructor directly, like jax.Array(...)
, but rather ..."?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In key-concepts.md I noticed some "jax" instead of "JAX", but we can fix those later.
I also noticed that "The magic behind transformations is the notion of a Tracers." should end with the word "Tracer" rather than "Tracers" I think (or the "a" should be deleted).
Maybe I'll try attaching comments to the deleted side of the moved files... EDIT nope couldn't do it, e.g. nowhere to attach comments to key-concepts.md line-by-line afaict. |
Docs are such a great opportunity to check over the fundamentals. In key-concepts.md, the Tracing section reminded me that I had a recent experiment trying to improve how Tracers are printed. The This branch makes some small tweaks, so that when we run this test file: import jax
import jax.numpy as jnp
jax.config.update('jax_platforms', 'cpu')
cool = []
def f(x):
cool.append(repr(x))
return x.sum()
def do(expr):
eval(expr)
print(f'{expr}\n\n{cool.pop()}')
print('\n===\n')
jax.clear_caches()
x = jnp.ones((3, 3))
do('jax.jit(f)(x)')
do('jax.grad(f)(x)')
do('jax.grad(lambda x: jax.grad(f)(x).sum())(x)')
do('jax.grad(jax.jit(f))(x)')
do('jax.jit(jax.grad(f))(x)') instead of getting this like we do at HEAD:
We get this:
I think for this line in the docs we'd get Landing this is outside the scope of this PR review, but I wanted to jot this down for posterity so we don't forget. LMK if you have any thoughts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In key-concepts.md there's a line "short for JAX eXPRession" that should be "short for JAX exPRession" to make the capitals work out.
How about we change that sentence from "A jaxpr (short for JAX eXPRession) represents a list of core units of computation called primitives that represent the effect of a computation." to something like "A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations."
In jit-compilation.md, I think this paragraph might make it sound like
The trouble is that in the example, we'll leak a tracer, and all heck will break loose! Maybe we can change it to something like: Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. (If you want debug printing, use I'm not sure about all those forward pointers but they may be useful. I've liked the other forward pointers in here so far! |
Nevermind, that's exactly what is already said in debugging.md. So a forward pointer covers that. |
In automatic-differentiation.md, we write "automatic differentiation (autodiff)" twice (once in each of the first two sections). The second time could just say "autodiff". |
It previously said something like this, but I found it confusing because the sentence mentions I'm fine with switching back, but that re-introduces the strangeness of the example not using the functions mentioned in the test. What do you suggest? |
I'm going to close this PR, becuase it's about moving camera-ready content to the right place. The volume of content-related comments is telling me it's not actually camera-ready! It will be cleanest to address those separately in-place, and then move the content once it's ready. |
Yeah tricky call. Do you think it's okay to use |
I don't think we should confuse things by adding |
… 0.4.27 Adam Paszke (20): Fix Pallas' while_loop lowering to properly account for the loop length Initial commit for Mosaic GPU Increase sharding to avoid timeouts [Mosaic:GPU] Update lowering to match upstream changes in the LLVM dialect Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch [Mosaic] Always use 32-bit selects while retiling [Mosaic] Add support for concatenating arrays of packed types (<32 bits) Update GPU and NVGPU MLIR bindings to match upstream MLIR changes [Mosaic GPU] Implement a simple profilng tool using CUDA events [Mosaic GPU] Add the first example: pipelined matmul [Mosaic GPU] Disable matmul tests in internal CI [Mosaic GPU] Use a custom TMA descriptor initialization method Fix imports in Mosaic GPU examples [Mosaic GPU] Use the profiler to compute approximate matmul TFLOPs [Mosaic GPU] Stop using the MLIR CUDA runtime [Mosaic GPU] Add the flash attention example [Mosaic GPU] Only call kernel initializer from inside a custom call [Mosaic GPU] Fix the diagnostic dump infrastructure [Mosaic] Add support for remote DMAs and semaphores in megacore mode [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args Benjamin Bastian (2): Remove outdated section in cloud_tpu_colabs README Update outdated comment Pmap_Cookbook Blake Hechtman (1): [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts. Chase Roberts (1): Add more imports to jax extend Chi Zeng (1): Let caller switch implementation of reduction after import Chris Jones (1): Allow replacing jaxpr `debug_info` with `None`. Christian Sigg (1): Switch `llo` and `tpu` dialects to MLIR properties. Clemens Giuliani (3): Add a common flag for the collectives implementations on cpu. try remove MpiCollectives from type annotation ignore type Dan Zheng (1): [jaxpr.rst] Remove extraneous 'let' in Jaxpr grammar. Daniel Ng (1): Move dtype settings out of metadata field into the root of Tensorstore spec David Dunleavy (2): Move `tsl/BUILD`, `tsl.bzl`, and `tsl.default.bzl` to XLA Update references to TSL `config_settings` to their new home in XLA David Hall (3): set tensorstore settings in cloud_tpu that make serialization more robust revert the reverters have been reverted Dinghua Li (2): Introduce an "inline_seq_dim" mode to paged attention kernel. The mode will fuse kernel instances along the sequence dimension into one kernel, hence reducing the number of kernels. Fix incorrect sequence length in batch megacore mode and enable megacore tests which were incorrectly disabled before. Dougal (2): Add a zeros rule for mutable arrays and test it using a custom vjp. Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file. Enrique Piqueras (2): Add builtin cc dataclass pytree node for performance. Fix jax.tree_util.register_dataclass in older JAX versions. Frederic Bastien (1): Current status + build script fixes George Necula (6): [callback] Add a flag to implement host_callback in terms of io_callback. [jax2tf] Adjust tolerance for asinh test. Remove old ducc_fft custom call. Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1 Fix mlir.merge_mlir_modules to properly remember the inlined symbols [export] Add backwards compatibility test for Pallas call on GPUs. Henning Becker (1): Move CUDA specific functions from asm_compiler to cuda_asm_compiler target Jackson Stokes (1): [XLA:TPU] Support scan loops for parameter input and output streaming in host offloading. Jake VanderPlas (66): README: improve TPU accuracy discussion test: work around issue with kstest in scipy>1.12 Remove support for complex jnp.floor_divide Restore upstream-nightly github action ufunc: fix implements wrapper for at [array api] allow Python scalar arguments to functions BUG: fix sign of beta() CI: pin actions dependencies to most recent release CI: avoid deprecated ruff configurations Fix typos in comments Finalize the deprecation of the arr.device() method test: fix testClipStaticBounds complex warning Make complex_arr.astype(bool) follow NumPy's semantics remove test of deprecated jax.random.shuffle API softmax: deprecate initial argument & always set to -inf internally DOC: update Colab mention in distributed arrays doc CI: run the metal CI nightly & on PRs touching the config file Avoid 'from jax import config' imports [key reuse] handle reuse of closed-over constants [key reuse] refactor & unify key reuse rule registration [key reuse] call key reuse logic directly in dispatch Finalize deprecation of invalid JIT argument names & numbers DOC: document strict dtype promotion mode [array api] update to latest test repo commit DOC: add introduction to sharded computation DOC: update installation guide DOC: fix Mesh construction in sharding doc Pin sphinx version to avoid error in 7.3.0 DOC: add stateful computation doc DOC: pin sphinx to >=7.3.2 DOC: remove copy of thinking-in-jax from new tutorial flow DOC: one last readthrough of the new 101 tutorials Add jax.scipy.special.gammasgn Finalize deprecation of zero-dimensional inputs to jnp.nonzero jnp.select: lower to lax.select_n DOC: respond to mattjj comments pallas flash attention: explicitly use dtype Finalize deprecation of `arr.device_buffer` and `arr.device_buffers` API deprecation doc: explicitly list some dates Remove extraneous pure_callback_api wrapper sparse_nm_test: skip on incompatible GPUs DOC: improve documentation of where, nonzero, and related functions DOC: link docstrings to github source DOC: improve docs of transpose & matrix_transpose Finalize deprecation of jnp.where keyword arguments jax.scipy.fft: manually document functions to avoid scipy import Cache (most) calls to dtypes.issubdtype DOC: add manual documentation to jax.scipy.special functions. DOC: improve docs for reshape() and ravel() test: fix reshape signature test for NumPy 2.1 Remove dependency on sphinx_autodoc_typehints Finalize deprecation of lax.linalg positional args jax.scipy.stats: manually document functions to avoid scipy import DOC: Improve docstrings for jax.scipy.signal DOC: Improve docstrings for jax.scipy.linalg DOC: Improve remaining jax.scipy docstrings Remove remaining top-level scipy imports Remove last scipy imports Refactor jax.numpy docstring tests DOC: replace old tutorials with new content jnp.delete: better docs Require arraylike input for several jax.numpy functions random_lax_test: fix kstest for newer NumPy jnp.linalg: improve API documentation jnp.linalg tensorinv & tensorsolve: improve implementation & docs CI: print numpy/scipy version in upstream job James Lottes (1): Refactor QDWH to be more efficient when run batched under vmap. Jamie Townsend (1): Correct a name in ScatterDimensionNumbers docstring Jesse Rosenstock (1): DOC: Fix :ref: format Jevin Jiang (4): [Pallas] Fix typo in semaphore_wait error messages. [Pallas][Mosaic] Expose semaphore read. [XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type. [XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim. Jieying Luo (9): [PJRT C API] Plumb plugin attributes from plugin to JAX python. Add a fallback when GetDefaultLayout is unimplemented for that backend. Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target. [PJRT C API] Add stream extension to support DLPack and implement this extension in CUDA plugin. Enable some more C API tests. Rename arg in build script to be more clear. Add get_device_ordinal to cuda plugin so that CUDA dependency can be removed from py_array (jaxlib). Add instructions to report issues to GitHub for unsupported features in PJRT C API. Fix cuda array interface with old jaxlib. John QiangZhang (1): [jax] Fix jax_export issue with static args. Junwhan Ahn (3): Optimize `_create_copy_plan` in array.py Optimize `_create_copy_plan` by iterating over only the shards that are needed for materialization Optimize `jax.device_put()` dispatch for 1:1 device-to-device transfers Justin Fu (6): Fix minor typo in Pallas docs. Add proper handling of OOB array accesses in pallas interpret mode. Redefines pltpu.trace as an alias of jax.named_scope. [Mosaic] Add guard for absl-py in tpu_custom_call.py. Remove explicit pallas trace_start/trace_stop primitives. These are now automatically inserted with the usage of jax.named_scope. Reverts 7844bac5d220b41253495cacf719f61905f46925 Kevin Kiningham (1): Add type information to Pallas primatives. Lianmin Zheng (1): Fix a typo in jax.jit docstring Marvin Kim (2): [JAX] Fix typo in comment. [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched. Matteo Hessel (1): Add sparse_sigmoid to jax.nn Matthew Johnson (18): small simplification to asymptotic complexity of make_jaxpr_effects [mutable-arrays] enable refs without cps, and not just at top level add a temporary config option to disable custom_vjp shape checking start adding EArray, a jax.Array analog that can contain extended dtypes alias jax.sharding.NamedSharding -> jax.NamedSharding [callbacks] io_callback batching rule accidentally called pure_callback [callbacks] allow unordered effects in batched while_loop if predicate is not batched simple fix to make_jaxpr docstring fix cache key typo np.ndarray -> np.arange(...).reshape fix relax a side-effects test that was erroneously checking against a canonical order quickstart tweaks (from jax-ml/jax#20819) tweak jit docstring skip pallas/gmm_test.py if we don't have hypothesis [omnitracing] pop frames, see if anything downstream breaks fix vmap-of-grad-of-shmap axis name reuse bug [omnitracing] partially un-regress dispatch time [shard_map] better fix for spmd_axis_name issues with shmap residuals Meekail Zain (16): Update Update env name Update `from_dlpack` to match array API 2023 Update `jnp.clip` to Array API 2023 standard Updated FAQ to include section on CUDA library initialization failures Clean up sparse test run conditions Add __array_namespace_info__ and corresponding utilities Add support for max_version, dl_device, copy kwargs in __dlpack__ Expose existing functions in array API namespace Fixed hypot bug on nan/inf pairings, began deprecation of non-real values Add new unstack function to numpy/array_api namespaces Add new cumulative_sum function to numpy and array_api Add support for copy kwarg in astype to match Array API Expose tile function in array_api namespace Refactor array_api namespace, relying more directly on jax.numpy Refactored common upcast for integral-type accumulators Mohammed Anany (1): Integrate Triton up to [8e0c7b42](https://github.com/openai/triton/commits/8e0c7b425ac149c43183de966ffa423fd46e4762) Olli Lupton (3): cuInit before querying compute capability _version_from_git_tree: avoid git describe jet_test: use float32 matmul precision Parker Schuh (1): Support auto in shard_map. Paul Wohlhart (1): Use xla_client.Device in jax.numpy. Pavel T (1): better unsupported indexing handling in lax_numpy.py Pearu Peterson (4): Workaround numpy 1.x assert_allclose false-positive result in comparing complex infinities. Workaround mpmath 1.3 bugs in tan and tanh evaluation at infinities Workaround mpmath 1.3 issues in asin and asinh evaluation at infinities and on branch cuts. Workaround mpmath 1.3 issues in asinh evaluation at infinities Peter Hawkins (8): Fix test failure in complex clip test. Install fork() warning during backend initialization, rather than jax import. Import etils.epath lazily. [jax2tf] Bump asinh test tolerance in graph and eager modes. Fix pytest failures from compilation cache test. Remove jax_triton as a BUILD dependency of pallas_test.py. Fix warnings in CI from compilation_cache_test. Compute source maps when pretty-printing jaxprs. Rebecca Chen (1): Silence some pytype errors. Roy Frostig (7): changelog: batching rule change for `rng_bit_generator` changelog: note doc change to use `jax.random.key` over `PRNGKey` add and populate `jax.extend.core.primitives` fix up `extend:core` build rule remove Threefry GPU kernel bump shard count for `random_lax_test` reintroduce the Threefry GPU kernel lowering, under a flag Ruturaj4 (1): [ROCm]: fix tsl path Sai-Suraj-27 (4): Updated all the pre-commit hooks versions. Updated ubuntu os version in readthedocs.yml file to the latest suggested version. Prefer raising of TypeError for invalid types instead of ValueError. Made the error messages when raising TypeError better. Samuel (1): Minor typo fix in docstring `jax.lax.psum` Selam Waktola (2): redundant phrase 'ever time' removed improve documentation for ix_ Sergei Lebedev (22): Fixed a few typos in the matmul example in "Pallas Design" Pallas now exclusively uses XLA for compiling kernels on GPU jax.pure_callback and jax.experimental.io_callback now use jax.Arrays Do not run Pallas GPU tests on Windows The compiler_params= argument of pl.pallas_call on GPU now uses "triton" to refer to Triton-specific parameters, instead of the repetitive "triton_params" pallas_call now has only one way to pass compiler_params= _cast() now takes JAX dtypes lax.Precision now uses _missing_ to handle aliases Pallas TPU now accepts compiler parameters only via mosaic=... Bumped the minimum jaxlib to 0.4.23 Import web_pdb lazily Import rich lazily Added int4 and uint4 to dtype-specific tests jax.debug.callback now requires a Callable[..., None] jax.debug.callback now passes arguments as jax.Arrays Removed unnecessary jax.tree.map calls from *_callback_impl functions Added a GPU-specific approximate tanh to Pallas Added elementwise_inline_asm to Pallas GPU Do not require a capsule to have a specific name in the CUDA plugin Bundle MLIR .pyi files with jaxlib Ported LuPivotsToPermutation to the typed XLA FFI Made has_side_effect= parameter of mlir.emit_python_callback keyword-only Sergey Kozub (3): Add JAX API that provides sparse matmul support (2:4 structured sparsity) Use correct kWidth in sparse dots with int8 input (on Ampere) Disable sparse_nm_test_gpu_h100 because of flakiness Sharad Vikram (1): Add dynamic grid support to emit_pipeline Shuhan Ding (4): update along with lax_numpy_test enable conv1d test metal_plugin ci with jaxlib nightly fix jaxlib config name Stephan Hoyer (2): Make linkcode_resolve() a bit more robust Recursively pull out __wrapped__ in linkcode_resolve() Tomás Longeri (3): [Mosaic] Fix Python pipeline not working correctly when HLO passes are enabled. [Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n). [Mosaic] Expand support of vector.extract and vector.extract_strided_slice Yash Katariya (32): Finish jax and jaxlib 0.4.26 release Expose `Layout(device_local_layout, sharding)` class allowing users to specify layouts of Arrays. Rename `layout.AUTO` to `DeviceLocalLayout.AUTO` Make device_local_layout and sharding optional in `Layout`. Also only accept `Layout` class to `_in_layouts` and `_out_layouts`. Remove the unused return from prepare_axis_resources Resolve a TODO now that in_shardings are chosen by XLA for inputs that don't have sharding specified or are uncommitted `device_local_layout` can be None on a jax.Array for backends that don't implement certain required methods for a jax.Array to populate the `device_local_layout`. Add `Layout` support to `jax.jit`. Share lowering code between jit and aot jit path Remove deprecated code from JAX lowering and compilation Delete deprecated AOT layouts API. Remove `mesh_shape` from ShardingSpec since it's not in use anymore. Add kwargs support to in_shardings argument of jax.jit. Reverts a1c8207caea8bbc323bbcfb7735768822a59f5ce Don't do layout checks during compiled safe call on DCE'd args. Remove the sharding and layout checks for non-DCE'd arguments during AOT safe call. Use `_internal_device_list` in `_get_device` so that all places accessing `_get_device` get a speedup. Make sure we don't return GSPMDSharding in `compiled.input_shardings` Remove the dead code now that jax.Array is the only array we have Add support for loading checkpoints with a given layout to the array serialization library If callback returns a fully replicated global array, return it as is. Accept layout on `ShapeDtypeStruct` on the `sharding` argument. `DeviceLocalLayout.AUTO` is not allowed on SDS. Add layout support to `make_array_from_callback`. Remove the spmd_mode check from OSS JAX since enhanced barrier is switched on for OSS JAX Skip `test_spmd_preserves_input_sharding_vmap_grad` unless `xla_extension_version >= 258` Cache the `_check_sharding` check in device_put. If aval and sharding are the same, no need to check multiple times Merge some loops in device_put since it's trivial to do so Fix donation with kwargs. The problem is that pytrees sort dictionaries by default. So if we create the donation vector with original `kwargs` order, it won't match the aval order (which is created by sorting kwargs i.e. dict) and we end up donating the wrong input. Replace donation_vector's logic with `donation_vector_with_in_tree` which is now deleted Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation Add `specialize` on `jax.jit` so that we can delete the duplicate code in `jax.make_jaxpr`. Start jax and jaxlib 0.4.27 release Yue Sheng (6): Fix token management for ordered side-effects. Async dispatch expensive computations on the JAX CPU backend. Update JAX official doc: point out that the device numbers are not in numerical order because of the underlying torus hardware topology. Temporarily disable async dispatch on JAX CPU by setting 'jax_cpu_enable_async_dispatch' to be `False` by default, as we observed abnormal memory usage increases. Make `core.Token` a non-trivial class which wraps a `jax.Array`. Currently, we use a singleton and empty `core.token` object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency. Unify token lowering in JAX to always use `stablehlo.token`. Yunlong Liu (1): Adds meaningful function names for better debugging. carlosgmartin (4): Add jax.nn.mish. Add where argument to logsumexp. Let initial=-jnp.inf by default for nn.softmax and nn.log_softmax. Let xs=None by default in lax.scan. dependabot[bot] (3): Bump actions/download-artifact from 4.1.4 to 4.1.6 Bump actions/upload-artifact from 4.3.1 to 4.3.3 Bump actions/download-artifact from 4.1.6 to 4.1.7 jax authors (69): Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Make layout on Array a property instead of a cached_property. Update XLA dependency to use revision [Pallas TPU] Pallas while loop -> fori test. Update XLA dependency to use revision Jax persistent compilation cache user guide. [Pallas TPU] Convert pattern_match_while_to_fori_loop to return (Jaxpr, str) rather than throw exceptions. Update XLA dependency to use revision Update XLA dependency to use revision [Pallas TPU] Add missing Mosaic lowering rules for float comparisons, tests. Update XLA dependency to use revision [Pallas] Global Barrier bug fix. [Pallas TPU] Raise clearer NotImplementedError on vector -> scalar reductions. Update XLA dependency to use revision Update XLA dependency to use revision [XLA] Clear derived instruction's sharding only if shapes are incompatible. Update XLA dependency to use revision fix matrix dimension and block shape. Update XLA dependency to use revision [Mosaic] Support scf.while and scf.condition. Update XLA dependency to use revision Reverts f18739900c615a85e8d182bcf3217f704cf7aa0d Update XLA dependency to use revision Document the fact that jax.clear_caches() doesn't affect the persistent cache. Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero. [Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic. Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Add some docstrings for remote DMAs and semaphore barriers. Return early from eigh for small matrices. This was accidentally removed cl/577222219. Update XLA dependency to use revision Only perform checks on slice sizes if they're static. Update XLA dependency to use revision Update XLA dependency to use revision Reverts 6bfbb4593a42fced91ba50de47271af425c74c20 Update XLA dependency to use revision Adds rewrite patterns to LinalgVectorizationPass to eliminate transfer_read and transfer_write ops. Improve performance of SVD when batched by avoiding several cond() constructs. Adds rewrite patterns for `arith.{cmpi,select}` and `tensor.splat` as sources to a vector.transfer_read op. Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Use the `windows-2022` image for running the `windows_ci` workflow. Point the `windows_ci` workflow to the correct VC directory. Fix example in mosaic tpu dialect layout.h Try fixing the MSVC path for `windows_ci` workflow once more. Return to the `win-2019` image for the `windows_ci` workflow. [XLA:CPU] Enable constant host offloading Update XLA dependency to use revision [Pallas TPU] Print the exception when a lowering exception occurs. Allow multiple indexers when doing discharge or swap in pallas Update XLA dependency to use revision Change determination of cloud TPU to check for TPU chips. [Pallas TPU] Increase clarity of dot 2D shape enforcement error. Update XLA dependency to use revision Update XLA dependency to use revision Update XLA dependency to use revision Don't create temp directory when module is getting imported. Pass `bazel_options` directly to the Bazel command, instead of into .bazelrc. Add a config for using Clang on Windows. Add an experimental, Clang version of the Windows CI job. Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model. Update XLA dependency to use revision kaixih (1): Support BNTH input formats rajasekharporeddy (9): Update jax.scipy.special.ndtri to return NaN for the values outside of the range [0, 1] Fix doc typos Fix Typos in docs DOC: Fix docstring typos in scipy special functions Fix doc Typos Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k Fix jax.scipy.stats.beta.logpdf to emulate scipy.stats.beta.logpdf Fix Typos and math rendering in jax.random docs Fix typos in docs and an error message
This is the first part of the restructuring in #18585
The main change here is that the older
jax-101
tutorials are replaced by newer, updated discussions, with redirects from the old URLs.Preview here: https://jax--20819.org.readthedocs.build/en/20819/
Find the main new content listed here: https://jax--20819.org.readthedocs.build/en/20819/tutorials.html, and linked in the side-bar.
I've added the new
sphinxext-rediraffe
extension to add automatic redirects for the old URLs to the replacement content. Some examples of the new redirects in action: