Skip to content

New TPU jobs to use updated runners. #5

New TPU jobs to use updated runners.

New TPU jobs to use updated runners. #5

# Cloud TPU CI
name: Cloud TPU Presubmit
# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch
on:
pull_request:
branches:
- main
types: [labeled, synchronize]
workflow_dispatch:
# Cancel any previous iterations if a new commit is pushed
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
cloud-tpu-test:
# TODO: confirm final naming for optional label
# if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu')
name: "TPU v5e x 8 Presubmit"
env:
ENABLE_PJRT_COMPATIBILITY: 1
# TODO: Needs final runs-on value
runs-on: arc-linux-x86-ct5lp-224-8tpu
container:
# TODO: Needs newer, light weight image
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11
timeout-minutes: 120
defaults:
run:
shell: bash -ex {0}
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
sleep 3600
# TODO: build jax should be done on a step prior or we should just bazel test
- name: Wait For Connection
uses: google-ml-infra/jax-fork/actions/ci_connection@61e7d8d6c273b102e4a6271c1e84bd0a4febc8cb
with:
halt-dispatch-input: "1"
- name: Build JAX
run: |
pip uninstall -y jaxlib
python3 build/build.py --use_clang
pip install -e .
ls -la dist/*.whl
pip install dist/*.whl
# Note the version it installs! Should be today's date
pip install -U --no-index --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
NUM_TESTS: 8
JAX_NUM_GENERATED_CASES: 25
run: |
# Run single-accelerator tests in parallel
mkdir results
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=$NUM_TESTS --tb=short \
--junitxml=results/singlejunit.xml --maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --junitxml=results/multijunit.xml \
--maxfail=20 -m "multiaccelerator" tests
# - name: 'Upload Artifact'
# if: success() || failure()
# uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet:actions/upload-artifact@v4
# with:
# name: junit
# path: |
# results/singlejunit.xml
# results/multijunit.xml
# retention-days: 1