diff --git a/.bazelrc b/.bazelrc
index 9d5d9664939e..948d92c29c26 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -215,6 +215,8 @@ build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base
build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base
build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
+build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base
+build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13"
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --config=cuda
@@ -237,6 +239,8 @@ build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base
build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base
build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
+build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base
+build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13"
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml
index 5d46f8fbf0d8..0f90cd72e463 100644
--- a/.github/workflows/ci-build.yaml
+++ b/.github/workflows/ci-build.yaml
@@ -20,16 +20,15 @@ permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
+ cancel-in-progress: true
+
jobs:
lint_and_typecheck:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- - name: Cancel previous
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
- if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -58,11 +57,6 @@ jobs:
prng-upgrade: 0
num_generated_cases: 1
steps:
- - name: Cancel previous
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
- if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -110,11 +104,6 @@ jobs:
matrix:
python-version: ['3.10']
steps:
- - name: Cancel previous
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
- if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -152,11 +141,6 @@ jobs:
matrix:
python-version: ['3.10']
steps:
- - name: Cancel previous
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
- if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -193,11 +177,6 @@ jobs:
enable-x64: 0
num_generated_cases: 10
steps:
- - name: Cancel previous
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
- if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml
index cdba39b3642a..cbe383f21ffe 100644
--- a/.github/workflows/jax-array-api.yml
+++ b/.github/workflows/jax-array-api.yml
@@ -9,6 +9,10 @@ on:
- '**workflows/jax-array-api.yml'
- '**experimental/array_api/**'
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
+ cancel-in-progress: true
+
jobs:
build:
diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml
index 0c739619df1a..75f4bba1a367 100644
--- a/.github/workflows/metal_plugin_ci.yml
+++ b/.github/workflows/metal_plugin_ci.yml
@@ -11,6 +11,10 @@ on:
paths:
- '**workflows/metal_plugin_ci.yml'
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
+ cancel-in-progress: true
+
jobs:
jax-metal-plugin-test:
diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml
index 61912ed8978e..367f8e05bf56 100644
--- a/.github/workflows/wheel_win_x64.yml
+++ b/.github/workflows/wheel_win_x64.yml
@@ -2,6 +2,10 @@ name: Wheel build - Windows CPU x86_64
on:
workflow_dispatch: # allows triggering the workflow run manually
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
+ cancel-in-progress: true
+
env:
DISTUTILS_USE_SDK: 1
MSSdk: 1
@@ -13,16 +17,11 @@ jobs:
matrix:
os: [windows-2019-32core]
arch: [AMD64]
- pyver: ['3.10', '3.11', '3.12']
+ pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2']
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
runs-on: ${{ matrix.os }}
steps:
- - name: Cancel Previous Runs
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
-
- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
@@ -58,7 +57,7 @@ jobs:
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
+ python -m pip install --find-links ${{ github.workspace }}\dist jaxlib
python -m pip install -e ${{ github.workspace }}
- python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml
index 42083f1d087d..194cac6fa79a 100644
--- a/.github/workflows/windows_ci.yml
+++ b/.github/workflows/windows_ci.yml
@@ -6,6 +6,10 @@ on:
pull_request:
types: [ labeled ] # allow force-windows-run label
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
+ cancel-in-progress: true
+
env:
DISTUTILS_USE_SDK: 1
MSSdk: 1
@@ -23,10 +27,6 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- - name: Cancel Previous Runs
- uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
- with:
- access_token: ${{ github.token }}
- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 869b9dfdd196..ee782d04a02c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,7 +10,48 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->
-## jax 0.4.32
+## jax 0.4.34
+
+* New Functionality
+ * This release includes wheels for Python 3.13. Free-threading mode is not yet
+ supported.
+
+* Deprecations
+ * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
+ arguments with `ndim != 1` are now deprecated, and in the future will result
+ in an error.
+
+* Deletion:
+ * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
+ in 0.4.30 JAX release.
+ Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
+ * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
+ `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
+ * You can also use `.out_info` property of `jax.stages.Lowered` to get the
+ output information (like tree structure, shape and dtype).
+ * For cross-backend lowering, you can replace
+ `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
+ `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
+ * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
+ The argument was only used by `xmap` which was removed in 0.4.31.
+
+
+## jax 0.4.33 (September 16, 2024)
+
+This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
+release.
+
+A TPU-only data corruption bug was found in the version of libtpu pinned by
+JAX 0.4.32, which manifested only if multiple TPU slices were present in the
+same job, for example, if training on multiple v5e slices.
+This release fixes that issue by pinning a fixed version of `libtpu`.
+
+This release fixes an inaccurate result for F64 tanh on CPU (#23590).
+
+## jax 0.4.32 (September 11, 2024)
+
+Note: This release was yanked from PyPi because of a data corruption bug on TPU.
+See the 0.4.33 release notes for more details.
* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
@@ -18,6 +59,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
C++ and CUDA code from JAX.
* Changes
+ * `jax_pmap_no_rank_reduction` flag is set to `True` by default.
+ * array[0] on a pmap result now introduces a reshape (use array[0:1]
+ instead).
+ * The per-shard shape (accessable via jax_array.addressable_shards or
+ jax_array.addressable_data(0)) now has a leading (1, ...). Update code
+ that directly accesses shards accordingly. The rank of the per-shard-shape
+ now matches that of the global shape which is the same behavior as jit.
+ This avoids costly reshapes when passing results from pmap into jit.
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
@@ -65,9 +114,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.
-## jaxlib 0.4.32
+## jaxlib 0.4.32 (September 11, 2024)
+
+Note: This release was yanked from PyPi because of a data corruption bug on TPU.
+See the 0.4.33 release notes for more details.
* Breaking changes
+ * This release of jaxlib switched to a new version of the CPU backend, which
+ should compile faster and leverage parallelism better. If you experience
+ any problems due to this change, you can temporarily enable the old CPU
+ backend by setting the environment variable
+ `XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this,
+ please file a JAX bug with instructions to reproduce.
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
diff --git a/WORKSPACE b/WORKSPACE
index 383adf810766..ed284acadf81 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -37,9 +37,10 @@ install_deps()
load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter")
custom_python_interpreter(
name = "python_dev",
- urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"],
- strip_prefix = "Python-{version}",
- version = "3.13.0a6",
+ urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"],
+ strip_prefix = "Python-{version_variant}",
+ version = "3.13.0",
+ version_variant = "3.13.0rc2",
)
load("@xla//:workspace4.bzl", "xla_workspace4")
diff --git a/build/requirements.in b/build/requirements.in
index add6b8577350..a8d81fa5c670 100644
--- a/build/requirements.in
+++ b/build/requirements.in
@@ -11,14 +11,17 @@ matplotlib; python_version>="3.11"
#
# build deps
#
-numpy~=2.0.0
+numpy~=2.0.0; python_version<="3.12"
+numpy~=2.1.0; python_version>="3.13"
#
# runtime deps
#
-scipy~=1.13.1
+scipy>=1.13.1
ml_dtypes>=0.4.0
opt_einsum
zstandard
etils[epath]
+# TODO(ybaturina): remove setuptools version
+setuptools<71.0.0
diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt
index 62b5e14e65b4..e2369a8001bb 100644
--- a/build/requirements_lock_3_13.txt
+++ b/build/requirements_lock_3_13.txt
@@ -1,52 +1,423 @@
#
-# This file is autogenerated by pip-compile with Python 3.12
+# This file is autogenerated by pip-compile with Python 3.13
# by the following command:
#
-# bazel run //build:requirements_dev.update
+# bazel run //build:requirements.update
#
-absl-py==2.1.0
+absl-py==2.1.0 \
+ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \
+ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
# via -r build/test-requirements.txt
-attrs==23.2.0
+attrs==24.2.0 \
+ --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \
+ --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2
# via hypothesis
-build==1.2.1
+build==1.2.2 \
+ --hash=sha256:119b2fb462adef986483438377a13b2f42064a2a3a4161f24a0cca698a07ac8c \
+ --hash=sha256:277ccc71619d98afdd841a0e96ac9fe1593b823af481d3b0cea748e8894e0613
# via -r build/test-requirements.txt
-cloudpickle==3.0.0
+cloudpickle==3.0.0 \
+ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \
+ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882
# via -r build/test-requirements.txt
-colorama==0.4.6
+colorama==0.4.6 \
+ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \
+ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6
# via -r build/test-requirements.txt
-contourpy==1.2.1
+contourpy==1.3.0 \
+ --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \
+ --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \
+ --hash=sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd \
+ --hash=sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad \
+ --hash=sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843 \
+ --hash=sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8 \
+ --hash=sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4 \
+ --hash=sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1 \
+ --hash=sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294 \
+ --hash=sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84 \
+ --hash=sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927 \
+ --hash=sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8 \
+ --hash=sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09 \
+ --hash=sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7 \
+ --hash=sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f \
+ --hash=sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab \
+ --hash=sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b \
+ --hash=sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3 \
+ --hash=sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223 \
+ --hash=sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973 \
+ --hash=sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087 \
+ --hash=sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081 \
+ --hash=sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc \
+ --hash=sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18 \
+ --hash=sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f \
+ --hash=sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d \
+ --hash=sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2 \
+ --hash=sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41 \
+ --hash=sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67 \
+ --hash=sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6 \
+ --hash=sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b \
+ --hash=sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2 \
+ --hash=sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c \
+ --hash=sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42 \
+ --hash=sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d \
+ --hash=sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4 \
+ --hash=sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5 \
+ --hash=sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49 \
+ --hash=sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b \
+ --hash=sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7 \
+ --hash=sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102 \
+ --hash=sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb \
+ --hash=sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7 \
+ --hash=sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e \
+ --hash=sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c \
+ --hash=sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8 \
+ --hash=sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35 \
+ --hash=sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b \
+ --hash=sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14 \
+ --hash=sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb \
+ --hash=sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589 \
+ --hash=sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c \
+ --hash=sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0 \
+ --hash=sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da \
+ --hash=sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800 \
+ --hash=sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6 \
+ --hash=sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66 \
+ --hash=sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca \
+ --hash=sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb \
+ --hash=sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c \
+ --hash=sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06 \
+ --hash=sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779 \
+ --hash=sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8 \
+ --hash=sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f \
+ --hash=sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c
# via matplotlib
-cycler==0.12.1
+cycler==0.12.1 \
+ --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
+ --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
# via matplotlib
-etils[epath,epy]==1.8.0
+etils[epath,epy]==1.9.4 \
+ --hash=sha256:4387e7a4911a3b5cc4b92b99a9211386d176b43bae1dac8e2fe345fc2cb95e4b \
+ --hash=sha256:fad950414f0a1ca58c70c70915b0014f9953dd9bcf8aa951a0f75ff9becbeb24
# via -r build/requirements.in
-execnet==2.1.1
+execnet==2.1.1 \
+ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \
+ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3
# via pytest-xdist
-flatbuffers==24.3.25
+filelock==3.16.0 \
+ --hash=sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec \
+ --hash=sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609
# via -r build/test-requirements.txt
-fonttools==4.51.0
+flatbuffers==24.3.25 \
+ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \
+ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4
+ # via -r build/test-requirements.txt
+fonttools==4.53.1 \
+ --hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \
+ --hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \
+ --hash=sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f \
+ --hash=sha256:2af40ae9cdcb204fc1d8f26b190aa16534fcd4f0df756268df674a270eab575d \
+ --hash=sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60 \
+ --hash=sha256:35250099b0cfb32d799fb5d6c651220a642fe2e3c7d2560490e6f1d3f9ae9169 \
+ --hash=sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8 \
+ --hash=sha256:4824c198f714ab5559c5be10fd1adf876712aa7989882a4ec887bf1ef3e00e31 \
+ --hash=sha256:5ff7e5e9bad94e3a70c5cd2fa27f20b9bb9385e10cddab567b85ce5d306ea923 \
+ --hash=sha256:651390c3b26b0c7d1f4407cad281ee7a5a85a31a110cbac5269de72a51551ba2 \
+ --hash=sha256:6e08f572625a1ee682115223eabebc4c6a2035a6917eac6f60350aba297ccadb \
+ --hash=sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab \
+ --hash=sha256:73379d3ffdeecb376640cd8ed03e9d2d0e568c9d1a4e9b16504a834ebadc2dfb \
+ --hash=sha256:75a157d8d26c06e64ace9df037ee93a4938a4606a38cb7ffaf6635e60e253b7a \
+ --hash=sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670 \
+ --hash=sha256:7b6b35e52ddc8fb0db562133894e6ef5b4e54e1283dff606fda3eed938c36fc8 \
+ --hash=sha256:84ec3fb43befb54be490147b4a922b5314e16372a643004f182babee9f9c3407 \
+ --hash=sha256:8959a59de5af6d2bec27489e98ef25a397cfa1774b375d5787509c06659b3671 \
+ --hash=sha256:9dfdae43b7996af46ff9da520998a32b105c7f098aeea06b2226b30e74fbba88 \
+ --hash=sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f \
+ --hash=sha256:9efd176f874cb6402e607e4cc9b4a9cd584d82fc34a4b0c811970b32ba62501f \
+ --hash=sha256:a1c7c5aa18dd3b17995898b4a9b5929d69ef6ae2af5b96d585ff4005033d82f0 \
+ --hash=sha256:aae7bd54187e8bf7fd69f8ab87b2885253d3575163ad4d669a262fe97f0136cb \
+ --hash=sha256:b21952c092ffd827504de7e66b62aba26fdb5f9d1e435c52477e6486e9d128b2 \
+ --hash=sha256:b96cd370a61f4d083c9c0053bf634279b094308d52fdc2dd9a22d8372fdd590d \
+ --hash=sha256:becc5d7cb89c7b7afa8321b6bb3dbee0eec2b57855c90b3e9bf5fb816671fa7c \
+ --hash=sha256:bee32ea8765e859670c4447b0817514ca79054463b6b79784b08a8df3a4d78e3 \
+ --hash=sha256:c6e7170d675d12eac12ad1a981d90f118c06cf680b42a2d74c6c931e54b50719 \
+ --hash=sha256:c818c058404eb2bba05e728d38049438afd649e3c409796723dfc17cd3f08749 \
+ --hash=sha256:c8696544c964500aa9439efb6761947393b70b17ef4e82d73277413f291260a4 \
+ --hash=sha256:c9cd19cf4fe0595ebdd1d4915882b9440c3a6d30b008f3cc7587c1da7b95be5f \
+ --hash=sha256:d4d0096cb1ac7a77b3b41cd78c9b6bc4a400550e21dc7a92f2b5ab53ed74eb02 \
+ --hash=sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58 \
+ --hash=sha256:da33440b1413bad53a8674393c5d29ce64d8c1a15ef8a77c642ffd900d07bfe1 \
+ --hash=sha256:e013aae589c1c12505da64a7d8d023e584987e51e62006e1bb30d72f26522c41 \
+ --hash=sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4 \
+ --hash=sha256:e54f1bba2f655924c1138bbc7fa91abd61f45c68bd65ab5ed985942712864bbb \
+ --hash=sha256:e5b708073ea3d684235648786f5f6153a48dc8762cdfe5563c57e80787c29fbb \
+ --hash=sha256:e8bf06b94694251861ba7fdeea15c8ec0967f84c3d4143ae9daf42bbc7717fe3 \
+ --hash=sha256:f08df60fbd8d289152079a65da4e66a447efc1d5d5a4d3f299cdd39e3b2e4a7d \
+ --hash=sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d \
+ --hash=sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2
# via matplotlib
-fsspec==2024.3.1
+fsspec==2024.9.0 \
+ --hash=sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8 \
+ --hash=sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b
# via etils
-hypothesis==6.100.1
+hypothesis==6.112.1 \
+ --hash=sha256:93631b1498b20d2c205ed304cbd41d50e9c069d78a9c773c1324ca094c5e30ce \
+ --hash=sha256:b070d7a1bb9bd84706c31885c9aeddc138e2b36a9c112a91984f49501c567856
# via -r build/test-requirements.txt
-importlib-resources==6.4.0
+importlib-resources==6.4.5 \
+ --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \
+ --hash=sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717
# via etils
-iniconfig==2.0.0
+iniconfig==2.0.0 \
+ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \
+ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374
# via pytest
-kiwisolver==1.4.5
+kiwisolver==1.4.7 \
+ --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \
+ --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \
+ --hash=sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5 \
+ --hash=sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0 \
+ --hash=sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d \
+ --hash=sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18 \
+ --hash=sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b \
+ --hash=sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258 \
+ --hash=sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95 \
+ --hash=sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e \
+ --hash=sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383 \
+ --hash=sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02 \
+ --hash=sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b \
+ --hash=sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523 \
+ --hash=sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee \
+ --hash=sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88 \
+ --hash=sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd \
+ --hash=sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb \
+ --hash=sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4 \
+ --hash=sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e \
+ --hash=sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c \
+ --hash=sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935 \
+ --hash=sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee \
+ --hash=sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e \
+ --hash=sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038 \
+ --hash=sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d \
+ --hash=sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b \
+ --hash=sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5 \
+ --hash=sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107 \
+ --hash=sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f \
+ --hash=sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2 \
+ --hash=sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17 \
+ --hash=sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb \
+ --hash=sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674 \
+ --hash=sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706 \
+ --hash=sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327 \
+ --hash=sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3 \
+ --hash=sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a \
+ --hash=sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2 \
+ --hash=sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f \
+ --hash=sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948 \
+ --hash=sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3 \
+ --hash=sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e \
+ --hash=sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545 \
+ --hash=sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc \
+ --hash=sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f \
+ --hash=sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650 \
+ --hash=sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a \
+ --hash=sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8 \
+ --hash=sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750 \
+ --hash=sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b \
+ --hash=sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34 \
+ --hash=sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225 \
+ --hash=sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51 \
+ --hash=sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c \
+ --hash=sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3 \
+ --hash=sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde \
+ --hash=sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599 \
+ --hash=sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c \
+ --hash=sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76 \
+ --hash=sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6 \
+ --hash=sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39 \
+ --hash=sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9 \
+ --hash=sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933 \
+ --hash=sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad \
+ --hash=sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520 \
+ --hash=sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1 \
+ --hash=sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503 \
+ --hash=sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b \
+ --hash=sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36 \
+ --hash=sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a \
+ --hash=sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643 \
+ --hash=sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60 \
+ --hash=sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483 \
+ --hash=sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf \
+ --hash=sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d \
+ --hash=sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6 \
+ --hash=sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644 \
+ --hash=sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2 \
+ --hash=sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9 \
+ --hash=sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2 \
+ --hash=sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640 \
+ --hash=sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade \
+ --hash=sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a \
+ --hash=sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c \
+ --hash=sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6 \
+ --hash=sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00 \
+ --hash=sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27 \
+ --hash=sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2 \
+ --hash=sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4 \
+ --hash=sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379 \
+ --hash=sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54 \
+ --hash=sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09 \
+ --hash=sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a \
+ --hash=sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c \
+ --hash=sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89 \
+ --hash=sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407 \
+ --hash=sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904 \
+ --hash=sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376 \
+ --hash=sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583 \
+ --hash=sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278 \
+ --hash=sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a \
+ --hash=sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d \
+ --hash=sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935 \
+ --hash=sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb \
+ --hash=sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895 \
+ --hash=sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b \
+ --hash=sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417 \
+ --hash=sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608 \
+ --hash=sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07 \
+ --hash=sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05 \
+ --hash=sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a \
+ --hash=sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d \
+ --hash=sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052
# via matplotlib
-markdown-it-py==3.0.0
+markdown-it-py==3.0.0 \
+ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \
+ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb
# via rich
-matplotlib==3.8.3
+matplotlib==3.9.2 ; python_version >= "3.11" \
+ --hash=sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21 \
+ --hash=sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5 \
+ --hash=sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697 \
+ --hash=sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9 \
+ --hash=sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca \
+ --hash=sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64 \
+ --hash=sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e \
+ --hash=sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03 \
+ --hash=sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae \
+ --hash=sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa \
+ --hash=sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3 \
+ --hash=sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e \
+ --hash=sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a \
+ --hash=sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc \
+ --hash=sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea \
+ --hash=sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b \
+ --hash=sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e \
+ --hash=sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447 \
+ --hash=sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b \
+ --hash=sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92 \
+ --hash=sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb \
+ --hash=sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66 \
+ --hash=sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9 \
+ --hash=sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7 \
+ --hash=sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2 \
+ --hash=sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30 \
+ --hash=sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d \
+ --hash=sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7 \
+ --hash=sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4 \
+ --hash=sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41 \
+ --hash=sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2 \
+ --hash=sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556 \
+ --hash=sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f \
+ --hash=sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772 \
+ --hash=sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c \
+ --hash=sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a \
+ --hash=sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51 \
+ --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \
+ --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \
+ --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413
# via -r build/requirements.in
-mdurl==0.1.2
+mdurl==0.1.2 \
+ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
+ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
-ml-dtypes==0.4.0
+ml-dtypes==0.5.0 \
+ --hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \
+ --hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \
+ --hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \
+ --hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \
+ --hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \
+ --hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \
+ --hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \
+ --hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \
+ --hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \
+ --hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \
+ --hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \
+ --hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \
+ --hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \
+ --hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \
+ --hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \
+ --hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \
+ --hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \
+ --hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \
+ --hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \
+ --hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \
+ --hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599
# via -r build/requirements.in
-mpmath==1.3.0
+mpmath==1.4.0a1 \
+ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
+ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1
# via -r build/test-requirements.txt
-numpy==1.26.4
+numpy==2.1.1 ; python_version >= "3.13" \
+ --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \
+ --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \
+ --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \
+ --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \
+ --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \
+ --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \
+ --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \
+ --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \
+ --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \
+ --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \
+ --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \
+ --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \
+ --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \
+ --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \
+ --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \
+ --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \
+ --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \
+ --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \
+ --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \
+ --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \
+ --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \
+ --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \
+ --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \
+ --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \
+ --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \
+ --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \
+ --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \
+ --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \
+ --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \
+ --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \
+ --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \
+ --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \
+ --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \
+ --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \
+ --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \
+ --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \
+ --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \
+ --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \
+ --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \
+ --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \
+ --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \
+ --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \
+ --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \
+ --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \
+ --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \
+ --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \
+ --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \
+ --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \
+ --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \
+ --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \
+ --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \
+ --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \
+ --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b
# via
# -r build/requirements.in
# -r build/test-requirements.txt
@@ -55,52 +426,315 @@ numpy==1.26.4
# ml-dtypes
# opt-einsum
# scipy
-opt-einsum==3.3.0
+opt-einsum==3.3.0 \
+ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
+ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
# via -r build/requirements.in
-packaging==24.0
+packaging==24.1 \
+ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \
+ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124
# via
# build
# matplotlib
# pytest
-pillow==10.3.0
+pillow==10.4.0 \
+ --hash=sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885 \
+ --hash=sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea \
+ --hash=sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df \
+ --hash=sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5 \
+ --hash=sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c \
+ --hash=sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d \
+ --hash=sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd \
+ --hash=sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06 \
+ --hash=sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908 \
+ --hash=sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a \
+ --hash=sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be \
+ --hash=sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0 \
+ --hash=sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b \
+ --hash=sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80 \
+ --hash=sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a \
+ --hash=sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e \
+ --hash=sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9 \
+ --hash=sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696 \
+ --hash=sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b \
+ --hash=sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309 \
+ --hash=sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e \
+ --hash=sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab \
+ --hash=sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d \
+ --hash=sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060 \
+ --hash=sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d \
+ --hash=sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d \
+ --hash=sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4 \
+ --hash=sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3 \
+ --hash=sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6 \
+ --hash=sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb \
+ --hash=sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94 \
+ --hash=sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b \
+ --hash=sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496 \
+ --hash=sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0 \
+ --hash=sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319 \
+ --hash=sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b \
+ --hash=sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856 \
+ --hash=sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef \
+ --hash=sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680 \
+ --hash=sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b \
+ --hash=sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42 \
+ --hash=sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e \
+ --hash=sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597 \
+ --hash=sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a \
+ --hash=sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8 \
+ --hash=sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3 \
+ --hash=sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736 \
+ --hash=sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da \
+ --hash=sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126 \
+ --hash=sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd \
+ --hash=sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5 \
+ --hash=sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b \
+ --hash=sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026 \
+ --hash=sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b \
+ --hash=sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc \
+ --hash=sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46 \
+ --hash=sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2 \
+ --hash=sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c \
+ --hash=sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe \
+ --hash=sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984 \
+ --hash=sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a \
+ --hash=sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70 \
+ --hash=sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca \
+ --hash=sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b \
+ --hash=sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91 \
+ --hash=sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3 \
+ --hash=sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84 \
+ --hash=sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1 \
+ --hash=sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5 \
+ --hash=sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be \
+ --hash=sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f \
+ --hash=sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc \
+ --hash=sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9 \
+ --hash=sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e \
+ --hash=sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141 \
+ --hash=sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef \
+ --hash=sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22 \
+ --hash=sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27 \
+ --hash=sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e \
+ --hash=sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1
# via
# -r build/test-requirements.txt
# matplotlib
-pluggy==1.4.0
+pluggy==1.5.0 \
+ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \
+ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669
# via pytest
-portpicker==1.6.0
+portpicker==1.6.0 \
+ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \
+ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa
# via -r build/test-requirements.txt
-psutil==5.9.8
+psutil==6.0.0 \
+ --hash=sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35 \
+ --hash=sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0 \
+ --hash=sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c \
+ --hash=sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1 \
+ --hash=sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3 \
+ --hash=sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c \
+ --hash=sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd \
+ --hash=sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3 \
+ --hash=sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0 \
+ --hash=sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2 \
+ --hash=sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6 \
+ --hash=sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d \
+ --hash=sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c \
+ --hash=sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0 \
+ --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \
+ --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \
+ --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0
# via portpicker
-pygments==2.17.2
+pygments==2.18.0 \
+ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \
+ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a
# via rich
-pyparsing==3.1.2
+pyparsing==3.2.0b1 \
+ --hash=sha256:51e00c907f7b2ac2d2c35c4d431e944c525ddcfd58b09517f308f40d70e0ddca \
+ --hash=sha256:ecf0805530839936196a802cd6d6d65ffa9328eebdc8ee5b8f4b358be5f16666
# via matplotlib
-pyproject-hooks==1.0.0
+pyproject-hooks==1.1.0 \
+ --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \
+ --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2
# via build
-pytest==8.1.1
+pytest==8.3.3 \
+ --hash=sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181 \
+ --hash=sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2
# via pytest-xdist
-pytest-xdist==3.5.0
+pytest-xdist==3.6.1 \
+ --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \
+ --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d
# via -r build/test-requirements.txt
-python-dateutil==2.9.0.post0
+python-dateutil==2.9.0.post0 \
+ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \
+ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427
# via matplotlib
-rich==13.7.1
+rich==13.8.1 \
+ --hash=sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06 \
+ --hash=sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a
# via -r build/test-requirements.txt
-scipy==1.13.1
+scipy==1.14.1 \
+ --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \
+ --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \
+ --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \
+ --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \
+ --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \
+ --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \
+ --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \
+ --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \
+ --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \
+ --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \
+ --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \
+ --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \
+ --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \
+ --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \
+ --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \
+ --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \
+ --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \
+ --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \
+ --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \
+ --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \
+ --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \
+ --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \
+ --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \
+ --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \
+ --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \
+ --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \
+ --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \
+ --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \
+ --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \
+ --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \
+ --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \
+ --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \
+ --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2
# via -r build/requirements.in
-six==1.16.0
+six==1.16.0 \
+ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \
+ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254
# via python-dateutil
-sortedcontainers==2.4.0
+sortedcontainers==2.4.0 \
+ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
+ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
# via hypothesis
-typing-extensions==4.11.0
+typing-extensions==4.12.2 \
+ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
+ --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
# via etils
-wheel==0.43.0
+wheel==0.44.0 \
+ --hash=sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f \
+ --hash=sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49
# via -r build/test-requirements.txt
-zipp==3.18.1
+zipp==3.20.2 \
+ --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \
+ --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29
# via etils
-zstandard==0.22.0
+zstandard==0.23.0 \
+ --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
+ --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \
+ --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \
+ --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \
+ --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \
+ --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \
+ --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \
+ --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \
+ --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \
+ --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \
+ --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \
+ --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \
+ --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \
+ --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \
+ --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \
+ --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \
+ --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \
+ --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \
+ --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \
+ --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \
+ --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \
+ --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \
+ --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \
+ --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \
+ --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \
+ --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \
+ --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \
+ --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \
+ --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \
+ --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \
+ --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \
+ --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \
+ --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \
+ --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \
+ --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \
+ --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \
+ --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \
+ --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \
+ --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \
+ --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \
+ --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \
+ --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \
+ --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \
+ --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \
+ --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \
+ --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \
+ --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \
+ --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \
+ --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \
+ --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \
+ --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \
+ --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \
+ --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \
+ --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \
+ --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \
+ --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \
+ --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \
+ --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \
+ --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \
+ --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \
+ --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \
+ --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \
+ --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \
+ --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \
+ --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \
+ --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \
+ --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \
+ --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \
+ --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \
+ --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \
+ --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \
+ --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \
+ --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \
+ --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \
+ --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \
+ --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \
+ --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \
+ --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \
+ --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \
+ --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \
+ --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \
+ --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \
+ --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \
+ --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \
+ --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \
+ --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \
+ --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \
+ --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \
+ --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \
+ --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \
+ --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \
+ --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \
+ --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \
+ --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \
+ --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \
+ --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \
+ --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69
# via -r build/requirements.in
# The following packages are considered to be unsafe in a requirements file:
-setuptools==69.2.0
- # via -r build/test-requirements.txt
+setuptools==70.3.0 \
+ --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \
+ --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc
+ # via
+ # -r build/requirements.in
+ # -r build/test-requirements.txt
diff --git a/build/test-requirements.txt b/build/test-requirements.txt
index 4f9d19e76ba2..0c9aa086f109 100644
--- a/build/test-requirements.txt
+++ b/build/test-requirements.txt
@@ -12,4 +12,5 @@ portpicker
pytest-xdist
wheel
rich
-setuptools
+# TODO(ybaturina): remove setuptools version
+setuptools<71.0.0
diff --git a/docs/_static/pallas/sparse/block_coo.svg b/docs/_static/pallas/sparse/block_coo.svg
new file mode 100644
index 000000000000..474dfcb64d7a
--- /dev/null
+++ b/docs/_static/pallas/sparse/block_coo.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/_static/pallas/sparse/prefetch_map.svg b/docs/_static/pallas/sparse/prefetch_map.svg
new file mode 100644
index 000000000000..08fdd2c1cf39
--- /dev/null
+++ b/docs/_static/pallas/sparse/prefetch_map.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/_static/pallas/sparse/sparse_matmul.svg b/docs/_static/pallas/sparse/sparse_matmul.svg
new file mode 100644
index 000000000000..06a24317cfe1
--- /dev/null
+++ b/docs/_static/pallas/sparse/sparse_matmul.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
index 061262a1410e..6d4bc5d87854 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -134,6 +134,7 @@ def _do_not_evaluate_in_jax(
'pallas/quickstart.md',
'pallas/tpu/pipelining.md',
'pallas/tpu/distributed.md',
+ 'pallas/tpu/sparse.md',
'pallas/tpu/matmul.md',
'jep/9407-type-promotion.md',
'autodidax.md',
@@ -224,6 +225,7 @@ def _do_not_evaluate_in_jax(
'pallas/quickstart.*',
'pallas/tpu/pipelining.*',
'pallas/tpu/distributed.*',
+ 'pallas/tpu/sparse.*',
'pallas/tpu/matmul.*',
'sharded-computation.*',
'distributed_data_loading.*'
diff --git a/docs/developer.md b/docs/developer.md
index 954cf7982a3a..40ad51e873ca 100644
--- a/docs/developer.md
+++ b/docs/developer.md
@@ -31,23 +31,33 @@ guidance on pip installation (e.g., for GPU and TPU support).
### Building `jaxlib` from source
+```{warning}
+While it should typically be possible to compile `jaxlib` from source using
+most modern compilers, the builds are only tested using clang. Pull requests
+are welcomed to improve support for different toolchains, but other compilers
+are not actively supported.
+```
+
To build `jaxlib` from source, you must also install some prerequisites:
-- a C++ compiler (g++, clang, or MSVC)
+- A C++ compiler:
- On Ubuntu or Debian you can install the necessary prerequisites with:
+ As mentioned in the box above, it is best to use a recent version of clang
+ (at the time of writing, the version we test is 18), but other compilers (e.g.
+ g++ or MSVC) may work.
- ```
- sudo apt install g++ python python3-dev
- ```
+ On Ubuntu or Debian you can follow the instructions from the
+ [LLVM](https://apt.llvm.org/) documentation to install the latest stable
+ version of clang.
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
See below for Windows build instructions.
-- there is no need to install Python dependencies locally, as your system
- Python will be ignored during the build; please check
+- Python: for running the build helper script. Note that there is no need to
+ install Python dependencies locally, as your system Python will be ignored
+ during the build; please check
[Managing hermetic Python](#managing-hermetic-python) for details.
To build `jaxlib` for CPU or TPU, you can run:
@@ -86,7 +96,7 @@ the `build/build.py` script itself will be processed by your system Python
interpreter. By default, the wheel is written to the `dist/` subdirectory of the
current directory.
-* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN
+* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN
versions in the configuration options. Bazel will download them and use as
target dependencies.
@@ -108,6 +118,8 @@ current directory.
--bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl"
```
+ Please see the full list of instructions in [XLA documentation](https://github.com/openxla/xla/blob/main/docs/hermetic_cuda.md).
+
* JAX versions prior v.0.4.32: you must have CUDA and CUDNN installed and
provide paths to them using configuration options.
@@ -257,8 +269,8 @@ together with their corresponding hashes are specified in
`build/requirements_lock_.txt` files (
e.g. `build/requirements_lock_3_12.txt` for `Python 3.12`).
-To update the lock files, make sure `build/requirements.in` contains the desired
-direct dependencies list and then execute the following command (which will call
+To update the lock files, make sure `build/requirements.in` contains the desired
+direct dependencies list and then execute the following command (which will call
[pip-compile](https://pypi.org/project/pip-tools/) under the hood):
```
@@ -380,7 +392,7 @@ sudo apt-get install libopenblas-dev -y
example for `Python 3.13` it should have something
like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the
`requirements` parameter must always be in `"major.minor"` version format, so
- even if you are building Python version `3.13.0rc1` the corresponding
+ even if you are building Python version `3.13.0rc1` the corresponding
`requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`,
**not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`.
@@ -695,7 +707,7 @@ using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync`
notebooks; for example:
```
-pip install jupytext==1.16.0
+pip install jupytext==1.16.4
jupytext --sync docs/notebooks/thinking_in_jax.ipynb
```
diff --git a/docs/export/export.md b/docs/export/export.md
index 9e6597cef49b..0ca1a64800e0 100644
--- a/docs/export/export.md
+++ b/docs/export/export.md
@@ -732,10 +732,7 @@ that live in jaxlib):
from jax._src.lib import version as jaxlib_version
def my_lowering_rule(ctx: LoweringRuleContext, ...):
- lowering_parameters = ctx.module_context.lowering_parameters
- forward_compat_mode = (lowering_parameters.for_export and
- not lowering_parameters.export_ignore_forward_compatibility)
- if forward_compat_mode or jaxlib_version < (0, 4, 31):
+ if ctx.is_forward_compat() or jaxlib_version < (0, 4, 31):
# this is the old lowering, using target T, while we
# are in forward compatibility mode for T, or we
# are in OSS and are using an old jaxlib.
diff --git a/docs/installation.md b/docs/installation.md
index 7a12f7c541a2..93df4a240a55 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -282,13 +282,13 @@ pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/j
- Google Cloud TPU:
```bash
-pip install -U --pre jax[tpu] jaxlib libtpu-nightly -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
- NVIDIA GPU (CUDA 12):
```bash
-pip install -U --pre jax[cuda12] jaxlib jax-cuda12-plugin jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
+pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```
- NVIDIA GPU (CUDA 12) legacy:
diff --git a/docs/jax.rst b/docs/jax.rst
index b2c4ba60739b..a8781d31a448 100644
--- a/docs/jax.rst
+++ b/docs/jax.rst
@@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`)
jit
disable_jit
ensure_compile_time_eval
- xla_computation
make_jaxpr
eval_shape
ShapeDtypeStruct
diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md
index c1ed1385bbbc..43ba3ebd6afb 100644
--- a/docs/pallas/CHANGELOG.md
+++ b/docs/pallas/CHANGELOG.md
@@ -11,15 +11,31 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->
-## Released with jax 0.4.32
+## Released with jax 0.4.34
* Changes
- * The kernel function is not allowed to close over constants. Instead, all the needed arrays
- must be passed as inputs, with proper block specs ({jax-issue}`#22746`).
+
+ * {func}`jax.experimental.pallas.debug_print` no longer requires all arguments
+ to be scalars. The restrictions on the arguments are backend-specific:
+ Non-scalar arguments are currently only supported on GPU, when using Triton.
* Deprecations
-* New functionality:
+* New functionality
+
+ * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`,
+ a PyTree specifying backend-specific temporary objects needed by the
+ kernel, for example, buffers, synchronization primitives etc.
+
+## Released with jax 0.4.33 (September 16, 2024)
+
+## Released with jax 0.4.32 (September 11, 2024)
+
+* Changes
+ * The kernel function is not allowed to close over constants. Instead, all the needed arrays
+ must be passed as inputs, with proper block specs ({jax-issue}`#22746`).
+
+* New functionality
* Improved error messages for mistakes in the signature of the index map functions,
to include the name and source location of the index map.
@@ -44,10 +60,6 @@ Remember to align the itemized text with the first line of an item within a list
* Previously it was possible to import many APIs that are meant to be
private, as `jax.experimental.pallas.pallas`. This is not possible anymore.
-
-* Deprecations
-
-
* New Functionality
* Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`.
* Improved error messages for the {func}`jax.experimental.pallas.pallas_call`
@@ -73,7 +85,3 @@ Remember to align the itemized text with the first line of an item within a list
* Added checkify support for {func}`jax.experimental.pallas.pallas_call` in
interpret mode ({jax-issue}`#21862`).
* Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`).
-
-
-
-
diff --git a/docs/pallas/async_note.md b/docs/pallas/async_note.md
new file mode 100644
index 000000000000..96370ee48625
--- /dev/null
+++ b/docs/pallas/async_note.md
@@ -0,0 +1,675 @@
+# Pallas Async Operations
+
+## Background \+ Motivation
+
+We’d like to expose APIs in Pallas to explicitly overlap computation and communication *across multiple kernels*.
+
+### XLA Async Decomposition
+
+As motivation, consider the following JAX pseudocode:
+
+```py
+def f(x):
+ y = ppermute(x)
+ z = x + 1
+ return y, z
+```
+
+In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by:
+
+1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future.
+2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`,
+
+resulting in the following program:
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1 # happens at the same time as ppermute
+ y = ppermute_done(fut)
+ return y, z
+```
+
+### Async ops inside kernels
+
+Now imagine we aren’t using XLA’s `ppermute` but have our own custom Pallas `ppermute`.
+
+```py
+def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
+ right_neighbor = ...
+ descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
+ descriptor.start()
+ descriptor.wait_send()
+ descriptor.wait_recv()
+
+def ppermute(x):
+ return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)
+```
+
+Currently, we cannot decompose `ppermute` into a `start/done` pair as XLA does, so instead we explicitly **fuse** the `x + 1` into the kernel.
+
+```py
+def add_one(x_ref, z_ref):
+ z_ref[...] = x_ref[...] + 1
+
+def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
+ right_neighbor = ...
+ descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
+ descriptor.start()
+
+ # Explicitly schedule inner kernel between start/wait
+ pltpu.emit_pipeline(add_one)(x_ref, z_ref)
+
+ descriptor.wait_send()
+ descriptor.wait_recv()
+
+def ppermute_and_add_one(x):
+ return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)
+
+```
+
+The goal is to enable writing separate kernels for starting the `ppermute` and waiting on it to complete, so that we can use a regular old `x + 1` in between (or whatever compute we want). This makes the code more readable, maintainable, and less bug-prone.
+
+## How do we implement decomposed Pallas async operations (on TPU)?
+
+The main thing to figure out when implementing decomposed async operations in Pallas is what the `future` that is passed between them contains. Specifically, it must contain some important state about the operation happening in the background.
+
+If we look at the Pallas code, we can see that we need a “descriptor” to both start and wait on a remote copy. Can we plumb this descriptor out of the Pallas kernel, and then pass it into another one? Well kinda. The underlying TPU hardware tracks async op progress via a pair of semaphores: `send_sem` enables us to wait on when a device is done sending data to its neighbor and `recv_sem` tracks the data transfer sent to a device from their neighbor. If we imagine writing a start kernel and a done kernel, all we’d need to pass from the start to the done would be the semaphores and some information about how much to wait on those semaphores.
+
+We can do this via extending Pallas to support returning semaphores from kernels.
+
+```py
+def ppermute_start_kernel(
+ in_ref, send_sem, recv_sem, out_ref, *, axis_name,
+):
+ axis_size = jax.lax.psum(1, axis_name)
+ left_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
+ )
+ right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
+ pltpu.semaphore_wait(barrier_sem, 1)
+ pltpu.make_async_remote_copy(
+ in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
+ ).start()
+
+def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
+ send_sem, recv_sem, out = pl.pallas_call(
+ functools.partial(ppermute_start_kernel, axis_name=axis_name),
+ out_shape=(
+ pltpu.SemaphoreType.DMA(()),
+ pltpu.SemaphoreType.DMA(()),
+ jax.ShapeDtypeStruct(
+ x.shape,
+ dtype=x.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ),
+ )(x)
+ return send_sem, recv_sem, out
+```
+
+Note that something subtle is happening here. Pallas is telling XLA that it would like some outputs to be semaphores (a.k.a. sync flags) and XLA will treat them as “reserved” (e.g. while they are alive in the XLA program, those sync flags cannot be allocated by other kernels). They behave similarly to barrier semaphores, which are reserved semaphores managed by XLA.
+
+Another thing to notice is that we return the output buffer `out` from the start kernel *while it’s being actively copied into*.
+
+Now we write the `done` kernel that performs the blocking operation. We pass `out` into the kernel to compute the shape needed to block on the semaphore.
+
+```py
+def ppermute_done_kernel(ref, send_sem, recv_sem, _):
+ pltpu.make_async_copy(ref, ref, send_sem).wait()
+ pltpu.make_async_copy(ref, ref, recv_sem).wait()
+
+def ppermute_done(send_sem, recv_sem, out) ->Array:
+ out = pl.pallas_call(
+ ppermute_done_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(
+ out.shape,
+ dtype=out.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={0:0}
+ )(out, send_sem, recv_sem)
+ return out
+```
+
+Note: we i/o alias the output buffer here to guarantee that the consumers are downstream of the `ppermute_done`.
+
+We now can implement the decomposed collective permute.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1 # happens at the same time as ppermute
+ y = ppermute_done(fut)
+ return y, z
+```
+
+***OR CAN WE?***
+
+## Why *doesn’t* this work?
+
+There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level.
+
+1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX.
+2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory.
+3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness.
+
+We will go over these issues one by one and suggest fixes.
+
+### Scheduling
+
+How do we explicitly force ops to happen in a particular order in JAX? Note that this is not a Pallas specific problem, and if we had async ops implemented using an alternative method, we’d still run into this.
+
+One way is to introduce an optimization barrier into the XLA program. The optimization barrier will prevent XLA moving ops around it.
+
+Here’s our original code:
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+```
+
+XLA could choose to execute `x + 1` in any of three places:
+
+```py
+def f(x):
+ z = x + 1
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ return y, z
+
+# OR
+
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+
+# OR
+
+def f(x):
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ z = x + 1
+ return y, z
+```
+
+To force the `x + 1` to happen between the `ppermute` ops, we can use `optimization_barrier`, which is semantically the identity function (i.e. `lambda x: x`) but introduces an explicit data dependency between values. Specifically, if we make the `x` that is used in `x + 1` dependent on the `fut` returned by `ppermute_start`, it must happen after `ppermute_start`.
+
+We also introduce a dependency that forces the output value `y` to depend on `z`.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ x, fut = optimization_barrier((x, fut)) # x now depends on fut
+ z = x + 1
+ z, fut = optimization_barrier((z, fut)) # fut now depends on z
+ y = ppermute_done(fut)
+ return y, z
+```
+
+`optimization_barrier` is a good enough hammer for us to explicitly write out schedules.
+
+### Lifetimes
+
+Let’s look at our original code again and assume the ops are happening in the correct order.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+```
+
+Let’s look at which point in the program XLA believes it is okay to free the buffer for `x`. It would be the point after which `x` is no longer used, specifically after `z = x + 1`.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ # XLA can free x here!
+ y = ppermute_done(fut)
+ return y, z
+```
+
+If XLA frees `x` after `z = x + 1` has completed, we run into a very bad problem. The `ppermute` could still be actively copying `x` to the neighbor after `z = x + 1` which means if `x` is freed, the `ppermute` will be reading from garbage memory\!
+
+How do we extend `x`’s lifetime to the `ppermute_done`? Well we can introduce a data dependency\! We need to modify our kernels a little bit to make this happen.
+
+First, we rewrite `ppermute_start` to return `x`, aliasing it through the kernel.
+
+```py
+def ppermute_start_kernel(
+ in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
+):
+ axis_size = jax.lax.psum(1, axis_name)
+ left_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
+ )
+ right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
+ pltpu.semaphore_wait(barrier_sem, 1)
+ pltpu.make_async_remote_copy(
+ in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
+ ).start()
+
+def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
+ send_sem, recv_sem, x, out = pl.pallas_call(
+ functools.partial(ppermute_start_kernel, axis_name=axis_name),
+ out_shape=(
+ pltpu.SemaphoreType.DMA(()),
+ pltpu.SemaphoreType.DMA(()),
+ jax.ShapeDtypeStruct(
+ x.shape,
+ dtype=x.dtype,
+ ),
+ jax.ShapeDtypeStruct(
+ x.shape,
+ dtype=x.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ),
+ input_output_aliases={0:2}
+ )(x)
+ return send_sem, recv_sem, x, out
+```
+
+We then have `ppermute_done` take in `x` and do nothing with it.
+
+```py
+def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
+ pltpu.make_async_copy(ref, ref, send_sem).wait()
+ pltpu.make_async_copy(ref, ref, recv_sem).wait()
+
+def ppermute_done(send_sem, recv_sem, x, out) ->Array:
+ out = pl.pallas_call(
+ ppermute_done_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(
+ out.shape,
+ dtype=out.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={1:0}
+ )(x, out, send_sem, recv_sem)
+ return out
+
+```
+
+Now when we write
+
+```py
+def f(x):
+ *sems, x ,out = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(*sems, x, out)
+ return y, z
+```
+
+XLA can no longer free `x` because it is an input to `ppermute_done`\! This means that `x`’s lifetime is tied to the `ppermute` and this code is now correct.
+
+### Defensive copies
+
+XLA, in its buffer assignment pass, analyzes which buffers are aliased to each other and inserts copies whenever an operation that aliases one of its inputs is not the final consumer of that input.
+
+#### Background
+
+Here’s a simple example. Let’s say we have an op `add_one_inplace` which takes in an array and adds one, but promises to do it in-place.
+
+The following code would be legal.
+
+```py
+def f():
+ x = jnp.arange(...)
+ y = add_one_inplace(x) return y
+```
+
+However, if `x` had a separate consumer as well, the program may not execute correctly.
+
+```py
+def f():
+ x = jnp.arange(...)
+ y = add_one_inplace(x)
+ return y, x * 2 # another x consumer!
+```
+
+This is because `x * 2` operates on the original `x` but `add_one_inplace` clobbers the value in `x`. `x * 2` needs to make sure to read the original values of `x`, not the ones after we’ve incremented it by 1\. XLA notices this and inserts a `copy` op (which is semantically the identity but the input and output buffers will be different).
+
+```py
+def f(x):
+ x2 = copy(x)
+ y = add_one_inplace(x2)
+ return y, x * 2
+```
+
+This pass in XLA ensures correctness in the presence of ops that perform in-place updates by forcing them to effectively be out-of-place with `copy` ops.
+
+#### Copies with downstream ops
+
+Let’s revisit our example where we add 1 while `ppermute`ing.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+```
+
+If we unpack the future into its components, we’ll see the the aliasing patterns:
+
+```py
+def f(x):
+ *sems, x2, y = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done((*sems, x2, y))
+ return y, z
+```
+
+We know that `x` is left unchanged by `ppermute_start` (that is, `x` is identical to `x2`), but XLA does not. In fact, it looks like our `add_one_inplace` example to XLA, where it conservatively assumes that `ppermute_start` mutated `x` and `x2` is the new aliased result. Therefore, when we do `z = x + 1`, we run into a consumer of the original buffer. XLA therefore introduces a copy\!
+
+```py
+def f(x):
+ x2 = copy(x)
+ *sems, x2, y = ppermute_start(x2)
+ z = x + 1
+ y = ppermute_done((*sems, x2, y))
+ return y, z
+```
+
+This copy is unnecessary because we know that `x2` is unchanged from `x`. In order to remove this copy, we’d need some mechanism to inform XLA we are just forwarding a value. However, in the absence of that we can rewrite our program a bit to explicitly use `x2` instead of `x`.
+
+```py
+def f(x):
+ *sems, x2, y = ppermute_start(x)
+ z = x2 + 1
+ y = ppermute_done((*sems, x2, y))
+ return y, z
+```
+
+Now, XLA doesn’t see a separate consumer of `x` so no more copy is introduced. However, this comes at a major downside in that it forces us to unpack the future coming from `ppermute_start`. It couples the lifetime problem to the copying problem.
+
+#### Loop aliasing
+
+Let’s consider a slightly more advanced example. Let’s implement a function that uses a `while_loop` with `ppermute` to send values around a ring.
+
+```py
+def f(x):
+ def body(i, x):
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+One implementation detail of `fori_loop` is that the inputs and outputs buffers are automatically aliased to each other. Note that we are setting up some additional aliasing in the `ppermute_start` and `ppermute_done` ops. Let’s run our own “buffer assignment” by coloring each of the values in the program to determine how many unique buffers we need.
+
+First, we’ll unpack the `fut` tuple that has the aliased `x` and `out` buffers.
+
+```py
+def f(x):
+ def body(i, x):
+ *sems, x, y = ppermute_start(x)
+ y = ppermute_done(*sems, x, y)
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+Let’s now color each of the values according to the unique buffer they are assigned. We have the input/output aliasing coming from `fori_loop`, the `x` aliasing coming from `ppermute_start` and the `y` aliasing coming from `ppermute_done`.
+
+```py
+def f(x):
+ def body(i, x):
+ *sems, x, y = ppermute_start(x)
+ y = ppermute_done((*sems, x, y))
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy.
+
+```py
+def f(x):
+ def body(i, x):
+ x = copy(x)
+ *sems, x, y = ppermute_start(x)
+ y = ppermute_done((*sems, x, y))
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+This copy means `x` and `y` are no longer aliased to each other and the program will be correct. However, do we need this copy? How do we introduce a double buffer to avoid expensive copies each iteration? The answer is unrolling\!
+
+We’ll manually unroll our code.
+
+```py
+def f(x):
+ def body(i, x):
+ *sems, x, x2 = ppermute_start(x)
+ x2 = ppermute_done((*sems, x, x2))
+
+ *sems, x2, y = ppermute_start(x2)
+ y = ppermute_done((*sems, x2, y))
+ return y
+ return fori_loop(0, 4, body, x)
+```
+
+Now if we were to run the same alias analysis, we’ll find that the buffers all no longer alias to each other and that we won’t need to insert defensive copies to be correct.
+
+Therefore, the simple solution to removing these copies is to use `fori_loop` with `unroll >= 2`.
+
+```py
+def f(x):
+ def body(i, x):
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ return y
+ return fori_loop(0, 8, body, x, unroll=2)
+```
+
+That’s sufficient to implement this loop without extra copies\!
+
+#### Passing futures across loop boundaries
+
+Let’s now look at an even more advanced example. We’ll implement the same program as before but stagger the loop, where we begin the `ppermute` in a prologue before the loop, and wait on the `ppermute` at the beginning of the loop.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ x = ppermute_done(fut)
+ fut = ppermute_start(x)
+ return fut
+ fut = fori_loop(0, 7, body, fut)
+ return ppermute_done(fut)
+```
+
+In this example, rather than passing a value `x` from one loop to another we are passing a future value.
+
+Let’s unpack the future again to see what’s happening.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ *sems, x, out = fut
+ x = ppermute_done((*sems, x, out))
+ (*sems, x, out) = ppermute_start(x)
+ return (*sems, x, out)
+ (*sems, x, out) = fori_loop(0, 7, body, x)
+ return ppermute_done((*sems, x, out))
+```
+
+So we’re explicitly threading the semaphores, the input buffer, and the target output buffer as a loop carry. What happens if we run alias analysis now? Well, we’ll run into the same aliasing issue as in the previous section where `x` and `out` will be aliased to each other. XLA will introduce a copy.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ *sems, x, out = fut
+ out = copy(out)
+ x = ppermute_done((*sems, x, out))
+ (*sems, x, out) = ppermute_start(x)
+ return (*sems, x, out)
+ (*sems, x, out) = fori_loop(0, 7, body, x)
+ return ppermute_done((*sems, x, out))
+```
+
+In this case, we inserted a copy on `out`. However, this is a really bad scenario because `out` is being actively copied into\! Even if we insert a copy on `x`, we will also run into issues because then `x`’s lifetime will not extend to the `ppermute_done`. This is very very bad\! We will not only get copies, but we will also get incorrect results\!
+
+The solution, as we observed before, is to avoid the copies by avoiding aliasing all the buffers via unrolling. So, if we do:
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ x = ppermute_done(fut)
+ fut = ppermute_start(x)
+ return fut
+ fut = fori_loop(0, 7, body, x, unroll=2)
+ return ppermute_done(fut)
+```
+
+our program should now be correct.
+
+### Putting it all together
+
+So we’ve come up with some rules of thumb:
+
+1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value.
+2. Use `unroll >= 2` when doing `ppermute`s in a loop body.
+
+Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result.
+
+```py
+def f(x):
+ out = jnp.zeros_like(x)
+ fut = (*sems, x, out) = ppermute_start(x)
+ out = out + x
+ def body(i, carry):
+ out, fut = carry
+ x = ppermute_done(fut)
+ fut = (*sems, x, out) = ppermute_start(x)
+ out = out + x
+ return out, fut
+ out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
+ return out, ppermute_done(fut)
+```
+
+Note that in this example, we don’t need `optimization_barrier`s because the loop boundary acts as a scheduling barrier, splitting up the `start`s and `done`s.
+
+That’s it, we are done\! This will be the official API for doing async ops in Pallas. Thank you everyone\! Mission accomplished\!
+
+***OR IS IT?***
+
+## Revenge of the State
+
+While it seems we have worked around copies and incorrectness issues by using some clever tricks, we are still in an awkward position. This API is powerful, but has many many footguns and caveats. There are likely far many more edge cases we will need to deal with that even require deep knowledge of XLA to predict or understand. Should we release an API like this? Or is there an alternative?
+
+Well, the answer may have been in front of us this whole time.
+
+Let’s run through this whole exercise one more time, *except*, let’s write the stateful version. This means each of our custom async ops now operate on `Ref`s instead of values.
+
+```py
+def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
+ ...
+
+def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
+ ...
+```
+
+Let’s assume we can implement these in Pallas and see what our new programs will look like. Let’s start with a basic collective permute:
+
+```py
+def f(x):
+ x_ref = make_ref(x)
+ y_ref = make_ref(zeros_like(x))
+ fut = ppermute_start_stateful(x_ref, y_ref)
+ ppermute_done_stateful(*fut, x_ref, y_ref)
+ return y_ref[...]
+```
+
+It’s a little bit more verbose than our original value-based version, but it has a few key differences. The first is that we create an “empty” `Ref` to receive the result of the `ppermute`, unlike the value-based version, which creates a value for us. One neat thing is that the lifetime of `x_ref` is clear here: it lives until `ppermute_done_stateful`. We don’t need to “sneak” the `x` value into the op like we did before.
+
+Another difference becomes more clear when we try adding an op between the `start/done`.
+
+```py
+def f(x):
+ x_ref = make_ref(x)
+ y_ref = make_ref(zeros_like(x))
+ fut = ppermute_start_stateful(x_ref, y_ref)
+ x_ref[...] += 1
+ ppermute_done_stateful(*fut, x_ref, y_ref)
+ return y_ref[...]
+```
+
+Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO.
+
+The final key difference is evident when we try our loop examples.
+
+```py
+def f(x):
+ x_ref = make_ref(x)
+ y_ref = make_ref(zeros_like(x))
+ def body(i, _):
+ fut = ppermute_start_stateful(x_ref, y_ref)
+ ppermute_done_stateful(*fut, x_ref, y_ref)
+ # Now switch to y_ref -> x_ref
+ fut = ppermute_start_stateful(y_ref, x_ref)
+ ppermute_done_stateful(*fut, y_ref, x_ref)
+ fori_loop(0, 8 // 2, body, None)
+ return x_ref[...]
+```
+
+Because of the requirement that we have a separate buffer ready to receive the `ppermute`, we were forced to write our code in such a way that unrolls it\! There is no way to write the version in XLA that requires copying because that would involve a `ppermute` that sends from a `Ref` into itself, which doesn’t really make sense.
+
+To handle this without the manual unrolling, we’d create a scratch buffer with a leading `2` dimension that acts as the send/recv target across iterations, switching each one. This is the same pattern we use internally in Pallas kernels when writing manually overlapped kernels.
+
+The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\!
+
+1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints.
+2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops.
+3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies.
+
+Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO.
+
+## Conclusion
+
+We’ve gone over some tricky challenges when it comes to async ops in Pallas and JAX. `Ref`s seem like a promising way of representing these ops that circumvents some of the issues that come up with value semantics. However, a downside is that it puts stateful JAX front and center, which we haven’t done yet outside of Pallas. It’s worth thinking whether we should educate users about stateful ops, or provide a more dangerous API. We also don’t know if everything we want to do is expressible via `Ref`s as well. We should also brainstorm alternatives to state to flesh out the design space. For example, what if XLA offered a first-class futures API that respected lifetimes, and it could automatically do things like double buffer loops with futures in them? That might be a viable alternative but the tradeoff would be giving more control to the compiler vs explicit control from the user.
diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst
index 467f375d0e43..5969349c962a 100644
--- a/docs/pallas/index.rst
+++ b/docs/pallas/index.rst
@@ -33,6 +33,13 @@ See also the :class:`jax.experimental.pallas` module API documentation.
tpu/index
.. toctree::
+ :caption: Design Notes
+ :maxdepth: 1
+
+ async_note
+
+.. toctree::
+ :caption: Other
:maxdepth: 1
CHANGELOG
diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst
index eba986c2cfe8..20abad5f610e 100644
--- a/docs/pallas/tpu/index.rst
+++ b/docs/pallas/tpu/index.rst
@@ -9,4 +9,6 @@ TPU specific documentation.
details
pipelining
matmul
+ sparse
distributed
+
diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb
new file mode 100644
index 000000000000..909103273e1e
--- /dev/null
+++ b/docs/pallas/tpu/sparse.ipynb
@@ -0,0 +1,724 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZHuzXqQ-9JUQ"
+ },
+ "source": [
+ "# Scalar Prefetch and Block-Sparse Computation\n",
+ "\n",
+ "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 56,
+ "status": "ok",
+ "timestamp": 1726001133029,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "ibeIs_6QFMAM",
+ "outputId": "d72edb91-4529-4650-c9e9-b96788608635"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running on TPU v5 lite\n"
+ ]
+ }
+ ],
+ "source": [
+ "import functools\n",
+ "import timeit\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "from jax import numpy as jnp\n",
+ "from jax import lax\n",
+ "from jax.experimental import checkify\n",
+ "from jax.experimental import pallas as pl\n",
+ "from jax.experimental.pallas import tpu as pltpu\n",
+ "\n",
+ "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n",
+ "print(\"Running on\", jax.devices()[0].device_kind)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FIDGpPTEIcOa"
+ },
+ "source": [
+ "## Dynamic Block Indexing with Scalar Prefetch\n",
+ "\n",
+ "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n",
+ "\n",
+ "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n",
+ "\n",
+ "```python\n",
+ "class PrefetchScalarGridSpec:\n",
+ " def __init__(self,\n",
+ " num_scalar_prefetch: int,\n",
+ " grid: tuple[int, ...],\n",
+ " in_specs: PyTree[BlockSpec],\n",
+ " out_specs: PyTree[BlockSpec],\n",
+ " scratch_shapes: tuple[MemorySpace, ...]):\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:\n",
+ "\n",
+ "- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:\n",
+ "```python\n",
+ "def index_map(*grid_indices, *prefetch_refs):\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.\n",
+ "```python\n",
+ "def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.\n",
+ "```python\n",
+ "kernel = pl.pallas_call(...)\n",
+ "result = kernel(*prefetch_args, *input_args)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pA8RmHEA2HN3"
+ },
+ "source": [
+ "## Example: Block Dynamic Slice with Scalar Prefetch\n",
+ "\n",
+ "Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:\n",
+ "\n",
+ "1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`\n",
+ "\n",
+ "2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.\n",
+ "\n",
+ "3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.\n",
+ "\n",
+ "Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 143,
+ "status": "ok",
+ "timestamp": 1726003877561,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "FWeTBlEYlCGD",
+ "outputId": "4b04a441-c97c-4d0d-d167-c60d4d31fd2e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Error |result - lax.dynamic_slice| = 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "def dynamic_slice_kernel(indices, x_ref, o_ref):\n",
+ " del indices\n",
+ " o_ref[...] = x_ref[...]\n",
+ "\n",
+ "@checkify.checkify\n",
+ "@functools.partial(jax.jit, static_argnums=(2,))\n",
+ "def block_dynamic_slice(x, starts, sizes):\n",
+ " grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=1,\n",
+ " grid=(1, 1),\n",
+ " in_specs=[pl.BlockSpec(\n",
+ " sizes,\n",
+ " lambda i, j, block_idx: (block_idx[0], block_idx[1]))],\n",
+ " out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),\n",
+ " )\n",
+ "\n",
+ " kernel = pl.pallas_call(\n",
+ " dynamic_slice_kernel,\n",
+ " grid_spec=grid_spec,\n",
+ " out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),\n",
+ " )\n",
+ " # Checkify inserts a runtime assert that starts are divisible by block size.\n",
+ " checkify.check(starts[0] % sizes[0] == 0, \"Starts must be divisible by size.\")\n",
+ " checkify.check(starts[1] % sizes[1] == 0, \"Starts must be divisible by size.\")\n",
+ " block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])\n",
+ " return kernel(block_idx, x)\n",
+ "\n",
+ "shape = (512, 512)\n",
+ "x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)\n",
+ "err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))\n",
+ "err.throw()\n",
+ "ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))\n",
+ "diff = jnp.max(jnp.abs(result - ref))\n",
+ "print(\"Error |result - lax.dynamic_slice| =\", diff)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "K2dod4lkoifa"
+ },
+ "source": [
+ "## Sparse Kernels: Representing Sparse Data\n",
+ "\n",
+ "Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.\n",
+ "\n",
+ "The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.\n",
+ "\n",
+ "![block_coo](../../_static/pallas/sparse/block_coo.svg)\n",
+ "\n",
+ "We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1gLiSvgIYUEx"
+ },
+ "outputs": [],
+ "source": [
+ "def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):\n",
+ " \"\"\"Returns a sampled matrix and its block-sparse representation.\n",
+ "\n",
+ " Args:\n",
+ " key: RNG Key.\n",
+ " M: Major array dimension.\n",
+ " N: Minor array dimension.\n",
+ " blk_M: Block size along M dimension.\n",
+ " blk_N: Block size along N dimension.\n",
+ " p: Probability that a block will be non-zero.\n",
+ " dtype: dtype of the sampled matrix.\n",
+ "\n",
+ " Returns:\n",
+ " dense_mat: A (M, N) dense sampled array.\n",
+ " block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing\n",
+ " the non-zero blocks of the matrix.\n",
+ " indices_i: A (num_blocks,) array of block indices for the first axis.\n",
+ " indices_j: A (num_blocks,) array of block indices for the second axis.\n",
+ " \"\"\"\n",
+ " mask_key, blocks_key = jax.random.split(key)\n",
+ " num_blocks = (M // blk_M, N // blk_N)\n",
+ " # We first sample a block mask, denoting which blocks are nonzero.\n",
+ " block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)\n",
+ " num_blocks = jnp.sum(block_mask)\n",
+ " indices = jnp.where(block_mask)\n",
+ " # For each non-zero block, we sample a block of random values.\n",
+ " block_data = jax.random.uniform(blocks_key,\n",
+ " shape=(num_blocks, blk_M, blk_N),\n",
+ " dtype=dtype)\n",
+ " # For checking purposes, create the dense version of the sparse matrix.\n",
+ " dense_mat = jnp.zeros((M, N), dtype=dtype)\n",
+ " for blk in range(num_blocks):\n",
+ " idx_i = indices[0][blk]\n",
+ " idx_j = indices[1][blk]\n",
+ " slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)\n",
+ " slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)\n",
+ " dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])\n",
+ " return dense_mat, block_data, indices[0], indices[1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eFyoZSTOH9Fk"
+ },
+ "source": [
+ "## Example: Sparse @ Dense Matrix Multiplication\n",
+ "\n",
+ "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n",
+ "\n",
+ "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n",
+ "\n",
+ "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n",
+ "\n",
+ "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 673,
+ "status": "ok",
+ "timestamp": 1725919879291,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "WfyV2WWhjsyA",
+ "outputId": "fa4d4fff-bc6b-4dc9-ac14-63276ca14131"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "mean |result - ref|: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "M = N = K = 16384\n",
+ "blk_M = blk_N = blk_K = 512\n",
+ "\n",
+ "\n",
+ "def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.\n",
+ " x_ref, y_ref, _, o_ref, # Kernel inputs.\n",
+ " accum_scratch,\n",
+ " ):\n",
+ " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
+ " del idxs_k_ref\n",
+ " blk_idx = pl.program_id(1)\n",
+ " is_start = blk_idx == 0\n",
+ " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
+ " @pl.when(is_start | changed_blocks)\n",
+ " def _():\n",
+ " accum_scratch[...] = jnp.zeros_like(accum_scratch)\n",
+ " accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)\n",
+ "\n",
+ " next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])\n",
+ " is_end = blk_idx == (num_blocks - 1)\n",
+ " @pl.when(is_end | next_block_change)\n",
+ " def _():\n",
+ " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
+ "\n",
+ "\n",
+ "def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
+ " del j, blk_idxs_i, blk_idxs_k\n",
+ " return (blk_idx, 0, 0)\n",
+ "def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
+ " del blk_idxs_i\n",
+ " return (blk_idxs_k[blk_idx], j)\n",
+ "def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
+ " del blk_idxs_k\n",
+ " return (blk_idxs_i[blk_idx], j)\n",
+ "\n",
+ "(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(\n",
+ " jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)\n",
+ "num_blocks = X_blocks.shape[0]\n",
+ "Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n",
+ "zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)\n",
+ "out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=2,\n",
+ " # Note that while num_blocks is static here, Pallas does support\n",
+ " # dynamic grid sizes.\n",
+ " grid=(M // blk_M, num_blocks),\n",
+ " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
+ " pl.BlockSpec((blk_K, blk_N), y_map),\n",
+ " # Placeholder for a zeros-array used by input_output_aliases.\n",
+ " pl.BlockSpec((blk_M, blk_N), o_map),\n",
+ " ],\n",
+ " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n",
+ " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n",
+ ")\n",
+ "kernel = pl.pallas_call(\n",
+ " dsd_kernel,\n",
+ " grid_spec=grid_spec,\n",
+ " out_shape=out_shape,\n",
+ " # We use input-output aliases to zero-out o_ref for blocks that we never\n",
+ " # visit. By passing in an array of zeros we avoid having o_ref start with\n",
+ " # uninitialized values.\n",
+ " input_output_aliases={4: 0}, # Map zeros to o_ref.\n",
+ ")\n",
+ "args = (indices_i, indices_k, X_blocks, Y, zeros)\n",
+ "result = kernel(*args)\n",
+ "\n",
+ "ref = X_dense @ Y\n",
+ "diff = jnp.abs(ref - result)\n",
+ "print('mean |result - ref|:', jnp.mean(diff))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2KDgPKF2tUjq"
+ },
+ "source": [
+ "We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.\n",
+ "\n",
+ "There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:\n",
+ "- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.\n",
+ "- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 6576,
+ "status": "ok",
+ "timestamp": 1725919886762,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "CkzjqnekpZbx",
+ "outputId": "1ae9031e-705a-4d05-f8b9-d09623918300"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sparse Kernel: 8.136 ms (avg over 100 trials)\n",
+ "Reference: 46.953 ms (avg over 100 trials)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Benchmark Sparse Pallas kernel vs reference JAX implementation\n",
+ "\n",
+ "def benchmark(f, ntrials: int = 100):\n",
+ " def run(*args, **kwargs):\n",
+ " # Compile function first\n",
+ " jax.block_until_ready(f(*args, **kwargs))\n",
+ " # Time function\n",
+ " result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),\n",
+ " number=ntrials)\n",
+ " time = result / ntrials\n",
+ " return time\n",
+ " return run\n",
+ "\n",
+ "\n",
+ "n_trials = 100\n",
+ "\n",
+ "pallas_impl = lambda *args: kernel(*args)\n",
+ "time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)\n",
+ "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n",
+ "\n",
+ "ref_impl = jax.jit(lambda x, y: x @ y)\n",
+ "time = benchmark(ref_impl, n_trials)(X_dense, Y)\n",
+ "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Q1KKd5vTCwnB"
+ },
+ "source": [
+ "## Sparse Access Patterns on Dense Data\n",
+ "\n",
+ "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n",
+ "\n",
+ "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n",
+ "\n",
+ "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n",
+ "\n",
+ "![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)\n",
+ "\n",
+ "*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*\n",
+ "\n",
+ "Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.\n",
+ "\n",
+ "```python\n",
+ "def mask_index_map(prefetch_map, i, j, ...):\n",
+ " next_nonzero_block = prefetch_map[i, j]\n",
+ " return (next_nonzero_block, 0, 0)\n",
+ "```\n",
+ "\n",
+ "We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ii7rzL5YIA8-"
+ },
+ "source": [
+ "## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ecjiqWfA2RlV"
+ },
+ "source": [
+ "In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.\n",
+ "\n",
+ "As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:\n",
+ "- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.\n",
+ "- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.\n",
+ "- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.\n",
+ "- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.\n",
+ "- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "19zGcliL2SJy"
+ },
+ "outputs": [],
+ "source": [
+ "def sparsify_mask(mask: jax.Array,\n",
+ " block_shape: tuple[int, int]):\n",
+ " \"\"\"Preprocesses a mask into a sparse reprentation.\n",
+ "\n",
+ " Args:\n",
+ " mask: A boolean array of shape [M, N]\n",
+ " block_shape: The size of a single block.\n",
+ "\n",
+ " Returns:\n",
+ " block_mask: A block_shape array of booleans indicating whether a block\n",
+ " is all-zeros (0) or contains non-zero elements (1).\n",
+ " prefetch_mask: A block_shape array of integers indicating the index of the\n",
+ " next non-zero block.\n",
+ " mask_data: A (num_blocks, block_shape) array containing\n",
+ " the data for non-zero blocks of the mask.\n",
+ " \"\"\"\n",
+ " M, N = mask.shape\n",
+ " bm, bn = block_shape\n",
+ "\n",
+ " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n",
+ " mask_types_finder = []\n",
+ " mask_data = []\n",
+ " mask_type_idxs = []\n",
+ "\n",
+ " next_mask_type_idx = 0\n",
+ " prefetch_mask = jnp.zeros_like(block_mask)\n",
+ " next_i = (M // bm) - 1\n",
+ " next_j = (N // bn) - 1\n",
+ " prefetch_i = jnp.zeros_like(block_mask)\n",
+ " prefetch_j = jnp.zeros_like(block_mask)\n",
+ " for i in range(M // bm, -1, -1):\n",
+ " for j in range(N // bn, -1, -1):\n",
+ " mask_block = mask[i * bm :(i + 1) * bm,\n",
+ " j * bn :(j + 1) * bn]\n",
+ " is_nonzero = jnp.any(mask_block)\n",
+ " if is_nonzero:\n",
+ " try:\n",
+ " type_index = mask_types_finder.index(str(mask_block))\n",
+ " except ValueError:\n",
+ " type_index = len(mask_types_finder)\n",
+ " mask_types_finder.append(str(mask_block))\n",
+ " mask_data.append(mask_block)\n",
+ " next_mask_type_idx = type_index\n",
+ " next_i = i\n",
+ " next_j = j\n",
+ " else:\n",
+ " type_index = -1\n",
+ " mask_type_idxs.append(type_index)\n",
+ " block_mask = block_mask.at[i, j].set(is_nonzero)\n",
+ " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n",
+ " prefetch_i = prefetch_i.at[i, j].set(next_i)\n",
+ " prefetch_j = prefetch_j.at[i, j].set(next_j)\n",
+ " return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "w4b7ckKq67Xw"
+ },
+ "source": [
+ "In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 5374,
+ "status": "ok",
+ "timestamp": 1725919713252,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "4YQ9OmbTCSjT",
+ "outputId": "2d752609-34f2-4059-e8ba-4d80afe8cb26"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "mean |result - ref|: 1.0252e-05\n"
+ ]
+ }
+ ],
+ "source": [
+ "M = N = K = 16384\n",
+ "blk_M = blk_N = 512\n",
+ "blk_K = 1024\n",
+ "\n",
+ "def sparse_mask_matmul(\n",
+ " block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.\n",
+ " x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.\n",
+ " accum_scratch\n",
+ " ):\n",
+ " del prefetch_mask, prefetch_i, prefetch_j\n",
+ " i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)\n",
+ " should_compute = block_mask_ref[i, j] != 0\n",
+ " @pl.when(k == 0)\n",
+ " def _():\n",
+ " o_ref[...] = jnp.zeros_like(o_ref)\n",
+ " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n",
+ "\n",
+ " # We only compute the output for blocks with non-zero masks.\n",
+ " # Otherwise we skip the computation entirely.\n",
+ " @pl.when(should_compute)\n",
+ " def _():\n",
+ " result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)\n",
+ " accum_scratch[...] += result\n",
+ " @pl.when(k == pl.num_programs(2) - 1)\n",
+ " def _():\n",
+ " o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)\n",
+ "\n",
+ "X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)\n",
+ "Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n",
+ "mask = jnp.ones((M, N), dtype=jnp.int32)\n",
+ "mask = jnp.tril(mask)\n",
+ "block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))\n",
+ "\n",
+ "def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n",
+ " del prefetch_mask, prefetch_j\n",
+ " # Zero-out the k index if the mask is zero, to avoid constantly fetching\n",
+ " # new blocks in the inner loop for blocks we are skipping.\n",
+ " k_fetch = (block_mask[i, j] != 0) * k\n",
+ " return (prefetch_i[i, j], k_fetch)\n",
+ "\n",
+ "def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n",
+ " del prefetch_mask, prefetch_i\n",
+ " k_fetch = (block_mask[i, j] != 0) * k\n",
+ " return (k_fetch, prefetch_j[i, j])\n",
+ "\n",
+ "def mask_map(i, j, k, block_mask, prefetch_mask, *_):\n",
+ " del k, block_mask\n",
+ " return (prefetch_mask[i, j], 0, 0)\n",
+ "\n",
+ "def o_map(i, j, k, *_):\n",
+ " del k\n",
+ " return (i, j)\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=4,\n",
+ " grid=(M // blk_M, N // blk_N, K // blk_K),\n",
+ " in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),\n",
+ " pl.BlockSpec((blk_K, blk_N), y_map),\n",
+ " pl.BlockSpec((1, blk_M, blk_N), mask_map)],\n",
+ " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n",
+ " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n",
+ ")\n",
+ "kernel = pl.pallas_call(\n",
+ " sparse_mask_matmul,\n",
+ " grid_spec=grid_spec,\n",
+ " out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),\n",
+ ")\n",
+ "args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n",
+ "result = kernel(*args)\n",
+ "\n",
+ "ref = mask * (X @ Y)\n",
+ "diff = jnp.abs(ref - result)\n",
+ "print('mean |result - ref|:', jnp.mean(diff))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uutNGgjZGGhB"
+ },
+ "source": [
+ "Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.\n",
+ "\n",
+ "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n",
+ "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n",
+ "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 8877,
+ "status": "ok",
+ "timestamp": 1725917397452,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "MAT9JjGNvsx8",
+ "outputId": "a32d56fb-a71b-4007-c6a5-e5270dcaa6cf"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sparse Kernel: 28.648 ms (avg over 100 trials)\n",
+ "Reference: 49.988 ms (avg over 100 trials)\n"
+ ]
+ }
+ ],
+ "source": [
+ "n_trials = 100\n",
+ "\n",
+ "pallas_impl = lambda *args: kernel(*args)\n",
+ "time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n",
+ "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n",
+ "\n",
+ "ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))\n",
+ "time = benchmark(ref_impl, n_trials)(mask, X, Y)\n",
+ "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "formats": "ipynb,md:myst",
+ "main_language": "python"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md
new file mode 100644
index 000000000000..23e14bb9bc0b
--- /dev/null
+++ b/docs/pallas/tpu/sparse.md
@@ -0,0 +1,567 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ main_language: python
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ name: python3
+---
+
++++ {"id": "ZHuzXqQ-9JUQ"}
+
+# Scalar Prefetch and Block-Sparse Computation
+
+In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 56
+ status: ok
+ timestamp: 1726001133029
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: ibeIs_6QFMAM
+outputId: d72edb91-4529-4650-c9e9-b96788608635
+---
+import functools
+import timeit
+import numpy as np
+import jax
+from jax import numpy as jnp
+from jax import lax
+from jax.experimental import checkify
+from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
+
+assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
+print("Running on", jax.devices()[0].device_kind)
+```
+
++++ {"id": "FIDGpPTEIcOa"}
+
+## Dynamic Block Indexing with Scalar Prefetch
+
+We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.
+
+To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:
+
+```python
+class PrefetchScalarGridSpec:
+ def __init__(self,
+ num_scalar_prefetch: int,
+ grid: tuple[int, ...],
+ in_specs: PyTree[BlockSpec],
+ out_specs: PyTree[BlockSpec],
+ scratch_shapes: tuple[MemorySpace, ...]):
+ ...
+```
+
+The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:
+
+- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:
+```python
+def index_map(*grid_indices, *prefetch_refs):
+ ...
+```
+
+- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.
+```python
+def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
+ ...
+```
+
+- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.
+```python
+kernel = pl.pallas_call(...)
+result = kernel(*prefetch_args, *input_args)
+```
+
++++ {"id": "pA8RmHEA2HN3"}
+
+## Example: Block Dynamic Slice with Scalar Prefetch
+
+Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:
+
+1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`
+
+2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.
+
+3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.
+
+Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 143
+ status: ok
+ timestamp: 1726003877561
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: FWeTBlEYlCGD
+outputId: 4b04a441-c97c-4d0d-d167-c60d4d31fd2e
+---
+def dynamic_slice_kernel(indices, x_ref, o_ref):
+ del indices
+ o_ref[...] = x_ref[...]
+
+@checkify.checkify
+@functools.partial(jax.jit, static_argnums=(2,))
+def block_dynamic_slice(x, starts, sizes):
+ grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=1,
+ grid=(1, 1),
+ in_specs=[pl.BlockSpec(
+ sizes,
+ lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
+ out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
+ )
+
+ kernel = pl.pallas_call(
+ dynamic_slice_kernel,
+ grid_spec=grid_spec,
+ out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
+ )
+ # Checkify inserts a runtime assert that starts are divisible by block size.
+ checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
+ checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
+ block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
+ return kernel(block_idx, x)
+
+shape = (512, 512)
+x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
+err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
+err.throw()
+ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
+diff = jnp.max(jnp.abs(result - ref))
+print("Error |result - lax.dynamic_slice| =", diff)
+```
+
++++ {"id": "K2dod4lkoifa"}
+
+## Sparse Kernels: Representing Sparse Data
+
+Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.
+
+The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.
+
+![block_coo](../../_static/pallas/sparse/block_coo.svg)
+
+We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis.
+
+```{code-cell}
+:id: 1gLiSvgIYUEx
+
+def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
+ """Returns a sampled matrix and its block-sparse representation.
+
+ Args:
+ key: RNG Key.
+ M: Major array dimension.
+ N: Minor array dimension.
+ blk_M: Block size along M dimension.
+ blk_N: Block size along N dimension.
+ p: Probability that a block will be non-zero.
+ dtype: dtype of the sampled matrix.
+
+ Returns:
+ dense_mat: A (M, N) dense sampled array.
+ block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
+ the non-zero blocks of the matrix.
+ indices_i: A (num_blocks,) array of block indices for the first axis.
+ indices_j: A (num_blocks,) array of block indices for the second axis.
+ """
+ mask_key, blocks_key = jax.random.split(key)
+ num_blocks = (M // blk_M, N // blk_N)
+ # We first sample a block mask, denoting which blocks are nonzero.
+ block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
+ num_blocks = jnp.sum(block_mask)
+ indices = jnp.where(block_mask)
+ # For each non-zero block, we sample a block of random values.
+ block_data = jax.random.uniform(blocks_key,
+ shape=(num_blocks, blk_M, blk_N),
+ dtype=dtype)
+ # For checking purposes, create the dense version of the sparse matrix.
+ dense_mat = jnp.zeros((M, N), dtype=dtype)
+ for blk in range(num_blocks):
+ idx_i = indices[0][blk]
+ idx_j = indices[1][blk]
+ slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
+ slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
+ dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
+ return dense_mat, block_data, indices[0], indices[1]
+```
+
++++ {"id": "eFyoZSTOH9Fk"}
+
+## Example: Sparse @ Dense Matrix Multiplication
+
+In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.
+
+We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:
+
+![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)
+
+It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 673
+ status: ok
+ timestamp: 1725919879291
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: WfyV2WWhjsyA
+outputId: fa4d4fff-bc6b-4dc9-ac14-63276ca14131
+---
+M = N = K = 16384
+blk_M = blk_N = blk_K = 512
+
+
+def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
+ x_ref, y_ref, _, o_ref, # Kernel inputs.
+ accum_scratch,
+ ):
+ """A DSD (Dense = Sparse @ Dense) matmul kernel."""
+ del idxs_k_ref
+ blk_idx = pl.program_id(1)
+ is_start = blk_idx == 0
+ changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
+ @pl.when(is_start | changed_blocks)
+ def _():
+ accum_scratch[...] = jnp.zeros_like(accum_scratch)
+ accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)
+
+ next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
+ is_end = blk_idx == (num_blocks - 1)
+ @pl.when(is_end | next_block_change)
+ def _():
+ o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
+
+
+def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
+ del j, blk_idxs_i, blk_idxs_k
+ return (blk_idx, 0, 0)
+def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
+ del blk_idxs_i
+ return (blk_idxs_k[blk_idx], j)
+def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
+ del blk_idxs_k
+ return (blk_idxs_i[blk_idx], j)
+
+(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
+ jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
+num_blocks = X_blocks.shape[0]
+Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
+zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
+out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=2,
+ # Note that while num_blocks is static here, Pallas does support
+ # dynamic grid sizes.
+ grid=(M // blk_M, num_blocks),
+ in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
+ pl.BlockSpec((blk_K, blk_N), y_map),
+ # Placeholder for a zeros-array used by input_output_aliases.
+ pl.BlockSpec((blk_M, blk_N), o_map),
+ ],
+ out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
+ scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
+)
+kernel = pl.pallas_call(
+ dsd_kernel,
+ grid_spec=grid_spec,
+ out_shape=out_shape,
+ # We use input-output aliases to zero-out o_ref for blocks that we never
+ # visit. By passing in an array of zeros we avoid having o_ref start with
+ # uninitialized values.
+ input_output_aliases={4: 0}, # Map zeros to o_ref.
+)
+args = (indices_i, indices_k, X_blocks, Y, zeros)
+result = kernel(*args)
+
+ref = X_dense @ Y
+diff = jnp.abs(ref - result)
+print('mean |result - ref|:', jnp.mean(diff))
+```
+
++++ {"id": "2KDgPKF2tUjq"}
+
+We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.
+
+There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:
+- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.
+- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 6576
+ status: ok
+ timestamp: 1725919886762
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: CkzjqnekpZbx
+outputId: 1ae9031e-705a-4d05-f8b9-d09623918300
+---
+# Benchmark Sparse Pallas kernel vs reference JAX implementation
+
+def benchmark(f, ntrials: int = 100):
+ def run(*args, **kwargs):
+ # Compile function first
+ jax.block_until_ready(f(*args, **kwargs))
+ # Time function
+ result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
+ number=ntrials)
+ time = result / ntrials
+ return time
+ return run
+
+
+n_trials = 100
+
+pallas_impl = lambda *args: kernel(*args)
+time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
+print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+
+ref_impl = jax.jit(lambda x, y: x @ y)
+time = benchmark(ref_impl, n_trials)(X_dense, Y)
+print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+```
+
++++ {"id": "Q1KKd5vTCwnB"}
+
+## Sparse Access Patterns on Dense Data
+
+In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).
+
+A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).
+
+The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.
+
+![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)
+
+*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*
+
+Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.
+
+```python
+def mask_index_map(prefetch_map, i, j, ...):
+ next_nonzero_block = prefetch_map[i, j]
+ return (next_nonzero_block, 0, 0)
+```
+
+We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps.
+
++++ {"id": "ii7rzL5YIA8-"}
+
+## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask
+
++++ {"id": "ecjiqWfA2RlV"}
+
+In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.
+
+As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:
+- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.
+- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.
+- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.
+- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.
+- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask.
+
+```{code-cell}
+:id: 19zGcliL2SJy
+
+def sparsify_mask(mask: jax.Array,
+ block_shape: tuple[int, int]):
+ """Preprocesses a mask into a sparse reprentation.
+
+ Args:
+ mask: A boolean array of shape [M, N]
+ block_shape: The size of a single block.
+
+ Returns:
+ block_mask: A block_shape array of booleans indicating whether a block
+ is all-zeros (0) or contains non-zero elements (1).
+ prefetch_mask: A block_shape array of integers indicating the index of the
+ next non-zero block.
+ mask_data: A (num_blocks, block_shape) array containing
+ the data for non-zero blocks of the mask.
+ """
+ M, N = mask.shape
+ bm, bn = block_shape
+
+ block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
+ mask_types_finder = []
+ mask_data = []
+ mask_type_idxs = []
+
+ next_mask_type_idx = 0
+ prefetch_mask = jnp.zeros_like(block_mask)
+ next_i = (M // bm) - 1
+ next_j = (N // bn) - 1
+ prefetch_i = jnp.zeros_like(block_mask)
+ prefetch_j = jnp.zeros_like(block_mask)
+ for i in range(M // bm, -1, -1):
+ for j in range(N // bn, -1, -1):
+ mask_block = mask[i * bm :(i + 1) * bm,
+ j * bn :(j + 1) * bn]
+ is_nonzero = jnp.any(mask_block)
+ if is_nonzero:
+ try:
+ type_index = mask_types_finder.index(str(mask_block))
+ except ValueError:
+ type_index = len(mask_types_finder)
+ mask_types_finder.append(str(mask_block))
+ mask_data.append(mask_block)
+ next_mask_type_idx = type_index
+ next_i = i
+ next_j = j
+ else:
+ type_index = -1
+ mask_type_idxs.append(type_index)
+ block_mask = block_mask.at[i, j].set(is_nonzero)
+ prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
+ prefetch_i = prefetch_i.at[i, j].set(next_i)
+ prefetch_j = prefetch_j.at[i, j].set(next_j)
+ return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)
+```
+
++++ {"id": "w4b7ckKq67Xw"}
+
+In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 5374
+ status: ok
+ timestamp: 1725919713252
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: 4YQ9OmbTCSjT
+outputId: 2d752609-34f2-4059-e8ba-4d80afe8cb26
+---
+M = N = K = 16384
+blk_M = blk_N = 512
+blk_K = 1024
+
+def sparse_mask_matmul(
+ block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
+ x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.
+ accum_scratch
+ ):
+ del prefetch_mask, prefetch_i, prefetch_j
+ i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
+ should_compute = block_mask_ref[i, j] != 0
+ @pl.when(k == 0)
+ def _():
+ o_ref[...] = jnp.zeros_like(o_ref)
+ accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
+
+ # We only compute the output for blocks with non-zero masks.
+ # Otherwise we skip the computation entirely.
+ @pl.when(should_compute)
+ def _():
+ result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
+ accum_scratch[...] += result
+ @pl.when(k == pl.num_programs(2) - 1)
+ def _():
+ o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)
+
+X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
+Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
+mask = jnp.ones((M, N), dtype=jnp.int32)
+mask = jnp.tril(mask)
+block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))
+
+def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
+ del prefetch_mask, prefetch_j
+ # Zero-out the k index if the mask is zero, to avoid constantly fetching
+ # new blocks in the inner loop for blocks we are skipping.
+ k_fetch = (block_mask[i, j] != 0) * k
+ return (prefetch_i[i, j], k_fetch)
+
+def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
+ del prefetch_mask, prefetch_i
+ k_fetch = (block_mask[i, j] != 0) * k
+ return (k_fetch, prefetch_j[i, j])
+
+def mask_map(i, j, k, block_mask, prefetch_mask, *_):
+ del k, block_mask
+ return (prefetch_mask[i, j], 0, 0)
+
+def o_map(i, j, k, *_):
+ del k
+ return (i, j)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=4,
+ grid=(M // blk_M, N // blk_N, K // blk_K),
+ in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
+ pl.BlockSpec((blk_K, blk_N), y_map),
+ pl.BlockSpec((1, blk_M, blk_N), mask_map)],
+ out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
+ scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
+)
+kernel = pl.pallas_call(
+ sparse_mask_matmul,
+ grid_spec=grid_spec,
+ out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
+)
+args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
+result = kernel(*args)
+
+ref = mask * (X @ Y)
+diff = jnp.abs(ref - result)
+print('mean |result - ref|:', jnp.mean(diff))
+```
+
++++ {"id": "uutNGgjZGGhB"}
+
+Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.
+
+We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:
+- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.
+- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 8877
+ status: ok
+ timestamp: 1725917397452
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: MAT9JjGNvsx8
+outputId: a32d56fb-a71b-4007-c6a5-e5270dcaa6cf
+---
+n_trials = 100
+
+pallas_impl = lambda *args: kernel(*args)
+time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
+print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+
+ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
+time = benchmark(ref_impl, n_trials)(mask, X, Y)
+print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+```
diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb
index 60bf4d41a7a6..cdfda63c6f13 100644
--- a/docs/sharded-computation.ipynb
+++ b/docs/sharded-computation.ipynb
@@ -60,7 +60,7 @@
"\n",
"Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n",
"\n",
- "How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n",
+ "How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n",
"\n",
"In the simplest cases, arrays are sharded on a single device, as demonstrated below:"
]
@@ -360,7 +360,7 @@
"\n",
"## 2. Semi-automated sharding with constraints\n",
"\n",
- "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
+ "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
"\n",
"For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:"
]
diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md
index ef4dc2d3288d..84516a557166 100644
--- a/docs/sharded-computation.md
+++ b/docs/sharded-computation.md
@@ -39,7 +39,7 @@ jax.devices()
Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.
-How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.
+How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.
In the simplest cases, arrays are sharded on a single device, as demonstrated below:
@@ -133,7 +133,7 @@ The result is partially replicated: that is, the first two elements of the array
## 2. Semi-automated sharding with constraints
-If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
+If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md
index 2eeffc30b255..4eb6e7a66cdd 100644
--- a/docs/stateful-computations.md
+++ b/docs/stateful-computations.md
@@ -144,7 +144,7 @@ This is because, like the strategy we just applied, object-oriented programming
In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
-Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the :ref:`pseudorandom-numbers` section.
+Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the {ref}`pseudorandom-numbers` section.
Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.
diff --git a/jax/BUILD b/jax/BUILD
index 74072dc44644..c6d8fe25af59 100644
--- a/jax/BUILD
+++ b/jax/BUILD
@@ -848,6 +848,7 @@ pytype_strict_library(
],
deps = [
":core",
+ ":dtypes",
":effects",
":pretty_printer",
":tree_util",
diff --git a/jax/__init__.py b/jax/__init__.py
index 7e958b21c5dd..e2e302adb855 100644
--- a/jax/__init__.py
+++ b/jax/__init__.py
@@ -127,7 +127,6 @@
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
-from jax._src.api import xla_computation as _deprecated_xla_computation
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.sharding_impls import make_mesh as make_mesh
@@ -224,20 +223,18 @@
"jax.clear_backends is deprecated.",
_deprecated_clear_backends
),
- # Added Jun 16, 2024
+ # Remove after jax 0.4.35 release.
"xla_computation": (
- "jax.xla_computation is deprecated. Please use the AOT APIs; see "
+ "jax.xla_computation is deleted. Please use the AOT APIs; see "
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
- "CHANGELOG.md for 0.4.30 for more examples.",
- _deprecated_xla_computation
+ "CHANGELOG.md for 0.4.30 for more examples.", None
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.api import clear_backends as clear_backends
- from jax._src.api import xla_computation as xla_computation
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves
diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py
index fd30119882e7..8c7fe2f489d5 100644
--- a/jax/_src/ad_checkpoint.py
+++ b/jax/_src/ad_checkpoint.py
@@ -514,7 +514,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
prevent_cse=prevent_cse, differentiated=differentiated, policy=policy)
out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
out_tangents_ = iter(out_tangents_)
- out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
+ out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents
ad.primitive_jvps[remat_p] = remat_jvp
diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py
index 57e881c34f82..bd1427f59e01 100644
--- a/jax/_src/ad_util.py
+++ b/jax/_src/ad_util.py
@@ -20,7 +20,7 @@
from jax._src import core
from jax._src import traceback_util
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
-from jax._src.tree_util import register_pytree_node
+from jax._src.tree_util import register_pytree_node, tree_map
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map
@@ -65,8 +65,8 @@ def __init__(self, aval: core.AbstractValue):
def __repr__(self) -> str:
return f'Zero({self.aval})'
@staticmethod
- def from_value(val: Any) -> Zero:
- return Zero(raise_to_shaped(get_aval(val)))
+ def from_primal_value(val: Any) -> Zero:
+ return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())
register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))
@@ -82,6 +82,7 @@ def _stop_gradient_impl(x: T) -> T:
stop_gradient_p.def_abstract_eval(lambda x: x)
+# User-facing version of `Zero`
class SymbolicZero:
def __init__(self, aval: core.AbstractValue) -> None:
self.aval = aval
@@ -108,6 +109,19 @@ def __getattr__(self, name):
else:
return attr
+ @staticmethod
+ def from_primal_value(val: Any) -> SymbolicZero:
+ return SymbolicZero(get_aval(val).to_tangent_aval())
+
+def zero_from_primal(val, symbolic_zeros=False):
+ def f(x):
+ tangent_aval = get_aval(x).to_tangent_aval()
+ if symbolic_zeros:
+ return SymbolicZero(tangent_aval)
+ else:
+ return zeros_like_aval(tangent_aval)
+ return tree_map(f, val)
+
JaxTypeOrTracer = Any
def replace_internal_symbolic_zeros(
diff --git a/jax/_src/api.py b/jax/_src/api.py
index 935995ec5cba..aae99a28bbea 100644
--- a/jax/_src/api.py
+++ b/jax/_src/api.py
@@ -46,7 +46,6 @@
from jax._src import config
from jax._src import core
from jax._src import dispatch
-from jax._src import effects
from jax._src import array
from jax._src import basearray
from jax._src import distributed
@@ -60,7 +59,7 @@
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
- argnums_partial_except, flatten_axes, donation_vector,
+ flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
@@ -73,13 +72,11 @@
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
-from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps,
- split_list)
+from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
-from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@@ -337,244 +334,6 @@ def disable_jit(disable: bool = True):
yield
-def xla_computation(fun: Callable,
- static_argnums: int | Iterable[int] = (),
- axis_env: Sequence[tuple[AxisName, int]] | None = None,
- in_parts=None, out_parts=None,
- backend: str | None = None,
- tuple_args: bool = False,
- instantiate_const_outputs: bool | None = None,
- return_shape: bool = False,
- donate_argnums: int | Iterable[int] = ()) -> Callable:
- """Creates a function that produces its XLA computation given example args.
-
- .. warning::
-
- This function is deprecated as of JAX v0.4.30, and will be removed in a future
- JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
- example, ``jax.xla_computation(fn)(*args)`` can be replaced with
- ``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
- See the `JAX 0.4.30 Change log`_ for more examples.
-
- Args:
- fun: Function from which to form XLA computations.
- static_argnums: See the :py:func:`jax.jit` docstring.
- axis_env: Optional, a sequence of pairs where the first element is an axis
- name and the second element is a positive integer representing the size of
- the mapped axis with that name. This parameter is useful when lowering
- functions that involve parallel communication collectives, and it
- specifies the axis name/size environment that would be set up by
- applications of :py:func:`jax.pmap`. See the examples below.
- in_parts: Optional, how each argument to ``fun`` should be partitioned or
- replicated. This is used to specify partitioned XLA computations, see
- ``sharded_jit`` for more info.
- out_parts: Optional, how each output of ``fun`` should be partitioned or
- replicated. This is used to specify partitioned XLA computations, see
- ``sharded_jit`` for more info.
- backend: This is an experimental feature and the API is likely to change.
- Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
- ``'tpu'``.
- tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
- XLA computation will have a single tuple argument that is unpacked into
- the specified function arguments. If `None`, tupling will be enabled when
- there are more than 100 arguments, since some platforms have limits on
- argument arity.
- instantiate_const_outputs: Deprecated argument, does nothing.
- return_shape: Optional boolean, defaults to ``False``. If ``True``, the
- wrapped function returns a pair where the first element is the XLA
- computation and the second element is a pytree with the same structure as
- the output of ``fun`` and where the leaves are objects with ``shape`` and
- ``dtype`` attributes representing the corresponding types of the output
- leaves.
- donate_argnums: Specify which arguments are "donated" to the computation.
- It is safe to donate arguments if you no longer need them once the
- computation has finished. In some cases XLA can make use of donated
- buffers to reduce the amount of memory needed to perform a computation,
- for example recycling one of your input buffers to store a result. You
- should not reuse buffers that you donate to a computation, JAX will raise
- an error if you try to.
-
- Returns:
- A wrapped version of ``fun`` that when applied to example arguments returns
- a built XLA Computation (see xla_client.py), from which representations of
- the unoptimized XLA HLO computation can be extracted using methods like
- ``as_hlo_text``, ``as_serialized_hlo_module_proto``, and
- ``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
- wrapped function returns a pair where the first element is the XLA
- Computation and the second element is a pytree representing the structure,
- shapes, dtypes, and named shapes of the output of ``fun``.
-
- Concrete example arguments are not always necessary. For those arguments not
- indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
- attributes is acceptable (excepting namedtuples, which are treated as Python
- containers).
-
- For example:
-
- >>> import jax
- >>>
- >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
- >>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
- >>> print(c.as_hlo_text()) # doctest: +SKIP
- HloModule xla_computation_f.6
-
- ENTRY xla_computation_f.6 {
- constant.2 = pred[] constant(false)
- parameter.1 = f32[] parameter(0)
- cosine.3 = f32[] cosine(parameter.1)
- sine.4 = f32[] sine(cosine.3)
- ROOT tuple.5 = (f32[]) tuple(sine.4)
- }
-
-
-
-
- Alternatively, the assignment to ``c`` above could be written:
-
- >>> import types
- >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
- >>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
-
-
- Here's an example that involves a parallel collective and axis name:
-
- >>> def f(x): return x - jax.lax.psum(x, 'i')
- >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
- >>> print(c.as_hlo_text()) # doctest: +SKIP
- HloModule jaxpr_computation.9
- primitive_computation.3 {
- parameter.4 = s32[] parameter(0)
- parameter.5 = s32[] parameter(1)
- ROOT add.6 = s32[] add(parameter.4, parameter.5)
- }
- ENTRY jaxpr_computation.9 {
- tuple.1 = () tuple()
- parameter.2 = s32[] parameter(0)
- all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
- ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
- }
-
-
-
- Notice the ``replica_groups`` that were generated. Here's an example that
- generates more interesting ``replica_groups``:
-
- >>> from jax import lax
- >>> def g(x):
- ... rowsum = lax.psum(x, 'i')
- ... colsum = lax.psum(x, 'j')
- ... allsum = lax.psum(x, ('i', 'j'))
- ... return rowsum, colsum, allsum
- ...
- >>> axis_env = [('i', 4), ('j', 2)]
- >>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
- >>> print(c.as_hlo_text()) # doctest: +SKIP
- HloModule jaxpr_computation__1.19
- [removed uninteresting text here]
- ENTRY jaxpr_computation__1.19 {
- tuple.1 = () tuple()
- parameter.2 = f32[] parameter(0)
- all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
- all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
- all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
- ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
- }
-
- .. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
- """
- if instantiate_const_outputs is not None:
- raise ValueError(
- "instantiate_const_outputs has been deprecated. Please use the ahead of"
- " time APIs. You can read more here:"
- " https://jax.readthedocs.io/en/latest/aot.html")
- if in_parts is not None:
- raise ValueError(
- "in_parts has been deprecated. Please use the ahead of time APIs. You"
- " can read more here: https://jax.readthedocs.io/en/latest/aot.html")
- if out_parts is not None:
- raise ValueError(
- "out_parts has been deprecated. Please use the ahead of time APIs. You"
- " can read more here: https://jax.readthedocs.io/en/latest/aot.html")
-
- check_callable(fun)
- static_argnums = _ensure_index_tuple(static_argnums)
- donate_argnums = _ensure_index_tuple(donate_argnums)
- donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
-
- fun_name = getattr(fun, "__name__", "unknown")
-
- platform = backend if backend is not None else xb.get_backend().platform
-
- def make_axis_env(nreps):
- if axis_env is None:
- return sharding_impls.AxisEnv(nreps, (), ())
- else:
- nreps = nreps * math.prod(size for name, size in axis_env)
- names, sizes = unzip2(axis_env)
- return sharding_impls.AxisEnv(nreps, names, sizes)
-
- @wraps(fun)
- @api_boundary
- def computation_maker(*args, **kwargs):
- if max(static_argnums + donate_argnums, default=-1) >= len(args):
- raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but "
- f"was called with only {len(args)} positional arguments.")
-
- f = lu.wrap_init(fun)
- f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
- args_flat, in_tree = tree_flatten((dyn_args, kwargs))
- if donate_argnums:
- donated_invars = donation_vector(donate_argnums, (), in_tree)
- else:
- donated_invars = (False,) * len(args_flat)
-
- jaxtree_fun, out_tree = flatten_fun(f, in_tree)
- avals = map(shaped_abstractify, args_flat)
- with ExitStack() as stack:
- for axis_name, size in axis_env or []:
- stack.enter_context(core.extend_axis_env(axis_name, size, None))
- jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
- jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
- if axis_env:
- jaxpr = core.remove_named_axis_effects(
- jaxpr, {axis_name for axis_name, _ in axis_env}
- )
- axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
- ordered_effects = list(
- effects.ordered_effects.filter_in(jaxpr.effects))
- lowering_result = mlir.lower_jaxpr_to_module(
- f"xla_computation_{fun_name}",
- core.ClosedJaxpr(jaxpr, consts),
- ordered_effects=ordered_effects,
- backend_or_name=backend,
- platforms=[platform],
- axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
- name_stack=source_info_util.new_name_stack(
- wrap_name(fun_name, "xla_computation")),
- donated_args=donated_invars,
- arg_shardings=None,
- result_shardings=None,
- lowering_parameters=mlir.LoweringParameters())
-
- m = mlir.module_to_bytecode(lowering_result.module)
- built = xc._xla.mlir.mlir_module_to_xla_computation(
- m, use_tuple_args=tuple_args, return_tuple=True)
- out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
- out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
- out_shape = tree_unflatten(out_tree(), out_shapes_flat)
- for out_aval in out_avals:
- if not isinstance(out_aval, ShapedArray):
- raise RuntimeError("As we want to propagate the weak_type, we need "
- "to get a ShapedArray, otherwise this "
- "information is lost")
-
- if return_shape:
- return built, out_shape
- else:
- return built
-
- return computation_maker
-
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
@@ -2067,7 +1826,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.get_aval, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
- if not core.typecompat(primal_aval.at_least_vspace(), tangent_aval):
+ if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval):
raise ValueError("linearized function called on tangent values inconsistent with "
"the original primal values: "
f"got {tangent_aval} for primal aval {primal_aval}")
@@ -2110,7 +1869,7 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
f"got {in_tree}, but expected to match {in_tree_expected}")
for arg, aval in zip(args, out_primal_avals):
ct_aval = shaped_abstractify(arg)
- ct_aval_expected = aval.at_least_vspace()
+ ct_aval_expected = aval.to_tangent_aval()
if (not core.typecompat(ct_aval, ct_aval_expected) and
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
raise ValueError(
@@ -2695,11 +2454,8 @@ class ShapeDtypeStruct:
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"]
- named_shape = {} # type: ignore
- def __init__(self, shape, dtype, named_shape=None, sharding=None,
- weak_type=False):
- del named_shape # ignored, vestigial
+ def __init__(self, shape, dtype, *, sharding=None, weak_type=False):
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
@@ -2970,7 +2726,8 @@ def clear_backends():
pjit._infer_params_cached.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
- pjit._cpp_pjit_cache.clear()
+ pjit._cpp_pjit_cache_fun_only.clear()
+ pjit._cpp_pjit_cache_explicit_attributes.clear()
xc._xla.PjitFunctionCache.clear_all()
@atexit.register
@@ -2999,7 +2756,8 @@ def clear_caches():
util.clear_all_weakref_lru_caches()
# Clear all C++ compiled executable caches for pjit
- pjit._cpp_pjit_cache.clear()
+ pjit._cpp_pjit_cache_fun_only.clear()
+ pjit._cpp_pjit_cache_explicit_attributes.clear()
pjit._infer_params_cached.cache_clear()
xc._xla.PjitFunctionCache.clear_all()
diff --git a/jax/_src/array.py b/jax/_src/array.py
index 909f5acf0d43..9e5595aacca3 100644
--- a/jax/_src/array.py
+++ b/jax/_src/array.py
@@ -28,6 +28,7 @@
from jax._src import basearray
from jax._src import config
from jax._src import core
+from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
@@ -115,6 +116,16 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
np_value = fun(*args)
np_value.__setstate__(arr_state)
jnp_value = api.device_put(np_value)
+ # TODO(slebedev): Remove this branch after December 10th 2024.
+ if "named_shape" in aval_state:
+ deprecations.warn(
+ "jax-aval-named-shape",
+ "Pickled array contains an aval with a named_shape attribute. This is"
+ " deprecated and the code path supporting such avals will be removed."
+ " Please re-pickle the array.",
+ stacklevel=2,
+ )
+ del aval_state["named_shape"]
jnp_value.aval = jnp_value.aval.update(**aval_state)
return jnp_value
diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py
index 6fdf0c600b7d..9bce9d0e4308 100644
--- a/jax/_src/cache_key.py
+++ b/jax/_src/cache_key.py
@@ -83,7 +83,8 @@ def get(module: ir.Module,
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
"""
entries = [
- ("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
+ ("computation",
+ lambda hash_obj: _hash_computation(hash_obj, module)),
("jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8")))),
@@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
)
+def _remove_custom_partitioning_ptr(m: ir.Module):
+ """
+ Removes custom_partitioning callback pointer from precompiled IR.
+ Python function pointers are not deterministic across executions.
+ """
+ def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
+ if (op.name == "stablehlo.custom_call" and
+ op.attributes["call_target_name"].value == "CustomSPMDPartitioning"):
+ op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
+ return ir.WalkResult.ADVANCE
+
+ m.operation.walk(_update_bc_attribute)
+ return m
+
+
def _serialize_ir(m: ir.Module) -> bytes:
output = io.BytesIO()
+ if config.remove_custom_partitioning_ptr_from_cache_key.value:
+ m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
+ m.operation.clone()))
m.operation.write_bytecode(file=output)
return output.getvalue()
diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py
index 1167914e51c9..e67f624fc32e 100644
--- a/jax/_src/checkify.py
+++ b/jax/_src/checkify.py
@@ -980,7 +980,7 @@ def jvp(*xs):
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
- out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
+ out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py
index 5b39994c7523..6033e1bbb928 100644
--- a/jax/_src/cloud_tpu_init.py
+++ b/jax/_src/cloud_tpu_init.py
@@ -80,8 +80,7 @@ def cloud_tpu_init() -> None:
os.environ['TPU_ML_PLATFORM'] = 'JAX'
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
- if hardware_utils.tpu_enhanced_barrier_supported():
- os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
+ os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
# this makes tensorstore serialization work better on TPU
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py
index 8117f871a969..b946dc0a2897 100644
--- a/jax/_src/compilation_cache.py
+++ b/jax/_src/compilation_cache.py
@@ -265,7 +265,9 @@ def put_executable_and_time(
cache.put(cache_key, executable_and_time)
-def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
+def get_cache_key(module: ir.Module,
+ devices: np.ndarray,
+ compile_options,
backend) -> str:
return cache_key.get(module, devices, compile_options, backend,
"zstandard" if zstandard is not None else "zlib")
diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py
index 81457f1cbd07..108741b5f8fd 100644
--- a/jax/_src/compiler.py
+++ b/jax/_src/compiler.py
@@ -33,7 +33,6 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
-from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np
@@ -157,8 +156,7 @@ def get_compile_options(
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
- if xla_extension_version >= 280:
- build_options.use_shardy_partitioner = use_shardy_partitioner
+ build_options.use_shardy_partitioner = use_shardy_partitioner
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
diff --git a/jax/_src/config.py b/jax/_src/config.py
index b2d1aa52ef2a..fe56ec68f6cb 100644
--- a/jax/_src/config.py
+++ b/jax/_src/config.py
@@ -1347,6 +1347,16 @@ def _update_jax_memories_thread_local(val):
'size to grow indefinitely.'),
)
+remove_custom_partitioning_ptr_from_cache_key = bool_state(
+ name='jax_remove_custom_partitioning_ptr_from_cache_key',
+ default=False,
+ help=('If set to True, remove the custom partitioning pointer '
+ 'present in the precompiled stableHLO before hashing '
+ 'during cache key computation. This is a potentially '
+ 'unsafe flag to set and only users who are sure of '
+ 'what they are trying to achieve should set it.'),
+)
+
default_dtype_bits = enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
@@ -1368,6 +1378,15 @@ def _update_jax_memories_thread_local(val):
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_dtype_promotion=val))
+disallow_mesh_context_manager = bool_state(
+ name='jax_disallow_mesh_context_manager',
+ default=False,
+ help=(
+ 'If set to True, trying to use a mesh as a context manager will'
+ ' result in a RuntimeError.'
+ ),
+)
+
def _update_x64_global(val):
lib.jax_jit.global_state().enable_x64 = val
@@ -1710,10 +1729,8 @@ def _update_debug_log_modules(module_names_str: str | None):
pmap_no_rank_reduction = bool_state(
name='jax_pmap_no_rank_reduction',
- default=False,
- help=(
- "If True, pmap shards have a the same rank as their enclosing array."
- )
+ default=True,
+ help='If True, pmap shards have a the same rank as their enclosing array.',
)
use_shardy_partitioner = bool_state(
diff --git a/jax/_src/core.py b/jax/_src/core.py
index ef3ace2e0e31..057a79925e2e 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -283,9 +283,9 @@ def manager(self):
def __repr__(self):
return (
- f"JaxprEqnContext(compute_type={self.compute_type},"
- f"threefry_partitionable={self.threefry_partitionable}),"
- f"xla_metadata={self.xla_metadata}"
+ f"JaxprEqnContext(compute_type={self.compute_type}, "
+ f"threefry_partitionable={self.threefry_partitionable}, "
+ f"xla_metadata={self.xla_metadata})"
)
@@ -343,8 +343,7 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
- xla_metadata_lib.current_xla_metadata(),
- )
+ xla_metadata_lib.current_xla_metadata())
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
@@ -1415,9 +1414,13 @@ def definitely_equal(x, y):
class AbstractValue:
__slots__: list[str] = []
- def at_least_vspace(self):
+ def to_tangent_aval(self):
raise NotImplementedError("must override")
+ # TODO(dougalm): deprecate this alias
+ def at_least_vspace(self):
+ return self.to_tangent_aval()
+
def __repr__(self):
try:
kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items())
@@ -1525,6 +1528,12 @@ def get_aval(x):
else:
return concrete_aval(x)
+def get_type(x):
+ aval = get_aval(x)
+ if isinstance(aval, ConcreteArray):
+ return raise_to_shaped(aval)
+ else:
+ return aval
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
@@ -1648,7 +1657,7 @@ def __repr__(self):
_oct = concretization_function_error(oct)
_index = concretization_function_error(operator.index)
- def at_least_vspace(self) -> AbstractValue:
+ def to_tangent_aval(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
@@ -1745,20 +1754,15 @@ def _invalid_shape_error(shape: Shape, context: str=""):
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding'] # inherits slots from parent
array_abstraction_level = 2
- named_shape = {} # type: ignore
- def __init__(self, shape, dtype, weak_type=False, named_shape=None,
- sharding=None):
- del named_shape # unused, vestigial
+ def __init__(self, shape, dtype, weak_type=False, sharding=None):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
if config.sharding_in_types.value:
self.sharding = sharding
- def update(self, shape=None, dtype=None, weak_type=None, named_shape=None,
- sharding=None):
- del named_shape # unused, vestigial
+ def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
if shape is None:
shape = self.shape
if dtype is None:
@@ -1792,7 +1796,7 @@ def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type,
getattr(self, 'sharding', None)))
- def at_least_vspace(self):
+ def to_tangent_aval(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
@@ -1951,7 +1955,7 @@ def join(self, other):
else:
raise TypeError(self, other)
- def at_least_vspace(self):
+ def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
@@ -2082,7 +2086,7 @@ def join(self, other):
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
- def at_least_vspace(self): return self
+ def to_tangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py
index 64a37b782358..05ede08d219c 100644
--- a/jax/_src/custom_derivatives.py
+++ b/jax/_src/custom_derivatives.py
@@ -67,7 +67,7 @@ def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)
def _zeros_like_pytree(x):
- return tree_map(Zero.from_value, x)
+ return tree_map(Zero.from_primal_value, x)
_stop_gradient = partial(
tree_map,
@@ -327,24 +327,27 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
- # TODO(mattjj): compare primals' tangent types to tangent objects' types
- primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False)
- for x in primals_out]
+ primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
+ expected_tangent_avals_out = [
+ raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval()
+ for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
- if primal_avals_out != tangent_avals_out:
- if len(primal_avals_out) == 1:
- (av1,), (av2,) = primal_avals_out, tangent_avals_out
+ if expected_tangent_avals_out != tangent_avals_out:
+ if len(expected_tangent_avals_out) == 1:
+ (av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out
msg = ("Custom JVP rule must produce primal and tangent outputs with "
- "equal shapes and dtypes, but got {} and {} respectively.")
- raise TypeError(msg.format(av1.str_short(), av2.str_short()))
+ "corresponding shapes and dtypes. Expected {} (tangent type of {}) but got {}.")
+ raise TypeError(msg.format(av_et.str_short(), av_p.str_short(), av_t.str_short()))
else:
msg = ("Custom JVP rule must produce primal and tangent outputs with "
- "equal shapes and dtypes, but got:\n{}")
+ "corresponding shapes and dtypes, but got:\n{}")
disagreements = (
- f" primal {av1.str_short()} for tangent {av2.str_short()}"
- for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
+ f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
+ for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
+ if av_et != av_t)
+
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, (out_tree, primal_avals)
@@ -392,7 +395,7 @@ def jvp(*xs):
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
- out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
+ out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
@@ -780,10 +783,10 @@ def append(x, d):
raise TypeError(msg.format(in_tree2, in_tree)) from None
results = []
for kp, a, ct in zip(keypaths, in_avals, cts_in_flat):
- if ct is zero or a != a.at_least_vspace():
- results.append(Zero(a.at_least_vspace()))
+ if ct is zero or a != a.to_tangent_aval():
+ results.append(Zero(a.to_tangent_aval()))
elif type(ct) is SymbolicZero:
- if not core.typecompat(a.at_least_vspace(), a_ := ct.aval):
+ if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval):
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
"that does not match the corresponding input tangent shape/dtype: "
f"at output{keystr(kp)} the SymbolicZero had shape/dtype "
@@ -794,7 +797,7 @@ def append(x, d):
raise ValueError(msg)
results.append(Zero(ct.aval))
else:
- if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
+ if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct))
and not (_temporary_dtype_exception(a, a_) or
_temporary_shape_exception(a, a_))):
msg = ("Custom VJP bwd rule must produce an output with the same "
@@ -908,16 +911,12 @@ def _custom_vjp_call_jaxpr_jvp(
_, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
- avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
+ avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
- # Cast float0 to zeros with the primal dtype because custom vjp rules don't
- # currently handle float0s
- args_dot = map(ad.replace_float0s, args, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
- tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
@@ -1039,7 +1038,7 @@ def fwd(*args, **kwargs):
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
- ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
+ ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
@@ -1153,7 +1152,7 @@ def _maybe_perturbed(x: Any) -> bool:
elif isinstance(x, pe.DynamicJaxprTracer):
# If x is a DynamicJaxprTracer then we're staging out; differentiation could
# happen later, but some types always have trivial tangents.
- vspace = x.aval.at_least_vspace()
+ vspace = x.aval.to_tangent_aval()
return not (vspace is core.abstract_token or
getattr(vspace, 'dtype', None) == dtypes.float0)
elif not isinstance(x, ad.JVPTracer):
@@ -1176,7 +1175,12 @@ def converted_fun(*args_hconsts):
args, hoisted_consts = split_list(args_hconsts, [num_args])
consts = merge(closure_consts, hoisted_consts)
all_args, in_tree2 = tree_flatten(tuple(args))
- assert in_tree == in_tree2
+ if in_tree != in_tree2:
+ msg = ("The inputs to the closure produced by closure_convert must have "
+ "the same Pytree structure as the example arguments passed when "
+ f"closure_convert was called. Expected {in_tree}, but got "
+ f"{in_tree2}")
+ raise TypeError(msg)
out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
return tree_unflatten(out_tree, out_flat)
@@ -1420,7 +1424,7 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd):
@fun.defjvp
def jvp(primals, tangents):
outs, residuals = fwd(*primals)
- tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs)
+ tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs)
tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types))
tan_fn.def_transpose(bwd)
return outs, tan_fn(tan_out_types, residuals, tangents)
diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py
index f6b0a81baf92..1efeed73cbc8 100644
--- a/jax/_src/debugger/core.py
+++ b/jax/_src/debugger/core.py
@@ -112,6 +112,11 @@ def from_frameinfo(cls, frame_info) -> DebuggerFrame:
# then we subtract it off from the `lineno` and don't need to subtract 1
# since both start and lineno are 1-indexed.
offset = frame_info.lineno - max(start, 1)
+ if offset >= len(source):
+ # Sometimes we don't get a valid source/offset pair. This seems to
+ # happen sometimes when code uses eval(). If that happens, give up.
+ source = []
+ offset = None
except OSError:
source = []
offset = None
diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py
index 3e7082ab10ec..3373496940e2 100644
--- a/jax/_src/debugging.py
+++ b/jax/_src/debugging.py
@@ -46,6 +46,7 @@
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
+from jax._src.api_util import shaped_abstractify
from jax._src.state import discharge as state_discharge
logger = logging.getLogger(__name__)
@@ -256,12 +257,29 @@ def debug_callback(callback: Callable[..., None], *args: Any,
raise TypeError("first argument to jax.debug.callback must be callable, "
f"but got an object of type {type(callback)}")
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
- effect = ordered_debug_effect if ordered else debug_effect
- def _flat_callback(*flat_args):
- args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
+ static_args, dyn_args = {}, []
+ for i, a in enumerate(flat_args):
+ try:
+ shaped_abstractify(a)
+ dyn_args.append(a)
+ except (AssertionError, TypeError):
+ static_args[i] = a
+
+ def _flat_callback(*dyn_args):
+ all_args = [None] * (len(static_args) + len(dyn_args))
+ di = iter(dyn_args)
+ for i in range(len(all_args)):
+ if i in static_args:
+ all_args[i] = static_args[i]
+ else:
+ all_args[i] = next(di)
+ assert next(di, None) is None
+ args, kwargs = tree_util.tree_unflatten(in_tree, all_args)
callback(*args, **kwargs)
return ()
- debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect)
+
+ effect = ordered_debug_effect if ordered else debug_effect
+ debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect)
class _DebugPrintFormatChecker(string.Formatter):
diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py
index 4e01c88afd1f..5f1d132bcbb3 100644
--- a/jax/_src/deprecations.py
+++ b/jax/_src/deprecations.py
@@ -121,6 +121,7 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
# Register a number of deprecations: we do this here to ensure they're
# always registered by the time `accelerate` and `is_acelerated` are called.
+register('jax-aval-named-shape')
register('jax-dlpack-import-legacy')
register("jax-numpy-astype-complex-to-real")
register("jax-numpy-array-none")
@@ -131,3 +132,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
register('jax-numpy-linalg-matrix_rank-tol')
register('jax-numpy-linalg-pinv-rcond')
register('jax-numpy-quantile-interpolation')
+register('jax-numpy-trimzeros-not-1d-array')
diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py
index 81f4180a1c12..d76b80ad3a89 100644
--- a/jax/_src/dtypes.py
+++ b/jax/_src/dtypes.py
@@ -784,7 +784,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
uint2,
uint4,
]
- if np_dtype.kind not in "biufc" and not is_custom_dtype:
+ if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py
index d0159f7a4334..7f7773acbd39 100644
--- a/jax/_src/export/_export.py
+++ b/jax/_src/export/_export.py
@@ -1127,7 +1127,7 @@ def flattened_primal_fun_jax(*args_flat):
vjp_in_avals = list(
itertools.chain(in_avals,
- map(lambda a: a.at_least_vspace(), out_avals)))
+ map(lambda a: a.to_tangent_aval(), out_avals)))
if apply_jit:
assert device_assignment is not None
diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py
index 7ab5de297752..81ef07a71b19 100644
--- a/jax/_src/hardware_utils.py
+++ b/jax/_src/hardware_utils.py
@@ -32,13 +32,6 @@
'0x006f',
]
-_TPU_ENHANCED_BARRIER_SUPPORTED = [
- # TPU v2, v3
- '0x0027',
- # TPU v4
- '0x005e',
-]
-
_NVIDIA_GPU_DEVICES = [
'/dev/nvidia0',
'/dev/nvidiactl', # Docker/Kubernetes
@@ -62,12 +55,6 @@ def num_available_tpu_chips_and_device_id():
return num_chips, device_id
-def tpu_enhanced_barrier_supported() -> bool:
- """Returns if tpu_enhanced_barrier flag is supported on this TPU version."""
- _, device_id = num_available_tpu_chips_and_device_id()
- return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED
-
-
def has_visible_nvidia_gpu() -> bool:
"""True if there's a visible nvidia gpu available on device, False otherwise."""
diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py
index f1b25cf96a95..f1f46a5c18f7 100644
--- a/jax/_src/interpreters/ad.py
+++ b/jax/_src/interpreters/ad.py
@@ -57,7 +57,7 @@ def _update_annotation(
# Implicit arguments never have tangents, so generate the tangent part of the
# type annotation from explicit arguments only.
explicit_avals = [aval for aval, explicit in orig_type if explicit]
- tan_types = [(aval.at_least_vspace(), True)
+ tan_types = [(aval.to_tangent_aval(), True)
for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz]
return lu.annotate(f, (*orig_type, *tan_types))
@@ -72,7 +72,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
@lu.transformation
def jvpfun(instantiate, transform_stack, primals, tangents):
- tangents = [Zero.from_value(t) if not isinstance(t, Zero)
+ tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents]
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext())
@@ -124,7 +124,7 @@ def linearize(traceable, *primals, **kwargs):
jvpfun, aux = jvp(traceable, has_aux=True)
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
- + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
+ + tuple(pe.PartialVal.unknown(get_aval(p).to_tangent_aval())
for p in primals))
_, in_tree = tree_flatten(((primals, primals), {}))
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
@@ -166,18 +166,6 @@ def unpair_pval(pval):
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)
-def replace_float0s(primal, tangent):
- if dtype(tangent) == float0:
- return zeros_like_jaxval(primal)
- else:
- return tangent
-
-def recast_to_float0(primal, tangent):
- if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
- return Zero(get_aval(primal).at_least_vspace())
- else:
- return tangent
-
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
# errors if you will)
@@ -203,7 +191,7 @@ def write_cotangent(prim, v, ct):
# assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
def read_cotangent(v):
- return ct_env.pop(v, Zero(v.aval.at_least_vspace()))
+ return ct_env.pop(v, Zero(v.aval.to_tangent_aval()))
def read_primal(v):
if type(v) is Literal:
@@ -295,11 +283,11 @@ def nonzero_tangent_outputs(*args, **kwargs):
class JVPTrace(Trace):
def pure(self, val):
- tangent_zero = Zero(get_aval(val).at_least_vspace())
+ tangent_zero = Zero.from_primal_value(val)
return JVPTracer(self, val, tangent_zero)
def lift(self, val):
- tangent_zero = Zero(get_aval(val).at_least_vspace())
+ tangent_zero = Zero.from_primal_value(val)
return JVPTracer(self, val, tangent_zero)
def sublift(self, val):
@@ -343,7 +331,7 @@ def new_out_axes_thunk():
result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz),
*args, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
- tangent_out = [Zero(get_aval(p).at_least_vspace()) if t is None else t
+ tangent_out = [Zero.from_primal_value(p) if t is None else t
for p, t in zip(primal_out, tangent_out)]
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
@@ -374,13 +362,11 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros):
primals_in = map(core.full_lower, primals_in)
if not symbolic_zeros:
tangents_in = map(instantiate_zeros, tangents_in)
- tangents_in = map(replace_float0s, primals_in, tangents_in)
else:
tangents_in = map(replace_internal_symbolic_zeros, tangents_in)
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out)
- tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
def post_process_custom_jvp_call(self, out_tracers, _):
@@ -398,14 +384,13 @@ def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
res_and_primals_out = fwd.call_wrapped(*fwd_in)
_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
- avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
+ avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
- tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
def post_process_custom_vjp_call(self, out_tracers, _):
@@ -505,8 +490,8 @@ def linear_jvp(primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
if all(type(tangent) is Zero for tangent in tangents):
if primitive.multiple_results:
- return val_out, map(Zero.from_value, val_out)
- return val_out, Zero.from_value(val_out)
+ return val_out, map(Zero.from_primal_value, val_out)
+ return val_out, Zero.from_primal_value(val_out)
else:
tangents = map(instantiate_zeros, tangents)
return val_out, primitive.bind(*tangents, **params)
@@ -533,7 +518,7 @@ def standard_jvp(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and type(t) is not Zero]
- return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
+ return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out))
def defjvp2(primitive, *jvprules):
assert isinstance(primitive, Primitive)
@@ -545,7 +530,7 @@ def standard_jvp2(jvprules, primitive, primals, tangents, **params):
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and type(t) is not Zero)
tangents_out = list(tangents_out)
- return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
+ return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out))
def add_tangents(x, y):
if type(x) is Zero:
@@ -580,7 +565,7 @@ def defjvp_zero(primitive):
def zero_jvp(primitive, primals, tangents, **params):
r = primitive.bind(*primals, **params)
- return r, Zero.from_value(r)
+ return r, Zero.from_primal_value(r)
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
@@ -591,7 +576,7 @@ def instantiate_zeros(tangent):
@lu.transformation_with_aux
def traceable(in_tree, *primals_and_tangents):
primals, tangents = tree_unflatten(in_tree, primals_and_tangents)
- tangents = [Zero(get_aval(p).at_least_vspace()) if t is None else t
+ tangents = [Zero.from_primal_value(p) if t is None else t
for p, t in zip(primals, tangents)]
primals_out, tangents_out = yield (primals, tangents), {}
tangents_out = [None if type(t) is Zero else t for t in tangents_out]
@@ -695,7 +680,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
- tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
+ tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
@@ -705,7 +690,7 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents):
num_primals = len(nonzeros)
primals = list(primals_and_nztangents[:num_primals])
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
- tangents = [next(nonzero_tangents) if nz else Zero.from_value(p)
+ tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p)
for p, nz in zip(primals, nonzeros)]
primals_out, tangents_out = yield (primals, tangents), {}
out_nonzeros = [type(t) is not Zero for t in tangents_out]
diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py
index 5bb3e204ced0..6bc6539b9262 100644
--- a/jax/_src/interpreters/partial_eval.py
+++ b/jax/_src/interpreters/partial_eval.py
@@ -168,11 +168,6 @@ def new_instantiated_literal(self, val) -> JaxprTracer:
def new_instantiated_const(self, val) -> JaxprTracer:
aval = get_aval(val)
- if isinstance(aval, DShapedArray):
- shape = [self.new_instantiated_const(d)
- if isinstance(d, Tracer) and d._trace.level < self.level else d
- for d in aval.shape]
- aval = aval.update(shape=tuple(shape))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(val))
def new_arg(self, pval: PartialVal) -> JaxprTracer:
@@ -258,15 +253,9 @@ def process_call(self, primitive, f, tracers, params):
# which were unknown to the first call (corresponding to in_avals).
# Wrap f to perform the partial evaluation and plumb out aux data.
- if not config.dynamic_shapes.value:
- f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
- f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
- tuple(in_avals))
- else:
- if f.in_type is None:
- f = lu.annotate(f, tuple((a, True) for a in in_avals))
- f_, aux = trace_to_subjaxpr_nounits_dyn(f, self.main, tuple(in_knowns),
- f.in_type, False)
+ f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False)
+ f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns),
+ tuple(in_avals))
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
const_params = update_params(params, in_knowns, 0)
@@ -569,92 +558,6 @@ def partial_eval_wrapper_nounits(
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env)
-@lu.transformation_with_aux
-def trace_to_subjaxpr_nounits_dyn(
- main: core.MainTrace, in_knowns: Sequence[bool], in_type: InputType,
- instantiate: bool | Sequence[bool],
- *in_consts: Any):
- trace = main.with_cur_sublevel()
- in_avals, which_explicit = unzip2(in_type)
-
- # To form input tracers from in_type, we need to first build ConstVar tracers
- # for all axis sizes, so that we can then use those tracers in the shapes of
- # avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly
- # type agreement checking via get_referent.
- in_consts_full: list[JaxprTracer | None] = [None] * len(in_type)
- in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns)
- for idx, (aval, explicit) in enumerate(in_type):
- if explicit and next(in_knowns_iter):
- constval = next(in_consts_iter)
- if isinstance(aval, DShapedArray):
- for i, d in enumerate(aval.shape):
- if isinstance(d, DBIdx):
- if in_consts_full[d.val] is None:
- in_consts_full[d.val] = \
- JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]),
- ConstVar(constval.shape[i]))
- assert core.same_referent(constval.shape[i], in_consts_full[d.val])
- shape = [in_consts_full[d.val] if type(d) is DBIdx else d
- for d in aval.shape]
- aval = aval.update(shape=tuple(shape))
- in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval),
- ConstVar(constval))
- # Check that we covered all axis sizes with ConstVar tracers.
- for idx, (aval, explicit) in enumerate(in_type):
- if not explicit: assert in_consts_full[idx] is not None
- if isinstance(aval, DShapedArray):
- assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None
- for d in aval.shape)
-
- # Next, build tracers for all unknown inputs, using the in_consts_full list
- # for axis size tracers when necessary.
- in_tracers = []
- in_knowns_iter = iter(in_knowns)
- for aval, explicit in in_type:
- if explicit and not next(in_knowns_iter):
- if isinstance(aval, DShapedArray):
- shape = [in_consts_full[d.val] if type(d) is DBIdx else d
- for d in aval.shape]
- aval = aval.update(shape=tuple(shape))
- tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding())
- in_tracers.append(tracer)
-
- # Merge in_consts and in_tracers and call wrapped fn with explicit arguments.
- in_args = merge_lists(in_knowns, in_tracers, in_consts)
- ans = yield in_args, {}
-
- # Instantiate outputs and build jaxpr.
- if isinstance(instantiate, bool):
- instantiate = [instantiate] * len(ans)
- out_tracers = map(trace.full_raise, map(core.full_lower, ans))
- out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t
- for inst, t in zip(instantiate, out_tracers)]
-
- # Collect known outputs.
- out_knowns: list[bool] = [t.is_known() for t in out_tracers]
- out_consts: list[Any] = [t.pval.get_known() for t in out_tracers
- if t.is_known()]
-
- # Build the jaxpr.
- out_tracers = [t for t in out_tracers if not t.is_known()]
- jaxpr, res, env = tracers_to_jaxpr(in_tracers, out_tracers)
- out_avals = [v.aval for v in jaxpr.outvars]
- idx_map = {v: InDBIdx(i)
- for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
- out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore
- if type(a) is DShapedArray else a, True) for a in out_avals]
-
- # Which residuals are just forwarded inputs? Check obj id, then prune.
- id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore
- if c is not None}
- fwds: list[int | None] = [id_map.get(id(c)) for c in res]
- res = tuple(c for c, fwd in zip(res, fwds) if fwd is None)
-
- del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \
- in_tracers, in_args, ans, out_tracers, out_avals
- yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env)
-
-
custom_partial_eval_rules: dict[Primitive, Callable] = {}
call_partial_eval_rules: dict[Primitive, Callable] = {}
call_param_updaters: dict[Primitive, Callable] = {}
@@ -1536,6 +1439,18 @@ def _prune_closed_jaxpr_outputs(
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: bool | Sequence[bool] = False,
) -> tuple[Jaxpr, list[bool]]:
+ """Runs dead-code elementation on a given jaxpr.
+
+ Args:
+ jaxpr: The jaxpr to DCE.
+ used_outputs: A list of bools indicating which outputs are used.
+ instantiate: A bool or a list of bools indicating which inputs should be
+ considered used, regardless of whether they are actually used in a jaxpr.
+ If a bool, the same value is used for all inputs.
+
+ Returns:
+ A tuple of ``(new_jaxpr, used_inputs)``.
+ """
if type(instantiate) is bool:
instantiate = (instantiate,) * len(jaxpr.invars)
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))
@@ -1545,7 +1460,7 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: bool | Sequence[bool] = False,
) -> tuple[Jaxpr, list[bool], list[bool]]:
jaxpr_ = convert_constvars_jaxpr(jaxpr)
- new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs)
+ new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs, instantiate)
used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)])
if sum(used_consts):
new_jaxpr = convert_invars_to_constvars(new_jaxpr, sum(used_consts))
@@ -2050,11 +1965,9 @@ def process_primitive(self, primitive, tracers, params):
def default_process_primitive(self, primitive, tracers, params):
avals = [t.aval for t in tracers]
out_avals, effects = primitive.abstract_eval(*avals, **params)
- # == serve as a "not xor" here.
- if not (isinstance(out_avals, (tuple,list)) == primitive.multiple_results):
- raise ValueError(f"{primitive}.abstract_eval() method should return"
- f" a tuple or a list if {primitive}.multiple_results"
- " is true. Otherwise it shouldn't.")
+ if isinstance(out_avals, (tuple, list)) != primitive.multiple_results:
+ raise ValueError(f"{primitive}.abstract_eval() method should return "
+ f"a tuple or a list iff {primitive}.multiple_results.")
out_avals = [out_avals] if not primitive.multiple_results else out_avals
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
@@ -2148,6 +2061,7 @@ def post_process_map(self, map_primitive, out_tracers, params):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
in_avals = [t.aval for t in tracers]
+ in_tangent_avals = [t.to_tangent_aval() for t in in_avals]
with core.new_sublevel():
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
@@ -2156,7 +2070,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
@_memoize
def jvp_jaxpr_thunk(*in_zeros):
for store in jvp.stores: store and store.reset()
- nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals)
+ nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals)
jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals))
in_avals_ = (*in_avals, *nz_tangent_avals)
jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
@@ -2818,8 +2732,7 @@ def inline_jaxpr_into_trace(
outvars = [Var('', v.aval) for v in eqn.outvars]
src_ = (src if not eqn.source_info.name_stack else
src.replace(name_stack=src.name_stack + eqn.source_info.name_stack))
- trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive,
- eqn.params, eqn.effects, src_))
+ trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore
map(env.setdefault, eqn.outvars, outvars)
tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars],
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index 882f71d58671..944e20fa7faa 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -22,6 +22,7 @@
from collections.abc import Callable, Sequence, Iterable, Iterator
import dataclasses
from functools import partial, lru_cache, cached_property
+import functools
import itertools as it
import logging
import math
@@ -89,6 +90,7 @@ class WeakRefList(list):
logger = logging.getLogger(__name__)
Index = Union[int, slice, tuple[Union[int, slice], ...]]
+PyTreeDef = tree_util.PyTreeDef
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
@@ -2905,6 +2907,34 @@ class MeshExecutableFastpathData(NamedTuple):
in_device_local_layouts: Sequence[DeviceLocalLayout | None]
+@dataclasses.dataclass(frozen=True, kw_only=True)
+class JitGlobalCppCacheKeys:
+ donate_argnums: tuple[int, ...] | None = None
+ donate_argnames: tuple[str, ...] | None = None
+ device: xc.Device | None = None
+ backend: str | None = None
+ in_shardings_treedef: PyTreeDef | None = None
+ in_shardings_leaves: tuple[Any, ...] | None = None
+ out_shardings_treedef: PyTreeDef | None = None
+ out_shardings_leaves: tuple[Any, ...] | None = None
+ in_layouts_treedef: PyTreeDef | None = None
+ in_layouts_leaves: tuple[Any, ...] | None = None
+ out_layouts_treedef: PyTreeDef | None = None
+ out_layouts_leaves: tuple[Any, ...] | None = None
+ use_resource_env: bool = False
+
+ @functools.cached_property
+ def contains_explicit_attributes(self):
+ return (self.donate_argnums is not None or
+ self.donate_argnames is not None or
+ self.device is not None or
+ self.backend is not None or
+ any(not is_unspecified(i) for i in self.in_shardings_leaves) or
+ any(not is_unspecified(o) for o in self.out_shardings_leaves) or
+ any(i is not None for i in self.in_layouts_leaves) or
+ any(o is not None for o in self.out_layouts_leaves))
+
+
def reflatten_outputs_for_dispatch(out_tree, out_flat):
# We arrive at dispatch having flattened according to the default
# pytree registry, but we want to re-flatten according to our
@@ -3018,16 +3048,17 @@ def aot_cache_miss(*args, **kwargs):
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry
- return xc._xla.pjit(
- self.unsafe_call.name, None, aot_cache_miss, [], [], [],
- tree_util.dispatch_registry, cc_shard_arg)
+ if xla_extension_version >= 286:
+ return xc._xla.pjit(
+ self.unsafe_call.name, None, aot_cache_miss, [], [],
+ JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
+ else:
+ return xc._xla.pjit(
+ self.unsafe_call.name, None, aot_cache_miss, [], [], [],
+ tree_util.dispatch_registry, cc_shard_arg)
-if xla_extension_version < 282:
- def cc_shard_arg(x, sharding):
- return shard_args([sharding], [None], [x])[0]
-else:
- def cc_shard_arg(x, sharding, layout): # type: ignore
- return shard_args([sharding], [layout], [x])[0]
+def cc_shard_arg(x, sharding, layout):
+ return shard_args([sharding], [layout], [x])[0]
def check_arg_avals_for_call(ref_avals, arg_avals,
diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py
index f2dbd8d4fa0e..0e037ec774b5 100644
--- a/jax/_src/lax/ann.py
+++ b/jax/_src/lax/ann.py
@@ -373,7 +373,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
reduction_input_size_override,
aggregate_to_topk)
if type(tangent) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(val_out)
+ tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
arg_shape = arg_out.shape
rank = len(arg_shape)
@@ -385,7 +385,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
idx = tuple(
arg_out if i == reduction_dimension else iotas[i] for i in range(rank))
tangent_out = tangent[idx]
- return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out))
+ return (val_out, arg_out), (tangent_out, ad_util.Zero.from_primal_value(arg_out))
approx_top_k_p = core.Primitive('approx_top_k')
diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py
index b96f9e8c6e40..4cb38d28c36f 100644
--- a/jax/_src/lax/control_flow/conditionals.py
+++ b/jax/_src/lax/control_flow/conditionals.py
@@ -434,7 +434,7 @@ def _cond_jvp(primals, tangents, branches):
out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp)
out_primals, out_tangents = split_list(out, [len(out_nz)])
out_tangents_iter = iter(out_tangents)
- out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
+ out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents
diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py
index 61b9a24644ce..21b522b3d8bb 100644
--- a/jax/_src/lax/control_flow/for_loop.py
+++ b/jax/_src/lax/control_flow/for_loop.py
@@ -340,7 +340,7 @@ def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear,
# into outputs as well. We don't care about these in AD so we throw them out.
out_primals, out_tangents = split_list(out_flat, [len(primals)])
out_tangents_iter = iter(out_tangents)
- out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
+ out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(out_primals, nonzero_tangents)]
return out_primals, out_tangents
ad.primitive_jvps[for_p] = _for_jvp
diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py
index 828728ebdbd2..41d809f8d688 100644
--- a/jax/_src/lax/control_flow/loops.py
+++ b/jax/_src/lax/control_flow/loops.py
@@ -547,7 +547,7 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
tangents_out_iter = iter(carry_dot + ys_dot)
- tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
+ tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(primals_out, nonzeros_out)]
return primals_out, tangents_out
@@ -1518,7 +1518,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
out_carry, out_carry_dot = split_list(out, [num_carry])
out_tangents_iter = iter(out_carry_dot)
- out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
+ out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(out_carry, nonzeros_out)]
return out_carry, out_tangents
diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py
index 21105e20aaf8..4e0f5086b121 100644
--- a/jax/_src/lax/control_flow/solves.py
+++ b/jax/_src/lax/control_flow/solves.py
@@ -316,7 +316,7 @@ def _tangent_linear_map(func, params, params_dot, *x):
this function computes ``∂A @ x``.
"""
assert any(type(p) is not ad_util.Zero for p in params_dot)
- zeros = _map(ad_util.Zero.from_value, x)
+ zeros = _map(ad_util.Zero.from_primal_value, x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent
@@ -352,7 +352,7 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
# split into x tangents and aux tangents (these become zero)
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
- daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
+ daux_leaves = _map(ad_util.Zero.from_primal_value, daux_leaves)
x_dot = dx_leaves + daux_leaves
diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py
index 8d2c24d6e64c..48af9c64ffc9 100644
--- a/jax/_src/lax/lax.py
+++ b/jax/_src/lax/lax.py
@@ -62,6 +62,7 @@
standard_multi_result_abstract_eval, standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
+from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@@ -2014,7 +2015,15 @@ def _tan_impl(x):
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
-mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
+# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this
+# lowering is mostly supported, but it fails on export or with the PJRT plugin
+# because those modes target an older StableHLO version, and the
+# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't
+# included in the 0.4.33 release.
+if jaxlib_version <= (0, 4, 33):
+ mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
+else:
+ mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
@@ -2300,7 +2309,7 @@ def _add_jvp(primals, tangents):
xdot, ydot = tangents
primal_out = add(x, y)
if type(xdot) is type(ydot) is ad_util.Zero:
- return primal_out, ad_util.Zero.from_value(primal_out)
+ return primal_out, ad_util.Zero.from_primal_value(primal_out)
if type(xdot) is ad_util.Zero:
return primal_out, _maybe_broadcast(primal_out.shape, ydot)
elif type(ydot) is ad_util.Zero:
@@ -2331,7 +2340,7 @@ def _sub_jvp(primals, tangents):
xdot, ydot = tangents
primal_out = sub(x, y)
if type(xdot) is type(ydot) is ad_util.Zero:
- return primal_out, ad_util.Zero.from_value(primal_out)
+ return primal_out, ad_util.Zero.from_primal_value(primal_out)
if type(xdot) is ad_util.Zero:
return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot))
elif type(ydot) is ad_util.Zero:
@@ -3355,7 +3364,7 @@ def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions
y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape,
broadcast_dimensions=broadcast_dimensions)
if type(operand_dot) is ad_util.Zero:
- y_dot = ad_util.Zero.from_value(y)
+ y_dot = ad_util.Zero.from_primal_value(y)
else:
y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape,
broadcast_dimensions=broadcast_dimensions)
@@ -4525,7 +4534,7 @@ def _top_k_jvp(primals, tangents, *, k):
tangent, = tangents
primals_out = top_k(operand, k)
if type(tangent) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(primals_out[0])
+ tangent_out = ad_util.Zero.from_primal_value(primals_out[0])
else:
_, k_idxs = primals_out
idx_shape = k_idxs.shape
@@ -4544,7 +4553,7 @@ def _top_k_jvp(primals, tangents, *, k):
collapsed_slice_dims=tuple(range(rank)),
start_index_map=tuple(range(rank)))
tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes)
- return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1]))
+ return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1]))
def _top_k_batch_rule(batched_args, batch_dims, *, k):
operand, = batched_args
@@ -4580,7 +4589,7 @@ def _top_k_lower(ctx, operand, k):
def _stop_gradient_jvp_rule(primals, tangents):
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
x, = primals
- return stop_gradient(x), ad_util.Zero.from_value(x)
+ return stop_gradient(x), ad_util.Zero.from_primal_value(x)
def _stop_gradient_batch_rule(batched_args, batch_dims):
x, = batched_args
diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py
index 0cc0e774af53..ec0a075dae1b 100644
--- a/jax/_src/lax/linalg.py
+++ b/jax/_src/lax/linalg.py
@@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand):
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
- # TODO(b/344892332): Remove the check after the compatibility period.
- if jaxlib_version < (0, 4, 31):
- ctx_arg = ()
- else:
- ctx_arg = (ctx,)
+ ctx_arg = (ctx,)
result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand,
lower=True, a_shape_vals=op_shape_vals)
@@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector):
def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector):
# TODO(b/360781533): Remove guard after 3 week forward compatibility period.
- if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32):
+ if ctx.is_forward_compat():
r_matrix_aval, _ = ctx.avals_in
try:
[platform] = ctx.module_context.platforms
@@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
- # TODO(b/344892332): Remove the conditional after the compatibility period.
- ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
+ ctx_args = (ctx,)
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
@@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering(
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
cpu_args = []
if platform == "cpu":
- # TODO(b/344892332): Remove the conditional after the compatibility period.
- ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
+ ctx_args = (ctx,)
cpu_args.extend(ctx_args)
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
a_shape_vals=op_shape_vals, lower=lower)
@@ -1493,8 +1487,8 @@ def _lu_jvp_rule(primals, tangents):
l_dot = l @ _tril(lau, -1)
u_dot = _triu(lau) @ u
lu_dot = l_dot + u_dot
- return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
- ad_util.Zero.from_value(permutation))
+ return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_primal_value(pivots),
+ ad_util.Zero.from_primal_value(permutation))
def _lu_batching_rule(batched_args, batch_dims):
@@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str,
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
m = operand_aval.shape[-2]
- # TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum
- # version and the forward compat flag after the 3 week compatibility window.
- if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat():
+ # TODO(b/357034884): Remove version gate on the forward compat flag after the
+ # 3 week compatibility window.
+ if ctx.is_forward_compat():
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for native lowering for lu on CPU and GPU is "
@@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
- # TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
- (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
+ (ctx,) if platform == "cpu" else ()
)
a_out, taus, *maybe_info_geqrf = geqrf_impl(
*ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals
@@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
else:
- # TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
- (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
+ (ctx,) if platform == "cpu" else ()
)
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
@@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering(
compute_uv=compute_uv)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
- # TODO(b/344892332): Remove the conditional after the compatibility period.
- ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
+ ctx_args = (ctx,)
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py
index 2a3a63e89a35..5ed1945ecb96 100644
--- a/jax/_src/lax/slicing.py
+++ b/jax/_src/lax/slicing.py
@@ -1362,7 +1362,7 @@ def _dynamic_update_slice_jvp(primals, tangents):
g_operand, g_update = tangents[:2]
val_out = dynamic_update_slice_p.bind(operand, update, *start_indices)
if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(val_out)
+ tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
g_operand = ad.instantiate_zeros(g_operand)
g_update = ad.instantiate_zeros(g_update)
@@ -2000,7 +2000,7 @@ def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(val_out)
+ tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
@@ -2180,7 +2180,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
unique_indices=unique_indices, mode=mode)
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(val_out)
+ tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
@@ -2294,7 +2294,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
update_consts=update_consts, dimension_numbers=dnums,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
- return val_out, ad_util.Zero.from_value(val_out)
+ return val_out, ad_util.Zero.from_primal_value(val_out)
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
@@ -2384,7 +2384,7 @@ def _scatter_transpose_rule(t, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
if not unique_indices:
- raise NotImplementedError("scatter transpose is only implemented where"
+ raise NotImplementedError("scatter transpose is only implemented where "
"unique_indices=True")
assert not ad.is_undefined_primal(indices)
if ad.is_undefined_primal(updates):
diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py
index dd8e664a095a..089a77de2949 100644
--- a/jax/_src/lax/windowed_reductions.py
+++ b/jax/_src/lax/windowed_reductions.py
@@ -707,7 +707,7 @@ def _select_and_scatter_add_jvp(
padding)
del g_operand
if type(g_source) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(val_out)
+ tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
tangent_out = _select_and_scatter_add(
g_source, operand, select_prim, window_dimensions,
@@ -952,7 +952,7 @@ def _select_and_gather_add_jvp(
padding, base_dilation, window_dilation)
del g_operand
if type(g_source) is ad_util.Zero:
- tangent_out = ad_util.Zero.from_value(val_out)
+ tangent_out = ad_util.Zero.from_primal_value(val_out)
else:
tangent_out = _select_and_gather_add(
g_source, operand, select_prim, window_dimensions,
diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py
index b30286b36a76..20234b678172 100644
--- a/jax/_src/mesh.py
+++ b/jax/_src/mesh.py
@@ -217,6 +217,8 @@ def __setattr__(self, name, value):
super().__setattr__(name, value)
def __enter__(self):
+ if jax_config.disallow_mesh_context_manager.value:
+ raise RuntimeError("Mesh context manager is disabled.")
new_env = thread_resources.stack[-1].with_mesh(self)
thread_resources.stack.append(new_env)
thread_resources.env = new_env
diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py
index a5b5aaf31799..c1f4831e5ec0 100644
--- a/jax/_src/nn/functions.py
+++ b/jax/_src/nn/functions.py
@@ -785,6 +785,14 @@ def _get_causal_mask(T, S):
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return mask[None, None, :, :]
+def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]):
+ query_pos = jnp.array(range(T))
+ key_pos = jnp.array(range(S))
+ left_window, right_window = local_window_size
+ left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window
+ right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window
+ return jnp.logical_and(right_mask, left_mask)[None, None, :, :]
+
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
q_mask = True
kv_mask = True
@@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen):
mask = q_indices < q_seqlen[:, None]
return mask[:, :, None, None]
-def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
+def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
+ local_window_size):
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
return logits
@@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
mask = _get_causal_mask(T, S)
combined_mask = jnp.logical_and(combined_mask, mask)
+ if local_window_size is not None:
+ mask = _get_window_mask(T, S, local_window_size)
+ combined_mask = jnp.logical_and(combined_mask, mask)
+
if q_seqlen is not None or kv_seqlen is not None:
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
combined_mask = jnp.logical_and(combined_mask, mask)
@@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
return padded_logits
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
- scale, q_seqlen, kv_seqlen):
+ scale, q_seqlen, kv_seqlen, local_window_size):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
@@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
if bias is not None:
logits = (logits + bias).astype(logits.dtype)
- padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen)
+ padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
+ local_window_size)
# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(jnp.float32)
@@ -857,7 +871,8 @@ def _dot_product_attention_xla(
is_causal: bool,
scale: float,
q_seqlen: Array | None,
- kv_seqlen: Array | None):
+ kv_seqlen: Array | None,
+ local_window_size: tuple[int, int] | None):
B, T, N, H = query.shape
_, S, K, _ = key.shape
@@ -875,11 +890,13 @@ def _reshape_to_grouped(t):
return t
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
- vmapped_fn = jax.vmap(_dot_product_attention_core,
- in_axes=(3, None, None, 2, 2, None, None, None, None),
- out_axes=3)
+ vmapped_fn = jax.vmap(
+ _dot_product_attention_core,
+ in_axes=(3, None, None, 2, 2, None, None, None, None, None),
+ out_axes=3,
+ )
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
- q_seqlen, kv_seqlen)
+ q_seqlen, kv_seqlen, local_window_size)
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded
@@ -894,6 +911,7 @@ def dot_product_attention(
is_causal: bool = False,
query_seq_lengths: ArrayLike | None = None,
key_value_seq_lengths: ArrayLike | None = None,
+ local_window_size: int | tuple[int, int] | None = None,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
r"""Scaled dot product attention function.
@@ -943,6 +961,12 @@ def dot_product_attention(
:code:`(B)`
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
shape :code:`(B)`
+ local_window_size: Window sizes to make self attention to attend to each
+ token's local window. If set, this specifies the (left_window_size,
+ right_window_size) for each token. E.g., if local_window_size == (3, 2)
+ and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
+ to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as
+ a symmetric window (window_size, window_size).
implementation: A string to control which implementation backend to use.
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
to `None`, which will automatically select the best available backend.
@@ -969,6 +993,8 @@ def _ensure_4d(t):
query_seq_lengths = jnp.asarray(query_seq_lengths)
if key_value_seq_lengths is not None:
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
+ if isinstance(local_window_size, int):
+ local_window_size = (local_window_size, local_window_size)
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
dtype: DType | None, name: str) -> None:
@@ -1003,6 +1029,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
+ local_window_size=local_window_size,
)
case 'cudnn':
use_padding = (
@@ -1022,9 +1049,21 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
mask_type = MaskType.CAUSAL
elif use_padding:
mask_type = MaskType.PADDING
+ # CuDNN supports only the left window with an exclusive boundary when
+ # causal mask is enabled.
+ sliding_window = None
+ if local_window_size is not None:
+ l_window, r_window = local_window_size
+ if r_window == 0 or mask_type == MaskType.CAUSAL:
+ sliding_window = l_window + 1
+ else:
+ raise ValueError(f"cuDNN doesn't support right window: {r_window} "
+ "when causal mask is not used.")
+
out = cudnn_dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
- key_value_seq_lengths, scale=scale_val, mask_type=mask_type
+ key_value_seq_lengths, scale=scale_val, mask_type=mask_type,
+ sliding_window_length=sliding_window,
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
@@ -1033,6 +1072,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
+ local_window_size=local_window_size,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py
index 806022c8a34b..a1601e9201fe 100644
--- a/jax/_src/numpy/lax_numpy.py
+++ b/jax/_src/numpy/lax_numpy.py
@@ -1283,11 +1283,68 @@ def angle(z: ArrayLike, deg: bool = False) -> Array:
return ufuncs.degrees(result) if deg else result
-@util.implements(np.diff)
@partial(jit, static_argnames=('n', 'axis'))
def diff(a: ArrayLike, n: int = 1, axis: int = -1,
prepend: ArrayLike | None = None,
append: ArrayLike | None = None) -> Array:
+ """Calculate n-th order difference between array elements along a given axis.
+
+ JAX implementation of :func:`numpy.diff`.
+
+ The first order difference is computed by ``a[i+1] - a[i]``, and the n-th order
+ difference is computed ``n`` times recursively.
+
+ Args:
+ a: input array. Must have ``a.ndim >= 1``.
+ n: int, optional, default=1. Order of the difference. Specifies the number
+ of times the difference is computed. If n=0, no difference is computed and
+ input is returned as is.
+ axis: int, optional, default=-1. Specifies the axis along which the difference
+ is computed. The difference is computed along ``axis -1`` by default.
+ prepend: scalar or array, optional, defualt=None. Specifies the values to be
+ prepended along ``axis`` before computing the difference.
+ append: scalar or array, optional, defualt=None. Specifies the values to be
+ appended along ``axis`` before computing the difference.
+
+ Returns:
+ An array containing the n-th order difference between the elements of ``a``.
+
+ See also:
+ - :func:`jax.numpy.ediff1d`: Computes the differences between consecutive
+ elements of an array.
+ - :func:`jax.numpy.cumsum`: Computes the cumulative sum of the elements of
+ the array along a given axis.
+ - :func:`jax.numpy.gradient`: Computes the gradient of an N-dimensional array.
+
+ Examples:
+ ``jnp.diff`` computes the first order difference along ``axis``, by default.
+
+ >>> a = jnp.array([[1, 5, 2, 9],
+ ... [3, 8, 7, 4]])
+ >>> jnp.diff(a)
+ Array([[ 4, -3, 7],
+ [ 5, -1, -3]], dtype=int32)
+
+ When ``n = 2``, second order difference is computed along ``axis``.
+
+ >>> jnp.diff(a, n=2)
+ Array([[-7, 10],
+ [-6, -2]], dtype=int32)
+
+ When ``prepend = 2``, it is prepended to ``a`` along ``axis`` before computing
+ the difference.
+
+ >>> jnp.diff(a, prepend=2)
+ Array([[-1, 4, -3, 7],
+ [ 1, 5, -1, -3]], dtype=int32)
+
+ When ``append = jnp.array([[3],[1]])``, it is appended to ``a`` along ``axis``
+ before computing the difference.
+
+ >>> jnp.diff(a, append=jnp.array([[3],[1]]))
+ Array([[ 4, -3, 7, -6],
+ [ 5, -1, -3, -3]], dtype=int32)
+ """
util.check_arraylike("diff", a)
arr = asarray(a)
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
@@ -1337,16 +1394,58 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
return arr
-_EDIFF1D_DOC = """\
-Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
-issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
-loses precision.
-"""
-@util.implements(np.ediff1d, lax_description=_EDIFF1D_DOC)
@jit
def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None,
to_begin: ArrayLike | None = None) -> Array:
+ """Compute the differences of the elements of the flattened array.
+
+ JAX implementation of :func:`numpy.ediff1d`.
+
+ Args:
+ ary: input array or scalar.
+ to_end: scalar or array, optional, default=None. Specifies the numbers to
+ append to the resulting array.
+ to_begin: scalar or array, optional, default=None. Specifies the numbers to
+ prepend to the resulting array.
+
+ Returns:
+ An array containing the differences between the elements of the input array.
+
+ Note:
+ Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will
+ not issue an error if casting ``to_end`` or ``to_begin`` to the type of
+ ``ary`` loses precision.
+
+ See also:
+ - :func:`jax.numpy.diff`: Computes the n-th order difference between elements
+ of the array along a given axis.
+ - :func:`jax.numpy.cumsum`: Computes the cumulative sum of the elements of
+ the array along a given axis.
+ - :func:`jax.numpy.gradient`: Computes the gradient of an N-dimensional array.
+
+ Examples:
+ >>> a = jnp.array([2, 3, 5, 9, 1, 4])
+ >>> jnp.ediff1d(a)
+ Array([ 1, 2, 4, -8, 3], dtype=int32)
+ >>> jnp.ediff1d(a, to_begin=-10)
+ Array([-10, 1, 2, 4, -8, 3], dtype=int32)
+ >>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
+ Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32)
+ >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
+ Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
+
+ For array with ``ndim > 1``, the differences are computed after flattening
+ the input array.
+
+ >>> a1 = jnp.array([[2, -1, 4, 7],
+ ... [3, 5, -6, 9]])
+ >>> jnp.ediff1d(a1)
+ Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)
+ >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
+ >>> jnp.ediff1d(a2)
+ Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)
+ """
util.check_arraylike("ediff1d", ary)
arr = ravel(ary)
result = lax.sub(arr[1:], arr[:-1])
@@ -1774,9 +1873,40 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in safe_zip(shape, out_indices))
-@util.implements(np.resize)
+
@partial(jit, static_argnames=('new_shape',))
def resize(a: ArrayLike, new_shape: Shape) -> Array:
+ """Return a new array with specified shape.
+
+ JAX implementation of :func:`numpy.resize`.
+
+ Args:
+ a: input array or scalar.
+ new_shape: int or tuple of ints. Specifies the shape of the resized array.
+
+ Returns:
+ A resized array with specified shape. The elements of ``a`` are repeated in
+ the resized array, if the resized array is larger than the original aray.
+
+ See also:
+ - :func:`jax.numpy.reshape`: Returns a reshaped copy of an array.
+ - :func:`jax.numpy.repeat`: Constructs an array from repeated elements.
+
+ Examples:
+ >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
+ >>> jnp.resize(x, (3, 3))
+ Array([[1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9]], dtype=int32)
+ >>> jnp.resize(x, (3, 4))
+ Array([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9, 1, 2, 3]], dtype=int32)
+ >>> jnp.resize(4, (3, 2))
+ Array([[4, 4],
+ [4, 4],
+ [4, 4]], dtype=int32, weak_type=True)
+ """
util.check_arraylike("resize", a)
new_shape = _ensure_index_tuple(new_shape)
@@ -4174,10 +4304,44 @@ def atleast_1d(x: ArrayLike, /) -> Array:
@overload
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...
-@util.implements(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
- # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
+ """Convert inputs to arrays with at least 1 dimension.
+
+ JAX implementation of :func:`numpy.atleast_1d`.
+
+ Args:
+ zero or more arraylike arguments.
+
+ Returns:
+ an array or list of arrays corresponding to the input values. Arrays
+ of shape ``()`` are converted to shape ``(1,)``, and arrays with other
+ shapes are returned unchanged.
+
+ See also:
+ - :func:`jax.numpy.asarray`
+ - :func:`jax.numpy.atleast_2d`
+ - :func:`jax.numpy.atleast_3d`
+
+ Examples:
+ Scalar arguments are converted to 1D, length-1 arrays:
+
+ >>> x = jnp.float32(1.0)
+ >>> jnp.atleast_1d(x)
+ Array([1.], dtype=float32)
+
+ Higher dimensional inputs are returned unchanged:
+
+ >>> y = jnp.arange(4)
+ >>> jnp.atleast_1d(y)
+ Array([0, 1, 2, 3], dtype=int32)
+
+ Multiple arguments can be passed to the function at once, in which
+ case a list of results is returned:
+
+ >>> jnp.atleast_1d(x, y)
+ [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]
+ """
util.check_arraylike("atleast_1d", *arys, emit_warning=True)
if len(arys) == 1:
return array(arys[0], copy=False, ndmin=1)
@@ -4194,9 +4358,52 @@ def atleast_2d(x: ArrayLike, /) -> Array:
@overload
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...
-@util.implements(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
+ """Convert inputs to arrays with at least 2 dimensions.
+
+ JAX implementation of :func:`numpy.atleast_2d`.
+
+ Args:
+ zero or more arraylike arguments.
+
+ Returns:
+ an array or list of arrays corresponding to the input values. Arrays
+ of shape ``()`` are converted to shape ``(1, 1)``, 1D arrays of shape
+ ``(N,)`` are converted to shape ``(1, N)``, and arrays of all other
+ shapes are returned unchanged.
+
+ See also:
+ - :func:`jax.numpy.asarray`
+ - :func:`jax.numpy.atleast_1d`
+ - :func:`jax.numpy.atleast_3d`
+
+ Examples:
+ Scalar arguments are converted to 2D, size-1 arrays:
+
+ >>> x = jnp.float32(1.0)
+ >>> jnp.atleast_2d(x)
+ Array([[1.]], dtype=float32)
+
+ One-dimensional arguments have a unit dimension prepended to the shape:
+
+ >>> y = jnp.arange(4)
+ >>> jnp.atleast_2d(y)
+ Array([[0, 1, 2, 3]], dtype=int32)
+
+ Higher dimensional inputs are returned unchanged:
+
+ >>> z = jnp.ones((2, 3))
+ >>> jnp.atleast_2d(z)
+ Array([[1., 1., 1.],
+ [1., 1., 1.]], dtype=float32)
+
+ Multiple arguments can be passed to the function at once, in which
+ case a list of results is returned:
+
+ >>> jnp.atleast_2d(x, y)
+ [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]
+ """
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("atleast_2d", *arys, emit_warning=True)
if len(arys) == 1:
@@ -4214,9 +4421,58 @@ def atleast_3d(x: ArrayLike, /) -> Array:
@overload
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...
-@util.implements(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
+ """Convert inputs to arrays with at least 3 dimensions.
+
+ JAX implementation of :func:`numpy.atleast_3d`.
+
+ Args:
+ zero or more arraylike arguments.
+
+ Returns:
+ an array or list of arrays corresponding to the input values. Arrays
+ of shape ``()`` are converted to shape ``(1, 1, 1)``, 1D arrays of
+ shape ``(N,)`` are converted to shape ``(1, N, 1)``, 2D arrays of
+ shape ``(M, N)`` are converted to shape ``(M, N, 1)``, and arrays
+ of all other shapes are returned unchanged.
+
+ See also:
+ - :func:`jax.numpy.asarray`
+ - :func:`jax.numpy.atleast_1d`
+ - :func:`jax.numpy.atleast_2d`
+
+ Examples:
+ Scalar arguments are converted to 3D, size-1 arrays:
+
+ >>> x = jnp.float32(1.0)
+ >>> jnp.atleast_3d(x)
+ Array([[[1.]]], dtype=float32)
+
+ 1D arrays have a unit dimension prepended and appended:
+
+ >>> y = jnp.arange(4)
+ >>> jnp.atleast_3d(y).shape
+ (1, 4, 1)
+
+ 2D arrays have a unit dimension appended:
+
+ >>> z = jnp.ones((2, 3))
+ >>> jnp.atleast_3d(z).shape
+ (2, 3, 1)
+
+ Multiple arguments can be passed to the function at once, in which
+ case a list of results is returned:
+
+ >>> x3, y3 = jnp.atleast_3d(x, y)
+ >>> print(x3)
+ [[[1.]]]
+ >>> print(y3)
+ [[[0]
+ [1]
+ [2]
+ [3]]]
+ """
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("atleast_3d", *arys, emit_warning=True)
if len(arys) == 1:
@@ -4241,14 +4497,6 @@ def _supports_buffer_protocol(obj):
return True
-_ARRAY_DOC = """
-This function will create arrays on JAX's default device. For control of the
-device placement of data, see :func:`jax.device_put`. More information is
-available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at
-https://jax.readthedocs.io/en/latest/faq.html).
-"""
-
-
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
@@ -5064,9 +5312,50 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
# General np.from* style functions mostly delegate to numpy.
-@util.implements(np.frombuffer)
def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float,
count: int = -1, offset: int = 0) -> Array:
+ r"""Convert a buffer into a 1-D JAX array.
+
+ JAX implementation of :func:`numpy.frombuffer`.
+
+ Args:
+ buffer: an object containing the data. It must be either a bytes object with
+ a length that is an integer multiple of the dtype element size, or
+ it must be an object exporting the `Python buffer interface`_.
+ dtype: optional. Desired data type for the array. Default is ``float64``.
+ This specifes the dtype used to parse the buffer, but note that after parsing,
+ 64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64``
+ flag is set to ``False``.
+ count: optional integer specifying the number of items to read from the buffer.
+ If -1 (default), all items from the buffer are read.
+ offset: optional integer specifying the number of bytes to skip at the beginning
+ of the buffer. Default is 0.
+
+ Returns:
+ A 1-D JAX array representing the interpreted data from the buffer.
+
+ See also:
+ - :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array.
+
+ Examples:
+ Using a bytes buffer:
+
+ >>> buf = b"\x00\x01\x02\x03\x04"
+ >>> jnp.frombuffer(buf, dtype=jnp.uint8)
+ Array([0, 1, 2, 3, 4], dtype=uint8)
+ >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
+ Array([1, 2, 3, 4], dtype=uint8)
+
+ Constructing a JAX array via the Python buffer interface, using Python's
+ built-in :mod:`array` module.
+
+ >>> from array import array
+ >>> pybuffer = array('i', [0, 1, 2, 3, 4])
+ >>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
+ Array([0, 1, 2, 3, 4], dtype=int32)
+
+ .. _Python buffer interface: https://docs.python.org/3/c-api/buffer.html
+ """
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
@@ -5175,8 +5464,31 @@ def fromfunction(function: Callable[..., Array], shape: Any,
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
-@util.implements(np.fromstring)
def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array:
+ """Convert a string of text into 1-D JAX array.
+
+ JAX implementation of :func:`numpy.fromstring`.
+
+ Args:
+ string: input string containing the data.
+ dtype: optional. Desired data type for the array. Default is ``float``.
+ count: optional integer specifying the number of items to read from the string.
+ If -1 (default), all items are read.
+ sep: the string used to separate values in the input string.
+
+ Returns:
+ A 1-D JAX array containing the parsed data from the input string.
+
+ See also:
+ - :func:`jax.numpy.frombuffer`: construct a JAX array from an object
+ that implements the buffer interface.
+
+ Examples:
+ >>> jnp.fromstring("1 2 3", dtype=int, sep=" ")
+ Array([1, 2, 3], dtype=int32)
+ >>> jnp.fromstring("0.1, 0.2, 0.3", dtype=float, count=2, sep=",")
+ Array([0.1, 0.2], dtype=float32)
+ """
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
@@ -6706,7 +7018,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array:
return res
-def trim_zeros(filt, trim='fb'):
+def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array:
"""Trim leading and/or trailing zeros of the input array.
JAX implementation of :func:`numpy.trim_zeros`.
@@ -6728,14 +7040,26 @@ def trim_zeros(filt, trim='fb'):
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)
"""
- filt = core.concrete_or_error(asarray, filt,
- "Error arose in the `filt` argument of trim_zeros()")
- nz = (filt == 0)
+ # Non-array inputs are deprecated 2024-09-11
+ util.check_arraylike("trim_zeros", filt, emit_warning=True)
+ core.concrete_or_error(None, filt,
+ "Error arose in the `filt` argument of trim_zeros()")
+ filt_arr = jax.numpy.asarray(filt)
+ del filt
+ if filt_arr.ndim != 1:
+ # Added on 2024-09-11
+ if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"):
+ raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.")
+ warnings.warn(
+ "Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it "
+ "works with Arrays having ndim != 1. In the future this will result in an error.",
+ DeprecationWarning, stacklevel=2)
+ nz = (filt_arr == 0)
if reductions.all(nz):
- return empty(0, _dtype(filt))
- start = argmin(nz) if 'f' in trim.lower() else 0
- end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
- return filt[start:len(filt) - end]
+ return empty(0, filt_arr.dtype)
+ start: Array | int = argmin(nz) if 'f' in trim.lower() else 0
+ end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0
+ return filt_arr[start:len(filt_arr) - end]
def trim_zeros_tol(filt, tol, trim='fb'):
@@ -7085,20 +7409,17 @@ def dot(a: ArrayLike, b: ArrayLike, *,
batch_dims = ((), ())
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
- # TODO(jakevdp): lower this case to dot_general as well?
- # Currently, doing so causes issues in remat tests due to #16805
- if preferred_element_type is not None:
- a = a.astype(preferred_element_type)
- b = b.astype(preferred_element_type)
- result = lax.mul(a, b)
+ contract_dims: tuple[tuple[int, ...], tuple[int, ...]] = ((), ())
else:
if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
- result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims),
- precision=precision, preferred_element_type=preferred_element_type)
- return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
+ result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims),
+ precision=precision,
+ preferred_element_type=preferred_element_type)
+ return lax_internal._convert_element_type(result, preferred_element_type,
+ output_weak_type)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
@@ -8119,11 +8440,51 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array:
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
-@util.implements(np.vander)
@partial(jit, static_argnames=('N', 'increasing'))
def vander(
x: ArrayLike, N: int | None = None, increasing: bool = False
) -> Array:
+ """Generate a Vandermonde matrix.
+
+ JAX implementation of :func:`numpy.vander`.
+
+ Args:
+ x: input array. Must have ``x.ndim == 1``.
+ N: int, optional, default=None. Specifies the number of the columns the
+ output matrix. If not specified, ``N = len(x)``.
+ increasing: bool, optional, default=False. Specifies the order of the powers
+ of the columns. If ``True``, the powers increase from left to right,
+ :math:`[x^0, x^1, ..., x^{(N-1)}]`. By default, the powers decrease from left to
+ right :math:`[x^{(N-1)}, ..., x^1, x^0]`.
+
+ Returns:
+ An array of shape ``[len(x), N]`` containing the generated Vandermonde matrix.
+
+ Examples:
+ >>> x = jnp.array([1, 2, 3, 4])
+ >>> jnp.vander(x)
+ Array([[ 1, 1, 1, 1],
+ [ 8, 4, 2, 1],
+ [27, 9, 3, 1],
+ [64, 16, 4, 1]], dtype=int32)
+
+ If ``N = 2``, generates a Vandermonde matrix with ``2`` columns.
+
+ >>> jnp.vander(x, N=2)
+ Array([[1, 1],
+ [2, 1],
+ [3, 1],
+ [4, 1]], dtype=int32)
+
+ Generates the Vandermonde matrix in increaing order of powers, when
+ ``increasing=True``.
+
+ >>> jnp.vander(x, increasing=True)
+ Array([[ 1, 1, 1, 1],
+ [ 1, 2, 4, 8],
+ [ 1, 3, 9, 27],
+ [ 1, 4, 16, 64]], dtype=int32)
+ """
util.check_arraylike("vander", x)
x = asarray(x)
if x.ndim != 1:
@@ -8207,9 +8568,41 @@ def argwhere(
return result.reshape(result.shape[0], ndim(a))
-@util.implements(np.argmax, skip_params=['out'])
def argmax(a: ArrayLike, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array:
+ """Return the index of the maximum value of an array.
+
+ JAX implementation of :func:`numpy.argmax`.
+
+ Args:
+ a: input array
+ axis: optional integer specifying the axis along which to find the maximum
+ value. If ``axis`` is not specified, ``a`` will be flattened.
+ out: unused by JAX
+ keepdims: if True, then return an array with the same number of dimensions
+ as ``a``.
+
+ Returns:
+ an array containing the index of the maximum value along the specified axis.
+
+ See also:
+ - :func:`jax.numpy.argmin`: return the index of the minimum value.
+ - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values.
+
+ Examples:
+ >>> x = jnp.array([1, 3, 5, 4, 2])
+ >>> jnp.argmax(x)
+ Array(2, dtype=int32)
+
+ >>> x = jnp.array([[1, 3, 2],
+ ... [5, 4, 1]])
+ >>> jnp.argmax(x, axis=1)
+ Array([1, 0], dtype=int32)
+
+ >>> jnp.argmax(x, axis=1, keepdims=True)
+ Array([[1],
+ [0]], dtype=int32)
+ """
util.check_arraylike("argmax", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.")
@@ -8229,9 +8622,42 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_))
return expand_dims(result, dims) if keepdims else result
-@util.implements(np.argmin, skip_params=['out'])
+
def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array:
+ """Return the index of the minimum value of an array.
+
+ JAX implementation of :func:`numpy.argmax`.
+
+ Args:
+ a: input array
+ axis: optional integer specifying the axis along which to find the maximum
+ value. If ``axis`` is not specified, ``a`` will be flattened.
+ out: unused by JAX
+ keepdims: if True, then return an array with the same number of dimensions
+ as ``a``.
+
+ Returns:
+ an array containing the index of the maximum value along the specified axis.
+
+ See also:
+ - :func:`jax.numpy.argmax`: return the index of the maximum value.
+ - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values.
+
+ Examples:
+ >>> x = jnp.array([1, 3, 5, 4, 2])
+ >>> jnp.argmin(x)
+ Array(0, dtype=int32)
+
+ >>> x = jnp.array([[1, 3, 2],
+ ... [5, 4, 1]])
+ >>> jnp.argmin(x, axis=1)
+ Array([0, 2], dtype=int32)
+
+ >>> jnp.argmin(x, axis=1, keepdims=True)
+ Array([[0],
+ [2]], dtype=int32)
+ """
util.check_arraylike("argmin", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.")
@@ -8252,19 +8678,57 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
return expand_dims(result, dims) if keepdims else result
-_NANARG_DOC = """\
-Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise
-an error.
-"""
-
-
-@util.implements(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out'])
def nanargmax(
a: ArrayLike,
axis: int | None = None,
out: None = None,
keepdims: bool | None = None,
) -> Array:
+ """Return the index of the maximum value of an array, ignoring NaNs.
+
+ JAX implementation of :func:`numpy.nanargmax`.
+
+ Args:
+ a: input array
+ axis: optional integer specifying the axis along which to find the maximum
+ value. If ``axis`` is not specified, ``a`` will be flattened.
+ out: unused by JAX
+ keepdims: if True, then return an array with the same number of dimensions
+ as ``a``.
+
+ Returns:
+ an array containing the index of the maximum value along the specified axis.
+
+ Note:
+ In the case of an axis with all-NaN values, the returned index will be -1.
+ This differs from the behavior of :func:`numpy.nanargmax`, which raises an error.
+
+ See also:
+ - :func:`jax.numpy.argmax`: return the index of the maximum value.
+ - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values.
+
+ Examples:
+ >>> x = jnp.array([1, 3, 5, 4, jnp.nan])
+
+ Using a standard :func:`~jax.numpy.argmax` leads to potentially unexpected results:
+
+ >>> jnp.argmax(x)
+ Array(4, dtype=int32)
+
+ Using ``nanargmax`` returns the index of the maximum non-NaN value.
+
+ >>> jnp.nanargmax(x)
+ Array(2, dtype=int32)
+
+ >>> x = jnp.array([[1, 3, jnp.nan],
+ ... [5, 4, jnp.nan]])
+ >>> jnp.nanargmax(x, axis=1)
+ Array([1, 0], dtype=int32)
+
+ >>> jnp.nanargmax(x, axis=1, keepdims=True)
+ Array([[1],
+ [0]], dtype=int32)
+ """
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.")
return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@@ -8281,13 +8745,50 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False):
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
-@util.implements(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out'])
def nanargmin(
a: ArrayLike,
axis: int | None = None,
out: None = None,
keepdims: bool | None = None,
) -> Array:
+
+ """Return the index of the minimum value of an array, ignoring NaNs.
+
+ JAX implementation of :func:`numpy.nanargmin`.
+
+ Args:
+ a: input array
+ axis: optional integer specifying the axis along which to find the maximum
+ value. If ``axis`` is not specified, ``a`` will be flattened.
+ out: unused by JAX
+ keepdims: if True, then return an array with the same number of dimensions
+ as ``a``.
+
+ Returns:
+ an array containing the index of the minimum value along the specified axis.
+
+ Note:
+ In the case of an axis with all-NaN values, the returned index will be -1.
+ This differs from the behavior of :func:`numpy.nanargmin`, which raises an error.
+
+ See also:
+ - :func:`jax.numpy.argmin`: return the index of the minimum value.
+ - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values.
+
+ Examples:
+ >>> x = jnp.array([jnp.nan, 3, 5, 4, 2])
+ >>> jnp.nanargmin(x)
+ Array(4, dtype=int32)
+
+ >>> x = jnp.array([[1, 3, jnp.nan],
+ ... [5, 4, jnp.nan]])
+ >>> jnp.nanargmin(x, axis=1)
+ Array([0, 1], dtype=int32)
+
+ >>> jnp.nanargmin(x, axis=1, keepdims=True)
+ Array([[0],
+ [1]], dtype=int32)
+ """
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.")
return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims))
@@ -8367,11 +8868,40 @@ def sort(
return lax.rev(result, dimensions=[dimension]) if descending else result
-@util.implements(np.sort_complex)
@jit
def sort_complex(a: ArrayLike) -> Array:
+ """Return a sorted copy of complex array.
+
+ JAX implementation of :func:`numpy.sort_complex`.
+
+ Complex numbers are sorted lexicographically, meaning by their real part
+ first, and then by their imaginary part if real parts are equal.
+
+ Args:
+ a: input array. If dtype is not complex, the array will be upcast to complex.
+
+ Returns:
+ A sorted array of the same shape and complex dtype as the input. If ``a``
+ is multi-dimensional, it is sorted along the last axis.
+
+ See also:
+ - :func:`jax.numpy.sort`: Return a sorted copy of an array.
+
+ Examples:
+ >>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
+ >>> jnp.sort_complex(a)
+ Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
+
+ Multi-dimensional arrays are sorted along the last axis:
+
+ >>> a = jnp.array([[5, 3, 4],
+ ... [6, 9, 2]])
+ >>> jnp.sort_complex(a)
+ Array([[3.+0.j, 4.+0.j, 5.+0.j],
+ [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
+ """
util.check_arraylike("sort_complex", a)
- a = lax.sort(asarray(a), dimension=0)
+ a = lax.sort(asarray(a))
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
@util.implements(np.lexsort)
@@ -10214,7 +10744,7 @@ def body_fun(state, _):
def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64
def _rank(x):
- idx = lax.iota(working_dtype, len(x))
+ idx = lax.iota(working_dtype, x.shape[0])
return zeros_like(idx).at[argsort(x)].set(idx)
query_flat = query.ravel()
if side == 'left':
@@ -10307,8 +10837,8 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
a, v = util.promote_dtypes(a, v)
if sorter is not None:
a = a[sorter]
- dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
- if len(a) == 0:
+ dtype = int32 if a.shape[0] <= np.iinfo(np.int32).max else int64
+ if a.shape[0] == 0:
return zeros_like(v, dtype=dtype)
impl = {
'scan': partial(_searchsorted_via_scan, False),
@@ -10318,9 +10848,46 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
}[method]
return impl(asarray(a), asarray(v), side, dtype) # type: ignore
-@util.implements(np.digitize)
-@partial(jit, static_argnames=('right',))
-def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
+
+@partial(jit, static_argnames=('right', 'method'))
+def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
+ *, method: str | None = None) -> Array:
+ """Convert an array to bin indices.
+
+ JAX implementation of :func:`numpy.digitize`.
+
+ Args:
+ x: array of values to digitize.
+ bins: 1D array of bin edges. Must be monotonically increasing or decreasing.
+ right: if true, the intervals include the right bin edges. If false (default)
+ the intervals include the left bin edges.
+ method: optional method argument to be passed to :func:`~jax.numpy.searchsorted`.
+ See that function for available options.
+
+ Returns:
+ An integer array of the same shape as ``x`` indicating the bin number that
+ the values are in.
+
+ See also:
+ - :func:`jax.numpy.searchsorted`: find insertion indices for values in a
+ sorted array.
+ - :func:`jax.numpy.histogram`: compute frequency of array values within
+ specified bins.
+
+ Examples:
+ >>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5])
+ >>> bins = jnp.array([1, 2, 3])
+ >>> jnp.digitize(x, bins)
+ Array([1, 2, 2, 1, 3, 3], dtype=int32)
+ >>> jnp.digitize(x, bins, right=True)
+ Array([0, 1, 2, 1, 2, 3], dtype=int32)
+
+ ``digitize`` supports reverse-ordered bins as well:
+
+ >>> bins = jnp.array([3, 2, 1])
+ >>> jnp.digitize(x, bins)
+ Array([2, 1, 1, 2, 0, 0], dtype=int32)
+ """
util.check_arraylike("digitize", x, bins)
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
bins_arr = asarray(bins)
@@ -10329,10 +10896,11 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
if bins_arr.shape[0] == 0:
return zeros_like(x, dtype=int32)
side = 'right' if not right else 'left'
+ kwds: dict[str, str] = {} if method is None else {'method': method}
return where(
bins_arr[-1] >= bins_arr[0],
- searchsorted(bins_arr, x, side=side),
- len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side)
+ searchsorted(bins_arr, x, side=side, **kwds),
+ bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds)
)
diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py
index c3b38e57ab76..b45b3370fe53 100644
--- a/jax/_src/numpy/ufuncs.py
+++ b/jax/_src/numpy/ufuncs.py
@@ -1005,24 +1005,186 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array],
lax_op(x.real, y.real))
return lax_op(x, y)
-@implements(np.greater_equal, module='numpy')
@partial(jit, inline=True)
def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
+ """Return element-wise truth value of ``x >= y``.
+
+ JAX implementation of :obj:`numpy.greater_equal`.
+
+ Args:
+ x: input array or scalar.
+ y: input array or scalar. ``x`` and ``y`` must either have same shape or be
+ broadcast compatible.
+
+ Returns:
+ An array containing boolean values. ``True`` if the elements of ``x >= y``,
+ and ``False`` otherwise.
+
+ See also:
+ - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``.
+ - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``.
+ - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``.
+
+ Examples:
+ Scalar inputs:
+
+ >>> jnp.greater_equal(4, 7)
+ Array(False, dtype=bool, weak_type=True)
+
+ Inputs with same shape:
+
+ >>> x = jnp.array([2, 5, -1])
+ >>> y = jnp.array([-6, 4, 3])
+ >>> jnp.greater_equal(x, y)
+ Array([ True, True, False], dtype=bool)
+
+ Inputs with broadcast compatibility:
+
+ >>> x1 = jnp.array([[3, -1, 4],
+ ... [5, 9, -6]])
+ >>> y1 = jnp.array([-1, 4, 2])
+ >>> jnp.greater_equal(x1, y1)
+ Array([[ True, False, True],
+ [ True, True, False]], dtype=bool)
+ """
return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y))
-@implements(np.greater, module='numpy')
+
@partial(jit, inline=True)
def greater(x: ArrayLike, y: ArrayLike, /) -> Array:
+ """Return element-wise truth value of ``x > y``.
+
+ JAX implementation of :obj:`numpy.greater`.
+
+ Args:
+ x: input array or scalar.
+ y: input array or scalar. ``x`` and ``y`` must either have same shape or be
+ broadcast compatible.
+
+ Returns:
+ An array containing boolean values. ``True`` if the elements of ``x > y``,
+ and ``False`` otherwise.
+
+ See also:
+ - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``.
+ - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of
+ ``x >= y``.
+ - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``.
+
+ Examples:
+ Scalar inputs:
+
+ >>> jnp.greater(5, 2)
+ Array(True, dtype=bool, weak_type=True)
+
+ Inputs with same shape:
+
+ >>> x = jnp.array([5, 9, -2])
+ >>> y = jnp.array([4, -1, 6])
+ >>> jnp.greater(x, y)
+ Array([ True, True, False], dtype=bool)
+
+ Inputs with broadcast compatibility:
+
+ >>> x1 = jnp.array([[5, -6, 7],
+ ... [-2, 5, 9]])
+ >>> y1 = jnp.array([-4, 3, 10])
+ >>> jnp.greater(x1, y1)
+ Array([[ True, False, False],
+ [ True, True, False]], dtype=bool)
+ """
return _complex_comparison(lax.gt, *promote_args("greater", x, y))
-@implements(np.less_equal, module='numpy')
+
@partial(jit, inline=True)
def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
+ """Return element-wise truth value of ``x <= y``.
+
+ JAX implementation of :obj:`numpy.less_equal`.
+
+ Args:
+ x: input array or scalar.
+ y: input array or scalar. ``x`` and ``y`` must have either same shape or be
+ broadcast compatible.
+
+ Returns:
+ An array containing the boolean values. ``True`` if the elements of ``x <= y``,
+ and ``False`` otherwise.
+
+ See also:
+ - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of
+ ``x >= y``.
+ - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``.
+ - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``.
+
+ Examples:
+ Scalar inputs:
+
+ >>> jnp.less_equal(6, -2)
+ Array(False, dtype=bool, weak_type=True)
+
+ Inputs with same shape:
+
+ >>> x = jnp.array([-4, 1, 7])
+ >>> y = jnp.array([2, -3, 8])
+ >>> jnp.less_equal(x, y)
+ Array([ True, False, True], dtype=bool)
+
+ Inputs with broadcast compatibility:
+
+ >>> x1 = jnp.array([2, -5, 9])
+ >>> y1 = jnp.array([[1, -6, 5],
+ ... [-2, 4, -6]])
+ >>> jnp.less_equal(x1, y1)
+ Array([[False, False, False],
+ [False, True, False]], dtype=bool)
+ """
return _complex_comparison(lax.le, *promote_args("less_equal", x, y))
-@implements(np.less, module='numpy')
+
@partial(jit, inline=True)
def less(x: ArrayLike, y: ArrayLike, /) -> Array:
+ """Return element-wise truth value of ``x < y``.
+
+ JAX implementation of :obj:`numpy.less`.
+
+ Args:
+ x: input array or scalar.
+ y: input array or scalar. ``x`` and ``y`` must either have same shape or be
+ broadcast compatible.
+
+ Returns:
+ An array containing boolean values. ``True`` if the elements of ``x < y``,
+ and ``False`` otherwise.
+
+ See also:
+ - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``.
+ - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of
+ ``x >= y``.
+ - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``.
+
+ Examples:
+ Scalar inputs:
+
+ >>> jnp.less(3, 7)
+ Array(True, dtype=bool, weak_type=True)
+
+ Inputs with same shape:
+
+ >>> x = jnp.array([5, 9, -3])
+ >>> y = jnp.array([1, 6, 4])
+ >>> jnp.less(x, y)
+ Array([False, False, True], dtype=bool)
+
+ Inputs with broadcast compatibility:
+
+ >>> x1 = jnp.array([[2, -4, 6, -8],
+ ... [-1, 5, -3, 7]])
+ >>> y1 = jnp.array([0, 3, -5, 9])
+ >>> jnp.less(x1, y1)
+ Array([[False, True, False, True],
+ [ True, False, False, True]], dtype=bool)
+ """
return _complex_comparison(lax.lt, *promote_args("less", x, y))
# Array API aliases
@@ -1423,8 +1585,61 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod
-@implements(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
+ """Calculate element-wise base ``x1`` exponential of ``x2``.
+
+ JAX implementation of :obj:`numpy.power`.
+
+ Args:
+ x1: scalar or array. Specifies the bases.
+ x2: scalar or array. Specifies the exponent. ``x1`` and ``x2`` should either
+ have same shape or be broadcast compatible.
+
+ Returns:
+ An array containing the base ``x1`` exponentials of ``x2`` with same dtype
+ as input.
+
+ Note:
+ - When ``x2`` is a concrete integer scalar, ``jnp.power`` lowers to
+ :func:`jax.lax.integer_pow`.
+ - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to
+ :func:`jax.lax.pow`.
+ - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative
+ integer power.
+ - ``jnp.power`` returns ``nan`` for negative value raised to the power of
+ non-integer values.
+
+ See also:
+ - :func:`jax.lax.pow`: Computes element-wise power, :math:`x^y`.
+ - :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where
+ :math:`y` is a fixed integer.
+ - :func:`jax.numpy.float_power`: Computes the first array raised to the power
+ of second array, element-wise, by promoting to the inexact dtype.
+ - :func:`jax.numpy.pow`: Computes the first array raised to the power of second
+ array, element-wise.
+
+ Examples:
+ Inputs with scalar integers:
+
+ >>> jnp.power(4, 3)
+ Array(64, dtype=int32, weak_type=True)
+
+ Inputs with same shape:
+
+ >>> x1 = jnp.array([2, 4, 5])
+ >>> x2 = jnp.array([3, 0.5, 2])
+ >>> jnp.power(x1, x2)
+ Array([ 8., 2., 25.], dtype=float32)
+
+ Inputs with broadcast compatibility:
+
+ >>> x3 = jnp.array([-2, 3, 1])
+ >>> x4 = jnp.array([[4, 1, 6],
+ ... [1.3, 3, 5]])
+ >>> jnp.power(x3, x4)
+ Array([[16., 3., 1.],
+ [nan, 27., 1.]], dtype=float32)
+ """
check_arraylike("power", x1, x2)
check_no_float0s("power", x1, x2)
@@ -1454,8 +1669,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# Handle cases #2 and #3 under a jit:
return _power(x1, x2)
-# Array API alias
-pow = power
+def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array:
+ """Alias of :func:`jax.numpy.power`"""
+ return power(x1, x2)
@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
@@ -1528,11 +1744,39 @@ def _wrap_between(x, _a):
return lax.sub(rem, a)
-@custom_jvp
-@implements(np.logaddexp2, module='numpy')
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
+ """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
+
+ JAX implementation of :obj:`numpy.logaddexp2`.
+
+ Args:
+ x1: input array or scalar.
+ x2: input array or scalar. ``x1`` and ``x2`` should either have same shape or
+ be broadcast compatible.
+
+ Returns:
+ An array containing the result, :math:`log_2(2^{x1}+2^{x2})`, element-wise.
+
+ See also:
+ - :func:`jax.numpy.logaddexp`: Computes ``log(exp(x1) + exp(x2))``, element-wise.
+ - :func:`jax.numpy.log2`: Calculates the base-2 logarithm of ``x`` element-wise.
+
+ Examples:
+ >>> x1 = jnp.array([[3, -1, 4],
+ ... [8, 5, -2]])
+ >>> x2 = jnp.array([2, 3, -5])
+ >>> result1 = jnp.logaddexp2(x1, x2)
+ >>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2))
+ >>> jnp.allclose(result1, result2)
+ Array(True, dtype=bool)
+ """
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
+ return _logaddexp2(x1, x2)
+
+
+@custom_jvp
+def _logaddexp2(x1, x2):
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
@@ -1546,7 +1790,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
-@logaddexp2.defjvp
+@_logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
@@ -1559,7 +1803,7 @@ def _logaddexp2_jvp(primals, tangents):
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
- """Calculates the base-2 logarithm of x element-wise
+ """Calculates the base-2 logarithm of ``x`` element-wise.
JAX implementation of :obj:`numpy.log2`.
diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py
index 8dbf37587f8f..1a956de1f7a9 100644
--- a/jax/_src/pallas/core.py
+++ b/jax/_src/pallas/core.py
@@ -31,6 +31,7 @@
from jax._src import config
from jax._src import core as jax_core
from jax._src import deprecations
+from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import state
@@ -114,24 +115,115 @@ def from_pallas_call(pallas_call_name: str | None,
" ".join(src_info_parts[1:]))
-# Pytrees of jax.ShapeDtypeStruct
-ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
-
split_list = util.split_list
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
+class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
+ __slots__ = ["memory_space"]
+
+ def __init__(self, shape, dtype, weak_type=False, sharding=None,
+ memory_space=None):
+ super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding)
+ self.memory_space = memory_space
+
+ def __eq__(self, other):
+ return super().__eq__(other) and self.memory_space == other.memory_space
+
+ def __hash__(self):
+ return hash((
+ self.shape,
+ self.dtype,
+ self.weak_type,
+ getattr(self, "sharding", None),
+ self.memory_space,
+ ))
+
+ def at_least_vspace(self):
+ """Vector space method needed for AD."""
+ raise NotImplementedError
+
+ def join(self, other):
+ raise NotImplementedError
+
+ def str_short(self, short_dtypes=False):
+ dt_str = (
+ jax_core._short_dtype_name(self.dtype)
+ if short_dtypes
+ else self.dtype.name
+ )
+ dt_str = dt_str.replace("void", "float0")
+ shapestr = ",".join(map(str, self.shape))
+ if hasattr(self, "sharding"):
+ sharding_str = f"{dt_str}[{shapestr}]({self.sharding})"
+ else:
+ sharding_str = ""
+ memoryspace_str = (
+ "" if self.memory_space is None else f"<{self.memory_space}>"
+ )
+ return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}"
+
+ def update(
+ self,
+ shape=None,
+ dtype=None,
+ weak_type=None,
+ sharding=None,
+ memory_space=None,
+ ):
+ if shape is None:
+ shape = self.shape
+ if dtype is None:
+ dtype = self.dtype
+ if weak_type is None:
+ weak_type = self.weak_type
+ if sharding is None:
+ sharding = getattr(self, "sharding", None)
+ if memory_space is None:
+ memory_space = self.memory_space
+ return ShapedArrayWithMemorySpace(
+ shape, dtype, weak_type, sharding=sharding, memory_space=memory_space
+ )
+mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types
+
+
+@dataclasses.dataclass(frozen=True)
+class MemoryRef:
+ """Like jax.ShapeDtypeStruct but with memory spaces."""
+ shape: tuple[int, ...]
+ dtype: jnp.dtype
+ # TODO(b/368122763): Unify memory space types across backends
+ memory_space: Any
+
+ def get_array_aval(self) -> jax_core.ShapedArray:
+ dtype = self.dtype
+ if not isinstance(dtype, (jnp.dtype, dtypes.ExtendedDType)):
+ dtype = jnp.dtype(dtype)
+ return ShapedArrayWithMemorySpace(
+ self.shape, dtype, memory_space=self.memory_space
+ )
+
+ def get_ref_aval(self) -> AbstractMemoryRef:
+ # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
+ # try to apply JAX ops to it.
+ return AbstractMemoryRef(
+ jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
+
+
class AbstractMemoryRef(state.AbstractRef):
__slots__ = ["inner_aval", "memory_space"]
- def __init__(self, inner_aval: jax_core.AbstractValue,
- memory_space: Any):
+ inner_aval: jax_core.ShapedArray
- assert isinstance(
- inner_aval, jax_core.ShapedArray
- ), f"Illegal ref, got {type(inner_aval)}"
+ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any):
+ if isinstance(inner_aval, ShapedArrayWithMemorySpace):
+ if inner_aval.memory_space is not None:
+ assert inner_aval.memory_space == memory_space, (
+ f"Mismatched memory spaces: {inner_aval.memory_space=},"
+ f" {memory_space=}"
+ )
self.inner_aval = inner_aval
self.memory_space = memory_space
@@ -148,9 +240,9 @@ def update(self, inner_aval=None, memory_space=None):
memory_space = self.memory_space if memory_space is None else memory_space
return AbstractMemoryRef(inner_aval, memory_space)
- def at_least_vspace(self):
+ def to_tangent_aval(self):
return AbstractMemoryRef(
- self.inner_aval.at_least_vspace(), self.memory_space)
+ self.inner_aval.to_tangent_aval(), self.memory_space)
def __eq__(self, other):
return (type(self) is type(other) and self.inner_aval == other.inner_aval
@@ -161,7 +253,7 @@ def __hash__(self):
class MemorySpace(enum.Enum):
- """ Logical, device-agnostic memory spaces.
+ """Logical, device-agnostic memory spaces.
Each memory space will be translated to a device-specific memory
type during lowering.
@@ -430,11 +522,10 @@ def __repr__(self):
BlockSpecTree = Any
-class MemrefTransform(Protocol):
- """Represents a transformation applied to a Memref on load or store."""
+class MemoryRefTransform(Protocol):
+ """Transforms a memory reference on load or store."""
def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
- """Returns the transformed aval given an input aval."""
raise NotImplementedError("Abstract evaluation not implemented.")
@@ -451,9 +542,7 @@ class BlockMapping:
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
- transforms: Sequence[MemrefTransform] = dataclasses.field(
- default_factory=tuple
- )
+ transforms: Sequence[MemoryRefTransform] = ()
def check_invariants(self) -> None:
if not config.enable_checks.value: return
@@ -665,9 +754,10 @@ def slice_scratch_ops(self):
@property
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *inputs."""
- index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
- ia.inner_aval.dtype)
- for ia in self.index_map_avals[len(self.grid):])
+ index_shapes = (
+ jax.ShapeDtypeStruct(ia.shape, ia.dtype)
+ for ia in self.index_map_avals[len(self.grid) :]
+ )
inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[:self.num_inputs])
@@ -734,7 +824,18 @@ def _convert_block_spec_to_block_mapping(
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
-@dataclasses.dataclass(init=False)
+
+class ScratchShape(Protocol):
+ def get_array_aval(self) -> jax_core.AbstractValue:
+ ...
+ def get_ref_aval(self) -> state.AbstractRef:
+ ...
+
+
+ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]]
+
+
+@dataclasses.dataclass(init=False, kw_only=True)
class GridSpec:
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
@@ -747,12 +848,14 @@ class GridSpec:
grid_names: tuple[Hashable, ...] | None
in_specs: BlockSpecTree
out_specs: BlockSpecTree
+ scratch_shapes: ScratchShapeTree = ()
def __init__(
self,
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
+ scratch_shapes: ScratchShapeTree = (),
):
# Be more lenient for in/out_specs
if isinstance(in_specs, list):
@@ -764,6 +867,7 @@ def __init__(
self.in_specs = in_specs
self.out_specs = out_specs
+ self.scratch_shapes = tuple(scratch_shapes)
grid_names = None
if isinstance(grid, int):
@@ -779,9 +883,6 @@ def __init__(
self.grid = grid # type: ignore
self.grid_names = grid_names
- def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
- assert False # Not needed in GridSpec
-
def _make_scalar_ref_aval(self, aval):
assert False # Not needed in GridSpec
@@ -826,12 +927,10 @@ def get_grid_mapping(
else:
num_flat_scalar_prefetch = 0
jaxpr_scalar_ref_avals = ()
-
- scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ())
- if scratch_shapes:
+ if grid_spec.scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
- scratch_shapes)
- flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes)
+ grid_spec.scratch_shapes)
+ flat_scratch_avals = map(lambda s: s.get_ref_aval(), flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py
index 61b1dc435e72..76166ae61963 100644
--- a/jax/_src/pallas/mosaic/core.py
+++ b/jax/_src/pallas/mosaic/core.py
@@ -19,7 +19,7 @@
import dataclasses
import enum
import functools
-from typing import Any, ClassVar, Hashable, Literal
+from typing import Any, ClassVar, Literal
import jax
from jax._src import core as jax_core
@@ -39,6 +39,7 @@
BlockSpecTree = pallas_core.BlockSpecTree
GridMapping = pallas_core.GridMapping
NoBlockSpec = pallas_core.NoBlockSpec
+ScratchShapeTree = pallas_core.ScratchShapeTree
AbstractMemoryRef = pallas_core.AbstractMemoryRef
no_block_spec = pallas_core.no_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
@@ -89,7 +90,7 @@ def __str__(self) -> str:
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
# A convenience function for constructing MemoryRef types.
- return MemoryRef(shape, dtype, self)
+ return pallas_core.MemoryRef(shape, dtype, self)
class semaphore_dtype(dtypes.extended): pass
class semaphore(semaphore_dtype): pass
@@ -101,6 +102,10 @@ class AbstractSemaphoreTyRules:
def pallas_interpret_element_aval(_) -> jax_core.ShapedArray:
return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE)
+ @staticmethod
+ def physical_element_aval(_) -> jax_core.ShapedArray:
+ return jax_core.ShapedArray((), jnp.int32)
+
class AbstractSemaphoreTy(dtypes.ExtendedDType):
name: str
_rules = AbstractSemaphoreTyRules
@@ -143,10 +148,13 @@ def __call__(self, shape: tuple[int, ...]):
dtype = SemaphoreTy()
if pallas_core.is_interpret_mode():
dtype = pallas_core.SEMAPHORE_INTERPRET_DTYPE
- return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
+ return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
- def get_aval(self) -> AbstractMemoryRef:
- return self(()).get_aval()
+ def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace:
+ return self(()).get_array_aval()
+
+ def get_ref_aval(self) -> AbstractMemoryRef:
+ return self(()).get_ref_aval()
@dataclasses.dataclass(frozen=True)
class AbstractSemaphore(jax_core.AbstractValue):
@@ -162,26 +170,9 @@ def join(self, other):
jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval
-@dataclasses.dataclass(frozen=True)
-class MemoryRef:
- """Like jax.ShapeDtypeStruct but with memory spaces."""
- shape: tuple[int, ...]
- dtype: jnp.dtype
- memory_space: TPUMemorySpace = TPUMemorySpace.ANY
-
- def get_aval(self) -> AbstractMemoryRef:
- return AbstractMemoryRef(
- jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
-
-
-@dataclasses.dataclass(init=False, unsafe_hash=True)
+@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
- grid: TupleGrid
- grid_names: tuple[Hashable, ...] | None
num_scalar_prefetch: int
- in_specs: pallas_core.BlockSpecTree
- out_specs: pallas_core.BlockSpecTree
- scratch_shapes: tuple[Any, ...]
def __init__(
self,
@@ -189,9 +180,9 @@ def __init__(
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
- scratch_shapes: Any | Sequence[Any] = ()
+ scratch_shapes: ScratchShapeTree = ()
):
- super().__init__(grid, in_specs, out_specs)
+ super().__init__(grid, in_specs, out_specs, scratch_shapes)
self.num_scalar_prefetch = num_scalar_prefetch
self.scratch_shapes = tuple(scratch_shapes)
@@ -199,14 +190,6 @@ def _make_scalar_ref_aval(self, aval):
return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
TPUMemorySpace.SMEM)
- def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
- if isinstance(obj, MemoryRef):
- return obj.get_aval()
- if isinstance(obj, SemaphoreType):
- return obj.get_aval()
- raise ValueError(f"No registered conversion for {type(obj)}. "
- "Only VMEM and SemaphoreType are supported.")
-
@dataclasses.dataclass(frozen=True)
class TensorCore:
diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py
index bd897deb3d1f..775f0c1f8256 100644
--- a/jax/_src/pallas/mosaic/lowering.py
+++ b/jax/_src/pallas/mosaic/lowering.py
@@ -40,7 +40,6 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
-from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
@@ -48,8 +47,8 @@
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
-from jax._src.pallas import pallas_call
from jax._src.pallas import core as pallas_core
+from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
@@ -58,6 +57,9 @@
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
+from jax._src.state.types import RefBitcaster
+from jax._src.state.utils import dtype_bitwidth
+from jax._src.typing import DTypeLike
from jax._src.util import safe_map
from jax._src.util import safe_zip
from jax._src.util import split_list
@@ -422,24 +424,23 @@ class MeshInfo:
axis_names: list[str]
mesh_strides: tuple[int, ...]
-def lower_jaxpr_to_module(
+
+def _check_block_mappings(
+ block_mappings: tuple[pallas_core.BlockMapping, ...],
lowering_context: mlir.LoweringRuleContext,
- ctx: ir.Context,
- grid_mapping: pallas_core.GridMapping,
- jaxpr: jax_core.Jaxpr,
- *,
- dimension_semantics: tuple[str | None, ...] | None,
name_and_src_info: pallas_core.NameAndSrcInfo,
- mesh: mesh_lib.Mesh | None = None,
- for_verification: bool = False,
-) -> tuple[Module, tuple[Any, ...]]:
- for bm in grid_mapping.block_mappings:
+) -> None:
+ del lowering_context # originally needed for forward compat
+ for bm in block_mappings:
rank = len(bm.block_shape)
# TODO(necula): add tests for SMEM blocks with trivial windowing
# We support scalars too
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and
bm.has_trivial_window()):
continue
+ if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE:
+ continue
+
def err_details():
return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
"has block shape "
@@ -448,20 +449,10 @@ def err_details():
f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in "
f"memory space {bm.block_aval.memory_space}."
"\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec")
- if lowering_context.is_forward_compat() or jaxlib_version < (0, 4, 32):
- # TODO(b/356116061): Remove the old rank condition
- if rank < 2:
- raise ValueError(
- "The Pallas TPU lowering currently supports only blocks of "
- "rank >= 2 for blocks, except those in the SMEM memory space "
- "having the same block shape as the array shape and a "
- "trivial index_map (returning all 0s). " + err_details())
- else:
- if rank < 1:
- raise ValueError(
- "The Pallas TPU lowering currently supports only blocks of "
- "rank >= 1. " + err_details())
-
+ if rank < 1:
+ raise ValueError(
+ "The Pallas TPU lowering currently supports only blocks of "
+ "rank >= 1. " + err_details())
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and
not bm.has_trivial_window()):
@@ -476,42 +467,42 @@ def err_details():
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
else:
bs1, as1 = 1, 1
- if lowering_context.is_forward_compat():
- # TODO(b/356116061): Remove the old divisibility condition
- # With shape polymorphism block_shape is static, but the array shape may
- # be symbolic. Write the divisibility comparisons to defer inequality
- # comparisons on dimensions as much as possible.
+
+ if rank >= 2:
evenly_divisible = (
- (bs0 % 128 == 0 or (bs0 == as0 and as0 < 128)) and
- (bs1 % 8 == 0 or (bs1 == as1 and as1 < 8))
+ (bs0 == as0 or bs0 % 128 == 0) and
+ (bs1 == as1 or bs1 % 8 == 0)
)
- if not evenly_divisible:
- raise ValueError(
- "The Pallas TPU lowering currently requires that the last two "
- "dimensions of your block shape are divisible by 8 and 128 "
- "respectively, if the respective dimensions of the overall array "
- "are larger than the respective factors. If array dimensions are "
- "smaller, the block should span the full array dimension. "
- + err_details())
else:
- if rank >= 2:
- evenly_divisible = (
- (bs0 == as0 or bs0 % 128 == 0) and
- (bs1 == as1 or bs1 % 8 == 0)
- )
- else:
- assert rank == 1
- # TODO(necula): test this for bool. What should it do?
- tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype))
- evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
+ assert rank == 1
+ # TODO(necula): test this for bool. What should it do?
+ tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype))
+ evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
if not evenly_divisible:
raise ValueError(
- "The Pallas TPU lowering currently requires that the last two "
- "dimensions of your block shape are divisible by 8 and 128 "
- "respectively, or be equal to the respective dimensions of the "
- "overall array. "
- + err_details())
+ "The Pallas TPU lowering currently requires that the last two "
+ "dimensions of your block shape are divisible by 8 and 128 "
+ "respectively, or be equal to the respective dimensions of the "
+ "overall array. "
+ + err_details()
+ )
+
+
+def lower_jaxpr_to_module(
+ lowering_context: mlir.LoweringRuleContext,
+ ctx: ir.Context,
+ grid_mapping: pallas_core.GridMapping,
+ jaxpr: jax_core.Jaxpr,
+ *,
+ dimension_semantics: tuple[str | None, ...] | None,
+ name_and_src_info: pallas_core.NameAndSrcInfo,
+ mesh: mesh_lib.Mesh | None = None,
+ for_verification: bool = False,
+) -> tuple[Module, tuple[Any, ...]]:
+ # Verify that we have legal block mappings to catch errors early.
+ _check_block_mappings(grid_mapping.block_mappings, lowering_context,
+ name_and_src_info)
mosaic_grid_mapping = MosaicGridMapping(
jaxpr, grid_mapping, dimension_semantics, mesh)
@@ -985,11 +976,12 @@ def _indexer_to_start_size_stride(
)
-def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
- indexer: NDIndexer,
- ref_block_shape: tuple[int | pallas_core.Mapped, ...]
- ) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...],
- tuple[int | pallas_core.Mapped, ...]]:
+def _slice_memref(
+ ref: ir.Value,
+ indexer: NDIndexer,
+ ref_dtype: DTypeLike,
+ ref_block_shape: tuple[int | pallas_core.Mapped, ...],
+) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]:
assert ref_block_shape is not None
target_shape = indexer.get_indexer_shape()
starts, sizes, strides, squeeze_dims, ref_block_shape = (
@@ -1006,26 +998,79 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
static_sizes = tuple(s if not isinstance(s, ir.Value)
else ir_dynamic_size for s in sizes)
target_ref_ty = ir.MemRefType.get(
- static_sizes, _dtype_to_ir_type(ref_aval.dtype),
- memory_space=ref.type.memory_space)
+ static_sizes,
+ _dtype_to_ir_type(ref_dtype),
+ memory_space=ref.type.memory_space,
+ )
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result
if any(squeeze_dims):
# We need to squeeze out some dimensions
static_sizes = tuple(s if not isinstance(s, ir.Value)
else ir_dynamic_size for s in target_shape)
squeezed_ref_ty = ir.MemRefType.get(
- static_sizes, _dtype_to_ir_type(ref_aval.dtype),
- memory_space=ref.type.memory_space)
+ static_sizes,
+ _dtype_to_ir_type(ref_dtype),
+ memory_space=ref.type.memory_space,
+ )
out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result
return out, ref_block_shape
-def _index_ref(ref, ref_aval, ref_block_shape, indexers):
- for indexer in indexers:
- ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer,
- ref_block_shape)
+def _bitcast_memref(
+ ref: ir.Value,
+ bitcaster: RefBitcaster,
+ ref_dtype: DTypeLike,
+ ref_block_shape: tuple[int | pallas_core.Mapped, ...],
+) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
+ src_bitwidth = dtype_bitwidth(ref_dtype)
+ dst_bitwidth = dtype_bitwidth(bitcaster.dtype)
+ if src_bitwidth != dst_bitwidth:
+ if len(ref_block_shape) < 2:
+ raise NotImplementedError(
+ "Bitcast 1D ref with bitwidth change is not supported."
+ )
+ if ref_block_shape[-2] is pallas_core.mapped:
+ raise NotImplementedError(
+ "Bitcast a ref whose 2nd minormost dimension is squeezed when"
+ " bitwidth changes."
+ )
+ new_ref_dtype = bitcaster.dtype
+ target_ref_ty = ir.MemRefType.get(
+ bitcaster.shape,
+ _dtype_to_ir_type(new_ref_dtype),
+ memory_space=ref.type.memory_space,
+ )
+ new_ref_block_shape = list(ref_block_shape)
+ if (
+ len(new_ref_block_shape) >= 2
+ and new_ref_block_shape[-2] is not pallas_core.mapped
+ ):
+ new_ref_block_shape[-2] = (
+ new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth
+ )
+ return (
+ tpu.memref_bitcast(target_ref_ty, ref),
+ new_ref_dtype,
+ tuple(new_ref_block_shape),
+ )
+
+
+def _transform_ref(ref, ref_dtype, ref_block_shape, transforms):
+ for transform in transforms:
+ match transform:
+ case NDIndexer():
+ ref, ref_block_shape = _slice_memref(
+ ref, transform, ref_dtype, ref_block_shape
+ )
+ case RefBitcaster():
+ ref, ref_dtype, ref_block_shape = _bitcast_memref(
+ ref, transform, ref_dtype, ref_block_shape
+ )
+ case _:
+ raise NotImplementedError(f"Unsupported transform: {transform}")
return ref, ref_block_shape
+
@dataclasses.dataclass(frozen=True)
class KeyScalarBundle:
"""A container class for PRNG key data.
@@ -1044,21 +1089,21 @@ class KeyScalarBundle:
scalars: list[ir.OpResult]
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
- ref, indexers, mask, _ = args_tree.unflatten(args_flat)
- ref_aval, indexers_avals, _, _ = args_tree.unflatten(ctx.avals_in)
- (*slice_indexers, idx) = indexers
+ ref, transforms, mask, _ = args_tree.unflatten(args_flat)
+ ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in)
+ (*prev_transforms, idx) = transforms
# Select last aval, which is the one that will be used for the load.
- (*_, idx_aval) = indexers_avals
+ (*_, idx_aval) = transforms_avals
if mask is not None:
raise NotImplementedError
ref_block_shape, *_ = ctx.block_shapes
- ref, ref_block_shape = _index_ref(
- ref, ref_aval, ref_block_shape, slice_indexers)
+ ref, ref_block_shape = _transform_ref(
+ ref, ref_aval.dtype, ref_block_shape, prev_transforms
+ )
ref_type = ir.MemRefType(ref.type)
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space"
- ref_aval, *_ = ctx.avals_in
(aval_out,) = ctx.avals_out
if isinstance(aval_out.dtype, prng.KeyTy):
if not is_smem_load:
@@ -1092,7 +1137,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
raise ValueError(
"Loads are only allowed on VMEM and SMEM references." + extra
)
- load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype)
+ load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype)
if need_stride:
load_val = tpu.StridedLoadOp(
aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts, strides
@@ -1187,17 +1232,18 @@ def _maybe_cast_store_to_memref_type(
def _masked_swap_lowering_rule(
ctx: LoweringRuleContext, *args_flat, args_tree, **_
):
- ref, indexers, val, mask = args_tree.unflatten(args_flat)
- ref_aval, indexers_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in)
- (*slice_indexers, idx) = indexers
- (*_, idx_aval) = indexers_avals
+ ref, transforms, val, mask = args_tree.unflatten(args_flat)
+ ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in)
+ (*prev_transforms, idx) = transforms
+ (*_, idx_aval) = transforms_avals
if mask is not None:
raise NotImplementedError
ref_block_shape, *_ = ctx.block_shapes
- ref, ref_block_shape = _index_ref(
- ref, ref_aval, ref_block_shape, slice_indexers)
+ ref, ref_block_shape = _transform_ref(
+ ref, ref_aval.dtype, ref_block_shape, prev_transforms
+ )
ref_type = ir.MemRefType(ref.type)
is_smem_store = str(ref_type.memory_space) == "#tpu.memory_space"
@@ -2486,11 +2532,9 @@ def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d):
def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
- (x_aval,) = ctx.avals_in
- if x_aval.dtype == jnp.float32:
- return lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)(ctx, x)
- else:
- raise NotImplementedError
+ return lower_fun(
+ pallas_utils.erf_inv_lowering_helper, multiple_results=False,
+ )(ctx, x)
lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule
@@ -2581,8 +2625,8 @@ def _semaphore_read_lowering_rule(
args_tree,
):
sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
- sem, indexers = tree_util.tree_unflatten(args_tree, args)
- sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
+ sem, transforms = tree_util.tree_unflatten(args_tree, args)
+ sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
return tpu.SemaphoreReadOp(sem).result
@@ -2595,8 +2639,10 @@ def _semaphore_signal_lowering_rule(
device_id_type: tpu_primitives.DeviceIdType,
):
sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
- sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args)
- sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
+ sem, transforms, value, device_id, core_index = tree_util.tree_unflatten(
+ args_tree, args
+ )
+ sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
if device_id is not None:
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.SemaphoreSignalOp(
@@ -2610,8 +2656,8 @@ def _semaphore_signal_lowering_rule(
def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
- sem, indexers, value = tree_util.tree_unflatten(args_tree, args)
- sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
+ sem, transforms, value = tree_util.tree_unflatten(args_tree, args)
+ sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
return tpu.SemaphoreWaitOp(sem, value).results
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
@@ -2619,13 +2665,13 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
(
src_ref,
- src_indexers,
+ src_transforms,
dst_ref,
- dst_indexers,
+ dst_transforms,
sem,
- sem_indexers,
+ sem_transforms,
src_sem,
- src_sem_indexers,
+ src_sem_transforms,
device_id,
) = tree_util.tree_unflatten(tree, args)
(src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = (
@@ -2635,16 +2681,17 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
raise NotImplementedError("DMAs with bool dtypes are not supported.")
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
- src_ref, _ = _index_ref(
- src_ref, src_ref_aval, src_ref_block_shape, src_indexers
+ src_ref, _ = _transform_ref(
+ src_ref, src_ref_aval.dtype, src_ref_block_shape, src_transforms
)
if src_sem is not None:
- src_sem, _ = _index_ref(
- src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers)
- dst_ref, _ = _index_ref(
- dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers
+ src_sem, _ = _transform_ref(
+ src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms
+ )
+ dst_ref, _ = _transform_ref(
+ dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms
)
- sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
+ sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
if device_id is not None:
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem,
@@ -2655,14 +2702,12 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
del device_id_type
- sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, args)
+ sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args)
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
ref_block_shape = block_shapes[2]
- ref, _ = _index_ref(
- ref, ref_aval, ref_block_shape, indexers
- )
- sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
+ ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms)
+ sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
return tpu.WaitDMAOp(sem, ref).results
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule
@@ -2706,6 +2751,9 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int):
def _debug_print_rule(
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
):
+ if any(aval.shape for aval in ctx.avals_in):
+ raise NotImplementedError("Only scalar values are supported")
+
primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py
index 71091af27ca3..b09d36a9d3b2 100644
--- a/jax/_src/pallas/mosaic/pallas_call_registration.py
+++ b/jax/_src/pallas/mosaic/pallas_call_registration.py
@@ -30,19 +30,24 @@
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core
+from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import lowering
from jax._src.pallas.mosaic import verification
+from jax._src import tpu_custom_call
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
from jax.experimental.pallas import tpu as pltpu
-def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
+def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
"""Casts boolean values to integers.
We perform this cast because Mosaic does not directly support bool values
for Memrefs. Instead, we load bools as integers and cast them to bools
after loading from a memref inside of the kernel.
"""
+ assert isinstance(
+ x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray)
+ ), type(x)
if isinstance(x, jax.Array):
if dtypes.issubdtype(x.dtype, jax.numpy.bool_):
return x.astype(lowering.BOOL_MEMREF_TYPE)
@@ -63,6 +68,41 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
)
+def _get_memory_space_from_aval(
+ out_aval: jax_core.AbstractValue,
+) -> tpu_custom_call.MemorySpace | None:
+ if not isinstance(out_aval, jax_core.ShapedArray):
+ raise ValueError('Memory spaces not defined for non-ShapedArrays')
+ if not isinstance(out_aval, core.ShapedArrayWithMemorySpace):
+ # If we are passed a regular old ShapedArray, we don't constrain the
+ # memory space
+ return None
+ # If we are passed an aval with an explicit memory space tag, we use it
+ # to constrain the memory space.
+ match out_aval.memory_space:
+ case None:
+ return None
+ case tpu_core.TPUMemorySpace.ANY:
+ return None
+ case tpu_core.TPUMemorySpace.VMEM:
+ return tpu_custom_call.MemorySpace.VMEM
+ case tpu_core.TPUMemorySpace.SEMAPHORE:
+ return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
+ return None
+
+
+def _get_memory_spaces_from_avals(
+ out_avals: tuple[jax_core.AbstractValue, ...],
+) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None:
+ output_memory_spaces = None
+ if any(
+ isinstance(out_aval, core.ShapedArrayWithMemorySpace)
+ for out_aval in out_avals
+ ):
+ output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals))
+ return output_memory_spaces
+
+
def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
@@ -74,6 +114,7 @@ def pallas_call_tpu_lowering_rule(
interpret: bool,
compiler_params: dict[str, Any],
cost_estimate: core.CostEstimate | None,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
@@ -129,9 +170,6 @@ def lower_module(for_verification: bool):
(a[0] + num_dyn_bounds + num_extra_args, a[1])
for a in input_output_aliases
)
- out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape,
- bm.array_shape_dtype.dtype)
- for bm in grid_mapping.block_mappings_output]
if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size
@@ -174,7 +212,7 @@ def lower_module(for_verification: bool):
def _maybe_cast_inputs(*args):
args = [_maybe_cast_to_int(x) for x in args]
return args
- kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] # type: ignore
+ kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in]
kernel_out_avals = [_maybe_cast_to_int(x) for x in out_avals]
cast_ctx = ctx.replace(avals_out=kernel_in_avals)
in_nodes = mlir.lower_fun(_maybe_cast_inputs)(cast_ctx, *in_nodes)
@@ -182,6 +220,7 @@ def _maybe_cast_inputs(*args):
# Dynamic grid bounds have to go at the front.
dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:]
kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals)
+ output_memory_spaces = _get_memory_spaces_from_avals(out_avals)
if cost_estimate is not None:
mosaic_cost_estimate = pltpu.CostEstimate(
flops=cost_estimate.flops,
@@ -208,7 +247,7 @@ def _maybe_cast_inputs(*args):
device_type=mosaic_params.get("device_type"),
internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"),
collective_id=mosaic_params.get("collective_id", None),
- output_memory_spaces=None, # TODO(apaszke,sharadmv): Implement this.
+ output_memory_spaces=output_memory_spaces,
)
_maybe_cast_to_bool = lambda x, aval: x.astype(
jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x
diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py
index fca9ee471e6a..005e4acdd106 100644
--- a/jax/_src/pallas/mosaic/pipeline.py
+++ b/jax/_src/pallas/mosaic/pipeline.py
@@ -40,7 +40,7 @@
SMEM = tpu_core.TPUMemorySpace.SMEM
VMEM = tpu_core.TPUMemorySpace.VMEM
DMA = tpu_core.SemaphoreType.DMA
-REF = tpu_core.MemoryRef
+REF = pallas_core.MemoryRef
SemaphoreType = tpu_core.SemaphoreType
SemaphoreTuple = jax.Array
ArrayRef = Union[REF, jax.Array]
@@ -189,7 +189,7 @@ class BufferedRef:
dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
- vmem_ref: a double-buffer to hold a working buffer and a dirty buffer used
+ window_ref: a double-buffer to hold a working buffer and a dirty buffer used
to copy into and out of. In the case of a BufferedRef targeting a VMEM
reference, this simply points to the existing ref.
accum_ref: accumulating buffer used by accumulator BufferedRefs.
@@ -210,7 +210,7 @@ class BufferedRef:
spec: pl.BlockSpec # static metadata
dtype: Any # static metadata
buffer_type: BufferType # static metadata
- vmem_ref: REF | None
+ window_ref: REF | None
accum_ref: REF | None
current_slot: ArrayRef | None
next_slot: ArrayRef | None
@@ -218,9 +218,17 @@ class BufferedRef:
sem_sends: SemaphoreTuple | None
def tree_flatten(self):
- return ((self.vmem_ref, self.accum_ref, self.current_slot,
- self.next_slot, self.sem_recvs, self.sem_sends),
- (self.spec, self.dtype, self.buffer_type))
+ return (
+ (
+ self.window_ref,
+ self.accum_ref,
+ self.current_slot,
+ self.next_slot,
+ self.sem_recvs,
+ self.sem_sends,
+ ),
+ (self.spec, self.dtype, self.buffer_type),
+ )
@classmethod
def tree_unflatten(cls, meta, data):
@@ -252,7 +260,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
spec=spec,
dtype=dtype,
buffer_type=buffer_type,
- vmem_ref=None, # to be bound to existing ref by the pipeline routine
+ window_ref=None, # to be bound to existing ref by the pipeline routine
accum_ref=accum_ref,
current_slot=None,
next_slot=None,
@@ -260,11 +268,12 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
sem_sends=None,
)
else:
+ memory_space = SMEM if spec.memory_space == SMEM else VMEM
return cls(
spec=spec,
dtype=dtype,
buffer_type=buffer_type,
- vmem_ref=VMEM((2,) + block_shape, dtype),
+ window_ref=memory_space((2,) + block_shape, dtype),
accum_ref=accum_ref,
current_slot=SMEM((1,), jnp.int32),
next_slot=SMEM((1,), jnp.int32),
@@ -313,9 +322,9 @@ def current_ref(self):
buffer_slice = tuple(
0 if x is None else slice(None) for x in self.block_shape)
if self.memory_space == VMEM:
- return self.vmem_ref.at[buffer_slice]
+ return self.window_ref.at[buffer_slice]
else:
- return self.vmem_ref.at[(self.current_slot[0], *buffer_slice)]
+ return self.window_ref.at[(self.current_slot[0], *buffer_slice)]
@property
def is_input(self):
@@ -341,11 +350,12 @@ def is_accumulator(self):
def is_input_output(self):
return self.buffer_type == BufferType.INPUT_OUTPUT
- def bind_existing_ref(self, vmem_ref, indices):
+ def bind_existing_ref(self, window_ref, indices):
"""For handling VMEM references, the pipeline aliases the existing ref."""
if self.memory_space == VMEM:
return dataclasses.replace(
- self, vmem_ref=vmem_ref.at[self.compute_slice(indices)])
+ self, window_ref=window_ref.at[self.compute_slice(indices)]
+ )
return self
def compute_slice(self, grid_indices):
@@ -432,8 +442,9 @@ def copy_in(self, src_ref, grid_indices):
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
src_ref.at[src_slice],
- self.vmem_ref.at[next_slot].at[dst_slice],
- self.sem_recvs.at[next_slot]).start()
+ self.window_ref.at[next_slot].at[dst_slice],
+ self.sem_recvs.at[next_slot],
+ ).start()
def copy_out(self, dst_ref, grid_indices):
"""Starts copy of HBM dma slice from the current slot."""
@@ -444,9 +455,10 @@ def copy_out(self, dst_ref, grid_indices):
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
- self.vmem_ref.at[slot].at[src_slice],
+ self.window_ref.at[slot].at[src_slice],
dst_ref.at[dst_slice],
- self.sem_sends.at[slot]).start()
+ self.sem_sends.at[slot],
+ ).start()
def wait_in(self, src_ref, grid_indices):
"""Waits for input copy to finish."""
@@ -456,9 +468,12 @@ def wait_in(self, src_ref, grid_indices):
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
current_slot = self.current_slot[0]
tpu_primitives.make_async_copy(
- src_ref.at[src_slice], # nb: doesn't matter
- self.vmem_ref.at[current_slot].at[dst_slice], # only dst shape is important
- self.sem_recvs.at[current_slot]).wait()
+ src_ref.at[src_slice], # nb: doesn't matter
+ self.window_ref.at[current_slot].at[
+ dst_slice
+ ], # only dst shape is important
+ self.sem_recvs.at[current_slot],
+ ).wait()
def wait_out(self, dst_ref, grid_indices):
"""Waits for output copy to finish."""
@@ -468,9 +483,10 @@ def wait_out(self, dst_ref, grid_indices):
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
- self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
- dst_ref.at[dst_slice], # only dst shape is important
- self.sem_sends.at[prev_slot]).wait()
+ self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
+ dst_ref.at[dst_slice], # only dst shape is important
+ self.sem_sends.at[prev_slot],
+ ).wait()
# Accumulator methods
#
@@ -498,14 +514,14 @@ def accumulate(self):
assert self.is_accumulator
if self.accum_ref is not None:
accum_dtype = jnp.float32
- if self.vmem_ref.dtype == jnp.int32:
+ if self.window_ref.dtype == jnp.int32:
accum_dtype = jnp.int32
# TODO(levskaya): we could generalize init and reduction functions,
# could it ever be useful to support more generic monoids?
self.current_ref[...] = (
- self.current_ref[...].astype(accum_dtype) +
- self.accum_ref[...].astype(accum_dtype)
- ).astype(self.vmem_ref.dtype)
+ self.current_ref[...].astype(accum_dtype)
+ + self.accum_ref[...].astype(accum_dtype)
+ ).astype(self.window_ref.dtype)
# Helper to tree map over BufferedRefs as leaves.
diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py
index 348820907ed0..aab214a2d700 100644
--- a/jax/_src/pallas/mosaic/primitives.py
+++ b/jax/_src/pallas/mosaic/primitives.py
@@ -33,6 +33,7 @@
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
+from jax._src.state.types import Transform
from jax._src.typing import DTypeLike
import jax.numpy as jnp
@@ -164,17 +165,21 @@ class DeviceIdType(enum.Enum):
LOGICAL = "logical"
-def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types=None):
+def check_sem_avals(
+ sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None
+):
if allowed_semaphore_types is None:
- allowed_semaphore_types = {tpu_core.semaphore,
- tpu_core.barrier_semaphore,
- # For interpret mode.
- pl_core.SEMAPHORE_INTERPRET_DTYPE}
+ allowed_semaphore_types = {
+ tpu_core.semaphore,
+ tpu_core.barrier_semaphore,
+ # For interpret mode.
+ pl_core.SEMAPHORE_INTERPRET_DTYPE,
+ }
if not isinstance(sem_aval, state.AbstractRef):
raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}")
sem_shape = sem_aval.shape
- if sem_indexers_avals:
- sem_shape = sem_indexers_avals[-1].get_indexer_shape()
+ if sem_transforms_avals:
+ sem_shape = sem_transforms_avals[-1].get_indexer_shape()
if sem_shape:
raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}")
sem_dtype = sem_aval.dtype
@@ -187,10 +192,11 @@ def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types=
f" {allowed_semaphore_types}."
)
-def _index_semaphore(ref_value, indexers, ref_aval):
+
+def _transform_semaphore(ref_value, transforms, ref_aval):
"""Helper function for indexing into a semaphore during state_discharge."""
if ref_value.shape == ref_aval.shape:
- return state_discharge.index_array(ref_value, indexers)
+ return state_discharge.transform_array(ref_value, transforms)
elif len(ref_value.shape) == 0:
return ref_value
else:
@@ -199,13 +205,14 @@ def _index_semaphore(ref_value, indexers, ref_aval):
f" {ref_aval.shape}"
)
+
semaphore_read_p = jax_core.Primitive("semaphore_read")
semaphore_read_p.multiple_results = False
def semaphore_read(sem_or_view):
- ref, indexers = _get_ref_and_indexers(sem_or_view)
- args = [ref, indexers]
+ ref, transforms = _get_ref_and_transforms(sem_or_view)
+ args = [ref, transforms]
flat_args, args_tree = tree_util.tree_flatten(args)
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)
@@ -214,10 +221,10 @@ def _semaphore_read_abstract_eval(
*avals,
args_tree,
):
- sem_aval, sem_indexers_avals = tree_util.tree_unflatten(args_tree, avals)
+ sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals)
check_sem_avals(
sem_aval,
- sem_indexers_avals,
+ sem_transforms_avals,
"read",
allowed_semaphore_types={
tpu_core.dma_semaphore,
@@ -233,8 +240,8 @@ def _semaphore_read_discharge_rule(in_avals,
*flat_args,
args_tree):
del out_avals
- [ref, indexers] = args_tree.unflatten(flat_args)
- sem_value = _index_semaphore(ref, indexers, in_avals[0])
+ [ref, transforms] = args_tree.unflatten(flat_args)
+ sem_value = _transform_semaphore(ref, transforms, in_avals[0])
sem_value = sem_value.astype(jnp.int32)
return (None,) * len(in_avals), sem_value
state_discharge.register_discharge_rule(semaphore_read_p)(
@@ -254,9 +261,9 @@ def semaphore_signal(
device_id_type: DeviceIdType = DeviceIdType.MESH,
core_index: int | jax.Array | None = None,
):
- ref, indexers = _get_ref_and_indexers(sem_or_view)
+ ref, transforms = _get_ref_and_transforms(sem_or_view)
inc = jnp.asarray(inc, dtype=jnp.int32)
- args = [ref, indexers, inc, device_id, core_index]
+ args = [ref, transforms, inc, device_id, core_index]
flat_args, args_tree = tree_util.tree_flatten(args)
semaphore_signal_p.bind(
*flat_args,
@@ -272,10 +279,14 @@ def _semaphore_signal_abstract_eval(
device_id_type: DeviceIdType,
):
del device_id_type
- sem_aval, sem_indexers_avals, value_aval, device_id_avals, core_index_aval = (
- tree_util.tree_unflatten(args_tree, avals)
- )
- check_sem_avals(sem_aval, sem_indexers_avals, "signal")
+ (
+ sem_aval,
+ sem_transforms_avals,
+ value_aval,
+ device_id_avals,
+ core_index_aval,
+ ) = tree_util.tree_unflatten(args_tree, avals)
+ check_sem_avals(sem_aval, sem_transforms_avals, "signal")
if value_aval.dtype != jnp.dtype("int32"):
raise ValueError("Must signal an int32 value.")
if device_id_avals is not None:
@@ -294,16 +305,16 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn,
tree = eqn.params["args_tree"]
(
sem,
- sem_indexers,
+ sem_transforms,
value,
device_ids,
_,
) = tree_util.tree_unflatten(tree, invars)
out = pp.concat([
- pp.text('semaphore_signal'),
- pp.text(' '),
- sp.pp_ref_indexers(context, sem, sem_indexers),
- pp.text(' '),
+ pp.text("semaphore_signal"),
+ pp.text(" "),
+ sp.pp_ref_transforms(context, sem, sem_transforms),
+ pp.text(" "),
pp.text(jax_core.pp_var(value, context)),
])
if device_ids is not None:
@@ -325,15 +336,15 @@ def _semaphore_signal_discharge_rule(in_avals,
args_tree,
device_id_type):
del out_avals, device_id_type
- [ref, indexers, inc, device_id, core_index] = args_tree.unflatten(flat_args)
+ [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args)
if device_id is not None:
raise NotImplementedError("Remote signal not implemented.")
if core_index is not None:
raise NotImplementedError("Multiple core support not implemented.")
- sem_value = _index_semaphore(ref, indexers, in_avals[0])
+ sem_value = _transform_semaphore(ref, transforms, in_avals[0])
inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE)
- _, new_sem_value = state_discharge.index_swap_array(
- ref, indexers, sem_value + inc
+ _, new_sem_value = state_discharge.transform_swap_array(
+ ref, transforms, sem_value + inc
)
return (new_sem_value,) + (None,) * (len(in_avals) - 1), ()
state_discharge.register_discharge_rule(semaphore_signal_p)(
@@ -345,16 +356,18 @@ def _semaphore_signal_discharge_rule(in_avals,
semaphore_wait_p.multiple_results = True
def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
- ref, indexers = _get_ref_and_indexers(sem_or_view)
+ ref, transforms = _get_ref_and_transforms(sem_or_view)
dec = jnp.asarray(dec, dtype=jnp.int32)
- args = [ref, indexers, dec]
+ args = [ref, transforms, dec]
flat_args, args_tree = tree_util.tree_flatten(args)
semaphore_wait_p.bind(*flat_args, args_tree=args_tree)
@semaphore_wait_p.def_abstract_eval
def _semaphore_wait_abstract_eval(*avals, args_tree):
- sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals)
- check_sem_avals(sem_aval, sem_indexers_avals, "wait")
+ sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten(
+ args_tree, avals
+ )
+ check_sem_avals(sem_aval, sem_transforms_avals, "wait")
if value_aval.dtype != jnp.dtype("int32"):
raise ValueError("Must wait an int32 value.")
return []
@@ -367,14 +380,14 @@ def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn,
tree = eqn.params["args_tree"]
(
sem,
- sem_indexers,
+ sem_transforms,
value,
) = tree_util.tree_unflatten(tree, invars)
return pp.concat([
- pp.text('semaphore_wait'),
- pp.text(' '),
- sp.pp_ref_indexers(context, sem, sem_indexers),
- pp.text(' '),
+ pp.text("semaphore_wait"),
+ pp.text(" "),
+ sp.pp_ref_transforms(context, sem, sem_transforms),
+ pp.text(" "),
pp.text(jax_core.pp_var(value, context)),
])
jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn
@@ -384,11 +397,11 @@ def _semaphore_wait_discharge_rule(in_avals,
*flat_args,
args_tree):
del out_avals
- [ref, indexers, dec] = args_tree.unflatten(flat_args)
- sem_value = _index_semaphore(ref, indexers, in_avals[0])
+ [ref, transforms, dec] = args_tree.unflatten(flat_args)
+ sem_value = _transform_semaphore(ref, transforms, in_avals[0])
dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE)
- _, new_sem_value = state_discharge.index_swap_array(
- ref, indexers, sem_value -dec
+ _, new_sem_value = state_discharge.transform_swap_array(
+ ref, transforms, sem_value - dec
)
return (new_sem_value,) + (None,) * (len(in_avals) - 1), ()
state_discharge.register_discharge_rule(semaphore_wait_p)(
@@ -399,13 +412,13 @@ def _semaphore_wait_discharge_rule(in_avals,
@dataclasses.dataclass
class AsyncCopyDescriptor:
src_ref: Any
- src_indexers: tuple[indexing.NDIndexer, ...]
+ src_transforms: tuple[Transform, ...]
dst_ref: Any
- dst_indexers: tuple[indexing.NDIndexer, ...]
+ dst_transforms: tuple[Transform, ...]
dst_sem: int | jax.Array
- dst_sem_indexers: tuple[indexing.NDIndexer, ...]
+ dst_sem_transforms: tuple[Transform, ...]
src_sem: int | jax.Array | None
- src_sem_indexers: tuple[indexing.NDIndexer, ...] | None
+ src_sem_transforms: tuple[Transform, ...] | None
device_id: int | jax.Array | None
device_id_type: DeviceIdType = DeviceIdType.MESH
@@ -421,13 +434,13 @@ def is_remote(self):
def start(self):
flat_args, tree = tree_util.tree_flatten((
self.src_ref,
- self.src_indexers,
+ self.src_transforms,
self.dst_ref,
- self.dst_indexers,
+ self.dst_transforms,
self.dst_sem,
- self.dst_sem_indexers,
+ self.dst_sem_transforms,
self.src_sem,
- self.src_sem_indexers,
+ self.src_sem_transforms,
self.device_id,
))
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)
@@ -438,9 +451,12 @@ def wait(self):
self.wait_recv()
def wait_recv(self):
- wait_args, tree = tree_util.tree_flatten(
- (self.dst_sem, self.dst_sem_indexers, self.dst_ref, self.dst_indexers)
- )
+ wait_args, tree = tree_util.tree_flatten((
+ self.dst_sem,
+ self.dst_sem_transforms,
+ self.dst_ref,
+ self.dst_transforms,
+ ))
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
)
@@ -448,9 +464,12 @@ def wait_recv(self):
def wait_send(self):
if not self.is_remote:
raise ValueError("Cannot `wait_send` on a local copy.")
- wait_args, tree = tree_util.tree_flatten(
- (self.src_sem, self.src_sem_indexers, self.src_ref, self.src_indexers)
- )
+ wait_args, tree = tree_util.tree_flatten((
+ self.src_sem,
+ self.src_sem_transforms,
+ self.src_ref,
+ self.src_transforms,
+ ))
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
)
@@ -463,32 +482,32 @@ def wait_send(self):
def _dma_start_abstract_eval(*args, tree, device_id_type):
(
src_ref_aval,
- src_indexers_avals,
+ src_transforms_avals,
dst_ref_aval,
- dst_indexers_avals,
+ dst_transforms_avals,
dst_sem_aval,
- dst_sem_indexers_avals,
+ dst_sem_transforms_avals,
src_sem_aval,
- src_sem_indexers_avals,
+ src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, args)
dst_sem_shape = dst_sem_aval.shape
- if dst_sem_indexers_avals:
- dst_sem_shape = dst_sem_indexers_avals[-1].get_indexer_shape()
+ if dst_sem_transforms_avals:
+ dst_sem_shape = dst_sem_transforms_avals[-1].get_indexer_shape()
if dst_sem_shape:
raise ValueError(
f"Cannot signal on a non-()-shaped semaphore: {dst_sem_shape}"
)
if src_sem_aval is not None:
src_sem_shape = src_sem_aval.shape
- if src_sem_indexers_avals:
- src_sem_shape = src_sem_indexers_avals[-1].get_indexer_shape()
+ if src_sem_transforms_avals:
+ src_sem_shape = src_sem_transforms_avals[-1].get_indexer_shape()
if src_sem_shape:
raise ValueError(
f"Cannot signal on a non-()-shaped semaphore: {src_sem_shape}"
)
- n_src_indexers = len(tree_util.tree_leaves(src_indexers_avals))
- return [], {state.ReadEffect(0), state.WriteEffect(n_src_indexers + 1)}
+ n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals))
+ return [], {state.ReadEffect(0), state.WriteEffect(n_src_transforms + 1)}
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
context: jax_core.JaxprPpContext,
@@ -497,27 +516,27 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
tree = eqn.params["tree"]
(
src_ref,
- src_indexers,
+ src_transforms,
dst_ref,
- dst_indexers,
+ dst_transforms,
dst_sem,
- dst_sem_indexers,
+ dst_sem_transforms,
src_sem,
- src_sem_indexers,
+ src_sem_transforms,
device_id,
) = tree_util.tree_unflatten(tree, invars)
- del src_sem_indexers
+ del src_sem_transforms
# TODO(sharadmv): pretty print source semaphores and device id
if src_sem or device_id:
return jax_core._pp_eqn(eqn, context, settings)
return pp.concat([
- pp.text('dma_start'),
- pp.text(' '),
- sp.pp_ref_indexers(context, src_ref, src_indexers),
- pp.text(' -> '),
- sp.pp_ref_indexers(context, dst_ref, dst_indexers),
- pp.text(' '),
- sp.pp_ref_indexers(context, dst_sem, dst_sem_indexers),
+ pp.text("dma_start"),
+ pp.text(" "),
+ sp.pp_ref_transforms(context, src_ref, src_transforms),
+ pp.text(" -> "),
+ sp.pp_ref_transforms(context, dst_ref, dst_transforms),
+ pp.text(" "),
+ sp.pp_ref_transforms(context, dst_sem, dst_sem_transforms),
])
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
@@ -526,24 +545,24 @@ def dma_start_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
(
src_ref,
- src_indexers,
+ src_transforms,
dst_ref,
- dst_indexers,
+ dst_transforms,
dst_sem,
- dst_sem_indexers,
+ dst_sem_transforms,
src_sem,
- src_sem_indexers,
+ src_sem_transforms,
device_id,
) = tree_util.tree_unflatten(tree, args)
(
_,
- src_indexers_avals,
+ src_transforms_avals,
_,
- dst_indexers_avals,
+ dst_transforms_avals,
dst_sem_aval,
- dst_sem_indexers_avals,
+ dst_sem_transforms_avals,
src_sem_aval,
- src_sem_indexers_avals,
+ src_sem_transforms_avals,
_,
) = tree_util.tree_unflatten(tree, in_avals)
del out_avals
@@ -551,14 +570,14 @@ def dma_start_discharge_rule(in_avals, out_avals,
if not is_remote:
# Local async copies only use one semaphore.
assert src_sem is None
- assert src_sem_indexers is None
+ assert src_sem_transforms is None
- num_src_sem_indexers = len(tree_util.tree_leaves(src_sem_indexers_avals))
- num_dst_sem_indexers = len(tree_util.tree_leaves(dst_sem_indexers_avals))
- num_src_index_vals = len(tree_util.tree_leaves(src_indexers_avals))
- num_dst_index_vals = len(tree_util.tree_leaves(dst_indexers_avals))
+ num_src_sem_transforms = len(tree_util.tree_leaves(src_sem_transforms_avals))
+ num_dst_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
+ num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
+ num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))
- updates = state_discharge.index_array(src_ref, src_indexers)
+ updates = state_discharge.transform_array(src_ref, src_transforms)
local_src = updates
if is_remote:
@@ -602,44 +621,52 @@ def dma_start_discharge_rule(in_avals, out_avals,
global_updates, index, axis=0, keepdims=False)
# Handle asymmetrical indexing when devices do not share the same
- # dst_indexer.
- global_dst_indexers = tree_util.tree_map(
- lambda x: jax.lax.all_gather(x, shard_axis), dst_indexers)
- dst_indexers = tree_util.tree_map(
+ # dst_transform.
+ global_dst_transforms = tree_util.tree_map(
+ lambda x: jax.lax.all_gather(x, shard_axis), dst_transforms
+ )
+ dst_transforms = tree_util.tree_map(
lambda x: jax.lax.dynamic_index_in_dim(
- x, index, axis=0, keepdims=False), global_dst_indexers)
+ x, index, axis=0, keepdims=False
+ ),
+ global_dst_transforms,
+ )
- _, new_dst = state_discharge.index_swap_array(
- dst_ref, dst_indexers, updates
+ _, new_dst = state_discharge.transform_swap_array(
+ dst_ref, dst_transforms, updates
)
# Update semaphore values.
# TODO(justinfu): Potentially handle asymmetric copy sizes.
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
- dst_sem_value = _index_semaphore(dst_sem, dst_sem_indexers, dst_sem_aval)
- _, new_dst_sem = state_discharge.index_swap_array(
- dst_sem, dst_sem_indexers, dst_sem_value + recv_size
+ dst_sem_value = _transform_semaphore(
+ dst_sem, dst_sem_transforms, dst_sem_aval
+ )
+ _, new_dst_sem = state_discharge.transform_swap_array(
+ dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
if is_remote:
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
- src_sem_value = _index_semaphore(src_sem, src_sem_indexers, src_sem_aval)
- _, new_src_sem = state_discharge.index_swap_array(
- src_sem, src_sem_indexers, src_sem_value + send_size
+ src_sem_value = _transform_semaphore(
+ src_sem, src_sem_transforms, src_sem_aval
+ )
+ _, new_src_sem = state_discharge.transform_swap_array(
+ src_sem, src_sem_transforms, src_sem_value + send_size
)
else:
new_src_sem = None
new_vals = (None,) # src_val
- new_vals += (None,) * num_src_index_vals
+ new_vals += (None,) * num_src_transform_vals
new_vals += (new_dst,) # dst_val
- new_vals += (None,) * num_dst_index_vals
+ new_vals += (None,) * num_dst_transform_vals
new_vals += (new_dst_sem,) # dst_sem
- new_vals += (None,) * num_dst_sem_indexers
+ new_vals += (None,) * num_dst_sem_transforms
if is_remote:
new_vals += (new_src_sem,) # src_sem
- new_vals += (None,) * num_src_sem_indexers
+ new_vals += (None,) * num_src_sem_transforms
new_vals += (None,) # device_id
assert (len(new_vals) ==
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
@@ -662,13 +689,13 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
del settings
invars = eqn.invars
tree = eqn.params["tree"]
- sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, invars)
+ sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars)
return pp.concat([
- pp.text('dma_wait'),
- pp.text(' '),
- sp.pp_ref_indexers(context, ref, indexers),
- pp.text(' '),
- sp.pp_ref_indexers(context, sem, sem_indexers),
+ pp.text("dma_wait"),
+ pp.text(" "),
+ sp.pp_ref_transforms(context, ref, transforms),
+ pp.text(" "),
+ sp.pp_ref_transforms(context, sem, sem_transforms),
])
jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
@@ -676,42 +703,53 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
del out_avals, device_id_type
- (sem, sem_indexers, ref, ref_indexers) = tree_util.tree_unflatten(tree, args)
+ (sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten(
+ tree, args
+ )
(
sem_aval,
- sem_indexers_avals,
+ sem_transforms_avals,
_,
- ref_indexers_avals,
+ ref_transforms_avals,
) = tree_util.tree_unflatten(tree, in_avals)
- num_sem_indexers = len(tree_util.tree_leaves(sem_indexers_avals))
- num_indexers = len(tree_util.tree_leaves(ref_indexers_avals))
- updates = state_discharge.index_array(ref, ref_indexers)
+ num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals))
+ num_transforms = len(tree_util.tree_leaves(ref_transforms_avals))
+ updates = state_discharge.transform_array(ref, ref_transforms)
copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
- sem_value = _index_semaphore(sem, sem_indexers, sem_aval)
- _, new_sem = state_discharge.index_swap_array(
- sem, sem_indexers, sem_value - copy_size
+ sem_value = _transform_semaphore(sem, sem_transforms, sem_aval)
+ _, new_sem = state_discharge.transform_swap_array(
+ sem, sem_transforms, sem_value - copy_size
)
new_vals = (new_sem,) # sem
- new_vals += (None,) * num_sem_indexers
+ new_vals += (None,) * num_sem_transforms
new_vals += (None,) # ref
- new_vals += (None,) * num_indexers
+ new_vals += (None,) * num_transforms
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
-def _get_ref_and_indexers(ref):
- if isinstance(ref, state.RefView):
- return ref.ref, ref.indexers
+def _get_ref_and_transforms(ref):
+ if isinstance(ref, state.TransformedRef):
+ return ref.ref, ref.transforms
return ref, ()
def make_async_copy(src_ref, dst_ref, sem):
"""Issues a DMA copying from src_ref to dst_ref."""
- src_ref, src_indexers = _get_ref_and_indexers(src_ref)
- dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref)
- sem, sem_indexers = _get_ref_and_indexers(sem)
- return AsyncCopyDescriptor(src_ref, src_indexers, dst_ref, dst_indexers,
- sem, sem_indexers, None, None, None,
- DeviceIdType.MESH)
+ src_ref, src_transforms = _get_ref_and_transforms(src_ref)
+ dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref)
+ sem, sem_transforms = _get_ref_and_transforms(sem)
+ return AsyncCopyDescriptor(
+ src_ref,
+ src_transforms,
+ dst_ref,
+ dst_transforms,
+ sem,
+ sem_transforms,
+ None,
+ None,
+ None,
+ DeviceIdType.MESH,
+ )
def async_copy(src_ref, dst_ref, sem):
"""Issues a DMA copying from src_ref to dst_ref."""
@@ -739,13 +777,22 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
Returns:
An AsyncCopyDescriptor.
"""
- src_ref, src_indexers = _get_ref_and_indexers(src_ref)
- send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem)
- dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref)
- recv_sem, recv_sem_indexers = _get_ref_and_indexers(recv_sem)
+ src_ref, src_transforms = _get_ref_and_transforms(src_ref)
+ send_sem, send_sem_transforms = _get_ref_and_transforms(send_sem)
+ dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref)
+ recv_sem, recv_sem_transforms = _get_ref_and_transforms(recv_sem)
return AsyncCopyDescriptor(
- src_ref, src_indexers, dst_ref, dst_indexers, recv_sem, recv_sem_indexers,
- send_sem, send_sem_indexers, device_id, device_id_type=device_id_type)
+ src_ref,
+ src_transforms,
+ dst_ref,
+ dst_transforms,
+ recv_sem,
+ recv_sem_transforms,
+ send_sem,
+ send_sem_transforms,
+ device_id,
+ device_id_type=device_id_type,
+ )
def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
device_id_type: DeviceIdType = DeviceIdType.MESH):
diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py
index df186d46373a..bae87226c664 100644
--- a/jax/_src/pallas/mosaic/verification.py
+++ b/jax/_src/pallas/mosaic/verification.py
@@ -550,13 +550,17 @@ def _pretend_abstract_eval(*_, **params):
def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree):
if ctx.lowering_context.for_verification:
- (base_read_refs, indexers) = tree_util.tree_unflatten(tree, flat_args)
+ (base_read_refs, transforms) = tree_util.tree_unflatten(tree, flat_args)
read_ref_avals, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
block_shapes, _ = tree_util.tree_unflatten(tree, ctx.block_shapes)
read_refs = [
lowering._index_ref(ref, aval, block_shape, indexer)[0]
for ref, aval, block_shape, indexer in zip(
- base_read_refs, read_ref_avals, block_shapes, indexers, strict=True,
+ base_read_refs,
+ read_ref_avals,
+ block_shapes,
+ transforms,
+ strict=True,
)
]
ir.Operation.create("verification.pretend", operands=read_refs)
@@ -565,8 +569,10 @@ def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree):
lowering.lowering_rules[pretend_p] = _pretend_lowering # type: ignore
def pretend(read_refs):
- refs, indexers = unzip2(primitives._get_ref_and_indexers(r) for r in read_refs)
- flat_args, tree = tree_util.tree_flatten((refs, indexers))
+ refs, transforms = unzip2(
+ primitives._get_ref_and_transforms(r) for r in read_refs
+ )
+ flat_args, tree = tree_util.tree_flatten((refs, transforms))
return pretend_p.bind(*flat_args, tree=tree)
diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD
index c3e8fc8b83de..171ff0439085 100644
--- a/jax/_src/pallas/mosaic_gpu/BUILD
+++ b/jax/_src/pallas/mosaic_gpu/BUILD
@@ -34,6 +34,7 @@ py_library(
deps = [
":core",
":pallas_call_registration",
+ ":primitives",
],
)
@@ -72,8 +73,22 @@ pytype_strict_library(
deps = [
"//jax",
"//jax:core",
+ "//jax:dtypes",
"//jax:mosaic_gpu",
"//jax:tree_util",
"//jax/_src/pallas",
] + py_deps("numpy"),
)
+
+pytype_strict_library(
+ name = "primitives",
+ srcs = ["primitives.py"],
+ deps = [
+ ":core",
+ ":lowering",
+ "//jax",
+ "//jax:core",
+ "//jax:mosaic_gpu",
+ "//jax/_src/pallas",
+ ],
+)
diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py
index 862a661e24b9..1bd512834ce5 100644
--- a/jax/_src/pallas/mosaic_gpu/__init__.py
+++ b/jax/_src/pallas/mosaic_gpu/__init__.py
@@ -11,3 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+# TODO(slebedev): Move these imports to ``jax.experimental.pallas``.
+
+from jax._src.pallas.mosaic_gpu.core import Barrier
+from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
+from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
+from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
+from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem
+from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem
+from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
+from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
+
+GMEM = GPUMemorySpace.GMEM
+SMEM = GPUMemorySpace.SMEM
+REGS = GPUMemorySpace.REGS
diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py
index 1a0c489af47d..dc698b8747d9 100644
--- a/jax/_src/pallas/mosaic_gpu/core.py
+++ b/jax/_src/pallas/mosaic_gpu/core.py
@@ -17,9 +17,10 @@
from collections.abc import Sequence
import dataclasses
import enum
-from typing import ClassVar, Literal
-from jax import core as jax_core
-from jax._src import core
+from typing import Any, ClassVar, Literal, Protocol
+
+from jax._src import core as jax_core
+from jax._src import dtypes
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax.experimental.mosaic import gpu as mosaic_gpu
@@ -29,11 +30,13 @@
AbstractMemoryRef = pallas_core.AbstractMemoryRef
-@dataclasses.dataclass(frozen=True)
+@dataclasses.dataclass(frozen=True, kw_only=True)
class GPUCompilerParams(pallas_core.CompilerParams):
"""Mosaic GPU compiler parameters.
Attributes:
+ approx_math: If True, the compiler is allowed to use approximate
+ implementations of some math operations, e.g. ``exp``. Defaults to False.
dimension_semantics: A list of dimension semantics for each grid
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
@@ -42,6 +45,7 @@ class GPUCompilerParams(pallas_core.CompilerParams):
meaning no pipelining is done.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
+ approx_math: bool = False
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1
@@ -56,11 +60,16 @@ def __str__(self) -> str:
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
# A convenience function for constructing MemoryRef types.
- return MemoryRef(shape, dtype, self)
+ return pallas_core.MemoryRef(shape, dtype, memory_space=self)
+
+
+class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
+ def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform:
+ ...
-class TilingTransform(pallas_core.MemrefTransform):
- """Represents a tiling transformation for Memrefs.
+class TilingTransform(MemoryRefTransform):
+ """Represents a tiling transformation for memory refs.
A tiling of (X, Y) on an array of shape (M, N) will result in a transformed
shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a
@@ -73,7 +82,7 @@ def __init__(self, tiling: tuple[int, ...]):
def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
- block_shape = block_aval.inner_aval.shape # pytype: disable=attribute-error
+ block_shape = block_aval.shape
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
@@ -110,9 +119,9 @@ class GPUBlockSpec(pallas_core.BlockSpec):
def to_block_mapping(
self,
origin: pallas_core.OriginStr,
- array_aval: core.ShapedArray,
+ array_aval: jax_core.ShapedArray,
*,
- index_map_avals: Sequence[core.AbstractValue],
+ index_map_avals: Sequence[jax_core.AbstractValue],
index_map_tree: tree_util.PyTreeDef,
grid: pallas_core.GridMappingGrid,
mapped_dims: tuple[int, ...],
@@ -125,7 +134,7 @@ def to_block_mapping(
grid=grid,
mapped_dims=mapped_dims,
)
- transforms: tuple[pallas_core.MemrefTransform, ...] = ()
+ transforms: tuple[pallas_core.MemoryRefTransform, ...] = ()
if self.tiling is not None:
transforms += (TilingTransform(self.tiling),)
return GPUBlockMapping(
@@ -141,20 +150,33 @@ def to_block_mapping(
)
-# TODO(b/354568887): Cosolidate this with TPU's MemoryRef.
+GMEM = GPUMemorySpace.GMEM
+SMEM = GPUMemorySpace.SMEM
+REGS = GPUMemorySpace.REGS
+
+
+class barrier_dtype(dtypes.extended):
+ pass
+
+
@dataclasses.dataclass(frozen=True)
-class MemoryRef:
- """Like jax.ShapeDtypeStruct but with memory spaces."""
+class BarrierType(dtypes.ExtendedDType):
+ type: ClassVar[Any] = barrier_dtype
+ name: ClassVar[str] = "barrier"
- shape: tuple[int, ...]
- dtype: jnp.dtype
- memory_space: GPUMemorySpace
+ num_arrivals: int
- def get_aval(self) -> AbstractMemoryRef:
- return AbstractMemoryRef(
- jax_core.ShapedArray(self.shape, self.dtype), self.memory_space
- )
+ def __str__(self):
+ return self.name
-GMEM = GPUMemorySpace.GMEM
-SMEM = GPUMemorySpace.SMEM
-REGS = GPUMemorySpace.REGS
+
+@dataclasses.dataclass(frozen=True)
+class Barrier:
+ num_arrivals: int
+ num_barriers: int = 1
+
+ def get_ref_aval(self) -> AbstractMemoryRef:
+ aval = jax_core.ShapedArray(
+ [self.num_barriers], BarrierType(self.num_arrivals)
+ )
+ return AbstractMemoryRef(aval, SMEM)
diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py
index 87a147c91c81..5eaf6e5233cb 100644
--- a/jax/_src/pallas/mosaic_gpu/lowering.py
+++ b/jax/_src/pallas/mosaic_gpu/lowering.py
@@ -19,22 +19,24 @@
from collections.abc import Sequence
import dataclasses
import functools
+import itertools as it
import math
from typing import Any, cast
import jax
+from jax import lax
from jax._src import core as jax_core
from jax._src import pjit
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
-from jax._src.lax import lax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
+from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.state import primitives as sp
from jax.experimental.mosaic import gpu as mosaic_gpu
@@ -96,8 +98,9 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
class ModuleContext:
name: str
grid_mapping: pallas_core.GridMapping
+ approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
- smem_used_bytes: int
+ smem_used_bytes: int = 0
# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
def scratch_view(
@@ -156,7 +159,8 @@ def stack_free_smem(self, bytes: int):
@dataclasses.dataclass(frozen=True)
class LoweringRuleContext:
- module_context: ModuleContext
+ module_ctx: ModuleContext
+ launch_ctx: mosaic_gpu.LaunchContext
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
@@ -175,10 +179,13 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name
def _eval_index_map(
- ctx: ModuleContext, idx: ir.Value, block_mapping: pallas_core.BlockMapping
+ module_ctx: ModuleContext,
+ launch_ctx: mosaic_gpu.LaunchContext,
+ idx: ir.Value,
+ block_mapping: pallas_core.BlockMapping,
) -> Sequence[ir.Value]:
block_indices = lower_jaxpr_to_mosaic_gpu(
- ctx, block_mapping.index_map_jaxpr.jaxpr, idx
+ module_ctx, launch_ctx, block_mapping.index_map_jaxpr.jaxpr, idx
)
result = []
for i, b in zip(block_indices, block_mapping.block_shape):
@@ -231,17 +238,16 @@ def lower_jaxpr_to_module(
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
-
- num_inputs = grid_mapping.num_inputs
params = compiler_params.get("mosaic_gpu", {})
+ approx_math = params.get("approx_math", False)
num_stages = params.get("num_stages", 1)
- dimension_semantics = params.get(
- "dimension_semantics", ["parallel"] * len(grid_mapping.grid)
- )
- if len(dimension_semantics) != len(grid_mapping.grid):
+ dimension_semantics = params.get("dimension_semantics")
+ if dimension_semantics is None:
+ dimension_semantics = ["parallel"] * len(grid_mapping.grid)
+ elif len(dimension_semantics) != len(grid_mapping.grid):
raise ValueError(
- "dimension_semantics must have an entrey for each grid dimension:"
- f" {len(dimension_semantics)=}, but len(grid={grid_mapping.grid})."
+ "dimension_semantics must have an entry for each grid dimension:"
+ f" {len(dimension_semantics)=}, but len(grid) is {grid_mapping.grid})."
)
sequential_axes = tuple(
i for i, s in enumerate(dimension_semantics) if s == "sequential"
@@ -249,39 +255,58 @@ def lower_jaxpr_to_module(
assert all(grid[axis] for axis in sequential_axes)
assert all(block[axis] == 1 for axis in sequential_axes)
+ in_in_smem, out_in_smem = util.split_list(
+ [
+ bm.block_aval.memory_space in (None, gpu_core.SMEM)
+ for bm in block_mappings
+ ],
+ [grid_mapping.num_inputs],
+ )
+
in_structs_gmem = [*grid_mapping.in_shapes]
in_block_shapes = [
bm.block_shape
- for bm in grid_mapping.block_mappings[:num_inputs]
+ for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
in_structs_smem = [
- jax.ShapeDtypeStruct(
- [num_stages,
- *bm.ref_aval.inner_aval.shape], # pytype: disable=attribute-error
- bm.ref_aval.inner_aval.dtype) # pytype: disable=attribute-error
- for bm in block_mappings[:num_inputs]
+ jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype)
+ if in_smem
+ else None
+ for bm, in_smem in zip(
+ block_mappings[: grid_mapping.num_inputs], in_in_smem
+ )
]
in_gmem_transforms = [
- bm.transforms for bm in grid_mapping.block_mappings[:num_inputs]
+ cast(gpu_core.MemoryRefTransform, bm.transforms)
+
+ for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs]
]
- _get_swizzle = (
+ in_swizzles = map(
lambda bm: bm.swizzle
if isinstance(bm, gpu_core.GPUBlockMapping)
- else None
+ else None,
+ grid_mapping.block_mappings[: grid_mapping.num_inputs],
)
- in_swizzles = map(_get_swizzle, grid_mapping.block_mappings[:num_inputs])
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
out_structs_smem = [
jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype)
- for bm, s in zip(
+ if in_smem
+ else None
+ for bm, in_smem, s in zip(
block_mappings[grid_mapping.num_inputs :],
+ out_in_smem,
grid_mapping.out_shapes,
)
]
def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value):
- *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers
+ *buffers_gmem, (
+ buffers_smem,
+ *scratch_buffers_smem,
+ runtime_smem,
+ barriers,
+ ) = buffers
assert len(buffers_gmem) == len(buffers_smem)
in_buffers_gmem, out_buffers_gmem = util.split_list(
buffers_gmem, [grid_mapping.num_inputs]
@@ -289,19 +314,31 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value):
in_buffers_smem, out_buffers_smem = util.split_list(
buffers_smem, [grid_mapping.num_inputs]
)
+ barriers, *extra_barriers = barriers
module_ctx = ModuleContext(
- name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0
+ name_and_src_info.name, grid_mapping, approx_math, runtime_smem
)
program_ids = map(_program_id, range(len(grid_mapping.grid)))
start_indices = map(
- functools.partial(_eval_index_map, module_ctx, program_ids),
+ partial(_eval_index_map, module_ctx, launch_ctx, program_ids),
block_mappings,
)
in_start_indices, out_start_indices = util.split_list(
start_indices, [grid_mapping.num_inputs]
)
+ # Precompute the total number of bytes transferred from GMEM to SMEM,
+ # so that we can do a single arrive instruction for all of the inputs.
+ in_transfer_bytes = 0
+ for in_smem, b_smem in zip(in_in_smem, in_buffers_smem):
+ if not in_smem:
+ continue
+ b_smem_type = ir.MemRefType(b_smem.type)
+ in_transfer_bytes += math.prod(b_smem_type.shape[1:]) * mgpu.bytewidth(
+ b_smem_type.element_type
+ )
+
def gmem_slice(
start_indices: Sequence[ir.Value],
step: ir.Value,
@@ -320,6 +357,9 @@ def gmem_slice(
)
def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
+ if not in_in_smem[idx]:
+ return
+
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
gmem_transforms = (x.to_gpu_transform() for x in in_gmem_transforms[idx])
launch_ctx.async_copy(
@@ -333,11 +373,14 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
barrier=barriers[slot],
gmem_transform=tuple(gmem_transforms),
swizzle=in_swizzles[idx],
- arrive=True,
+ arrive=False, # The caller must do ``arrive_expect_tx`` manually!
uniform=False,
)
def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
+ if not out_in_smem[idx]:
+ return
+
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
@@ -362,6 +405,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
if any(
b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis]
for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem)
+ if b_smem
):
raise ValueError(
"Array dimensions along the sequential axis must be divisible by"
@@ -370,6 +414,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
num_steps, *rest = {
b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis]
for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem)
+ if b_smem
}
if rest:
raise ValueError(
@@ -382,6 +427,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
with mgpu.single_thread():
for slot in range(min(num_stages, num_steps)):
+ barriers[slot].arrive_expect_tx(in_transfer_bytes)
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))
@@ -392,11 +438,15 @@ def _(step, _):
# Only wait if async copies were issued.
barriers[slot].wait()
- _ = lower_jaxpr_to_mosaic_gpu(
- module_ctx,
- jaxpr,
- [mgpu.memref_slice(b_smem, slot) for b_smem in buffers_smem],
- )
+ args = [
+ mgpu.memref_slice(buffers_smem[idx], slot)
+ if in_smem
+ else buffers_gmem[idx]
+ for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem))
+ ]
+ args.extend(scratch_buffers_smem)
+ args.extend(extra_barriers)
+ _ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args)
mgpu.commit_shared()
with mgpu.single_thread():
@@ -410,17 +460,40 @@ def _(step, _):
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
for idx in range(grid_mapping.num_inputs):
fetch(idx, next_step, slot)
+ barriers[slot].arrive_expect_tx(in_transfer_bytes)
return ()
launch_ctx.await_async_copy(0)
+ scratch_avals = [
+ var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops]
+ ]
+ if not all(
+ isinstance(aval, pallas_core.AbstractMemoryRef)
+ and aval.memory_space is gpu_core.SMEM
+ for aval in scratch_avals
+ ):
+ raise TypeError(
+ f"All scratch operands must be in SMEM, but got: {scratch_avals}"
+ )
+ extra_barriers = [
+ mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
+ for aval in scratch_avals
+ if isinstance(aval.dtype, gpu_core.BarrierType)
+ ]
+ extra_smem_scratch = [
+ jax.ShapeDtypeStruct(aval.shape, aval.dtype)
+ for aval in scratch_avals
+ if not isinstance(aval.dtype, gpu_core.BarrierType)
+ ]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr)
- extra_smem_scratch = [
+ extra_smem_scratch.append(
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
- ]
+ )
+
module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel(
body,
grid=grid,
@@ -429,12 +502,11 @@ def _(step, _):
in_shapes=in_structs_gmem,
out_shape=out_structs_gmem,
smem_scratch_shape=(
- *in_structs_smem,
- *out_structs_smem,
+ (*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
- mgpu.Barrier(
- arrival_count=len(in_structs_gmem),
- num_barriers=num_stages,
+ (
+ mgpu.Barrier(arrival_count=1, num_barriers=num_stages),
+ *extra_barriers,
),
),
module_name=name_and_src_info.name,
@@ -455,7 +527,8 @@ def deco(fn):
def lower_jaxpr_to_mosaic_gpu(
- ctx: ModuleContext,
+ module_ctx: ModuleContext,
+ launch_ctx: mosaic_gpu.LaunchContext,
jaxpr: jax_core.Jaxpr,
args: Sequence[ir.Value],
consts=(),
@@ -480,7 +553,8 @@ def write_env(var: jax_core.Var, val):
)
rule = mosaic_lowering_rules[eqn.primitive]
rule_ctx = LoweringRuleContext(
- ctx,
+ module_ctx,
+ launch_ctx,
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
)
@@ -547,7 +621,9 @@ def _swap_lowering_rule(
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
if jaxpr.consts:
raise NotImplementedError
- return lower_jaxpr_to_mosaic_gpu(ctx.module_context, jaxpr.jaxpr, args)
+ return lower_jaxpr_to_mosaic_gpu(
+ ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args
+ )
@register_lowering_rule(lax.broadcast_in_dim_p)
@@ -560,7 +636,8 @@ def _broadcast_in_dim_lowering_rule(
):
if broadcast_dimensions:
raise NotImplementedError
- return _ensure_fa(x, ctx.avals_in[0]).broadcast(shape)
+ [x_aval] = ctx.avals_in
+ return _ensure_fa(x, x_aval.dtype).broadcast(shape)
@register_lowering_rule(lax.convert_element_type_p)
@@ -568,7 +645,8 @@ def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
del weak_type, sharding
- return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype))
+ [x_aval] = ctx.avals_in
+ return _ensure_fa(x, x_aval.dtype).astype(mlir.dtype_to_ir_type(new_dtype))
def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
@@ -586,7 +664,8 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
@register_lowering_rule(lax.integer_pow_p)
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
- x = _ensure_fa(x, *ctx.avals_in)
+ [x_aval] = ctx.avals_in
+ x = _ensure_fa(x, x_aval.dtype)
if y == 2:
return x * x
return NotImplementedError
@@ -594,7 +673,8 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
@register_lowering_rule(lax.rsqrt_p)
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
- return _ensure_fa(x, *ctx.avals_in).rsqrt()
+ [x_aval] = ctx.avals_in
+ return _ensure_fa(x, x_aval.dtype).rsqrt(ctx.module_ctx.approx_math)
@register_lowering_rule(lax.reduce_sum_p)
@@ -602,7 +682,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
if axes != (0,):
raise NotImplementedError("No support for axes other than 0 yet")
[x_aval] = ctx.avals_in
- _, [scratch] = ctx.module_context.scratch_view(
+ _, [scratch] = ctx.module_ctx.scratch_view(
[jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)]
)
return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ())
@@ -615,8 +695,9 @@ def _debug_print_lowering_rule(
fmt,
has_placeholders: bool,
):
- del ctx
- del has_placeholders
+ del has_placeholders # Unused.
+ if any(aval.shape for aval in ctx.avals_in):
+ raise NotImplementedError("Only scalar values are supported")
primitives.check_debug_print_format(fmt, *args)
mgpu.debug_print(fmt, *args)
return ()
@@ -627,16 +708,91 @@ def _run_scoped_lowering_rule(
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
):
in_avals = [v.aval.inner_aval for v in jaxpr.invars]
- bytes_allocated, input_refs = ctx.module_context.scratch_view(
- [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) for aval in in_avals]
- )
+ bytes_allocated, input_refs = ctx.module_ctx.scratch_view([
+ jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)
+ for aval in in_avals
+ ])
outs = lower_jaxpr_to_mosaic_gpu(
- ctx.module_context, jaxpr, input_refs, consts
+ ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
)
- ctx.module_context.stack_free_smem(bytes_allocated)
+ ctx.module_ctx.stack_free_smem(bytes_allocated)
return outs
+def _lower_jaxpr_to_for_loop(
+ ctx: LoweringRuleContext,
+ jaxpr: jax_core.Jaxpr,
+ start: ir.Value,
+ length: ir.Value,
+ consts,
+ *args,
+ has_loop_index: bool,
+):
+
+ @mgpu.fori(length, [*args])
+ def loop(loop_index, body_args):
+ if has_loop_index:
+ loop_index = arith_dialect.addi(loop_index, start)
+ jaxpr_args = [*consts, loop_index, *body_args]
+ else:
+ jaxpr_args = [*consts, *body_args]
+ return lower_jaxpr_to_mosaic_gpu(
+ ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args
+ )
+
+ return loop.results
+
+
+@register_lowering_rule(lax.scan_p)
+def _scan_lowering_rule(
+ ctx: LoweringRuleContext,
+ *args,
+ jaxpr: jax_core.ClosedJaxpr,
+ linear: tuple[bool, ...],
+ length: int,
+ reverse: bool,
+ unroll: bool | int,
+ num_consts: int,
+ num_carry: int,
+ _split_transpose: bool,
+):
+ # Can only handle fori_loop-like scans.
+ if (
+ (num_extensive := len(args) - num_consts - num_carry)
+ or reverse
+ or unroll != 1
+ ):
+ raise NotImplementedError
+ del linear, num_extensive, reverse, unroll
+
+ jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
+ if jaxpr_consts:
+ raise NotImplementedError
+ del jaxpr_consts
+
+ jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop(
+ jaxpr, num_consts, num_carry
+ )
+ consts, args = util.split_list(args, [num_consts])
+ _consts_avals, arg_avals = util.split_list(ctx.avals_in, [num_consts])
+ if has_loop_index:
+ start, *args = args
+ index_aval, *_arg_avals = arg_avals
+ start = _ensure_ir_value(start, index_aval)
+ length = _ir_constant(length, start.type)
+ else:
+ start = _i32_constant(0)
+ length = _i32_constant(length)
+ for_out = _lower_jaxpr_to_for_loop(
+ ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index
+ )
+ if has_loop_index:
+ # Need to return the final loop index value if the outer scan expects
+ # it as an output.
+ return [length, *for_out]
+ return for_out
+
+
def _bcast(
x: ir.Value,
y: ir.Value,
@@ -644,22 +800,16 @@ def _bcast(
y_aval: jax_core.ShapedArray,
out_aval: jax_core.ShapedArray,
) -> ir.Value:
- if isinstance(x, (np.ndarray, np.number, int, float)):
+ if not isinstance(x, mgpu.FragmentedArray):
x_dtype = x_aval.dtype
if x_aval.weak_type:
x_dtype = y_aval.dtype
- x = mgpu.FragmentedArray.splat(
- _ir_constant(x, mlir.dtype_to_ir_type(x_dtype)), ()
- )
- if isinstance(y, (np.ndarray, np.number, int, float)):
+ x = _ensure_fa(x, x_dtype)
+ if not isinstance(y, mgpu.FragmentedArray):
y_dtype = y_aval.dtype
if y_aval.weak_type:
y_dtype = x_aval.dtype
- y = mgpu.FragmentedArray.splat(
- _ir_constant(y, mlir.dtype_to_ir_type(y_dtype)), ()
- )
- assert isinstance(x, mgpu.FragmentedArray)
- assert isinstance(y, mgpu.FragmentedArray)
+ y = _ensure_fa(y, y_dtype)
if x_aval.shape != out_aval.shape:
x = x.broadcast(out_aval.shape)
if y_aval.shape != out_aval.shape:
@@ -667,17 +817,25 @@ def _bcast(
return x, y
-def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray:
+def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
if isinstance(x, mgpu.FragmentedArray):
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return mgpu.FragmentedArray.splat(
- _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), ()
+ _ir_constant(x, mlir.dtype_to_ir_type(dtype)), ()
)
elif isinstance(x, ir.Value):
- if isinstance(x.type, (ir.IntegerType, ir.FloatType)):
+ if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)):
return mgpu.FragmentedArray.splat(x, ())
- raise NotImplementedError
+ raise NotImplementedError(f"Unsupported type: {type(x)}")
+
+
+def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value:
+ if isinstance(x, ir.Value):
+ return x
+ elif isinstance(x, (np.number, np.ndarray, int, float)):
+ return _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype))
+ raise NotImplementedError(f"Unsupported type: {type(x)}")
def _ir_constant(v: object, t: ir.Type) -> ir.Value:
@@ -691,8 +849,12 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value:
raise NotImplementedError(f"Unsupported constant: {v!r}")
-def _i32_constant(v: object) -> ir.Value:
- return _ir_constant(v, ir.IntegerType.get_signless(32))
+def _i32_constant(v: int) -> ir.Value:
+ return arith_dialect.constant(ir.IntegerType.get_signless(32), v)
+
+
+def _i64_constant(v: int) -> ir.Value:
+ return arith_dialect.constant(ir.IntegerType.get_signless(64), v)
def _as_index(v: int | ir.Value) -> ir.Value:
diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
index 5b46caf1553a..5b09cad176a6 100644
--- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
+++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
@@ -37,8 +37,9 @@ def pallas_call_lowering(
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
- del interpret
+ del interpret, out_avals
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Mosaic GPU backend"
diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py
new file mode 100644
index 000000000000..e96574612bfa
--- /dev/null
+++ b/jax/_src/pallas/mosaic_gpu/primitives.py
@@ -0,0 +1,105 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""GPU-specific Pallas primitives."""
+
+from __future__ import annotations
+
+from jax._src import core as jax_core
+from jax._src import state
+from jax._src.pallas import core as pallas_core
+from jax._src.pallas.mosaic_gpu import core as gpu_core
+from jax._src.pallas.mosaic_gpu import lowering
+
+
+async_copy_p = jax_core.Primitive("async_copy")
+async_copy_p.multiple_results = True
+
+
+@async_copy_p.def_effectful_abstract_eval
+def _async_copy_abstract_eval(*avals):
+ del avals # Unused.
+ return (), {state.ReadEffect(0), state.WriteEffect(1)}
+
+
+@lowering.register_lowering_rule(async_copy_p)
+def _async_copy_lowering_rule(
+ ctx: lowering.LoweringRuleContext, src, dst, barrier=None
+):
+ ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, barrier=barrier)
+ return ()
+
+
+def async_copy_smem_to_gmem(
+ src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef
+) -> None:
+ if src.memory_space is not gpu_core.SMEM:
+ raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
+ if dst.memory_space is not gpu_core.GMEM:
+ raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
+ async_copy_p.bind(src, dst)
+ return None
+
+
+def async_copy_gmem_to_smem(
+ src: pallas_core.AbstractMemoryRef,
+ dst: pallas_core.AbstractMemoryRef,
+ *,
+ barrier: pallas_core.AbstractMemoryRef,
+) -> None:
+ if src.memory_space is not gpu_core.GMEM:
+ raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
+ if dst.memory_space is not gpu_core.SMEM:
+ raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
+ async_copy_p.bind(src, dst, barrier)
+ return None
+
+
+class WaitEffect(jax_core.Effect):
+ ...
+
+
+wait_effect = WaitEffect()
+
+
+wait_p = jax_core.Primitive("wait")
+wait_p.multiple_results = True
+
+
+@wait_p.def_effectful_abstract_eval
+def _wait_abstract_eval(*avals, **params):
+ del avals, params # Unused.
+ return (), {wait_effect}
+
+
+@lowering.register_lowering_rule(wait_p)
+def _wait_lowering_rule(
+ ctx: lowering.LoweringRuleContext, barrier=None, allow_groups=None,
+):
+ if barrier is not None:
+ barrier.wait()
+ else:
+ assert allow_groups is not None
+ ctx.launch_ctx.await_async_copy(allow_groups=allow_groups)
+ return ()
+
+
+def wait_smem_to_gmem(allow_groups: int) -> None:
+ """Waits until there are no more than the given number of SMEM->GMEM copies in flight."""
+ wait_p.bind(allow_groups=allow_groups)
+
+
+def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None:
+ """Waits on the given barrier."""
+ wait_p.bind(barrier)
diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py
index b69fb03f0951..1c10d2bda9e9 100644
--- a/jax/_src/pallas/pallas_call.py
+++ b/jax/_src/pallas/pallas_call.py
@@ -62,6 +62,7 @@
BlockSpecTree = pallas_core.BlockSpecTree
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
+ScratchShapeTree = pallas_core.ScratchShapeTree
CostEstimate = pallas_core.CostEstimate
# See the docstring for GridMapping for the calling convention
@@ -167,8 +168,12 @@ def _get_next_indices(grid, indices):
next_indices.append(jnp.where(carry, 0, i))
return tuple(reversed(next_indices))
-def _pallas_call_impl(*args, **kwargs):
- assert False # We always jit a pallas call, we only need the lowering rule
+def _pallas_call_impl(*args, **params):
+ # Call the lowering path
+ @partial(jax.jit, inline=True)
+ def _jit_run(*args):
+ return pallas_call_p.bind(*args, **params)
+ return _jit_run(*args)
def _pallas_call_impl_interpret(
@@ -180,8 +185,9 @@ def _pallas_call_impl_interpret(
grid_mapping: GridMapping,
compiler_params: Any,
cost_estimate: CostEstimate,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
- del compiler_params, cost_estimate
+ del compiler_params, cost_estimate, out_avals
# If we're in interpret mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list( # type: ignore
@@ -323,10 +329,20 @@ def body(carry):
pallas_call_p.def_impl(_pallas_call_impl)
-def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
- return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape,
- bm.array_shape_dtype.dtype)
- for bm in grid_mapping.block_mappings_output)
+
+def _pallas_call_abstract_eval(
+ *avals, out_avals: tuple[jax_core.AbstractValue, ...], **_
+):
+ del avals
+ # Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
+ return [
+ jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
+ if isinstance(a, pallas_core.ShapedArrayWithMemorySpace)
+ else a
+ for a in out_avals
+ ]
+
+
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
@@ -342,6 +358,7 @@ def _pallas_call_jvp_rule(
interpret,
compiler_params: Any,
cost_estimate: CostEstimate | None,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
@@ -405,6 +422,7 @@ def _pallas_call_jvp_rule(
input_output_aliases=(),
compiler_params=compiler_params,
cost_estimate=jvp_cost_estimate,
+ out_avals=(*out_avals, *out_avals)
)
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
return out_primals, out_tangents
@@ -538,6 +556,7 @@ def _batch_with_explicit_loop(
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
"""Batch the pallas_call by calling it in loop over the batch size.
@@ -604,6 +623,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
+ out_avals=out_avals,
)
for i, batch_out_array in enumerate(batch_out):
state[i] = jax.lax.dynamic_update_index_in_dim(
@@ -642,6 +662,7 @@ def _pallas_call_batching_rule(
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
@@ -684,6 +705,7 @@ def get_size(i, x, d):
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
+ out_avals=out_avals,
)
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
@@ -716,6 +738,7 @@ def get_size(i, x, d):
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
+ out_avals=out_avals,
)
else:
pass # No dynamic grid dimensions
@@ -749,6 +772,7 @@ def get_size(i, x, d):
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
+ out_avals=out_avals,
)
if not dims:
@@ -922,7 +946,11 @@ def g():
assert ragged_axis_length is not None
args = (ragged_axis_length, *args)
-
+ assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
+ batched_out_avals = tuple(
+ aval.update(shape=tuple_insert(aval.shape, 0, axis_size))
+ for aval in out_avals
+ )
out = pallas_call_p.bind(
*dynamic_grid_args,
*args,
@@ -936,6 +964,7 @@ def g():
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=batched_cost_estimate,
+ out_avals=batched_out_avals,
)
return out, (0,) * len(out)
@@ -965,6 +994,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
interpret: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
+ out_avals: tuple[jax_core.AbstractValue, ...],
**kwargs):
# We implement the checkify rule in 4 steps:
# 1) First, trace the kernel body to get the expected error shapes.
@@ -1091,11 +1121,13 @@ def _ensure_2d_error_shape(arg):
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
new_vals_in = [*scalars, *err_vals, *args]
+ new_out_avals = (*shaped_err_avals, *out_avals)
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
jaxpr=final_jaxpr,
interpret=interpret,
grid_mapping=grid_mapping_with_error,
input_output_aliases=input_output_aliases_with_error,
+ out_avals=new_out_avals,
**kwargs)
errors, results = split_list(result, [num_err_vals])
# TODO(b/350593266): Remove line below once we support ()-shaped scalars.
@@ -1224,6 +1256,17 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params):
)
jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule
+def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
+ match out_shape:
+ case jax.ShapeDtypeStruct():
+ return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
+ case pallas_core.MemoryRef():
+ return out_shape.get_array_aval()
+ case _:
+ if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")):
+ raise ValueError(f"Invalid out_shape type: {type(out_shape)}")
+ return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
+
def pallas_call(
kernel: Callable[..., None],
@@ -1233,6 +1276,7 @@ def pallas_call(
grid: TupleGrid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
+ scratch_shapes: ScratchShapeTree = (),
input_output_aliases: dict[int, int] = {},
debug: bool = False,
interpret: bool = False,
@@ -1250,8 +1294,9 @@ def pallas_call(
corresponding ``in_specs`` and ``out_specs``.
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
and dtypes of the outputs.
- grid_spec: An alternative way to specify ``grid``, ``in_specs``, and
- ``out_specs``. If given, those other parameters must not be also given.
+ grid_spec: An alternative way to specify ``grid``, ``in_specs``,
+ ``out_specs`` and ``scratch_shapes``. If given, those other parameters
+ must not be also given.
grid: the iteration space, as a tuple of integers. The kernel is executed
as many times as ``prod(grid)``.
See details at :ref:`pallas_grid`.
@@ -1265,6 +1310,9 @@ def pallas_call(
The default value for ``out_specs`` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
See details at :ref:`pallas_blockspec`.
+ scratch_shapes: a PyTree of backend-specific temporary objects required
+ by the kernel, such as temporary buffers, synchronization primitives,
+ etc.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them. These indices are in the
flattened inputs and outputs.
@@ -1305,7 +1353,7 @@ def pallas_call(
}
if grid_spec is None:
- grid_spec = GridSpec(grid, in_specs, out_specs)
+ grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
else:
if grid:
raise ValueError(
@@ -1319,6 +1367,10 @@ def pallas_call(
raise ValueError(
"If `grid_spec` is specified, then `out_specs` must "
f"be `no_block_spec`. It is {out_specs}")
+ if scratch_shapes:
+ raise ValueError(
+ "If `grid_spec` is specified, then `scratch_shapes` must "
+ f"be `()`. It is {scratch_shapes}")
del grid, in_specs, out_specs
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
# TODO(necula): this canonicalization may be convenient for some usage
@@ -1328,17 +1380,15 @@ def pallas_call(
out_shape = tuple(out_shape)
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
- flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore
- for x in flat_out_shapes]
- @jax.jit
+ @partial(jax.jit, inline=True)
def wrapped(*args):
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
in_paths, flat_args = unzip2(flat_args_with_paths)
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
- flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
+ flat_out_avals = tuple(_convert_out_shape_to_aval(v)
for v in flat_out_shapes)
kernel_fun_sig = api_util.fun_signature(kernel)
@@ -1393,6 +1443,7 @@ def wrapped(*args):
*dynamic_grid_bounds,
*index_args,
*rest_args,
+ out_avals=flat_out_avals,
jaxpr=jaxpr,
name_and_src_info=name_and_src_info,
debug=debug,
diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py
index e41a8cf59975..89b6c6e14acd 100644
--- a/jax/_src/pallas/primitives.py
+++ b/jax/_src/pallas/primitives.py
@@ -177,8 +177,10 @@ def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None,
atomic_type: AtomicOpType):
- x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "atomic_rmw")
- args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
+ x_ref, transforms = sp.get_ref_and_transforms(
+ x_ref_or_view, idx, "atomic_rmw"
+ )
+ args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask))
return atomic_rmw_p.bind(
*args_flat, args_tree=args_tree, atomic_type=atomic_type
)
@@ -379,7 +381,7 @@ def _load_pp_rule(eqn, context, settings):
result = [
lhs,
pp.text(' <- '),
- sp.pp_ref_indexers(context, x, indexers)
+ sp.pp_ref_transforms(context, x, indexers)
]
if mask is not None:
result += [
@@ -529,7 +531,7 @@ def _swap_pp_rule(eqn, context, settings):
# Pretty prints `_ = swap x v i` as `x[i] <- v`
y, = eqn.outvars
x, indexers, val, mask = eqn.params["args_tree"].unflatten(eqn.invars)
- x_i = sp.pp_ref_indexers(context, x, indexers)
+ x_i = sp.pp_ref_transforms(context, x, indexers)
if isinstance(y, jax_core.DropVar):
return pp.concat([
x_i,
@@ -638,8 +640,10 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
eviction_policy: TO BE DOCUMENTED.
volatile: TO BE DOCUMENTED.
"""
- x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load")
- args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other))
+ x_ref, transforms = sp.get_ref_and_transforms(x_ref_or_view, idx, "load")
+ args_flat, args_tree = tree_util.tree_flatten(
+ (x_ref, transforms, mask, other)
+ )
return load_p.bind(
*args_flat,
args_tree=args_tree,
@@ -657,8 +661,10 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
Returns:
The value stored in the ref prior to the swap.
"""
- x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name)
- args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask))
+ x_ref, transforms = sp.get_ref_and_transforms(
+ x_ref_or_view, idx, _function_name
+ )
+ args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask))
return swap_p.bind(
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
)
@@ -708,7 +714,7 @@ class PrintEffect(effects.Effect):
def debug_print(fmt: str, *args: jax.typing.ArrayLike):
- """Prints scalar values from inside a Pallas kernel.
+ """Prints values from inside a Pallas kernel.
Args:
fmt: A format string to be included in the output. The restrictions on the
@@ -718,11 +724,11 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike):
(``{...}``), since it is always printed before any of the values.
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
- conversions are not supported.
+ conversions are not supported. All values must be scalars.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
- the format string.
- *args: The scalar values to print.
+ the format string. All values must be scalars.
+ *args: The values to print.
""" # fmt: skip
has_placeholders = False
if fmt:
@@ -765,9 +771,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool):
@debug_print_p.def_effectful_abstract_eval
def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool):
- del fmt, has_placeholders
- if any(aval.shape for aval in avals):
- raise ValueError("Only scalar values are supported")
+ del avals, fmt, has_placeholders # Unused.
return [], {debug_print_effect}
@@ -824,7 +828,7 @@ def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
- avals = [t.get_aval() for t in flat_types]
+ avals = [t.get_ref_aval() for t in flat_types]
# Turn the function into a jaxpr. The body of run_scoped may have
# effects (IO) on constvars (i.e. variables inherited from the
# parent scope). Jax can't reason about effects to references that
diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py
index 5e495f4bef3e..ac28bd21a3dc 100644
--- a/jax/_src/pallas/triton/lowering.py
+++ b/jax/_src/pallas/triton/lowering.py
@@ -277,6 +277,10 @@ def lower_jaxpr_to_triton_module(
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
+ if jaxpr.invars[grid_mapping.slice_scratch_ops]:
+ raise NotImplementedError(
+ "scratch memory not implemented in the Triton backend"
+ )
with grid_mapping.trace_env():
jaxpr, _ = pe.dce_jaxpr(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
@@ -1202,7 +1206,14 @@ def debug_print_lowering_rule(
"pl.debug_print() does not support placeholders when lowering to Triton"
)
- tt_dialect.print_(f" {fmt} ", hex=False, args=args)
+ tt_dialect.print_(
+ f" {fmt} ",
+ hex=False,
+ args=args,
+ is_signed=ir.DenseI32ArrayAttr.get([
+ jnp.issubdtype(aval.dtype, jnp.signedinteger) for aval in ctx.avals_in
+ ]),
+ )
return ()
@@ -1344,7 +1355,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
register_lowering(lax.erf_inv_p)(
- lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)
+ lower_fun(pallas_utils.erf_inv_lowering_helper, multiple_results=False)
)
@@ -2354,10 +2365,8 @@ def _lower_jaxpr_to_for_loop(
else:
jaxpr_args = [*consts, *for_body_args]
all_out = lower_jaxpr_to_triton_ir(
- ctx.context,
- jaxpr,
- ctx.block_infos,
- *jaxpr_args)
+ ctx.context, jaxpr, ctx.block_infos, *jaxpr_args
+ )
scf_dialect.yield_(all_out)
return list(for_op.results_)
@@ -2394,11 +2403,9 @@ def _scan_lowering_rule(
args = map(_ensure_ir_value, args, ctx.avals_in)
consts, args = util.split_list(args, [num_consts])
if has_loop_index:
- lb, *args = args
- lower_bound = lb
- ub = _add(lb, _ir_constant(length, lb.type))
- upper_bound = ub
- bound_type = ub.type
+ lower_bound, *args = args
+ upper_bound = _add(lower_bound, _ir_constant(length, lower_bound.type))
+ bound_type = lower_bound.type
else:
lower_bound = _i32_constant(0)
upper_bound = _i32_constant(length)
diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py
index 5ee7077dcc1f..67b0bd326616 100644
--- a/jax/_src/pallas/triton/pallas_call_registration.py
+++ b/jax/_src/pallas/triton/pallas_call_registration.py
@@ -49,8 +49,9 @@ def pallas_call_lowering(
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
+ out_avals: tuple[jax_core.AbstractValue, ...],
):
- del interpret
+ del interpret, out_avals
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py
index cfca0769d13d..e485537216ca 100644
--- a/jax/_src/pallas/utils.py
+++ b/jax/_src/pallas/utils.py
@@ -72,7 +72,7 @@ def next_power_of_2(x: int) -> int:
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int:
- if isinstance(dtype, jnp.integer):
+ if jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).bits
return np.dtype(dtype).itemsize * 8
@@ -186,7 +186,7 @@ def pattern_match_while_to_fori_loop(
# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802
-def erf_inv_32_lowering_helper(x):
+def _erf_inv_32_lowering_helper(x):
k_degree = 9
w_lt_5_constants = [
2.81022636e-08, 3.43273939e-07, -3.5233877e-06,
@@ -212,6 +212,83 @@ def erf_inv_32_lowering_helper(x):
return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)
+# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L696-L802
+def _erf_inv_64_lowering_helper(x):
+ w_lt_625_constants = [
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356
+ ]
+
+ w_lt_16_constants = [
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635,
+ ]
+
+ w_gt_16_constants = [
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221,
+ ] # should add "as jnp.float64 array"?
+
+ w = -jnp.log1p(x * -x)
+ w_lt_625 = w < 6.25
+ w_lt_16 = w < 16.0
+
+ def get_coefficient(i):
+ c = w_lt_625_constants[i]
+ if i < 19:
+ c = jnp.where(w_lt_625, c, w_lt_16_constants[i])
+ if i < 17:
+ c = jnp.where(w_lt_16, c, w_gt_16_constants[i])
+ return c
+
+ select2 = jnp.where(w_lt_16, 3.25, 5.0)
+ select2_result = jnp.sqrt(w) - select2
+ w = jnp.where(w_lt_625, w - 3.125, select2_result)
+
+ p = get_coefficient(0)
+ for i in range(1, 17):
+ p = get_coefficient(i) + p * w
+ for i in range(17, 19):
+ p = jnp.where(w_lt_16, get_coefficient(i) + p * w, p)
+ for i in range(19, 23):
+ p = jnp.where(w_lt_625, get_coefficient(i) + p * w, p)
+
+ return jnp.where(jnp.abs(x) == 1.0, np.inf * x, p * x)
+
+
+def erf_inv_lowering_helper(x):
+ if x.dtype == jnp.float32:
+ return _erf_inv_32_lowering_helper(x)
+ if x.dtype == jnp.float64:
+ return _erf_inv_64_lowering_helper(x)
+ raise NotImplementedError(f"erf_inv_lowering_helper not implemented for {x.dtype}")
+
+
def sign_lowering_helper(x):
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
return (x != 0).astype(x.dtype)
diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index ed5b825c62b4..0abaa3fd0139 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -19,7 +19,6 @@
import dataclasses
from functools import partial
import inspect
-import itertools as it
import logging
import operator as op
import weakref
@@ -63,6 +62,7 @@
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
+from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
@@ -165,7 +165,6 @@ class PjitInfo(NamedTuple):
keep_unused: bool
inline: bool
abstracted_axes: Any | None
- has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit
# Hash and compare PjitInfo by identity when used as a cache key.
@@ -312,14 +311,39 @@ def _cpp_pjit_evict_fn(self):
# The entries are doubled here from the default 4096 because _pjit_call_impl
# also has a cpp dispatch path and that would double the number of entries in
# the global shared cache.
-_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192)
+# This cache is only used for jit's with only fun. For example: jax.jit(f)
+_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
+# This cache is used for jit where extra arguments are defined other than the
+# fun. For example: jax.jit(f, donate_argnums=...) OR
+# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
+# capacity might get full very fast because of all the jitted function in JAX
+# which might evict train_step for example.
+_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
-def _get_cpp_global_cache(pjit_has_explicit_sharding):
- if pjit_has_explicit_sharding:
- return xc._xla.PjitFunctionCache()
- else:
- return _cpp_pjit_cache
+
+if xla_extension_version < 286:
+ def _get_cpp_global_cache(pjit_has_explicit_sharding):
+ if pjit_has_explicit_sharding:
+ return xc._xla.PjitFunctionCache()
+ else:
+ return _cpp_pjit_cache_fun_only
+
+ def _pjit_explicit_sharding_and_layout(
+ in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
+ device, backend) -> bool:
+ return (device is not None or
+ backend is not None or
+ any(not is_unspecified(i) for i in in_shardings_flat) or
+ any(not is_unspecified(o) for o in out_shardings_flat) or
+ any(i is not None for i in in_layouts_flat) or
+ any(o is not None for o in out_layouts_flat))
+else:
+ def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
+ if contains_explicit_attributes:
+ return _cpp_pjit_cache_explicit_attributes
+ else:
+ return _cpp_pjit_cache_fun_only
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@@ -340,11 +364,35 @@ def cache_miss(*args, **kwargs):
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
- cpp_pjit_f = xc._xla.pjit(
- fun_name(fun),
- fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
- jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
- _get_cpp_global_cache(jit_info.has_explicit_sharding))
+ if xla_extension_version >= 286:
+ cache_key = pxla.JitGlobalCppCacheKeys(
+ donate_argnums=jit_info.donate_argnums,
+ donate_argnames=jit_info.donate_argnames,
+ device=jit_info.device, backend=jit_info.backend,
+ in_shardings_treedef=jit_info.in_shardings_treedef,
+ in_shardings_leaves=jit_info.in_shardings_leaves,
+ out_shardings_treedef=jit_info.out_shardings_treedef,
+ out_shardings_leaves=jit_info.out_shardings_leaves,
+ in_layouts_treedef=jit_info.in_layouts_treedef,
+ in_layouts_leaves=jit_info.in_layouts_leaves,
+ out_layouts_treedef=jit_info.out_layouts_treedef,
+ out_layouts_leaves=jit_info.out_layouts_leaves,
+ use_resource_env=jit_info.use_resource_env)
+ cpp_pjit_f = xc._xla.pjit(
+ fun_name(fun), fun, cache_miss, jit_info.static_argnums,
+ jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
+ pxla.cc_shard_arg,
+ _get_cpp_global_cache(cache_key.contains_explicit_attributes))
+ else:
+ has_explicit_sharding = _pjit_explicit_sharding_and_layout(
+ jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
+ jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
+ jit_info.device, jit_info.backend)
+ cpp_pjit_f = xc._xla.pjit(
+ fun_name(fun), fun, cache_miss, jit_info.static_argnums,
+ jit_info.static_argnames, jit_info.donate_argnums,
+ tree_util.dispatch_registry, pxla.cc_shard_arg,
+ _get_cpp_global_cache(has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
@@ -352,17 +400,6 @@ def cache_miss(*args, **kwargs):
return cpp_pjitted_f
-def _pjit_explicit_sharding_and_layout(
- in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
- device, backend) -> bool:
- return (device is not None or
- backend is not None or
- any(not is_unspecified(i) for i in in_shardings_flat) or
- any(not is_unspecified(o) for o in out_shardings_flat) or
- any(i is not None for i in in_layouts_flat) or
- any(o is not None for o in out_layouts_flat))
-
-
def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
@@ -446,10 +483,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)
- has_explicit_sharding = _pjit_explicit_sharding_and_layout(
- in_shardings_leaves, out_shardings_leaves, in_layouts_leaves,
- out_layouts_leaves, device, backend)
-
return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
@@ -467,7 +500,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
- has_explicit_sharding=has_explicit_sharding,
use_resource_env=use_resource_env)
@@ -475,16 +507,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
@api_boundary
def lower(*args, **kwargs):
- traced = trace(*args, **kwargs)
- try:
- return traced.lower()
- except pxla.DeviceAssignmentMismatchError as e:
- fails, = e.args
- fun_name = getattr(fun, '__qualname__',
- getattr(fun, '__name__', str(fun)))
- msg = _device_assignment_mismatch_error(
- fun_name, fails, traced._args_flat, 'jit', traced._arg_names)
- raise ValueError(msg) from None
+ return trace(*args, **kwargs).lower()
@api_boundary
def eval_shape(*args, **kwargs):
@@ -504,7 +527,7 @@ def trace(*args, **kwargs) -> stages.Traced:
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
pgle_profiler=None)
return stages.Traced(
- p.params['jaxpr'], args_info, p.params["name"],p.out_tree,
+ p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
lower_callable, args_flat, p.arg_names, p.num_consts)
wrapped = _cpp_pjit(fun, jit_info)
@@ -1018,8 +1041,8 @@ def _create_sharding_for_array(mesh, x, name, api_name):
' then the mesh context manager is not required.')
# A nice user error is raised in prepare_axis_resources.
assert x is None or isinstance(x, ParsedPartitionSpec), x
- return (pxla.create_mesh_pspec_sharding(mesh, x)
- if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x))
+ return (pxla.create_mesh_pspec_sharding(mesh, x) if x is None else
+ pxla.create_mesh_pspec_sharding(mesh, x.get_partition_spec(), x))
def _create_sharding_with_device_backend(device, backend):
@@ -1494,11 +1517,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
return tuple(resolved_in_layouts)
-def _resolve_in_shardings(
- args, pjit_in_shardings: Sequence[PjitSharding],
- out_shardings: Sequence[PjitSharding],
- pjit_mesh: pxla.Mesh | None,
- check_device_assignment: bool = True) -> Sequence[PjitSharding]:
+def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
+ ) -> Sequence[PjitSharding]:
# If True, means that device or backend is set by the user on pjit and it
# has the same semantics as device_put i.e. doesn't matter which device the
# arg is on, reshard it to the device mentioned. So don't do any of the
@@ -1521,18 +1541,6 @@ def _resolve_in_shardings(
if getattr(a, '_committed', True):
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
- # Check if the device_assignment across inputs, outputs and arguments is the
- # same.
- if check_device_assignment:
- pxla._get_and_check_device_assignment(
- it.chain(
- util.stable_unique(committed_arg_shardings),
- ((i, pxla.MismatchType.IN_SHARDING, None)
- for i in util.stable_unique(pjit_in_shardings)),
- ((o, pxla.MismatchType.OUT_SHARDING, None)
- for o in util.stable_unique(out_shardings))),
- (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
-
resolved_in_shardings = []
for arg, pjit_in_s in zip(args, pjit_in_shardings):
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
@@ -1602,9 +1610,7 @@ def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
lowering_platforms, lowering_parameters, pgle_profiler):
- in_shardings = _resolve_in_shardings(
- args, in_shardings, out_shardings,
- resource_env.physical_mesh if resource_env is not None else None)
+ in_shardings = _resolve_in_shardings(args, in_shardings)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
jaxpr.in_avals)
lowered = _pjit_lower(
@@ -1733,13 +1739,27 @@ def call_impl_cache_miss(*args_, **kwargs_):
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
- donated_argnums = [i for i, d in enumerate(donated_invars) if d]
- has_explicit_sharding = _pjit_explicit_sharding_and_layout(
- in_shardings, out_shardings, in_layouts, out_layouts, None, None)
- return xc._xla.pjit(
- name, f, call_impl_cache_miss, [], [], donated_argnums,
- tree_util.dispatch_registry, pxla.cc_shard_arg,
- _get_cpp_global_cache(has_explicit_sharding))(*args)
+ donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
+ if xla_extension_version >= 286:
+ cache_key = pxla.JitGlobalCppCacheKeys(
+ donate_argnums=donated_argnums, donate_argnames=None,
+ device=None, backend=None,
+ in_shardings_treedef=None, in_shardings_leaves=in_shardings,
+ out_shardings_treedef=None, out_shardings_leaves=out_shardings,
+ in_layouts_treedef=None, in_layouts_leaves=in_layouts,
+ out_layouts_treedef=None, out_layouts_leaves=out_layouts,
+ use_resource_env=resource_env is not None)
+ return xc._xla.pjit(
+ name, f, call_impl_cache_miss, [], [], cache_key,
+ tree_util.dispatch_registry, pxla.cc_shard_arg,
+ _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
+ else:
+ has_explicit_sharding = _pjit_explicit_sharding_and_layout(
+ in_shardings, out_shardings, in_layouts, out_layouts, None, None)
+ return xc._xla.pjit(
+ name, f, call_impl_cache_miss, [], [], donated_argnums,
+ tree_util.dispatch_registry, pxla.cc_shard_arg,
+ _get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)
@@ -1780,13 +1800,11 @@ def pjit_staging_rule(trace, *args, **params):
params['jaxpr'], params['out_shardings'], params['out_layouts'])
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)
-
if (params["inline"] and
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
-
if config.dynamic_shapes.value:
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py
index 46bd873bd029..debd37dde64f 100644
--- a/jax/_src/scipy/spatial/transform.py
+++ b/jax/_src/scipy/spatial/transform.py
@@ -167,12 +167,12 @@ def as_rotvec(self, degrees: bool = False) -> jax.Array:
"""Represent as rotation vectors."""
return _as_rotvec(self.quat, degrees)
- def as_quat(self, canonical: bool=False) -> jax.Array:
+ def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> jax.Array:
"""Represent as quaternions."""
- if canonical:
- return _make_canonical(self.quat)
- else:
- return self.quat
+ quat = _make_canonical(self.quat) if canonical else self.quat
+ if scalar_first:
+ return jnp.roll(quat, shift=1, axis=-1)
+ return quat
def inv(self):
"""Invert this rotation."""
diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py
index 3401edd9e112..837aa011f165 100644
--- a/jax/_src/scipy/special.py
+++ b/jax/_src/scipy/special.py
@@ -558,7 +558,7 @@ def entr(x: ArrayLike) -> Array:
\mathrm{entr}(x) = \begin{cases}
-x\log(x) & x > 0 \\
0 & x = 0\\
- -\infty & x > 0
+ -\infty & \mathrm{otherwise}
\end{cases}
Args:
diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py
index add297b6a351..310ff38b7247 100644
--- a/jax/_src/sharding_impls.py
+++ b/jax/_src/sharding_impls.py
@@ -18,7 +18,6 @@
from collections import OrderedDict
from collections.abc import Mapping, Sequence
import dataclasses
-import enum
import functools
import itertools
import math
@@ -955,43 +954,20 @@ def get_array_mapping(
cast(ArrayMapping, get_array_mapping(p)))
-class SpecSync(enum.IntEnum):
- """Encodes how much out of sync the real value of partitions is compared to the user specified one.
-
- We use this to make sure we don't show garbage modified values while claiming
- that the users have specified them like that.
- """
- OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
- DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
- IN_SYNC = 2 # Entirely in sync
-
class ParsedPartitionSpec:
- __slots__ = ('unsafe_user_spec', 'partitions', 'sync')
+ __slots__ = ('_user_spec', 'partitions')
- def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
- self.unsafe_user_spec = user_spec
+ def __init__(self, user_spec, partitions):
+ self._user_spec = user_spec
# None in partitions represents unconstrained dim.
# TODO(yashkatariya): May use a sentinel value.
self.partitions = tuple(partitions)
- self.sync = sync
-
- @property
- def user_spec(self):
- return self.unsynced_user_spec(SpecSync.IN_SYNC)
def get_partition_spec(self) -> PartitionSpec:
- if self.sync < SpecSync.IN_SYNC:
- return get_single_pspec(self)
+ if isinstance(self._user_spec, PartitionSpec):
+ return self._user_spec
else:
- if isinstance(self.unsafe_user_spec, PartitionSpec):
- return self.unsafe_user_spec
- else:
- return get_single_pspec(self)
-
- def unsynced_user_spec(self, min_sync):
- if self.sync < min_sync:
- raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
- return self.unsafe_user_spec
+ return get_single_pspec(self)
def insert_axis_partitions(self, dim, val):
parts = self.partitions
@@ -999,8 +975,7 @@ def insert_axis_partitions(self, dim, val):
if too_short > 0:
parts += ((),) * too_short
new_partitions = util.tuple_insert(parts, dim, val)
- new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
- return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
+ return ParsedPartitionSpec(None, new_partitions)
@classmethod
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
@@ -1027,11 +1002,12 @@ def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
return cls(new_entry, axis_specs)
def __hash__(self):
- return hash((self.partitions, self.sync))
+ return hash(self.partitions)
def __eq__(self, other):
- return (self.partitions == other.partitions and
- self.sync == other.sync)
+ if not isinstance(other, ParsedPartitionSpec):
+ return False
+ return self.partitions == other.partitions
def __len__(self):
return len(self.partitions)
@@ -1043,58 +1019,19 @@ def __iter__(self):
return iter(self.partitions)
def __repr__(self):
- return (f"ParsedPartitionSpec(partitions={self.partitions}, "
- f"unsafe_user_spec={self.unsafe_user_spec}, "
- f"sync={self.sync})")
-
-class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
- """ParsedPartitionSpecs that are canonicalized.
-
- ParsedPartitionSpecs may contain trailing empty tuples, that make them
- semantically different in general, and yet in some situations we prefer
- to regard them as equivalent. For example, partitions of () and ((),)
- cannot be always considered equivalent, since the first one is a valid
- spec for a scalar value, while the second is not! However, when either of
- those are applied to a 2D array, they both mean that the array is fully
- replicated.
-
- So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
- partitions.
- """
-
- def __init__(self, parsed_pspec: ParsedPartitionSpec):
- partitions = list(parsed_pspec.partitions)
- while partitions and partitions[-1] == ():
- partitions.pop()
-
- super().__init__(parsed_pspec.unsafe_user_spec, partitions,
- parsed_pspec.sync)
-
- def __repr__(self):
- return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
- f"unsafe_user_spec={self.unsafe_user_spec}, "
- f"sync={self.sync})")
+ return f"ParsedPartitionSpec(partitions={self.partitions})"
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
- # This split exists because you can pass `_parsed_pspec` that has been
- # modified from the original. For example: Adding extra dimension to
- # axis_resources for vmap handlers. In such cases you need to preserve the
- # `sync` attribute of parsed pspecs.
- # PartitionSpec is inferred from the parsed pspec in this case.
- # TODO(yaskatariya): Remove this and replace this with a normalized
- # representation of Parsed Pspec
if parsed_pspec is None:
parsed_pspec = prepare_axis_resources(
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
-
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec
-def prepare_axis_resources(axis_resources,
- arg_name,
+def prepare_axis_resources(axis_resources, arg_name,
allow_unconstrained_dims=False):
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_util.tree_flatten(
@@ -1131,9 +1068,11 @@ def _check_unique_resources(axis_resources, arg_name):
if resource_counts.most_common(1)[0][1] > 1:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
if multiple_uses:
- raise ValueError(f"A single {arg_name} specification can map every mesh axis "
- f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
- f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
+ raise ValueError(
+ f'A single {arg_name} specification can map every mesh axis to at'
+ ' most one positional dimension, but'
+ f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
+ f' for {mesh_lib.show_axes(multiple_uses)}')
# Axis environments
@@ -1312,8 +1251,7 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
out.extend(parse_flatten_op_sharding(s, mesh))
return out
elif hlo_sharding.is_replicated():
- return [CanonicalizedParsedPartitionSpec(
- ParsedPartitionSpec(PartitionSpec(), ()))]
+ return [ParsedPartitionSpec(PartitionSpec(), ())]
elif hlo_sharding.is_tiled():
mesh_shape = mesh.shape
mesh_axis_order = unflatten_array(
@@ -1337,8 +1275,9 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
)
if hlo_sharding.replicate_on_last_tile_dim():
partitions = partitions[:-1]
- return [CanonicalizedParsedPartitionSpec(
- ParsedPartitionSpec('', partitions))]
+ while partitions and partitions[-1] == ():
+ partitions.pop()
+ return [ParsedPartitionSpec(None, partitions)]
else:
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")
diff --git a/jax/_src/stages.py b/jax/_src/stages.py
index b924072fc044..3a2c375b64db 100644
--- a/jax/_src/stages.py
+++ b/jax/_src/stages.py
@@ -734,12 +734,22 @@ def out_info(self):
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
+ from jax._src.interpreters import pxla
+ from jax._src import pjit
+
if _private_parameters is None:
_private_parameters = mlir.LoweringParameters()
new_callable = functools.partial(
self._lower_callable, lowering_platforms=lowering_platforms,
lowering_parameters=_private_parameters)
- return Lowered(new_callable(), self.args_info, self._out_tree)
+ try:
+ lowering = new_callable()
+ except pxla.DeviceAssignmentMismatchError as e:
+ fails, = e.args
+ msg = pjit._device_assignment_mismatch_error(
+ self.fun_name, fails, self._args_flat, 'jit', self._arg_names)
+ raise ValueError(msg) from None
+ return Lowered(lowering, self.args_info, self._out_tree)
@runtime_checkable
diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py
index 0041b2506061..2f1c88be495b 100644
--- a/jax/_src/state/__init__.py
+++ b/jax/_src/state/__init__.py
@@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state."""
-from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
- AccumEffect, StateEffect, RefEffect,
- get_ref_state_effects, shaped_array_ref,
- RefView)
+from jax._src.state.types import (
+ AbstractRef,
+ AccumEffect,
+ ReadEffect,
+ RefEffect,
+ StateEffect,
+ Transform,
+ TransformedRef,
+ WriteEffect,
+ get_ref_state_effects,
+ shaped_array_ref,
+)
diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py
index 4795af054280..7970440d29a6 100644
--- a/jax/_src/state/discharge.py
+++ b/jax/_src/state/discharge.py
@@ -20,10 +20,8 @@
import operator
from typing import Any, Protocol, TypeVar
-import numpy as np
-
-from jax._src import api_util
from jax._src import ad_util
+from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
@@ -35,12 +33,20 @@
from jax._src.lax import lax
from jax._src.lax import slicing as lax_slicing
from jax._src.state import indexing
-from jax._src.state.types import AbstractRef, RefEffect
-from jax._src.state.primitives import get_p, swap_p, addupdate_p
-from jax._src.state.utils import hoist_consts_to_refs
+from jax._src.state.primitives import addupdate_p, get_p, swap_p
+from jax._src.state.types import AbstractRef, RefBitcaster, RefEffect
+from jax._src.state.utils import bitcast, hoist_consts_to_refs
from jax._src.typing import Array
-from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache,
- partition_list, merge_lists, split_dict)
+from jax._src.util import (
+ merge_lists,
+ partition_list,
+ safe_map,
+ safe_zip,
+ split_dict,
+ split_list,
+ weakref_lru_cache,
+)
+import numpy as np
## JAX utilities
@@ -169,7 +175,7 @@ def _maybe_convert_to_slice(
return None
start = i.start
- end = i.start + i.size * i.stride
+ end = i.start + (i.size - 1) * i.stride + 1
stride = i.stride
# cannot convert to static `slice` if `start` or `end` is dynamic
@@ -264,73 +270,95 @@ def _prepend_scatter(x, indexer, val, *, add=False):
return x[None].at[(0, *indexer)].add(val)[0]
return x[None].at[(0, *indexer)].set(val)[0]
+def _bitcast_array(x, bitcaster: RefBitcaster):
+ return bitcast(x, bitcaster.dtype)
-def index_array(x, indexers):
- if indexers is None:
- indexers = []
+def _index_array(x, indexer):
+ if _is_trivial_indexer(indexer):
+ return x
+ # Try the three APIs in the following order: `lax.slice`,
+ # `lax.dynamic_slice` and gather
+ if maybe_slice := _maybe_convert_to_slice(indexer):
+ x = lax_slicing.slice(x, *zip(*maybe_slice))
+ # If everything in the indexer is a slice or ()-shaped, we can also
+ # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
+ # We need to squeeze out the 1-sized slices at the end.
+ elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
+ starts, sizes, squeeze_dims = maybe_slice
+ y = lax_slicing.dynamic_slice(x, starts, sizes)
+ x = lax.squeeze(y, squeeze_dims)
+ else:
+ indexer = _convert_to_array_indexer(indexer)
+ x = x[None][(np.array(0, "int32"), *indexer)]
+ return x
+
+
+def transform_array(x, transforms):
+ if transforms is None:
+ transforms = []
result = x
- for indexer in indexers:
- if _is_trivial_indexer(indexer):
+ for transform in transforms:
+ if transform is None:
continue
- if indexer is None:
- continue
-
- # Try the three APIs in the following order: `lax.slice`,
- # `lax.dynamic_slice` and gather
- if maybe_slice := _maybe_convert_to_slice(indexer):
- result = lax_slicing.slice(result, *zip(*maybe_slice))
- # If everything in the indexer is a slice or ()-shaped, we can also
- # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
- # We need to squeeze out the 1-sized slices at the end.
- elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
- starts, sizes, squeeze_dims = maybe_slice
- y = lax_slicing.dynamic_slice(result, starts, sizes)
- result = lax.squeeze(y, squeeze_dims)
+ if isinstance(transform, indexing.NDIndexer):
+ result = _index_array(result, transform)
+ elif isinstance(transform, RefBitcaster):
+ result = _bitcast_array(result, transform)
else:
- indexer = _convert_to_array_indexer(indexer)
- result = result[None][(np.array(0, "int32"), *indexer)]
+ raise NotImplementedError(f"Unsupported transform: {transform}")
return result
-def index_swap_array(x, indexers, val):
- if indexers is None:
- indexers = []
+def transform_swap_array(x, transforms, val):
+ if transforms is None:
+ transforms = []
result = x
result_val = val
# Compute updated "val" (result).
_results = [x]
- for indexer in indexers:
- if _is_trivial_indexer(indexer):
- _results.append(None)
- continue
- # If everything in the indexer is a slice or ()-shaped, we can also
- # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
- # We need to squeeze out the 1-sized slices at the end.
- if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
- starts, sizes, squeeze_dims = maybe_slice
- result_old = lax_slicing.dynamic_slice(result, starts, sizes)
- result = lax.squeeze(result_old, squeeze_dims)
+ for transform in transforms:
+ if isinstance(transform, indexing.NDIndexer):
+ indexer = transform
+ if _is_trivial_indexer(indexer):
+ _results.append(None)
+ continue
+ # If everything in the indexer is a slice or ()-shaped, we can also
+ # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
+ # We need to squeeze out the 1-sized slices at the end.
+ if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
+ starts, sizes, squeeze_dims = maybe_slice
+ result_old = lax_slicing.dynamic_slice(result, starts, sizes)
+ result = lax.squeeze(result_old, squeeze_dims)
+ else:
+ indexer = _convert_to_array_indexer(indexer)
+ result = _prepend_gather(result, indexer)
+ _results.append(result)
+ elif isinstance(transform, RefBitcaster):
+ _results.append(_bitcast_array(result, transform))
else:
- indexer = _convert_to_array_indexer(indexer)
- result = _prepend_gather(result, indexer)
- _results.append(result)
+ raise NotImplementedError(f"Unsupported transform: {transform}")
# Compute updated "x" (result_val)
- for i, indexer in reversed(list(enumerate(indexers))):
- if _is_trivial_indexer(indexer):
- continue
- if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
- starts, _, squeeze_dims = maybe_slice
- result_val = lax.expand_dims(result_val, squeeze_dims)
- result_val = lax_slicing.dynamic_update_slice(
- _results[i], result_val, starts)
+ for i, transform in reversed(list(enumerate(transforms))):
+ if isinstance(transform, indexing.NDIndexer):
+ indexer = transform
+ if _is_trivial_indexer(indexer):
+ continue
+ if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
+ starts, _, squeeze_dims = maybe_slice
+ result_val = lax.expand_dims(result_val, squeeze_dims)
+ result_val = lax_slicing.dynamic_update_slice(
+ _results[i], result_val, starts
+ )
+ else:
+ indexer = _convert_to_array_indexer(indexer)
+ result_val = _prepend_scatter(_results[i], indexer, result_val)
else:
- indexer = _convert_to_array_indexer(indexer)
- result_val = _prepend_scatter(_results[i], indexer, result_val)
+ raise NotImplementedError(f"Unsupported transform: {transform}")
return result, result_val
def _get_discharge(x, idx, tree):
- indexers = tree_util.tree_unflatten(tree, idx)
- return index_array(x, indexers)
+ transforms = tree_util.tree_unflatten(tree, idx)
+ return transform_array(x, transforms)
@register_discharge_rule(swap_p)
def _swap_discharge_rule(
@@ -342,8 +370,8 @@ def _swap_discharge_rule(
return (x_new, None) + (None,) * len(idx), z
def _swap_discharge(x, val, idx, tree):
- indexers = tree_util.tree_unflatten(tree, idx)
- return index_swap_array(x, indexers, val)
+ transforms = tree_util.tree_unflatten(tree, idx)
+ return transform_swap_array(x, transforms, val)
@register_discharge_rule(addupdate_p)
def _addupdate_discharge_rule(
@@ -355,10 +383,10 @@ def _addupdate_discharge_rule(
return (ans, None) + (None,) * len(idx), []
def _addupdate_discharge(x, val, idx, tree):
- indexers = tree_util.tree_unflatten(tree, idx)
- if len(indexers) > 1:
+ transforms = tree_util.tree_unflatten(tree, idx)
+ if len(transforms) > 1:
raise NotImplementedError("Only single indexer is supported.")
- indexer = indexers[0]
+ indexer = transforms[0]
if _is_trivial_indexer(indexer):
return x + val
# If everything in the indexer is a slice or ()-shaped, we can also
@@ -462,7 +490,7 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *,
len(primals)])
del out_consts
out_tangents_iter = iter(out_tangents)
- out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
+ out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p)
for p, nz in zip(out_primals, nonzero_tangents)]
return out_primals, out_tangents
ad.primitive_jvps[run_state_p] = _run_state_jvp
@@ -488,8 +516,7 @@ def eval_jaxpr(*refs):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
- return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
- for a in res_ref_avals]
+ return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars"
diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py
index 750d3239a019..988f362290f0 100644
--- a/jax/_src/state/primitives.py
+++ b/jax/_src/state/primitives.py
@@ -18,9 +18,6 @@
import types
from typing import Any, Union
-import numpy as np
-
-
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
@@ -28,14 +25,22 @@
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
-from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
+from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
-from jax._src.typing import Array
from jax._src.state import indexing
-from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect,
- AccumEffect)
+from jax._src.state.types import (
+ AbstractRef,
+ AccumEffect,
+ ReadEffect,
+ RefBitcaster,
+ Transform,
+ TransformedRef,
+ WriteEffect,
+)
+from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip
+import numpy as np
## General utilities
@@ -59,29 +64,29 @@
Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...]
-def get_ref_and_indexers(
+def get_ref_and_transforms(
ref_or_view: Any, idx: Indexer | None, function_name: str
-) -> tuple[Any, tuple[indexing.NDIndexer, ...]]:
- if isinstance(ref_or_view, RefView):
- ref, indexers = ref_or_view.ref, ref_or_view.indexers
+) -> tuple[Any, tuple[Transform, ...]]:
+ if isinstance(ref_or_view, TransformedRef):
+ ref, transforms = ref_or_view.ref, ref_or_view.transforms
else:
- ref, indexers = ref_or_view, ()
+ ref, transforms = ref_or_view, ()
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
return ref, ()
if idx is None:
- return ref, indexers
+ return ref, transforms
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
- return ref, (*indexers, nd_indexer)
+ return ref, (*transforms, nd_indexer)
def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
- ref, indexers = get_ref_and_indexers(ref_or_view, idx, "ref_get")
- flat_indexers, tree = tree_util.tree_flatten(indexers)
- return get_p.bind(ref, *flat_indexers, tree=tree)
+ ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get")
+ flat_transforms, tree = tree_util.tree_flatten(transforms)
+ return get_p.bind(ref, *flat_transforms, tree=tree)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
@@ -102,14 +107,22 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
swap_p = core.Primitive("swap")
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
-def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array,
- _function_name: str = "ref_swap") -> Array:
+
+def ref_swap(
+ ref_or_view: AbstractRef | TransformedRef,
+ idx: Indexer | None,
+ value: Array,
+ _function_name: str = "ref_swap",
+) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
- ref, indexers = get_ref_and_indexers(ref_or_view, idx, _function_name)
- flat_indexers, tree = tree_util.tree_flatten(indexers)
- return swap_p.bind(ref, value, *flat_indexers, tree=tree)
+ ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name)
+ flat_transforms, tree = tree_util.tree_flatten(transforms)
+ return swap_p.bind(ref, value, *flat_transforms, tree=tree)
-def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array) -> None:
+
+def ref_set(
+ ref_or_view: AbstractRef | TransformedRef, idx: Indexer | None, value: Array
+) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref_or_view, idx, value, _function_name="ref_set")
@@ -130,34 +143,50 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
- ref, indexers = get_ref_and_indexers(ref_or_view, idx, "ref_addupdate")
- flat_indexers, tree = tree_util.tree_flatten(indexers)
- return addupdate_p.bind(ref, x, *flat_indexers, tree=tree)
+ ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate")
+ flat_transforms, tree = tree_util.tree_flatten(transforms)
+ return addupdate_p.bind(ref, x, *flat_transforms, tree=tree)
## get/set/addupdate abstract evaluation rules
-def _shape_after_indexing(
- shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...]
+def _shape_after_transforming(
+ shape: tuple[int | Array, ...], transforms: tuple[Transform, ...]
) -> tuple[int | Array, ...]:
- for indexer in indexers:
- # Run some simple checks that all the indexers have consistent shapes
- if not indexer.is_dynamic_size:
- assert indexer.shape == shape, (indexer.shape, shape)
- shape = indexer.get_indexer_shape()
+ for transform in transforms:
+ match transform:
+ case indexing.NDIndexer():
+ # Run some simple checks that all the indexers have consistent shapes
+ if not transform.is_dynamic_size:
+ assert transform.shape == shape, (transform.shape, shape)
+ shape = transform.get_indexer_shape()
+ case RefBitcaster():
+ shape = transform.shape
+ case _:
+ raise ValueError(f"Unsupported transform: {transform}")
return shape
+def _dtype_after_transforming(
+ dtype: Any, transforms: tuple[Transform, ...]
+) -> Any:
+ for transform in reversed(transforms):
+ if isinstance(transform, RefBitcaster):
+ return transform.dtype
+ return dtype
+
+
def _get_abstract_eval(ref_aval: AbstractRef, *args,
tree):
- indexers = tree_util.tree_unflatten(tree, args)
+ transforms = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
- out_shape = _shape_after_indexing(ref_aval.shape, indexers)
- out_aval = ref_aval.inner_aval.update(shape=out_shape)
+ out_shape = _shape_after_transforming(ref_aval.shape, transforms)
+ out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
+ out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype)
else:
- if indexers:
+ if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)})
@@ -166,27 +195,30 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*args: Any, tree):
- indexers = tree_util.tree_unflatten(tree, args)
+ transforms = tree_util.tree_unflatten(tree, args)
out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
- expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers)
+ expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms)
+ expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
if expected_out_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Expected shape: {expected_out_shape}. "
f"Value shape: {val_aval.shape}. "
- f"Indices: {indexers}. ")
- if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type:
- raise ValueError("Invalid dtype for `swap`. "
- f"Ref dtype: {ref_aval.dtype}. "
- f"Value dtype: {val_aval.dtype}. ")
- out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype)
+ f"Transforms: {transforms}. ")
+ if expected_out_dtype != val_aval.dtype and not val_aval.weak_type:
+ raise ValueError(
+ "Invalid dtype for `swap`. "
+ f"Ref dtype: {expected_out_dtype}. "
+ f"Value dtype: {val_aval.dtype}. "
+ )
+ out_aval = core.ShapedArray(expected_out_shape, expected_out_dtype)
else:
- if indexers:
+ if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)})
@@ -196,26 +228,29 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*args: Any, tree):
- indexers = tree_util.tree_unflatten(tree, args)
+ transforms = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
val_aval = core.raise_to_shaped(val_aval)
- slice_shape = _shape_after_indexing(ref_aval.shape, indexers)
+ out_shape = _shape_after_transforming(ref_aval.shape, transforms)
+ out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
assert isinstance(val_aval, core.ShapedArray)
- if slice_shape != val_aval.shape:
- raise ValueError("Invalid shape for `addupdate`. "
- f"Ref shape: {ref_aval.shape}. "
- f"Slice shape: {slice_shape}. "
- f"Value shape: {val_aval.shape}. "
- f"Indices: {indexers}. ")
- if ref_aval.dtype != val_aval.dtype:
+ if out_shape != val_aval.shape:
+ raise ValueError(
+ "Invalid shape for `addupdate`. "
+ f"Ref shape: {ref_aval.shape}. "
+ f"Expected shape: {out_shape}. "
+ f"Value shape: {val_aval.shape}. "
+ f"Transforms: {transforms}. "
+ )
+ if out_dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
else:
- # Check that the indexers are valid
- if indexers:
+ # Check that the transforms are valid
+ if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
return [], {AccumEffect(0)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
@@ -261,52 +296,73 @@ def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
indices.append(core.pp_var(idx, context)) # type: ignore
return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")])
-def _pp_indexers(
- context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...],
+
+def pp_bitcaster(
+ context: core.JaxprPpContext, bitcaster: RefBitcaster
+) -> pp.Doc:
+ del context
+ return pp.text(
+ f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]"
+ )
+
+
+def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc:
+ match transform:
+ case indexing.NDIndexer():
+ return pp_indexer(context, transform)
+ case RefBitcaster():
+ return pp_bitcaster(context, transform)
+ case _:
+ raise ValueError(f"Unsupported transform: {transform}")
+
+
+def _pp_transforms(
+ context: core.JaxprPpContext,
+ transforms: tuple[Transform, ...],
):
- if not indexers:
+ if not transforms:
return pp.text("[...]")
return pp.concat(
- [pp_indexer(context, indexer) for indexer in indexers]
+ [pp_transform(context, transform) for transform in transforms]
)
-def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers):
+
+def pp_ref_transforms(context: core.JaxprPpContext, ref, transforms):
return pp_ref_var(
pp.concat([
pp.text(core.pp_var(ref, context)),
- _pp_indexers(context, indexers),
+ _pp_transforms(context, transforms),
])
)
+
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
# Pretty prints `a = get x i` as `x[i] <- a`
y, = eqn.outvars
x, *flat_idx = eqn.invars
- indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
+ transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
- return pp.concat([
- lhs,
- pp.text(' <- '),
- pp_ref_indexers(context, x, indexers)
- ])
+ return pp.concat(
+ [lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)]
+ )
core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
y, = eqn.outvars
x, v, *flat_idx = eqn.invars
- indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
+ transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
if type(y) is core.DropVar:
# In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v`
del y
return pp.concat([
- pp_ref_indexers(context, x, indexers),
- pp.text(' <- '),
- pp.text(core.pp_var(v, context))
- ])
+ pp_ref_transforms(context, x, transforms),
+ pp.text(" <- "),
+ pp.text(core.pp_var(v, context)),
+ ])
else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
- x_i = pp_ref_indexers(context, x, indexers)
+ x_i = pp_ref_transforms(context, x, transforms)
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
x_i, pp.text(', '),
@@ -318,11 +374,12 @@ def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
# pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars
x, v, *flat_idx = eqn.invars
- indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
+ transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
return pp.concat([
- pp_ref_indexers(context, x, indexers),
- pp.text(' += '),
- pp.text(core.pp_var(v, context))])
+ pp_ref_transforms(context, x, transforms),
+ pp.text(" += "),
+ pp.text(core.pp_var(v, context)),
+ ])
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py
index a71d671c5345..e64d6258a808 100644
--- a/jax/_src/state/types.py
+++ b/jax/_src/state/types.py
@@ -21,11 +21,13 @@
from typing import Any, Union
from jax._src import core
+from jax._src import dtypes
from jax._src import effects
from jax._src import pretty_printer as pp
+from jax._src import tree_util
from jax._src.state import indexing
-from jax._src.util import safe_map, safe_zip
from jax._src.typing import Array
+from jax._src.util import safe_map, safe_zip
## JAX utilities
@@ -72,7 +74,39 @@ class AccumEffect(RefEffect):
StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
+
# ## `Ref`s
+@tree_util.register_pytree_node_class
+@dataclasses.dataclass(frozen=True)
+class RefBitcaster:
+ dtype: dtypes.DType
+ shape: tuple[int, ...]
+
+ @classmethod
+ def from_ref_new_dtype(cls, ref_or_view: Any, dtype) -> RefBitcaster:
+ if isinstance(ref_or_view, TransformedRef):
+ if ref_or_view.is_dynamic_size:
+ raise NotImplementedError(
+ "Bitcast ref with dynamic size is not supported."
+ )
+ from jax._src.state.utils import eval_bitcast_shape # pytype: disable=import-error
+ dtype = dtypes.dtype(dtype)
+ return cls(dtype, eval_bitcast_shape(ref_or_view, dtype))
+
+ @property
+ def is_dynamic_size(self):
+ return False
+
+ def tree_flatten(self):
+ return (), (self.dtype, self.shape)
+
+ @classmethod
+ def tree_unflatten(cls, metadata, arrays):
+ assert not arrays
+ return cls(*metadata)
+
+
+Transform = indexing.NDIndexer | RefBitcaster
@dataclasses.dataclass
class RefIndexer:
@@ -82,37 +116,47 @@ def __getitem__(self, slc):
if not isinstance(slc, tuple):
slc = (slc,)
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
- if isinstance(self.ref_or_view, RefView):
+ if isinstance(self.ref_or_view, TransformedRef):
view = self.ref_or_view
- return RefView(view.ref, (*view.indexers, indexer))
- return RefView(self.ref_or_view, (indexer,))
+ return TransformedRef(view.ref, (*view.transforms, indexer))
+ return TransformedRef(self.ref_or_view, (indexer,))
-Indexer = Any
@dataclasses.dataclass
-class RefView:
+class TransformedRef:
ref: Any
- indexers: tuple[indexing.NDIndexer, ...]
+ transforms: tuple[Transform, ...]
@property
def is_dynamic_size(self):
- return self.indexers[-1].is_dynamic_size
+ return self.transforms[-1].is_dynamic_size
@property
def shape(self) -> tuple[int | Array, ...]:
assert (
- len(self.indexers) > 0
- ), "Should not be able to create a trivial RefView"
- return self.indexers[-1].get_indexer_shape()
+ len(self.transforms) > 0
+ ), "Should not be able to create a trivial TransformedRef"
+ if isinstance(self.transforms[-1], indexing.NDIndexer):
+ return self.transforms[-1].get_indexer_shape()
+ return self.transforms[-1].shape
@property
def dtype(self):
+ for transform in reversed(self.transforms):
+ if isinstance(transform, RefBitcaster):
+ return transform.dtype
return self.ref.dtype
@property
def at(self) -> RefIndexer:
return RefIndexer(self)
+ def bitcast(self, dtype):
+ return TransformedRef(
+ self.ref,
+ (*self.transforms, RefBitcaster.from_ref_new_dtype(self, dtype)),
+ )
+
def __getattr__(self, name):
return getattr(self.ref, name)
@@ -152,20 +196,30 @@ def join(self, other):
@property
def shape(self):
- if not isinstance(self.inner_aval, core.ShapedArray):
- raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
- return self.inner_aval.shape
+ try:
+ return self.inner_aval.shape # pytype: disable=attribute-error
+ except AttributeError:
+ raise AttributeError(
+ f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`."
+ ) from None
@property
def dtype(self):
- if not isinstance(self.inner_aval, core.UnshapedArray):
- raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
- return self.inner_aval.dtype
+ try:
+ return self.inner_aval.dtype # pytype: disable=attribute-error
+ except AttributeError:
+ raise AttributeError(
+ f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`."
+ ) from None
@core.aval_property
def at(self):
return RefIndexer(self)
+ @core.aval_method
+ def bitcast(self, dtype):
+ return TransformedRef(self, (RefBitcaster.from_ref_new_dtype(self, dtype),))
+
@core.aval_method
@staticmethod
def get(tracer, idx=()):
@@ -189,8 +243,8 @@ def _setitem(self, tracer, idx, value) -> None:
def __repr__(self) -> str:
return f'Ref{{{self.inner_aval.str_short()}}}'
- def at_least_vspace(self):
- return AbstractRef(self.inner_aval.at_least_vspace())
+ def to_tangent_aval(self):
+ return AbstractRef(self.inner_aval.to_tangent_aval())
def __eq__(self, other):
return (type(self) is type(other) and self.inner_aval == other.inner_aval)
diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py
index 33fced775fad..909e84c3a6e3 100644
--- a/jax/_src/state/utils.py
+++ b/jax/_src/state/utils.py
@@ -13,14 +13,18 @@
# limitations under the License.
"""Utilities for tracing stateful functions."""
+from functools import partial
from typing import Callable
-from jax._src.interpreters import partial_eval as pe
+import jax
from jax._src import core
+from jax._src import dtypes
from jax._src import linear_util as lu
+from jax._src.interpreters import partial_eval as pe
from jax._src.state import AbstractRef
-from jax._src.util import split_list, safe_map, safe_zip
from jax._src.state.primitives import ref_get
+from jax._src.typing import DTypeLike
+from jax._src.util import safe_map, safe_zip, split_list
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@@ -79,3 +83,41 @@ def val_to_ref_aval(x) -> AbstractRef:
if type(aval) is not core.ShapedArray:
raise TypeError(f"can't make ref from {x}")
return AbstractRef(aval)
+
+
+def dtype_bitwidth(dtype: DTypeLike) -> int:
+ if dtypes.isdtype(dtype, "integral"):
+ return dtypes.iinfo(dtype).bits
+ return dtypes.dtype(dtype).itemsize * 8
+
+
+def bitcast(x, dtype: DTypeLike):
+ x_bitwidth = dtype_bitwidth(x.dtype)
+ y_bitwidth = dtype_bitwidth(dtype)
+ shape = list(x.shape)
+ if x_bitwidth != y_bitwidth:
+ if len(shape) < 2:
+ raise NotImplementedError(
+ "Bitcast 1D ref with bitwidth change is not supported."
+ )
+ # Note: this is only valid on TPU.
+ if shape[-2] * x_bitwidth % y_bitwidth != 0:
+ raise ValueError(
+ "Expected input and output shapes are the same after multiplying"
+ " the second-minor dimension by the bitwidths."
+ )
+ shape[-2] = shape[-2] * x_bitwidth // y_bitwidth
+ if x_bitwidth < y_bitwidth:
+ ratio = y_bitwidth // x_bitwidth
+ x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes(
+ -1, -2
+ )
+ y = jax.lax.bitcast_convert_type(x, dtype)
+ if x_bitwidth > y_bitwidth:
+ y = y.swapaxes(-1, -2).reshape(shape)
+ return y
+
+
+def eval_bitcast_shape(x, dtype: DTypeLike):
+ f = partial(bitcast, dtype=dtype)
+ return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape
diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index 870268e99384..5afcd5e3a718 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -2074,6 +2074,51 @@ def arccosh(self, x):
return ctx.make_mpc((inf._mpf_, imag._mpf_))
return ctx.acosh(x)
+ def arctan(self, x):
+ ctx = x.context
+
+ if isinstance(x, ctx.mpc):
+ # Workaround mpmath 1.3 bug in atan(+-inf+-infj) evaluation
+ # (see mpmath/mpmath#775 with the fix).
+ # TODO(pearu): remove the if-block below when mpmath 1.4 or
+ # newer will be the required test dependency.
+ pi = ctx.pi
+ zero = ctx.zero
+ if ctx.isinf(x.real) or ctx.isinf(x.imag):
+ if x.real < 0:
+ return ctx.make_mpc(((-pi / 2)._mpf_, zero._mpf_))
+ return ctx.make_mpc(((pi / 2)._mpf_, zero._mpf_))
+
+ # On branch cut, mpmath.mp.atan returns different value compared
+ # to mpmath.fp.atan and numpy.arctan (see mpmath/mpmath#865).
+ # The following if-block ensures compatibility with
+ # numpy.arctan.
+ if x.real == 0 and x.imag < -1:
+ return (-ctx.atan(x)).conjugate()
+ return ctx.atan(x)
+
+ def arctanh(self, x):
+ ctx = x.context
+
+ if isinstance(x, ctx.mpc):
+ # Workaround mpmath 1.3 bug in atanh(+-inf+-infj) evaluation
+ # (see mpmath/mpmath#775 with the fix).
+ # TODO(pearu): remove the if-block below when mpmath 1.4 or
+ # newer will be the required test dependency.
+ pi = ctx.pi
+ zero = ctx.zero
+ if ctx.isinf(x.real) or ctx.isinf(x.imag):
+ if x.imag < 0:
+ return ctx.make_mpc((zero._mpf_, (-pi / 2)._mpf_))
+ return ctx.make_mpc((zero._mpf_, (pi / 2)._mpf_))
+
+ # On branch cut, mpmath.mp.atanh returns different value
+ # compared to mpmath.fp.atanh and numpy.arctanh. The following
+ # if-block ensures compatibility with numpy.arctanh.
+ if x.imag == 0 and x.real > 1:
+ return ctx.atanh(x).conjugate()
+ return ctx.atanh(x)
+
def normalize(self, exact, reference, value):
"""Normalize reference and value using precision defined by the
difference of exact and reference.
diff --git a/jax/core.py b/jax/core.py
index 1f433d6f5c29..9857fcf88c02 100644
--- a/jax/core.py
+++ b/jax/core.py
@@ -85,6 +85,7 @@
full_lower as full_lower,
gensym as gensym,
get_aval as get_aval,
+ get_type as get_type,
get_referent as get_referent,
is_constant_dim as is_constant_dim,
is_constant_shape as is_constant_shape,
diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py
index 96dc8898fd8e..8e517f5d4610 100644
--- a/jax/custom_derivatives.py
+++ b/jax/custom_derivatives.py
@@ -34,5 +34,6 @@
)
from jax._src.ad_util import (
- SymbolicZero as SymbolicZero
+ SymbolicZero as SymbolicZero,
+ zero_from_primal as zero_from_primal
)
diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py
index 8176465c1470..62da0f231d50 100644
--- a/jax/experimental/attrs.py
+++ b/jax/experimental/attrs.py
@@ -169,7 +169,7 @@ def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
def _linearize(traceable: lu.WrappedFun, *primals):
jvpfun, attrs = _split_attrs(_jvp(traceable))
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
- + tuple(pe.PartialVal.unknown(core.get_aval(p).at_least_vspace())
+ + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval())
for p in primals))
_, in_tree = tree_flatten((primals, primals))
jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree)
@@ -211,7 +211,7 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
f_, *attr_primals, *primals_flat)
- attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).at_least_vspace()
+ attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).to_tangent_aval()
for o, a in attrs_out]
f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
attrs, attrs_out)
diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py
index 43e9813d7fac..63c3299c5904 100644
--- a/jax/experimental/host_callback.py
+++ b/jax/experimental/host_callback.py
@@ -536,6 +536,8 @@ def power3_with_cotangents(x):
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
+from jax._src.lib import xla_extension_version
+from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
import numpy as np
@@ -1085,7 +1087,6 @@ def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
finally:
builder.clear_sharding()
-
def _outside_call_translation_rule(ctx,
avals_in,
avals_out,
@@ -1185,8 +1186,123 @@ def _outside_call_translation_rule(ctx,
f"identity = {identity}")
return results + [next_token, next_itoken]
+if xla_extension_version < 287:
+ xla.register_translation(outside_call_p, _outside_call_translation_rule)
+
+
+def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext,
+ *args_op,
+ identity,
+ device_index,
+ flat_results_aval=(),
+ **params):
+ # We expect the current tokens at the end, inserted by _rewrite_jaxpr.
+ current_token = args_op[-2]
+ current_itoken = args_op[-1]
+
+ args_to_outfeed = args_op[:-2]
+ # Some platforms refuse to infeed empty arrays. We generate constants
+ # instead.
+ non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
+ flat_results_aval))
+ need_callback_results_on_device = (not identity and
+ len(non_empty_flat_results_aval) > 0)
+ send_infeed = need_callback_results_on_device
+ generated_infeed = False # Keep track if we emitted an infeed op
+ for platform in ctx.module_context.platforms:
+ _raise_if_using_outfeed_with_pjrt_c_api(
+ xb.get_backend(platform)
+ )
+ callback_id = _register_callback(
+ functools.partial(
+ _outside_call_run_callback,
+ send_infeed=send_infeed,
+ identity=identity,
+ flat_results_aval=flat_results_aval,
+ **params))
-xla.register_translation(outside_call_p, _outside_call_translation_rule)
+ outfeed_sharding = xla_client.OpSharding()
+ outfeed_sharding.type = xla_client.OpSharding.Type.MAXIMAL
+ outfeed_sharding.tile_assignment_dimensions = [1]
+ outfeed_sharding.tile_assignment_devices = [device_index]
+
+ # next_token = _callback_handler_data.receiver.add_outfeed(
+ # comp, current_token, callback_id, args_to_outfeed, device_index)
+
+ xla_shapes = util.flatten(
+ xla.aval_to_xla_shapes(aval) for aval in ctx.avals_in[:-2])
+ _callback_handler_data.receiver.register_outfeed(callback_id, xla_shapes)
+ outfeed_header_start = 271828 # Must match kOutfeedHeaderStart in C++
+ header = mlir.ir_constant(np.array([outfeed_header_start, callback_id],
+ dtype=np.uint32))
+ header_outfeed = hlo.OutfeedOp([header], current_token,
+ outfeed_config=ir.StringAttr.get(''))
+ mlir.set_sharding(header_outfeed, outfeed_sharding)
+ next_token, = header_outfeed.results
+ data_outfeed = hlo.OutfeedOp(args_to_outfeed, next_token,
+ outfeed_config=ir.StringAttr.get(''))
+ mlir.set_sharding(data_outfeed, outfeed_sharding)
+ next_token, = data_outfeed.results
+
+
+ if identity:
+ results = list(args_to_outfeed)
+ next_itoken = current_itoken
+ else:
+ empty_results = [
+ mlir.ir_constant(np.zeros(aval.shape, aval.dtype))
+ for aval in flat_results_aval
+ if _aval_is_empty(aval)
+ ]
+ if non_empty_flat_results_aval:
+ assert need_callback_results_on_device
+ after_outfeed_itoken = hlo.AfterAllOp([current_itoken, next_token])
+ # We shard the infeed as AssignedDevice(device_index). This must match the
+ # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
+ # this kind of sharding, we use a custom translation for infeed.
+ array_sharding_proto = xla_client.OpSharding()
+ array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
+ array_sharding_proto.tile_assignment_dimensions = [1]
+ array_sharding_proto.tile_assignment_devices = [device_index]
+
+ token_sharding_proto = xla_client.OpSharding()
+ token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
+ infeed_sharding_proto = xla.tuple_sharding_proto(
+ [array_sharding_proto] * len(non_empty_flat_results_aval) +
+ [token_sharding_proto])
+
+ output_types = map(mlir.aval_to_ir_types, non_empty_flat_results_aval)
+ flat_output_types = util.flatten(output_types)
+
+ layouts = ir.ArrayAttr.get([
+ ir.ArrayAttr.get(
+ [mlir.i64_attr(i)
+ for i in range(len(aval.shape) - 1, -1, -1)])
+ for aval in non_empty_flat_results_aval
+ ])
+ infeed = hlo.InfeedOp(flat_output_types + [hlo.TokenType.get()],
+ after_outfeed_itoken,
+ infeed_config=ir.StringAttr.get(''),
+ layout=layouts)
+ mlir.set_sharding(infeed, infeed_sharding_proto)
+ non_empty_results = list(infeed.results[:-1])
+ next_itoken = infeed.results[-1]
+ generated_infeed = True
+ results = [
+ empty_results.pop(0)
+ if _aval_is_empty(result_aval) else non_empty_results.pop(0)
+ for result_aval in flat_results_aval
+ ]
+ else:
+ results = empty_results
+ next_itoken = current_itoken
+
+ assert generated_infeed == send_infeed, (
+ f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
+ assert identity or len(results) == len(flat_results_aval), (
+ f"got {len(results)} but expected {len(flat_results_aval)}. "
+ f"identity = {identity}")
+ return results + [next_token, next_itoken]
def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
@@ -1202,23 +1318,32 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
platform = ctx.module_context.platforms[0]
use_outfeed = _use_outfeed(platform)
if use_outfeed:
- # Fall back to XLA path if we are using the outfeed
- # TODO(sharadmv): update to use MLIR for this path as well and delete
- # XLA lowering
- return mlir.xla_fallback_lowering(outside_call_p)(
- ctx,
- *args,
- has_token=has_token,
- identity=identity,
- flat_results_aval=flat_results_aval,
- device_index=device_index,
- **params)
+ if xla_extension_version < 287:
+ return mlir.xla_fallback_lowering(outside_call_p)(
+ ctx,
+ *args,
+ has_token=has_token,
+ identity=identity,
+ device_index=device_index,
+ flat_results_aval=flat_results_aval,
+ **params,
+ )
+ else:
+ return _outside_call_outfeed_lowering(
+ ctx, *args,
+ has_token=has_token,
+ identity=identity,
+ flat_results_aval=flat_results_aval,
+ device_index=device_index,
+ **params,
+ )
else:
# TODO(necula): It seems that on CPU, with custom call, the device_index
# does not work, and the callback is always run on device_index=0
if (device_index != 0 and "cpu" in ctx.module_context.platforms):
raise ValueError(
"The device_index feature on CPU works only when using outfeed.")
+
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
current_token = args[-2]
@@ -1280,7 +1405,10 @@ def wrapped_callback(*args):
f"identity = {identity}")
return list(results) + [next_token, next_itoken]
-mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
+if xla_extension_version < 287:
+ mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
+else:
+ mlir.register_lowering(outside_call_p, _outside_call_lowering)
def _outside_call_run_callback(
arrays, device, *,
@@ -1766,7 +1894,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn],
id_p.multiple_results = True
id_p.def_impl(lambda *args: args)
id_p.def_abstract_eval(lambda *args: args)
-xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
+mlir.register_lowering(id_p, lambda ctx, *args: args)
dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py
index 2e2941fca5b1..0e263844b18e 100644
--- a/jax/experimental/mosaic/gpu/__init__.py
+++ b/jax/experimental/mosaic/gpu/__init__.py
@@ -27,6 +27,7 @@
import tempfile
import time
from typing import Any, Generic, TypeVar
+import weakref
import jax
from jax._src import config
@@ -800,6 +801,21 @@ def main(token_ptr, buffers):
return module, out_shape, unwrap_output_tuple
+def _declare_runtime_functions():
+ """Declares the runtime functions that can be used by the generated code."""
+ ptr_ty = ir.Type.parse("!llvm.ptr")
+ i64 = ir.IntegerType.get_signless(64)
+ arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
+ init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
+ func.FuncOp(
+ "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
+ )
+ memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
+ func.FuncOp(
+ "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
+ )
+
+
def as_gpu_kernel(
body,
grid: tuple[int, int, int],
@@ -867,16 +883,97 @@ def kernel(*args):
return kernel
-def _declare_runtime_functions():
- """Declares the runtime functions that can be used by the generated code."""
- ptr_ty = ir.Type.parse("!llvm.ptr")
- i64 = ir.IntegerType.get_signless(64)
- arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
- init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
- func.FuncOp(
- "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
- )
- memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
- func.FuncOp(
- "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
+def as_torch_gpu_kernel(
+ body,
+ grid: tuple[int, int, int],
+ block: tuple[int, int, int],
+ in_shape,
+ out_shape,
+ smem_scratch_shape: ShapeTree | Union[ShapeTree],
+ prof_spec: profiler.ProfilerSpec | None = None,
+ cluster: tuple[int, int, int] = (1, 1, 1),
+ module_name: str = "unknown",
+):
+ try:
+ import torch
+ except ImportError:
+ raise RuntimeError("as_torch_gpu_kernel requires PyTorch")
+ torch.cuda.init() # Make sure CUDA context is set up.
+
+ if isinstance(in_shape, list):
+ in_shape = tuple(in_shape)
+ elif not isinstance(in_shape, tuple):
+ in_shape = (in_shape,)
+
+ flat_out_types, out_treedef = jax.tree.flatten(out_shape)
+ expected_arg_treedef = jax.tree.structure(in_shape)
+
+ module, out_shape, unwrap_output_tuple = (
+ _lower_as_gpu_kernel(
+ body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
+ module_name, prof_spec
+ )
)
+
+ # Get our hands on the compilation and unload functions
+ try:
+ import jax_plugins.xla_cuda12 as cuda_plugin
+ except ImportError:
+ raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds "
+ "that use backend plugins")
+ dll = ctypes.CDLL(cuda_plugin._get_library_path())
+ compile_func = dll.MosaicGpuCompile
+ compile_func.argtypes = [ctypes.c_void_p]
+ compile_func.restype = ctypes.POINTER(ctypes.c_void_p)
+ unload_func = dll.MosaicGpuUnload
+ unload_func.argtypes = [compile_func.restype]
+ unload_func.restype = None
+
+ module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
+ compiled = compile_func(ctypes.c_char_p(module_asm))
+ if compiled is None:
+ raise RuntimeError("Failed to compile the module")
+ ctx, launch_ptr = compiled[0], compiled[1]
+ ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx))
+ launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr)
+
+ def as_torch_dtype(dtype):
+ # torch contains NumPy-compatible dtypes in its top namespace
+ return getattr(torch, np.dtype(dtype).name)
+
+ def apply(*args):
+ flat_args, arg_treedef = jax.tree.flatten(args)
+ if arg_treedef != expected_arg_treedef:
+ raise ValueError(
+ f"Invalid argument structure: expected {expected_arg_treedef}, got"
+ f" {arg_treedef}, ({args=})"
+ )
+
+ # Construct a device pointer list like in the XLA calling convention
+ buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))()
+ i = -1 # Define i in case there are no args
+ device = 'cuda'
+ for i, arg in enumerate(flat_args):
+ buffers[i] = arg.data_ptr()
+ device = arg.device
+ flat_outs = []
+ for i, t in enumerate(flat_out_types, i + 1):
+ out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device)
+ flat_outs.append(out)
+ buffers[i] = out.data_ptr()
+ # Allocate another buffer for args of the host-side program. This is sadly
+ # the default MLIR calling convention.
+ args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)()
+ args_ptr[0] = ctx_ptr_ptr
+ args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_)
+ args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)),
+ ctypes.POINTER(ctypes.c_void_p))
+ launch(args_ptr)
+ return jax.tree.unflatten(out_treedef, flat_outs)
+
+ # Unload the compiled code when the Python function is destroyed.
+ def unload(_):
+ unload_func(compiled)
+ apply.destructor = weakref.ref(apply, unload)
+
+ return apply
diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py
index 52d403cd0131..775b7c2ea898 100644
--- a/jax/experimental/mosaic/gpu/examples/matmul.py
+++ b/jax/experimental/mosaic/gpu/examples/matmul.py
@@ -132,7 +132,7 @@ def build_kernel(
if stages < 2:
raise ValueError(f"Need at least 2 stages, but got {stages=}")
if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2:
- raise ValueError("Transpose only supported for only happen for 16bit types")
+ raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})")
if swizzle not in {32, 64, 128}:
raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}")
diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py
index 7e0d43a05551..502373bdc91e 100644
--- a/jax/experimental/mosaic/gpu/fragmented_array.py
+++ b/jax/experimental/mosaic/gpu/fragmented_array.py
@@ -245,7 +245,9 @@ def _pointwise(self, op, *other):
other_arrs = []
for o in other:
if not isinstance(o, FragmentedArray):
- if not isinstance(o, ir.Value):
+ if isinstance(o, (float, int)):
+ o = utils.c(o, self.mlir_dtype)
+ elif not isinstance(o, ir.Value):
raise NotImplementedError(o)
o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout)
@@ -267,6 +269,14 @@ def _pointwise(self, op, *other):
new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs))
return FragmentedArray(_registers=new_regs, _layout=self.layout)
+ def __neg__(self):
+ if ir.FloatType.isinstance(self.mlir_dtype):
+ return self._pointwise(arith.negf)
+ elif ir.IntegerType.isinstance(self.mlir_dtype):
+ return self._pointwise(arith.negsi)
+ else:
+ raise NotImplementedError(self.mlir_dtype)
+
def __add__(self, other):
if ir.FloatType.isinstance(self.mlir_dtype):
return self._pointwise(arith.addf, other)
@@ -484,6 +494,8 @@ def astype(self, new_dtype: ir.Type):
convert = arith.sitofp
elif from_float and to_integer:
convert = arith.fptosi
+ else:
+ raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
new_registers = np.empty_like(self.registers)
match self.layout:
case WGMMAFragLayout():
diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py
index 546411c82c4c..30b8ca5cfb14 100644
--- a/jax/experimental/mosaic/gpu/utils.py
+++ b/jax/experimental/mosaic/gpu/utils.py
@@ -154,9 +154,8 @@ def fori(bound, carrys):
flat_carrys, carry_treedef = jax.tree.flatten(carrys)
def wrapper(f):
- index = ir.IndexType.get()
- c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0))
- c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1))
+ c0 = arith.constant(bound.type, 0)
+ c1 = arith.constant(bound.type, 1)
for_op = scf.ForOp(c0, bound, c1, flat_carrys)
with ir.InsertionPoint(for_op.body):
i = for_op.induction_variable
diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py
index 554bf2641769..56003ea7af5d 100644
--- a/jax/experimental/multihost_utils.py
+++ b/jax/experimental/multihost_utils.py
@@ -90,19 +90,17 @@ def sync_global_devices(name: str):
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
+# Identity function is at the top level so that `process_allgather` doesn't
+# recompile on every invocation.
def _identity_fn(x):
return x
-@lru_cache(maxsize=128)
-def _jitted_identity_fn(sharding):
- return jax.jit(_identity_fn, out_shardings=sharding)
-
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment)
- out = _jitted_identity_fn(reps)(inp)
+ out = jax.jit(_identity_fn, out_shardings=reps)(inp)
else:
# All inputs here will be fully addressable.
if jax.process_count() == 1:
@@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled):
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
- out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr)
+ out = jax.jit(_identity_fn,
+ out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr)
return np.asarray(out.addressable_data(0))
diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py
index 832f7b7d1184..c81b509d70cf 100644
--- a/jax/experimental/pallas/__init__.py
+++ b/jax/experimental/pallas/__init__.py
@@ -23,11 +23,11 @@
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
from jax._src.pallas.core import CostEstimate
+from jax._src.pallas.core import GridSpec
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import Unblocked
from jax._src.pallas.core import unblocked
-from jax._src.pallas.core import GridSpec
from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add
diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py
index e7fa25a3fc0d..8a1a223ae36e 100644
--- a/jax/experimental/pallas/tpu.py
+++ b/jax/experimental/pallas/tpu.py
@@ -68,3 +68,4 @@
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM
+SEMAPHORE = TPUMemorySpace.SEMAPHORE
diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py
index bf331bbb913f..fabd45ca069a 100644
--- a/jax/experimental/shard_map.py
+++ b/jax/experimental/shard_map.py
@@ -166,7 +166,7 @@ def wrapped(*args):
raise e('shard_map in_specs') from None
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
if s is not None)
- fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat)
+ fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False)
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
@@ -1405,7 +1405,7 @@ def new_out_names_thunk():
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
result = shard_map_p.bind(f_jvp, *args, **params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
- tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t
+ tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t
for p, t in zip(primal_out, tangent_out)]
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp
diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py
index 9eafa0db0fc2..d200577c2416 100644
--- a/jax/experimental/sparse/bcoo.py
+++ b/jax/experimental/sparse/bcoo.py
@@ -332,11 +332,11 @@ def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype
data, indices = primals_out
if type(Mdot) is ad.Zero:
- data_dot = ad.Zero.from_value(data)
+ data_dot = ad.Zero.from_primal_value(data)
else:
data_dot = _bcoo_extract(indices, Mdot)
- tangents_out = (data_dot, ad.Zero.from_value(indices))
+ tangents_out = (data_dot, ad.Zero.from_primal_value(indices))
return primals_out, tangents_out
@@ -571,7 +571,7 @@ def _bcoo_transpose_jvp(primals, tangents, *, permutation: Sequence[int], spinfo
data_dot, _ = tangents
primals_out = _bcoo_transpose(data, indices, permutation=permutation, spinfo=spinfo)
data_dot_out, _ = _bcoo_transpose(data_dot, indices, permutation=permutation, spinfo=spinfo)
- return primals_out, (data_dot_out, ad.Zero.from_value(indices))
+ return primals_out, (data_dot_out, ad.Zero.from_primal_value(indices))
def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int], spinfo: SparseInfo):
data_ct, indices_ct = ct
@@ -1277,7 +1277,7 @@ def _bcoo_spdot_general_jvp(primals, tangents, **kwds):
data_dot_out += _bcoo_spdot_general(lhs_data_dot, lhs_indices, rhs_data, rhs_indices, **kwds)[0]
if type(rhs_data_dot) is not ad.Zero:
data_dot_out += _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data_dot, rhs_indices, **kwds)[0]
- return primals_out, [data_dot_out, ad.Zero.from_value(primals_out[1])]
+ return primals_out, [data_dot_out, ad.Zero.from_primal_value(primals_out[1])]
# TODO(JVP): transpose rule
batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule
@@ -1358,8 +1358,8 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo):
permute = nfold_vmap(lambda d, p: d[p], props.n_batch)
data_out = permute(data, perm)
- indices_dot_out = ad.Zero.from_value(indices)
- data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm)
+ indices_dot_out = ad.Zero.from_primal_value(indices)
+ data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm)
return (data_out, indices_out), (data_dot_out, indices_dot_out)
_bcoo_sort_indices_hlo = mlir.lower_fun(
@@ -1544,8 +1544,8 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
permute = lambda x, i, y: x
permute = nfold_vmap(permute, props.n_batch)
data_out = permute(data_out, mapping, data)
- indices_dot_out = ad.Zero.from_value(indices_out)
- data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot)
+ indices_dot_out = ad.Zero.from_primal_value(indices_out)
+ data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot)
return (data_out, indices_out), (data_dot_out, indices_dot_out)
_bcoo_sum_duplicates_hlo = mlir.lower_fun(
diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py
index 7f3ebb43c0ec..7275d6bb20aa 100644
--- a/jax/experimental/sparse/bcsr.py
+++ b/jax/experimental/sparse/bcsr.py
@@ -272,11 +272,11 @@ def _bcsr_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype
data, indices, indptr = primals_out
if type(Mdot) is ad.Zero:
- data_dot = ad.Zero.from_value(data)
+ data_dot = ad.Zero.from_primal_value(data)
else:
data_dot = bcsr_extract(indices, indptr, Mdot)
- tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr))
+ tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr))
return primals_out, tangents_out
diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py
index 8863478df4d3..c65bc87235d6 100644
--- a/jax/experimental/sparse/coo.py
+++ b/jax/experimental/sparse/coo.py
@@ -348,11 +348,11 @@ def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
data, row, col = primals_out
if type(Mdot) is ad.Zero:
- data_dot = ad.Zero.from_value(data)
+ data_dot = ad.Zero.from_primal_value(data)
else:
data_dot = _coo_extract(row, col, Mdot)
- tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))
+ tangents_out = (data_dot, ad.Zero.from_primal_value(row), ad.Zero.from_primal_value(col))
return primals_out, tangents_out
diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py
index c1178943c02a..89d08f109d68 100644
--- a/jax/experimental/sparse/csr.py
+++ b/jax/experimental/sparse/csr.py
@@ -380,11 +380,11 @@ def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype):
data, indices, indptr = primals_out
if type(Mdot) is ad.Zero:
- data_dot = ad.Zero.from_value(data)
+ data_dot = ad.Zero.from_primal_value(data)
else:
data_dot = _csr_extract(indices, indptr, Mdot)
- tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr))
+ tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr))
return primals_out, tangents_out
diff --git a/jax/extend/BUILD b/jax/extend/BUILD
index babe0c8b10d2..59958c1da389 100644
--- a/jax/extend/BUILD
+++ b/jax/extend/BUILD
@@ -80,3 +80,9 @@ pytype_strict_library(
srcs = ["ffi.py"],
deps = ["//jax"],
)
+
+pytype_strict_library(
+ name = "ifrt_programs",
+ srcs = ["ifrt_programs.py"],
+ deps = ["//jax/_src/lib"],
+)
diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py
new file mode 100644
index 000000000000..d5fb9245af91
--- /dev/null
+++ b/jax/extend/ifrt_programs.py
@@ -0,0 +1,22 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Note: import as is required for names to be exported.
+# See PEP 484 & https://github.com/google/jax/issues/7570
+
+from jax._src.lib import xla_extension as _xe
+
+ifrt_programs = _xe.ifrt_programs
+
+del _xe
diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py
index 6663df3ac473..6bfc3473ff50 100644
--- a/jax/interpreters/ad.py
+++ b/jax/interpreters/ad.py
@@ -59,9 +59,7 @@
primitive_jvps as primitive_jvps,
primitive_transposes as primitive_transposes,
rearrange_binders as rearrange_binders,
- recast_to_float0 as recast_to_float0,
reducing_transposes as reducing_transposes,
- replace_float0s as replace_float0s,
standard_jvp as standard_jvp,
standard_jvp2 as standard_jvp2,
traceable as traceable,
diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py
index 706f5a2fe253..3c63948bee63 100644
--- a/jax/interpreters/partial_eval.py
+++ b/jax/interpreters/partial_eval.py
@@ -91,7 +91,6 @@
trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic,
trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2,
trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits,
- trace_to_subjaxpr_nounits_dyn as trace_to_subjaxpr_nounits_dyn,
trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd,
tracers_to_jaxpr as tracers_to_jaxpr,
trivial_ctx as trivial_ctx,
diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi
index d5b66c1b3b32..c23f659bd3f9 100644
--- a/jax/numpy/__init__.pyi
+++ b/jax/numpy/__init__.pyi
@@ -300,7 +300,8 @@ def diagonal(
def diff(a: ArrayLike, n: int = ..., axis: int = ...,
prepend: ArrayLike | None = ...,
append: ArrayLike | None = ...) -> Array: ...
-def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ...
+def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ..., *,
+ method: str | None = ...) -> Array: ...
divide = true_divide
def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ...
def dot(
diff --git a/jax/version.py b/jax/version.py
index cc690e02cb46..c6e4b3ad11ec 100644
--- a/jax/version.py
+++ b/jax/version.py
@@ -21,7 +21,7 @@
import pathlib
import subprocess
-_version = "0.4.32"
+_version = "0.4.34"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
@@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):
__version__ = _get_version_string()
-_minimum_jaxlib_version = "0.4.31"
+_minimum_jaxlib_version = "0.4.33"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD
index 5cf85f3697c7..34e40d12d5be 100644
--- a/jaxlib/cuda/BUILD
+++ b/jaxlib/cuda/BUILD
@@ -227,6 +227,22 @@ cc_library(
],
)
+cc_library(
+ name = "cusolver_interface",
+ srcs = ["//jaxlib/gpu:solver_interface.cc"],
+ hdrs = ["//jaxlib/gpu:solver_interface.h"],
+ deps = [
+ ":cuda_gpu_kernel_helpers",
+ ":cuda_vendor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@xla//xla/tsl/cuda:cublas",
+ "@xla//xla/tsl/cuda:cudart",
+ "@xla//xla/tsl/cuda:cusolver",
+ ],
+)
+
cc_library(
name = "cusolver_kernels_ffi",
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
@@ -237,6 +253,7 @@ cc_library(
":cuda_make_batch_pointers",
":cuda_solver_handle_pool",
":cuda_vendor",
+ ":cusolver_interface",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h
index fba57d11b9f2..47505020f3b8 100644
--- a/jaxlib/ffi_helpers.h
+++ b/jaxlib/ffi_helpers.h
@@ -62,35 +62,15 @@ namespace jax {
FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y)
// All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN
-// where the LHS is wrapped in parentheses.
-#define FFI_ASSIGN_OR_RETURN_EAT(...)
-#define FFI_ASSIGN_OR_RETURN_REM(...) __VA_ARGS__
-#define FFI_ASSIGN_OR_RETURN_EMPTY()
-
-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(...) \
- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1))
-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER(args) \
- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I args
-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty
-
-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY(...) \
- FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(__VA_ARGS__)
-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(...) \
- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(_, ##__VA_ARGS__)
-
-#define FFI_ASSIGN_OR_RETURN_IF_1(_Then, _Else) _Then
-#define FFI_ASSIGN_OR_RETURN_IF_0(_Then, _Else) _Else
-#define FFI_ASSIGN_OR_RETURN_IF(_Cond, _Then, _Else) \
- FFI_ASSIGN_OR_RETURN_CONCAT_(FFI_ASSIGN_OR_RETURN_IF_, _Cond)(_Then, _Else)
-
-#define FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(...) \
- FFI_ASSIGN_OR_RETURN_IS_EMPTY(FFI_ASSIGN_OR_RETURN_EAT __VA_ARGS__)
-
-#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(...) \
- FFI_ASSIGN_OR_RETURN_IF(FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(__VA_ARGS__), \
- FFI_ASSIGN_OR_RETURN_REM, \
- FFI_ASSIGN_OR_RETURN_EMPTY()) \
- __VA_ARGS__
+// where the LHS is wrapped in parentheses. See a more detailed discussion at
+// https://stackoverflow.com/a/62984543
+#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \
+ FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X)
+#define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__
+#define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \
+ FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__)
+#define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__
+#define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY
template
inline absl::StatusOr MaybeCastNoOverflow(
diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD
index 8c4144974b4a..048ea23a9cff 100644
--- a/jaxlib/gpu/BUILD
+++ b/jaxlib/gpu/BUILD
@@ -53,6 +53,8 @@ exports_files(srcs = [
"solver.cc",
"solver_handle_pool.cc",
"solver_handle_pool.h",
+ "solver_interface.cc",
+ "solver_interface.h",
"solver_kernels.cc",
"solver_kernels.h",
"solver_kernels_ffi.cc",
diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc
new file mode 100644
index 000000000000..3c8282ec603a
--- /dev/null
+++ b/jaxlib/gpu/solver_interface.cc
@@ -0,0 +1,237 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "jaxlib/gpu/solver_interface.h"
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "jaxlib/gpu/gpu_kernel_helpers.h"
+#include "jaxlib/gpu/vendor.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+namespace solver {
+
+// LU decomposition: getrf
+
+#define JAX_GPU_DEFINE_GETRF(Type, Name) \
+ template <> \
+ absl::StatusOr GetrfBufferSize(gpusolverDnHandle_t handle, int m, \
+ int n) { \
+ int lwork; \
+ JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
+ Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
+ return lwork; \
+ } \
+ \
+ template <> \
+ absl::Status Getrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \
+ Type *workspace, int lwork, int *ipiv, int *info) { \
+ return JAX_AS_STATUS( \
+ Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
+ }
+
+JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf);
+JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf);
+JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf);
+JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf);
+#undef JAX_GPU_DEFINE_GETRF
+
+#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \
+ template <> \
+ absl::Status GetrfBatched(gpublasHandle_t handle, int n, Type **a, \
+ int lda, int *ipiv, int *info, int batch) { \
+ return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \
+ }
+
+JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched);
+JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched);
+JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched);
+JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched);
+#undef JAX_GPU_DEFINE_GETRF_BATCHED
+
+// QR decomposition: geqrf
+
+#define JAX_GPU_DEFINE_GEQRF(Type, Name) \
+ template <> \
+ absl::StatusOr GeqrfBufferSize(gpusolverDnHandle_t handle, int m, \
+ int n) { \
+ int lwork; \
+ JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
+ Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
+ return lwork; \
+ } \
+ \
+ template <> \
+ absl::Status Geqrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \
+ Type *tau, Type *workspace, int lwork, int *info) { \
+ return JAX_AS_STATUS( \
+ Name(handle, m, n, a, m, tau, workspace, lwork, info)); \
+ }
+
+JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf);
+JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf);
+JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf);
+JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf);
+#undef JAX_GPU_DEFINE_GEQRF
+
+#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \
+ template <> \
+ absl::Status GeqrfBatched(gpublasHandle_t handle, int m, int n, \
+ Type **a, Type **tau, int *info, \
+ int batch) { \
+ return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \
+ }
+
+JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched);
+JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched);
+JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched);
+JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched);
+#undef JAX_GPU_DEFINE_GEQRF_BATCHED
+
+// Householder transformations: orgqr
+
+#define JAX_GPU_DEFINE_ORGQR(Type, Name) \
+ template <> \
+ absl::StatusOr OrgqrBufferSize(gpusolverDnHandle_t handle, int m, \
+ int n, int k) { \
+ int lwork; \
+ JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
+ handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \
+ return lwork; \
+ } \
+ \
+ template <> \
+ absl::Status Orgqr(gpusolverDnHandle_t handle, int m, int n, int k, \
+ Type *a, Type *tau, Type *workspace, int lwork, \
+ int *info) { \
+ return JAX_AS_STATUS( \
+ Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
+ }
+
+JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr);
+JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr);
+JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr);
+JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr);
+#undef JAX_GPU_DEFINE_ORGQR
+
+// Symmetric (Hermitian) eigendecomposition:
+// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
+// * QR algorithm: syevd/heevd
+
+#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \
+ template <> \
+ absl::StatusOr SyevjBufferSize( \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \
+ int lwork; \
+ JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
+ Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
+ /*w=*/nullptr, &lwork, params))); \
+ return lwork; \
+ } \
+ \
+ template <> \
+ absl::Status Syevj( \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \
+ Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \
+ return JAX_AS_STATUS( \
+ Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \
+ }
+
+JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj);
+JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj);
+JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj);
+JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj);
+#undef JAX_GPU_DEFINE_SYEVJ
+
+#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \
+ template <> \
+ absl::StatusOr SyevjBatchedBufferSize( \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \
+ int lwork; \
+ JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
+ Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
+ /*w=*/nullptr, &lwork, params, batch))); \
+ return lwork; \
+ } \
+ \
+ template <> \
+ absl::Status SyevjBatched( \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \
+ Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \
+ int batch) { \
+ return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \
+ lwork, info, params, batch)); \
+ }
+
+JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched);
+JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched);
+JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched);
+JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched);
+#undef JAX_GPU_DEFINE_SYEVJ_BATCHED
+
+#define JAX_GPU_DEFINE_SYEVD(Type, Name) \
+ template <> \
+ absl::StatusOr SyevdBufferSize(gpusolverDnHandle_t handle, \
+ gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n) { \
+ int lwork; \
+ JAX_RETURN_IF_ERROR( \
+ JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \
+ /*lda=*/n, /*w=*/nullptr, &lwork))); \
+ return lwork; \
+ } \
+ \
+ template <> \
+ absl::Status Syevd(gpusolverDnHandle_t handle, \
+ gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
+ int n, Type *a, RealType::value *w, \
+ Type *workspace, int lwork, int *info) { \
+ return JAX_AS_STATUS( \
+ Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
+ }
+
+JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd);
+JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd);
+JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd);
+JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd);
+#undef JAX_GPU_DEFINE_SYEVD
+
+// Symmetric rank-k update: syrk
+
+#define JAX_GPU_DEFINE_SYRK(Type, Name) \
+ template <> \
+ absl::Status Syrk(gpublasHandle_t handle, gpublasFillMode_t uplo, \
+ gpublasOperation_t trans, int n, int k, \
+ const Type *alpha, const Type *a, const Type *beta, \
+ Type *c) { \
+ int lda = trans == GPUBLAS_OP_N ? n : k; \
+ return JAX_AS_STATUS( \
+ Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \
+ }
+
+JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk);
+JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk);
+JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk);
+JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk);
+#undef JAX_GPU_DEFINE_SYRK
+
+} // namespace solver
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h
new file mode 100644
index 000000000000..5072be98489f
--- /dev/null
+++ b/jaxlib/gpu/solver_interface.h
@@ -0,0 +1,174 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines a standard interface to the GPU linear algebra libraries.
+
+#ifndef JAXLIB_GPU_SOLVER_INTERFACE_H_
+#define JAXLIB_GPU_SOLVER_INTERFACE_H_
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "jaxlib/gpu/vendor.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+namespace solver {
+
+template
+struct RealType {
+ using value = T;
+};
+
+template <>
+struct RealType {
+ using value = float;
+};
+
+template <>
+struct RealType {
+ using value = double;
+};
+
+#define JAX_GPU_SOLVER_EXPAND_DEFINITION(ReturnType, FunctionName) \
+ template \
+ ReturnType FunctionName( \
+ JAX_GPU_SOLVER_##FunctionName##_ARGS(T, typename RealType::value)) { \
+ return absl::UnimplementedError(absl::StrFormat( \
+ #FunctionName " not implemented for type %s", typeid(T).name())); \
+ } \
+ template <> \
+ ReturnType FunctionName( \
+ JAX_GPU_SOLVER_##FunctionName##_ARGS(float, float)); \
+ template <> \
+ ReturnType FunctionName( \
+ JAX_GPU_SOLVER_##FunctionName##_ARGS(double, double)); \
+ template <> \
+ ReturnType FunctionName( \
+ JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuComplex, float)); \
+ template <> \
+ ReturnType FunctionName( \
+ JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuDoubleComplex, double))
+
+// LU decomposition: getrf
+
+#define JAX_GPU_SOLVER_GetrfBufferSize_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, int m, int n
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GetrfBufferSize);
+#undef JAX_GPU_SOLVER_GetrfBufferSize_ARGS
+
+#define JAX_GPU_SOLVER_Getrf_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, int m, int n, Type *a, Type *workspace, \
+ int lwork, int *ipiv, int *info
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Getrf);
+#undef JAX_GPU_SOLVER_Getrf_ARGS
+
+#define JAX_GPU_SOLVER_GetrfBatched_ARGS(Type, ...) \
+ gpublasHandle_t handle, int n, Type **a, int lda, int *ipiv, int *info, \
+ int batch
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GetrfBatched);
+#undef JAX_GPU_SOLVER_GetrfBatched_ARGS
+
+// QR decomposition: geqrf
+
+#define JAX_GPU_SOLVER_GeqrfBufferSize_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, int m, int n
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GeqrfBufferSize);
+#undef JAX_GPU_SOLVER_GeqrfBufferSize_ARGS
+
+#define JAX_GPU_SOLVER_Geqrf_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, int m, int n, Type *a, Type *tau, \
+ Type *workspace, int lwork, int *info
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Geqrf);
+#undef JAX_GPU_SOLVER_Geqrf_ARGS
+
+#define JAX_GPU_SOLVER_GeqrfBatched_ARGS(Type, ...) \
+ gpublasHandle_t handle, int m, int n, Type **a, Type **tau, int *info, \
+ int batch
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GeqrfBatched);
+#undef JAX_GPU_SOLVER_GeqrfBatched_ARGS
+
+// Householder transformations: orgqr
+
+#define JAX_GPU_SOLVER_OrgqrBufferSize_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, int m, int n, int k
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, OrgqrBufferSize);
+#undef JAX_GPU_SOLVER_OrgqrBufferSize_ARGS
+
+#define JAX_GPU_SOLVER_Orgqr_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, int m, int n, int k, Type *a, Type *tau, \
+ Type *workspace, int lwork, int *info
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr);
+#undef JAX_GPU_SOLVER_Orgqr_ARGS
+
+// Symmetric (Hermitian) eigendecomposition:
+// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
+// * QR algorithm: syevd/heevd
+
+#define JAX_GPU_SOLVER_SyevjBufferSize_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBufferSize);
+#undef JAX_GPU_SOLVER_SyevjBufferSize_ARGS
+
+#define JAX_GPU_SOLVER_Syevj_ARGS(Type, Real) \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
+ int lwork, int *info, gpuSyevjInfo_t params
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevj);
+#undef JAX_GPU_SOLVER_Syevj_ARGS
+
+#define JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBatchedBufferSize);
+#undef JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS
+
+#define JAX_GPU_SOLVER_SyevjBatched_ARGS(Type, Real) \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
+ int lwork, int *info, gpuSyevjInfo_t params, int batch
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, SyevjBatched);
+#undef JAX_GPU_SOLVER_SyevjBatched_ARGS
+
+#define JAX_GPU_SOLVER_SyevdBufferSize_ARGS(Type, ...) \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevdBufferSize);
+#undef JAX_GPU_SOLVER_SyevdBufferSize_ARGS
+
+#define JAX_GPU_SOLVER_Syevd_ARGS(Type, Real) \
+ gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
+ gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
+ int lwork, int *info
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd);
+#undef JAX_GPU_SOLVER_Syevd_ARGS
+
+// Symmetric rank-k update: syrk
+
+#define JAX_GPU_SOLVER_Syrk_ARGS(Type, ...) \
+ gpublasHandle_t handle, gpublasFillMode_t uplo, gpublasOperation_t trans, \
+ int n, int k, const Type *alpha, const Type *a, const Type *beta, \
+ Type *c
+JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk);
+#undef JAX_GPU_SOLVER_Syrk_ARGS
+
+#undef JAX_GPU_SOLVER_EXPAND_DEFINITION
+
+} // namespace solver
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
+
+#endif // JAXLIB_GPU_SOLVER_INTERFACE_H_
diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc
index 3c74b85192ad..e3f63234f538 100644
--- a/jaxlib/gpu/solver_kernels_ffi.cc
+++ b/jaxlib/gpu/solver_kernels_ffi.cc
@@ -29,9 +29,13 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/make_batch_pointers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
+#include "jaxlib/gpu/solver_interface.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
+#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
+
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm);
namespace jax {
@@ -39,7 +43,6 @@ namespace JAX_GPU_NAMESPACE {
namespace ffi = ::xla::ffi;
-namespace {
template
inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch,
int64_t size,
@@ -53,22 +56,6 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch,
return static_cast(maybe_workspace.value());
}
-template
-struct RealType {
- using Type = T;
-};
-
-template <>
-struct RealType {
- using Type = float;
-};
-
-template <>
-struct RealType {
- using Type = double;
-};
-} // namespace
-
#define SOLVER_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \
return impl(__VA_ARGS__); \
@@ -93,33 +80,6 @@ struct RealType {
// LU decomposition: getrf
-namespace {
-#define GETRF_KERNEL_IMPL(type, name) \
- template <> \
- struct GetrfKernel { \
- static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \
- int n) { \
- int lwork; \
- JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
- name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
- return lwork; \
- } \
- static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
- type* workspace, int lwork, int* ipiv, \
- int* info) { \
- return JAX_AS_STATUS( \
- name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
- } \
- }
-
-template
-struct GetrfKernel;
-GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf);
-GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf);
-GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf);
-GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf);
-#undef GETRF_KERNEL_IMPL
-
template
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@@ -131,7 +91,7 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
- GetrfKernel::BufferSize(handle.get(), m, n));
+ solver::GetrfBufferSize(handle.get(), m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace(scratch, lwork, "getrf"));
@@ -140,13 +100,13 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
+ out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
int ipiv_step = std::min(m, n);
for (auto i = 0; i < batch; ++i) {
- FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run(
+ FFI_RETURN_IF_ERROR_STATUS(solver::Getrf(
handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data));
out_data += m * n;
ipiv_data += ipiv_step;
@@ -155,23 +115,6 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
return ffi::Error::Success();
}
-#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
- template <> \
- struct GetrfBatchedKernel { \
- static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
- int* ipiv, int* info, int batch) { \
- return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
- } \
- }
-
-template
-struct GetrfBatchedKernel;
-GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
-GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
-GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
-GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
-#undef GETRF_BATCHED_KERNEL_IMPL
-
template
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
@@ -188,15 +131,15 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
+ out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
sizeof(T) * n * n);
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
- FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run(
+ FFI_RETURN_IF_ERROR_STATUS(solver::GetrfBatched(
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
return ffi::Error::Success();
@@ -228,7 +171,6 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in getrf", absl::FormatStreamed(dataType)));
}
-} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
ffi::Ffi::Bind()
@@ -242,33 +184,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
// QR decomposition: geqrf
-namespace {
-#define GEQRF_KERNEL_IMPL(type, name) \
- template <> \
- struct GeqrfKernel { \
- static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \
- int n) { \
- int lwork; \
- JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
- name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
- return lwork; \
- } \
- static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
- type* tau, type* workspace, int lwork, \
- int* info) { \
- return JAX_AS_STATUS( \
- name(handle, m, n, a, m, tau, workspace, lwork, info)); \
- } \
- }
-
-template
-struct GeqrfKernel;
-GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf);
-GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf);
-GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf);
-GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf);
-#undef GEQRF_KERNEL_IMPL
-
template
ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@@ -279,7 +194,7 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
- GeqrfKernel::BufferSize(handle.get(), m, n));
+ solver::GeqrfBufferSize(handle.get(), m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace(scratch, lwork, "geqrf"));
@@ -292,14 +207,14 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
auto out_data = static_cast(out->untyped_data());
auto tau_data = static_cast(tau->untyped_data());
if (a_data != out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
+ out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
int out_step = m * n;
int tau_step = std::min(m, n);
for (auto i = 0; i < batch; ++i) {
- FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel::Run(
+ FFI_RETURN_IF_ERROR_STATUS(solver::Geqrf(
handle.get(), m, n, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += tau_step;
@@ -307,23 +222,6 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
return ffi::Error::Success();
}
-#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \
- template <> \
- struct GeqrfBatchedKernel { \
- static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \
- type** tau, int* info, int batch) { \
- return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
- } \
- }
-
-template
-struct GeqrfBatchedKernel;
-GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched);
-GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched);
-GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched);
-GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched);
-#undef GEQRF_BATCHED_KERNEL_IMPL
-
template
ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@@ -341,21 +239,21 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
auto out_data = out->untyped_data();
auto tau_data = tau->untyped_data();
if (a_data != out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
+ out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch,
sizeof(T) * m * n);
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch,
sizeof(T) * std::min(m, n));
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
// We ignore the output value of `info` because it is only used for shape
// checking.
int info;
- FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel::Run(
+ FFI_RETURN_IF_ERROR_STATUS(solver::GeqrfBatched(
handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch));
return ffi::Error::Success();
@@ -385,7 +283,6 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType)));
}
-} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
ffi::Ffi::Bind()
@@ -398,34 +295,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
// Householder transformations: orgqr
-namespace {
-#define ORGQR_KERNEL_IMPL(type, name) \
- template <> \
- struct OrgqrKernel { \
- static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \
- int n, int k) { \
- int lwork; \
- JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
- name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \
- /*tau=*/nullptr, &lwork))); \
- return lwork; \
- } \
- static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \
- type* a, type* tau, type* workspace, int lwork, \
- int* info) { \
- return JAX_AS_STATUS( \
- name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
- } \
- }
-
-template
-struct OrgqrKernel;
-ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr);
-ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr);
-ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr);
-ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr);
-#undef ORGQR_KERNEL_IMPL
-
template
ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@@ -437,7 +306,7 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
- OrgqrKernel::BufferSize(handle.get(), m, n, k));
+ solver::OrgqrBufferSize(handle.get(), m, n, k));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace(scratch, lwork, "orgqr"));
@@ -450,13 +319,13 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
auto tau_data = static_cast(tau.untyped_data());
auto out_data = static_cast(out->untyped_data());
if (a_data != out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
+ out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
int out_step = m * n;
for (auto i = 0; i < batch; ++i) {
- FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel::Run(
+ FFI_RETURN_IF_ERROR_STATUS(solver::Orgqr(
handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += k;
@@ -492,7 +361,6 @@ ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType)));
}
-} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
ffi::Ffi::Bind()
@@ -510,98 +378,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
// dispatches dynamically to both syevd and syevj depending on the problem
// size and the algorithm selected by the user via the `algorithm` attribute.
-namespace {
-#define SYEVJ_KERNEL_IMPL(type, name) \
- template <> \
- struct SyevjKernel { \
- static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \
- gpusolverEigMode_t jobz, \
- gpusolverFillMode_t uplo, int n, \
- gpuSyevjInfo_t params) { \
- int lwork; \
- JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
- name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
- /*w=*/nullptr, &lwork, params))); \
- return lwork; \
- } \
- static absl::Status Run(gpusolverDnHandle_t handle, \
- gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
- int n, type* a, RealType::Type* w, \
- type* workspace, int lwork, int* info, \
- gpuSyevjInfo_t params) { \
- return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \
- lwork, info, params)); \
- } \
- }
-
-template
-struct SyevjKernel;
-SYEVJ_KERNEL_IMPL(float, gpusolverDnSsyevj);
-SYEVJ_KERNEL_IMPL(double, gpusolverDnDsyevj);
-SYEVJ_KERNEL_IMPL(gpuComplex, gpusolverDnCheevj);
-SYEVJ_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevj);
-#undef SYEVJ_KERNEL_IMPL
-
-#define SYEVJ_BATCHED_KERNEL_IMPL(type, name) \
- template <> \
- struct SyevjBatchedKernel { \
- static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \
- gpusolverEigMode_t jobz, \
- gpusolverFillMode_t uplo, int n, \
- gpuSyevjInfo_t params, int batch) { \
- int lwork; \
- JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
- name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
- /*w=*/nullptr, &lwork, params, batch))); \
- return lwork; \
- } \
- static absl::Status Run(gpusolverDnHandle_t handle, \
- gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
- int n, type* a, RealType::Type* w, \
- type* workspace, int lwork, int* info, \
- gpuSyevjInfo_t params, int batch) { \
- return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \
- lwork, info, params, batch)); \
- } \
- }
-
-template
-struct SyevjBatchedKernel;
-SYEVJ_BATCHED_KERNEL_IMPL(float, gpusolverDnSsyevjBatched);
-SYEVJ_BATCHED_KERNEL_IMPL(double, gpusolverDnDsyevjBatched);
-SYEVJ_BATCHED_KERNEL_IMPL(gpuComplex, gpusolverDnCheevjBatched);
-SYEVJ_BATCHED_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevjBatched);
-#undef SYEVJ_BATCHED_KERNEL_IMPL
-
-#define SYEVD_KERNEL_IMPL(type, name) \
- template <> \
- struct SyevdKernel { \
- static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \
- gpusolverEigMode_t jobz, \
- gpusolverFillMode_t uplo, int n) { \
- int lwork; \
- JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
- name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
- /*w=*/nullptr, &lwork))); \
- return lwork; \
- } \
- static absl::Status Run(gpusolverDnHandle_t handle, \
- gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
- int n, type* a, RealType::Type* w, \
- type* workspace, int lwork, int* info) { \
- return JAX_AS_STATUS( \
- name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
- } \
- }
-
-template
-struct SyevdKernel;
-SYEVD_KERNEL_IMPL(float, gpusolverDnSsyevd);
-SYEVD_KERNEL_IMPL(double, gpusolverDnDsyevd);
-SYEVD_KERNEL_IMPL(gpuComplex, gpusolverDnCheevd);
-SYEVD_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevd);
-#undef SYEVD_KERNEL_IMPL
-
template
ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm,
@@ -618,49 +394,48 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
auto a_data = static_cast(a.untyped_data());
auto out_data = static_cast(out->untyped_data());
- auto w_data = static_cast::Type*>(w->untyped_data());
+ auto w_data = static_cast::value*>(w->untyped_data());
auto info_data = info->typed_data();
if (a_data != out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
+ out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
if (algorithm == SyevdAlgorithm::kJacobi ||
(algorithm == SyevdAlgorithm::kDefault && size <= 32)) {
gpuSyevjInfo_t params;
- FFI_RETURN_IF_ERROR_STATUS(
- JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms));
std::unique_ptr params_cleanup(
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
if (batch == 1) {
- FFI_ASSIGN_OR_RETURN(int lwork, SyevjKernel::BufferSize(
+ FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize(
handle.get(), jobz, uplo, n, params));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace(scratch, lwork, "syevj"));
- FFI_RETURN_IF_ERROR_STATUS(
- SyevjKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data,
- workspace, lwork, info_data, params));
+ FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n,
+ out_data, w_data, workspace,
+ lwork, info_data, params));
} else {
FFI_ASSIGN_OR_RETURN(
- int lwork, SyevjBatchedKernel::BufferSize(handle.get(), jobz, uplo,
+ int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo,
n, params, batch));
FFI_ASSIGN_OR_RETURN(
auto workspace,
AllocateWorkspace(scratch, lwork, "syevj_batched"));
- FFI_RETURN_IF_ERROR_STATUS(SyevjBatchedKernel::Run(
- handle.get(), jobz, uplo, n, out_data, w_data, workspace, lwork,
- info_data, params, batch));
+ FFI_RETURN_IF_ERROR_STATUS(
+ solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data,
+ workspace, lwork, info_data, params, batch));
}
} else {
FFI_ASSIGN_OR_RETURN(
- int lwork, SyevdKernel::BufferSize(handle.get(), jobz, uplo, n));
+ int lwork, solver::SyevdBufferSize(handle.get(), jobz, uplo, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace(scratch, lwork, "syevd"));
int out_step = n * n;
for (auto i = 0; i < batch; ++i) {
- FFI_RETURN_IF_ERROR_STATUS(
- SyevdKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data,
- workspace, lwork, info_data));
+ FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n,
+ out_data, w_data, workspace,
+ lwork, info_data));
out_data += out_step;
w_data += n;
++info_data;
@@ -695,7 +470,6 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in syevd", absl::FormatStreamed(dataType)));
}
-} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch,
ffi::Ffi::Bind()
@@ -709,110 +483,83 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch,
.Ret>() // info
);
-#define SYRK_KERNEL_IMPL(type, fn) \
- template <> \
- struct SyrkKernel { \
- static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \
- std::int64_t k, bool transpose, \
- const type* alpha, const type* beta, \
- const type* a_matrix, type* c_matrix) { \
- gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \
- gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \
- int lda = transpose ? n : k; \
- return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \
- alpha, a_matrix, lda, beta, \
- c_matrix, n)); \
- } \
- }
-
-template
-struct SyrkKernel;
-
-SYRK_KERNEL_IMPL(float, gpublasSsyrk);
-SYRK_KERNEL_IMPL(double, gpublasDsyrk);
-SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk);
-SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk);
-#undef SYRK_KERNEL_IMPL
+// Symmetric rank-k update: syrk
template
-ffi::Error SyrkImpl(gpuStream_t stream,
- ffi::AnyBuffer a_matrix,
- ffi::AnyBuffer c_matrix,
- bool transpose,
- ffi::AnyBuffer alpha,
- ffi::AnyBuffer beta,
- ffi::Result c_matrix_out) {
+ffi::Error SyrkImpl(gpuStream_t stream, bool transpose, ffi::AnyBuffer a,
+ ffi::AnyBuffer c_in, ffi::AnyBuffer alpha,
+ ffi::AnyBuffer beta, ffi::Result c_out) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
- SplitBatch2D(a_matrix.dimensions()));
- FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]),
- SplitBatch2D(c_matrix.dimensions()));
- FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]),
- SplitBatch2D(c_matrix_out->dimensions()));
- if (batch != batch_c || batch != batch_out) {
- return ffi::Error(ffi::ErrorCode::kInvalidArgument,
- "a_matrix, c_matrix and c_matrix_out must have the same "
- "batch size.");
+ SplitBatch2D(a.dimensions()));
+ if (alpha.element_count() != 1 || beta.element_count() != 1) {
+ return ffi::Error::InvalidArgument(
+ "The alpha and beta inputs to syrk must be scalars");
}
- int n = transpose ? cols : rows;
- int k = transpose ? rows : cols;
-
+ auto size = transpose ? cols : rows;
FFI_RETURN_IF_ERROR(
- CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk"));
+ CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk"));
FFI_RETURN_IF_ERROR(
- CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk"));
+ CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk"));
+
+ FFI_ASSIGN_OR_RETURN(auto n,
+ MaybeCastNoOverflow(transpose ? cols : rows));
+ FFI_ASSIGN_OR_RETURN(auto k,
+ MaybeCastNoOverflow(transpose ? rows : cols));
+ gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER;
+ gpublasOperation_t trans = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T;
- const T* a_data = static_cast(a_matrix.untyped_data());
- T* c_data = static_cast(c_matrix.untyped_data());
- T* c_out_data = static_cast(c_matrix_out->untyped_data());
+ const T* a_data = static_cast(a.untyped_data());
+ T* c_data = static_cast(c_in.untyped_data());
+ T* c_out_data = static_cast(c_out->untyped_data());
// with alpha or beta provided as device_pointers, cublassyrk will SIGSEGV
T host_alpha;
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- &host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
- stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_alpha, alpha.untyped_data(),
+ sizeof(T), gpuMemcpyDeviceToHost,
+ stream));
T host_beta;
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- &host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
- stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_beta, beta.untyped_data(),
+ sizeof(T), gpuMemcpyDeviceToHost,
+ stream));
if (c_data != c_out_data) {
- FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
- c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice,
- stream)));
+ JAX_FFI_RETURN_IF_GPU_ERROR(
+ gpuMemcpyAsync(c_out_data, c_data, c_in.size_bytes(),
+ gpuMemcpyDeviceToDevice, stream));
}
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
for (int i = 0; i < batch; ++i) {
- FFI_RETURN_IF_ERROR_STATUS(SyrkKernel::Run(
- handle.get(), n, k, transpose, &host_alpha, &host_beta,
- a_data + i * k * n, c_out_data + i * n * n));
+ FFI_RETURN_IF_ERROR_STATUS(solver::Syrk(handle.get(), uplo, trans, n, k,
+ &host_alpha, a_data, &host_beta,
+ c_out_data));
+ a_data += k * n;
+ c_out_data += n * n;
}
return ffi::Error::Success();
}
-ffi::Error SyrkDispatch(
- gpuStream_t stream,
- ffi::AnyBuffer a_matrix,
- ffi::AnyBuffer c_matrix,
- bool transpose,
- ffi::AnyBuffer alpha,
- ffi::AnyBuffer beta,
- ffi::Result c_matrix_out) {
- auto dataType = a_matrix.element_type();
- SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose,
- alpha, beta, c_matrix_out);
- return ffi::Error::InvalidArgument("Unsupported element type for Syrk");
+ffi::Error SyrkDispatch(gpuStream_t stream, bool transpose, ffi::AnyBuffer a,
+ ffi::AnyBuffer c_in, ffi::AnyBuffer alpha,
+ ffi::AnyBuffer beta,
+ ffi::Result c_out) {
+ auto dataType = a.element_type();
+ SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, a, c_in, alpha, beta,
+ c_out);
+ return ffi::Error::InvalidArgument(absl::StrFormat(
+ "Unsupported dtype %s in syrk", absl::FormatStreamed(dataType)));
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
ffi::Ffi::Bind()
.Ctx>()
- .Arg() // a_matrix
- .Arg() // c_matrix
.Attr("transpose") // transpose
- .Arg() // alpha
- .Arg() // beta
- .Ret()); // c_matrix_out
+ .Arg() // a
+ .Arg() // c_in
+ .Arg() // alpha
+ .Arg() // beta
+ .Ret() // c_out
+);
#undef SOLVER_DISPATCH_IMPL
#undef SOLVER_BLAS_DISPATCH_IMPL
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
index bc1d30893537..d80db4e1394e 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
+++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
@@ -322,13 +322,21 @@ LogicalResult MemRefBitcastOp::verify() {
auto src_dim_size = src_ty.getDimSize(i);
auto tgt_dim_size = tgt_ty.getDimSize(i);
if (i == src_ty.getRank() - 2) {
- src_dim_size *= src_bitwidth;
- tgt_dim_size *= tgt_bitwidth;
- }
- if (src_dim_size != tgt_dim_size) {
- return emitOpError(
- "Expected the same dim size on the 2nd minormost dim: ")
- << src_dim_size << " vs " << tgt_dim_size;
+ auto src_bits = src_dim_size * src_bitwidth;
+ auto tgt_bits = tgt_dim_size * tgt_bitwidth;
+ if (src_bits != tgt_bits) {
+ return emitOpError(
+ "Expected the same number of bits on the 2nd minormost "
+ "dim: (")
+ << src_dim_size << " * " << src_bitwidth << ") vs ("
+ << tgt_dim_size << " * " << tgt_bitwidth << ")";
+ ;
+ }
+ } else {
+ if (src_dim_size != tgt_dim_size) {
+ return emitOpError("Expected the same dim size on dim ")
+ << i << ": " << src_dim_size << " vs " << tgt_dim_size;
+ }
}
}
// Source and target attributes may be different before propagation is done by
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc
index 37112666f542..569038500067 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc
+++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc
@@ -78,6 +78,14 @@ LogicalResult specializeMemorySpace(TypedValue value,
updateResultFrom(op, op.getInput().getType());
continue;
}
+ if (auto op = dyn_cast(some_op)) {
+ updateResultFrom(op, op.getInput().getType());
+ continue;
+ }
+ if (auto op = dyn_cast(some_op)) {
+ updateResultFrom(op, op.getInput().getType());
+ continue;
+ }
if (auto op = dyn_cast(some_op)) {
updateResultFrom(op, op.getOperand().getType());
continue;
diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc
index 2e5723b184a8..103f9f78c32f 100644
--- a/jaxlib/mosaic/gpu/custom_call.cc
+++ b/jaxlib/mosaic/gpu/custom_call.cc
@@ -377,10 +377,40 @@ GetKernelCache() {
return std::make_pair(&context_cache, &mutex);
}
+
+absl::StatusOr CompileAndInit(const char* module) {
+ mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
+ InitContext(&context);
+ mlir::ParserConfig parse_config(&context);
+ auto module_op =
+ mlir::parseSourceString(module, parse_config);
+ if (!module_op) {
+ return absl::InternalError("Failed to parse module");
+ }
+ auto maybe_engine = Compile(*module_op);
+ if (!maybe_engine.ok()) {
+ return maybe_engine.status();
+ }
+ mlir::ExecutionEngine* execution_engine = maybe_engine->get();
+ auto main = execution_engine->lookupPacked("_mlir_ciface_main");
+ auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
+ if (!init || !main) {
+ return absl::InternalError("Failed to retrieve kernel function");
+ }
+ void* module_ptr = nullptr;
+ void* kernel_ptr = nullptr;
+ void** module_ptr_ptr = &module_ptr;
+ void** kernel_ptr_ptr = &kernel_ptr;
+ void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
+ reinterpret_cast(*init)(init_args);
+ return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
+ reinterpret_cast(*main));
+}
+
// Each compiled kernel has a unique init func, and each kernel is used from
// a single HLO module. So it should be safe to not include the CUDA context
// in the key.
-absl::StatusOr> CompileAndInit(
+absl::StatusOr> CachedCompileAndInit(
CacheKey key, const char* module) {
auto cache_and_mutex = GetKernelCache();
auto* cache = cache_and_mutex.first;
@@ -397,33 +427,11 @@ absl::StatusOr> CompileAndInit(
absl::MutexLock lock(mutex);
// We released the reader lock, another thread might have initialized it.
if (cache->find(key) == cache->end()) {
- mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
- InitContext(&context);
- mlir::ParserConfig parse_config(&context);
- auto module_op =
- mlir::parseSourceString(module, parse_config);
- if (!module_op) {
- return absl::InternalError("Failed to parse module");
- }
- auto maybe_engine = Compile(*module_op);
- if (!maybe_engine.ok()) {
- return maybe_engine.status();
+ auto compiled = CompileAndInit(module);
+ if (!compiled.ok()) {
+ return compiled.status();
}
- mlir::ExecutionEngine* execution_engine = maybe_engine->get();
- auto main = execution_engine->lookupPacked("_mlir_ciface_main");
- auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
- if (!init || !main) {
- return absl::InternalError("Failed to retrieve kernel function");
- }
- void* module_ptr = nullptr;
- void* kernel_ptr = nullptr;
- void** module_ptr_ptr = &module_ptr;
- void** kernel_ptr_ptr = &kernel_ptr;
- void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
- reinterpret_cast(*init)(init_args);
- cache->insert_or_assign(
- key, CompiledKernel(std::move(*maybe_engine), kernel_ptr,
- reinterpret_cast(*main)));
+ cache->insert_or_assign(key, std::move(*compiled));
}
return cache->at(key).GetHostLaunch();
}
@@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
abort();
}
CacheKey key(hash, reinterpret_cast(ctx));
- auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash));
+ auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash));
if (!ctx_and_kernel.ok()) {
XlaCustomCallStatusSetFailure(status,
ctx_and_kernel.status().message().data(),
@@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
"CUDA");
} // namespace
+
+extern "C" {
+
+__attribute__((visibility("default")))
+void** MosaicGpuCompile(const char* module) {
+ auto compiled = CompileAndInit(module);
+ if (!compiled.ok()) {
+ return nullptr;
+ }
+ auto [ctx, launch] = compiled->GetHostLaunch();
+ auto tuple_ptr = std::unique_ptr(new void*[3]);
+ if (!tuple_ptr) {
+ return nullptr;
+ }
+ tuple_ptr.get()[0] = ctx;
+ tuple_ptr.get()[1] = reinterpret_cast(launch);
+ tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled));
+ if (!tuple_ptr.get()[2]) {
+ return nullptr;
+ }
+ return tuple_ptr.release();
+}
+
+__attribute__((visibility("default")))
+void MosaicGpuUnload(void** tuple_ptr) {
+ delete reinterpret_cast(tuple_ptr[2]);
+ delete[] tuple_ptr;
+}
+
+} // extern "C"
diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD
index ce856ae5f83d..5987415224c7 100644
--- a/jaxlib/rocm/BUILD
+++ b/jaxlib/rocm/BUILD
@@ -168,6 +168,21 @@ cc_library(
],
)
+cc_library(
+ name = "hipsolver_interface",
+ srcs = ["//jaxlib/gpu:solver_interface.cc"],
+ hdrs = ["//jaxlib/gpu:solver_interface.h"],
+ deps = [
+ ":hip_gpu_kernel_helpers",
+ ":hip_vendor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@local_config_rocm//rocm:hipblas",
+ "@local_config_rocm//rocm:hipsolver",
+ ],
+)
+
cc_library(
name = "hipsolver_kernels_ffi",
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
@@ -178,6 +193,7 @@ cc_library(
":hip_make_batch_pointers",
":hip_solver_handle_pool",
":hip_vendor",
+ ":hipsolver_interface",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel
index 8463cba08c5f..4642af12011d 100644
--- a/jaxlib/tools/BUILD.bazel
+++ b/jaxlib/tools/BUILD.bazel
@@ -64,11 +64,12 @@ py_test(
cc_binary(
name = "pjrt_c_api_gpu_plugin.so",
linkopts = [
- "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)",
+ "-Wl,--version-script,$(location :gpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
deps = [
+ ":gpu_version_script.lds",
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
"@xla//xla/service:gpu_plugin",
diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py
index 28d2806a7da9..ced0b76c344c 100644
--- a/jaxlib/tools/build_gpu_kernels_wheel.py
+++ b/jaxlib/tools/build_gpu_kernels_wheel.py
@@ -74,7 +74,7 @@ def write_setup_cfg(sources_path, cpu):
license_files = LICENSE.txt
[bdist_wheel]
-plat-name={tag}
+plat_name={tag}
""")
diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py
index 73cb8a9e020d..0e2bba0c74d0 100644
--- a/jaxlib/tools/build_gpu_plugin_wheel.py
+++ b/jaxlib/tools/build_gpu_plugin_wheel.py
@@ -80,7 +80,7 @@ def write_setup_cfg(sources_path, cpu):
license_files = LICENSE.txt
[bdist_wheel]
-plat-name={tag}
+plat_name={tag}
python-tag=py3
"""
)
diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py
index 48aab847f3fb..6305b0c24aa8 100644
--- a/jaxlib/tools/build_wheel.py
+++ b/jaxlib/tools/build_wheel.py
@@ -164,7 +164,7 @@ def write_setup_cfg(sources_path, cpu):
license_files = LICENSE.txt
[bdist_wheel]
-plat-name={tag}
+plat_name={tag}
"""
)
diff --git a/jaxlib/tools/gpu_version_script.lds b/jaxlib/tools/gpu_version_script.lds
new file mode 100644
index 000000000000..8e46b2c590b2
--- /dev/null
+++ b/jaxlib/tools/gpu_version_script.lds
@@ -0,0 +1,11 @@
+VERS_1.0 {
+ global:
+ extern "C" {
+ GetPjrtApi;
+ MosaicGpuCompile;
+ MosaicGpuUnload;
+ };
+
+ local:
+ *;
+};
diff --git a/setup.py b/setup.py
index 08ce8dbcb4ed..81eef74e0049 100644
--- a/setup.py
+++ b/setup.py
@@ -19,10 +19,10 @@
project_name = 'jax'
-_current_jaxlib_version = '0.4.31'
+_current_jaxlib_version = '0.4.33'
# The following should be updated after each new jaxlib release.
-_latest_jaxlib_version_on_pypi = '0.4.31'
-_libtpu_version = '0.1.dev20240729'
+_latest_jaxlib_version_on_pypi = '0.4.33'
+_libtpu_version = '0.1.dev20240916'
def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
diff --git a/tests/BUILD b/tests/BUILD
index d1fb4dcc7cde..4635a48cede1 100644
--- a/tests/BUILD
+++ b/tests/BUILD
@@ -1187,7 +1187,10 @@ jax_test(
shard_count = {
"tpu": 5,
},
- tags = ["noasan"], # Times out
+ tags = [
+ "noasan", # Times out.
+ "notsan", # TODO(b/309111150): Re-enable after rolling forward cl/666056414.
+ ],
deps = [
"//jax:experimental",
"//jax:experimental_host_callback",
diff --git a/tests/api_test.py b/tests/api_test.py
index 1a119846be9c..1deaa4c08dc8 100644
--- a/tests/api_test.py
+++ b/tests/api_test.py
@@ -50,7 +50,6 @@
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
-from jax._src import deprecations
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import xla_bridge
@@ -60,7 +59,6 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
-from jax._src.lib import xla_client
from jax._src.lib import xla_extension
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
@@ -2717,26 +2715,6 @@ def __init__(self, *args, **kwargs):
out_shape = api.eval_shape(lambda x: x, x) # doesn't crash
self.assertEqual(out_shape.shape, (3,))
- def test_eval_shape_names(self):
- raise unittest.SkipTest("named shape are deprecated")
-
- def fun(x, y):
- return lax.psum(x, 'i') + y
-
- class MyArgArray:
- def __init__(self, shape, dtype, named_shape):
- self.shape = shape
- self.dtype = jnp.dtype(dtype)
- self.named_shape = named_shape
-
- x = MyArgArray((3, 2), jnp.float32, {'i': 10})
- y = MyArgArray((3, 2), jnp.float32, {'j': 5})
- with core.extend_axis_env('i', 10, None):
- with core.extend_axis_env('j', 5, None):
- out_shape = api.eval_shape(fun, x, y)
-
- self.assertEqual(out_shape.named_shape, {'j': 5})
-
def test_issue_871(self):
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
x = jnp.array([1, 2, 3])
@@ -2904,74 +2882,6 @@ def test_jacfwd_of_complex_errors(self):
r"sub-dtype of np.floating\), but got complex.*"),
lambda: dfn(3. + 1j))
- def test_xla_computation(self):
- # these tests basically check the examples in the xla_computation docstring
-
- def e(x):
- return jnp.sin(jnp.cos(x))
- c = api.xla_computation(e)(2.)
- self.assertIn('cosine', c.as_hlo_text())
- self.assertIn('sine', c.as_hlo_text())
-
- def f(x):
- return x - lax.psum(x, 'i')
- axis_env = [('i', 4)]
- c = api.xla_computation(f, axis_env=axis_env)(2)
- self.assertIn('all-reduce', c.as_hlo_text())
- self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
-
- def g(x):
- rowsum = lax.psum(x, 'i')
- colsum = lax.psum(x, 'j')
- allsum = lax.psum(x, ('i', 'j'))
- return rowsum, colsum, allsum
- axis_env = [('i', 4), ('j', 2)]
- c = api.xla_computation(g, axis_env=axis_env)(5.)
- self.assertIn('all-reduce', c.as_hlo_text())
- self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
- self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
- self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
-
- def h(x):
- rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
- colsum = lax.psum(x, 'j')
- return rowsum, colsum
- axis_env = [('i', 4), ('j', 2)]
- c = api.xla_computation(h, axis_env=axis_env)(5.)
- self.assertIn('all-reduce', c.as_hlo_text())
- self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
- self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
-
- def test_xla_computation_args(self):
- def foo(x, y, z):
- return x + y + z
-
- c = api.xla_computation(foo)(1., 2., 3.)
- self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
-
- c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
- param_shapes = c.program_shape().parameter_shapes()
- self.assertEqual(len(param_shapes), 1)
- self.assertEqual(param_shapes[0].xla_element_type(),
- xla_client.PrimitiveType.TUPLE)
-
- def test_xla_computation_duck_typing(self):
- def foo(x, y, z):
- return x + y + z
-
- x = jax.ShapeDtypeStruct((), np.float32)
- y = jax.ShapeDtypeStruct((), np.float32)
- z = jax.ShapeDtypeStruct((), np.float32)
-
- c = api.xla_computation(foo)(x, y, z)
- self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
-
- c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
- param_shapes = c.program_shape().parameter_shapes()
- self.assertEqual(len(param_shapes), 1)
- self.assertEqual(param_shapes[0].xla_element_type(),
- xla_client.PrimitiveType.TUPLE)
-
def test_compiler_ir(self):
# TODO(phawkins): merge these tests with the `xla_computation` tests.
def e(x):
@@ -2983,72 +2893,6 @@ def e(x):
self.assertIn("stablehlo.cosine", stablehlo)
self.assertIn("stablehlo.sine", stablehlo)
- def test_staging_out_multi_replica(self):
- def f(x):
- return api.pmap(jnp.mean)(x)
- xla_comp = api.xla_computation(f)
- xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
-
- def test_xla_computation_instantiate_constant_outputs(self):
- def f():
- return jnp.zeros((3, 4))
-
- xla_comp = api.xla_computation(f)()
- out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
- self.assertEqual(out_shape.dimensions(), (3, 4))
-
- def test_xla_computation_static_argnums(self):
- def f(x, y):
- return x + y
-
- xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
- hlo_text = xla_comp.as_hlo_text()
- self.assertIn("constant(3)", hlo_text)
- # The static arguments should be removed from the function being compiled,
- # thus the function should have only a single argument.
- self.assertIn("parameter(0)", hlo_text)
- self.assertNotIn("parameter(1)", hlo_text)
-
- def test_xla_computation_return_shape(self):
- _, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
- return_shape=True)(np.int32(1))
- expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
- api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
- self.assertEqual(shape_tree, expected)
-
- def test_xla_computation_psum_constant(self):
- f = lambda: jax.lax.psum(1, "i")
- api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
-
- @jtu.ignore_warning(message="Some donated buffers were not usable")
- def test_xla_computation_donate_argnums(self):
- api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
-
- def test_xla_computation_lower_fun_axis_env(self):
- axis_name = 'i'
- def fn(x):
- y = lax.all_gather(
- x, axis_name=axis_name)
- return y * lax.axis_index(axis_name).astype(jnp.float32)
-
- input_x = jnp.ones((5,6,4), dtype=jnp.float32)
- axis_env = [(axis_name, jax.local_device_count())]
- _ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
-
- @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
- def test_xla_computation_axis_env(self):
- is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
- xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
-
- def fn(x):
- z = x * jax.lax.axis_index('i').astype(jnp.float32)
- def inner_fn(carry, a):
- return carry + a, ()
- return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
-
- x = jnp.ones((5, 6, 4), dtype=jnp.float32)
- _ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
-
def test_concurrent_device_get_and_put(self):
def f(x):
for _ in range(100):
@@ -3678,7 +3522,7 @@ def f(x):
return x + y + y
x = np.array([1, 2], dtype=np.float32)
- hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
+ hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n')
hlo_lines = {s.strip() for s in hlo_lines}
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
@@ -3805,11 +3649,6 @@ def g(x):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
- def test_xla_computation_zeros_doesnt_device_put(self):
- with jtu.count_device_put() as count:
- api.xla_computation(lambda: jnp.zeros(3))()
- self.assertEqual(count[0], 0)
-
def test_join_concrete_arrays_with_omnistaging(self):
# https://github.com/google/jax/issues/4622
x = jnp.array([1., 2., 3.])
@@ -5532,13 +5371,12 @@ def f(x):
x, _ = g(x)
return x
- c = api.xla_computation(f)(2.)
- self.assertNotIn('while', c.as_hlo_text())
- self.assertNotIn('conditional', c.as_hlo_text())
- self.assertNotIn('opt-barrier', c.as_hlo_text())
+ text = jax.jit(f).lower(2.).as_text('hlo')
+ self.assertNotIn('while', text)
+ self.assertNotIn('conditional', text)
+ self.assertNotIn('opt-barrier', text)
- c = api.xla_computation(grad(f))(2.)
- text = c.as_hlo_text()
+ text = jax.jit(grad(f)).lower(2.).as_text('hlo')
self.assertTrue('while' in text or 'conditional' in text
or 'opt-barrier' in text)
@@ -5557,13 +5395,13 @@ def f(x):
x, _ = g(x)
return x
- c = api.xla_computation(f)(2.)
- self.assertNotIn('while', c.as_hlo_text())
- self.assertNotIn('conditional', c.as_hlo_text())
+ text = jax.jit(f).lower(2.).as_text('hlo')
+ self.assertNotIn('while', text)
+ self.assertNotIn('conditional', text)
- c = api.xla_computation(grad(f))(2.)
- self.assertNotIn('while', c.as_hlo_text())
- self.assertNotIn('conditional', c.as_hlo_text())
+ text = jax.jit(grad(f)).lower(2.).as_text('hlo')
+ self.assertNotIn('while', text)
+ self.assertNotIn('conditional', text)
@parameterized.named_parameters(
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
@@ -6608,49 +6446,6 @@ def f(x):
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
self.assertIn('psum', str(jaxpr))
- def test_make_jaxpr_named(self):
- raise unittest.SkipTest("named shape are deprecated")
- def f(x):
- return x - lax.psum(x, 'i')
-
- x = api.ShapeDtypeStruct(
- shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10})
- jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
- named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
- self.assertEqual(named_shapes, [{'i': 10}, {}])
-
- @parameterized.parameters(True, False)
- def test_vjp_reduce_axes_jaxpr(self, gy_batched):
- raise unittest.SkipTest("reduce_axes autodiff is removed")
- def f(w, x):
- return jnp.sin(jnp.dot(x, w))
-
- w = api.ShapeDtypeStruct(
- shape=(3, 4), dtype=jnp.float32, named_shape={})
- x = api.ShapeDtypeStruct(
- shape=(3,), dtype=jnp.float32, named_shape={'batch': 2})
- gy = api.ShapeDtypeStruct(
- shape=(4,), dtype=jnp.float32,
- named_shape={'batch': 2} if gy_batched else {})
-
- # per-example
- jaxpr, shapes = api.make_jaxpr(
- lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)],
- return_shape=True)(w, x, gy)
- expected = (api.ShapeDtypeStruct(
- shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x)
- self.assertEqual(shapes, expected)
- self.assertNotIn('psum', str(jaxpr))
-
- # reduced
- jaxpr, shapes = api.make_jaxpr(
- lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy),
- axis_env=[('batch', 2)],
- return_shape=True)(w, x, gy)
- expected = (w, x)
- self.assertEqual(shapes, expected)
- self.assertIn('psum', str(jaxpr))
-
def test_weak_type_jit_invariance(self):
y = jnp.broadcast_to(3., (3,))
self.assertTrue(y.aval.weak_type)
@@ -6679,7 +6474,7 @@ def test_elide_trivial_broadcasts(self):
self.assertLen(jaxpr.jaxpr.eqns, 0)
def test_convert_element_type_literal_constant_folding(self):
- # this convert_elemnt_type is nontrivial, but because it's on a scalar we
+ # this convert_element_type is nontrivial, but because it's on a scalar we
# constant-fold it
cet = partial(lax.convert_element_type, new_dtype='float16')
jaxpr = api.make_jaxpr(lambda: cet(3.))()
@@ -7408,10 +7203,11 @@ def foo_jvp(primals, tangents):
TypeError,
re.escape(
"Custom JVP rule must produce primal and tangent outputs "
- "with equal shapes and dtypes, but got float32[] and float32[1] "
- "respectively."),
+ "with corresponding shapes and dtypes. "
+ "Expected float32[] (tangent type of float32[]) but got float32[1]."),
lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),)))
+
def test_jvp_rule_doesnt_return_pair_error_message(self):
# https://github.com/google/jax/issues/2516
@@ -7741,12 +7537,13 @@ def g_jvp(primals, tangents):
self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32'))
def test_float0(self):
+ scalar_float0 = jnp.zeros((), dtype=float0)
@jax.custom_jvp
def f(x, y):
return x, y
def f_jvp(primals, _):
- # we need a defined (non-float0) tangent to trigger the rule
- return primals, (2., 1)
+ x, y = primals
+ return (x, y), (2., custom_derivatives_public.zero_from_primal(y))
f.defjvp(f_jvp)
primals = (2., 3)
@@ -7756,12 +7553,13 @@ def f_jvp(primals, _):
(primals, expected_tangents))
def test_float0_initial_style(self):
+ scalar_float0 = jnp.zeros((), dtype=float0)
@jax.custom_jvp
def f(x, y):
return x, y
def f_jvp(primals, _):
x, y = primals
- return (x, y), (2., 1)
+ return (x, y), (2., custom_derivatives_public.zero_from_primal(y))
f.defjvp(f_jvp)
def foo(x, y):
@@ -7769,8 +7567,9 @@ def foo(x, y):
return out
primals = (2., 3)
- tangents = (np.ones(()), np.zeros((), float0),)
- expected_tangents = (2., np.zeros((), float0))
+ tangents = (np.ones(()), scalar_float0)
+ expected_tangents = (2., scalar_float0)
+
self.assertAllClose(api.jvp(foo, primals, tangents),
(primals, expected_tangents))
@@ -8935,7 +8734,7 @@ def f(x):
def f_fwd(x):
return x, (2., x)
def f_rev(*_):
- return ((2., 1),)
+ return ((2., jnp.zeros(shape=(), dtype=float0)),)
f.defvjp(f_fwd, f_rev)
def foo(x, y):
@@ -9193,6 +8992,19 @@ def closure(x):
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
self.assertAllClose(g_x, 17. * x, check_dtypes=False)
+ def test_closure_convert_pytree_mismatch(self):
+ # See https://github.com/google/jax/issues/23588
+ def f(x, z):
+ return z * x
+
+ x, z = 2.0, 3.0
+ _, vjp = api.vjp(f, x, z)
+ vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x)
+ vjp_pure(x, *vjp_aux_args)
+ with self.assertRaisesRegex(
+ TypeError, "The inputs to the closure produced by closure_convert"):
+ vjp_pure(x, vjp_aux_args)
+
def test_float0_cotangents_automatically_handled(self):
@jax.custom_vjp
def f(x, y):
@@ -9862,12 +9674,12 @@ def __call__(self, *args):
# an option of inferring output types.
def custom_transpose(example_out):
if isinstance(example_out, Callable):
- out_type = core.get_aval(0.).at_least_vspace()
+ out_type = core.get_aval(0.).to_tangent_aval()
return _custom_transpose(out_type, example_out)
return partial(
_custom_transpose,
jax.tree.map(
- lambda x: core.get_aval(x).at_least_vspace(), example_out))
+ lambda x: core.get_aval(x).to_tangent_aval(), example_out))
class CustomTransposeTest(jtu.JaxTestCase):
@@ -10966,25 +10778,6 @@ def test_pmap_nested_donate_ignored(self):
class NamedCallTest(jtu.JaxTestCase):
- @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
- def test_default_name(self):
- is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
- xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
-
- @api.named_call
- def my_test_function(x):
- return x**2
-
- @jax.jit
- def f(x):
- return my_test_function(x)
-
- c = xla_computation(f)(2)
- print_opts = xla_client._xla.HloPrintOptions.short_parsable()
- print_opts.print_metadata = True
- hlo_text = c.as_hlo_module().to_string(print_opts)
- self.assertIn("my_test_function", hlo_text)
-
def test_non_jaxtype_arg(self):
# For the test to fail without the invalid JaxType filter we need to pass
# in a valid JaxType that forces the invalid Jaxtype to be raised to an
diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py
index 508dbacc2a98..00925c5f7dfc 100644
--- a/tests/cache_key_test.py
+++ b/tests/cache_key_test.py
@@ -14,8 +14,10 @@
import hashlib
import os
+import re
import sys
import unittest
+from typing import cast as type_cast
import numpy as np
@@ -29,6 +31,11 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
+from jax._src.lib.mlir import ir
+from jax._src.mesh import Mesh
+from jax._src.partition_spec import PartitionSpec as P
+from jax._src.sharding_impls import NamedSharding
+from jax._src.custom_partitioning import custom_partitioning
config.parse_flags_with_absl()
@@ -155,6 +162,49 @@ def test_different_computations(self):
cache_key.get(computation2, devices, compile_options, backend),
)
+ def test_custom_partitioning_ptr_removal(self):
+ def _partition(mesh, arg_shapes, result_shape):
+ arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
+ result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec)
+ return mesh, jax.numpy.add, result_shardings, arg_shardings
+
+ def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
+ return NamedSharding(mesh, arg_shapes[0].sharding.spec)
+
+ @custom_partitioning
+ def _cp_add(x, y):
+ return jax.numpy.add(x, y)
+
+ _cp_add.def_partition(
+ infer_sharding_from_operands=_infer_sharding_from_operands,
+ partition=_partition)
+
+ devices = np.asarray(jax.devices())
+ with Mesh(devices, ('x',)) as m:
+ computation = jax.jit(
+ _cp_add,
+ in_shardings=(NamedSharding(m, P('x')),
+ NamedSharding(m, P('x'))),
+ out_shardings=NamedSharding(m, P('x'))
+ ).lower(
+ jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
+ jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
+ ).compiler_ir()
+ pattern = (
+ r'stablehlo\.custom_call @CustomSPMDPartitioning\('
+ r'(.*?)\) \{'
+ r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
+ r'\}'
+ )
+ with config.remove_custom_partitioning_ptr_from_cache_key(True):
+ with computation.context:
+ updated_module = cache_key._remove_custom_partitioning_ptr(
+ type_cast(ir.Module, computation.operation.clone()))
+ bcs = [match[2] for
+ match in re.findall(pattern, str(updated_module), re.DOTALL)]
+ for bc in bcs:
+ self.assertEqual(bc, "REMOVED")
+
def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py
index 75ff39630705..4b1182e16b5a 100644
--- a/tests/custom_object_test.py
+++ b/tests/custom_object_test.py
@@ -68,20 +68,17 @@ def __repr__(self):
class AbstractSparseArray(core.ShapedArray):
__slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval']
- def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
- named_shape=None):
+ def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False):
super().__init__(shape, dtypes.canonicalize_dtype(dtype))
- named_shape = {} if named_shape is None else named_shape
self.index_dtype = index_dtype
self.nnz = nnz
- self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype),
- weak_type, named_shape)
+ self.data_aval = core.ShapedArray(
+ (nnz,), dtypes.canonicalize_dtype(dtype), weak_type)
self.indices_aval = core.ShapedArray(
- (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype),
- named_shape=named_shape)
+ (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype))
def update(self, shape=None, dtype=None, index_dtype=None, nnz=None,
- weak_type=None, named_shape=None):
+ weak_type=None):
if shape is None:
shape = self.shape
if dtype is None:
@@ -92,10 +89,7 @@ def update(self, shape=None, dtype=None, index_dtype=None, nnz=None,
nnz = self.nnz
if weak_type is None:
weak_type = self.weak_type
- if named_shape is None:
- named_shape = self.named_shape
- return AbstractSparseArray(
- shape, dtype, index_dtype, nnz, weak_type, named_shape)
+ return AbstractSparseArray(shape, dtype, index_dtype, nnz, weak_type)
def strip_weak_type(self):
return self
diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py
index 273c12f1b13c..5532fdf0303f 100644
--- a/tests/debugging_primitives_test.py
+++ b/tests/debugging_primitives_test.py
@@ -80,6 +80,18 @@ def f(x):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
+ def test_static_args(self):
+ @jax.jit
+ def f(arr):
+ jax.debug.print("arr {array}, dtype: {dtype}, arr {array2}",
+ array=arr, dtype=arr.dtype, array2=arr)
+ arr = jnp.array([1, 2, 3], dtype=jnp.float32)
+ with jtu.capture_stdout() as output:
+ f(arr)
+ jax.effects_barrier()
+ self.assertEqual(
+ output(), "arr [1. 2. 3.], dtype: float32, arr [1. 2. 3.]\n")
+
def test_debug_print_works_with_named_format_strings(self):
def f(x):
debug_print('x: {x}', x=x)
diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py
index 4e7898d57fe0..103357ac18ac 100644
--- a/tests/export_back_compat_test.py
+++ b/tests/export_back_compat_test.py
@@ -66,7 +66,6 @@
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
-from jax._src.lib import version as jaxlib_version
config.parse_flags_with_absl()
@@ -190,14 +189,11 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
- # TODO(b/344892332): Remove the check after the compatibility period.
- has_xla_ffi_support = jaxlib_version >= (0, 4, 31)
self.run_one_test(func, data, rtol=rtol, atol=atol)
- if has_xla_ffi_support:
- with config.export_ignore_forward_compatibility(True):
- # FFI Kernel test
- data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
- self.run_one_test(func, data, rtol=rtol, atol=atol)
+ with config.export_ignore_forward_compatibility(True):
+ # FFI Kernel test
+ data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
+ self.run_one_test(func, data, rtol=rtol, atol=atol)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
@@ -258,14 +254,11 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
- # TODO(b/344892332): Remove the check after the compatibility period.
- has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
- if has_xla_ffi_support:
- with config.export_ignore_forward_compatibility(True):
- # FFI Kernel test
- data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
- self.run_one_test(func, data, rtol=rtol, atol=atol,
- check_results=check_eig_results)
+ with config.export_ignore_forward_compatibility(True):
+ # FFI Kernel test
+ data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
+ self.run_one_test(func, data, rtol=rtol, atol=atol,
+ check_results=check_eig_results)
@staticmethod
def eigh_input(shape, dtype):
@@ -316,14 +309,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
- # TODO(b/344892332): Remove the check after the compatibility period.
- has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
- if has_xla_ffi_support:
- # FFI Kernel test
- with config.export_ignore_forward_compatibility(True):
- data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
- self.run_one_test(func, data, rtol=rtol, atol=atol,
- check_results=partial(self.check_eigh_results, operand))
+ # FFI Kernel test
+ with config.export_ignore_forward_compatibility(True):
+ data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
+ self.run_one_test(func, data, rtol=rtol, atol=atol,
+ check_results=partial(self.check_eigh_results, operand))
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
@@ -385,8 +375,6 @@ def test_cuda_lu_pivots_to_permutation(self):
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
- if jaxlib_version < (0, 4, 32):
- self.skipTest("Not implemented in older versions of jaxlib")
dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
shape = (3, 4)
@@ -416,15 +404,12 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
self.run_one_test(func, data, rtol=rtol)
- # TODO(b/344892332): Remove the check after the compatibility period.
- has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
- if has_xla_ffi_support:
- with config.export_ignore_forward_compatibility(True):
- # FFI Kernel test
- data = self.load_testdata(
- cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
- )
- self.run_one_test(func, data, rtol=rtol)
+ with config.export_ignore_forward_compatibility(True):
+ # FFI Kernel test
+ data = self.load_testdata(
+ cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
+ )
+ self.run_one_test(func, data, rtol=rtol)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
@@ -502,14 +487,11 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))
- # TODO(b/344892332): Remove the check after the compatibility period.
- has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
- if has_xla_ffi_support:
- with config.export_ignore_forward_compatibility(True):
- # FFI Kernel test
- data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
- self.run_one_test(func, data, rtol=rtol, atol=atol,
- check_results=partial(self.check_lu_results, operand,
+ with config.export_ignore_forward_compatibility(True):
+ # FFI Kernel test
+ data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
+ self.run_one_test(func, data, rtol=rtol, atol=atol,
+ check_results=partial(self.check_lu_results, operand,
dtype=dtype))
def check_svd_results(self, input, res_run, res_exp,
@@ -629,16 +611,13 @@ def func(input):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
input))
- # TODO(b/344892332): Remove the check after the compatibility period.
- has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
- if has_xla_ffi_support:
- with config.export_ignore_forward_compatibility(True):
- # FFI Kernel test
- data = self.load_testdata(
- cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
- )
- self.run_one_test(func, data, rtol=rtol, atol=atol,
- check_results=partial(self.check_svd_results, input))
+ with config.export_ignore_forward_compatibility(True):
+ # FFI Kernel test
+ data = self.load_testdata(
+ cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
+ )
+ self.run_one_test(func, data, rtol=rtol, atol=atol,
+ check_results=partial(self.check_svd_results, input))
@jtu.parameterized_filterable(
kwargs=[
diff --git a/tests/export_test.py b/tests/export_test.py
index b269aef28d79..d5884b7e6b16 100644
--- a/tests/export_test.py
+++ b/tests/export_test.py
@@ -473,7 +473,8 @@ def f(xi, xf):
# Native JAX 1st order vjp
(f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf)
- f_outi_ct = np.ones(f_outi.shape, dtype=f_outi.dtype)
+ f_outi_ct = np.ones(f_outi.shape,
+ dtype=core.primal_dtype_to_tangent_dtype(f_outi.dtype))
f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype)
xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct))
diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py
index e75e8e7d735f..f34b8211eb33 100644
--- a/tests/filecheck/math.filecheck.py
+++ b/tests/filecheck/math.filecheck.py
@@ -419,7 +419,7 @@ def integer_pow(x): return lax.integer_pow(x, 3)
print_ir(jnp.bfloat16(0))(lax.sqrt)
# CHECK-LABEL: TEST: tan float16[]
- # CHECK: chlo.tan
+ # CHECK: hlo.tan
# CHECK-SAME: tensor
print_ir(np.float16(0))(lax.tan)
diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py
index 1323196feda0..ddf42a28e2ba 100644
--- a/tests/lax_numpy_test.py
+++ b/tests/lax_numpy_test.py
@@ -1478,6 +1478,12 @@ def testTrimZeros(self, a_shape, dtype, trim):
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
+ def testTrimZerosNotOneDArray(self):
+ # TODO: make this an error after the deprecation period.
+ with self.assertWarnsRegex(DeprecationWarning,
+ r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"):
+ jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]]))
+
@jtu.sample_product(
rank=(1, 2),
dtype=default_dtypes,
@@ -2128,6 +2134,9 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False])
+ @jtu.ignore_warning(
+ category=RuntimeWarning, message='invalid value encountered in cast'
+ )
def testUniqueEqualNan(self, dtype, equal_nan):
shape = (20,)
rng = jtu.rand_some_nan(self.rng())
@@ -2800,6 +2809,23 @@ def testDigitize(self, xshape, binshape, right, reverse, dtype):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
+ @jtu.sample_product(
+ xshape=[(20,), (5, 4)],
+ binshape=[(0,), (1,), (5,)],
+ right=[True, False],
+ method=['scan', 'scan_unrolled', 'sort', 'compare_all'],
+ reverse=[True, False],
+ dtype=default_dtypes,
+ )
+ def testDigitizeMethod(self, xshape, binshape, right, method, reverse, dtype):
+ order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
+ rng = jtu.rand_default(self.rng())
+ args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
+ np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32')
+ jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right, method=method)
+ self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
+ self._CompileAndCheck(jnp_fun, args_maker)
+
@jtu.sample_product(
dtypes=[
[np.float32],
@@ -4272,14 +4298,8 @@ def testSortStableDescending(self):
self.assertArraysEqual(jnp.argsort(x), argsorted_stable)
self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable)
- @jtu.sample_product(
- [dict(shape=shape, axis=axis)
- for shape in one_dim_array_shapes
- for axis in [None]
- ],
- dtype=all_dtypes,
- )
- def testSortComplex(self, dtype, shape, axis):
+ @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
+ def testSortComplex(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker,
@@ -6288,7 +6308,8 @@ def test_lax_numpy_docstrings(self):
unimplemented = ['fromfile', 'fromiter']
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
- 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'round_']
+ 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'pow',
+ 'round_']
skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split']
for name in dir(jnp):
diff --git a/tests/lax_test.py b/tests/lax_test.py
index ce30131953af..3d31bcb7d555 100644
--- a/tests/lax_test.py
+++ b/tests/lax_test.py
@@ -3798,7 +3798,8 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind):
size_im = 11
atol = None
- if name in {"arccos", "arcsin", "arcsinh", "arccosh"}:
+ if (name in {"arccos", "arcsin", "arcsinh", "arccosh"}
+ or name in {"arctan", "arctanh"} and jax._src.lib.version > (0, 4, 31)):
# TODO(pearu): eliminate this if-block when a fix to mpmath#787
# becomes available
extra_prec_multiplier = 20
@@ -3954,21 +3955,21 @@ def regions_with_inaccuracies_keep(*to_keep):
elif name == 'arccos':
regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real')
- elif name == 'arctan':
+ elif name == 'arctan' and jax._src.lib.version <= (0, 4, 31):
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj',
- 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag')
+ 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag')
if dtype == np.complex128:
- regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
+ regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real')
- elif name == 'arctanh':
+ elif name == 'arctanh' and jax._src.lib.version <= (0, 4, 31):
regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
elif name in {'cos', 'sin'}:
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan',
- 'arcsinh', 'arcsin', 'arccosh'}:
+ 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}:
regions_with_inaccuracies.clear()
else:
assert 0 # unreachable
diff --git a/tests/layout_test.py b/tests/layout_test.py
index 2f240195f22d..1d18179ccfee 100644
--- a/tests/layout_test.py
+++ b/tests/layout_test.py
@@ -15,7 +15,6 @@
import contextlib
import math
from functools import partial
-import unittest
from absl.testing import absltest
import numpy as np
@@ -47,8 +46,9 @@ def setUp(self):
super().setUp()
def test_auto_layout(self):
- if jtu.test_device_matches(["gpu"]):
- self.skipTest("This test does not work on GPU backend.")
+ # Remove this condition when xla_extension_version >= 285
+ if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
+ self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape1 = (128, 128)
shape2 = (128, 128)
@@ -114,8 +114,9 @@ def init(x, y):
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
def test_default_layout(self):
- if jtu.test_device_matches(["gpu"]):
- self.skipTest("This test does not work on GPU backend.")
+ # Remove this condition when xla_extension_version >= 285
+ if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
+ self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (4, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
@@ -155,8 +156,9 @@ def f(x):
out_shardings=DLL.AUTO).lower(sds).compile()
def test_in_layouts_out_layouts(self):
- if jtu.test_device_matches(["gpu"]):
- self.skipTest("This test does not work on GPU backend.")
+ # Remove this condition when xla_extension_version >= 285
+ if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
+ self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (8, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
@@ -181,8 +183,9 @@ def f(x):
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
def test_sharding_and_layouts(self):
- if jtu.test_device_matches(["gpu"]):
- self.skipTest("This test does not work on GPU backend.")
+ # Remove this condition when xla_extension_version >= 285
+ if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
+ self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (4, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
@@ -246,6 +249,10 @@ def f(x, y):
def test_aot_layout_mismatch(self):
if jtu.test_device_matches(["gpu"]):
+ # The test fails on GPU because the compilation with both input and
+ # output set to auto layout is underspecified. The GPU compiler chooses
+ # the default layout as the input layout and that choice does not
+ # raise an exception.
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (256, 4, 2)
@@ -416,8 +423,6 @@ def f(x):
self.assertArraysEqual(out, inp.T)
def test_device_put_user_concrete_layout(self):
- if jtu.test_device_matches(["gpu"]):
- self.skipTest("This test does not work on GPU backend.")
shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
@@ -472,8 +477,9 @@ def test_incompatible_aval_error_device_put(self):
jax.device_put(inp, l)
def test_concrete_layout_in_shardings(self):
- if jtu.test_device_matches(["gpu"]):
- self.skipTest("This test does not work on GPU backend.")
+ # Remove this condition when xla_extension_version >= 285
+ if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
+ self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (16, 128)
@@ -482,7 +488,9 @@ def test_concrete_layout_in_shardings(self):
custom_dll = DLL(major_to_minor=(0, 1))
- @partial(jax.jit, in_shardings=Layout(custom_dll, s))
+ @partial(jax.jit,
+ in_shardings=Layout(custom_dll, s),
+ out_shardings=Layout(DLL.AUTO))
def f(x):
return x.T
@@ -502,8 +510,6 @@ def g(x):
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)
- @unittest.skipIf(xla_extension_version < 282,
- "Requires xla_extension_version >= 282")
def test_in_layouts_jit_jnp_input(self):
major_last_layout = DLL(major_to_minor=(1, 0))
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
diff --git a/tests/linalg_test.py b/tests/linalg_test.py
index 4dcdeb19e1ef..446e10abd097 100644
--- a/tests/linalg_test.py
+++ b/tests/linalg_test.py
@@ -16,7 +16,6 @@
from functools import partial
import itertools
-import unittest
import numpy as np
import scipy
@@ -2194,9 +2193,6 @@ def testHilbert(self, n):
symmetrize_output=[True, False],
)
@jtu.skip_on_devices("tpu")
- @unittest.skipIf(
- jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32"
- )
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
rng = jtu.rand_default(self.rng())
batch_size = 10
diff --git a/tests/memories_test.py b/tests/memories_test.py
index 68aecfdf669f..3e0f444a1e66 100644
--- a/tests/memories_test.py
+++ b/tests/memories_test.py
@@ -742,6 +742,29 @@ def h(x):
self.assertArraysEqual(out2, inp * 6)
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
+ def test_compute_on_basic_inline(self):
+ @compute_on('device_host')
+ @jax.jit
+ def g(x):
+ return x * 2
+
+ @functools.partial(jax.jit, inline=True)
+ def h(x):
+ y = g(x)
+ return y * 3
+
+ @jax.jit
+ def f(x):
+ return h(x)
+
+ inp = jnp.arange(8)
+ out = f(inp)
+ self.assertArraysEqual(out, inp * 6)
+
+ lowered_text = f.lower(jnp.arange(8)).as_text('hlo')
+ self.assertRegex(lowered_text,
+ 'to_apply=g.*frontend_attributes={_xla_compute_type="host"}')
+
def test_compute_on_reduction(self):
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD
index abab212d8618..6e5c94982d47 100644
--- a/tests/mosaic/BUILD
+++ b/tests/mosaic/BUILD
@@ -47,13 +47,16 @@ jax_test(
name = "gpu_test",
srcs = ["gpu_test.py"],
config_tags_overrides = {
- "gpu_h100_2gpu": {
+ "gpu_h100": {
"ondemand": False, # Include in presubmit.
},
},
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
- enable_configs = ["gpu_h100_2gpu"],
+ enable_configs = [
+ "gpu_h100",
+ "gpu_h100_2gpu",
+ ],
shard_count = 4,
deps = [
"//jax:mosaic_gpu",
diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py
index ec9a7cd8b64e..1a29bbb5736d 100644
--- a/tests/mosaic/gpu_test.py
+++ b/tests/mosaic/gpu_test.py
@@ -19,6 +19,7 @@
import itertools
import math
import operator
+import unittest
from absl.testing import absltest, parameterized
import jax
@@ -1243,6 +1244,8 @@ def kernel(ctx, dst, _):
(lambda x: mgpu.FragmentedArray.cos(x, approx=True), np.cos, True),
(lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt, False),
(lambda x: mgpu.FragmentedArray.rsqrt(x, approx=True), jax.lax.rsqrt, True),
+ (lambda x: -x, jax.lax.neg, False),
+ (lambda x: x + 42.0, lambda x: x + 42.0, False),
),
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
@@ -1387,5 +1390,28 @@ def kernel(ctx, src, dst, _):
jax.block_until_ready(f(xd))
+class TorchTest(TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ try:
+ import torch
+ except ImportError:
+ raise unittest.SkipTest("Test requires PyTorch")
+ cls.torch = torch
+
+ def test_basic(self):
+ def kernel(ctx, i_gmem, o_gmem, _):
+ x = mgpu.FragmentedArray.load_strided(i_gmem)
+ (x + x).store_untiled(o_gmem)
+
+ ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
+ x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
+ f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
+ y = f(x)
+ np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
+ del y # Make sure the destructor runs successfully.
+
+
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
diff --git a/tests/nn_test.py b/tests/nn_test.py
index 3722db42671c..be07de184e60 100644
--- a/tests/nn_test.py
+++ b/tests/nn_test.py
@@ -38,11 +38,11 @@
config.parse_flags_with_absl()
-def _is_required_cudnn_version_satisfied():
+def _is_required_cudnn_version_satisfied(min_cudnn_version):
return (
jtu.is_cuda_compute_capability_at_least("8.0") and
cuda_versions is not None and
- cuda_versions.cudnn_get_version() >= 8904
+ cuda_versions.cudnn_get_version() >= min_cudnn_version
)
def _check_cudnn_backend(fn, *args, **kwargs):
@@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
- if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
+ if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
@@ -102,13 +102,15 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
@parameterized.product(
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
- ('custom', 'padding'), ('bias', 'causal')],
+ ('custom', 'padding'), ('bias', 'causal'),
+ ('causal', 'sliding_window')],
)
def testDotProductAttentionMask(self, mask_mode):
- if not _is_required_cudnn_version_satisfied():
- raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)
+ min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
+ if not _is_required_cudnn_version_satisfied(min_cudnn_version):
+ raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16
B, S, T, N, H = 2, 128, 128, 4, 32
@@ -119,6 +121,7 @@ def testDotProductAttentionMask(self, mask_mode):
grad = random.normal(keys[3], (B, T, N, H), dtype)
bias, mask = None, None
q_seqlen, kv_seqlen = None, None
+ window_size = None
is_causal = 'causal' in mask_mode
if 'padding' in mask_mode:
@@ -130,6 +133,8 @@ def testDotProductAttentionMask(self, mask_mode):
mask = custom_mask[None, None, :, :]
if 'bias' in mask_mode:
bias = random.normal(keys[4], (1, N, T, S), dtype)
+ if 'sliding_window' in mask_mode:
+ window_size = (3, 2) if is_causal else (3, 0)
sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
@@ -141,9 +146,11 @@ def testDotProductAttentionMask(self, mask_mode):
# Convert the kargs to positional args for the jax.vjp.
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
+ local_window_size=window_size,
)
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
+ local_window_size=window_size,
)
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD
index 9b8167527b92..6804d91675c1 100644
--- a/tests/pallas/BUILD
+++ b/tests/pallas/BUILD
@@ -98,13 +98,13 @@ jax_test(
disable_configs = [
"gpu",
"gpu_x32",
- "gpu_a100",
"gpu_p100",
"gpu_p100_x32",
- "gpu_h100",
],
enable_configs = [
+ "gpu_a100",
"gpu_a100_x32",
+ "gpu_h100",
"gpu_h100_x32",
],
shard_count = {
@@ -422,6 +422,20 @@ jax_test(
] + py_deps("hypothesis"),
)
+jax_test(
+ name = "tpu_pallas_async_test",
+ srcs = ["tpu_pallas_async_test.py"],
+ disable_backends = [
+ "cpu",
+ "gpu",
+ ],
+ tags = [
+ ],
+ deps = [
+ "//jax:pallas_tpu",
+ ],
+)
+
jax_test(
name = "tpu_pallas_mesh_test",
srcs = ["tpu_pallas_mesh_test.py"],
diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py
index 59e28db6d9e2..d49b83fe160b 100644
--- a/tests/pallas/indexing_test.py
+++ b/tests/pallas/indexing_test.py
@@ -647,27 +647,37 @@ class IndexerOpsInterpretTest(IndexerOpsTest):
# TODO(ayx): Fix all test cases here
_ADVANCED_INDEXER_TEST_CASES = [
- ((8, 2), lambda arr, a, b, c, d: arr[2]),
- ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
- ((16, 3), lambda arr, a, b, c, d: arr[a, a]),
- ((16, 16), lambda arr, a, b, c, d: arr[::4, ::4]),
+ # integer
+ ((3, 2), lambda arr, a, b, c, d: arr[2]),
+ # slice
+ ((12, 12), lambda arr, a, b, c, d: arr[::4, ::4]),
((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]),
- ((16, 3), lambda arr, a, b, c, d: arr[a, :]),
- # ((16, 3), lambda arr, a, b, c, d: arr[:, a]),
- ((16, 3), lambda arr, a, b, c, d: arr[a, ::4]),
- # ((16, 3), lambda arr, a, b, c, d: arr[::4, a]),
+ ((8, 2), lambda arr, a, b, c, d: arr[1::3, :]),
+ # array
+ ((4, 3), lambda arr, a, b, c, d: arr[a]),
+ ((4, 3, 2), lambda arr, a, b, c, d: arr[c, c]),
+ # integer + 1-D array
+ ((4, 3), lambda arr, a, b, c, d: arr[2, a]),
+ ((4, 3), lambda arr, a, b, c, d: arr[a, 2]),
+ # slice + 1-D array
+ ((4, 3), lambda arr, a, b, c, d: arr[a, :]),
+ # ((4, 3), lambda arr, a, b, c, d: arr[:, a]),
+ ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]),
+ # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]),
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]),
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]),
- # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, ::2, ::2]),
- # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, ::2, ::2]),
- # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
- # ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
+ ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
+ ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]),
- # ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, c]),
- ((8, 6, 4), lambda arr, a, b, c, d: arr[a]),
- ((6, 8, 4), lambda arr, a, b, c, d: arr[c, c]),
- ((6, 8, 4), lambda arr, a, b, c, d: arr[c, ::3]),
- # ((8, 6, 4), lambda arr, a, b, c, d: arr[::3, c]),
+ ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
+ ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]),
+ # slice + array w/ broadcasting
+ ((8, 8, 3, 6), lambda arr, a, b, c, d: \
+ arr[b[:, None], ::4, a[None], a[:, None]]),
+ # integer + slice + 1-D array
+ ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]),
+ ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, a, ::2]),
+ # boolean
# ((6, 2), lambda arr, a, b, c, d: arr[d]),
# ((8, 6), lambda arr, a, b, c, d: arr[::4, d]),
]
diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py
index e90033be151d..17ef26c7f9b3 100644
--- a/tests/pallas/mosaic_gpu_test.py
+++ b/tests/pallas/mosaic_gpu_test.py
@@ -19,7 +19,7 @@
import jax
from jax._src import config
from jax._src import test_util as jtu
-import jax._src.pallas.mosaic_gpu.core as plgpu
+import jax._src.pallas.mosaic_gpu as plgpu
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
@@ -78,6 +78,23 @@ def kernel(x_ref, o_ref):
x = jnp.arange(128 * 2).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
+ def test_add_one_grid_with_scratch(self):
+
+ @functools.partial(
+ pl.pallas_call,
+ out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
+ in_specs=[pl.BlockSpec((128,), lambda *i: i)],
+ out_specs=pl.BlockSpec((128,), lambda *i: i),
+ scratch_shapes=[plgpu.SMEM((128,), jnp.float32)],
+ grid=2,
+ )
+ def kernel(x_ref, o_ref, scratch_ref):
+ scratch_ref[...] = x_ref[...] + 1
+ o_ref[...] = scratch_ref[...]
+
+ x = jnp.arange(256).astype(jnp.float32)
+ np.testing.assert_array_equal(kernel(x), x + 1.0)
+
@parameterized.product(num_stages=[1, 2, 3])
def test_add_one_grid_pipelined(self, num_stages):
@@ -98,6 +115,42 @@ def kernel(x_ref, o_ref):
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
+ def test_add_one_with_async_copy_smem_to_gmem(self):
+ @functools.partial(
+ pl.pallas_call,
+ out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
+ out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
+ scratch_shapes=[plgpu.SMEM((128,), jnp.float32)],
+ )
+ def kernel(x_ref, o_ref_gmem, scratch_ref):
+ scratch_ref[...] = x_ref[...] + 1
+ plgpu.async_copy_smem_to_gmem(scratch_ref, o_ref_gmem)
+ plgpu.wait_smem_to_gmem(0)
+
+ x = jnp.arange(128).astype(jnp.float32)
+ np.testing.assert_array_equal(kernel(x), x + 1.0)
+
+ def test_add_one_with_async_copy_gmem_to_smem(self):
+
+ @functools.partial(
+ pl.pallas_call,
+ out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
+ in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
+ scratch_shapes=[
+ plgpu.SMEM((128,), jnp.float32),
+ plgpu.Barrier(num_arrivals=1),
+ ],
+ )
+ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
+ plgpu.async_copy_gmem_to_smem(
+ x_ref_gmem, scratch_ref, barrier=barrier_ref
+ )
+ plgpu.wait_barrier(barrier_ref)
+ o_ref[...] = scratch_ref[...] + 1
+
+ x = jnp.arange(128).astype(jnp.float32)
+ np.testing.assert_array_equal(kernel(x), x + 1.0)
+
def test_add_doubled_sum(self):
@functools.partial(
pl.pallas_call,
@@ -109,6 +162,19 @@ def kernel(x_ref, o_ref):
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
+ @parameterized.parameters(False, True)
+ def test_rsqrt(self, approx_math):
+ @functools.partial(
+ pl.pallas_call,
+ out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
+ compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
+ )
+ def kernel(x_ref, o_ref):
+ o_ref[...] = jax.lax.rsqrt(x_ref[...])
+
+ x = jnp.arange(128).astype(jnp.float32)
+ np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x))
+
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
def test_layer_norm(self, input_factor):
eps = 1e-5
@@ -245,6 +311,18 @@ def kernel(x_ref, o_ref):
result = kernel(x)
self.assertEqual(result.shape, (4, 2, 64, 64))
+ def test_fori_loop(self):
+ @functools.partial(
+ pl.pallas_call,
+ out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
+ )
+ def kernel(x_ref, o_ref):
+ # Equivalent to x_ref[...] + 2 + 3.
+ o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...])
+
+ x = jnp.arange(256).astype(jnp.float32)
+ np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0)
+
if __name__ == "__main__":
absltest.main()
diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py
index 627ee0e8a227..63c3148e8108 100644
--- a/tests/pallas/ops_test.py
+++ b/tests/pallas/ops_test.py
@@ -31,6 +31,7 @@
from jax import lax
from jax import random
from jax._src import config
+from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
from jax._src import test_util as jtu
@@ -59,6 +60,10 @@
jtu.setup_hypothesis(max_examples=50)
+intx = dtypes.canonicalize_dtype(jnp.int64)
+floatx = dtypes.canonicalize_dtype(jnp.float64)
+
+
def smem_on_tpu():
if jtu.test_device_matches(["tpu"]):
return pltpu.SMEM
@@ -245,8 +250,6 @@ class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
def setUp(self):
- if jax.config.x64_enabled:
- self.skipTest("Only works in 32-bit")
if not self.INTERPRET:
if jtu.device_under_test() == "cpu":
self.skipTest("Only interpret mode supported on CPU")
@@ -263,11 +266,6 @@ def pallas_call(cls, *args, **kwargs):
class OpsTest(PallasBaseTest):
- def setUp(self):
- super().setUp()
- if jax.config.x64_enabled:
- self.skipTest("Only works in 32-bit")
-
@parameterized.named_parameters(
(fn.__name__, fn, dtype) for fn, dtype in [
(lax.pow, jnp.float32),
@@ -340,7 +338,7 @@ def kernel(x_ref, y_ref, o_ref):
result = self.pallas_call(
kernel,
- out_shape=jax.ShapeDtypeStruct([1, 128], jnp.int32),
+ out_shape=jax.ShapeDtypeStruct([1, 128], intx),
in_specs=[
pl.BlockSpec(memory_space=smem_on_tpu()),
pl.BlockSpec(memory_space=smem_on_tpu()),
@@ -435,13 +433,15 @@ def kernel(x_ref, ones_ref, o_ref):
float_value = jnp.where(reduced_as_bool, 1.0, 0.0)
o_ref[0, 0] = float_value[0, 0]
- if input_type == 'all_true':
+ if input_type == "all_true":
x = jnp.ones((8, 128), dtype=jnp.float32)
- elif input_type == 'all_false':
+ elif input_type == "all_false":
x = jnp.zeros((8, 128), dtype=jnp.float32)
- elif input_type == 'one_false':
+ elif input_type == "one_false":
x = jnp.ones((8, 128), dtype=jnp.float32)
x = x.at[0, 0].set(0.0)
+ else:
+ raise ValueError(f"Unknown input type: {input_type}")
ones = jnp.ones_like(x)
result = self.pallas_call(
@@ -451,7 +451,7 @@ def kernel(x_ref, ones_ref, o_ref):
pl.BlockSpec((8, 128), lambda *_: (0, 0)),
],
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=smem_on_tpu()),
- out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32),
+ out_shape=jax.ShapeDtypeStruct([1, 1], floatx),
grid=(1,),
)(x, ones)
np.testing.assert_array_equal(result[0, 0], float(expected_result))
@@ -473,7 +473,7 @@ def kernel(x_ref, o_ref):
pl.BlockSpec((8, 128), lambda *_: (0, 0)),
],
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
- out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32),
+ out_shape=jax.ShapeDtypeStruct([1, 1], floatx),
grid=(1,),
)(x)
@@ -702,22 +702,28 @@ def kernel(x_ref, o_ref):
np.testing.assert_array_equal(out, expected)
@parameterized.product(
- dtype=[jnp.float32],
- value=[-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4],
+ dtype=[jnp.float32, jnp.float64],
+ value=[-3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4],
)
def test_erf_inv(self, dtype, value):
+ if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64:
+ self.skipTest("float64 is not supported on TPU")
+
@functools.partial(
self.pallas_call,
- # TODO(ayx): add float64 support for `erf_inv`
- out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
+ out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = lax.erf_inv(x_ref[...])
- x = jnp.full((8, 128), value, dtype=dtype)
- out = kernel(x)
- expected = lax.erf_inv(x)
- np.testing.assert_array_equal(out, expected)
+ with contextlib.ExitStack() as stack:
+ if jnp.dtype(dtype).itemsize == 8:
+ stack.enter_context(config.enable_x64(True))
+
+ x = jnp.full((8, 128), value, dtype=dtype)
+ out = kernel(x)
+ expected = lax.erf_inv(x)
+ np.testing.assert_array_equal(out, expected)
class OpsInterpretTest(OpsTest):
@@ -746,6 +752,8 @@ class OpsExtraTest(PallasBaseTest):
def setUp(self):
super().setUp()
+ if jax.config.x64_enabled:
+ self.skipTest("Only works in 32-bit")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
# TODO: most tests fail on TPU in non-interpret mode
self.skipTest("On TPU the test works only in interpret mode")
diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py
index 1d3316760fe8..5ee30ba3382a 100644
--- a/tests/pallas/pallas_test.py
+++ b/tests/pallas/pallas_test.py
@@ -33,7 +33,6 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax.control_flow.for_loop import for_loop
-from jax._src.lib import version as jaxlib_version
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
@@ -371,17 +370,10 @@ def copy_kernel(x_ref, o_ref):
test_context = contextlib.nullcontext()
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
- if jaxlib_version < (0, 4, 32):
- # TODO(b/356116061): Remove the old rank condition
- if rank < 2:
- test_context = self.assertRaisesRegex(
- ValueError,
- "TPU lowering currently supports only blocks of rank >= 2")
- else:
- if rank < 1:
- test_context = self.assertRaisesRegex(
- ValueError,
- "TPU lowering currently supports only blocks of rank >= 1")
+ if rank < 1:
+ test_context = self.assertRaisesRegex(
+ ValueError,
+ "TPU lowering currently supports only blocks of rank >= 1")
if rank >= 1:
bs0, as0 = block_shape[-1], shape[-1]
diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py
index a34c2b2f2f61..1d57dc164294 100644
--- a/tests/pallas/tpu_ops_test.py
+++ b/tests/pallas/tpu_ops_test.py
@@ -16,15 +16,15 @@
import sys
import unittest
-import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
-
import jax
from jax import lax
-import jax.numpy as jnp
from jax._src import test_util as jtu
+from jax._src.pallas import utils as pallas_utils
from jax.experimental import pallas as pl
+import jax.numpy as jnp
+import numpy as np
if sys.platform != "win32":
from jax.experimental.pallas import tpu as pltpu
@@ -67,28 +67,29 @@ def pallas_call(cls, *args, **kwargs):
class OpsTest(PallasBaseTest):
- @parameterized.product(from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES)
- def test_bitcast(self, from_dtype, to_dtype):
- # TODO(jevinjiang): remove this after 2nd minor large tiling is enabled.
- if (not jtu.is_device_tpu_at_least(version=5)) and (
- from_dtype in (jnp.int8, jnp.int16) or to_dtype in (jnp.int8, jnp.int16)
- ):
- self.skipTest(
- "Not implemented: packing and unpacking int8, int16 are not supported"
- " on < TPUv5"
- )
+ @parameterized.product(
+ from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES, is_ref_bitcast=[False, True]
+ )
+ def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast):
+ if not jtu.is_device_tpu_at_least(version=4):
+ self.skipTest("Run on TPUv4+ to have expected memory layout")
if from_dtype == to_dtype:
self.skipTest("No bitcast needed")
if from_dtype == jnp.bool_ or to_dtype == jnp.bool_:
self.skipTest("Bitcasting with bool is not supported")
def kernel(x_ref, y_ref):
- y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype)
-
- m, n = 32, 256
- shape = (m, n)
- out_shape = (m * from_dtype.dtype.itemsize // to_dtype.dtype.itemsize, n)
- inp = np.arange(np.prod(shape), dtype=from_dtype).reshape(shape)
+ if is_ref_bitcast:
+ y_ref[...] = x_ref.bitcast(to_dtype)[...]
+ else:
+ y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype)
+
+ m, n = 1, 256
+ in_packing = 32 // pallas_utils.dtype_bitwidth(from_dtype)
+ out_packing = 32 // pallas_utils.dtype_bitwidth(to_dtype)
+ in_shape = (m * in_packing, n)
+ out_shape = (m * out_packing, n)
+ inp = np.arange(np.prod(in_shape), dtype=from_dtype).reshape(in_shape)
out = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype),
diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py
new file mode 100644
index 000000000000..4f9d591dbea4
--- /dev/null
+++ b/tests/pallas/tpu_pallas_async_test.py
@@ -0,0 +1,759 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test TPU-specific uses of Pallas async APIs."""
+
+import functools
+from typing import Any
+from absl.testing import absltest
+from absl.testing import parameterized
+import jax
+from jax._src import test_util as jtu
+from jax.experimental import pallas as pl
+from jax.experimental import shard_map
+from jax.experimental.pallas import tpu as pltpu
+import jax.numpy as jnp
+import numpy as np
+
+
+jax.config.parse_flags_with_absl()
+P = jax.sharding.PartitionSpec
+partial = functools.partial
+
+Future = Any
+
+
+def make_async_copy(target_memory_space=None):
+ if target_memory_space is None:
+ target_memory_space = pltpu.ANY
+ @jax.named_call
+ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]:
+
+ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, sem):
+ del aliased_x_ref
+ pltpu.make_async_copy(x_ref, o_ref, sem).start()
+
+ x, out, sem = pl.pallas_call(
+ copy_start_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x
+ target_memory_space(x.shape, x.dtype), # out
+ pltpu.SemaphoreType.DMA(()),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=target_memory_space),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ),
+ input_output_aliases={0: 0},
+ )(x)
+ return x, (out, sem)
+
+ @jax.named_call
+ def copy_done(x: jax.Array, future: Future) -> jax.Array:
+ out, sem = future
+
+ def copy_done_kernel(x_ref, o_ref, sem, aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(x_ref, o_ref, sem).wait()
+
+ out = pl.pallas_call(
+ copy_done_kernel,
+ out_shape=target_memory_space(x.shape, x.dtype), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=target_memory_space),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=target_memory_space),
+ input_output_aliases={1: 0},
+ )(x, out, sem)
+ return out
+
+ return copy_start, copy_done
+
+
+def make_async_slice(index: int):
+
+ def async_slice_start_kernel(x_ref, aliased_x_ref, o_ref, sem):
+ del aliased_x_ref
+ pltpu.make_async_copy(x_ref.at[index], o_ref, sem).start()
+
+ def async_slice_done_kernel(x_ref, o_ref, sem, aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(x_ref.at[index], o_ref, sem).wait()
+
+ @jax.named_call
+ def async_slice_start(x: jax.Array) -> tuple[jax.Array, Future]:
+
+ x, out, sem = pl.pallas_call(
+ async_slice_start_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x
+ jax.ShapeDtypeStruct(x.shape[1:], x.dtype), # out
+ pltpu.SemaphoreType.DMA(()),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ),
+ input_output_aliases={0: 0},
+ )(x)
+ return x, (out, sem)
+
+ @jax.named_call
+ def async_slice_done(
+ x: jax.Array, future: Future
+ ) -> tuple[jax.Array, Future]:
+ out, sem = future
+ out = pl.pallas_call(
+ async_slice_done_kernel,
+ out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)),
+ input_output_aliases={1: 0},
+ )(x, out, sem)
+ return out
+
+ return async_slice_start, async_slice_done
+
+
+def make_async_dynamic_slice(index: jax.Array):
+
+ def async_dslice_start_kernel(index_ref, x_ref, aliased_x_ref, o_ref, sem):
+ del aliased_x_ref
+ pltpu.make_async_copy(x_ref.at[index_ref[0]], o_ref, sem).start()
+
+ def async_dslice_done_kernel(x_ref, o_ref, sem, aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(x_ref.at[0], o_ref, sem).wait()
+
+ @jax.named_call
+ def async_dslice_start(x: jax.Array) -> tuple[jax.Array, Future]:
+
+ x, out, sem = pl.pallas_call(
+ async_dslice_start_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x
+ jax.ShapeDtypeStruct(x.shape[1:], x.dtype), # out
+ pltpu.SemaphoreType.DMA(()),
+ ),
+ grid_spec=pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=1,
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ),
+ ),
+ input_output_aliases={1: 0},
+ )(index[None], x)
+ return x, (out, sem)
+
+ @jax.named_call
+ def async_dslice_done(
+ x: jax.Array, future: Future
+ ) -> tuple[jax.Array, Future]:
+ out, sem = future
+ out = pl.pallas_call(
+ async_dslice_done_kernel,
+ out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)),
+ input_output_aliases={1: 0},
+ )(x, out, sem)
+ return out
+
+ return async_dslice_start, async_dslice_done
+
+
+class PallasCallAsyncCopyTest(parameterized.TestCase):
+ # TODO(b/368123537): add more tests
+
+ def setUp(self):
+ super().setUp()
+ if not jtu.is_device_tpu_at_least(4):
+ self.skipTest('DMAs only guaranteed to work ou TPU v4+')
+
+ def test_basic_async_copy(self):
+ @jax.jit
+ def f(x):
+ copy_start, copy_done = make_async_copy()
+ x, fut = copy_start(x)
+ y = copy_done(x, fut)
+ return y
+
+ x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x)
+
+ def test_multiple_async_copy(self):
+ @jax.jit
+ def f(x):
+ copy_start, copy_done = make_async_copy()
+ x, fut = copy_start(x)
+ x2, fut2 = copy_start(x)
+ y = copy_done(x, fut)
+ y2 = copy_done(x2, fut2)
+ return y, y2
+
+ x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
+ y, y2 = f(x)
+ np.testing.assert_array_equal(y, x)
+ np.testing.assert_array_equal(y2, x)
+
+ def test_async_slice(self):
+ @jax.jit
+ def f(x):
+ async_slice_start, async_slice_done = make_async_slice(2)
+ x, fut = async_slice_start(x)
+ y = async_slice_done(x, fut)
+ return y
+
+ x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x[2])
+
+ def test_async_dynamic_slice(self):
+ @jax.jit
+ def f(x, i):
+ async_slice_start, async_slice_done = make_async_dynamic_slice(i)
+ x, fut = async_slice_start(x)
+ y = async_slice_done(x, fut)
+ return y
+
+ x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32)
+ y = f(x, 2)
+ np.testing.assert_array_equal(y, x[2])
+
+ def test_multi_async_dynamic_slice(self):
+ @jax.jit
+ def f(x, i, j):
+ async_slice_start, async_slice_done = make_async_dynamic_slice(i)
+ async_slice_start2, async_slice_done2 = make_async_dynamic_slice(j)
+ x, fut = async_slice_start(x)
+ x2, fut2 = async_slice_start2(x)
+ y = async_slice_done(x, fut)
+ y2 = async_slice_done2(x2, fut2)
+ return y, y2
+
+ x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32)
+ y, y2 = f(x, 2, 3)
+ np.testing.assert_array_equal(y, x[2])
+ np.testing.assert_array_equal(y2, x[3])
+
+ def test_basic_async_copy_into_vmem(self):
+ @jax.jit
+ def f(x):
+ copy_start, copy_done = make_async_copy(pltpu.VMEM)
+ x, fut = copy_start(x)
+ y = copy_done(x, fut)
+ return y
+
+ x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x)
+
+ def test_multiple_async_copy_into_vmem(self):
+ @jax.jit
+ def f(x):
+ copy_start, copy_done = make_async_copy(pltpu.VMEM)
+ x1, fut = copy_start(x)
+ x2, fut2 = copy_start(x)
+ y = copy_done(x1, fut)
+ y2 = copy_done(x2, fut2)
+ return y, y2
+
+ x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32)
+ y, y2 = f(x)
+ np.testing.assert_array_equal(y, x)
+ np.testing.assert_array_equal(y2, x)
+
+ def test_copy_in_a_loop(self):
+
+ @jax.jit
+ def f(x):
+ def body(_, carry):
+ x = carry
+ copy_start, copy_done = make_async_copy()
+ x, fut = copy_start(x)
+ y = copy_done(x, fut)
+ return y
+ x = jax.lax.fori_loop(0, x.shape[0], body, x)
+ return x
+
+ x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x)
+
+ def test_staggered_copy_in_a_loop(self):
+
+ @jax.jit
+ def f(x):
+ copy_start, copy_done = make_async_copy()
+ x, fut = copy_start(x)
+ def body(_, carry):
+ x, fut = carry
+ y = copy_done(x, fut)
+ y, fut = copy_start(y)
+ return y, fut
+ # We *must* use unroll > 2 here because of aliasing constraints. XLA will
+ # introduce copies of the active buffer with unroll=1.
+ y, fut = jax.lax.fori_loop(0, x.shape[0] - 1, body, (x, fut), unroll=2)
+ x = copy_done(y, fut)
+ return x
+
+ x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x)
+
+ def test_full_copy_in_a_loop(self):
+
+ @jax.jit
+ def f(x):
+ y = jnp.zeros_like(x)
+ def body(i, carry):
+ x, ys = carry
+ copy_start, copy_done = make_async_dynamic_slice(i)
+ x, fut = copy_start(x)
+ y = copy_done(x, fut)
+ ys = ys.at[i].set(y)
+ return x, ys
+ _, y = jax.lax.fori_loop(0, x.shape[0], body, (x, y))
+ return y
+
+ x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x)
+
+ def test_staggered_full_copy_in_a_loop(self):
+
+ @jax.jit
+ def f(x):
+ y = jnp.zeros_like(x)
+ copy_start, _ = make_async_dynamic_slice(jnp.array(0))
+ x, fut = copy_start(x)
+ def body(i, carry):
+ x, fut, ys = carry
+ _, copy_done = make_async_dynamic_slice(i)
+ y = copy_done(x, fut)
+ copy_start, _ = make_async_dynamic_slice(i + 1)
+ ys = ys.at[i].set(y)
+ x, fut = copy_start(x)
+ return x, fut, ys
+ # We can use unroll=1 here because we have the ys.at[i].set(y) in the
+ # middle
+ x, fut, ys = jax.lax.fori_loop(0, x.shape[0] - 1, body, (x, fut, y),
+ unroll=1)
+ _, copy_done = make_async_dynamic_slice(x.shape[0] - 1)
+ y = copy_done(x, fut)
+ ys = ys.at[x.shape[0] - 1].set(y)
+ return ys
+
+ x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32)
+ y = f(x)
+ np.testing.assert_array_equal(y, x)
+
+
+def make_async_remote_copy(axis_name: str, direction: str = 'right',
+ target_memory_space=None):
+ if target_memory_space is None:
+ target_memory_space = pltpu.ANY
+ @jax.named_call
+ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]:
+
+ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem):
+ del aliased_x_ref
+ axis_size = jax.lax.psum(1, axis_name)
+ left_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
+ )
+ right_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) + 1, axis_size
+ )
+ if direction == 'right':
+ src_neighbor = left_neighbor
+ dst_neighbor = right_neighbor
+ else:
+ src_neighbor = right_neighbor
+ dst_neighbor = left_neighbor
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor, core_index=0)
+ pltpu.semaphore_wait(barrier_sem, 1)
+ pltpu.make_async_remote_copy(
+ x_ref, o_ref, send_sem, recv_sem, device_id=dst_neighbor,
+ ).start()
+
+ x, out, send_sem, recv_sem = pl.pallas_call(
+ copy_start_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x
+ target_memory_space(x.shape, x.dtype), # out
+ pltpu.SemaphoreType.DMA(()), # send_sem
+ pltpu.SemaphoreType.DMA(()), # recv_sem
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=target_memory_space),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ),
+ input_output_aliases={0: 0},
+ compiler_params=pltpu.TPUCompilerParams(collective_id=0),
+ )(x)
+ return x, (out, send_sem, recv_sem)
+
+ @jax.named_call
+ def send_done(x: jax.Array, future: Future) -> jax.Array:
+ _, send_sem, _ = future
+
+ def send_done_kernel(x_ref, send_sem, aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(x_ref, x_ref, send_sem).wait()
+
+ x = pl.pallas_call(
+ send_done_kernel,
+ out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={0: 0},
+ )(x, send_sem)
+ return x
+
+ @jax.named_call
+ def recv_done(x: jax.Array, future: Future) -> jax.Array:
+ out, _, recv_sem = future
+
+ def send_done_kernel(x_ref, o_ref, send_sem, aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(x_ref, o_ref, send_sem).wait()
+
+ out = pl.pallas_call(
+ send_done_kernel,
+ out_shape=target_memory_space(x.shape, x.dtype), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=target_memory_space),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=target_memory_space),
+ input_output_aliases={1: 0},
+ )(x, out, recv_sem)
+ return out
+
+ return copy_start, send_done, recv_done
+
+
+def make_bidi_collective_permute(axis_name: str):
+ @jax.named_call
+ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]:
+
+ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems):
+ del aliased_x_ref
+ axis_size = jax.lax.psum(1, axis_name)
+ left_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
+ )
+ right_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) + 1, axis_size
+ )
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor, core_index=0)
+ pltpu.semaphore_signal(
+ barrier_sem, device_id=right_neighbor, core_index=0
+ )
+ pltpu.semaphore_wait(barrier_sem, 2)
+ assert x.shape[0] % 2 == 0, x.shape
+ pltpu.make_async_remote_copy(
+ x_ref.at[pl.ds(0, x.shape[0] // 2)],
+ o_ref.at[pl.ds(0, x.shape[0] // 2)],
+ right_sems[0],
+ right_sems[1],
+ device_id=right_neighbor,
+ ).start()
+ pltpu.make_async_remote_copy(
+ x_ref.at[pl.ds(x.shape[0] // 2, x.shape[0] // 2)],
+ o_ref.at[pl.ds(x.shape[0] // 2, x.shape[0] // 2)],
+ left_sems[0],
+ left_sems[1],
+ device_id=left_neighbor,
+ ).start()
+
+ x, out, left_sems, right_sems = pl.pallas_call(
+ copy_start_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x
+ pltpu.ANY(x.shape, x.dtype), # out
+ (pltpu.SemaphoreType.DMA(()),) * 2, # left_sems
+ (pltpu.SemaphoreType.DMA(()),) * 2, # right_sems
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2,
+ (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2,
+ ),
+ input_output_aliases={0: 0},
+ compiler_params=pltpu.TPUCompilerParams(collective_id=0),
+ )(x)
+ return x, (out, left_sems, right_sems)
+
+ @jax.named_call
+ def send_done(x: jax.Array, future: Future) -> jax.Array:
+ _, (send_left_sem, _), (send_right_sem, _) = future
+
+ def send_done_kernel(x_ref, send_left_sem, send_right_sem, aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(
+ x_ref.at[x_ref.shape[0] // 2 :],
+ x_ref.at[x_ref.shape[0] // 2 :],
+ send_left_sem,
+ ).wait()
+ pltpu.make_async_copy(
+ x_ref.at[x_ref.shape[0] // 2 :],
+ x_ref.at[x_ref.shape[0] // 2 :],
+ send_right_sem,
+ ).wait()
+
+ x = pl.pallas_call(
+ send_done_kernel,
+ out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={0: 0},
+ )(x, send_left_sem, send_right_sem)
+ return x
+
+ @jax.named_call
+ def recv_done(x: jax.Array, future: Future) -> jax.Array:
+ out, (_, recv_left_sem), (_, recv_right_sem) = future
+
+ def recv_done_kernel(o_ref, x_ref, recv_left_sem, recv_right_sem,
+ aliased_o_ref):
+ del aliased_o_ref
+ pltpu.make_async_copy(
+ x_ref.at[o_ref.shape[0] // 2 :],
+ o_ref.at[o_ref.shape[0] // 2 :],
+ recv_left_sem,
+ ).wait()
+ pltpu.make_async_copy(
+ x_ref.at[o_ref.shape[0] // 2 :],
+ o_ref.at[o_ref.shape[0] // 2 :],
+ recv_right_sem,
+ ).wait()
+
+ out = pl.pallas_call(
+ recv_done_kernel,
+ out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={0: 0},
+ )(out, x, recv_left_sem, recv_right_sem)
+ return out
+ return copy_start, send_done, recv_done
+
+
+class PallasCallRemoteAsyncCopyTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ if not jtu.is_device_tpu_at_least(4):
+ self.skipTest('DMAs only guaranteed to work ou TPU v4+')
+ if jax.device_count() < 2:
+ self.skipTest('Test only works with >2 devices')
+
+ def test_basic_remote_copy(self):
+
+ mesh = jax.make_mesh((jax.device_count(),), ('x',))
+
+ @jax.jit
+ @partial(
+ shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'),
+ check_rep=False,
+ )
+ def f(x):
+ copy_start, send_done, recv_done = make_async_remote_copy('x')
+ x, fut = copy_start(x)
+ x = send_done(x, fut)
+ y = recv_done(x, fut)
+ return y
+
+ x = jax.random.normal(
+ jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32
+ )
+ y = f(x)
+ expected = jnp.roll(x, shift=1, axis=0)
+ np.testing.assert_array_equal(y, expected)
+
+ def test_multi_remote_copy(self):
+
+ mesh = jax.make_mesh((jax.device_count(),), ('x',))
+
+ @jax.jit
+ @partial(
+ shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'),
+ check_rep=False,
+ )
+ def f(x):
+ copy_start, send_done, recv_done = make_async_remote_copy(
+ 'x', direction='right'
+ )
+ copy_start2, send_done2, recv_done2 = make_async_remote_copy(
+ 'x', direction='left'
+ )
+ x, fut = copy_start(x)
+ x, fut2 = copy_start2(x)
+ x = send_done(x, fut)
+ x = send_done2(x, fut2)
+ y = recv_done(x, fut)
+ y2 = recv_done2(x, fut2)
+ return y, y2
+
+ x = jax.random.normal(
+ jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32
+ )
+ y, y2 = f(x)
+ y_expected = jnp.roll(x, shift=1, axis=0)
+ y2_expected = jnp.roll(x, shift=-1, axis=0)
+ np.testing.assert_array_equal(y, y_expected)
+ np.testing.assert_array_equal(y2, y2_expected)
+
+ def test_basic_collective_permute_loop(self):
+
+ mesh = jax.make_mesh((jax.device_count(),), ('x',))
+
+ @jax.jit
+ @partial(
+ shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'),
+ check_rep=False,
+ )
+ def f(x):
+ copy_start, send_done, recv_done = make_async_remote_copy('x')
+ def body(_, x):
+ x, fut = copy_start(x)
+ x = send_done(x, fut)
+ y = recv_done(x, fut)
+ return y
+ # Send all the way around except for one step
+ return jax.lax.fori_loop(0, jax.device_count() - 1, body, x)
+ x = jax.random.normal(
+ jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32
+ )
+ y = f(x)
+ expected = jnp.roll(x, shift=-1, axis=0)
+ np.testing.assert_array_equal(y, expected)
+
+ def test_staggered_collective_permute_loop(self):
+
+ mesh = jax.make_mesh((jax.device_count(),), ('x',))
+
+ @jax.jit
+ @partial(
+ shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'),
+ check_rep=False,
+ )
+ def f(x):
+ assert x.shape[0] == 1
+ copy_start, send_done, recv_done = make_async_remote_copy('x')
+ x, fut = copy_start(x)
+ def body(_, carry):
+ x, fut = carry
+ x = send_done(x, fut)
+ y = recv_done(x, fut)
+ y, fut = copy_start(y)
+ return y, fut
+ # Send all the way around except for one step
+ x, fut = jax.lax.fori_loop(0, jax.device_count() - 2, body, (x, fut),
+ unroll=2)
+ x = send_done(x, fut)
+ y = recv_done(x, fut)
+ return y
+
+ n_devices = jax.device_count()
+ x = jax.random.normal(
+ jax.random.key(0), (n_devices, 8, 128), dtype=jnp.float32
+ )
+ y = f(x)
+ expected = jnp.roll(x, shift=-1, axis=0)
+ np.testing.assert_array_equal(y, expected)
+
+ def test_bidi_collective_permute_loop(self):
+ mesh = jax.make_mesh((jax.device_count(),), ('x',))
+
+ @jax.jit
+ @partial(
+ shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'),
+ check_rep=False,
+ )
+ def f(x):
+ assert x.shape[0] == 1
+ x = x[0]
+ copy_start, send_done, recv_done = make_bidi_collective_permute('x')
+ def body(_, x):
+ x, fut = copy_start(x)
+ x = send_done(x, fut)
+ y = recv_done(x, fut)
+ return y
+ # Send all the way around except for one step
+ y = jax.lax.fori_loop(0, jax.device_count() - 1, body, x)
+ return y[None]
+ x = jax.random.normal(
+ jax.random.key(0), (jax.device_count(), 16, 128), dtype=jnp.float32
+ )
+ y = f(x)
+ expected = jnp.concatenate([
+ jnp.roll(x[:, :8], axis=0, shift=-1),
+ jnp.roll(x[:, 8:], axis=0, shift=1),
+ ], axis=1)
+ np.testing.assert_array_equal(y, expected)
+
+
+if __name__ == "__main__":
+ absltest.main(testLoader=jtu.JaxTestLoader())
diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py
index 83ca6a5787cc..e100a5a39e49 100644
--- a/tests/pallas/tpu_pallas_test.py
+++ b/tests/pallas/tpu_pallas_test.py
@@ -31,6 +31,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_extension
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
+from jax._src.state import utils as state_utils
from jax.experimental import mesh_utils
from jax.experimental import mosaic
from jax.experimental import pallas as pl
@@ -1926,6 +1927,100 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
np.testing.assert_array_equal(out, expected)
+class PallasCallRefTransformTest(PallasBaseTest):
+
+ @parameterized.product(slice_first=[True, False])
+ def test_dma_bitcasted_ref(self, slice_first):
+ if not jtu.is_device_tpu_at_least(4):
+ self.skipTest('DMAs not supported on TPU generations <= 3')
+
+ def kernel(x_hbm_ref, y_hbm_ref):
+ def body(sem):
+ ref = (
+ x_hbm_ref.at[:8, :, :128].bitcast(jnp.int16)
+ if slice_first
+ else x_hbm_ref.bitcast(jnp.int16).at[:8, :, :128]
+ )
+ pltpu.async_copy(ref, y_hbm_ref.at[...], sem).wait()
+
+ pl.run_scoped(body, pltpu.SemaphoreType.DMA)
+
+ x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 1, 256))
+ y = self.pallas_call(
+ kernel,
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.int16),
+ )(x)
+ expected = (
+ state_utils.bitcast(x[:8, :, :128], jnp.int16)
+ if slice_first
+ else state_utils.bitcast(x, jnp.int16)[:8, :, :128]
+ )
+ np.testing.assert_array_equal(y, expected)
+
+ @parameterized.product(slice_first=[True, False])
+ def test_load_bitcasted_ref(self, slice_first: bool):
+ def kernel(x_ref, y_ref):
+ ref = (
+ x_ref.at[:8, :128].bitcast(jnp.int16)
+ if slice_first
+ else x_ref.bitcast(jnp.int16).at[:16, :128]
+ )
+ y_ref[...] = ref[...]
+
+ x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256))
+ y = self.pallas_call(
+ kernel,
+ out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int16),
+ )(x)
+ expected = (
+ state_utils.bitcast(x[:8, :128], jnp.int16)
+ if slice_first
+ else state_utils.bitcast(x, jnp.int16)[:16, :128]
+ )
+ np.testing.assert_array_equal(y, expected)
+
+ @parameterized.product(slice_first=[True, False])
+ def test_store_bitcasted_ref(self, slice_first):
+ def kernel(x_ref, y_ref):
+ ref = (
+ y_ref.at[:8, :128].bitcast(jnp.bfloat16)
+ if slice_first
+ else y_ref.bitcast(jnp.bfloat16).at[:16, :128]
+ )
+ ref[...] = x_ref[...]
+
+ x = jnp.arange(16 * 128, dtype=jnp.bfloat16).reshape((16, 128))
+ y = self.pallas_call(
+ kernel,
+ out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
+ )(x)
+ expected = state_utils.bitcast(x, jnp.int32)
+ np.testing.assert_array_equal(y[:8, :128], expected)
+
+ def test_multiple_ref_transforms(self):
+
+ def kernel(x_ref, y_ref):
+ ref = (
+ x_ref.at[:8, :256]
+ .bitcast(jnp.int16)
+ .bitcast(jnp.float16)
+ .at[:, :128]
+ .bitcast(jnp.int32)
+ )
+ y_ref[...] = ref[...]
+
+ x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256))
+ y = self.pallas_call(
+ kernel,
+ out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
+ )(x)
+ np.testing.assert_array_equal(y, x[:8, :128])
+
+
class PallasCallPrintTest(PallasBaseTest):
def test_debug_print(self):
diff --git a/tests/pjit_test.py b/tests/pjit_test.py
index dbb867ab9a39..57106948f7d3 100644
--- a/tests/pjit_test.py
+++ b/tests/pjit_test.py
@@ -56,8 +56,8 @@
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
-from jax._src.lib import xla_extension_version
from jax._src.lib import xla_extension
+from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2
config.parse_flags_with_absl()
@@ -653,18 +653,16 @@ def testAutodiff(self, mesh, resources):
@jtu.with_mesh([('x', 2), ('y', 1)])
def testAutodiffCache(self):
- f = pjit(
- lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None
- )
+ f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None)
x = jnp.arange(16, dtype=jnp.float32)
- jax.grad(f)(x) # Warm up the cache.
- before = pjit_lib._pjit_lower_cached.cache_info()
- jax.grad(f)(x)
- after = pjit_lib._pjit_lower_cached.cache_info()
- # One hit for the forward pass, one hit for backward.
- self.assertEqual(after.hits, before.hits + 2)
- self.assertEqual(after.misses, before.misses)
+ jax.grad(f)(x) # Warm up the cache.
+ with jtu.count_pjit_cpp_cache_miss() as count:
+ jax.grad(f)(x)
+ if xla_extension_version >= 286:
+ self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
+ else:
+ self.assertEqual(count[0], 2)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self):
@@ -4433,8 +4431,6 @@ def f(x):
"Compiled object called with input sharding.*does not match"):
compiled(cpu_arr)
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_different_devices_wsc_abstract_mesh_cache_hit(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
@@ -4463,8 +4459,6 @@ def f(x):
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
@@ -4484,8 +4478,6 @@ def f(x):
self.assertArraysEqual(out_eager, np_inp * 2)
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_wsc_sds_abstract_mesh(self):
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P())
@@ -4499,8 +4491,6 @@ def f(x):
sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s)
f.eval_shape(sds) # doesn't crash
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_wsc_vmap_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
@@ -4517,8 +4507,6 @@ def f(x):
out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x')))
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)
@@ -4542,6 +4530,20 @@ def test_wsc_abstract_mesh_errors(self):
' match the mesh shape of the target sharding.*'):
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))
+ @unittest.skipIf(xla_extension_version < 286,
+ "Requires xla_extension_version >= 286")
+ def test_global_jit_cpp_cache_hit_out_shardings(self):
+ mesh = jtu.create_mesh((2,), 'x')
+ s = NamedSharding(mesh, P('x'))
+
+ def f(x):
+ return x * 2
+
+ with jtu.count_pjit_cpp_cache_miss() as count:
+ jax.jit(f, out_shardings=s)(np.arange(8))
+ jax.jit(f, out_shardings=s)(np.arange(8))
+ self.assertEqual(count[0], 1)
+
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@@ -5206,11 +5208,6 @@ def test_get_partition_spec(self):
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
P('x', 'y'))
- out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec(
- P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC)
- self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
- P('x', 'y'))
-
def test_mesh_with_list_devices(self):
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
self.assertIsInstance(mesh.devices, np.ndarray)
diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py
index c02653dd171b..540136b33870 100644
--- a/tests/scipy_spatial_test.py
+++ b/tests/scipy_spatial_test.py
@@ -132,6 +132,20 @@ def testRotationAsQuatCanonical(self, shape, dtype):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
+ @jtu.sample_product(
+ dtype=float_dtypes,
+ shape=[(4,), (num_samples, 4)],
+ )
+ def testRotationAsQuatScalarFirst(self, shape, dtype):
+ if scipy_version < (1, 14, 0):
+ self.skipTest("Scipy 1.14.0 added the `scalar_first` arg.")
+ rng = jtu.rand_default(self.rng())
+ args_maker = lambda: (rng(shape, dtype),)
+ jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat(scalar_first=True)
+ np_fn = lambda q: osp_Rotation.from_quat(q).as_quat(scalar_first=True).astype(dtype)
+ self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
+ self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
+
@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples, 4)],
diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py
index 323d44b542d6..27199c874332 100644
--- a/tests/shape_poly_test.py
+++ b/tests/shape_poly_test.py
@@ -2843,11 +2843,6 @@ def test_vmap_error(self):
((2, 3, 8, 4), "b1, b2, ..."),
((2, 3, 4, 5), "b1, b2, m, n"),
]
- # TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version.
- # jaxlib versions before 0.4.32 require a static shape for the non-batch
- # dimensions because these are used for computing the "permuation_size"
- # which is passed to lu_pivots_to_permutation.
- if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n")
],
[
# The random primitive tests, with threefry (both partitionable and
diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py
index e9c23b3e5f0d..3d9b567e2ef4 100644
--- a/tests/shard_map_test.py
+++ b/tests/shard_map_test.py
@@ -45,7 +45,6 @@
from jax._src import linear_util as lu
from jax._src import tree_util
import jax.numpy as jnp
-from jax._src.lib import xla_extension_version
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.shard_map import shard_map
@@ -777,8 +776,6 @@ def with_capture(y_slice):
# is over an axis of size 2. This is a problem at the moment.
jax.make_jaxpr(mapped)(x, y).jaxpr
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_shard_map_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
@@ -803,8 +800,6 @@ def f(x):
self.assertArraysEqual(out2, np_inp)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_different_devices_shmap_abstract_mesh_cache_hit(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
@@ -835,8 +830,6 @@ def f(x):
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
- @unittest.skipIf(xla_extension_version < 281,
- 'Requires xla_extension_version >= 281')
def test_shmap_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)
@@ -2164,6 +2157,19 @@ def f(x, y):
with config.disable_vmap_shmap_error():
_ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y)
+ def test_in_spec_none_hashability(self):
+ mesh = jtu.create_mesh((2,), ('i',))
+
+ class A:
+ def __hash__(self):
+ raise Exception
+
+ @partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=())
+ def f(a):
+ return ()
+
+ f(A()) # don't crash
+
class FunSpec(NamedTuple):
name: str
diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl
index 8f4accca508c..72bad324e0f0 100644
--- a/third_party/xla/workspace.bzl
+++ b/third_party/xla/workspace.bzl
@@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
-XLA_COMMIT = "720b2c53346660e95abbed7cf3309a8b85e979f9"
-XLA_SHA256 = "a93bb0414c33025f6cb775c374d5695c84055f2bd88d6ea826d51d99612baaef"
+XLA_COMMIT = "a0cb79873742367204ad1386e9ca4fd815b3f860"
+XLA_SHA256 = "bcedc70cf3cdcc94159313365b15eb49e25e0d8a9d4713c290ead5a507d2b366"
def repo():
tf_http_archive(