Skip to content

[Pallas] Add a cost estimator for Pallas/JAX functions. #65849

[Pallas] Add a cost estimator for Pallas/JAX functions.

[Pallas] Add a cost estimator for Pallas/JAX functions. #65849

Workflow file for this run

name: CI
# We test all supported Python versions as follows:
# - 3.10 : Documentation build
# - 3.10 : Part of Matrix with NumPy dispatch
# - 3.10 : Part of Matrix
# - 3.11 : Part of Matrix
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
lint_and_typecheck:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python 3.11
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.11
- run: python -m pip install pre-commit
- uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
- run: pre-commit run --show-diff-on-failure --color=always --all-files
build:
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
runs-on: linux-x86-n2-32
container:
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 60
strategy:
matrix:
# Test the oldest and newest supported Python versions here.
include:
- name-prefix: "with 3.10"
python-version: "3.10"
enable-x64: 1
prng-upgrade: 1
num_generated_cases: 1
- name-prefix: "with 3.13"
python-version: "3.13"
enable-x64: 0
prng-upgrade: 0
num_generated_cases: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install .[minimum-jaxlib] -r build/test-requirements.txt
- name: Run tests
env:
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
pip install -e .
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 tests examples
documentation:
name: Documentation - test code snippets
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
python-version: ['3.10']
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
- name: Test documentation
env:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
JAX_TRACEBACK_FILTERING: "off"
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
documentation_render:
name: Documentation - render documentation
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
python-version: ['3.10']
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
- name: Render documentation
run: |
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
jax2tf_test:
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
runs-on: ${{ matrix.os }}
timeout-minutes: 30
strategy:
matrix:
# Test the oldest supported Python version here.
include:
- python-version: "3.10"
os: ubuntu-latest
enable-x64: 0
num_generated_cases: 10
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
- name: Run tests
env:
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
pip install -e .
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
ffi:
name: FFI example
runs-on: linux-x86-g2-16-l4-1gpu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
timeout-minutes: 30
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.12
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
- name: Install JAX
run: pip install .[cuda12]
- name: Build and install example project
run: python -m pip install -v ./examples/ffi[test]
env:
# We test building using GCC instead of clang. All other JAX builds use
# clang, but it is useful to make sure that FFI users can compile using
# a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear.
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
- name: Run CPU tests
run: python -m pytest examples/ffi/tests
env:
JAX_PLATFORM_NAME: cpu
- name: Run GPU tests
run: python -m pytest examples/ffi/tests