Skip to content

Commit

Permalink
Merge branch 'main' into jax-docs-advanced-tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Sep 19, 2024
2 parents 65ce1bf + 815dc3b commit 038d7ad
Show file tree
Hide file tree
Showing 160 changed files with 8,092 additions and 2,505 deletions.
4 changes: 4 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base
build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base
build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base
build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13"

build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --config=cuda
Expand All @@ -237,6 +239,8 @@ build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base
build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base
build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base
build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13"

# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
Expand Down
29 changes: 4 additions & 25 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
lint_and_typecheck:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
Expand Down Expand Up @@ -58,11 +57,6 @@ jobs:
prng-upgrade: 0
num_generated_cases: 1
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
Expand Down Expand Up @@ -110,11 +104,6 @@ jobs:
matrix:
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
Expand Down Expand Up @@ -152,11 +141,6 @@ jobs:
matrix:
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
Expand Down Expand Up @@ -193,11 +177,6 @@ jobs:
enable-x64: 0
num_generated_cases: 10
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ on:
- '**workflows/jax-array-api.yml'
- '**experimental/array_api/**'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
build:

Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/metal_plugin_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ on:
paths:
- '**workflows/metal_plugin_ci.yml'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
jax-metal-plugin-test:

Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ name: Wheel build - Windows CPU x86_64
on:
workflow_dispatch: # allows triggering the workflow run manually

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

env:
DISTUTILS_USE_SDK: 1
MSSdk: 1
Expand All @@ -13,16 +17,11 @@ jobs:
matrix:
os: [windows-2019-32core]
arch: [AMD64]
pyver: ['3.10', '3.11', '3.12']
pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2']
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
runs-on: ${{ matrix.os }}

steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}

- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade

Expand Down Expand Up @@ -58,7 +57,7 @@ jobs:
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
python -m pip install --find-links ${{ github.workspace }}\dist jaxlib
python -m pip install -e ${{ github.workspace }}
python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
8 changes: 4 additions & 4 deletions .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
pull_request:
types: [ labeled ] # allow force-windows-run label

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

env:
DISTUTILS_USE_SDK: 1
MSSdk: 1
Expand All @@ -23,10 +27,6 @@ jobs:
runs-on: ${{ matrix.os }}

steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}

- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
Expand Down
62 changes: 60 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,63 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.32
## jax 0.4.34

* New Functionality
* This release includes wheels for Python 3.13. Free-threading mode is not yet
supported.

* Deprecations
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
arguments with `ndim != 1` are now deprecated, and in the future will result
in an error.

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
The argument was only used by `xmap` which was removed in 0.4.31.


## jax 0.4.33 (September 16, 2024)

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of `libtpu`.

This release fixes an inaccurate result for F64 tanh on CPU (#23590).

## jax 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
C++ and CUDA code from JAX.

* Changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
* array[0] on a pmap result now introduces a reshape (use array[0:1]
instead).
* The per-shard shape (accessable via jax_array.addressable_shards or
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
Expand Down Expand Up @@ -65,9 +114,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.

## jaxlib 0.4.32
## jaxlib 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* Breaking changes
* This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
`XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this,
please file a JAX bug with instructions to reproduce.
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
Expand Down
7 changes: 4 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ install_deps()
load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter")
custom_python_interpreter(
name = "python_dev",
urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"],
strip_prefix = "Python-{version}",
version = "3.13.0a6",
urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"],
strip_prefix = "Python-{version_variant}",
version = "3.13.0",
version_variant = "3.13.0rc2",
)

load("@xla//:workspace4.bzl", "xla_workspace4")
Expand Down
7 changes: 5 additions & 2 deletions build/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ matplotlib; python_version>="3.11"
#
# build deps
#
numpy~=2.0.0
numpy~=2.0.0; python_version<="3.12"
numpy~=2.1.0; python_version>="3.13"

#
# runtime deps
#
scipy~=1.13.1
scipy>=1.13.1

ml_dtypes>=0.4.0
opt_einsum
zstandard
etils[epath]
# TODO(ybaturina): remove setuptools version
setuptools<71.0.0
Loading

0 comments on commit 038d7ad

Please sign in to comment.