Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,60 @@ jobs:
"context": "github-actions/build"
}'

build-checkpoint-multihost:
name: "build-checkpoint-multihost (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: linux-g2-16-l4-1gpu-x4
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
defaults:
run:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.12"]
jax-version: ["newest"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
# TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`.
# Currently in place to override remote orbax import due to flax dependency.
run: |
pip install -e .
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip uninstall -y orbax
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax[k8s,cuda12] jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
fi
- name: Test with pytest
# TODO(yaning): Move these to an exclude target within pytest.ini.
run: |
python -m pytest orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'

build-export:
name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: ubuntu-latest
Expand Down
Loading