diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1bf227287..67f59eadc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,5 +1,4 @@ name: build - on: # continuous schedule: @@ -12,32 +11,24 @@ on: pull_request: branches: - main - permissions: contents: read actions: write # to cancel previous workflows - concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - jobs: - build-checkpoint: - name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: ubuntu-latest + multiprocess-checkpoint-benchmarks: + name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-g2-16-l4-1gpu-x4 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: working-directory: checkpoint strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.6.0" # keep in sync with minimum version in checkpoint/pyproject.toml - # TODO(b/401258175) Re-enable once JAX nightlies are fixed. - # - python-version: "3.13" - # jax-version: "nightly" + python-version: ["3.12"] + jax-version: ["0.5.0"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} @@ -45,23 +36,23 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - # TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. - # Currently in place to override remote orbax import due to flax dependency. run: | pip install -e . - pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip uninstall -y orbax if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib + pip install -U jax[k8s,cuda12] jaxlib elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ else - pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" + pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" fi - - name: Test with pytest - # TODO(yaning): Move these to an exclude target within pytest.ini. + pip install gcsfs + - name: Run benchmarks + env: + GCS_BUCKET_PATH: gs://benchmarks/benchmark-results/${{ github.run_id }} run: | - python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py + cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status @@ -80,154 +71,201 @@ jobs: "description": "'$status'", "context": "github-actions/build" }' - - build-export: - name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: ubuntu-latest - defaults: - run: - working-directory: export - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml - # TODO(b/401258175) Re-enable once JAX nightlies are fixed. - # - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it. - # jax-version: "nightly" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - version: '3.x' - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") - - pip install . - pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - fi - - name: Test with pytest - run: | - test_dir=$(mktemp -d) - cp orbax/export/conftest.py ${test_dir} - for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do - cp ${t} ${test_dir} - XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) - done - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' - - build-orbax-model: - name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: ubuntu-latest - defaults: - run: - working-directory: model - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.5.0" # keep in sync with minimum version in experimental/model/pyproject.toml - # - python-version: "3.13" - # jax-version: "nightly" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - version: '3.x' - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - pip install tensorflow - - protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") - - pip install -e . - - pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - fi - - name: Test with pytest - run: | - pytest orbax/experimental/model/core/python/*_test.py - - pytest orbax/experimental/model/tf2obm/*_test.py - - pytest orbax/experimental/model/jax2obm/ \ - --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ - --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ - --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' + # build-checkpoint: + # name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + # runs-on: ubuntu-latest + # defaults: + # run: + # working-directory: checkpoint + # strategy: + # matrix: + # python-version: ["3.10", "3.11", "3.12"] + # jax-version: ["newest"] + # include: + # - python-version: "3.10" + # jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml + # # TODO(b/401258175) Re-enable once JAX nightlies are fixed. + # # - python-version: "3.13" + # # jax-version: "nightly" + # steps: + # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # - name: Set up Python ${{ matrix.python-version }} + # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + # with: + # python-version: ${{ matrix.python-version }} + # - name: Install dependencies + # # TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. + # # Currently in place to override remote orbax import due to flax dependency. + # run: | + # pip install -e . + # pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # pip uninstall -y orbax + # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then + # pip install -U jax jaxlib + # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + # pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + # else + # pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" + # fi + # - name: Test with pytest + # # TODO(yaning): Move these to an exclude target within pytest.ini. + # run: | + # python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py + # # The below step just reports the success or failure of tests as a "commit status". + # # This is needed for copybara integration. + # - name: Report success or failure as github status + # if: always() + # shell: bash + # run: | + # status="${{ job.status }}" + # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + # curl -sS --request POST \ + # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + # --header 'content-type: application/json' \ + # --data '{ + # "state": "'$lowercase_status'", + # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + # "description": "'$status'", + # "context": "github-actions/build" + # }' + # build-export: + # name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + # runs-on: ubuntu-latest + # defaults: + # run: + # working-directory: export + # strategy: + # matrix: + # python-version: ["3.10", "3.11", "3.12"] + # jax-version: ["newest"] + # include: + # - python-version: "3.10" + # jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml + # # TODO(b/401258175) Re-enable once JAX nightlies are fixed. + # # - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it. + # # jax-version: "nightly" + # steps: + # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # - name: Set up Python ${{ matrix.python-version }} + # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + # with: + # python-version: ${{ matrix.python-version }} + # - name: Install Protoc + # uses: arduino/setup-protoc@v1 + # with: + # version: '3.x' + # repo-token: ${{ secrets.GITHUB_TOKEN }} + # - name: Extract branch name + # shell: bash + # run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + # id: extract_branch + # - name: Install dependencies + # run: | + # sudo apt-get update + # sudo apt-get install -y protobuf-compiler + # protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") + # pip install . + # pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then + # pip install -U jax jaxlib + # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + # pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + # else + # pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + # fi + # - name: Test with pytest + # run: | + # test_dir=$(mktemp -d) + # cp orbax/export/conftest.py ${test_dir} + # for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do + # cp ${t} ${test_dir} + # XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) + # done + # # The below step just reports the success or failure of tests as a "commit status". + # # This is needed for copybara integration. + # - name: Report success or failure as github status + # if: always() + # shell: bash + # run: | + # status="${{ job.status }}" + # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + # curl -sS --request POST \ + # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + # --header 'content-type: application/json' \ + # --data '{ + # "state": "'$lowercase_status'", + # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + # "description": "'$status'", + # "context": "github-actions/build" + # }' + # build-orbax-model: + # name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + # runs-on: ubuntu-latest + # defaults: + # run: + # working-directory: model + # strategy: + # matrix: + # python-version: ["3.10", "3.11", "3.12"] + # jax-version: ["newest"] + # include: + # - python-version: "3.10" + # jax-version: "0.5.0" # keep in sync with minimum version in experimental/model/pyproject.toml + # # - python-version: "3.13" + # # jax-version: "nightly" + # steps: + # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # - name: Set up Python ${{ matrix.python-version }} + # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + # with: + # python-version: ${{ matrix.python-version }} + # - name: Install Protoc + # uses: arduino/setup-protoc@v1 + # with: + # version: '3.x' + # repo-token: ${{ secrets.GITHUB_TOKEN }} + # - name: Extract branch name + # shell: bash + # run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + # id: extract_branch + # - name: Install dependencies + # run: | + # sudo apt-get update + # sudo apt-get install -y protobuf-compiler + # pip install tensorflow + # protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") + # pip install -e . + # pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then + # pip install -U jax jaxlib + # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + # pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + # else + # pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + # fi + # - name: Test with pytest + # run: | + # pytest orbax/experimental/model/core/python/*_test.py + # pytest orbax/experimental/model/tf2obm/*_test.py + # pytest orbax/experimental/model/jax2obm/ \ + # --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ + # --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ + # --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py + # - name: Report success or failure as github status + # if: always() + # shell: bash + # run: | + # status="${{ job.status }}" + # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + # curl -sS --request POST \ + # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + # --header 'content-type: application/json' \ + # --data '{ + # "state": "'$lowercase_status'", + # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + # "description": "'$status'", + # "context": "github-actions/build" + # }'