diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0553236c..a07f5bed 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,5 @@ -**Description** +# Description + Include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. @@ -6,16 +7,18 @@ this change. Fixes # (issue) If this is a hotfix to a released version, please specify it. -**How Has This Been Tested?** +## How has this been tested? + Please describe the tests that you ran to verify your changes. Please also note any relevant details for your test configuration (e.g. compiler, OS). Include enough information so someone can reproduce your tests. -**Checklist:** +## Checklist + - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas -- [ ] I have made corresponding changes to the documentation +- [ ] I have made corresponding changes to the documentation (e.g. add new modules to docs/docstrings/) - [ ] My changes generate no new warnings - [ ] Any dependent changes have been merged and published in downstream modules - [ ] New check tests, if applicable, are included diff --git a/.github/workflows/create-cache.yaml b/.github/workflows/create-cache.yaml new file mode 100644 index 00000000..6d823662 --- /dev/null +++ b/.github/workflows/create-cache.yaml @@ -0,0 +1,52 @@ +name: "Create GHA cache" + +# GitHub puts the following restrictions on cache sharing. PRs can access +# +# - caches that were created by the PR / earlier runs of the PR +# - caches that were created on the target branch +# +# To get effective cache sharing between PRs, we create caches on the `develop` +# branch (which is where almost all PRs merge into). + +on: + push: + branches: [develop] + +# Cancel running jobs if there's a newer push +concurrency: + group: ${{ github.repository }}-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + # GitHub Actions cache of the pre-commit environment + pre-commit: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - uses: actions/cache@v4 + id: cache + with: + path: ~/.cache/pre-commit + key: pre-commit_${{ env.pythonLocation }}_${{ hashFiles('.pre-commit-config.yaml') }} + lookup-only: true # don't actually download the cache + + - name: Populate pre-commit environment (if not cached) + if: steps.cache.outputs.cache-hit != 'true' + run: | + pip install pre-commit + pre-commit install --install-hooks + + # GitHub Actions cache of pyFV3 test data + pyFV3_test_data: + uses: NOAA-GFDL/pyFV3/.github/workflows/create_cache.yml@develop + + # GitHub Actions cache of pySHiELD test data + pySHiELD_test_data: + uses: NOAA-GFDL/pySHiELD/.github/workflows/create_cache.yml@develop diff --git a/.github/workflows/docs_build.yaml b/.github/workflows/docs_build.yaml new file mode 100644 index 00000000..3b0f1ed1 --- /dev/null +++ b/.github/workflows/docs_build.yaml @@ -0,0 +1,34 @@ +name: "Build docs" + +# This workflow builds the docs to catch us when modules are moved/deleted +# and the auto-generated docstrings can't build anymore. This does not catch +# us when new modules are added. + +# Run these these whenever ... +on: + pull_request: # ... a PR is opened / updated + merge_group: # ... the PR is added to the merge queue + +# cancel running jobs if theres a newer push +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install mkdocs + run: pip install mkdocs-material mkdocstrings[python] + + - name: Build docs + run: mkdocs build diff --git a/.github/workflows/docs_deploy.yaml b/.github/workflows/docs_deploy.yaml new file mode 100644 index 00000000..1e6659d3 --- /dev/null +++ b/.github/workflows/docs_deploy.yaml @@ -0,0 +1,35 @@ +name: "NDSL documentation" + +# Documentation gets automatically deployed upon merge to develop +on: + push: + branches: + - develop + +# Security: Restrict permissions of this workflow to whats needed. +permissions: + contents: write + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email github-actions[bot]@users.noreply.github.com + + - name: Setup python + uses: actions/setup-python@v5 + with: + python-version: 3.11 + + - name: Install dependencies + run: pip install mkdocs-material mkdocstrings[python] + + - name: Deploy docs to GitHub Pages + run: mkdocs gh-deploy --force diff --git a/.github/workflows/fv3_translate_tests.yaml b/.github/workflows/fv3_translate_tests.yaml index 9d030d72..f28e5ad1 100644 --- a/.github/workflows/fv3_translate_tests.yaml +++ b/.github/workflows/fv3_translate_tests.yaml @@ -1,10 +1,16 @@ name: "FV3 translate tests" + +# Run these these whenever ... on: - pull_request: + pull_request: # ... a PR is opened / updated + merge_group: # ... the PR is added to the merge queue + push: + branches: + - main # ... when merging into the main branch jobs: fv3_translate_tests: uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop with: - component_trigger: true + component_trigger: true component_name: NDSL diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index ba5863f3..40fb1e63 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -1,6 +1,12 @@ name: "Lint" + +# Run these these whenever ... on: - pull_request: + pull_request: # ... a PR is opened / updated + merge_group: # ... the PR is added to the merge queue + push: + branches: + - main # ... when merging into the main branch # cancel running jobs if theres a newer push concurrency: @@ -11,18 +17,23 @@ jobs: lint: runs-on: ubuntu-latest steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: 'recursive' + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' - - name: Setup Python 3.11 - uses: actions/setup-python@v5 - with: - python-version: '3.11' + # Only restore (don't save) caches on PRs. New caches created from PRs won't be + # accessible from other PRs, see workflows/create-cache.yaml. + - uses: actions/cache/restore@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit_${{ env.pythonLocation }}_${{ hashFiles('.pre-commit-config.yaml') }} - - name: Install pre-commit - run: pip install pre-commit + - name: Install pre-commit + run: pip install pre-commit - - name: Run lint via pre-commit - run: pre-commit run --all-files + - name: Run lint via pre-commit + run: pre-commit run --all-files diff --git a/.github/workflows/pace_tests.yaml b/.github/workflows/pace_tests.yaml index 3837beec..ea3d40b3 100644 --- a/.github/workflows/pace_tests.yaml +++ b/.github/workflows/pace_tests.yaml @@ -1,10 +1,16 @@ name: "pace main tests" + +# Run these these whenever ... on: - pull_request: + pull_request: # ... a PR is opened / updated + merge_group: # ... the PR is added to the merge queue + push: + branches: + - main # ... when merging into the main branch jobs: pace_main_tests: uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop with: - component_trigger: true + component_trigger: true component_name: NDSL diff --git a/.github/workflows/shield_tests.yaml b/.github/workflows/shield_tests.yaml index fcc86f20..53ba510b 100644 --- a/.github/workflows/shield_tests.yaml +++ b/.github/workflows/shield_tests.yaml @@ -1,10 +1,16 @@ name: "SHiELD Translate tests" + +# Run these these whenever ... on: - pull_request: + pull_request: # ... a PR is opened / updated + merge_group: # ... the PR is added to the merge queue + push: + branches: + - main # ... when merging into the main branch jobs: shield_translate_tests: uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop with: - component_trigger: true + component_trigger: true component_name: NDSL diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 7bde1bf0..bd90f66c 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -1,6 +1,12 @@ name: "NDSL unit tests" + +# Run these these whenever ... on: - pull_request: + pull_request: # ... a PR is opened / updated + merge_group: # ... the PR is added to the merge queue + push: + branches: + - main # ... when merging into the main branch # cancel running jobs if theres a newer push concurrency: @@ -10,28 +16,31 @@ concurrency: jobs: ndsl_unit_tests: runs-on: ubuntu-latest - container: - image: ghcr.io/noaa-gfdl/miniforge:mpich steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: 'recursive' + - name: Checkout repository + uses: actions/checkout@v5 + with: + submodules: 'recursive' + + - name: Setup Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' - - name: Install Python packages - run: pip3 install .[test] + - name: Install mpi (MPICH flavor) + run: pip3 install mpich - - name: prepare input eta files - run: python tests/grid/generate_eta_files.py + - name: Install Python packages + run: pip3 install .[test] - - name: Run serial-cpu tests - run: coverage run --rcfile=setup.cfg -m pytest tests + - name: Run serial-cpu tests + run: coverage run --rcfile=pyproject.toml -m pytest tests - - name: Run parallel-cpu tests - run: mpiexec -np 6 --oversubscribe coverage run --rcfile=setup.cfg -m mpi4py -m pytest tests/mpi + - name: Run parallel-cpu tests + run: mpiexec -np 6 coverage run --rcfile=pyproject.toml -m mpi4py -m pytest tests/mpi - - name: Output code coverage - run: | - coverage combine - coverage report + - name: Output code coverage + run: | + coverage combine + coverage report diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8df5e5ae..da67e90a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,50 +2,51 @@ default_language_version: python: python3 repos: -- repo: https://github.com/psf/black - rev: 20.8b1 + # Use mypyc-compiled black, which is about 2x faster + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.1.0 hooks: - - id: black - additional_dependencies: ["click==8.0.4"] + - id: black-jupyter -- repo: https://github.com/pre-commit/mirrors-isort - rev: v5.4.2 + - repo: https://github.com/pre-commit/mirrors-isort + rev: v5.10.1 hooks: - - id: isort - args: ["--profile", "black"] + - id: isort + args: ["--profile", "black"] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.18.2 hooks: - - id: mypy - name: mypy-ndsl - args: [--config-file, setup.cfg] - additional_dependencies: [types-PyYAML] - files: ndsl - exclude: | - (?x)^( - ndsl/ndsl/gt4py_utils.py | - )$ -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + - id: mypy + name: mypy-ndsl + args: ["--config-file", "pyproject.toml"] + additional_dependencies: [types-PyYAML] + files: ndsl + exclude: | + (?x)^( + ndsl/ndsl/gt4py_utils.py | + )$ + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/pycqa/flake8 + rev: 7.3.0 hooks: - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + - id: flake8 + name: flake8 + language_version: python3 + additional_dependencies: [Flake8-pyproject, flake8-bugbear] + + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.15.0 hooks: - - id: flake8 - name: flake8 - language_version: python3 - args: [--config, setup.cfg] - exclude: | - (?x)^( - .*/__init__.py | - )$ - - id: flake8 - name: flake8 __init__.py files - files: "__init__.py" - # ignore unused import error in __init__.py files - args: ["--ignore=F401,E203", --config, setup.cfg] + - id: pretty-format-toml + args: [--autofix, --indent, "2"] + - id: pretty-format-yaml + args: [--autofix, --preserve-quotes, --indent, "2", --offset, "2"] diff --git a/README.md b/README.md index 6a387192..b256eda7 100644 --- a/README.md +++ b/README.md @@ -3,57 +3,104 @@ # NOAA/NASA Domain Specific Language middleware -NDSL is a middleware for climate and weather modelling developed jointly by NOAA and NASA. The middleware brings together [GT4Py](https://github.com/GridTools/gt4py/) (the `cartesian` flavor), ETH CSCS's stencil DSL, and [DaCE](https://github.com/spcl/dace/), ETH SPCL's data flow framework, both developed for high-performance and portability. On top of those pillars, NDSL deploys a series of optimized APIs for common operations (Halo exchange, domain decomposition, MPI...), a set of bespoke optimizations for the models targeted by the middleware and tools to port existing models. +NDSL is a middleware for climate and weather modelling developed jointly by NOAA and NASA. The middleware brings together [GT4Py](https://github.com/GridTools/gt4py/) (the `cartesian` flavor), ETH CSCS's stencil DSL, and [DaCe](https://github.com/spcl/dace/), ETH SPCL's data flow framework, both developed for high-performance and portability. On top of those pillars, NDSL deploys a series of optimized APIs for common operations (Halo exchange, domain decomposition, MPI, ...), a set of bespoke optimizations for the models targeted by the middleware and tools to port existing models. -## Battery-included for FV-based models +## Batteries-included for FV-based models -Historically NDSL was developed to port the FV3 dynamical core on the cube-sphere. Therefore, the middleware ships with ready-to-execute specialization for models based on cube-sphere grid and FV-based model in particular. +Historically, NDSL was developed to port the FV3 dynamical core on the cubed-sphere. Therefore, the middleware ships with ready-to-execute specialization for models based on cubed-sphere grids and FV-based models in particular. ## Quickstart -Recommended Python is `3.11.x` all other dependencies will be pulled during install. +Currently, NDSL requires Python version `3.11.x`, a GNU compiler and MPI installed. All other dependencies installed during package installation. We recommend using virtual (or conda) environment. -NDSL submodules `gt4py` and `dace` to point to vetted versions, use `git clone --recurse-submodule`. +```shell +# We have submodules for GT4Py and DaCe. Don't forget to pull them +git clone --recurse-submodules git@github.com:NOAA-GFDL/NDSL.git -NDSL is __NOT__ available on `pypi`. Installation of the package has to be local, via `pip install ./NDSL` (`-e` supported). The packages has a few options: +cd NDSL/ -- `ndsl[test]`: installs the test packages (based on `pytest`) -- `ndsl[demos]`: installs extra requirements to run [NDSL examples](./examples/NDSL/) -- `ndsl[docs]`: installs extra requirements to build the docs -- `ndsl[develop]`: installs tools for development, docs, and tests. +# We strongly recommend using a virtual environment (or conda) +python -m venv .venv/ +source ./venv/bin/activate -Tests are available via: +# Choose pip install -e .[dev] if you'd like to contribute +pip install .[demos] +``` + +Now, checkout [examples/NDSL](./examples/NDSL/) and ran through the Jupyter notebooks. Note that you have to install NDSL locally, as it is not available on `pypi`. + +## The slightly longer version -- `pytest -x test`: running CPU serial tests (GPU as well if `cupy` is installed) -- `mpirun -np 6 pytest -x test/mpi`: running CPU parallel tests (GPU as well if `cupy` is installed) +NDSL is under active development and may only work with specific setups. This is what we know works for us. -## Requirements & supported compilers +### Requirements and supported compilers -For CPU backends: +The run the CPU backends you will need: -- 3.11.x >= Python < 3.12.x -- Compilers: - - GNU 11.2+ +- Python: 3.11.x +- CXX compiler: GNU 11.2+ +- Libraries: MPI -For GPU backends (the above plus): +To run the GPU backends, you'll need: +- Python: 3.11.x +- CXX compiler: GNU 11.2+ +- Libraries: MPI compiled with CUDA support - CUDA 11.2+ - Python package: - - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) -- Libraries: - - MPI compiled with cuda support + - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) + +A simple way to install MPI is using pre-built wheels, e.g. + +```shell +# See "quickstart" above how to setup a virtual environment +cd NDSL/ +source ./venv/bin/activate + +# Install MPI into your virtual environment +pip install openmpi +``` + +A note on the compiler: NDSL currently only works with the GNU compiler. Using `clang` will result in errors related to undefined OpenMP flags. For MacOS users, we know that `gcc` version 14 from homebrew works. + +### Installation options + +See [quickstart](#quickstart) above on how to pull and setup a virtual environment. The packages has a few options: + +- `ndsl[test]`: extra dependencies to run tests (based on `pytest`) +- `ndsl[demos]`: extra dependencies to run [NDSL examples](./examples/NDSL/) +- `ndsl[docs]`: extra dependencies to build the docs +- `ndsl[dev]`: installs tools for development, docs, and tests. + +### Running tests + +Tests are available via `pytest` (don't forget to install the `test` or `dev` extras). + +To run serial tests on CPU (GPU tests also run if `cupy` is available) + +```bash +pytest tests/ +``` + +To run parallel tests on CPU (GPU tests also run if `cupy` is available) + +```bash +mpirun -np 6 pytest tests/mpi +``` ## Development ### Code/contribution guidelines -TBD +1. Code quality is enforced by `pre-commit` (which is part of the "dev" extra). Run `pre-commit install` to install the pre-commit hooks locally or make sure to run `pre-commit run -a` before submitting a pull request. +2. While we don't strictly enforce type hints, we add them on new code. +3. Pull requests have to merged as "squash merge" to keep the `git` history clean. ### Documentation We are using [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), which allows us to write the docs in Markdown files and optionally serve it as a static site. -To view the documentation, install NDSL with the `docs` or `develop` extras. Then just run +To view the documentation, install NDSL with the `docs` or `dev` extras. Then run the following: ```bash mkdocs serve @@ -66,7 +113,7 @@ Contributing to the documentation is straight forward: 3. [Optional] Start the development server and look how your changes are rendered. 4. Submit a pull request with your changes. -## Point of Contacts +## Points of contact - NOAA: Rusty Benson: rusty.benson -at- noaa.gov - NASA: Florian Deconinck florian.g.deconinck -at- nasa.gov diff --git a/docs/docstrings/checkpointer/base.md b/docs/docstrings/checkpointer/base.md new file mode 100644 index 00000000..b3c282ae --- /dev/null +++ b/docs/docstrings/checkpointer/base.md @@ -0,0 +1,3 @@ +# base + +::: checkpointer.base diff --git a/docs/docstrings/checkpointer/null.md b/docs/docstrings/checkpointer/null.md new file mode 100644 index 00000000..1395606c --- /dev/null +++ b/docs/docstrings/checkpointer/null.md @@ -0,0 +1,3 @@ +# null + +::: checkpointer.null diff --git a/docs/docstrings/checkpointer/snapshots.md b/docs/docstrings/checkpointer/snapshots.md new file mode 100644 index 00000000..41b1ebdc --- /dev/null +++ b/docs/docstrings/checkpointer/snapshots.md @@ -0,0 +1,3 @@ +# snapshots + +::: checkpointer.snapshots diff --git a/docs/docstrings/checkpointer/thresholds.md b/docs/docstrings/checkpointer/thresholds.md new file mode 100644 index 00000000..f03ac634 --- /dev/null +++ b/docs/docstrings/checkpointer/thresholds.md @@ -0,0 +1,3 @@ +# thresholds + +::: checkpointer.thresholds diff --git a/docs/docstrings/checkpointer/validation.md b/docs/docstrings/checkpointer/validation.md new file mode 100644 index 00000000..cc950f52 --- /dev/null +++ b/docs/docstrings/checkpointer/validation.md @@ -0,0 +1,3 @@ +# validation + +::: checkpointer.validation diff --git a/docs/docstrings/comm/_boundary_utils.md b/docs/docstrings/comm/_boundary_utils.md new file mode 100644 index 00000000..6983ba26 --- /dev/null +++ b/docs/docstrings/comm/_boundary_utils.md @@ -0,0 +1,3 @@ +# _boundary_utils + +::: comm._boundary_utils diff --git a/docs/docstrings/comm/boundary.md b/docs/docstrings/comm/boundary.md new file mode 100644 index 00000000..3f456751 --- /dev/null +++ b/docs/docstrings/comm/boundary.md @@ -0,0 +1,3 @@ +# boundary + +::: comm.boundary diff --git a/docs/docstrings/comm/caching_comm.md b/docs/docstrings/comm/caching_comm.md new file mode 100644 index 00000000..6d59b3de --- /dev/null +++ b/docs/docstrings/comm/caching_comm.md @@ -0,0 +1,3 @@ +# caching_comm + +::: comm.caching_comm diff --git a/docs/docstrings/comm/comm_abc.md b/docs/docstrings/comm/comm_abc.md new file mode 100644 index 00000000..30ce5a15 --- /dev/null +++ b/docs/docstrings/comm/comm_abc.md @@ -0,0 +1,3 @@ +# comm_abc + +::: comm.comm_abc diff --git a/docs/docstrings/comm/communicator.md b/docs/docstrings/comm/communicator.md new file mode 100644 index 00000000..466ad3e7 --- /dev/null +++ b/docs/docstrings/comm/communicator.md @@ -0,0 +1,3 @@ +# communicator + +::: comm.communicator diff --git a/docs/docstrings/comm/decomposition.md b/docs/docstrings/comm/decomposition.md new file mode 100644 index 00000000..687d0a74 --- /dev/null +++ b/docs/docstrings/comm/decomposition.md @@ -0,0 +1,3 @@ +# decomposition + +::: comm.decomposition diff --git a/docs/docstrings/comm/local_comm.md b/docs/docstrings/comm/local_comm.md new file mode 100644 index 00000000..a7999ec4 --- /dev/null +++ b/docs/docstrings/comm/local_comm.md @@ -0,0 +1,3 @@ +# local_comm + +::: comm.local_comm diff --git a/docs/docstrings/comm/mpi.md b/docs/docstrings/comm/mpi.md new file mode 100644 index 00000000..596e76a9 --- /dev/null +++ b/docs/docstrings/comm/mpi.md @@ -0,0 +1,3 @@ +# mpi + +::: comm.mpi diff --git a/docs/docstrings/comm/null_comm.md b/docs/docstrings/comm/null_comm.md new file mode 100644 index 00000000..2c4639d8 --- /dev/null +++ b/docs/docstrings/comm/null_comm.md @@ -0,0 +1,3 @@ +# null_comm + +::: comm.null_comm diff --git a/docs/docstrings/comm/partitioner.md b/docs/docstrings/comm/partitioner.md new file mode 100644 index 00000000..e33fdd35 --- /dev/null +++ b/docs/docstrings/comm/partitioner.md @@ -0,0 +1,3 @@ +# partitioner + +::: comm.partitioner diff --git a/docs/docstrings/debug/config.md b/docs/docstrings/debug/config.md new file mode 100644 index 00000000..06a0cf27 --- /dev/null +++ b/docs/docstrings/debug/config.md @@ -0,0 +1,3 @@ +# config + +::: debug.config diff --git a/docs/docstrings/debug/debugger.md b/docs/docstrings/debug/debugger.md new file mode 100644 index 00000000..9e609b85 --- /dev/null +++ b/docs/docstrings/debug/debugger.md @@ -0,0 +1,3 @@ +# debugger + +::: debug.debugger diff --git a/docs/docstrings/debug/tooling.md b/docs/docstrings/debug/tooling.md new file mode 100644 index 00000000..dc7f1e37 --- /dev/null +++ b/docs/docstrings/debug/tooling.md @@ -0,0 +1,3 @@ +# tooling + +::: debug.tooling diff --git a/docs/docstrings/dsl/gt4py_utils.md b/docs/docstrings/dsl/gt4py_utils.md new file mode 100644 index 00000000..c995eabd --- /dev/null +++ b/docs/docstrings/dsl/gt4py_utils.md @@ -0,0 +1,3 @@ +# gt4py_utils + +::: dsl.gt4py_utils diff --git a/docs/docstrings/dsl/stencil.md b/docs/docstrings/dsl/stencil.md new file mode 100644 index 00000000..2bcc0488 --- /dev/null +++ b/docs/docstrings/dsl/stencil.md @@ -0,0 +1,3 @@ +# stencil + +::: dsl.stencil diff --git a/docs/docstrings/dsl/stencil_config.md b/docs/docstrings/dsl/stencil_config.md new file mode 100644 index 00000000..20e1bfc4 --- /dev/null +++ b/docs/docstrings/dsl/stencil_config.md @@ -0,0 +1,3 @@ +# stencil_config + +::: dsl.stencil_config diff --git a/docs/docstrings/dsl/typing.md b/docs/docstrings/dsl/typing.md new file mode 100644 index 00000000..1f227afb --- /dev/null +++ b/docs/docstrings/dsl/typing.md @@ -0,0 +1,3 @@ +# typing + +::: dsl.typing diff --git a/docs/docstrings/grid/eta.md b/docs/docstrings/grid/eta.md new file mode 100644 index 00000000..b7b46bc8 --- /dev/null +++ b/docs/docstrings/grid/eta.md @@ -0,0 +1,3 @@ +# eta + +::: grid.eta diff --git a/docs/docstrings/grid/generation.md b/docs/docstrings/grid/generation.md new file mode 100644 index 00000000..6b963e6d --- /dev/null +++ b/docs/docstrings/grid/generation.md @@ -0,0 +1,3 @@ +# generation + +::: grid.generation diff --git a/docs/docstrings/grid/geometry.md b/docs/docstrings/grid/geometry.md new file mode 100644 index 00000000..bfdf031f --- /dev/null +++ b/docs/docstrings/grid/geometry.md @@ -0,0 +1,3 @@ +# geometry + +::: grid.geometry diff --git a/docs/docstrings/grid/global_setup.md b/docs/docstrings/grid/global_setup.md new file mode 100644 index 00000000..bc8525bb --- /dev/null +++ b/docs/docstrings/grid/global_setup.md @@ -0,0 +1,3 @@ +# global_setup + +::: grid.global_setup diff --git a/docs/docstrings/grid/gnomonic.md b/docs/docstrings/grid/gnomonic.md new file mode 100644 index 00000000..a42f21de --- /dev/null +++ b/docs/docstrings/grid/gnomonic.md @@ -0,0 +1,3 @@ +# gnomonic + +::: grid.gnomonic diff --git a/docs/docstrings/grid/helper.md b/docs/docstrings/grid/helper.md new file mode 100644 index 00000000..264de3eb --- /dev/null +++ b/docs/docstrings/grid/helper.md @@ -0,0 +1,3 @@ +# helper + +::: grid.helper diff --git a/docs/docstrings/grid/mirror.md b/docs/docstrings/grid/mirror.md new file mode 100644 index 00000000..aee66f44 --- /dev/null +++ b/docs/docstrings/grid/mirror.md @@ -0,0 +1,3 @@ +# mirror + +::: grid.mirror diff --git a/docs/docstrings/grid/stretch_transformation.md b/docs/docstrings/grid/stretch_transformation.md new file mode 100644 index 00000000..f2d07ac4 --- /dev/null +++ b/docs/docstrings/grid/stretch_transformation.md @@ -0,0 +1,3 @@ +# stretch_transformation + +::: grid.stretch_transformation diff --git a/docs/docstrings/halo/cuda_kernels.md b/docs/docstrings/halo/cuda_kernels.md new file mode 100644 index 00000000..4bc9132e --- /dev/null +++ b/docs/docstrings/halo/cuda_kernels.md @@ -0,0 +1,3 @@ +# cuda_kernels + +::: halo.cuda_kernels diff --git a/docs/docstrings/halo/data_transformer.md b/docs/docstrings/halo/data_transformer.md new file mode 100644 index 00000000..1bec674d --- /dev/null +++ b/docs/docstrings/halo/data_transformer.md @@ -0,0 +1,3 @@ +# data_transformer + +::: halo.data_transformer diff --git a/docs/docstrings/halo/rotate.md b/docs/docstrings/halo/rotate.md new file mode 100644 index 00000000..e4b9bf53 --- /dev/null +++ b/docs/docstrings/halo/rotate.md @@ -0,0 +1,3 @@ +# rotate + +::: halo.rotate diff --git a/docs/docstrings/halo/updater.md b/docs/docstrings/halo/updater.md new file mode 100644 index 00000000..e37eb28e --- /dev/null +++ b/docs/docstrings/halo/updater.md @@ -0,0 +1,3 @@ +# updater + +::: halo.updater diff --git a/docs/docstrings/index.md b/docs/docstrings/index.md new file mode 100644 index 00000000..125a1787 --- /dev/null +++ b/docs/docstrings/index.md @@ -0,0 +1,19 @@ +# Welcome to NDSL Docstrings + +This documentation is generated from the docstrings of the various NDSL classes and functions. It's organized by the various NDSL paths. + +- Top level +- checkpointer +- comm +- debug +- dsl +- grid +- halo +- initialization +- monitor +- performance +- quantity +- restart +- stencils +- testing +- viz diff --git a/docs/docstrings/initialization/allocator.md b/docs/docstrings/initialization/allocator.md new file mode 100644 index 00000000..05f78a36 --- /dev/null +++ b/docs/docstrings/initialization/allocator.md @@ -0,0 +1,3 @@ +# allocator + +::: initialization.allocator diff --git a/docs/docstrings/initialization/grid_sizer.md b/docs/docstrings/initialization/grid_sizer.md new file mode 100644 index 00000000..eedf0e21 --- /dev/null +++ b/docs/docstrings/initialization/grid_sizer.md @@ -0,0 +1,3 @@ +# grid_sizer + +::: initialization.grid_sizer diff --git a/docs/docstrings/initialization/subtile_grid_sizer.md b/docs/docstrings/initialization/subtile_grid_sizer.md new file mode 100644 index 00000000..8219fa92 --- /dev/null +++ b/docs/docstrings/initialization/subtile_grid_sizer.md @@ -0,0 +1,3 @@ +# subtile_grid_sizer + +::: initialization.subtile_grid_sizer diff --git a/docs/docstrings/monitor/convert.md b/docs/docstrings/monitor/convert.md new file mode 100644 index 00000000..56d65989 --- /dev/null +++ b/docs/docstrings/monitor/convert.md @@ -0,0 +1,3 @@ +# convert + +::: monitor.convert diff --git a/docs/docstrings/monitor/netcdf_monitor.md b/docs/docstrings/monitor/netcdf_monitor.md new file mode 100644 index 00000000..7d7ba68d --- /dev/null +++ b/docs/docstrings/monitor/netcdf_monitor.md @@ -0,0 +1,3 @@ +# netcdf_monitor + +::: monitor.netcdf_monitor diff --git a/docs/docstrings/monitor/protocol.md b/docs/docstrings/monitor/protocol.md new file mode 100644 index 00000000..ef4e2701 --- /dev/null +++ b/docs/docstrings/monitor/protocol.md @@ -0,0 +1,3 @@ +# protocol + +::: monitor.protocol diff --git a/docs/docstrings/monitor/zarr_monitor.md b/docs/docstrings/monitor/zarr_monitor.md new file mode 100644 index 00000000..01de7de1 --- /dev/null +++ b/docs/docstrings/monitor/zarr_monitor.md @@ -0,0 +1,3 @@ +# zarr_monitor + +::: monitor.zarr_monitor diff --git a/docs/docstrings/performance/collector.md b/docs/docstrings/performance/collector.md new file mode 100644 index 00000000..885c1194 --- /dev/null +++ b/docs/docstrings/performance/collector.md @@ -0,0 +1,3 @@ +# collector + +::: performance.collector diff --git a/docs/docstrings/performance/config.md b/docs/docstrings/performance/config.md new file mode 100644 index 00000000..e699d903 --- /dev/null +++ b/docs/docstrings/performance/config.md @@ -0,0 +1,3 @@ +# config + +::: performance.config diff --git a/docs/docstrings/performance/profiler.md b/docs/docstrings/performance/profiler.md new file mode 100644 index 00000000..2b4fe0aa --- /dev/null +++ b/docs/docstrings/performance/profiler.md @@ -0,0 +1,3 @@ +# profiler + +::: performance.profiler diff --git a/docs/docstrings/performance/report.md b/docs/docstrings/performance/report.md new file mode 100644 index 00000000..74b8744b --- /dev/null +++ b/docs/docstrings/performance/report.md @@ -0,0 +1,3 @@ +# report + +::: performance.report diff --git a/docs/docstrings/performance/timer.md b/docs/docstrings/performance/timer.md new file mode 100644 index 00000000..e80aa18b --- /dev/null +++ b/docs/docstrings/performance/timer.md @@ -0,0 +1,3 @@ +# timer + +::: performance.timer diff --git a/docs/docstrings/performance/tools.md b/docs/docstrings/performance/tools.md new file mode 100644 index 00000000..58521bad --- /dev/null +++ b/docs/docstrings/performance/tools.md @@ -0,0 +1,3 @@ +# tools + +::: performance.tools diff --git a/docs/docstrings/quantity/bounds.md b/docs/docstrings/quantity/bounds.md new file mode 100644 index 00000000..11ee5471 --- /dev/null +++ b/docs/docstrings/quantity/bounds.md @@ -0,0 +1,3 @@ +# bounds + +::: quantity.bounds diff --git a/docs/docstrings/quantity/field_bundle.md b/docs/docstrings/quantity/field_bundle.md new file mode 100644 index 00000000..e49f9af7 --- /dev/null +++ b/docs/docstrings/quantity/field_bundle.md @@ -0,0 +1,3 @@ +# field_bundle + +::: quantity.field_bundle diff --git a/docs/docstrings/quantity/metadata.md b/docs/docstrings/quantity/metadata.md new file mode 100644 index 00000000..de6af4c1 --- /dev/null +++ b/docs/docstrings/quantity/metadata.md @@ -0,0 +1,3 @@ +# metadata + +::: quantity.metadata diff --git a/docs/docstrings/quantity/quantity.md b/docs/docstrings/quantity/quantity.md new file mode 100644 index 00000000..8a466c37 --- /dev/null +++ b/docs/docstrings/quantity/quantity.md @@ -0,0 +1,3 @@ +# quantity + +::: quantity.quantity diff --git a/docs/docstrings/quantity/state.md b/docs/docstrings/quantity/state.md new file mode 100644 index 00000000..6b33e884 --- /dev/null +++ b/docs/docstrings/quantity/state.md @@ -0,0 +1,3 @@ +# state + +::: quantity.state diff --git a/docs/docstrings/restart/_legacy_restart.md b/docs/docstrings/restart/_legacy_restart.md new file mode 100644 index 00000000..720d990e --- /dev/null +++ b/docs/docstrings/restart/_legacy_restart.md @@ -0,0 +1,3 @@ +# _legacy_restart + +::: restart._legacy_restart diff --git a/docs/docstrings/restart/_properties.md b/docs/docstrings/restart/_properties.md new file mode 100644 index 00000000..c5ce0c66 --- /dev/null +++ b/docs/docstrings/restart/_properties.md @@ -0,0 +1,3 @@ +# _properties + +:: restart._properties diff --git a/docs/docstrings/stencils/basic_operations.md b/docs/docstrings/stencils/basic_operations.md new file mode 100644 index 00000000..7f21e3c0 --- /dev/null +++ b/docs/docstrings/stencils/basic_operations.md @@ -0,0 +1,3 @@ +# basic_operations + +::: stencils.basic_operations diff --git a/docs/docstrings/stencils/c2l_ord.md b/docs/docstrings/stencils/c2l_ord.md new file mode 100644 index 00000000..ebacdf15 --- /dev/null +++ b/docs/docstrings/stencils/c2l_ord.md @@ -0,0 +1,3 @@ +# c2l_ord + +::: stencils.c2l_ord diff --git a/docs/docstrings/stencils/corners.md b/docs/docstrings/stencils/corners.md new file mode 100644 index 00000000..2932d1a5 --- /dev/null +++ b/docs/docstrings/stencils/corners.md @@ -0,0 +1,3 @@ +# corners + +::: stencils.corners diff --git a/docs/docstrings/stencils/tridiag.md b/docs/docstrings/stencils/tridiag.md new file mode 100644 index 00000000..22988c33 --- /dev/null +++ b/docs/docstrings/stencils/tridiag.md @@ -0,0 +1,3 @@ +# tridiag + +::: stencils.tridiag diff --git a/docs/docstrings/testing/comparison.md b/docs/docstrings/testing/comparison.md new file mode 100644 index 00000000..48d4661f --- /dev/null +++ b/docs/docstrings/testing/comparison.md @@ -0,0 +1,3 @@ +# comparison + +::: testing.comparison diff --git a/docs/docstrings/testing/dummy_comm.md b/docs/docstrings/testing/dummy_comm.md new file mode 100644 index 00000000..538c01d2 --- /dev/null +++ b/docs/docstrings/testing/dummy_comm.md @@ -0,0 +1,3 @@ +# dummy_comm + +::: testing.dummy_comm diff --git a/docs/docstrings/testing/perturbation.md b/docs/docstrings/testing/perturbation.md new file mode 100644 index 00000000..9c9bd456 --- /dev/null +++ b/docs/docstrings/testing/perturbation.md @@ -0,0 +1,3 @@ +# perturbation + +::: testing.perturbation diff --git a/docs/docstrings/top/boilerplate.md b/docs/docstrings/top/boilerplate.md new file mode 100644 index 00000000..4d3fb7b7 --- /dev/null +++ b/docs/docstrings/top/boilerplate.md @@ -0,0 +1,3 @@ +# boilerplate + +::: boilerplate diff --git a/docs/docstrings/top/buffer.md b/docs/docstrings/top/buffer.md new file mode 100644 index 00000000..9eed514e --- /dev/null +++ b/docs/docstrings/top/buffer.md @@ -0,0 +1,3 @@ +# buffer + +::: buffer diff --git a/docs/docstrings/top/constants.md b/docs/docstrings/top/constants.md new file mode 100644 index 00000000..68d4eafb --- /dev/null +++ b/docs/docstrings/top/constants.md @@ -0,0 +1,3 @@ +# constants + +::: constants diff --git a/docs/docstrings/top/exceptions.md b/docs/docstrings/top/exceptions.md new file mode 100644 index 00000000..8803fee0 --- /dev/null +++ b/docs/docstrings/top/exceptions.md @@ -0,0 +1,3 @@ +# exceptions + +::: exceptions diff --git a/docs/docstrings/top/filesystem.md b/docs/docstrings/top/filesystem.md new file mode 100644 index 00000000..9f05cc3a --- /dev/null +++ b/docs/docstrings/top/filesystem.md @@ -0,0 +1,3 @@ +# filesystem + +::: filesystem diff --git a/docs/docstrings/top/io.md b/docs/docstrings/top/io.md new file mode 100644 index 00000000..33391c1b --- /dev/null +++ b/docs/docstrings/top/io.md @@ -0,0 +1,3 @@ +# io + +::: io diff --git a/docs/docstrings/top/logging.md b/docs/docstrings/top/logging.md new file mode 100644 index 00000000..8ecbd707 --- /dev/null +++ b/docs/docstrings/top/logging.md @@ -0,0 +1,3 @@ +# logging + +::: logging diff --git a/docs/docstrings/top/namelist.md b/docs/docstrings/top/namelist.md new file mode 100644 index 00000000..c129312a --- /dev/null +++ b/docs/docstrings/top/namelist.md @@ -0,0 +1,3 @@ +# namelist + +::: namelist diff --git a/docs/docstrings/top/optional_imports.md b/docs/docstrings/top/optional_imports.md new file mode 100644 index 00000000..2810878d --- /dev/null +++ b/docs/docstrings/top/optional_imports.md @@ -0,0 +1,3 @@ +# optional_imports + +::: optional_imports diff --git a/docs/docstrings/top/types.md b/docs/docstrings/top/types.md new file mode 100644 index 00000000..49eca211 --- /dev/null +++ b/docs/docstrings/top/types.md @@ -0,0 +1,3 @@ +# types + +::: types diff --git a/docs/docstrings/top/typing.md b/docs/docstrings/top/typing.md new file mode 100644 index 00000000..32f1ab4d --- /dev/null +++ b/docs/docstrings/top/typing.md @@ -0,0 +1,3 @@ +# typing + +::: typing diff --git a/docs/docstrings/top/units.md b/docs/docstrings/top/units.md new file mode 100644 index 00000000..5d2efcf8 --- /dev/null +++ b/docs/docstrings/top/units.md @@ -0,0 +1,3 @@ +# units + +::: units diff --git a/docs/docstrings/top/utils.md b/docs/docstrings/top/utils.md new file mode 100644 index 00000000..b2ee3805 --- /dev/null +++ b/docs/docstrings/top/utils.md @@ -0,0 +1,3 @@ +# utils + +::: utils diff --git a/docs/docstrings/viz/cube_sphere.md b/docs/docstrings/viz/cube_sphere.md new file mode 100644 index 00000000..15206dbe --- /dev/null +++ b/docs/docstrings/viz/cube_sphere.md @@ -0,0 +1,3 @@ +# cube_sphere + +::: viz.cube_sphere diff --git a/docs/index.md b/docs/index.md index 650a2846..779a8781 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,102 +1,13 @@ # NDSL Documentation -NDSL allows atmospheric scientists to write focus on what matters in model development and hides away the complexities of coding for a super computer. +NDSL is a middleware for climate and weather modelling developed jointly by NOAA and NASA. It allows atmospheric scientists to focus on what matters in model development and essentially decouples performance engineering from model development. -## Quick Start +## Portable performance -Python `3.11.x` is required for NDSL and all its third party dependencies for installation. +NDSL brings together [GT4Py](https://github.com/GridTools/gt4py/) and [DaCe](https://github.com/spcl/dace/), two libraries developed for high-performance and portability. On top of those pillars, NDSL deploys a series of optimized APIs for common operations, e.g. halo exchange or domain decomposition, and tools to port existing models. -NDSL submodules `gt4py` and `dace` to point to vetted versions, use `git clone --recurse-submodule` to update the git submodules. +## Batteries-included for FV-based models -NDSL is **NOT** available on `pypi`. Installation of the package has to be local, via `pip install ./NDSL` (`-e` supported). The packages have a few options: +Historically, NDSL was developed to port the FV3 dynamical core on the cubed-sphere. Therefore, the middleware ships with ready-to-execute specialization for models based on cubed-sphere grids and FV-based models in particular. -- `ndsl[test]`: installs the test packages (based on `pytest`) -- `ndsl[develop]`: installs tools for development and tests. - -NDSL uses pytest for its unit tests, the tests are available via: - -- `pytest -x test`: running CPU serial tests (GPU as well if `cupy` is installed) -- `mpirun -np 6 pytest -x test/mpi`: running CPU parallel tests (GPU as well if `cupy` is installed) - -## Requirements & supported compilers - -For CPU backends: - -- 3.11.x >= Python < 3.12.x -- Compilers: - - GNU 11.2+ - -For GPU backends (the above plus): - -- CUDA 11.2+ -- Python package: - - `cupy` (latest with proper driver support [see install notes](https://docs.cupy.dev/en/stable/install.html)) -- Libraries: - - MPI compiled with cuda support - -## NDSL installation and testing - -NDSL is not available at `pypi`, it uses - -```bash -pip install NDSL -``` - -to install NDSL locally. - -NDSL has a few options: - -- `ndsl[test]`: installs the test packages (based on `pytest`) -- `ndsl[develop]`: installs tools for development and tests. - -Tests are available via: - -- `pytest -x test`: running CPU serial tests (GPU as well if `cupy` is installed) -- `mpirun -np 6 pytest -x test/mpi`: running CPU parallel tests (GPU as well if `cupy` is installed) - -## Configurations for Pace - -Configurations for Pace to use NDSL with different backend: - -- FV3_DACEMODE=Python[Build|BuildAndRun|Run] controls the full program optimizer behavior - - - Python: default, use stencil only, no full program optimization - - - Build: will build the program then exit. This _build no matter what_. (backend must be `dace:gpu` or `dace:cpu`) - - - BuildAndRun: same as above but after build the program will keep executing (backend must be `dace:gpu` or `dace:cpu`) - - - Run: load pre-compiled program and execute, fail if the .so is not present (_no hash check!_) (backend must be `dace:gpu` or `dace:cpu`) - -- PACE_FLOAT_PRECISION=64 control the floating point precision throughout the program. - -Install Pace with different NDSL backend: - -- Shell scripts to install Pace using NDSL backend on specific machines such as Gaea can be found in `examples/build_scripts/`. -- When cloning Pace you will need to update the repository's submodules as well: - -```bash -git clone --recursive https://github.com/ai2cm/pace.git -``` - - or if you have already cloned the repository: - -```bash -git submodule update --init --recursive -``` - -- Pace requires GCC > 9.2, MPI, and Python 3.8 on your system, and CUDA is required to run with a GPU backend. -- We recommend creating a python `venv` or conda environment specifically for Pace. - -```bash -python3 -m venv venv_name -source venv_name/bin/activate -``` - -- Inside of your pace `venv` or conda environment pip install the Python requirements, GT4Py, and Pace: - -```bash -pip3 install -r requirements_dev.txt -c constraints.txt -``` - -- There are also separate requirements files which can be installed for linting (`requirements_lint.txt`) and building documentation (`requirements_docs.txt`). +Next: get [up and running](./quickstart.md). diff --git a/docs/porting/translate/index.md b/docs/porting/translate/index.md index 1aa083bd..3b1b3876 100644 --- a/docs/porting/translate/index.md +++ b/docs/porting/translate/index.md @@ -27,6 +27,12 @@ You can create an overwrite file to manually set the threshold in you data direc ![image1.png](../../images/translate/image1.png) +### Testing with Custom Data + +The default netcdf names can be overwritten by setting `self.override_input_netcdf_name` and/or `self.override_output_netcdf_name` +for the input and output netcdf, respectively. These new names should be the entire file name (including, but not requiring, +-In or -Out suffixes), excluding the `.nc` file tag. These files **must** be located in the same folder as the default files. + ### Overwriting Arguments to your compute function The compute_func will be called automatically in the test. If your names in the netcdf are matching the `kwargs` of your function directly, no further action required: diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 00000000..a0125181 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,36 @@ +# Quickstart + +Alright - let's get you up an running! + +NDSL requires Python version `3.11` and a GNU compiler. We strongly recommend using a conda or virtual environment. + +```shell +# We have submodules for GT4Py and DaCe. Don't forget to pull them +git clone --recurse-submodules git@github.com:NOAA-GFDL/NDSL.git + +cd NDSL/ + +# We strongly recommend using conda or a virtual environment +python -m venv .venv/ +source ./venv/bin/activate + +# [optional] Install MPI if you don't have a system installation. +pip install openmpi + +# Finally, install NDSL +pip install .[demos] +``` + +Now you can run through the Jupyter notebooks in `examples/NDSL` :rocket:. + +Read on in the [user manual](./user/index.md). + +!!! note "Supported compilers" + + NDSL currently only works with the GNU compiler. Using `clang` will result in errors related to undefined OpenMP flags. + + For MacOS users, we know that `gcc` version 14 from homebrew works. + +!!! question "Why cloning the repository?" + + We are cloning the repository because NDSL is not available on `pypi`. diff --git a/docs/user/index.md b/docs/user/index.md index 292d3953..46ddc3eb 100644 --- a/docs/user/index.md +++ b/docs/user/index.md @@ -1,3 +1,51 @@ # Usage documentation This part of the documentation is geared towards users of NDSL. + +## Up and running + +See our [quickstart guide](../quickstart.md) on how to get up and running. + +## Configuration + +NDSL tries to have sensible defaults. In cases you want tweak something, here are some pointers: + +### Literal precision (float/int) + +Unspecified integer and floating point literals (e.g. `42` and `3.1415`) default to 64-bit precision. This can be changed with the environment variable `PACE_FLOAT_PRECISION`. + +For mixed precision code, you can specify the "hard coded" precision with type hints and casts, e.g. + +```python +with computation(PARALLEL), interval(...): + # Either 32-bit or 64-bit depending on `PACE_FLOAT_PRECISION` + my_int = 42 + my_float = 3.1415 + + # Always 32-bit + my_int32: int32 = 42 + my_float32: float32 = 3.1415 + + # Explicit 64-bit cast within otherwise unspecified calculation + factor = 0.5 * float64(3.1415 + 2.71828) +``` + +### Full program optimizer + +The behavior of the full program optimizer is controlled by `FV3_DACEMODE`. Valid values are: + +`Python` + +: The default. Disables full program optimization and only accelerates stencil code. + +`Build` + +: Build the program, then exit. This mode is only available for backends `dace:gpu` and `dace:cpu`. + +`BuildAndRun` + +: Build the program, then run it immediately. This mode is only available for backends `dace:gpu` and `dace:cpu`. + +`Run` + +: Load a pre-compiled program and run it. Fails if the pre-compiled program can not be found. This mode is only available for backends `dace:gpu` and `dace:cpu`. diff --git a/examples/Fortran_serialization/02_read_serialized_data_python.ipynb b/examples/Fortran_serialization/02_read_serialized_data_python.ipynb index c8265202..b584f733 100644 --- a/examples/Fortran_serialization/02_read_serialized_data_python.ipynb +++ b/examples/Fortran_serialization/02_read_serialized_data_python.ipynb @@ -57,23 +57,24 @@ ], "source": [ "import sys\n", + "\n", "# Appends the Serialbox python path to PYTHONPATH. If needed, change to appropriate path containing serialbox installation\n", - "sys.path.append('/home/ckung/Documents/Code/SMT-Nebulae/sw_stack_path/install/serialbox/python')\n", + "sys.path.append(\"/path/to/your/serialbox/python\")\n", "import serialbox as ser\n", "import numpy as np\n", "\n", "# If needed, change the path in second parameter of ser.Serializer to appropriate path that contains Fortran data via Serialbox from 01.ipynb\n", - "serializer = ser.Serializer(ser.OpenModeKind.Read,\"./Fortran/sb/\",\"FILLQ2ZERO_InOut\")\n", + "serializer = ser.Serializer(ser.OpenModeKind.Read, \"./Fortran/sb/\", \"FILLQ2ZERO_InOut\")\n", "\n", "savepoints = serializer.savepoint_list()\n", "\n", "Qin_out = serializer.read(\"q_in\", savepoints[0])\n", - "mass = serializer.read(\"m_in\", savepoints[0])\n", - "fq_out = serializer.read(\"fq_in\", savepoints[0])\n", + "mass = serializer.read(\"m_in\", savepoints[0])\n", + "fq_out = serializer.read(\"fq_in\", savepoints[0])\n", "\n", - "print('Sum of Qin_out = ', sum(sum(sum(Qin_out))))\n", - "print('Sum of mass = ', sum(sum(sum(mass))))\n", - "print('Sum of fq_out = ', sum(sum(fq_out)))" + "print(\"Sum of Qin_out = \", sum(sum(sum(Qin_out))))\n", + "print(\"Sum of mass = \", sum(sum(sum(mass))))\n", + "print(\"Sum of fq_out = \", sum(sum(fq_out)))" ] }, { @@ -105,30 +106,31 @@ " JM = Q.shape[1]\n", " LM = Q.shape[2]\n", "\n", - " TPW = np.sum(Q*MASS,2)\n", + " TPW = np.sum(Q * MASS, 2)\n", " for J in range(JM):\n", " for I in range(IM):\n", - " NEGTPW = 0.\n", + " NEGTPW = 0.0\n", " for L in range(LM):\n", - " if(Q[I,J,L] < 0.0):\n", - " NEGTPW = NEGTPW + (Q[I,J,L]*MASS[I,J,L])\n", - " Q[I,J,L] = 0.0\n", + " if Q[I, J, L] < 0.0:\n", + " NEGTPW = NEGTPW + (Q[I, J, L] * MASS[I, J, L])\n", + " Q[I, J, L] = 0.0\n", " for L in range(LM):\n", - " if(Q[I,J,L] >= 0.0):\n", - " Q[I,J,L] = Q[I,J,L]*(1.0 + NEGTPW/(TPW[I,J]-NEGTPW))\n", - " FILLQ[I,J] = -NEGTPW\n", - " \n", - "fillq2zero1(Qin_out,mass,fq_out)\n", + " if Q[I, J, L] >= 0.0:\n", + " Q[I, J, L] = Q[I, J, L] * (1.0 + NEGTPW / (TPW[I, J] - NEGTPW))\n", + " FILLQ[I, J] = -NEGTPW\n", + "\n", + "\n", + "fillq2zero1(Qin_out, mass, fq_out)\n", "\n", - "print('Sum of Qin_out = ', sum(sum(sum(Qin_out))))\n", - "print('Sum of fq_out = ', sum(sum(fq_out)))\n", + "print(\"Sum of Qin_out = \", sum(sum(sum(Qin_out))))\n", + "print(\"Sum of fq_out = \", sum(sum(fq_out)))\n", "\n", "Qin_out_ref = serializer.read(\"q_out\", savepoints[0])\n", - "mass_ref = serializer.read(\"m_out\", savepoints[0])\n", - "fq_out_ref = serializer.read(\"fq_out\", savepoints[0])\n", + "mass_ref = serializer.read(\"m_out\", savepoints[0])\n", + "fq_out_ref = serializer.read(\"fq_out\", savepoints[0])\n", "\n", - "print(np.allclose(Qin_out,Qin_out_ref))\n", - "print(np.allclose(fq_out,fq_out_ref))" + "print(np.allclose(Qin_out, Qin_out_ref))\n", + "print(np.allclose(fq_out, fq_out_ref))" ] }, { @@ -194,25 +196,27 @@ ], "source": [ "# If needed, change the path in second parameter of ser.Serializer to appropriate path that contains Fortran data via Serialbox from 01.ipynb\n", - "serializer = ser.Serializer(ser.OpenModeKind.Read,\"./Fortran_ts/sb/\",\"FILLQ2ZERO_InOut\")\n", + "serializer = ser.Serializer(\n", + " ser.OpenModeKind.Read, \"./Fortran_ts/sb/\", \"FILLQ2ZERO_InOut\"\n", + ")\n", "\n", "savepoints = serializer.savepoint_list()\n", "\n", "for currentSavepoint in savepoints:\n", " Qin_out = serializer.read(\"q_in\", currentSavepoint)\n", - " mass = serializer.read(\"m_in\", currentSavepoint)\n", - " fq_out = serializer.read(\"fq_in\", currentSavepoint)\n", + " mass = serializer.read(\"m_in\", currentSavepoint)\n", + " fq_out = serializer.read(\"fq_in\", currentSavepoint)\n", "\n", - " fillq2zero1(Qin_out,mass,fq_out)\n", + " fillq2zero1(Qin_out, mass, fq_out)\n", "\n", " Qin_out_ref = serializer.read(\"q_out\", currentSavepoint)\n", - " mass_ref = serializer.read(\"m_out\", currentSavepoint)\n", - " fq_out_ref = serializer.read(\"fq_out\", currentSavepoint)\n", + " mass_ref = serializer.read(\"m_out\", currentSavepoint)\n", + " fq_out_ref = serializer.read(\"fq_out\", currentSavepoint)\n", "\n", - " print('Current savepoint = ', currentSavepoint)\n", - " print('SUM(Qin_out) = ', sum(sum(sum(Qin_out))))\n", - " print(np.allclose(Qin_out,Qin_out_ref))\n", - " print(np.allclose(fq_out,fq_out_ref))" + " print(\"Current savepoint = \", currentSavepoint)\n", + " print(\"SUM(Qin_out) = \", sum(sum(sum(Qin_out))))\n", + " print(np.allclose(Qin_out, Qin_out_ref))\n", + " print(np.allclose(fq_out, fq_out_ref))" ] } ], diff --git a/examples/NDSL/01_gt4py_basics.ipynb b/examples/NDSL/01_gt4py_basics.ipynb index fa3797d7..98b43922 100644 --- a/examples/NDSL/01_gt4py_basics.ipynb +++ b/examples/NDSL/01_gt4py_basics.ipynb @@ -4,34 +4,32 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# **GT4Py Tutorial : Stencil Basics**\n", + "# GT4Py Tutorial: Stencil Basics\n", "\n", - "## **Introduction**\n", + "## Introduction\n", "\n", "This notebook will show how to create a simple GT4Py stencil that copies data from one variable to another.\n", "\n", - "### **Notebook Requirements**\n", + "### Notebook Requirements\n", "\n", - "- Python v3.11.x to v3.12.x\n", + "- Python v3.11.x\n", "- [NOAA/NASA Domain Specific Language Middleware](https://github.com/NOAA-GFDL/NDSL)\n", - "- `ipykernel==6.1.0`\n", - "- [`ipython_genutils`](https://pypi.org/project/ipython_genutils/)\n", "\n", - "### **Quick GT4Py (Cartesian version) Overview**\n", + "### Quick GT4Py (Cartesian version) Overview\n", "\n", - "GT4Py is a Domain Specific Language (DSL) in Python that enables a developer to write stencil computations. Compared to simply running under Python, GT4Py achieves performance when the Python code is translated and compiled into a lower level language such as C++ and CUDA, which enables the codebase to execute on a multitude of architectures. In this notebook, we will cover the basics of creating GT4Py stencils and demonstrate several intracies of the DSL. Additional information about GT4Py can be found at the [GT4Py site](https://gridtools.github.io/gt4py/latest/index.html). One small note is that this tutorial covers and uses the Cartesian version of GT4Py and not the unstructured version.\n", + "GT4Py is a Domain Specific Language (DSL) in Python that enables a developer to write stencil computations. Compared to simply running under Python, GT4Py achieves performance when the Python code is transformed and compiled into a lower level language such as C++ and CUDA. This code transformation capability also enables GT4Py-based code to execute on multiple architectures. In this notebook, we will cover the basics of creating GT4Py stencils and demonstrate several intricacies of the DSL. Additional information about GT4Py can be found in the [GT4Py documentation](https://gridtools.github.io/gt4py/latest/index.html). One small note is that this tutorial covers and uses the Cartesian version of GT4Py and not the unstructured (\"next\") version.\n", "\n", - "### **GT4Py Parallel/Execution Model**\n", + "### GT4Py Parallel/Execution Model\n", "\n", - "Within a 3-dimensional domain, GT4Py considers computations in two parts. If we assume an `(I,J,K)` coordinate system as a reference, GT4Py separates computations in the Horizontal (`IJ`) spatial plane and Vertical (`K`) spatial interval. In the Horizontal spatial plane, computations are implicitly executed in parallel, which also means that there is no assumed calculation order within the plane. In the Vertical spatial interval, comptuations are specified by an iteration policy that will be discussed later through examples.\n", + "Within a 3-dimensional domain, GT4Py considers computations in two parts. If we assume an `(I,J,K)` coordinate system as the reference, GT4Py separates computations in the horizontal (`IJ`) spatial plane and vertical (`K`) spatial dimension. In the horizontal spatial plane, computations are implicitly executed in parallel, which also means that there is no assumed calculation order within the plane. In the vertical spatial dimension, computations are specified by an iteration policy that will be discussed later through examples.\n", "\n", - "Another quick note is that the computations are executed sequentially in the order they appear in code.\n", + "Another quick note is that the stencil computations are executed sequentially in the order they appear in code.\n", "\n", - "## **Tutorial**\n", + "## Tutorial\n", "\n", - "### **Copy Stencil example**\n", + "### Copy Stencil example\n", "\n", - "To demonstrate how to implement a GT4Py stencil, we'll step through an example that copies the values of one array into another array. First, we import several packages. " + "To demonstrate how to implement a GT4Py stencil, we'll step through a code example that copies the values of one array into another array. First, we import several packages." ] }, { @@ -50,7 +48,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As we walk through the example, we'll highlight different terms and such from the imported packages. Let's first define, in GT4Py terms, two arrays of size 5 by 5 by 2 (dimensionally `I` by `J` by `K`). These arrays are defined using a `Quantity` object, an NDSL data container for physical quantities. More detailed information about the `Quantity` object and its arguments can be found from the [`Quantity` docstring](https://github.com/NOAA-GFDL/NDSL/blob/develop/ndsl/quantity.py#L270). To make debugging easier, the `numpy` backend will be used." + "As we walk through the example, we'll highlight different terms and such from the imported packages. Let's first define, in GT4Py terms, two arrays of size 5 by 5 by 2 (dimensionally `I` by `J` by `K`). These arrays are defined using a `Quantity` object, an NDSL data container for physical quantities. More detailed information about the `Quantity` object and its arguments can be found from the [`Quantity` documentation](https://noaa-gfdl.github.io/NDSL/docstrings/quantity/quantity/#quantity.quantity.Quantity.__init__). To make debugging easier, we'll use the `numpy` backend." ] }, { @@ -59,7 +57,7 @@ "metadata": {}, "outputs": [], "source": [ - "backend = 'numpy'\n", + "backend = \"numpy\"\n", "\n", "nx = 5\n", "ny = 5\n", @@ -67,25 +65,22 @@ "\n", "shape = (nx, ny, nz)\n", "\n", - "qty_out = Quantity(data=np.zeros([nx, ny, nz]),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", - "arr = np.indices(shape,dtype=float).sum(axis=0) # Value of each entry is sum of the I and J index at each point\n", + "arr = np.indices(shape, dtype=float).sum(\n", + " axis=0\n", + ") # Value of each entry is sum of the I and J index at each point\n", "\n", - "qty_in = Quantity(data=arr,\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend)" + "qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We will next create a simple GT4Py stencil that copies values from one input to another. A stencil will look like a Python subroutine or function except that it uses specific GT4Py functionalities." + "Next, we will next create a simple GT4Py stencil that copies values from one argument to another. A stencil will look like a Python function except that it uses specific GT4Py functionalities." ] }, { @@ -95,8 +90,7 @@ "outputs": [], "source": [ "@stencil(backend=backend)\n", - "def copy_stencil(input_field: FloatField,\n", - " output_field: FloatField):\n", + "def copy_stencil(input_field: FloatField, output_field: FloatField) -> None:\n", " with computation(PARALLEL), interval(...):\n", " output_field = input_field" ] @@ -105,18 +99,44 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As mentioned before, GT4Py (cartesian version) was designed for stencil-based computation. Since stencil calculations generally are localized computations, GT4Py stencils are written using variables and the variable's relative location if it's an array. If there are no indices in brackets next to a GT4Py type (such as `FloatField`), it's implied to be at the [0] (for 1-dimension), [0,0] (for 2-dimension), or [0,0,0] (for 3-dimension) location. For the simple example `copy_stencil`, the value of `input_field` simply gets copied to `output_field` at every point in the domain of interest.\n", + "As mentioned before, GT4Py (the cartesian version) was designed for stencil-based computation. Since stencil calculations generally are localized computations, GT4Py stencils are written using variables and the variable's **relative location** if the variable is an array. For example, a stencil-based [Laplacian](https://en.wikipedia.org/wiki/Discrete_Laplace_operator#Finite_differences) calculation can be implemented as follows." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def laplacian(input_field: FloatField, output_field: FloatField):\n", + " with computation(PARALLEL), interval(...):\n", + " output_field = (\n", + " -4.0 * input_field[0, 0, 0]\n", + " + input_field[1, 0, 0]\n", + " + input_field[-1, 0, 0]\n", + " + input_field[0, 1, 0]\n", + " + input_field[0, -1, 0]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the `laplacian` stencil, every value of `output_field` at a particular location is a function of `input_field` at that same location (`input_field[0,0,0]`), the adjacent `input_field` values in the `I`-direction (`input_field[1,0,0]` and `input_field[-1,0,0]`), and the adjacent `input_field` values in the `J`-direction (`input_field[0,1,0]` and `input_field[0,-1,0]`)\n", "\n", - "We see that this stencil does not contain any explicit loops. As mentioned above in the notebook, GT4Py has a particular computation policy that implicitly executes in parallel within an `IJ` plane and is user defined in the `K` interval. This execution policy in the `K` interval is dictated by the `computation` and `interval` keywords. \n", + "If there are no indices in brackets next to a GT4Py type (such as `FloatField`), it's implied to be at the `[0]` (for 1-dimension), `[0,0]` (for 2-dimension), or `[0,0,0]` (for 3-dimension) location. Note again that these locations are **relative**. In the `laplacian` stencil, `output_field` has an implied `[0,0,0]` location, and `input_field[0,0,0]` could have been written simply as `input_field`.\n", "\n", - "- `with computation(PARALLEL)` means that there's no order preference to executing the `K` interval. This also means that the `K` interval can be computed in parallel to potentially gain performance if computational resources are available.\n", + "We see that both, `copy_stencil` and the `laplacian` stencil, do not contain any explicit loops. As mentioned above, GT4Py has a particular computation policy that implicitly executes in parallel within the `IJ` plane and is user defined in the `K` dimension. This execution policy in the `K` dimension is dictated by the `computation` and `interval` keywords.\n", + "- `with computation(PARALLEL)` means that there's no order preference to executing the `K` dimension. This effectively means that the `K` dimension can be computed in parallel to potentially gain performance if computational resources are available.\n", "\n", - "- `interval(...)` means that the entire `K` interval is executed. Instead of `(...)`, more specific intervals can be specified using a tuple of two integers. For example... \n", + "- `interval(...)` means that the \"entire\" `K` dimension range is executed (as dictated by the keywords `origin` and `domain`, which will be covered [later in the notebook](#setting-domain-subsets-in-a-stencil-call)). Instead of `(...)`, more specific intervals can be specified using a tuple of two integers. For example...\n", "\n", - " - `interval(0,2)` : The interval `K` = 0 to 1 is executed.\n", - " - `interval(0,-1)` : The interval `K` = 0 to N-2 (where N is the size of `K`) is executed.\n", + " - `interval(0, 2)` : The interval `K` = 0 to 1 is executed.\n", + " - `interval(0, -1)` : The interval `K` = 0 to N-2 (where N is the size of `K`) is executed.\n", + " - Note: The actual intervals for the above two `interval` are also dictated by [`origin` and `domain`](#setting-domain-subsets-in-a-stencil-call)\n", "\n", - "The decorator `@stencil(backend=backend)` (Note: `stencil` comes from the package `ndsl.dsl.gt4py`) converts `copy_stencil` to use the specified `backend` to \"compile\" the stencil. `stencil` can also be a function call to create a stencil object." + "The decorator `@stencil(backend=backend)` (Note: `stencil` comes from the package `ndsl.dsl.gt4py`) converts `copy_stencil` to use the specified `backend` to \"compile\" the stencil. `stencil` can also be a function call to create a stencil object." ] }, { @@ -132,9 +152,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that the input and output parameters to `copy_stencil` are of type `FloatField`, which can essentially be thought of as a 3-dimensional NumPy array of `float` types.\n", + "Note that the input and output parameters to `copy_stencil` and `laplacian` are of type `FloatField`, which can essentially be thought of as a 3-dimensional NumPy array of `float` types.\n", "\n", - "`plot_field_at_kN` plots the values within the `IJ` plane at `K = 0` if no integer is specified or at `K` equal to the integer that is specified as an argument. As we can see in the plots below, `copy_stencil` copies the values from `qty_in` into `qty_out`." + "`plot_field_at_kN` plots the values within the `IJ` plane at `K = 0` if no integer is specified or at `K` equal to the integer that is specified as an argument. As we can see in the plots below, `copy_stencil` copies the values from `qty_in` into `qty_out`." ] }, { @@ -159,15 +179,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Setting domain subsets in a stencil call**\n", + "### Setting domain subsets in a stencil call\n", + "\n", + "GT4Py also allows a subset to be specified from a stencil call and executed in a fashion similar to using `interval(...)` in the `K` dimension. This is done by setting the stencil call's `origin` and `domain` argument.\n", "\n", - "GT4Py also allows a subset to be specified from a stencil call and executed in a fashion similar to using `interval(...)` in the K interval. This is done by setting the stencil call's `origin` and `domain` argument.\n", + "- `origin`: This specifies the \"starting\" coordinate to perform computations.\n", "\n", - "- `origin` : This specifies the \"starting\" coordinate to perform computations. \n", + "- `domain`: This specifies the range of the stencil computation based on `origin` as the \"starting\" coordinate.\n", "\n", - "- `domain` : This specifies the range of the stencil computation based on `origin` as the \"starting\" coordinate (Note: May need to check whether this affects `interval()`)\n", + "- Note: May need to check whether `domain` and `origin` affect `interval()`\n", "\n", - "If these two parameters are not set, the stencil call by default will iterate over the entire input domain. The following demonstrates the effect of specifying different `origin` and `domain`." + "If these two parameters are not set, the stencil call by default will iterate over the entire input domain. The following demonstrates the effect of specifying different `origin` and `domain`." ] }, { @@ -176,11 +198,9 @@ "metadata": {}, "outputs": [], "source": [ - "qty_out = Quantity(data=np.zeros([nx, ny, nz]),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", "print(\"Plotting values of qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", @@ -191,11 +211,9 @@ "print(\"Plotting qty_out at K = 0 based on `copy_stencil` with origin=(1, 0, 0)\")\n", "qty_out.plot_k_level(0)\n", "\n", - "qty_out = Quantity(data=np.zeros([nx, ny, nz]),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", "print(\"Resetting qty_out to zero...\")\n", "print(\"Plotting values of qty_out at K = 0\")\n", @@ -205,11 +223,9 @@ "print(\"Plotting qty_out at K = 0 based on `copy_stencil` with origin=(0, 1, 0)\")\n", "qty_out.plot_k_level(0)\n", "\n", - "qty_out = Quantity(data=np.zeros([nx, ny, nz]),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", "print(\"Resetting qty_out to zero...\")\n", "print(\"Plotting values of qty_out at K = 0\")\n", @@ -221,11 +237,9 @@ "print(\"Plotting qty_out at K = 1 based on `copy_stencil` with origin=(0, 0, 1)\")\n", "qty_out.plot_k_level(1)\n", "\n", - "qty_out = Quantity(data=np.zeros([nx, ny, nz]),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "print(\"Resetting qty_out to zero...\")\n", "print(\"Plotting values of qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", @@ -236,17 +250,17 @@ "print(\"Plotting qty_out at K = 0 based on `copy_stencil` with domain=(2, 2, nz)\")\n", "qty_out.plot_k_level(0)\n", "\n", - "qty_out = Quantity(data=np.zeros([nx, ny, nz]),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "print(\"Resetting qty_out to zero...\")\n", "print(\"Plotting values of qty_out at K = 0\")\n", "qty_out.plot_k_level(0)\n", "print(\"Executing `copy_stencil` with origin=(2, 2, 0), domain=(2, 2, nz)\")\n", "copy_stencil(qty_in, qty_out, origin=(2, 2, 0), domain=(2, 2, nz))\n", - "print(\"Plotting qty_out at K = 0 based on `copy_stencil` with origin=(2, 2, 0), domain=(2, 2, nz)\")\n", + "print(\n", + " \"Plotting qty_out at K = 0 based on `copy_stencil` with origin=(2, 2, 0), domain=(2, 2, nz)\"\n", + ")\n", "qty_out.plot_k_level(0)" ] }, @@ -254,13 +268,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **`FORWARD` and `BACKWARD` `computation` keywords and Offset Indexing within a stencil call**\n", + "### `FORWARD` and `BACKWARD` `computation` keywords and Offset Indexing within a stencil call\n", "\n", - "Besides `PARALLEL`, the developer can specify `FORWARD` or `BACKWARD` as the iteration policy in `K` for a stencil. Essentially, the `FORWARD` policy has `K` iterating consecutively starting from the lowest vertical index to the highest, while the `BACKWARD` policy performs the reverse.\n", + "Besides `PARALLEL`, the developer can specify `FORWARD` or `BACKWARD` as the iteration policy in `K` for a stencil. Essentially, the `FORWARD` policy has `K` iterating consecutively starting from the lowest vertical index to the highest, while the `BACKWARD` policy performs the reverse.\n", "\n", - "An array-based stencil variable can also have an integer dimensional offset if the array variable is on the right hand side of the `=` for the computation. When a computation is performed at a particular point, an offset variable's coordinate is based on that particular point plus (or minus) the offset in the offset dimension.\n", + "An array-based stencil variable can also have an integer dimensional offset if the array variable is on the right hand side of the `=` for the computation. When a computation is performed at a particular point, an offset variable's coordinate is based on that particular point plus (or minus) the offset in the offset dimension.\n", "\n", - "The following examples demonstrate the use of these two iteration policies and also offset indexing in the `K` dimension. Note that offsets can also be applied to the `I` or `J` dimension." + "The following examples demonstrate the use of these two iteration policies and also offset indexing in the `K` dimension. Note that offsets can also be applied to the `I` or `J` dimension." ] }, { @@ -275,22 +289,18 @@ "ny = 5\n", "nz = 5\n", "nhalo = 1\n", - "backend=\"numpy\"\n", + "backend = \"numpy\"\n", "\n", "shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n", "\n", - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", - "arr = np.indices(shape,dtype=float).sum(axis=0) # Value of each entry is sum of the I and J index at each point\n", - "qty_in = Quantity(data=arr,\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "arr = np.indices(shape, dtype=float).sum(\n", + " axis=0\n", + ") # Value of each entry is sum of the I and J index at each point\n", + "qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n", "\n", "print(\"Plotting values of qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", @@ -299,36 +309,46 @@ "print(\"Plotting values of qty_in at K = 2\")\n", "qty_in.plot_k_level(2)\n", "\n", + "\n", "@stencil(backend=backend)\n", "def mult_upward(qty_in: FloatField, qty_out: FloatField):\n", " with computation(FORWARD), interval(...):\n", - " qty_out = qty_in[0,0,-1] * 2.0\n", + " qty_out = qty_in[0, 0, -1] * 2.0\n", + "\n", "\n", "print(\"Executing 'mult_upward' with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\")\n", - "mult_upward(qty_in, qty_out, origin=(nhalo,nhalo,1), domain=(nx,ny,2))\n", - "print(\"Plotting values of qty_out at K = 0 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\")\n", + "mult_upward(qty_in, qty_out, origin=(nhalo, nhalo, 1), domain=(nx, ny, 2))\n", + "print(\n", + " \"Plotting values of qty_out at K = 0 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\"\n", + ")\n", "qty_out.plot_k_level(0)\n", - "print(\"Plotting values of qty_out at K = 1 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\")\n", + "print(\n", + " \"Plotting values of qty_out at K = 1 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\"\n", + ")\n", "qty_out.plot_k_level(1)\n", - "print(\"Plotting values of qty_out at K = 2 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\")\n", + "print(\n", + " \"Plotting values of qty_out at K = 2 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\"\n", + ")\n", "qty_out.plot_k_level(2)\n", - "print(\"Plotting values of qty_out at K = 3 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\")\n", + "print(\n", + " \"Plotting values of qty_out at K = 3 with origin=(nhalo, nhalo, 1), domain=(nx, ny, 2)\"\n", + ")\n", "qty_out.plot_k_level(3)\n", "\n", + "\n", "@stencil(backend=backend)\n", "def copy_downward(qty_in: FloatField, qty_out: FloatField):\n", " with computation(BACKWARD), interval(...):\n", - " qty_out = qty_in[0,0,1]\n", + " qty_out = qty_in[0, 0, 1]\n", + "\n", "\n", "print(\"Resetting qty_out to zeros\")\n", - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", "print(\"Executing 'copy_downward' with origin=(1, 1, 0), domain=(nx, ny, nz-1)\")\n", - "copy_downward(qty_in, qty_out, origin=(1, 1, 0), domain=(nx, ny, nz-1))\n", + "copy_downward(qty_in, qty_out, origin=(1, 1, 0), domain=(nx, ny, nz - 1))\n", "print(\"***\")\n", "print(\"Plotting values of qty_out at K = 0\")\n", "qty_out.plot_k_level(0)\n", @@ -342,7 +362,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Regarding offsets, GT4Py does not allow offsets to variables in the left hand side of the `=`. Uncomment and execute the below code to see the error `Assignment to non-zero offsets is not supported.`." + "Regarding offsets, GT4Py does only allow offsets to variables in the left hand side of the `=` in the `K` dimension. Uncomment and execute the below code to see the error message \"Assignment to non-zero offsets is not supported in IJ\"." ] }, { @@ -361,9 +381,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Limits to offset : Cannot set offset outside of usable domain**\n", + "### Limits to offset : Cannot set offset outside of usable domain\n", "\n", - "Note that there are limits to the offsets that can be applied in the stencil. An error will result if the specified shift results attemps to read data that is not available or allocated. In the example below, a shift of -2 in the `J` axis will shift `field_in` out of its possible range in `J`." + "Note that there are limits to the offsets that can be applied in the stencil. An error will result if the specified shift results attempts to read data that is not available or allocated. In the example below, a shift of -2 in the `J` axis will shift `field_in` out of its possible range in `J`." ] }, { @@ -376,34 +396,34 @@ "ny = 5\n", "nz = 5\n", "nhalo = 1\n", - "backend=\"numpy\"\n", + "backend = \"numpy\"\n", "\n", "shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n", "\n", - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", + "\n", + "arr = np.indices(shape, dtype=float).sum(\n", + " axis=0\n", + ") # Value of each entry is sum of the I and J index at each point\n", + "qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n", "\n", - "arr = np.indices(shape,dtype=float).sum(axis=0) # Value of each entry is sum of the I and J index at each point\n", - "qty_in = Quantity(data=arr,\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", "\n", "@stencil(backend=backend)\n", "def copy_stencil_offset(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", - " field_out = field_in[0,-2,0]\n", + " field_out = field_in[0, -2, 0]\n", + "\n", "\n", "print(\"Executing 'copy_stencil' with origin=(nhalo, nhalo, 0), domain=(nx, ny, nz)\")\n", "copy_stencil(qty_in, qty_out, origin=(nhalo, nhalo, 0), domain=(nx, ny, nz))\n", "print(\"Executing 'copy_stencil' where qty_out is copied back to qty_in\")\n", "copy_stencil(qty_out, qty_in)\n", "qty_in.plot_k_level(0)\n", - "print(\"Executing 'copy_stencil_offset' where origin=(nhalo, nhalo, 0), domain=(nx, ny, nz)\")\n", + "print(\n", + " \"Executing 'copy_stencil_offset' where origin=(nhalo, nhalo, 0), domain=(nx, ny, nz)\"\n", + ")\n", "copy_stencil_offset(qty_in, qty_out, origin=(nhalo, nhalo, 0), domain=(nx, ny, nz))\n", "qty_out.plot_k_level(0)" ] @@ -412,9 +432,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **`if/else` statements**\n", + "### `if/else` statements\n", "\n", - "GT4Py allows for `if/else` statements to exist within a stencil. The following simple example shows a stencil `stencil_if_zero` modifing values of `in_out_field` depending on its initial value." + "GT4Py allows for `if/else` statements to exist within a stencil. The following simple example shows a stencil `stencil_if_zero` modifying values of `in_out_field` depending on its initial value." ] }, { @@ -423,18 +443,14 @@ "metadata": {}, "outputs": [], "source": [ - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", - "\n", - "arr = np.indices(shape,dtype=float).sum(axis=0) # Value of each entry is sum of the I and J index at each point\n", - "qty_in = Quantity(data=arr,\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", + "\n", + "arr = np.indices(shape, dtype=float).sum(\n", + " axis=0\n", + ") # Value of each entry is sum of the I and J index at each point\n", + "qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n", "\n", "print(\"Plotting values of qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", @@ -442,11 +458,16 @@ "qty_out.plot_k_level(0)\n", "print(\"Running copy_stencil with origin=(nhalo, nhalo, 0), domain=(nx, ny, 5)\")\n", "copy_stencil(qty_in, qty_out, origin=(nhalo, nhalo, 0), domain=(nx, ny, 5))\n", - "print(\"Plotting values of qty_out at K = 0 based on running copy_stencil with origin=(nhalo, nhalo, 0), domain=(nx, ny, 5)\")\n", + "print(\n", + " \"Plotting values of qty_out at K = 0 based on running copy_stencil with origin=(nhalo, nhalo, 0), domain=(nx, ny, 5)\"\n", + ")\n", "qty_out.plot_k_level(0)\n", - "print(\"Plotting values of qty_out at K = 1 based on running copy_stencil with origin=(nhalo, nhalo, 0), domain=(nx, ny, 5)\")\n", + "print(\n", + " \"Plotting values of qty_out at K = 1 based on running copy_stencil with origin=(nhalo, nhalo, 0), domain=(nx, ny, 5)\"\n", + ")\n", "qty_out.plot_k_level(1)\n", "\n", + "\n", "@stencil(backend=backend)\n", "def stencil_if_zero(in_out_field: FloatField):\n", " with computation(PARALLEL), interval(...):\n", @@ -454,6 +475,8 @@ " in_out_field = 30\n", " else:\n", " in_out_field = 10\n", + "\n", + "\n", "print(\"Running 'stencil_if_zero' on qty_out\")\n", "stencil_if_zero(qty_out)\n", "print(\"Plotting values of qty_out at K = 0 based on running stencil_if_zero\")\n", @@ -466,9 +489,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Function calls**\n", + "### Function calls\n", "\n", - "GT4Py also has the capability to create functions in order to better organize code. The main difference between a GT4Py function call and a GT4Py stencil is that a function does not (and cannot) contain the keywords `computation` and `interval`. However, array index referencing within a GT4py function is the same as in a GT4Py stencil.\n", + "GT4Py also has the capability to create functions in order to better organize code. The main difference between a GT4Py function call and a GT4Py stencil is that a function does not (and cannot) contain the keywords `computation` and `interval`. However, array index referencing within a GT4py function is the same as in a GT4Py stencil.\n", "\n", "GT4Py functions can be created by using the decorator `function` (Note: `function` originates from the package `ndsl.dsl.gt4py`)." ] @@ -481,35 +504,34 @@ "source": [ "from ndsl.dsl.gt4py import function\n", "\n", + "\n", "@function\n", "def plus_one(field: FloatField):\n", - " return field[0, 0, 0] + 1\n", + " return field[0, 0, 0] + 1\n", + "\n", "\n", "@stencil(backend=backend)\n", "def field_plus_one(source: FloatField, target: FloatField):\n", - " with computation(PARALLEL), interval(...):\n", - " target = plus_one(source)\n", + " with computation(PARALLEL), interval(...):\n", + " target = plus_one(source)\n", + "\n", "\n", "nx = 5\n", "ny = 5\n", "nz = 5\n", "nhalo = 1\n", - "backend=\"numpy\"\n", + "backend = \"numpy\"\n", "\n", "shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n", "\n", - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", - "\n", - "arr = np.indices(shape, dtype=float).sum(axis=0) # Value of each entry is sum of the I and J index at each point\n", - "qty_in = Quantity(data=arr,\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", + "\n", + "arr = np.indices(shape, dtype=float).sum(\n", + " axis=0\n", + ") # Value of each entry is sum of the I and J index at each point\n", + "qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n", "\n", "print(\"Plotting values of qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", diff --git a/examples/NDSL/02_NDSL_basics.ipynb b/examples/NDSL/02_NDSL_basics.ipynb index 1c572ce1..b9d53dca 100644 --- a/examples/NDSL/02_NDSL_basics.ipynb +++ b/examples/NDSL/02_NDSL_basics.ipynb @@ -4,19 +4,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# **NDSL Basics** #\n", + "# NDSL Basics\n", "\n", - "### **Introduction**\n", - "After establishing the basics of using GT4Py, we'll take a look at developing an object-oriented coding approach with the NDSL middleware. Much of the object-oriented work comes from the development of [Pace](https://github.com/NOAA-GFDL/pace), the implementation of the FV3GFS / SHiELD atmospheric model using GT4Py and [DaCe](https://github.com/spcl/dace). The `StencilFactory` object will be introduced and demoed." + "### Introduction\n", + "After establishing the basics of using GT4Py, we will develop an object-oriented coding approach with the NDSL middleware. Much of the object-oriented work comes from the development of [Pace](https://github.com/NOAA-GFDL/pace), the implementation of the FV3GFS / SHiELD atmospheric model using GT4Py and [DaCe](https://github.com/spcl/dace). The `StencilFactory` object will be introduced and demoed." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### **Creating the `StencilFactory` object**\n", + "### Creating the `StencilFactory` object\n", "\n", - "The `StencilFactory` object enables the sharing of stencil properties across multiple stencils as well as \"build and execute\" the stencil. To help ease the introduction, the [`boilerplate` module](./boilerplate.py) contains a function `get_one_tile_factory` that takes the domain size, halo size, and backend of interest and returns a `StencilFactory` object. For more details about the objects needed to create the `StencilFactory`, the reader can view the [`get_one_tile_factory`](./boilerplate.py#get_one_tile_factory) function." + "The `StencilFactory` object enables the sharing of stencil properties across multiple stencils as well as \"build and execute\" the stencil. To help ease the introduction, the [`boilerplate` module](./basic_boilerplate.py) contains a function `get_one_tile_factory` that takes the domain size, halo size, and backend of interest and returns a `StencilFactory` object. For more details about the objects needed to create the `StencilFactory`, the reader can view the [`get_one_tile_factory`](./basic_boilerplate.py#get_one_tile_factory) function." ] }, { @@ -40,13 +40,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Creating the Copy stencil**\n", + "### Creating the Copy stencil\n", "\n", - "The `NDSL` and `gt4py` module contain key terms that will be used to create the stencil. Many terms are covered in the [GT4Py basic tutorial](./01_gt4py_basics.ipynb) notebook, but we'll briefly recap.\n", + "The `NDSL` and `gt4py` module contain key terms that will be used to create the stencil. Many terms are covered in the [GT4Py basic tutorial](./01_gt4py_basics.ipynb) notebook, but we'll briefly recap.\n", "\n", "- `FloatField` : This type can generally can be thought of as a `gt4py` 3-dimensional `numpy` array of floating point values.\n", "\n", - "- `computation(PARALLEL)` : This keyword combination means that there is no assumed order to perform calculations in the `K` (3rd) dimension of a `gt4py` storage. `PARALLEL` can be replaced by `FORWARD` or `BACKWARD` for serialized calculations in the `K` dimension.\n", + "- `computation(PARALLEL)` : This keyword combination means that there is no assumed order to perform calculations in the `K` (3rd) dimension of a `gt4py` storage. `PARALLEL` can be replaced by `FORWARD` or `BACKWARD` for serialized calculations in the `K` dimension.\n", "\n", "- `interval(...)` : This keyword specifies the range of computation in the `K` dimension.\n", "\n", @@ -62,6 +62,7 @@ "from ndsl.dsl.gt4py import PARALLEL, computation, interval\n", "from ndsl.dsl.typing import FloatField\n", "\n", + "\n", "def copy_field_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", " field_out = field_in" @@ -71,16 +72,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that a decorator does not surround this stencil as shown before in the [basic tutorial](./01_gt4py_basics.ipynb). Instead, we'll use the `StencilFactory` to \"initiate\" the stencil." + "Note that a decorator does not surround this stencil as shown before in the [basic tutorial](./01_gt4py_basics.ipynb). Instead, we'll use the `StencilFactory` to \"initiate\" the stencil." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### **Creating a class that performs a stencil computation**\n", + "### Creating a class that performs a stencil computation\n", "\n", - "Using the `StencilFactory` object created earlier, the code will now create a class `CopyField` that takes `copy_field_stencil` and defines the computation domain from the parameters `origin` and `domain` within `__init__`. `origin` indicates the \"starting\" point of the stencil calculation, and `domain` indicates the extent of the stencil calculation in the three dimensions. Note that when creating `stencil_factory`, a 6 by 6 by 1 sized domain surrounded with a halo layer of size 1 was defined. Thus, whenever a `CopyField` object is created, it will perform calculations within the 6 by 6 by 1 domain (specified by `domain=grid_indexing.domain_compute()`), and the `origin` will start at the `[0,0,0]` location of the 6 by 6 by 1 grid (specified by `origin=grid_indexing.origin_compute()`)." + "Using the `StencilFactory` object created earlier, the code will now create a class `CopyField` that takes `copy_field_stencil` and defines the computation domain from the parameters `origin` and `domain` within `__init__`. `origin` indicates the \"starting\" point of the stencil calculation, and `domain` indicates the extent of the stencil calculation in the three dimensions. Note that when creating `stencil_factory`, a 6 by 6 by 1 sized domain surrounded with a halo layer of size 1 was defined. Thus, whenever a `CopyField` object is created, it will perform calculations within the 6 by 6 by 1 domain (specified by `domain=grid_indexing.domain_compute()`), and the `origin` will start at the `[0,0,0]` location of the 6 by 6 by 1 grid (specified by `origin=grid_indexing.origin_compute()`)." ] }, { @@ -93,19 +94,19 @@ " def __init__(self, stencil_factory: StencilFactory):\n", " grid_indexing = stencil_factory.grid_indexing\n", " self._copy_field = stencil_factory.from_origin_domain(\n", - " copy_field_stencil, # <-- gt4py stencil function wrapped into NDSL\n", + " copy_field_stencil, # <-- gt4py stencil function wrapped into NDSL\n", " origin=grid_indexing.origin_compute(),\n", " domain=grid_indexing.domain_compute(),\n", " )\n", "\n", - " def __call__( # <-- Runtime path\n", + " def __call__( # <-- Runtime path\n", " self,\n", " field_in: FloatField,\n", " field_out: FloatField,\n", " ):\n", " self._copy_field(field_in, field_out)\n", - " \n", - " \n", + "\n", + "\n", "copy_field = CopyField(stencil_factory)" ] }, @@ -113,9 +114,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Allocating Data in `NDSL`**\n", + "### Allocating Data in `NDSL`\n", "\n", - "The next code section will create arrays using `Quantity`. For more information about `Quantity`, see the [GT4Py Basic tutorial](./01_gt4py_basics.ipynb#Copy_Stencil_example)." + "The next code section will create arrays using `Quantity`. For more information about `Quantity`, see the [GT4Py Basic tutorial](./01_gt4py_basics.ipynb#Copy_Stencil_example)." ] }, { @@ -131,21 +132,16 @@ "size = (nx + 2 * nhalo) * (ny + 2 * nhalo) * nz\n", "shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n", "\n", - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", - "\n", - "\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", "\n", - "arr = np.indices(shape,dtype=float).sum(axis=0) # Value of each entry is sum of the I and J index at each point\n", + "arr = np.indices(shape, dtype=float).sum(\n", + " axis=0\n", + ") # Value of each entry is sum of the I and J index at each point\n", "\n", - "qty_in = Quantity(data=arr,\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend)\n", + "qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n", "\n", "print(\"Plotting qty_in at K = 0\")\n", "qty_in.plot_k_level(0)\n", @@ -157,7 +153,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Calling `copy_field` stencil**\n", + "### Calling `copy_field` stencil\n", "\n", "The code will call `copy_field` to execute `copy_field_stencil` using the previously defined `Quantity` data containers and plot the result at `K = 0`." ] @@ -178,20 +174,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "From the plot, we see that the copy is only applied to the inner 6 by 6 area and not the entire domain. The stencil in this case only applies in this \"domain\" and not the \"halo\" region surrounding the domain." + "From the plot, we see that the copy is only applied to the inner 6 by 6 area and not the entire domain. The stencil in this case only applies in this \"domain\" and not the \"halo\" region surrounding the domain." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### **Applying a J offset**\n", + "### Applying a J offset\n", "\n", - "The next example will create a stencil that takes a `Quantity` as an input, shift the input by 1 in the `-J` direction, and write it to an output `Quantity`. This stencil is defined in `copy_field_offset_stencil`.\n", + "The next example will create a stencil that takes a `Quantity` as an input, shift the input by 1 in the `-J` direction, and write it to an output `Quantity`. This stencil is defined in `copy_field_offset_stencil`.\n", "\n", - "Note that in `copy_field_offset_stencil`, the shift in the `J` dimension is performed by referencing the `J` object from `ndsl.dsl.gt4py` for simplicity. This reference will apply the shift in `J` to the entire input domain. Another way to perform the shift without referencing the `J` object is to write `[0,-1,0]` (assuming that the variable being modified is 3-dimensional) instead of `[J-1]`.\n", + "Note that in `copy_field_offset_stencil`, the shift in the `J` dimension is performed by referencing the `J` object from `ndsl.dsl.gt4py` for simplicity. This reference will apply the shift in `J` to the entire input domain. Another way to perform the shift without referencing the `J` object is to write `[0,-1,0]` (assuming that the variable being modified is 3-dimensional) instead of `[J-1]`.\n", "\n", - "With the stencil in place, a class `CopyFieldOffset` is defined using the `StencilFactory` object and `copy_field_offset_stencil`. The class is instantiated and demonstrated to shift `qty_in` by 1 in the J-dimension and write to `qty_out`." + "With the stencil in place, a class `CopyFieldOffset` is defined using the `StencilFactory` object and `copy_field_offset_stencil`. The class is instantiated and demonstrated to shift `qty_in` by 1 in the J-dimension and write to `qty_out`." ] }, { @@ -202,10 +198,12 @@ "source": [ "from ndsl.dsl.gt4py import J\n", "\n", + "\n", "def copy_field_offset_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", - " field_out = field_in[J-1]\n", - " \n", + " field_out = field_in[J - 1]\n", + "\n", + "\n", "class CopyFieldOffset:\n", " def __init__(self, stencil_factory: StencilFactory):\n", " grid_indexing = stencil_factory.grid_indexing\n", @@ -221,14 +219,13 @@ " field_out: FloatField,\n", " ):\n", " self._copy_field_offset(field_in, field_out)\n", - " \n", + "\n", + "\n", "copy_field_offset = CopyFieldOffset(stencil_factory)\n", - " \n", - "qty_out = Quantity(data=np.zeros(shape),\n", - " dims=[\"I\", \"J\", \"K\"],\n", - " units=\"m\",\n", - " gt4py_backend=backend\n", - " )\n", + "\n", + "qty_out = Quantity(\n", + " data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n", + ")\n", "\n", "print(\"Initialize qty_out to zeros\")" ] @@ -249,7 +246,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Limits to offset : Cannot set offset outside of usable domain**\n", + "### Limits to offset : Cannot set offset outside of usable domain\n", "\n", "Note that when the copy offset by `-1` in the `j`-direction is performed, the 'halo' region at `J = 8` is copied over due to the `J` shift. This means that there are limits to the shift amount since choosing a large shift amount may result in accessing a data region that does not exist. The following example shows this by trying to perform a shift by `-2` in the `j`-direction." ] @@ -262,8 +259,9 @@ "source": [ "def copy_field_offset_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", - " field_out = field_in[J-2]\n", - " \n", + " field_out = field_in[J - 2]\n", + "\n", + "\n", "class CopyFieldOffset:\n", " def __init__(self, stencil_factory: StencilFactory):\n", " grid_indexing = stencil_factory.grid_indexing\n", @@ -279,7 +277,8 @@ " field_out: FloatField,\n", " ):\n", " self._copy_field_offset(field_in, field_out)\n", - " \n", + "\n", + "\n", "copy_field_offset = CopyFieldOffset(stencil_factory)\n", "\n", "copy_field_offset(qty_in, qty_out)" @@ -289,9 +288,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Example demonstrating error when writing to offset outputs**\n", + "### Example demonstrating error when writing to offset outputs\n", "\n", - "While offsets can be applied to all input `Quantity` variables in a stencil, output `Quantity` variables cannot have such offsets. When an offset is applied to an output stencil calculation, the error `GTScriptSyntaxError: Assignment to non-zero offsets is not supported.` will be displayed." + "While offsets can be applied to all input `Quantity` variables in a stencil, output `Quantity` variables cannot have such offsets. When an offset is applied to an output stencil calculation, the error `GTScriptSyntaxError: Assignment to non-zero offsets is not supported in IJ." ] }, { @@ -302,8 +301,9 @@ "source": [ "def copy_field_offset_output_stencil(field_in: FloatField, field_out: FloatField):\n", " with computation(PARALLEL), interval(...):\n", - " field_out[0,1,0] = field_in\n", - " \n", + " field_out[0, 1, 0] = field_in\n", + "\n", + "\n", "class CopyFieldOffsetOutput:\n", " def __init__(self, stencil_factory: StencilFactory):\n", " grid_indexing = stencil_factory.grid_indexing\n", @@ -319,7 +319,8 @@ " field_out: FloatField,\n", " ):\n", " self._copy_field_offset_output(field_in, field_out)\n", - " \n", + "\n", + "\n", "copy_field_offset_output = CopyFieldOffsetOutput(stencil_factory)" ] } diff --git a/examples/NDSL/03_orchestration_basics.ipynb b/examples/NDSL/03_orchestration_basics.ipynb index 48f0d702..93f818d5 100644 --- a/examples/NDSL/03_orchestration_basics.ipynb +++ b/examples/NDSL/03_orchestration_basics.ipynb @@ -4,17 +4,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# **NDSL Orchestration Basics**\n", + "# NDSL Orchestration Basics\n", "\n", - "### **Introduction**\n", + "### Introduction\n", "\n", - "When writing code using NDSL, there will be moments where an algorithm or code pattern does not match the stencil paradigm, and shoehorning the algorithm into the paradigm increases development difficulty. For these moments, we have a capability called orchestration that enables developers to use regular Python for non-stencil algorithms alongside stencil-based code via [DaCe](https://github.com/spcl/dace). DaCe also will attempt to find optimizations before output C++ code.\n", + "When writing code using NDSL, there will be moments where an algorithm or code pattern does not match the stencil paradigm, and shoehorning the algorithm into the paradigm increases development difficulty. For these moments, we have a capability called orchestration that enables developers to use regular Python for non-stencil algorithms alongside stencil-based code via [DaCe](https://github.com/spcl/dace). DaCe also will attempt to find optimizations before output C++ code.\n", "\n", "In this example, we will explore how to orchestrate a codebase using NDSL.\n", "\n", "### **Orchestration Example**\n", "\n", - "We'll step through a simple example that will orchestrate a codebase containing stencils and Python code. First we'll import the necessary packages." + "We'll step through a simple example that will orchestrate a codebase containing stencils and Python code. First we'll import the necessary packages." ] }, { @@ -44,7 +44,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next we'll define a simple stencil that sums the values around a point and applies a weight factor to that sum. Note that unlike [previous](./01_gt4py_basics.ipynb#Copy_Stencil_example) examples, we are not using the `@stencil` decorator since this stencil will be referenced within a `StencilFactory` function call." + "Next we'll define a simple stencil that sums the values around a point and applies a weight factor to that sum. Note that unlike [previous](./01_gt4py_basics.ipynb#Copy_Stencil_example) examples, we are not using the `@stencil` decorator since this stencil will be referenced within a `StencilFactory` function call." ] }, { @@ -54,9 +54,9 @@ "outputs": [], "source": [ "def localsum_stencil(\n", - " field: FloatField, # type: ignore\n", + " field: FloatField, # type: ignore\n", " result: FloatField, # type: ignore\n", - " weight: Float, # type: ignore\n", + " weight: Float, # type: ignore\n", "):\n", " with computation(PARALLEL), interval(...):\n", " result = weight * (\n", @@ -68,7 +68,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We'll define an object that enables the orchestration and combines both stencils and regular Python codes. The orchestration occurs with the `orchestrate` call in the `__init__` definition. Within `__call__`, there's a combination of both stencil and regular python codes." + "We'll define an object that enables the orchestration and combines both stencils and regular Python codes. The orchestration occurs with the `orchestrate` call in the `__init__` definition. Within `__call__`, there's a combination of both stencil and regular python codes." ] }, { @@ -99,15 +99,15 @@ "\n", " def __call__(self, in_field: FloatField, out_result: FloatField) -> None:\n", " self._local_sum(in_field, out_result, 2.0) # GT4Py Stencil\n", - " tmp_field = out_result[:, :, :] + 2 # Regular Python code\n", - " self._local_sum(tmp_field, out_result, 2.0) # GT4Py Stencil" + " tmp_field = out_result[:, :, :] + 2 # Regular Python code\n", + " self._local_sum(tmp_field, out_result, 2.0) # GT4Py Stencil" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we'll create a simple driver that defines the domain and halo size, specifies the backend (`dace:cpu` in order to use DaCe), and uses the boilerplate code to create a stencil and quantity factory objects. These objects help define the computational domain used for this particular example. After defining quantities (`in_field` and `out_field`) to hold the appropriate values and creating an object `local_sum` for our combined stencil/Python calculation, `local_sum` is called to perform the computation. In the output, we can see DaCe orchestrating the code. " + "Next, we'll create a simple driver that defines the domain and halo size, specifies the backend (`dace:cpu` in order to use DaCe), and uses the boilerplate code to create a stencil and quantity factory objects. These objects help define the computational domain used for this particular example. After defining quantities (`in_field` and `out_field`) to hold the appropriate values and creating an object `local_sum` for our combined stencil/Python calculation, `local_sum` is called to perform the computation. In the output, we can see DaCe orchestrating the code." ] }, { diff --git a/examples/mpi/zarr_monitor.py b/examples/mpi/zarr_monitor.py index 0c089af4..ea748f54 100644 --- a/examples/mpi/zarr_monitor.py +++ b/examples/mpi/zarr_monitor.py @@ -19,7 +19,7 @@ def get_example_state(time): - sizer = SubtileGridSizer(nx=48, ny=48, nz=70, n_halo=3, extra_dim_lengths={}) + sizer = SubtileGridSizer(nx=48, ny=48, nz=70, n_halo=3, data_dimensions={}) allocator = QuantityFactory(sizer, np) air_temperature = allocator.zeros([X_DIM, Y_DIM, Z_DIM], units="degK") air_temperature.view[:] = np.random.randn(*air_temperature.extent) @@ -39,7 +39,7 @@ def get_example_state(time): time = cftime.DatetimeJulian(2020, 1, 1) timestep = timedelta(hours=1) - for i in range(10): + for _i in range(10): state = get_example_state(time) monitor.store(state) time += timestep diff --git a/external/dace b/external/dace index 13402cbf..1033dfcf 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 13402cbfeeb6969cbd3915acfb7a30bdb543071b +Subproject commit 1033dfcf9d118856d82c6ee8d6f6cfacec662335 diff --git a/external/gt4py b/external/gt4py index 45324c88..e140f707 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 45324c88e57b5e8dfc974efa70fa2f2e5e10677f +Subproject commit e140f70731b723c519239e027237fb6281f4733b diff --git a/mkdocs.yml b/mkdocs.yml index 09916f21..1ca8682c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,23 +9,119 @@ theme: nav: - Home: index.md + - Quickstart: quickstart.md - User documentation: user/index.md - Porting: - - General Concepts: porting/index.md - - Testing Infrastructure: porting/translate/index.md + - General Concepts: porting/index.md + - Testing Infrastructure: porting/translate/index.md - Under the hood: - - Technical Documentation: dev/index.md - - DaCe: dev/dace.md - - GT4Py: dev/gt4py.md - + - Technical Documentation: dev/index.md + - DaCe: dev/dace.md + - GT4Py: dev/gt4py.md + - NDSL Docstrings: + - General Info: docstrings/index.md + - Top: + - "boilerplate": docstrings/top/boilerplate.md + - "buffer": docstrings/top/buffer.md + - "constants": docstrings/top/constants.md + - "exceptions": docstrings/top/exceptions.md + - "filesystem": docstrings/top/filesystem.md + - "io": docstrings/top/io.md + - "logging": docstrings/top/logging.md + - "namelist": docstrings/top/namelist.md + - "optional_imports": docstrings/top/optional_imports.md + - "types": docstrings/top/types.md + - "typing": docstrings/top/typing.md + - "units": docstrings/top/units.md + - "utils": docstrings/top/utils.md + - checkpointer: + - "base": docstrings/checkpointer/base.md + - "null": docstrings/checkpointer/null.md + - "snapshots": docstrings/checkpointer/snapshots.md + - "thresholds": docstrings/checkpointer/thresholds.md + - "validation": docstrings/checkpointer/validation.md + - comm: + - "_boundary_utils": docstrings/comm/_boundary_utils.md + - "boundary": docstrings/comm/boundary.md + - "caching_comm": docstrings/comm/caching_comm.md + - "comm_abc": docstrings/comm/comm_abc.md + - "communicator": docstrings/comm/communicator.md + - "decomposition": docstrings/comm/decomposition.md + - "local_comm": docstrings/comm/local_comm.md + - "mpi": docstrings/comm/mpi.md + - "null_comm": docstrings/comm/null_comm.md + - "partitioner": docstrings/comm/partitioner.md + - debug: + - "config": docstrings/debug/config.md + - "debugger": docstrings/debug/debugger.md + - "tooling": docstrings/debug/tooling.md + - dsl: + - "gt4py_utils": docstrings/dsl/gt4py_utils.md + - "stencil": docstrings/dsl/stencil.md + - "stencil_config": docstrings/dsl/stencil_config.md + - "typing": docstrings/dsl/typing.md + - grid: + - "eta": docstrings/grid/eta.md + - "generation": docstrings/grid/generation.md + - "geometry": docstrings/grid/geometry.md + - "global_setup": docstrings/grid/global_setup.md + - "gnomonic": docstrings/grid/gnomonic.md + - "helper": docstrings/grid/helper.md + - "mirror": docstrings/grid/mirror.md + - "stretch_transformation": docstrings/grid/stretch_transformation.md + - halo: + - "cuda_kernels": docstrings/halo/cuda_kernels.md + - "data_transformer": docstrings/halo/data_transformer.md + - "rotate": docstrings/halo/rotate.md + - "updater": docstrings/halo/updater.md + - initialization: + - "allocator": docstrings/initialization/allocator.md + - "grid_sizer": docstrings/initialization/grid_sizer.md + - "subtile_grid_sizer": docstrings/initialization/subtile_grid_sizer.md + - monitor: + - "convert": docstrings/monitor/convert.md + - "netcdf_monitor": docstrings/monitor/netcdf_monitor.md + - "protocol": docstrings/monitor/protocol.md + - "zarr_monitor": docstrings/monitor/zarr_monitor.md + - performance: + - "collector": docstrings/performance/collector.md + - "config": docstrings/performance/config.md + - "profiler": docstrings/performance/profiler.md + - "report": docstrings/performance/report.md + - "timer": docstrings/performance/timer.md + - "tools": docstrings/performance/tools.md + - quantity: + - "bounds": docstrings/quantity/bounds.md + - "field_bundle": docstrings/quantity/field_bundle.md + - "metadata": docstrings/quantity/metadata.md + - "quantity": docstrings/quantity/quantity.md + - "state": docstrings/quantity/state.md + - restart: + - "_legacy_restart": docstrings/restart/_legacy_restart.md + - "_properties": docstrings/restart/_properties.md + - stencils: + - "basic_operations": docstrings/stencils/basic_operations.md + - "c2l_ord": docstrings/stencils/c2l_ord.md + - "corners": docstrings/stencils/corners.md + - "tridiag": docstrings/stencils/tridiag.md + - testing: + - "comparison": docstrings/testing/comparison.md + - "dummy_comm": docstrings/testing/dummy_comm.md + - "perturbation": docstrings/testing/perturbation.md + - viz: + - "cube_sphere": docstrings/viz/cube_sphere.md markdown_extensions: # simple glossary file - abbr # support for colored notes / warnings / tips / examples - admonition + # support for "definition lists" (
) + - def_list # support for footnotes - footnotes + # support for emojis + - pymdownx.emoji # support for syntax highlighting - pymdownx.highlight: anchor_linenums: true @@ -39,14 +135,20 @@ markdown_extensions: - pymdownx.superfences: custom_fences: # support for mermaid graphs - - name: mermaid - class: mermaid - format: python/name:pymdownx.superfences.fence_code_format + - name: mermaid + class: mermaid + format: python/name:pymdownx.superfences.fence_code_format # image inclusion plugins: # add search box to the header, configuration in theme - search + - mkdocstrings: + handlers: + python: + paths: [./ndsl] # Adjust this path to where your Python modules are + options: + show_source: false watch: # reload when the glossary file is updated diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 2cac006f..00000000 --- a/mypy.ini +++ /dev/null @@ -1,55 +0,0 @@ -[mypy] -ignore_missing_imports = True - -# untyped vcm packages -[mypy-fv3viz] -ignore_missing_imports = True - -[mypy-report] -ignore_missing_imports = True - -[mypy-loaders] -ignore_missing_imports = True - -# External Libraries -[mypy-mappm] -ignore_missing_imports = True - -[mypy-gcsfs] -ignore_missing_imports = True - -[mypy-xgcm] -ignore_missing_imports = True - -[mypy-google.*] -ignore_missing_imports = True - -[mypy-numpy] -ignore_missing_imports = True - -[mypy-fsspec] -ignore_missing_imports = True - -[mypy-dask.*] -ignore_missing_imports = True - -[mypy-scipy.*] -ignore_missing_imports = True - -[mypy-skimage.*] -ignore_missing_imports = True - -[mypy-apache_beam.*] -ignore_missing_imports = True - -[mypy-intake] -ignore_missing_imports = True - -[mypy-joblib] -ignore_missing_imports = True - -[mypy-sklearn.*] -ignore_missing_imports = True - -[mypy-toolz] -ignore_missing_imports = True diff --git a/ndsl/__init__.py b/ndsl/__init__.py index f6bb6117..c7730282 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,4 +1,5 @@ from . import dsl # isort:skip +from .logging import ndsl_log # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm @@ -6,7 +7,7 @@ from .comm.partitioner import CubedSpherePartitioner, TilePartitioner from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath -from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG +from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration from .dsl.dace.orchestration import orchestrate, orchestrate_function from .dsl.dace.utils import ( ArrayReport, @@ -15,21 +16,77 @@ StorageReport, ) from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from .initialization.allocator import QuantityFactory -from .initialization.sizer import GridSizer, SubtileGridSizer -from .logging import ndsl_log +from .initialization import GridSizer, QuantityFactory, SubtileGridSizer from .monitor.netcdf_monitor import NetCDFMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Quantity +from .quantity import Local, Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator from .utils import MetaEnumStr + + +__all__ = [ + "dsl", + "CubedSphereCommunicator", + "TileCommunicator", + "LocalComm", + "MPIComm", + "NullComm", + "CubedSpherePartitioner", + "TilePartitioner", + "ConstantVersions", + "FV3CodePath", + "DaceConfig", + "DaCeOrchestration", + "orchestrate", + "orchestrate_function", + "ArrayReport", + "DaCeProgress", + "MaxBandwidthBenchmarkProgram", + "StorageReport", + "WrappedHaloUpdater", + "FrozenStencil", + "GridIndexing", + "StencilFactory", + "TimingCollector", + "CompilationConfig", + "RunMode", + "StencilConfig", + "OutOfBoundsError", + "HaloExchangeSpec", + "HaloUpdater", + "HaloUpdateRequest", + "VectorInterfaceHaloUpdater", + "QuantityFactory", + "GridSizer", + "SubtileGridSizer", + "ndsl_log", + "NetCDFMonitor", + "Namelist", + "NullPerformanceCollector", + "PerformanceCollector", + "NullProfiler", + "Profiler", + "Experiment", + "Report", + "TimeReport", + "Quantity", + "FieldBundle", + "FieldBundleType", + "DummyComm", + "Allocator", + "MetaEnumStr", + "State", + "NDSLRuntime", + "Local", +] diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index 41b34820..d17d6349 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -1,7 +1,3 @@ -from typing import Tuple - -import numpy as np - from ndsl import ( CompilationConfig, DaceConfig, @@ -16,18 +12,17 @@ TileCommunicator, TilePartitioner, ) -from ndsl.optional_imports import cupy as cp def _get_factories( nx: int, ny: int, nz: int, - nhalo, + nhalo: int, backend: str, orchestration: DaCeOrchestration, topology: str, -) -> Tuple[StencilFactory, QuantityFactory]: +) -> tuple[StencilFactory, QuantityFactory]: """Build a Stencil & Quantity factory for a combination of options. Dev Note: We don't expose this function because we want the boilerplate to remain @@ -65,7 +60,6 @@ def _get_factories( ny_tile=ny, nz=nz, n_halo=nhalo, - extra_dim_lengths={}, layout=partitioner.layout, tile_partitioner=partitioner, ) @@ -75,31 +69,34 @@ def _get_factories( grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) - quantity_factory = QuantityFactory( - sizer, cp if stencil_config.is_gpu_backend else np - ) + quantity_factory = QuantityFactory.from_backend(sizer, backend) return stencil_factory, quantity_factory def get_factories_single_tile_orchestrated( - nx, ny, nz, nhalo, on_cpu: bool = True -) -> Tuple[StencilFactory, QuantityFactory]: - """Build a Stencil & Quantity factory for orchestrated CPU, on a single tile topology.""" + nx: int, ny: int, nz: int, nhalo: int, backend: str = "dace:cpu" +) -> tuple[StencilFactory, QuantityFactory]: + """Build the pair of (StencilFactory, QuantityFactory) for orchestrated code on a single tile topology.""" + + if backend is not None and not backend.startswith("dace"): + raise ValueError("Only `dace:*` backends can be orchestrated.") + return _get_factories( nx=nx, ny=ny, nz=nz, nhalo=nhalo, - backend="dace:cpu" if on_cpu else "dace:gpu", + backend=backend, orchestration=DaCeOrchestration.BuildAndRun, topology="tile", ) def get_factories_single_tile( - nx, ny, nz, nhalo, backend: str = "numpy" -) -> Tuple[StencilFactory, QuantityFactory]: + nx: int, ny: int, nz: int, nhalo: int, backend: str = "numpy" +) -> tuple[StencilFactory, QuantityFactory]: + """Build the pair of (StencilFactory, QuantityFactory) for stencils on a single tile topology.""" return _get_factories( nx=nx, ny=ny, diff --git a/ndsl/buffer.py b/ndsl/buffer.py index 05cd6434..7bb30bb3 100644 --- a/ndsl/buffer.py +++ b/ndsl/buffer.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import contextlib -from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple +from collections.abc import Callable, Generator, Iterable import numpy as np from numpy.lib.index_tricks import IndexExpression @@ -14,8 +16,8 @@ ) -BufferKey = Tuple[Callable, Iterable[int], type] -BUFFER_CACHE: Dict[BufferKey, List["Buffer"]] = {} +BufferKey = tuple[Callable, Iterable[int], type] +BUFFER_CACHE: dict[BufferKey, list["Buffer"]] = {} class Buffer: @@ -40,7 +42,7 @@ def __init__(self, key: BufferKey, array: np.ndarray): @classmethod def pop_from_cache( cls, allocator: Allocator, shape: Iterable[int], dtype: type - ) -> "Buffer": + ) -> Buffer: """Retrieve or insert then retrieve of buffer from cache. Args: @@ -61,7 +63,7 @@ def pop_from_cache( return cls(key, array) @staticmethod - def push_to_cache(buffer: "Buffer"): + def push_to_cache(buffer: Buffer) -> None: """Push the buffer back into the cache. Args: @@ -69,7 +71,7 @@ def push_to_cache(buffer: "Buffer"): """ BUFFER_CACHE[buffer._key].append(buffer) - def finalize_memory_transfer(self): + def finalize_memory_transfer(self) -> None: """Finalize any memory transfer""" device_synchronize() @@ -78,7 +80,7 @@ def assign_to( destination_array: np.ndarray, buffer_slice: IndexExpression = np.index_exp[:], buffer_reshape: IndexExpression = None, - ): + ) -> None: """Assign internal array to destination_array. Args: @@ -94,7 +96,7 @@ def assign_to( def assign_from( self, source_array: np.ndarray, buffer_slice: IndexExpression = np.index_exp[:] - ): + ) -> None: """Assign source_array to internal array. Args: @@ -129,7 +131,7 @@ def array_buffer( def send_buffer( allocator: Callable, array: np.ndarray, - timer: Optional[Timer] = None, + timer: Timer | None = None, ) -> np.ndarray: """A context manager ensuring that `array` is contiguous in a context where it is being sent as data, copying into a recycled buffer array if necessary. @@ -163,7 +165,7 @@ def send_buffer( def recv_buffer( allocator: Callable, array: np.ndarray, - timer: Optional[Timer] = None, + timer: Timer | None = None, ) -> np.ndarray: """A context manager ensuring that array is contiguous in a context where it is being used to receive data, using a recycled buffer array and then copying the diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index 8fee4dc1..22ff634e 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -7,3 +7,14 @@ ThresholdCalibrationCheckpointer, ) from .validation import ValidationCheckpointer + + +__all__ = [ + "NullCheckpointer", + "SnapshotCheckpointer", + "InsufficientTrialsError", + "SavepointThresholds", + "Threshold", + "ThresholdCalibrationCheckpointer", + "ValidationCheckpointer", +] diff --git a/ndsl/checkpointer/base.py b/ndsl/checkpointer/base.py index 8218bbfe..f51fb0a2 100644 --- a/ndsl/checkpointer/base.py +++ b/ndsl/checkpointer/base.py @@ -1,7 +1,16 @@ import abc +from typing import TypeAlias + +import numpy as np + +from ndsl import Quantity + + +SavepointName: TypeAlias = str +VariableName: TypeAlias = str +ArrayLike: TypeAlias = Quantity | np.ndarray class Checkpointer(abc.ABC): @abc.abstractmethod - def __call__(self, savepoint_name, **kwargs): - ... + def __call__(self, savepoint_name: SavepointName, **kwargs: ArrayLike) -> None: ... diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index fbc78755..f3e1eb5f 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,6 +1,6 @@ -from ndsl.checkpointer.base import Checkpointer +from ndsl.checkpointer.base import ArrayLike, Checkpointer, SavepointName class NullCheckpointer(Checkpointer): - def __call__(self, savepoint_name, **kwargs): + def __call__(self, savepoint_name: SavepointName, **kwargs: ArrayLike) -> None: pass diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index 1447c5fb..403bd9e8 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -3,11 +3,13 @@ import numpy as np import xarray as xr -from ndsl.checkpointer.base import Checkpointer +from ndsl.checkpointer.base import ArrayLike, Checkpointer, SavepointName, VariableName from ndsl.optional_imports import cupy as cp -def make_dims(savepoint_dim, label, data_list): +def make_dims( + savepoint_dim: str, label: str, data_list: list[np.ndarray] +) -> tuple[list[str], np.ndarray]: """ Helper which defines dimension names for an xarray variable. @@ -22,16 +24,25 @@ def make_dims(savepoint_dim, label, data_list): class _Snapshots: - def __init__(self): - self._savepoints = collections.defaultdict(list) - self._arrays = collections.defaultdict(list) + def __init__(self) -> None: + self._savepoints: dict[VariableName, list[SavepointName]] = ( + collections.defaultdict(list) + ) + self._arrays: dict[VariableName, list[np.ndarray]] = collections.defaultdict( + list + ) - def store(self, savepoint_name: str, variable_name: str, python_data): + def store( + self, + savepoint_name: SavepointName, + variable_name: VariableName, + python_data: np.ndarray, + ) -> None: self._savepoints[variable_name].append(savepoint_name) self._arrays[variable_name].append(python_data) @property - def dataset(self) -> "xr.Dataset": + def dataset(self) -> xr.Dataset: data_vars = {} for variable_name, savepoint_list in self._savepoints.items(): savepoint_dim = f"sp_{variable_name}" @@ -48,18 +59,18 @@ class SnapshotCheckpointer(Checkpointer): of variables between checkpointer calls. """ - def __init__(self, rank: int): + def __init__(self, rank: int) -> None: self._rank = rank self._snapshots = _Snapshots() - def __call__(self, savepoint_name, **kwargs): + def __call__(self, savepoint_name: SavepointName, **kwargs: ArrayLike) -> None: for name, value in kwargs.items(): array_data = np.copy(value.data) self._snapshots.store(savepoint_name, name, array_data) @property - def dataset(self) -> "xr.Dataset": + def dataset(self) -> xr.Dataset: return self._snapshots.dataset - def cleanup(self): + def cleanup(self) -> None: self.dataset.to_netcdf(f"comparison_rank{self._rank}.nc") diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index fbf0e956..b177ab32 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -1,19 +1,16 @@ +from __future__ import annotations + import collections import contextlib import dataclasses -from typing import Dict, List, Mapping, Union +from collections.abc import Mapping import numpy as np -from ndsl.checkpointer.base import Checkpointer +from ndsl.checkpointer.base import ArrayLike, Checkpointer, SavepointName, VariableName from ndsl.quantity import Quantity -SavepointName = str -VariableName = str -ArrayLike = Union[Quantity, np.ndarray] - - class InsufficientTrialsError(Exception): pass @@ -23,7 +20,7 @@ class Threshold: relative: float absolute: float - def merge(self, other: "Threshold") -> "Threshold": + def merge(self, other: Threshold) -> Threshold: """ Provide a threshold which is always satisfied if both input thresholds are satisfied. @@ -38,7 +35,7 @@ def merge(self, other: "Threshold") -> "Threshold": @dataclasses.dataclass class SavepointThresholds: - savepoints: Dict[SavepointName, List[Dict[VariableName, Threshold]]] + savepoints: dict[SavepointName, list[dict[VariableName, Threshold]]] def cast_to_ndarray(array: ArrayLike) -> np.ndarray: @@ -68,19 +65,19 @@ def __init__(self, factor: float = 1.0): # we keep dictionaries (over savepoint name) of lists (over call count) # of dictionaries (over variable name) of numpy arrays self._minimums: Mapping[ - SavepointName, List[Mapping[VariableName, np.ndarray]] + SavepointName, list[Mapping[VariableName, np.ndarray]] ] = collections.defaultdict(list) self._maximums: Mapping[ - SavepointName, List[Mapping[VariableName, np.ndarray]] + SavepointName, list[Mapping[VariableName, np.ndarray]] ] = collections.defaultdict(list) self._factor = factor self._abs_sums: Mapping[ - SavepointName, List[Mapping[VariableName, np.ndarray]] + SavepointName, list[Mapping[VariableName, np.ndarray]] ] = collections.defaultdict(list) self._n_trials = 0 self._n_calls: Mapping[SavepointName, int] = collections.defaultdict(int) - def __call__(self, savepoint_name, **kwargs): + def __call__(self, savepoint_name: SavepointName, **kwargs: ArrayLike) -> None: """ Record values for a savepoint. @@ -98,19 +95,19 @@ def __call__(self, savepoint_name, **kwargs): ) self._abs_sums[savepoint_name].append(collections.defaultdict(lambda: 0.0)) for varname, array in kwargs.items(): - array: np.ndarray = cast_to_ndarray(array) - self._minimums[savepoint_name][i_call][varname] = np.minimum( + array = cast_to_ndarray(array) + self._minimums[savepoint_name][i_call][varname] = np.minimum( # type: ignore[index] self._minimums[savepoint_name][i_call][varname], array ) - self._maximums[savepoint_name][i_call][varname] = np.maximum( + self._maximums[savepoint_name][i_call][varname] = np.maximum( # type: ignore[index] self._maximums[savepoint_name][i_call][varname], array ) - self._abs_sums[savepoint_name][i_call][varname] += np.abs(array) + self._abs_sums[savepoint_name][i_call][varname] += np.abs(array) # type: ignore[index] - self._n_calls[savepoint_name] += 1 + self._n_calls[savepoint_name] += 1 # type: ignore[index] @contextlib.contextmanager - def trial(self): + def trial(self): # type: ignore[no-untyped-def] """ Context manager for a trial. @@ -126,14 +123,12 @@ def trial(self): self._n_trials += 1 @property - def thresholds( - self, - ) -> SavepointThresholds: + def thresholds(self) -> SavepointThresholds: if self._n_trials < 2: raise InsufficientTrialsError( "at least 2 trials required to generate thresholds" ) - savepoints: Dict[SavepointName, List[Dict[VariableName, Threshold]]] = {} + savepoints: dict[SavepointName, list[dict[VariableName, Threshold]]] = {} for savepoint_name in self._minimums: savepoints[savepoint_name] = [] for i_call in range(self._n_calls[savepoint_name]): diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 00da3d13..ae9e5405 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -1,22 +1,17 @@ import collections import contextlib import os.path -from typing import MutableMapping, Tuple +from collections.abc import MutableMapping import numpy as np import xarray as xr -from ndsl.checkpointer.base import Checkpointer -from ndsl.checkpointer.thresholds import ( - ArrayLike, - SavepointName, - SavepointThresholds, - cast_to_ndarray, -) +from ndsl.checkpointer.base import ArrayLike, Checkpointer, SavepointName +from ndsl.checkpointer.thresholds import SavepointThresholds, cast_to_ndarray def _clip_pace_array_to_target( - array: np.ndarray, target_shape: Tuple[int, ...] + array: np.ndarray, target_shape: tuple[int, ...] ) -> np.ndarray: """ Clip an array from pace to align it to a target shape from target serialized data. @@ -36,7 +31,9 @@ def _clip_pace_array_to_target( return _remove_symmetric_halos(array, target_shape) -def _remove_buffer_if_needed(array: np.ndarray, target_shape: Tuple[int, ...]): +def _remove_buffer_if_needed( + array: np.ndarray, target_shape: tuple[int, ...] +) -> np.ndarray: selection = [] # both arrays are assumed to have the same staggering and an even number of # halo points for each dimension, so any odd difference in points must be @@ -51,7 +48,9 @@ def _remove_buffer_if_needed(array: np.ndarray, target_shape: Tuple[int, ...]): return array[tuple(selection)] -def _remove_symmetric_halos(array: np.ndarray, target_shape: Tuple[int, ...]): +def _remove_symmetric_halos( + array: np.ndarray, target_shape: tuple[int, ...] +) -> np.ndarray: selection = [] for array_len, target_len in zip(array.shape, target_shape): n_halo_clip = (array_len - target_len) // 2 @@ -86,7 +85,7 @@ def __init__( self._n_calls: MutableMapping[SavepointName, int] = collections.defaultdict(int) @contextlib.contextmanager - def trial(self): + def trial(self): # type: ignore[no-untyped-def] """ Context manager for a trial. diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index c4a58658..580f881a 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -6,3 +6,14 @@ CachingRequestWriter, ) from .comm_abc import Comm, Request + + +__all__ = [ + "CachingCommData", + "CachingCommReader", + "CachingCommWriter", + "CachingRequestReader", + "CachingRequestWriter", + "Comm", + "Request", +] diff --git a/ndsl/comm/_boundary_utils.py b/ndsl/comm/_boundary_utils.py index 01e241fc..fbc6fa8b 100644 --- a/ndsl/comm/_boundary_utils.py +++ b/ndsl/comm/_boundary_utils.py @@ -1,8 +1,6 @@ import functools -from typing import Union import ndsl.constants as constants -from ndsl.exceptions import OutOfBoundsError def shift_boundary_slice_tuple(dims, origin, extent, boundary_type, slice_tuple): @@ -83,19 +81,19 @@ def get_boundary_slice(dims, origin, extent, shape, boundary_type, n_halo, inter else: start, stop = edge_index, edge_index + n_points if start < 0: - raise OutOfBoundsError( - f"boundary slice extends past start of domain on dimension {dim}" + raise IndexError( + f"Boundary slice extends past start of domain on dimension {dim}." ) elif stop > shape_1d: - raise OutOfBoundsError( - f"boundary slice extends past end of domain on dimension {dim}" + raise IndexError( + f"Boundary slice extends past end of domain on dimension {dim}." ) else: boundary_slice.append(slice(start, stop)) return tuple(boundary_slice) -def boundary_at_start_of_dim(boundary: int, dim: str) -> Union[bool, None]: +def boundary_at_start_of_dim(boundary: int, dim: str) -> bool | None: """ Return True if boundary is at the start of the dimension, False if at the end, None if the boundary does not align with the dimension. diff --git a/ndsl/comm/boundary.py b/ndsl/comm/boundary.py index f6016f6e..a8f7d276 100644 --- a/ndsl/comm/boundary.py +++ b/ndsl/comm/boundary.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Tuple +from typing import Any from ndsl.comm._boundary_utils import get_boundary_slice from ndsl.quantity import Quantity, QuantityHaloSpec @@ -18,7 +18,7 @@ class Boundary: orientation of the axes in from_rank to the orientation of the axes in to_rank. """ - def send_view(self, quantity: Quantity, n_points: int): + def send_view(self, quantity: Quantity, n_points: int) -> Any: """Return a sliced view of points which should be sent at this boundary. Args: @@ -27,7 +27,7 @@ def send_view(self, quantity: Quantity, n_points: int): """ return self._view(quantity, n_points, interior=True) - def recv_view(self, quantity: Quantity, n_points: int): + def recv_view(self, quantity: Quantity, n_points: int) -> Any: """Return a sliced view of points which should be received at this boundary. Args: @@ -36,7 +36,7 @@ def recv_view(self, quantity: Quantity, n_points: int): """ return self._view(quantity, n_points, interior=False) - def send_slice(self, specification: QuantityHaloSpec) -> Tuple[slice]: + def send_slice(self, specification: QuantityHaloSpec) -> tuple[slice]: """Return the index slices which should be sent at this boundary. Args: @@ -48,7 +48,7 @@ def send_slice(self, specification: QuantityHaloSpec) -> Tuple[slice]: """ return self._slice(specification, interior=True) - def recv_slice(self, specification: QuantityHaloSpec) -> Tuple[slice]: + def recv_slice(self, specification: QuantityHaloSpec) -> tuple[slice]: """Return the index slices which should be received at this boundary. Args: @@ -60,7 +60,7 @@ def recv_slice(self, specification: QuantityHaloSpec) -> Tuple[slice]: """ return self._slice(specification, interior=False) - def _slice(self, specification: QuantityHaloSpec, interior: bool) -> Tuple[slice]: + def _slice(self, specification: QuantityHaloSpec, interior: bool) -> tuple[slice]: """Returns a tuple of slices (one per dimensions) indexing the data to be exchange. Args: @@ -69,9 +69,9 @@ def _slice(self, specification: QuantityHaloSpec, interior: bool) -> Tuple[slice Return: A tuple of slices (one per dimensions) """ - raise NotImplementedError() + raise NotImplementedError("Boundary._slice()") - def _view(self, quantity: Quantity, n_points: int, interior: bool): + def _view(self, quantity: Quantity, n_points: int, interior: bool) -> Any: """Return a sliced view of points in the given quantity at this boundary. Args: @@ -80,7 +80,7 @@ def _view(self, quantity: Quantity, n_points: int, interior: bool): interior: if True, give points inside the computational domain (default), otherwise give points in the halo """ - raise NotImplementedError() + raise NotImplementedError("Boundary._view()") @dataclasses.dataclass @@ -89,7 +89,7 @@ class SimpleBoundary(Boundary): boundary_type: int - def _view(self, quantity: Quantity, n_points: int, interior: bool): + def _view(self, quantity: Quantity, n_points: int, interior: bool) -> Any: boundary_slice = get_boundary_slice( quantity.dims, quantity.origin, @@ -101,7 +101,7 @@ def _view(self, quantity: Quantity, n_points: int, interior: bool): ) return quantity.data[tuple(boundary_slice)] - def _slice(self, specification: QuantityHaloSpec, interior: bool) -> Tuple[slice]: + def _slice(self, specification: QuantityHaloSpec, interior: bool) -> tuple[slice]: return get_boundary_slice( specification.dims, specification.origin, diff --git a/ndsl/comm/caching_comm.py b/ndsl/comm/caching_comm.py index 42f92ea2..4de86eb9 100644 --- a/ndsl/comm/caching_comm.py +++ b/ndsl/comm/caching_comm.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import copy import dataclasses import pickle -from typing import Any, BinaryIO, List, Optional, TypeVar +from typing import Any, BinaryIO, TypeVar import numpy as np @@ -12,27 +14,29 @@ class CachingRequestWriter(Request): - def __init__(self, req: Request, buffer: np.ndarray, buffer_list: List[np.ndarray]): + def __init__( + self, req: Request, buffer: np.ndarray, buffer_list: list[np.ndarray] + ) -> None: self._req = req self._buffer = buffer self._buffer_list = buffer_list - def wait(self): + def wait(self) -> None: self._req.wait() self._buffer_list.append(copy.deepcopy(self._buffer)) class CachingRequestReader(Request): - def __init__(self, recvbuf, data): + def __init__(self, recvbuf: Any, data: Any) -> None: self._recvbuf = recvbuf self._data = data - def wait(self): + def wait(self) -> None: self._recvbuf[:] = self._data class NullRequest(Request): - def wait(self): + def wait(self) -> None: pass @@ -47,51 +51,51 @@ class CachingCommData: rank: int size: int - bcast_objects: List[Any] = dataclasses.field(default_factory=list) - received_buffers: List[np.ndarray] = dataclasses.field(default_factory=list) - generic_obj_buffers: List[Any] = dataclasses.field(default_factory=list) - split_data: List["CachingCommData"] = dataclasses.field(default_factory=list) + bcast_objects: list[Any] = dataclasses.field(default_factory=list) + received_buffers: list[np.ndarray] = dataclasses.field(default_factory=list) + generic_obj_buffers: list[Any] = dataclasses.field(default_factory=list) + split_data: list[CachingCommData] = dataclasses.field(default_factory=list) - def __post_init__(self): + def __post_init__(self) -> None: self._i_bcast = 0 self._i_buffers = 0 self._i_split = 0 self._i_generic_obj = 0 - def get_bcast(self): + def get_bcast(self) -> Any: return_value = self.bcast_objects[self._i_bcast] self._i_bcast += 1 return return_value - def get_buffer(self): + def get_buffer(self) -> np.ndarray: return_value = self.received_buffers[self._i_buffers] self._i_buffers += 1 return return_value - def get_generic_obj(self): + def get_generic_obj(self) -> Any: return_value = self.generic_obj_buffers[self._i_generic_obj] self._i_generic_obj += 1 return return_value - def get_split(self): + def get_split(self) -> CachingCommData: return_value = self.split_data[self._i_split] self._i_split += 1 return return_value - def dump(self, file: BinaryIO): + def dump(self, file: BinaryIO) -> None: pickle.dump(self, file) @classmethod - def load(self, file: BinaryIO) -> "CachingCommData": + def load(self, file: BinaryIO) -> CachingCommData: return pickle.load(file) -class CachingCommReader(Comm): +class CachingCommReader(Comm[T]): """ mpi4py Comm-like object which replays stored communications. """ - def __init__(self, data: CachingCommData): + def __init__(self, data: CachingCommData) -> None: """ Initialize a CachingCommReader. @@ -109,63 +113,68 @@ def Get_rank(self) -> int: def Get_size(self) -> int: return self._data.size - def bcast(self, value: Optional[T], root=0) -> T: + def bcast(self, value: T | None, root: int = 0) -> T | None: return self._data.get_bcast() - def barrier(self): + def barrier(self) -> None: pass - def Barrier(self): + def Barrier(self) -> None: pass - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] recvbuf[:] = self._data.get_buffer() - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] if recvbuf is not None: recvbuf[:] = self._data.get_buffer() - def allgather(self, sendobj): + def allgather(self, sendobj: T) -> list[T]: raise NotImplementedError("allgather not yet implemented for CachingCommReader") - def Send(self, sendbuf, dest, tag: int = 0, **kwargs): + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] pass - def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return NullRequest() - def Recv(self, recvbuf, source, tag: int = 0, **kwargs): + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] recvbuf[:] = self._data.get_buffer() - def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return CachingRequestReader(recvbuf, self._data.get_buffer()) - def sendrecv(self, sendbuf, dest, **kwargs): - raise NotImplementedError() + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("CachingCommReader.sendrecv") - def Split(self, color, key) -> "CachingCommReader": + def Split(self, color, key) -> CachingCommReader: # type: ignore[no-untyped-def] new_data = self._data.get_split() return CachingCommReader(data=new_data) - def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: return self._data.get_generic_obj() - def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: raise NotImplementedError("CachingCommReader.Allreduce") + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: + raise NotImplementedError("CachingCommReader.Allreduce_inplace") + @classmethod - def load(cls, file: BinaryIO) -> "CachingCommReader": + def load(cls, file: BinaryIO) -> CachingCommReader: data = CachingCommData.load(file) return cls(data) -class CachingCommWriter(Comm): +class CachingCommWriter(Comm[T]): """ Wrapper around a mpi4py Comm object which can be serialized and then loaded as a CachingCommReader. """ - def __init__(self, comm: Comm): + def __init__(self, comm: Comm[T]) -> None: """ Args: comm: underlying mpi4py comm-like object @@ -182,60 +191,65 @@ def Get_rank(self) -> int: def Get_size(self) -> int: return self._comm.Get_size() - def bcast(self, value: Optional[T], root=0) -> T: + def bcast(self, value: T | None, root: int = 0) -> T | None: result = self._comm.bcast(value=value, root=root) self._data.bcast_objects.append(copy.deepcopy(result)) return result - def barrier(self): + def barrier(self) -> None: return self._comm.barrier() - def Barrier(self): + def Barrier(self) -> None: pass - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): + def Scatter(self, sendbuf, recvbuf, root=0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Scatter(sendbuf=sendbuf, recvbuf=recvbuf, root=root, **kwargs) self._data.received_buffers.append(copy.deepcopy(recvbuf)) - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): + def Gather(self, sendbuf, recvbuf, root=0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Gather(sendbuf=sendbuf, recvbuf=recvbuf, root=root, **kwargs) self._data.received_buffers.append(copy.deepcopy(recvbuf)) - def allgather(self, sendobj): + def allgather(self, sendobj: T) -> list[T]: raise NotImplementedError("allgather not yet implemented for CachingCommReader") - def Send(self, sendbuf, dest, tag: int = 0, **kwargs): + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Send(sendbuf=sendbuf, dest=dest, tag=tag, **kwargs) - def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return self._comm.Isend(sendbuf, dest, tag=tag, **kwargs) - def Recv(self, recvbuf, source, tag: int = 0, **kwargs): + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Recv(recvbuf=recvbuf, source=source, tag=tag, **kwargs) self._data.received_buffers.append(copy.deepcopy(recvbuf)) - def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] req = self._comm.Irecv(recvbuf, source, tag=tag, **kwargs) return CachingRequestWriter( req=req, buffer=recvbuf, buffer_list=self._data.received_buffers ) - def sendrecv(self, sendbuf, dest, **kwargs): - raise NotImplementedError() + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("CachingCommWriter.sendrecv") - def Split(self, color, key) -> "CachingCommWriter": + def Split(self, color, key) -> CachingCommWriter: # type: ignore[no-untyped-def] new_comm = self._comm.Split(color=color, key=key) new_wrapper = CachingCommWriter(new_comm) self._data.split_data.append(new_wrapper._data) return new_wrapper - def dump(self, file: BinaryIO): + def dump(self, file: BinaryIO) -> None: self._data.dump(file) - def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: result = self._comm.allreduce(sendobj, op) self._data.generic_obj_buffers.append(copy.deepcopy(result)) return result - def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: raise NotImplementedError("CachingCommWriter.Allreduce") + + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: + raise NotImplementedError("CachingCommWriter.Allreduce_inplace") diff --git a/ndsl/comm/comm_abc.py b/ndsl/comm/comm_abc.py index 45596f1e..f1dde524 100644 --- a/ndsl/comm/comm_abc.py +++ b/ndsl/comm/comm_abc.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import abc import enum -from typing import List, Optional, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -27,74 +29,59 @@ class ReductionOperator(enum.Enum): class Request(abc.ABC): @abc.abstractmethod - def wait(self): - ... + def wait(self) -> None: ... -class Comm(abc.ABC): +class Comm(abc.ABC, Generic[T]): @abc.abstractmethod - def Get_rank(self) -> int: - ... + def Get_rank(self) -> int: ... @abc.abstractmethod - def Get_size(self) -> int: - ... + def Get_size(self) -> int: ... @abc.abstractmethod - def bcast(self, value: Optional[T], root=0) -> T: - ... + def bcast(self, value: T | None, root: int = 0) -> T | None: ... @abc.abstractmethod - def barrier(self): - ... + def barrier(self) -> None: ... @abc.abstractmethod - def Barrier(self): - ... + def Barrier(self) -> None: ... @abc.abstractmethod - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): - ... + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): ... # type: ignore[no-untyped-def] @abc.abstractmethod - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): - ... + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): ... # type: ignore[no-untyped-def] @abc.abstractmethod - def allgather(self, sendobj: T) -> List[T]: - ... + def allgather(self, sendobj: T) -> list[T]: ... @abc.abstractmethod - def Send(self, sendbuf, dest, tag: int = 0, **kwargs): - ... + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): ... # type: ignore[no-untyped-def] @abc.abstractmethod - def sendrecv(self, sendbuf, dest, **kwargs): - ... + def sendrecv(self, sendbuf, dest, **kwargs: dict): ... # type: ignore[no-untyped-def] @abc.abstractmethod - def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: - ... + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: ... # type: ignore[no-untyped-def] @abc.abstractmethod - def Recv(self, recvbuf, source, tag: int = 0, **kwargs): - ... + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): ... # type: ignore[no-untyped-def] @abc.abstractmethod - def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: - ... + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: ... # type: ignore[no-untyped-def] @abc.abstractmethod - def Split(self, color, key) -> "Comm": - ... + def Split(self, color, key) -> Comm: ... # type: ignore[no-untyped-def] @abc.abstractmethod - def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: - ... + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: ... @abc.abstractmethod - def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: - ... + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ... - def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: - ... + @abc.abstractmethod + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: ... diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index c523affb..6eee4514 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import abc -from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Self, cast import numpy as np @@ -16,7 +19,7 @@ from ndsl.types import NumpyModule -def to_numpy(array, dtype=None) -> np.ndarray: +def to_numpy(array, dtype=None) -> np.ndarray: # type: ignore[no-untyped-def] """ Input array can be a numpy array or a cupy array. Returns numpy array. """ @@ -43,19 +46,20 @@ class Communicator(abc.ABC): def __init__( self, comm: CommABC, - partitioner, + partitioner: Partitioner, force_cpu: bool = False, - timer: Optional[Timer] = None, + timer: Timer | None = None, ): self.comm = comm self.partitioner: Partitioner = partitioner self._force_cpu = force_cpu - self._boundaries: Optional[Mapping[int, Boundary]] = None + self._boundaries: Mapping[int, Boundary] | None = None self._last_halo_tag = 0 self.timer: Timer = timer if timer is not None else NullTimer() - @abc.abstractproperty - def tile(self) -> "TileCommunicator": + @property + @abc.abstractmethod + def tile(self) -> TileCommunicator: pass @classmethod @@ -63,10 +67,10 @@ def tile(self) -> "TileCommunicator": def from_layout( cls, comm: CommABC, - layout: Tuple[int, int], + layout: tuple[int, int], force_cpu: bool = False, - timer: Optional[Timer] = None, - ): + timer: Timer | None = None, + ) -> Self: pass @property @@ -89,13 +93,13 @@ def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: return module @staticmethod - def _device_synchronize(): + def _device_synchronize() -> None: """Wait for all work that could be in-flight to finish.""" # this is a method so we can profile it separately from other device syncs device_synchronize() def _create_all_reduce_quantity( - self, input_metadata: QuantityMetadata, input_data + self, input_metadata: QuantityMetadata, input_data: Any ) -> Quantity: """Create a Quantity for all_reduce data and metadata""" all_reduce_quantity = Quantity( @@ -113,49 +117,52 @@ def all_reduce( self, input_quantity: Quantity, op: ReductionOperator, - output_quantity: Quantity = None, - ): + output_quantity: Quantity | None = None, + ) -> Quantity: reduced_quantity_data = self.comm.allreduce(input_quantity.data, op) if output_quantity is None: - all_reduce_quantity = self._create_all_reduce_quantity( + return self._create_all_reduce_quantity( input_quantity.metadata, reduced_quantity_data ) - return all_reduce_quantity - else: - if output_quantity.data.shape != input_quantity.data.shape: - raise TypeError("Shapes not matching") - input_quantity.metadata.duplicate_metadata(output_quantity.metadata) + if output_quantity.data.shape != input_quantity.data.shape: + raise TypeError("Shapes not matching") + + input_quantity.metadata.duplicate_metadata(output_quantity.metadata) - output_quantity.data = reduced_quantity_data + output_quantity.data = reduced_quantity_data + return output_quantity def all_reduce_per_element( self, input_quantity: Quantity, output_quantity: Quantity, op: ReductionOperator, - ): + ) -> None: self.comm.Allreduce(input_quantity.data, output_quantity.data, op) def all_reduce_per_element_in_place( self, quantity: Quantity, op: ReductionOperator - ): + ) -> None: + # Note that device_synchronization is Cupy/Cuda specific + # at the moment. + device_synchronize() self.comm.Allreduce_inplace(quantity.data, op) - def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): # type: ignore[no-untyped-def] with send_buffer(numpy_module.zeros, sendbuf) as send: with recv_buffer(numpy_module.zeros, recvbuf) as recv: self.comm.Scatter(send, recv, **kwargs) - def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): + def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): # type: ignore[no-untyped-def] with send_buffer(numpy_module.zeros, sendbuf) as send: with recv_buffer(numpy_module.zeros, recvbuf) as recv: self.comm.Gather(send, recv, **kwargs) def scatter( self, - send_quantity: Optional[Quantity] = None, - recv_quantity: Optional[Quantity] = None, + send_quantity: Quantity | None = None, + recv_quantity: Quantity | None = None, ) -> Quantity: """Transfer subtile regions of a full-tile quantity from the tile root rank to all subtiles. @@ -170,9 +177,11 @@ def scatter( raise TypeError("send_quantity is a required argument on the root rank") if self.rank == constants.ROOT_RANK: send_quantity = cast(Quantity, send_quantity) - metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) + metadata: QuantityMetadata = self.comm.bcast( + send_quantity.metadata, root=constants.ROOT_RANK + ) # type: ignore[assignment] else: - metadata = self.comm.bcast(None, root=constants.ROOT_RANK) + metadata = self.comm.bcast(None, root=constants.ROOT_RANK) # type: ignore[assignment] shape = self.partitioner.subtile_extent(metadata, self.rank) if recv_quantity is None: recv_quantity = self._get_scatter_recv_quantity(shape, metadata) @@ -214,7 +223,7 @@ def _get_gather_recv_quantity( ) -> Quantity: """Initialize a Quantity for use when receiving global data during gather""" recv_quantity = Quantity( - send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), + send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), # type: ignore dims=send_metadata.dims, units=send_metadata.units, origin=tuple([0 for dim in send_metadata.dims]), @@ -229,7 +238,7 @@ def _get_scatter_recv_quantity( ) -> Quantity: """Initialize a Quantity for use when receiving subtile data during scatter""" recv_quantity = Quantity( - send_metadata.np.zeros(shape, dtype=send_metadata.dtype), + send_metadata.np.zeros(shape, dtype=send_metadata.dtype), # type: ignore dims=send_metadata.dims, units=send_metadata.units, gt4py_backend=send_metadata.gt4py_backend, @@ -238,8 +247,8 @@ def _get_scatter_recv_quantity( return recv_quantity def gather( - self, send_quantity: Quantity, recv_quantity: Quantity = None - ) -> Optional[Quantity]: + self, send_quantity: Quantity, recv_quantity: Quantity | None = None + ) -> Quantity | None: """Transfer subtile regions of a full-tile quantity from each rank to the tile root rank. @@ -250,7 +259,7 @@ def gather( Returns: recv_quantity: quantity if on root rank, otherwise None """ - result: Optional[Quantity] + result: Quantity | None if self.rank == constants.ROOT_RANK: with array_buffer( send_quantity.np.zeros, @@ -291,7 +300,7 @@ def gather( result = None return result - def gather_state(self, send_state=None, recv_state=None, transfer_type=None): + def gather_state(self, send_state=None, recv_state=None, transfer_type=None): # type: ignore[no-untyped-def] """Transfer a state dictionary from subtile ranks to the tile root rank. 'time' is assumed to be the same on all ranks, and its value will be set @@ -329,7 +338,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None): del gather_quantity return recv_state - def scatter_state(self, send_state=None, recv_state=None): + def scatter_state(self, send_state=None, recv_state=None): # type: ignore[no-untyped-def] """Transfer a state dictionary from the tile root rank to all subtiles. Args: @@ -341,13 +350,13 @@ def scatter_state(self, send_state=None, recv_state=None): rank_state: the state corresponding to this rank's subdomain """ - def scatter_root(): + def scatter_root() -> None: if send_state is None: raise TypeError("send_state is a required argument on the root rank") name_list = list(send_state.keys()) while "time" in name_list: name_list.remove("time") - name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) + name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) # type: ignore[assignment] array_list = [send_state[name] for name in name_list] for name, array in zip(name_list, array_list): if name in recv_state: @@ -358,9 +367,9 @@ def scatter_root(): send_state.get("time", None), root=constants.ROOT_RANK ) - def scatter_client(): + def scatter_client() -> None: name_list = self.comm.bcast(None, root=constants.ROOT_RANK) - for name in name_list: + for name in name_list: # type: ignore if name in recv_state: self.scatter(recv_quantity=recv_state[name]) else: @@ -377,7 +386,7 @@ def scatter_client(): recv_state.pop("time") return recv_state - def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): + def halo_update(self, quantity: Quantity | list[Quantity], n_points: int) -> None: """Perform a halo update on a quantity or quantities Args: @@ -393,7 +402,7 @@ def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): halo_updater.wait() def start_halo_update( - self, quantity: Union[Quantity, List[Quantity]], n_points: int + self, quantity: Quantity | list[Quantity], n_points: int ) -> HaloUpdater: """Start an asynchronous halo update on a quantity. @@ -431,10 +440,10 @@ def start_halo_update( def vector_halo_update( self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], + x_quantity: Quantity | list[Quantity], + y_quantity: Quantity | list[Quantity], n_points: int, - ): + ) -> None: """Perform a halo update of a horizontal vector quantity or quantities. Assumes the x and y dimension indices are the same between the two quantities. @@ -460,8 +469,8 @@ def vector_halo_update( def start_vector_halo_update( self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], + x_quantity: Quantity | list[Quantity], + y_quantity: Quantity | list[Quantity], n_points: int, ) -> HaloUpdater: """Start an asynchronous halo update of a horizontal vector quantity. @@ -518,7 +527,9 @@ def start_vector_halo_update( halo_updater.start(x_quantities, y_quantities) return halo_updater - def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): + def synchronize_vector_interfaces( + self, x_quantity: Quantity, y_quantity: Quantity + ) -> None: """ Synchronize shared points at the edges of a vector interface variable. @@ -567,7 +578,9 @@ def start_synchronize_vector_interfaces( req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) return req - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + def get_scalar_halo_updater( + self, specifications: list[QuantityHaloSpec] + ) -> HaloUpdater: if len(specifications) == 0: raise RuntimeError("Cannot create updater with specifications list") if specifications[0].n_points == 0: @@ -583,9 +596,9 @@ def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): def get_vector_halo_updater( self, - specifications_x: List[QuantityHaloSpec], - specifications_y: List[QuantityHaloSpec], - ): + specifications_x: list[QuantityHaloSpec], + specifications_y: list[QuantityHaloSpec], + ) -> HaloUpdater: if len(specifications_x) == 0 and len(specifications_y) == 0: raise RuntimeError("Cannot create updater with empty specifications list") if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: @@ -616,7 +629,7 @@ def boundaries(self) -> Mapping[int, Boundary]: return self._boundaries -def bcast_metadata_list(comm, quantity_list): +def bcast_metadata_list(comm: CommABC, quantity_list: list[Quantity]): # type: ignore[no-untyped-def] is_root = comm.Get_rank() == constants.ROOT_RANK if is_root: metadata_list = [] @@ -627,7 +640,7 @@ def bcast_metadata_list(comm, quantity_list): return comm.bcast(metadata_list, root=constants.ROOT_RANK) -def bcast_metadata(comm, array): +def bcast_metadata(comm: CommABC, array: Quantity): # type: ignore[no-untyped-def] return bcast_metadata_list(comm, [array])[0] @@ -636,11 +649,11 @@ class TileCommunicator(Communicator): def __init__( self, - comm, + comm: CommABC, partitioner: TilePartitioner, force_cpu: bool = False, - timer: Optional[Timer] = None, - ): + timer: Timer | None = None, + ) -> None: """Initialize a TileCommunicator. Args: @@ -657,20 +670,20 @@ def __init__( @classmethod def from_layout( cls, - comm, - layout: Tuple[int, int], + comm: CommABC, + layout: tuple[int, int], force_cpu: bool = False, - timer: Optional[Timer] = None, - ) -> "TileCommunicator": + timer: Timer | None = None, + ) -> TileCommunicator: partitioner = TilePartitioner(layout=layout) return cls(comm=comm, partitioner=partitioner, force_cpu=force_cpu, timer=timer) @property - def tile(self): + def tile(self) -> TileCommunicator: return self def start_halo_update( - self, quantity: Union[Quantity, List[Quantity]], n_points: int + self, quantity: Quantity | list[Quantity], n_points: int ) -> HaloUpdater: """Start an asynchronous halo update on a quantity. @@ -692,8 +705,8 @@ def start_halo_update( def start_vector_halo_update( self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], + x_quantity: Quantity | list[Quantity], + y_quantity: Quantity | list[Quantity], n_points: int, ) -> HaloUpdater: """Start an asynchronous halo update of a horizontal vector quantity. @@ -759,7 +772,7 @@ def __init__( comm: CommABC, partitioner: CubedSpherePartitioner, force_cpu: bool = False, - timer: Optional[Timer] = None, + timer: Timer | None = None, ): """Initialize a CubedSphereCommunicator. @@ -780,7 +793,7 @@ def __init__( f"comm object with only {comm.Get_size()} ranks, are we running " "with mpi and the correct number of ranks?" ) - self._tile_communicator: Optional[TileCommunicator] = None + self._tile_communicator: TileCommunicator | None = None self._force_cpu = force_cpu super(CubedSphereCommunicator, self).__init__( comm, partitioner, force_cpu, timer @@ -790,11 +803,11 @@ def __init__( @classmethod def from_layout( cls, - comm, - layout: Tuple[int, int], + comm: CommABC, + layout: tuple[int, int], force_cpu: bool = False, - timer: Optional[Timer] = None, - ) -> "CubedSphereCommunicator": + timer: Timer | None = None, + ) -> CubedSphereCommunicator: partitioner = CubedSpherePartitioner(tile=TilePartitioner(layout=layout)) return cls(comm=comm, partitioner=partitioner, force_cpu=force_cpu, timer=timer) @@ -805,7 +818,7 @@ def tile(self) -> TileCommunicator: self._initialize_tile_communicator() return cast(TileCommunicator, self._tile_communicator) - def _initialize_tile_communicator(self): + def _initialize_tile_communicator(self) -> None: tile_comm = self.comm.Split( color=self.partitioner.tile_index(self.rank), key=self.rank ) @@ -823,7 +836,7 @@ def _get_gather_recv_quantity( # needs to change the quantity dimensions since we add a "tile" dimension, # unlike for tile scatter/gather which retains the same dimensions recv_quantity = Quantity( - metadata.np.zeros(global_extent, dtype=metadata.dtype), + metadata.np.zeros(global_extent, dtype=metadata.dtype), # type: ignore dims=(constants.TILE_DIM,) + metadata.dims, units=metadata.units, origin=(0,) + tuple([0 for dim in metadata.dims]), @@ -845,7 +858,7 @@ def _get_scatter_recv_quantity( # needs to change the quantity dimensions since we remove a "tile" dimension, # unlike for tile scatter/gather which retains the same dimensions recv_quantity = Quantity( - metadata.np.zeros(shape, dtype=metadata.dtype), + metadata.np.zeros(shape, dtype=metadata.dtype), # type: ignore dims=metadata.dims[1:], units=metadata.units, gt4py_backend=metadata.gt4py_backend, diff --git a/ndsl/comm/decomposition.py b/ndsl/comm/decomposition.py index 48c216e8..89781e0e 100644 --- a/ndsl/comm/decomposition.py +++ b/ndsl/comm/decomposition.py @@ -1,9 +1,10 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING from gt4py.cartesian import config as gt_config +from mpi4py import MPI if TYPE_CHECKING: @@ -23,7 +24,9 @@ def determine_rank_is_compiling(rank: int, size: int) -> bool: return rank < (size / 6) -def block_waiting_for_compilation(comm, compilation_config: CompilationConfig) -> None: +def block_waiting_for_compilation( + comm: MPI.Comm, compilation_config: CompilationConfig +) -> None: """block moving on until an ok is received from the compiling rank Args: @@ -35,7 +38,7 @@ def block_waiting_for_compilation(comm, compilation_config: CompilationConfig) - _ = comm.recv(source=compiling_rank) -def unblock_waiting_tiles(comm) -> None: +def unblock_waiting_tiles(comm: MPI.Comm) -> None: """sends a message to all the ranks waiting for compilation to finish Args: @@ -55,14 +58,14 @@ def check_cached_path_exists(cache_filepath: str) -> None: raise RuntimeError(f"Error: Could not find caches for rank at {cache_filepath}") -def build_cache_path(config: CompilationConfig) -> Tuple[str, str]: +def build_cache_path(config: CompilationConfig) -> tuple[str, str]: """generate the GT-Cache path from the config Args: config (CompilationConfig): stencil-config object at post-init state Returns: - Tuple[str, str]: path and individual rank string + tuple[str, str]: path and individual rank string """ if config.size == 1: target_rank_str = "" @@ -76,7 +79,7 @@ def build_cache_path(config: CompilationConfig) -> Tuple[str, str]: return path, target_rank_str -def set_distributed_caches(config: CompilationConfig): +def set_distributed_caches(config: CompilationConfig) -> None: """In Run mode, check required file then point current rank cache to source cache""" # Check that we have all the file we need to early out in case diff --git a/ndsl/comm/local_comm.py b/ndsl/comm/local_comm.py index 1ae10177..c52e99bc 100644 --- a/ndsl/comm/local_comm.py +++ b/ndsl/comm/local_comm.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import copy -from typing import Any +from typing import Any, TypeVar -from ndsl.comm.comm_abc import Comm +from ndsl.comm.comm_abc import Comm, ReductionOperator from ndsl.logging import ndsl_log from ndsl.utils import ensure_contiguous, safe_assign_array +T = TypeVar("T") + + class ConcurrencyError(Exception): """Exception to denote that a rank cannot proceed because it is waiting on a call from another rank.""" @@ -14,40 +19,40 @@ class ConcurrencyError(Exception): class AsyncResult: - def __init__(self, result): + def __init__(self, result) -> None: # type: ignore[no-untyped-def] self._result = result - def wait(self): + def wait(self): # type: ignore[no-untyped-def] return self._result() -class LocalComm(Comm): - def __init__(self, rank, total_ranks, buffer_dict): +class LocalComm(Comm[T]): + def __init__(self, rank: int, total_ranks: int, buffer_dict: dict) -> None: self.rank = rank self.total_ranks = total_ranks self._buffer = buffer_dict - self._i_buffer = {} + self._i_buffer: dict = {} @property - def _split_comms(self): + def _split_comms(self) -> dict: self._buffer["split_comms"] = self._buffer.get("split_comms", {}) return self._buffer["split_comms"] @property - def _split_buffers(self): + def _split_buffers(self) -> dict: self._buffer["split_buffers"] = self._buffer.get("split_buffers", {}) return self._buffer["split_buffers"] - def __repr__(self): + def __repr__(self) -> str: return f"LocalComm(rank={self.rank}, total_ranks={self.total_ranks})" - def Get_rank(self): + def Get_rank(self) -> int: return self.rank - def Get_size(self): + def Get_size(self) -> int: return self.total_ranks - def _get_buffer(self, buffer_type, in_value): + def _get_buffer(self, buffer_type: str, in_value: T | None) -> T: i_buffer = self._i_buffer.get(buffer_type, 0) self._i_buffer[buffer_type] = i_buffer + 1 if buffer_type not in self._buffer: @@ -56,7 +61,7 @@ def _get_buffer(self, buffer_type, in_value): self._buffer[buffer_type].append(in_value) return self._buffer[buffer_type][i_buffer] - def _get_send_recv(self, from_rank, tag: int): + def _get_send_recv(self, from_rank, tag: int): # type: ignore[no-untyped-def] key = (from_rank, self.rank, tag) if "send_recv" not in self._buffer: raise ConcurrencyError( @@ -70,31 +75,31 @@ def _get_send_recv(self, from_rank, tag: int): return_value = self._buffer["send_recv"][key].pop(0) return return_value - def _put_send_recv(self, value, to_rank, tag: int): + def _put_send_recv(self, value, to_rank, tag: int) -> None: # type: ignore[no-untyped-def] key = (self.rank, to_rank, tag) self._buffer["send_recv"] = self._buffer.get("send_recv", {}) self._buffer["send_recv"][key] = self._buffer["send_recv"].get(key, []) self._buffer["send_recv"][key].append(copy.deepcopy(value)) @property - def _bcast_buffer(self): + def _bcast_buffer(self) -> list: if "bcast" not in self._buffer: self._buffer["bcast"] = [] return self._buffer["bcast"] @property - def _scatter_buffer(self): + def _scatter_buffer(self) -> list: if "scatter" not in self._buffer: self._buffer["scatter"] = [] return self._buffer["scatter"] @property - def _gather_buffer(self): + def _gather_buffer(self) -> list: if "gather" not in self._buffer: self._buffer["gather"] = [None for i in range(self.total_ranks)] return self._buffer["gather"] - def bcast(self, value, root=0): + def bcast(self, value: T | None, root: int = 0) -> T | None: if root != 0: raise NotImplementedError( "LocalComm assumes ranks are called in order, so root must be " @@ -104,13 +109,13 @@ def bcast(self, value, root=0): ndsl_log.debug(f"bcast {value} to rank {self.rank}") return value - def Barrier(self): + def Barrier(self) -> None: return - def barrier(self): + def barrier(self) -> None: return - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] ensure_contiguous(sendbuf) ensure_contiguous(recvbuf) if root != 0: @@ -124,7 +129,7 @@ def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): sendbuf = self._get_buffer("scatter", None) safe_assign_array(recvbuf, sendbuf[self.rank]) - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] ensure_contiguous(sendbuf) ensure_contiguous(recvbuf) gather_buffer = self._gather_buffer @@ -141,47 +146,47 @@ def Gather(self, sendbuf, recvbuf, root=0, **kwargs): for i, sendbuf in enumerate(gather_buffer): safe_assign_array(recvbuf[i, :], sendbuf) - def allgather(self, sendobj): + def allgather(self, sendobj: T) -> list[T]: raise NotImplementedError( "cannot implement allgather on local comm due to its inherent parallelism" ) - def Send(self, sendbuf, dest, tag: int = 0, **kwargs): + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] ensure_contiguous(sendbuf) self._put_send_recv(sendbuf, dest, tag) - def Isend(self, sendbuf, dest, tag: int = 0, **kwargs): + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] result = self.Send(sendbuf, dest, tag) - def send(): + def send(): # type: ignore[no-untyped-def] return result return AsyncResult(send) - def Recv(self, recvbuf, source, tag: int = 0, **kwargs): + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] ensure_contiguous(recvbuf) safe_assign_array(recvbuf, self._get_send_recv(source, tag)) - def Irecv(self, recvbuf, source, tag: int = 0, **kwargs): - def receive(): + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + def receive(): # type: ignore[no-untyped-def] return self.Recv(recvbuf, source, tag) return AsyncResult(receive) - def sendrecv(self, sendbuf, dest, **kwargs): + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] raise NotImplementedError( "sendrecv fundamentally cannot be written for LocalComm, " "as it requires synchronicity" ) - def Split(self, color, key): + def Split(self, color, key) -> LocalComm: # type: ignore[no-untyped-def] # key argument is ignored, assumes we're calling the ranks from least to # greatest when mocking Split self._split_comms[color] = self._split_comms.get(color, []) self._split_buffers[color] = self._split_buffers.get(color, {}) rank = len(self._split_comms[color]) total_ranks = rank + 1 - new_comm = LocalComm( + new_comm: LocalComm = LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=self._split_buffers[color] ) for comm in self._split_comms[color]: @@ -189,14 +194,20 @@ def Split(self, color, key): self._split_comms[color].append(new_comm) return new_comm - def allreduce(self, sendobj, op=None, recvobj=None) -> Any: + def allreduce(self, sendobj, op=None, recvobj=None) -> Any: # type: ignore[no-untyped-def] raise NotImplementedError( "allreduce fundamentally cannot be written for LocalComm, " "as it requires synchronicity" ) - def Allreduce(self, sendobj, recvobj, op) -> Any: + def Allreduce(self, sendobj, recvobj, op) -> Any: # type: ignore[no-untyped-def] raise NotImplementedError( "Allreduce fundamentally cannot be written for LocalComm, " "as it requires synchronicity" ) + + def Allreduce_inplace(self, obj: Any, op: ReductionOperator) -> Any: + raise NotImplementedError( + "Allreduce_inplace fundamentally cannot be written for LocalComm, " + "as it requires synchronicity" + ) diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 0b4a5540..4b427914 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -1,4 +1,11 @@ -from typing import Dict, List, Optional, TypeVar, cast +"""Wrapper around mpi4py. + +This module defines a light-weight wrapper around mpi4py. It is the only +place where we directly import from mpi4py. This allows to potentially +swap mpi4py in the future. +""" + +from typing import TypeVar, cast from mpi4py import MPI @@ -9,7 +16,7 @@ class MPIComm(Comm): - _op_mapping: Dict[ReductionOperator, MPI.Op] = { + _op_mapping: dict[ReductionOperator, MPI.Op] = { ReductionOperator.OP_NULL: MPI.OP_NULL, ReductionOperator.MAX: MPI.MAX, ReductionOperator.MIN: MPI.MIN, @@ -27,7 +34,7 @@ class MPIComm(Comm): ReductionOperator.NO_OP: MPI.NO_OP, } - def __init__(self): + def __init__(self) -> None: if MPI is None: raise RuntimeError("MPI not available") self._comm: Comm = cast(Comm, MPI.COMM_WORLD) @@ -38,47 +45,49 @@ def Get_rank(self) -> int: def Get_size(self) -> int: return self._comm.Get_size() - def bcast(self, value: Optional[T], root=0) -> T: + def bcast(self, value: T | None, root: int = 0) -> T | None: return self._comm.bcast(value, root=root) - def barrier(self): + def barrier(self) -> None: self._comm.barrier() - def Barrier(self): + def Barrier(self) -> None: pass - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Scatter(sendbuf, recvbuf, root=root, **kwargs) - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Gather(sendbuf, recvbuf, root=root, **kwargs) - def allgather(self, sendobj: T) -> List[T]: + def allgather(self, sendobj: T) -> list[T]: return self._comm.allgather(sendobj) - def Send(self, sendbuf, dest, tag: int = 0, **kwargs): + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Send(sendbuf, dest, tag=tag, **kwargs) - def sendrecv(self, sendbuf, dest, **kwargs): + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] return self._comm.sendrecv(sendbuf, dest, **kwargs) - def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return self._comm.Isend(sendbuf, dest, tag=tag, **kwargs) - def Recv(self, recvbuf, source, tag: int = 0, **kwargs): + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] self._comm.Recv(recvbuf, source, tag=tag, **kwargs) - def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return self._comm.Irecv(recvbuf, source, tag=tag, **kwargs) - def Split(self, color, key) -> "Comm": + def Split(self, color, key) -> Comm: # type: ignore[no-untyped-def] return self._comm.Split(color, key) - def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: return self._comm.allreduce(sendobj, self._op_mapping[op]) - def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T: - return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op]) + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + return self._comm.Allreduce(sendobj, recvobj, self._op_mapping[op]) def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T: return self._comm.Allreduce(MPI.IN_PLACE, recvobj, self._op_mapping[op]) diff --git a/ndsl/comm/null_comm.py b/ndsl/comm/null_comm.py index 5ca92359..2d67c78f 100644 --- a/ndsl/comm/null_comm.py +++ b/ndsl/comm/null_comm.py @@ -1,25 +1,31 @@ import copy -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, TypeVar, cast from ndsl.comm.comm_abc import Comm, ReductionOperator, Request +T = TypeVar("T") + + class NullAsyncResult(Request): - def __init__(self, recvbuf=None): + def __init__(self, recvbuf: Any = None) -> None: self._recvbuf = recvbuf - def wait(self): + def wait(self) -> None: if self._recvbuf is not None: self._recvbuf[:] = 0.0 -class NullComm(Comm): +class NullComm(Comm[T]): """ A class with a subset of the mpi4py Comm API, but which 'receives' a fill value (default zero) instead of using MPI. """ - def __init__(self, rank, total_ranks, fill_value=0.0): + default_fill_value: T = cast(T, 0) + + def __init__(self, rank: int, total_ranks: int, fill_value: T = default_fill_value): """ Args: rank: rank to mock @@ -30,56 +36,56 @@ def __init__(self, rank, total_ranks, fill_value=0.0): self.rank = rank self.total_ranks = total_ranks self._fill_value = fill_value - self._split_comms: Mapping[Any, NullComm] = {} + self._split_comms: Mapping[Any, list[NullComm]] = {} - def __repr__(self): + def __repr__(self) -> str: return f"NullComm(rank={self.rank}, total_ranks={self.total_ranks})" - def Get_rank(self): + def Get_rank(self) -> int: return self.rank - def Get_size(self): + def Get_size(self) -> int: return self.total_ranks - def bcast(self, value, root=0): + def bcast(self, value: T | None, root: int = 0) -> T | None: return value - def barrier(self): + def barrier(self) -> None: return - def Barrier(self): + def Barrier(self) -> None: return - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] if recvbuf is not None: recvbuf[:] = self._fill_value - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] if recvbuf is not None: recvbuf[:] = self._fill_value - def allgather(self, sendobj): + def allgather(self, sendobj: T) -> list[T]: return [copy.deepcopy(sendobj) for _ in range(self.total_ranks)] - def Send(self, sendbuf, dest, **kwargs): + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] pass - def Isend(self, sendbuf, dest, **kwargs): + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return NullAsyncResult() - def Recv(self, recvbuf, source, **kwargs): + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] recvbuf[:] = self._fill_value - def Irecv(self, recvbuf, source, **kwargs): + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] return NullAsyncResult(recvbuf) - def sendrecv(self, sendbuf, dest, **kwargs): + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] return sendbuf - def Split(self, color, key): + def Split(self, color, key) -> Comm: # type: ignore[no-untyped-def] # key argument is ignored, assumes we're calling the ranks from least to # greatest when mocking Split - self._split_comms[color] = self._split_comms.get(color, []) + self._split_comms[color] = self._split_comms.get(color, []) # type: ignore[index] rank = len(self._split_comms[color]) total_ranks = rank + 1 new_comm = NullComm( @@ -91,9 +97,15 @@ def Split(self, color, key): self._split_comms[color].append(new_comm) return new_comm - def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: return self._fill_value - def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + # TODO: what about reduction operator `op`? recvobj = sendobj return recvobj + + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: + raise NotImplementedError("NullComm.Allreduce_inplace") diff --git a/ndsl/comm/partitioner.py b/ndsl/comm/partitioner.py index 6b8750a1..db0150ca 100644 --- a/ndsl/comm/partitioner.py +++ b/ndsl/comm/partitioner.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import abc import copy import functools -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast +from collections.abc import Callable, Sequence +from typing import Self, TypeVar, cast +import f90nml import numpy as np import ndsl.constants as constants @@ -55,21 +59,28 @@ def get_tile_number(tile_rank: int, total_ranks: int) -> int: class Partitioner(abc.ABC): - @abc.abstractmethod - def __init__(self): - self.tile = None - self.layout = None + tile: TilePartitioner + layout: tuple[int, int] + + def __init__( + self, tile: TilePartitioner, layout: tuple[int, int] | list[int] + ) -> None: + self.tile = tile + if len(layout) != 2: + raise ValueError( + f"Expected layout to be a tuple/list of two integers. Got {layout} instead." + ) + self.layout = tuple(layout) # type: ignore[assignment] @abc.abstractmethod - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: - ... + def boundary(self, boundary_type: int, rank: int) -> bd.SimpleBoundary | None: ... @abc.abstractmethod - def tile_index(self, rank: int): + def tile_index(self, rank: int) -> int: pass @abc.abstractmethod - def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: + def global_extent(self, rank_metadata: QuantityMetadata) -> tuple[int, ...]: """Return the shape of a full tile representation for the given dimensions. Args: @@ -87,7 +98,7 @@ def subtile_slice( global_dims: Sequence[str], global_extent: Sequence[int], overlap: bool = False, - ) -> Tuple[Union[int, slice], ...]: + ) -> tuple[int | slice, ...]: """Return the subtile slice of a given rank on an array. Global refers to the domain being partitioned. For example, for a partitioning @@ -113,7 +124,7 @@ def subtile_extent( self, global_metadata: QuantityMetadata, rank: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: """Return the shape of a single rank representation for the given dimensions. Args: @@ -134,19 +145,18 @@ def total_ranks(self) -> int: class TilePartitioner(Partitioner): def __init__( self, - layout: Tuple[int, int], + layout: tuple[int, int] | list[int], edge_interior_ratio: float = 1.0, ): """Create an object for fv3gfs tile decomposition.""" - self.layout = layout self.edge_interior_ratio = edge_interior_ratio - self.tile = self + super().__init__(self, layout) - def tile_index(self, rank: int): + def tile_index(self, rank: int) -> int: return 0 @classmethod - def from_namelist(cls, namelist): + def from_namelist(cls, namelist: f90nml.Namelist) -> Self: """Initialize a TilePartitioner from a Fortran namelist. Args: @@ -154,7 +164,7 @@ def from_namelist(cls, namelist): """ return cls(layout=namelist["fv_core_nml"]["layout"]) - def subtile_index(self, rank: int) -> Tuple[int, int]: + def subtile_index(self, rank: int) -> tuple[int, int]: """ Return the (y, x) subtile position of a given rank as an integer number of subtiles. @@ -166,8 +176,8 @@ def total_ranks(self) -> int: return self.layout[0] * self.layout[1] def global_extent( - self, rank_metadata: Union[Quantity, QuantityMetadata] - ) -> Tuple[int, ...]: + self, rank_metadata: Quantity | QuantityMetadata + ) -> tuple[int, ...]: """Return the shape of a full tile representation for the given dimensions. Args: @@ -184,7 +194,7 @@ def subtile_extent( self, global_metadata: QuantityMetadata, rank: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: """Return the shape of a single rank representation for the given dimensions. Args: @@ -210,7 +220,7 @@ def subtile_slice( global_dims: Sequence[str], global_extent: Sequence[int], overlap: bool = False, - ) -> Tuple[slice, ...]: + ) -> tuple[slice, ...]: """Return the subtile slice of a given rank on an array. Global refers to the domain being partitioned. For example, for a partitioning @@ -250,7 +260,7 @@ def on_tile_left(self, rank: int) -> bool: def on_tile_right(self, rank: int) -> bool: return on_tile_right(self.subtile_index(rank), self.layout) - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: + def boundary(self, boundary_type: int, rank: int) -> bd.SimpleBoundary | None: """Returns a boundary of the requested type for a given rank. Target ranks will be on the same tile as the given rank, wrapping around as @@ -269,7 +279,7 @@ def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary] @functools.lru_cache(maxsize=DEFAULT_CACHE_SIZE) def _cached_boundary( self, boundary_type: int, rank: int - ) -> Optional[bd.SimpleBoundary]: + ) -> bd.SimpleBoundary | None: boundary = { WEST: self._left_edge, EAST: self._right_edge, @@ -330,18 +340,18 @@ def _bottom_edge(self, rank: int) -> bd.SimpleBoundary: n_clockwise_rotations=0, ) - def _top_left_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _top_left_corner(self, rank: int) -> bd.SimpleBoundary | None: return _get_corner(constants.NORTHWEST, rank, self._left_edge, self._top_edge) - def _top_right_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _top_right_corner(self, rank: int) -> bd.SimpleBoundary | None: return _get_corner(constants.NORTHEAST, rank, self._right_edge, self._top_edge) - def _bottom_left_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _bottom_left_corner(self, rank: int) -> bd.SimpleBoundary | None: return _get_corner( constants.SOUTHWEST, rank, self._left_edge, self._bottom_edge ) - def _bottom_right_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _bottom_right_corner(self, rank: int) -> bd.SimpleBoundary | None: return _get_corner( constants.SOUTHEAST, rank, self._right_edge, self._bottom_edge ) @@ -358,7 +368,7 @@ def _get_corner( rank: int, edge_func_1: Callable[[int], bd.Boundary], edge_func_2: Callable[[int], bd.Boundary], -): +) -> bd.SimpleBoundary: edge_1 = edge_func_1(rank) edge_2 = edge_func_2(edge_1.to_rank) rotations = edge_1.n_clockwise_rotations + edge_2.n_clockwise_rotations @@ -379,10 +389,10 @@ def __init__(self, tile: TilePartitioner): """ if not isinstance(tile, TilePartitioner): raise TypeError("tile must be a TilePartitioner") - self.tile = tile + super().__init__(tile, tile.layout) @classmethod - def from_namelist(cls, namelist): + def from_namelist(cls, namelist: f90nml.Namelist) -> Self: """Initialize a CubedSpherePartitioner from a Fortran namelist. Args: @@ -402,16 +412,12 @@ def tile_root_rank(self, rank: int) -> int: """Returns the lowest rank on the same tile as a given rank.""" return self.tile.total_ranks * (rank // self.tile.total_ranks) - @property - def layout(self) -> Tuple[int, int]: - return self.tile.layout - @property def total_ranks(self) -> int: """the number of ranks on the cubed sphere""" return 6 * self.tile.total_ranks - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: + def boundary(self, boundary_type: int, rank: int) -> bd.SimpleBoundary | None: """Returns a boundary of the requested type for a given rank, or None. On tile corners, the boundary across that corner does not exist. @@ -429,7 +435,7 @@ def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary] @functools.lru_cache(maxsize=DEFAULT_CACHE_SIZE) def _cached_boundary( self, boundary_type: int, rank: int - ) -> Optional[bd.SimpleBoundary]: + ) -> bd.SimpleBoundary | None: boundary = { WEST: self._left_edge, EAST: self._right_edge, @@ -530,7 +536,7 @@ def _bottom_edge(self, rank: int) -> bd.SimpleBoundary: boundary.to_rank -= self.tile.total_ranks return boundary - def _top_left_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _top_left_corner(self, rank: int) -> bd.SimpleBoundary | None: if self.tile.on_tile_top(rank) and self.tile.on_tile_left(rank): corner = None else: @@ -545,7 +551,7 @@ def _top_left_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: ) return corner - def _top_right_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _top_right_corner(self, rank: int) -> bd.SimpleBoundary | None: if on_tile_top(self.tile.subtile_index(rank), self.layout) and on_tile_right( self.tile.subtile_index(rank), self.layout ): @@ -562,7 +568,7 @@ def _top_right_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: ) return corner - def _bottom_left_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _bottom_left_corner(self, rank: int) -> bd.SimpleBoundary | None: if on_tile_bottom(self.tile.subtile_index(rank)) and on_tile_left( self.tile.subtile_index(rank) ): @@ -579,7 +585,7 @@ def _bottom_left_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: ) return corner - def _bottom_right_corner(self, rank: int) -> Optional[bd.SimpleBoundary]: + def _bottom_right_corner(self, rank: int) -> bd.SimpleBoundary | None: if on_tile_bottom(self.tile.subtile_index(rank)) and on_tile_right( self.tile.subtile_index(rank), self.layout ): @@ -613,7 +619,7 @@ def _get_corner( n_clockwise_rotations=rotations, ) - def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: + def global_extent(self, rank_metadata: QuantityMetadata) -> tuple[int, ...]: """Return the shape of a full cube representation for the given dimensions. Args: @@ -630,7 +636,7 @@ def subtile_extent( self, cube_metadata: QuantityMetadata, rank: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: """Return the shape of a single rank representation for the given dimensions. Args: @@ -649,7 +655,7 @@ def subtile_slice( global_dims: Sequence[str], global_extent: Sequence[int], overlap: bool = False, - ) -> Tuple[Union[int, slice], ...]: + ) -> tuple[int | slice, ...]: """Return the subtile slice of a given rank on an array. Global refers to the domain being partitioned. For example, for a partitioning @@ -682,24 +688,24 @@ def subtile_slice( ) -def on_tile_left(subtile_index: Tuple[int, int]) -> bool: +def on_tile_left(subtile_index: tuple[int, int]) -> bool: return subtile_index[1] == 0 -def on_tile_right(subtile_index: Tuple[int, int], layout: Tuple[int, int]) -> bool: +def on_tile_right(subtile_index: tuple[int, int], layout: tuple[int, int]) -> bool: return subtile_index[1] == layout[1] - 1 -def on_tile_top(subtile_index: Tuple[int, int], layout: Tuple[int, int]) -> bool: +def on_tile_top(subtile_index: tuple[int, int], layout: tuple[int, int]) -> bool: return subtile_index[0] == layout[0] - 1 -def on_tile_bottom(subtile_index: Tuple[int, int]) -> bool: +def on_tile_bottom(subtile_index: tuple[int, int]) -> bool: return subtile_index[0] == 0 def rotate_subtile_rank( - rank: int, layout: Tuple[int, int], n_clockwise_rotations: int + rank: int, layout: tuple[int, int], n_clockwise_rotations: int ) -> int: """Returns the rank position where this rank would be if you rotated the tile n_clockwise_rotations times. @@ -716,21 +722,21 @@ def rotate_subtile_rank( return to_tile_rank -def transpose_subtile_rank(rank, layout): +def transpose_subtile_rank(rank: int, layout: tuple[int, int]) -> int: """Returns the rank position where this rank would be if you transposed the tile. """ return transform_subtile_rank(np.transpose, rank, layout) -def fliplr_subtile_rank(rank, layout): +def fliplr_subtile_rank(rank: int, layout: tuple[int, int]) -> int: """Returns the rank position where this rank would be if you flipped the tile along a vertical axis """ return transform_subtile_rank(np.fliplr, rank, layout) -def flipud_subtile_rank(rank, layout): +def flipud_subtile_rank(rank: int, layout: tuple[int, int]) -> int: """Returns the rank position where this rank would be if you flipped the tile along a horizontal axis """ @@ -740,8 +746,8 @@ def flipud_subtile_rank(rank, layout): def transform_subtile_rank( transform_func: Callable[[np.ndarray], np.ndarray], rank: int, - layout: Tuple[int, int], -): + layout: tuple[int, int], +) -> int: """Returns the rank position where this rank would be if you performed a transformation on the tile which strictly moves ranks. """ @@ -752,24 +758,24 @@ def transform_subtile_rank( def subtile_index( - rank: int, ranks_per_tile: int, layout: Tuple[int, int] -) -> Tuple[int, int]: + rank: int, ranks_per_tile: int, layout: tuple[int, int] +) -> tuple[int, int]: within_tile_rank = rank % ranks_per_tile j = within_tile_rank // layout[1] i = within_tile_rank % layout[1] return j, i -def is_even(value: Union[int, float]) -> bool: +def is_even(value: int | float) -> bool: return value % 2 == 0 def tile_extent_from_rank_metadata( dims: Sequence[str], rank_extent: Sequence[int], - layout: Tuple[int, int], + layout: tuple[int, int], edge_interior_ratio: float = 1.0, -) -> Tuple[int, ...]: +) -> tuple[int, ...]: """ Returns the extent of a tile given data about a single rank, and the tile layout. @@ -801,11 +807,11 @@ def rank_slice_from_tile_metadata( dims: Sequence[str], *, extent: Sequence[int], - layout: Tuple[int, int], - subtile_index: Tuple[int, int], + layout: tuple[int, int], + subtile_index: tuple[int, int], edge_interior_ratio: float, overlap: bool, -) -> Tuple[slice, ...]: +) -> tuple[slice, ...]: return _rank_slice_from_tile_metadata_cached( dims=tuple(dims), extent=tuple(extent), @@ -818,14 +824,14 @@ def rank_slice_from_tile_metadata( @functools.lru_cache(maxsize=DEFAULT_CACHE_SIZE) def _rank_slice_from_tile_metadata_cached( - dims: Tuple[str, ...], + dims: tuple[str, ...], *, - extent: Tuple[int, ...], - layout: Tuple[int, int], - subtile_index: Tuple[int, int], + extent: tuple[int, ...], + layout: tuple[int, int], + subtile_index: tuple[int, int], edge_interior_ratio: float, overlap: bool, -) -> Tuple[slice, ...]: +) -> tuple[slice, ...]: # detect if one of the given dims is the tile dimension and ignore it cartesian_dims = discard_dimension(dims, constants.TILE_DIM, data=dims) cartesian_extent = discard_dimension(dims, constants.TILE_DIM, data=extent) @@ -868,16 +874,18 @@ def _rank_slice_from_tile_metadata_cached( T = TypeVar("T") -def discard_dimension(dims, dim_name: str, data: Sequence[T]) -> List[T]: +def discard_dimension( + dims: tuple[str, ...], dim_name: str, data: Sequence[T] +) -> list[T]: return [item for (item, dim) in zip(data, dims) if dim != dim_name] def _subtile_extents_from_tile_metadata( dims: Sequence[str], tile_extent: Sequence[int], - layout: Tuple[int, int], + layout: tuple[int, int], edge_interior_ratio: float = 1.0, -) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: +) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Returns the extent of a given rank given data about a tile, and the tile layout. @@ -918,7 +926,7 @@ def _valid_edge_tile_sizes( # steps through all valid sizes to sort them: # [start, counting down to 1, counting up from start] - for i in range(len(unsorted_valid_sizes) + 1): + for _i in range(len(unsorted_valid_sizes) + 1): index = start + factor * offset if index in unsorted_valid_sizes and index not in valid_sizes: valid_sizes.append(index) @@ -989,7 +997,7 @@ def _valid_edge_tile_sizes( def extent_from_metadata( dims: Sequence[str], extent: Sequence[int], layout_factors: np.ndarray -) -> Tuple[int, ...]: +) -> tuple[int, ...]: return_extents = [] for dim, rank_extent, layout_factor in zip(dims, extent, layout_factors): if dim in constants.INTERFACE_DIMS: @@ -1004,11 +1012,11 @@ def extent_from_metadata( def subtile_slice( dims: Sequence[str], global_extent: Sequence[int], - layout: Tuple[int, int], - subtile_index: Tuple[int, int], + layout: tuple[int, int], + subtile_index: tuple[int, int], edge_interior_ratio: float = 1.0, overlap: bool = False, -) -> Tuple[slice, ...]: +) -> tuple[slice, ...]: """ Returns the slice of data within a tile's computational domain belonging to a single rank. diff --git a/ndsl/constants.py b/ndsl/constants.py index 7092f361..82d16d30 100644 --- a/ndsl/constants.py +++ b/ndsl/constants.py @@ -1,5 +1,6 @@ import os from enum import Enum +from typing import Literal import numpy as np @@ -16,13 +17,29 @@ class ConstantVersions(Enum): GEOS = "GEOS" # Constant as defined in GEOS v11.4.2 -CONST_VERSION_AS_STR = os.environ.get("PACE_CONSTANTS", "UFS") +def _get_constant_version( + default: Literal["GFDL", "UFS", "GEOS"] = "UFS", +) -> Literal["GFDL", "UFS", "GEOS"]: + if os.getenv("PACE_CONSTANTS", ""): + ndsl_log.warning("PACE_CONSTANTS is deprecated. Use NDSL_CONSTANTS instead.") + if os.getenv("NDSL_CONSTANTS", ""): + ndsl_log.warning( + "PACE_CONSTANTS and NDSL_CONSTANTS were both specified. NDSL_CONSTANTS will take precedence." + ) -try: - CONST_VERSION = ConstantVersions[CONST_VERSION_AS_STR] - ndsl_log.info(f"Constant selected: {CONST_VERSION}") -except KeyError as e: - raise RuntimeError(f"Constants {CONST_VERSION_AS_STR} is not implemented, abort.") + constants_as_str = os.getenv("NDSL_CONSTANTS", os.getenv("PACE_CONSTANTS", default)) + expected: list[Literal["GFDL", "UFS", "GEOS"]] = ["GFDL", "UFS", "GEOS"] + + if constants_as_str not in expected: + raise RuntimeError( + f"Constants '{constants_as_str}' is not implemented, abort. Valid values are {expected}." + ) + + return constants_as_str # type: ignore + + +CONST_VERSION = ConstantVersions[_get_constant_version()] +ndsl_log.info(f"Constant selected: {CONST_VERSION}") ##################### # Common constants diff --git a/ndsl/debug/config.py b/ndsl/debug/config.py index 94ee7119..fb747ca3 100644 --- a/ndsl/debug/config.py +++ b/ndsl/debug/config.py @@ -31,7 +31,7 @@ ndsl_debugger = None -def _set_debugger(): +def _set_debugger() -> None: config = os.getenv("NDSL_DEBUG_CONFIG", "") if not os.path.exists(config): if config != "": diff --git a/ndsl/debug/debugger.py b/ndsl/debug/debugger.py index 7e1f60fe..b9de8631 100644 --- a/ndsl/debug/debugger.py +++ b/ndsl/debug/debugger.py @@ -2,6 +2,7 @@ import numbers import os import pathlib +from typing import Any import pandas as pd import xarray as xr @@ -26,7 +27,7 @@ class Debugger: calls_count: dict[str, int] = dataclasses.field(default_factory=dict) track_parameter_count: dict[str, int] = dataclasses.field(default_factory=dict) - def _to_xarray(self, data, name) -> xr.DataArray: + def _to_xarray(self, data: Any, name: str | None) -> xr.DataArray: if isinstance(data, Quantity): if self.save_compute_domain_only: mem = data.field @@ -42,13 +43,13 @@ def _to_xarray(self, data, name) -> xr.DataArray: or pd.api.types.is_string_dtype(data) or isinstance(data, numbers.Number) ): - return xr.DataArray(data) + return xr.DataArray(data, name=name) else: ndsl_log.error(f"[Debugger] Cannot save data of type {type(data)}") return xr.DataArray([0]) return xr.DataArray(mem, dims=[f"dim_{i}_{s}" for i, s in enumerate(shp)]) - def track_data(self, data_as_dict, source_as_name, is_in) -> None: + def track_data(self, data_as_dict: dict, source_as_name: str, is_in: bool) -> None: for name, data in data_as_dict.items(): if name not in self.track_parameter_by_name: continue @@ -71,10 +72,10 @@ def track_data(self, data_as_dict, source_as_name, is_in) -> None: self.track_parameter_count[name] += 1 - def save_as_dataset(self, data_as_dict, savename, is_in) -> None: - """Save dictionnary of data to NetCDF + def save_as_dataset(self, data_as_dict: dict, savename: str, is_in: bool) -> None: + """Save dictionary of data to NetCDF - Note: Unknown types in the dictionnary won't be saved. + Note: Unknown types in the dictionary won't be saved. """ if savename not in self.stencils_or_class: return @@ -102,7 +103,7 @@ def save_as_dataset(self, data_as_dict, savename, is_in) -> None: except ValueError as e: ndsl_log.error(f"[DebugInfo] Failure to save {savename}: {e}") - def increment_call_count(self, savename: str): + def increment_call_count(self, savename: str) -> None: """Increment the call count for this savename""" if savename not in self.calls_count.keys(): self.calls_count[savename] = 0 diff --git a/ndsl/debug/tooling.py b/ndsl/debug/tooling.py index 4ec89bbc..aebc83d1 100644 --- a/ndsl/debug/tooling.py +++ b/ndsl/debug/tooling.py @@ -1,15 +1,17 @@ import inspect +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any from ndsl.debug.config import ndsl_debugger -def instrument(func) -> Callable: +def instrument(func: Callable) -> Callable: @wraps(func) - def wrapper(self, *args: Any, **kwargs: Any): + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if ndsl_debugger is None: return func(self, *args, **kwargs) + savename = func.__qualname__ params = inspect.signature(func).parameters data_as_dict = {} diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index e3fe0cc8..5fa508d2 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -1,8 +1,10 @@ # Literal precision for both GT4Py & NDSL import os import sys +from typing import Literal from ndsl.comm.mpi import MPI +from ndsl.logging import ndsl_log gt4py_config_module = "gt4py.cartesian.config" @@ -12,8 +14,35 @@ " Please import `ndsl.dsl` or any `ndsl` module " " before any `gt4py` imports." ) -NDSL_GLOBAL_PRECISION = int(os.getenv("PACE_FLOAT_PRECISION", "64")) -os.environ["GT4PY_LITERAL_PRECISION"] = str(NDSL_GLOBAL_PRECISION) + + +def _get_literal_precision(default: Literal["32", "64"] = "64") -> Literal["32", "64"]: + if os.getenv("PACE_FLOAT_PRECISION", ""): + ndsl_log.warning( + "PACE_FLOAT_PRECISION is deprecated. Use NDSL_LITERAL_PRECISION instead." + ) + if os.getenv("NDSL_LITERAL_PRECISION", ""): + ndsl_log.warning( + "PACE_FLOAT_PRECISION and NDSL_LOGLEVEL were both specified. NDSL_LITERAL_PRECISION will take precedence." + ) + + precision = os.getenv( + "NDSL_LITERAL_PRECISION", os.getenv("PACE_FLOAT_PRECISION", default) + ) + + expected: list[Literal["32", "64"]] = ["32", "64"] + if precision in expected: + return precision # type: ignore + + ndsl_log.warning( + f"Unexpected literal precision '{precision}', falling back to '{default}'. Valid values are {expected}." + ) + return default + + +NDSL_GLOBAL_PRECISION = int(_get_literal_precision()) +os.environ["GT4PY_LITERAL_INT_PRECISION"] = str(NDSL_GLOBAL_PRECISION) +os.environ["GT4PY_LITERAL_FLOAT_PRECISION"] = str(NDSL_GLOBAL_PRECISION) # Set cache names for default gt backends workflow diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index edf563b7..1c1e7ec8 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -6,7 +6,7 @@ def identify_code_path( rank: int, partitioner: Partitioner, ) -> FV3CodePath: - if partitioner.layout == (1, 1) or partitioner.layout == [1, 1]: + if partitioner.layout == (1, 1): return FV3CodePath.All elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1: raise NotImplementedError( diff --git a/ndsl/dsl/caches/codepath.py b/ndsl/dsl/caches/codepath.py index 8ebf9492..61591ccf 100644 --- a/ndsl/dsl/caches/codepath.py +++ b/ndsl/dsl/caches/codepath.py @@ -23,10 +23,10 @@ class FV3CodePath(enum.Enum): Bottom = "FV3_B" Center = "FV3_C" - def __str__(self): + def __str__(self) -> str: return self.value - def __repr__(self): + def __repr__(self) -> str: return self.value def __format__(self, format_spec: str) -> str: diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index e69de29b..597c2a31 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -0,0 +1,5 @@ +from .dace_config import DaceConfig +from .orchestration import orchestrate, orchestrate_function + + +__all__ = ["DaceConfig", "orchestrate", "orchestrate_function"] diff --git a/ndsl/dsl/dace/build.py b/ndsl/dsl/dace/build.py index 90d36a46..155583a0 100644 --- a/ndsl/dsl/dace/build.py +++ b/ndsl/dsl/dace/build.py @@ -1,6 +1,6 @@ -from typing import List, Optional, Tuple - +from dace import config as dace_conf from dace.sdfg import SDFG +from gt4py.cartesian import config as gt_config from ndsl.dsl.caches.cache_location import get_cache_directory, get_cache_fullpath from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration @@ -11,7 +11,7 @@ # Distributed compilation -def unblock_waiting_tiles(comm, sdfg_path: str) -> None: +def unblock_waiting_tiles(comm, sdfg_path: str) -> None: # type: ignore if comm and comm.Get_size() > 1: for tile in range(1, 6): tilesize = comm.Get_size() / 6 @@ -23,8 +23,8 @@ def build_info_filepath() -> str: def write_build_info( - sdfg: SDFG, layout: Tuple[int, int], resolution_per_tile: List[int], backend: str -): + sdfg: SDFG, layout: tuple[int, int], resolution_per_tile: list[int], backend: str +) -> None: """Write down all relevant information on the build to identify it at load time.""" # Dev NOTE: we should be able to leverage sdfg.make_key to get a hash or @@ -48,9 +48,9 @@ def write_build_info( def get_sdfg_path( daceprog_name: str, config: DaceConfig, - sdfg_file_path: Optional[str] = None, - override_run_only=False, -) -> Optional[str]: + sdfg_file_path: str | None = None, + override_run_only: bool = False, +) -> str | None: """Build an SDFG path from the qualified program name or it's direct path to .sdfg Args: @@ -101,12 +101,12 @@ def get_sdfg_path( f"cannot be run with current resolution {config.tile_resolution}" ) - print(f"[DaCe Config] Rank {config.my_rank} loading SDFG {sdfg_dir_path}") + ndsl_log.debug(f"[DaCe Config] Rank {config.my_rank} loading SDFG {sdfg_dir_path}") return sdfg_dir_path -def set_distributed_caches(config: "DaceConfig"): +def set_distributed_caches(config: DaceConfig) -> None: """In Run mode, check required file then point current rank cache to source cache""" # Execute specific initialization per orchestration state @@ -127,14 +127,25 @@ def set_distributed_caches(config: "DaceConfig"): ) # Set read/write caches to the target rank - from gt4py.cartesian import config as gt_config - if config.do_compile: verb = "reading/writing" else: verb = "reading" gt_config.cache_settings["dir_name"] = get_cache_directory(config.code_path) + + # NOTE: In the (rare) case we orchestrate code _without_ any stencils, we need + # to set the build folder. The other code is in FrozenStencil and deals with the + # case of `dace` used in both orchestrated and not orchestrated. + # A better build system would deal with this in BOTH cases. + dace_conf.Config.set( + "default_build_folder", + value="{gt_root}/{gt_cache}/dacecache".format( + gt_root=gt_config.cache_settings["root_path"], + gt_cache=gt_config.cache_settings["dir_name"], + ), + ) + ndsl_log.info( f"[{orchestration_mode}] Rank {config.my_rank} " f"{verb} cache {gt_config.cache_settings['dir_name']}" diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 27f17375..d76e10da 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -1,18 +1,23 @@ +from __future__ import annotations + import enum import os -from typing import Any, Dict, Optional, Tuple +from typing import Any, Self import dace.config -from dace.codegen.compiled_sdfg import CompiledSDFG from dace.frontend.python.parser import DaceProgram +from gt4py.cartesian.config import GT4PY_COMPILE_OPT_LEVEL from ndsl.comm.communicator import Communicator +from ndsl.comm.null_comm import NullComm from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import get_precision +from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp +from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector # This can be turned on to revert compilation for orchestration @@ -21,6 +26,21 @@ DEACTIVATE_DISTRIBUTED_DACE_COMPILE = False +def _debug_dace_orchestration() -> bool: + """ + Debugging Dace orchestration deeper can be done by turning on `syncdebug`. + We control this Dace configuration below with our own override. + """ + if os.getenv("PACE_DACE_DEBUG", ""): + ndsl_log.warning("PACE_DACE_DEBUG is deprecated. Use NDSL_DACE_DEBUG instead.") + if os.getenv("NDSL_DACE_DEBUG", ""): + ndsl_log.warning( + "PACE_DACE_DEBUG and NDSL_DACE_DEBUG were both specified. NDSL_DACE_DEBUG will take precedence." + ) + + return os.getenv("NDSL_DACE_DEBUG", os.getenv("PACE_DACE_DEBUG", "False")) == "True" + + def _is_corner(rank: int, partitioner: Partitioner) -> bool: if partitioner.tile.on_tile_bottom(rank): if partitioner.tile.on_tile_left(rank): @@ -35,28 +55,28 @@ def _is_corner(rank: int, partitioner: Partitioner) -> bool: return False -def _smallest_rank_bottom(x: int, y: int, layout: Tuple[int, int]): +def _smallest_rank_bottom(x: int, y: int, layout: tuple[int, int]) -> bool: return y == 0 and x == 1 -def _smallest_rank_top(x: int, y: int, layout: Tuple[int, int]): +def _smallest_rank_top(x: int, y: int, layout: tuple[int, int]) -> bool: return y == layout[1] - 1 and x == 1 -def _smallest_rank_left(x: int, y: int, layout: Tuple[int, int]): +def _smallest_rank_left(x: int, y: int, layout: tuple[int, int]) -> bool: return x == 0 and y == 1 -def _smallest_rank_right(x: int, y: int, layout: Tuple[int, int]): +def _smallest_rank_right(x: int, y: int, layout: tuple[int, int]) -> bool: return x == layout[0] - 1 and y == 1 -def _smallest_rank_middle(x: int, y: int, layout: Tuple[int, int]): +def _smallest_rank_middle(x: int, y: int, layout: tuple[int, int]) -> bool: return layout[0] > 1 and layout[1] > 1 and x == 1 and y == 1 def _determine_compiling_ranks( - config: "DaceConfig", + config: DaceConfig, partitioner: Partitioner, ) -> bool: """ @@ -126,48 +146,54 @@ class DaCeOrchestration(enum.Enum): Run = 3 -class FrozenCompiledSDFG: - """ - Cache transform args to allow direct execution of the CSDFG - - Args: - csdfg: compiled SDFG, e.g. loaded .so - sdfg_args: transformed args to align for CSDFG direct execution - - WARNING: No checks are done on arguments, any memory swap (free/realloc) - will lead to difficult to debug misbehavior - """ - - def __init__( - self, daceprog: DaceProgram, csdfg: CompiledSDFG, args, kwargs - ) -> None: - self.csdfg = csdfg - self.sdfg_args = daceprog._create_sdfg_args(csdfg.sdfg, args, kwargs) - - def __call__(self): - return self.csdfg(**self.sdfg_args) - - class DaceConfig: def __init__( self, - communicator: Optional[Communicator], + communicator: Communicator | None, backend: str, tile_nx: int = 0, tile_nz: int = 0, - orchestration: Optional[DaCeOrchestration] = None, + orchestration: DaCeOrchestration | None = None, + time: bool = False, ): + """Specialize the DaCe configuration for NDSL use. + + Dev note: This class wrongly carries two runtime values: + - `loaded_precompiled_SDFG`: cache of SDFG loaded post build + - `performance_collector`: runtime timer shared for all runtime call + of orchestrate code + + Args: + communicator: used for setting the distributed caches + backend: string for the backend + tile_nx: x/y domain size for a single time + tile_nz: z domain size for a single time + orchestration: orchestration mode from DaCeOrchestration + time: trigger performance collection, available to user with + `performance_collector` + """ + # Recording SDFG loaded for fast re-access # ToDo: DaceConfig becomes a bit more than a read-only config - # with this. Should be refactor into a DaceExecutor carrying a config - self.loaded_precompiled_SDFG: Dict[DaceProgram, FrozenCompiledSDFG] = {} + # with this. Should be refactored into a DaceExecutor carrying a config + self.loaded_precompiled_SDFG: dict[DaceProgram, dace.CompiledSDFG] = {} + self.performance_collector = ( + PerformanceCollector( + "InternalOrchestrationTimer", + comm=( + communicator.comm if communicator is not None else NullComm(0, 6, 0) + ), + ) + if time + else NullPerformanceCollector() + ) # Temporary. This is a bit too out of the ordinary for the common user. # We should refactor the architecture to allow for a `gtc:orchestrated:dace:X` # backend that would signify both the `CPU|GPU` split and the orchestration mode if orchestration is None: fv3_dacemode_env_var = os.getenv("FV3_DACEMODE", "Python") - # The below condition guard against defining empty FV3_DACEMODE and + # The below condition guards against defining empty FV3_DACEMODE and # awkward behavior of os.getenv returning "" even when not defined if fv3_dacemode_env_var is None or fv3_dacemode_env_var == "": fv3_dacemode_env_var = "Python" @@ -175,25 +201,47 @@ def __init__( else: self._orchestrate = orchestration - # Debugging Dace orchestration deeper can be done by turning on `syncdebug` - # We control this Dace configuration below with our own override - dace_debug_env_var = os.getenv("PACE_DACE_DEBUG", "False") == "True" + # We hijack the optimization level of GT4Py because we don't + # have the configuration at NDSL level, but we do use the GT4Py + # level + # TODO: if GT4PY opt level is funneled via NDSL - use it here + optimization_level = int(GT4PY_COMPILE_OPT_LEVEL) # Set the configuration of DaCe to a rigid & tested set of divergence # from the defaults when orchestrating - if orchestration != DaCeOrchestration.Python: + if self.is_dace_orchestrated(): + # Detecting neoverse-v1/2 requires an external package, we swap it + # for a read on GH200 nodes themselves. + is_arm_neoverse = ( + cp is not None + and cp.cuda.runtime.getDeviceProperties(0)["name"] + == b"NVIDIA GH200 480GB" + ) + + if optimization_level == 0: + dace.config.Config.set("compiler", "build_type", value="Debug") + elif optimization_level == 2 or optimization_level == 1: + dace.config.Config.set("compiler", "build_type", value="RelWithDebInfo") + else: + dace.config.Config.set("compiler", "build_type", value="Release") + # Required to True for gt4py storage/memory dace.config.Config.set( "compiler", "allow_view_arguments", value=True, ) + # Resolve "march/mtune" option for GPU + # - turn on numeric-centric SSE by default + # - Neoverse-V2 Grace CPU is too new for GCC 14 and -march=native will fail + # - use alternative march=armv8-a instead + march_cpu = "armv8-a" if is_arm_neoverse else "native" # Removed --fmath dace.config.Config.set( "compiler", "cpu", "args", - value="-std=c++14 -fPIC -Wall -Wextra -O3", + value=f"-march={march_cpu} -std=c++17 -fPIC -Wall -Wextra -O{optimization_level}", ) # Potentially buggy - deactivate dace.config.Config.set( @@ -202,17 +250,20 @@ def __init__( "openmp_sections", value=0, ) + # Resolve "march/mtune" option for GPU + # - turn on numeric-centric SSE by default + # - Neoverse-V2 Grace CPU will fail + # - use alternative mcpu=native instead + march_option = "-mcpu=native" if is_arm_neoverse else "-march=native" # Removed --fast-math dace.config.Config.set( "compiler", "cuda", "args", - value="-std=c++14 -Xcompiler -fPIC -O3 -Xcompiler -march=native", + value=f"-std=c++14 -Xcompiler -fPIC -O3 -Xcompiler {march_option}", ) - cuda_sm = 60 - if cp: - cuda_sm = cp.cuda.Device(0).compute_capability + cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60 dace.config.Config.set("compiler", "cuda", "cuda_arch", value=f"{cuda_sm}") # Block size/thread count is defaulted to an average value for recent # hardware (Pascal and upward). The problem of setting an optimized @@ -261,7 +312,7 @@ def __init__( # Enable to debug GPU failures dace.config.Config.set( - "compiler", "cuda", "syncdebug", value=dace_debug_env_var + "compiler", "cuda", "syncdebug", value=_debug_dace_orchestration() ) if get_precision() == 32: @@ -273,7 +324,7 @@ def __init__( value="c", ) - # attempt to kill the dace.conf to avoid confusion + # Attempt to kill the dace.conf to avoid confusion if dace.config.Config._cfg_filename: try: os.remove(dace.config.Config._cfg_filename) @@ -303,13 +354,10 @@ def __init__( set_distributed_caches(self) - if ( - self._orchestrate != DaCeOrchestration.Python - and "dace" not in self._backend - ): + if self.is_dace_orchestrated() and "dace" not in self._backend: raise RuntimeError( - "DaceConfig: orchestration can only be leverage " - f"on dace or dace:gpu not on {self._backend}" + "DaceConfig: orchestration can only be leveraged " + f"with the `dace:*` backends, not with {self._backend}." ) def is_dace_orchestrated(self) -> bool: @@ -327,7 +375,7 @@ def get_orchestrate(self) -> DaCeOrchestration: def get_sync_debug(self) -> bool: return dace.config.Config.get_bool("compiler", "cuda", "syncdebug") - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: return { "_orchestrate": str(self._orchestrate.name), "_backend": self._backend, @@ -338,7 +386,7 @@ def as_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: dict) -> Self: config = cls( None, backend=data["_backend"], diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 5c30367f..43b6c4f3 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -1,9 +1,16 @@ +from __future__ import annotations + +import numbers import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Callable, Sequence +from typing import Any -import dace -import gt4py.storage +from dace import SDFG, CompiledSDFG from dace import compiletime as DaceCompiletime +from dace import dtypes +from dace import method as dace_method +from dace import nodes +from dace import program as dace_program from dace.dtypes import DeviceType as DaceDeviceType from dace.dtypes import StorageType as DaceStorageType from dace.frontend.python.common import SDFGConvertible @@ -11,21 +18,23 @@ from dace.transformation.auto.auto_optimize import make_transients_persistent from dace.transformation.helpers import get_parent_map from dace.transformation.passes.simplify import SimplifyPass +from gt4py import storage +import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE, DaceConfig, DaCeOrchestration, - FrozenCompiledSDFG, ) from ndsl.dsl.dace.sdfg_debug_passes import ( negative_delp_checker, negative_qtracers_checker, sdfg_nan_checker, ) -from ndsl.dsl.dace.sdfg_opt_passes import splittable_region_expansion +from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -35,12 +44,19 @@ from ndsl.optional_imports import cupy as cp +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False +"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" + +_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] +"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" + + def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" return func -def _upload_to_device(host_data: List[Any]) -> None: +def _upload_to_device(host_data: list) -> None: """Make sure any ndarrays gets uploaded to the device This will raise an assertion if cupy is not installed. @@ -52,17 +68,17 @@ def _upload_to_device(host_data: List[Any]) -> None: def _download_results_from_dace( - config: DaceConfig, dace_result: Optional[List[Any]], args: List[Any] -): + config: DaceConfig, dace_result: list | None +) -> list | None: """Move all data from DaCe memory space to GT4Py""" if dace_result is None: return None backend = config.get_backend() - return [gt4py.storage.from_array(result, backend=backend) for result in dace_result] + return [storage.from_array(result, backend=backend) for result in dace_result] -def _to_gpu(sdfg: dace.SDFG): +def _to_gpu(sdfg: SDFG) -> None: """Flag memory in SDFG to GPU. Force deactivate OpenMP sections for sanity.""" @@ -70,22 +86,22 @@ def _to_gpu(sdfg: dace.SDFG): allmaps = [ (me, state) for me, state in sdfg.all_nodes_recursive() - if isinstance(me, dace.nodes.MapEntry) + if isinstance(me, nodes.MapEntry) ] topmaps = [ (me, state) for me, state in allmaps if get_parent_map(state, me) is None ] # Set storage of arrays to GPU, scalarizable arrays will be set on registers - for sd, _aname, arr in sdfg.arrays_recursive(): + for _sd, _aname, arr in sdfg.arrays_recursive(): if arr.shape == (1,): - arr.storage = dace.StorageType.Register + arr.storage = dtypes.StorageType.Register else: - arr.storage = dace.StorageType.GPU_Global + arr.storage = dtypes.StorageType.GPU_Global # All maps will be schedule on GPU for mapentry, _state in topmaps: - mapentry.schedule = dace.ScheduleType.GPU_Device + mapentry.schedule = dtypes.ScheduleType.GPU_Device # Deactivate OpenMP sections for sd in sdfg.all_sdfgs_recursive(): @@ -93,41 +109,82 @@ def _to_gpu(sdfg: dace.SDFG): def _simplify( - sdfg: dace.SDFG, + sdfg: SDFG, *, validate: bool = True, validate_all: bool = False, verbose: bool = False, -): +) -> None: """Override of sdfg.simplify to skip failing transformation per https://github.com/spcl/dace/issues/1328 """ - return SimplifyPass( + SimplifyPass( validate=validate, validate_all=validate_all, verbose=verbose, + # We disable ScalarToSymbolPromotion because it might push symbols onto edges + # that DaCe itself can't parse anymore later, e.g. casts, inlined function + # calls or (complicated) field accesses. + skip=["ScalarToSymbolPromotion"], ).apply_pass(sdfg, {}) def _build_sdfg( - dace_program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs -): - """Build the .so out of the SDFG on the top tile ranks only""" + dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any +) -> None: + """Build the .so out of the SDFG on the top tile ranks only.""" is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile + device_type = DaceDeviceType.GPU if config.is_gpu_backend() else DaceDeviceType.CPU if is_compiling: + with DaCeProgress(config, "Validate original SDFG"): + sdfg.validate() + + # Fully specialize all known symbols and then propagate these changes in the simplify + # pass that follows. This is not only a smart idea in general, but also simplifies (haha) + # the schedule tree (optimization) roundtrip. + with DaCeProgress(config, "Fully specialize symbols"): + for my_sdfg in sdfg.all_sdfgs_recursive(): + if my_sdfg.parent_nsdfg_node is not None: + repl_dict = {} + for sym, val in my_sdfg.parent_nsdfg_node.symbol_mapping.items(): + if isinstance(val, numbers.Number): + repl_dict[sym] = val + my_sdfg.replace_dict(repl_dict) + + with DaCeProgress(config, "Simplify (1)"): + _simplify(sdfg) + + if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + stree = sdfg.as_schedule_tree() + + with DaCeProgress(config, "Schedule Tree: optimization"): + if config.is_gpu_backend(): + GPUPipeline().run(stree) + else: + CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + + with DaCeProgress(config, "Schedule Tree: go back to SDFG"): + sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) + # Make the transients array persistents if config.is_gpu_backend(): + # TODO + # The following should happen on the stree level _to_gpu(sdfg) - make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.GPU) + + make_transients_persistent(sdfg=sdfg, device=device_type) # Upload args to device _upload_to_device(list(args) + list(kwargs.values())) else: + # TODO + # The following should happen on the stree level for _sd, _aname, arr in sdfg.arrays_recursive(): if arr.shape == (1,): arr.storage = DaceStorageType.Register - make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.CPU) + make_transients_persistent(sdfg=sdfg, device=device_type) # Build non-constants & non-transients from the sdfg_kwargs sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) @@ -139,29 +196,18 @@ def _build_sdfg( if k in sdfg_kwargs and tup[1].transient: del sdfg_kwargs[k] - with DaCeProgress(config, "Simplify (1/2)"): - _simplify(sdfg, validate=False, verbose=True) - - # Perform pre-expansion fine tuning - with DaCeProgress(config, "Split regions"): - splittable_region_expansion(sdfg, verbose=True) - - # Expand the stencil computation Library Nodes with the right expansion - with DaCeProgress(config, "Expand"): - sdfg.expand_library_nodes() - - with DaCeProgress(config, "Simplify (2/2)"): - _simplify(sdfg, validate=False, verbose=True) + with DaCeProgress(config, "Simplify (2)"): + _simplify(sdfg) # Move all memory that can be into a pool to lower memory pressure. # Change Persistent memory (sub-SDFG) into Scope and flag it. with DaCeProgress(config, "Turn Persistents into pooled Scope"): memory_pooled = 0.0 for _sd, _aname, arr in sdfg.arrays_recursive(): - if arr.lifetime == dace.AllocationLifetime.Persistent: + if arr.lifetime == dtypes.AllocationLifetime.Persistent: arr.pool = True memory_pooled += arr.total_size * arr.dtype.bytes - arr.lifetime = dace.AllocationLifetime.Scope + arr.lifetime = dtypes.AllocationLifetime.Scope memory_pooled = float(memory_pooled) / (1024 * 1024) ndsl_log.debug( f"{DaCeProgress.default_prefix(config)} Pooled {memory_pooled} mb", @@ -175,10 +221,15 @@ def _build_sdfg( negative_delp_checker(sdfg) negative_qtracers_checker(sdfg) + with DaCeProgress(config, "Validate before compile"): + sdfg.validate() + # Compile with DaCeProgress(config, "Codegen & compile"): sdfg.compile() - write_build_info(sdfg, config.layout, config.tile_resolution, config._backend) + write_build_info( + sdfg, config.layout, config.tile_resolution, config.get_backend() + ) # Printing analysis of the compiled SDFG with DaCeProgress(config, "Build finished. Running memory static analysis"): @@ -191,7 +242,10 @@ def _build_sdfg( # On Build: all ranks sync, then exit. # On BuildAndRun: all ranks sync, then load the SDFG from # the expected path (made available by build). - # We use a "FrozenCompiledSDFG" to minimize re-entry cost at call time + # We use a "CompiledSDFG" which keep the `so` online but _won't_ + # do the marshalling of the arguments at call time. For this we call + # `dace_program._create_sdfg_args`. There's optimization potential for + # re-entry cost there. mode = config.get_orchestrate() # DEV NOTE: we explicitly use MPI.COMM_WORLD here because it is @@ -204,60 +258,83 @@ def _build_sdfg( if mode == DaCeOrchestration.BuildAndRun: if not is_compiling: ndsl_log.info( - f"{DaCeProgress.default_prefix(config)} Rank is not compiling." + f"{DaCeProgress.default_prefix(config)} Rank is not compiling. " "Waiting for compilation to end on all other ranks..." ) MPI.COMM_WORLD.Barrier() with DaCeProgress(config, "Loading"): sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True) + if sdfg_path is None: + raise ValueError("Couldn't load SDFG post build") compiledSDFG, _ = dace_program.load_precompiled_sdfg( sdfg_path, *args, **kwargs ) - config.loaded_precompiled_SDFG[dace_program] = FrozenCompiledSDFG( - dace_program, compiledSDFG, args, kwargs - ) - - return _call_sdfg(dace_program, sdfg, config, args, kwargs) + config.loaded_precompiled_SDFG[dace_program] = compiledSDFG def _call_sdfg( - dace_program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs -): - """Dispatch the SDFG execution and/or build""" - # Pre-compiled SDFG code path does away with any data checks and - # cached the marshalling - leading to almost direct C call - # DaceProgram performs argument transformation & checks for a cost ~200ms - # of overhead - if dace_program in config.loaded_precompiled_SDFG: + dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any +) -> list | None: + """Dispatch to either SDFG execution and/or build.""" + + with config.performance_collector.timestep_timer.clock(f"{dace_program.name}.Call"): + # Check if we need to build first + mode = config.get_orchestrate() + if ( + mode in [DaCeOrchestration.Build, DaCeOrchestration.BuildAndRun] + and dace_program not in config.loaded_precompiled_SDFG # already cached + ): + ndsl_log.info("Building DaCe orchestration") + _build_sdfg(dace_program, sdfg, config, args, kwargs) + + if mode not in [DaCeOrchestration.BuildAndRun, DaCeOrchestration.Run]: + raise ValueError(f"Unexpected DaceOrchestration mode `{mode}`.") + + if dace_program not in config.loaded_precompiled_SDFG: + raise RuntimeError( + "Dace program not found in cache. Are you running `DaCeOrchestration.Run` " + "without a pre-filled cache folder? Try `DacCeOrchestration.BuildAndRun` instead." + ) + + # Pre-compiled SDFG code path does away with any data checks and + # cached the marshalling - leading to almost direct C call + # DaceProgram performs argument transformation & checks for a cost ~200ms + # of overhead with DaCeProgress(config, "Run"): if config.is_gpu_backend(): _upload_to_device(list(args) + list(kwargs.values())) - res = config.loaded_precompiled_SDFG[dace_program]() - res = _download_results_from_dace( - config, res, list(args) + list(kwargs.values()) - ) - return res - mode = config.get_orchestrate() - if mode in [DaCeOrchestration.Build, DaCeOrchestration.BuildAndRun]: - ndsl_log.info("Building DaCe orchestration") - return _build_sdfg(dace_program, sdfg, config, args, kwargs) - - if mode == DaCeOrchestration.Run: - # We should never hit this, it should be caught by the - # loaded_precompiled_SDFG check above - raise RuntimeError("Unexpected call - pre-compiled SDFG failed to load") - else: - raise NotImplementedError(f"Mode '{mode}' unimplemented at call time") + # NOTE: this will go over declared arguments and closure arguments. + # It is a very slow piece of code. Compiled SDFG comes with a "fast_call" + # function that expects all pointers to have been made C worthy. This is + # something we did with "FrozenCompileSDFG" but we undid because it meant that + # external changing memory (reallocation...) would quietly fail + with config.performance_collector.timestep_timer.clock( + f"{dace_program.name}.ArgMarshalling" + ): + current_sdfg_args = dace_program._create_sdfg_args( + config.loaded_precompiled_SDFG[dace_program].sdfg, args, kwargs + ) + + with config.performance_collector.timestep_timer.clock( + f"{dace_program.name}.Runtime" + ): + results = config.loaded_precompiled_SDFG[dace_program]( + **current_sdfg_args + ) + + config.performance_collector.collect_performance() + + return _download_results_from_dace(config, results) def _parse_sdfg( dace_program: DaceProgram, config: DaceConfig, - *args, - **kwargs, -) -> Optional[dace.SDFG]: + *args: Any, + **kwargs: Any, +) -> SDFG | CompiledSDFG | None: """Return an SDFG depending on cache existence. Either parses, load a .sdfg or load .so (as a compiled sdfg) @@ -288,6 +365,7 @@ def _parse_sdfg( **kwargs, save=False, simplify=False, + validate=False, # TODO: should we have a "debug flag" to turn this on? ) return sdfg @@ -298,9 +376,7 @@ def _parse_sdfg( with DaCeProgress(config, "Load precompiled .sdfg (.so)"): compiledSDFG, _ = dace_program.load_precompiled_sdfg(sdfg_path, *args, **kwargs) - config.loaded_precompiled_SDFG[dace_program] = FrozenCompiledSDFG( - dace_program, compiledSDFG, args, kwargs - ) + config.loaded_precompiled_SDFG[dace_program] = compiledSDFG return compiledSDFG @@ -313,13 +389,13 @@ class _LazyComputepathFunction(SDFGConvertible): that will be compiled but not regenerated. """ - def __init__(self, func: Callable, config: DaceConfig): + def __init__(self, func: Callable, config: DaceConfig) -> None: self.func = func self.config = config - self.daceprog: DaceProgram = dace.program(self.func) + self.daceprog: DaceProgram = dace_program(self.func) self._sdfg = None - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] assert self.config.is_dace_orchestrated() sdfg = _parse_sdfg( self.daceprog, @@ -336,23 +412,23 @@ def __call__(self, *args, **kwargs): ) @property - def global_vars(self): + def global_vars(self): # type: ignore[no-untyped-def] return self.daceprog.global_vars @global_vars.setter - def global_vars(self, value): + def global_vars(self, value): # type: ignore[no-untyped-def] self.daceprog.global_vars = value - def __sdfg__(self, *args, **kwargs): + def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] return _parse_sdfg(self.daceprog, self.config, *args, **kwargs) - def __sdfg_closure__(self, *args, **kwargs): + def __sdfg_closure__(self, *args, **kwargs): # type: ignore[no-untyped-def] return self.daceprog.__sdfg_closure__(*args, **kwargs) - def __sdfg_signature__(self): + def __sdfg_signature__(self): # type: ignore[no-untyped-def] return self.daceprog.argnames, self.daceprog.constant_args - def closure_resolver(self, constant_args, given_args, parent_closure=None): + def closure_resolver(self, constant_args, given_args, parent_closure=None): # type: ignore[no-untyped-def] return self.daceprog.closure_resolver(constant_args, given_args, parent_closure) @@ -367,24 +443,26 @@ class _LazyComputepathMethod: # In order to not regenerate SDFG for the same obj.method callable # we cache the SDFGEnabledCallable we have already init - bound_callables: Dict[Tuple[int, int], "SDFGEnabledCallable"] = dict() + bound_callables: dict[tuple[int, int], SDFGEnabledCallable] = dict() class SDFGEnabledCallable(SDFGConvertible): - def __init__(self, lazy_method: "_LazyComputepathMethod", obj_to_bind): - methodwrapper = dace.method(lazy_method.func) + def __init__( + self, lazy_method: _LazyComputepathMethod, obj_to_bind: object + ) -> None: + methodwrapper = dace_method(lazy_method.func) self.obj_to_bind = obj_to_bind self.lazy_method = lazy_method self.daceprog: DaceProgram = methodwrapper.__get__(obj_to_bind) @property - def global_vars(self): + def global_vars(self): # type: ignore[no-untyped-def] return self.daceprog.global_vars @global_vars.setter - def global_vars(self, value): + def global_vars(self, value): # type: ignore[no-untyped-def] self.daceprog.global_vars = value - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] assert self.lazy_method.config.is_dace_orchestrated() sdfg = _parse_sdfg( self.daceprog, @@ -400,16 +478,16 @@ def __call__(self, *args, **kwargs): kwargs, ) - def __sdfg__(self, *args, **kwargs): + def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] return _parse_sdfg(self.daceprog, self.lazy_method.config, *args, **kwargs) - def __sdfg_closure__(self, reevaluate=None): + def __sdfg_closure__(self, reevaluate=None): # type: ignore[no-untyped-def] return self.daceprog.__sdfg_closure__(reevaluate) - def __sdfg_signature__(self): + def __sdfg_signature__(self): # type: ignore[no-untyped-def] return self.daceprog.argnames, self.daceprog.constant_args - def closure_resolver(self, constant_args, given_args, parent_closure=None): + def closure_resolver(self, constant_args, given_args, parent_closure=None): # type: ignore[no-untyped-def] return self.daceprog.closure_resolver( constant_args, given_args, parent_closure ) @@ -418,13 +496,13 @@ def __init__(self, func: Callable, config: DaceConfig): self.func = func self.config = config - def __get__(self, obj, objtype=None) -> SDFGEnabledCallable: + def __get__(self, obj: object, objtype: Any = None) -> SDFGEnabledCallable: """Return SDFGEnabledCallable wrapping original obj.method from cache. Update cache first if need be""" if (id(obj), id(self.func)) not in _LazyComputepathMethod.bound_callables: - _LazyComputepathMethod.bound_callables[ - (id(obj), id(self.func)) - ] = _LazyComputepathMethod.SDFGEnabledCallable(self, obj) + _LazyComputepathMethod.bound_callables[(id(obj), id(self.func))] = ( + _LazyComputepathMethod.SDFGEnabledCallable(self, obj) + ) return _LazyComputepathMethod.bound_callables[(id(obj), id(self.func))] @@ -432,10 +510,10 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable: def orchestrate( *, obj: object, - config: Optional[DaceConfig], + config: DaceConfig, method_to_orchestrate: str = "__call__", - dace_compiletime_args: Optional[Sequence[str]] = None, -): + dace_compiletime_args: Sequence[str] | None = None, +) -> None: """ Orchestrate a method of an object with DaCe. The method object is patched in place, replacing the original Callable with @@ -449,83 +527,85 @@ def orchestrate( dace_compiletime_args: list of names of arguments to be flagged has dace.compiletime for orchestration to behave """ + if not config.is_dace_orchestrated(): + return + if config is None: raise ValueError("DaCe config cannot be None") + if not hasattr(obj, method_to_orchestrate): + raise RuntimeError( + f"Could not orchestrate, " + f"{type(obj).__name__}.{method_to_orchestrate} " + "does not exist." + ) + if dace_compiletime_args is None: dace_compiletime_args = [] - if config.is_dace_orchestrated(): - if not hasattr(obj, method_to_orchestrate): - raise RuntimeError( - f"Could not orchestrate, " - f"{type(obj).__name__}.{method_to_orchestrate} " - "does not exists" - ) - - func = type.__getattribute__(type(obj), method_to_orchestrate) - - # Flag argument as dace.constant - for argument in dace_compiletime_args: - func.__annotations__[argument] = DaceCompiletime - - # Build DaCe orchestrated wrapper - # This is a JIT object, e.g. DaCe compilation will happen on call - wrapped = _LazyComputepathMethod(func, config).__get__(obj) - - if method_to_orchestrate == "__call__": - # Grab the function from the type of the child class - # Dev note: we need to use type for dunder call because: - # a = A() - # a() - # resolved to: type(a).__call__(a) - # therefore patching the instance call (e.g a.__call__) is not enough. - # We could patch the type(self), ergo the class itself - # but that would patch _every_ instance of A. - # What we can do is patch the instance.__class__ with a local made class - # in order to keep each instance with it's own patch. - # - # Re: type:ignore - # Mypy is unhappy about dynamic class name and the devs (per github - # issues discussion) is to make a plugin. Too much work -> ignore mypy - - class _(type(obj)): # type: ignore - __qualname__ = f"{type(obj).__qualname__}_patched" - __name__ = f"{type(obj).__name__}_patched" - - def __call__(self, *arg, **kwarg): - return wrapped(*arg, **kwarg) - - def __sdfg__(self, *args, **kwargs): - return wrapped.__sdfg__(*args, **kwargs) - - def __sdfg_closure__(self, reevaluate=None): - return wrapped.__sdfg_closure__(reevaluate) - - def __sdfg_signature__(self): - return wrapped.__sdfg_signature__() - - def closure_resolver( - self, constant_args, given_args, parent_closure=None - ): - return wrapped.closure_resolver( - constant_args, given_args, parent_closure - ) - - # We keep the original class type name to not perturb - # the workflows that uses it to build relevant info (path, hash...) - previous_cls_name = type(obj).__name__ - obj.__class__ = _ - type(obj).__name__ = previous_cls_name - else: - # For regular attribute - we can just patch as usual - setattr(obj, method_to_orchestrate, wrapped) + func = type.__getattribute__(type(obj), method_to_orchestrate) + + # Flag argument as dace.constant + for argument in dace_compiletime_args: + func.__annotations__[argument] = DaceCompiletime + + # Build DaCe orchestrated wrapper + # This is a JIT object, e.g. DaCe compilation will happen on call + wrapped = _LazyComputepathMethod(func, config).__get__(obj) + + if method_to_orchestrate == "__call__": + # Grab the function from the type of the child class + # Dev note: we need to use type for dunder call because: + # a = A() + # a() + # resolved to: type(a).__call__(a) + # therefore patching the instance call (e.g a.__call__) is not enough. + # We could patch the type(self), ergo the class itself + # but that would patch _every_ instance of A. + # What we can do is patch the instance.__class__ with a local made class + # in order to keep each instance with it's own patch. + # + # Re: type:ignore + # Mypy is unhappy about dynamic class name and the devs (per github + # issues discussion) is to make a plugin. Too much work -> ignore mypy + + class _(type(obj)): # type: ignore + __qualname__ = f"{type(obj).__qualname__}_patched" + __name__ = f"{type(obj).__name__}_patched" + + def __call__(self, *arg, **kwarg): # type: ignore[no-untyped-def] + return wrapped(*arg, **kwarg) + + def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] + sdfg = wrapped.__sdfg__(*args, **kwargs) + sdfg.validate() + return sdfg + + def __sdfg_closure__(self, reevaluate=None): # type: ignore[no-untyped-def] + return wrapped.__sdfg_closure__(reevaluate) + + def __sdfg_signature__(self): # type: ignore[no-untyped-def] + return wrapped.__sdfg_signature__() + + def closure_resolver(self, constant_args, given_args, parent_closure=None): # type: ignore[no-untyped-def] + return wrapped.closure_resolver( + constant_args, given_args, parent_closure + ) + + # We keep the original class type name to not perturb + # the workflows that uses it to build relevant info (path, hash...) + previous_cls_name = type(obj).__name__ + obj.__class__ = _ + type(obj).__name__ = previous_cls_name + else: + # For regular attribute - we can just patch as usual + setattr(obj, method_to_orchestrate, wrapped) def orchestrate_function( - config: DaceConfig = None, - dace_compiletime_args: Optional[Sequence[str]] = None, -) -> Union[Callable[..., Any], _LazyComputepathFunction]: + config: DaceConfig, + dace_compiletime_args: Sequence[str] | None = None, +) -> Callable[..., Any] | _LazyComputepathFunction: """ Decorator orchestrating a method of an object with DaCe. If the model configuration doesn't demand orchestration, this won't do anything. @@ -539,8 +619,8 @@ def orchestrate_function( if dace_compiletime_args is None: dace_compiletime_args = [] - def _decorator(func: Callable[..., Any]): - def _wrapper(*args, **kwargs): + def _decorator(func: Callable[..., Any]): # type: ignore[no-untyped-def] + def _wrapper(*args, **kwargs): # type: ignore[no-untyped-def] for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime return _LazyComputepathFunction(func, config) diff --git a/ndsl/dsl/dace/replacements.py b/ndsl/dsl/dace/replacements.py new file mode 100644 index 00000000..ca7bed0c --- /dev/null +++ b/ndsl/dsl/dace/replacements.py @@ -0,0 +1,40 @@ +"""This module uses DaCe's op_repository feature to override symbols/AST objects +during parsing and replace them with an SDFG compatible representation. This +allows custom NDSL syntax, objects and symbols to be natively orchestratable.""" + +from dace import SDFG, SDFGState, dtypes +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python.newast import ProgramVisitor +from dace.frontend.python.replacements import ( + UfuncInput, + UfuncOutput, + _datatype_converter, +) + +from ndsl.dsl.typing import Float, Int + + +@oprepo.replaces("Float") +def _convert_Float( + _pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arg: UfuncInput +) -> UfuncOutput: + """Replace `Float(x)` with a typecast of `x` to the proper floating precision type""" + return _datatype_converter( + sdfg, + state, + arg, + dtype=dtypes.dtype_to_typeclass(Float), + ) + + +@oprepo.replaces("Int") +def _convert_Int( + _pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arg: UfuncInput +) -> UfuncOutput: + """Replace `Int(x)` with a typecast of `x` to the proper integer precision type""" + return _datatype_converter( + sdfg, + state, + arg, + dtype=dtypes.dtype_to_typeclass(Int), + ) diff --git a/ndsl/dsl/dace/sdfg/loop_transform.py b/ndsl/dsl/dace/sdfg/loop_transform.py new file mode 100644 index 00000000..7e6cf1d4 --- /dev/null +++ b/ndsl/dsl/dace/sdfg/loop_transform.py @@ -0,0 +1,19 @@ +from dace import SDFG, ScheduleType, nodes + + +def make_SDFG_CPU_sequential(sdfg: SDFG) -> None: + """Utility to turn a CPU-based SDFG to pure serial by removing OpenMP""" + # Disable OpenMP sections + for sd in sdfg.all_sdfgs_recursive(): + sd.openmp_sections = False + + # Disable OpenMP maps + for node, _ in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.EntryNode): + schedule = getattr(node, "schedule", False) + if schedule in ( + ScheduleType.CPU_Multicore, + ScheduleType.CPU_Persistent, + ScheduleType.Default, + ): + node.schedule = ScheduleType.Sequential diff --git a/ndsl/dsl/dace/sdfg_debug_passes.py b/ndsl/dsl/dace/sdfg_debug_passes.py index d1632af7..c40d327d 100644 --- a/ndsl/dsl/dace/sdfg_debug_passes.py +++ b/ndsl/dsl/dace/sdfg_debug_passes.py @@ -1,5 +1,4 @@ import copy -from typing import List, Optional, Tuple import dace import sympy as sp @@ -14,11 +13,11 @@ def _filter_all_maps( sdfg: dace.SDFG, - whitelist: List[str] = None, - blacklist: List[str] = None, - skip_dynamic_memlet=True, -) -> List[ - Tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]] + whitelist: list[str] | None = None, + blacklist: list[str] | None = None, + skip_dynamic_memlet: bool = True, +) -> list[ + tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]] ]: """ Grab all maps outputs and filter by variable name (either black or whitelist) @@ -33,9 +32,13 @@ def _filter_all_maps( A list of access nodes, with their state & edges organized as [state, node, edges] """ + if whitelist is None: + whitelist = [] + if blacklist is None: + blacklist = [] - checks: List[ - Tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]] + checks: list[ + tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]] ] = [] all_maps = [ (me, state) @@ -54,13 +57,11 @@ def _filter_all_maps( continue node = sdutil.get_last_view_node(state, e.dst) # Whitelist - if whitelist is not None: - if all([varname not in node.data for varname in whitelist]): - continue + if all([varname not in node.data for varname in whitelist]): + continue # Blacklist - if blacklist is not None: - if any([varname in node.data for varname in blacklist]): - continue + if any([varname in node.data for varname in blacklist]): + continue # Skip dynamic (region) outputs if skip_dynamic_memlet and state.memlet_path(e)[0].data.dynamic: dynamic_skipped += 1 @@ -81,8 +82,8 @@ def _check_node( check_c_code: str, comment_c_code: str, assert_out: bool = False, - array_range: Optional[List[Tuple[int, int, int]]] = None, -): + array_range: list[tuple[int, int, int]] | None = None, +) -> None: """ Grab all maps outputs and filter by variable name (either black or whitelist) @@ -119,7 +120,7 @@ def _check_node( index_printf = ", ".join(["%d"] * len(input_array.shape)) # Get range from memlet (which may not be the entire array size) - def evaluate(expr): + def evaluate(expr): # type: ignore[no-untyped-def] return expr.subs({sp.Function("int_floor"): symbolic.int_floor}) # A bug in DaCe can lead to an edge labeled for storage on CPU @@ -173,7 +174,7 @@ def evaluate(expr): ) -def trace_all_outputs_at_index(sdfg: dace.SDFG, i: int, j: int, k: int): +def trace_all_outputs_at_index(sdfg: dace.SDFG, i: int, j: int, k: int) -> None: """Prints value for all variable when written for a specific index. @@ -230,7 +231,7 @@ def negative_delp_checker(sdfg: dace.SDFG) -> None: ndsl_log.info(f"Added {len(all_maps_filtered)} delp* < 0 checks") -def negative_qtracers_checker(sdfg: dace.SDFG): +def negative_qtracers_checker(sdfg: dace.SDFG) -> None: """ Adds a negative check on every tracers via their name when written to. Assert when check is True. @@ -268,20 +269,25 @@ def negative_qtracers_checker(sdfg: dace.SDFG): def sdfg_nan_checker( sdfg: dace.SDFG, - i_range: Optional[Tuple[int, int, int]] = None, - j_range: Optional[Tuple[int, int, int]] = None, - k_range: Optional[Tuple[int, int, int]] = None, -): + i_range: tuple[int, int, int] | None = None, + j_range: tuple[int, int, int] | None = None, + k_range: tuple[int, int, int] | None = None, +) -> None: """ Insert a check on array after each computational map to check for NaN in the domain. Assert when check is True. """ all_maps_filtered = _filter_all_maps(sdfg, blacklist=["diss_estd"]) - if i_range or j_range or k_range: - array_range = [i_range, j_range, k_range] - else: + if i_range is None and j_range is None and k_range is None: array_range = None + else: + if i_range is not None and j_range is not None and k_range is not None: + array_range = [i_range, j_range, k_range] + else: + raise RuntimeError( + "It looks like you have to specify either all or not of the ranges." + ) for state, node, e in all_maps_filtered: _check_node( @@ -299,7 +305,7 @@ def sdfg_nan_checker( ndsl_log.info(f"Added {len(all_maps_filtered)} NaN checks") -def sdfg_execution_progress(sdfg: dace.SDFG): +def sdfg_execution_progress(sdfg: dace.SDFG) -> None: all_maps_filtered = _filter_all_maps(sdfg) for state, node, e in all_maps_filtered: diff --git a/ndsl/dsl/dace/sdfg_opt_passes.py b/ndsl/dsl/dace/sdfg_opt_passes.py deleted file mode 100644 index b7582cc5..00000000 --- a/ndsl/dsl/dace/sdfg_opt_passes.py +++ /dev/null @@ -1,24 +0,0 @@ -import dace - -from ndsl.logging import ndsl_log - - -def splittable_region_expansion(sdfg: dace.SDFG, verbose: bool = False): - """ - Set certain StencilComputation library nodes to expand to a different - schedule if they contain small splittable regions. - """ - from gt4py.cartesian.gtc.dace.nodes import StencilComputation - - for node, _ in sdfg.all_nodes_recursive(): - if isinstance(node, StencilComputation): - if node.has_splittable_regions() and "corner" in node.label: - node.expansion_specification = [ - "Sections", - "Stages", - "J", - "I", - "K", - ] - if verbose: - ndsl_log.debug(f"Reordered schedule for {node.label}") diff --git a/ndsl/dsl/dace/stree/__init__.py b/ndsl/dsl/dace/stree/__init__.py new file mode 100644 index 00000000..6435e662 --- /dev/null +++ b/ndsl/dsl/dace/stree/__init__.py @@ -0,0 +1,4 @@ +from .pipeline import CPUPipeline, GPUPipeline + + +__all__ = ["CPUPipeline", "GPUPipeline"] diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py new file mode 100644 index 00000000..47c764b3 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -0,0 +1,4 @@ +from .axis_merge import AxisIterator, CartesianAxisMerge + + +__all__ = ["AxisIterator", "CartesianAxisMerge"] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py new file mode 100644 index 00000000..262a6021 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +import copy +import re +from typing import Any + +import dace +import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.properties import CodeBlock + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( + AxisIterator, + no_data_dependencies_on_cartesian_axis, +) +from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( + detect_cycle, + list_index, + swap_node_position_in_tree, +) + + +def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a map over the given axis.""" + map_parameter = node.node.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def _both_same_single_axis_maps( + first: stree.MapScope, + second: stree.MapScope, + axis: AxisIterator, +) -> bool: + return ( + (len(first.node.params) == 1 and len(second.node.params) == 1) # Single axis + and first.node.params[0] == second.node.params[0] # Same axis + and _is_axis_map(first, axis) # Correct axis + ) + + +def _can_merge_axis_maps( + first: stree.MapScope, + second: stree.MapScope, + axis: AxisIterator, +) -> bool: + return _both_same_single_axis_maps( + first, second, axis + ) and no_data_dependencies_on_cartesian_axis( + first, + second, + axis, + ) + + +class InsertOvercomputationGuard(stree.ScheduleNodeTransformer): + def __init__( + self, + axis_as_string: str, + *, + merged_range: dace.subsets.Range, + original_range: dace.subsets.Range, + ): + self._axis_as_string = axis_as_string + self._merged_range = merged_range + self._original_range = original_range + + def _execution_condition(self) -> CodeBlock: + # NOTE range.ranges are inclusive, e.g. + # Range(0:4) -> ranges = (start=1, stop=3, step=1) + range = self._original_range + start = range.ranges[0][0] + stop = range.ranges[0][1] + step = range.ranges[0][2] + return CodeBlock( + f"{self._axis_as_string} >= {start} " + f"and {self._axis_as_string} <= {stop} " + f"and ({self._axis_as_string} - {start}) % {step} == 0" + ) + + def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: + all_children_are_maps = all( + [isinstance(child, stree.MapScope) for child in node.children] + ) + if not all_children_are_maps: + if self._merged_range != self._original_range: + node.children = [ + stree.IfScope( + condition=self._execution_condition(), children=node.children + ) + ] + return node + + node.children = self.visit(node.children) + return node + + +def _get_next_node( + nodes: list[stree.ScheduleTreeNode], + node: stree.ScheduleTreeNode, +) -> stree.ScheduleTreeNode: + return nodes[list_index(nodes, node) + 1] + + +def _last_node( + nodes: list[stree.ScheduleTreeNode], node: stree.ScheduleTreeNode +) -> bool: + return list_index(nodes, node) >= len(nodes) - 1 + + +def _sanitize_axis(axis: AxisIterator, name_to_normalize: str) -> str: + axis_clean = f"{axis.as_str()}" + pattern = f"{axis.as_str()}_[0-9]*" + + return re.sub(pattern, axis_clean, name_to_normalize) + + +class NormalizeAxisSymbol(stree.ScheduleNodeVisitor): + def __init__(self, axis: AxisIterator) -> None: + self.axis = axis + + def visit_MapScope( + self, + map_scope: stree.MapScope, + axis_replacements: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + if axis_replacements is None: + axis_replacements = {} + for index, param in enumerate(map_scope.node.params): + sanitized_param = _sanitize_axis(self.axis, param) + axis_replacements[param] = sanitized_param + map_scope.node.params[index] = sanitized_param + + # visit children + for child in map_scope.children: + self.visit(child, axis_rpl_dict=axis_replacements) + + def visit_TaskletNode( + self, + node: stree.TaskletNode, + axis_replacements: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + if axis_replacements is None: + axis_replacements = {} + for memlets in node.in_memlets.values(): + memlets.replace(axis_replacements) + for memlets in node.out_memlets.values(): + memlets.replace(axis_replacements) + + +class CartesianAxisMerge(stree.ScheduleNodeTransformer): + """Merge a cartesian axis if they are contiguous in code-flow. + + Can do: + - merge a given axis with the next maps at the same recursion level + - can overcompute (eager) to allow for more merging at the cost of an if + + Args: + axis: AxisIterator to be merged + eager: overcompute with a conditional guard + """ + + def __init__( + self, + axis: AxisIterator, + *, + eager: bool = True, + ) -> None: + self.axis = axis + self.eager = eager + + def __str__(self) -> str: + return f"CartesianAxisMerge({self.axis.name})" + + def _merge_node( + self, + node: stree.ScheduleTreeNode, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + """Direct code to the correct resolver for the node (e.g. visitor) + + Dev Note: Order matters! + Default behavior for base class must be _after_ bespoke leaf class + behavior (e.g. IfScope before ControlFlowScope) + """ + + if isinstance(node, stree.MapScope): + return self._map_overcompute_merge(node, nodes) + elif isinstance(node, stree.IfScope): + return self._push_ifelse_down(node, nodes) + elif isinstance(node, stree.TaskletNode): + return self._push_tasklet_down(node, nodes) + elif isinstance(node, stree.ControlFlowScope): + return self._default_control_flow(node, nodes) + else: + ndsl_log.debug( + f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends." + ) + return 0 + + def _default_control_flow( + self, + the_control_flow: stree.ControlFlowScope, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + if len(the_control_flow.children) != 0: + return self._merge(the_control_flow) + + return 0 + + def _push_tasklet_down( + self, + the_tasklet: stree.TaskletNode, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + """Push tasklet into a consecutive map.""" + in_memlets = the_tasklet.input_memlets() + if len(in_memlets) != 0: + if "__pystate" in [tasklet.data for tasklet in the_tasklet.input_memlets()]: + return 0 # Tasklet is a callback + + next_index = list_index(nodes, the_tasklet) + if next_index == len(nodes): + return 0 # Last node - done + + next_node = nodes[next_index + 1] + + # Before checking the possibility of merging - attempt to surface + # a map from the next nodes + merged = self._merge_node(next_node, nodes) + + # Attempt to push the tasklet in the next map + ndsl_log.debug(" Push tasklet down into next map") + next_node = nodes[next_index + 1] + if isinstance(next_node, stree.MapScope): + next_node.children.insert(0, the_tasklet) + the_tasklet.parent = next_node + nodes.remove(the_tasklet) + merged += self._merge_node(next_node, nodes) + + return merged + + def _push_ifelse_down( + self, + the_if: stree.IfScope, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + merged = 0 + + # Recurse down if/else/elif + if_index = list_index(nodes, the_if) + if len(the_if.children) != 0: + merged += self._merge_node(the_if.children[0], the_if.children) + for else_index in range(if_index + 1, len(nodes)): + else_node = nodes[else_index] + if else_index < len(nodes) and ( + isinstance(else_node, stree.ElseScope) + or isinstance(else_node, stree.ElifScope) + ): + merged += self._merge_node(else_node, else_node.children) + else: + break + + # Look at swapping if/else/elif first map w/ control flow + + # Gather all first maps - if they do not exists, get out + all_maps = [] + if isinstance(the_if.children[0], stree.MapScope): + all_maps.append(the_if.children[0]) + else: + return merged + for else_index in range(if_index + 1, len(nodes)): + else_node = nodes[else_index] + if else_index < len(nodes) and ( + isinstance(else_node, stree.ElseScope) + or isinstance(else_node, stree.ElifScope) + ): + if isinstance(else_node.children[0], stree.MapScope): + all_maps.append(else_node.children[0]) + else: + return merged + + else: + break + + # Check for mergeability + if len(all_maps) > 1: + the_map = all_maps[0] + for _map in all_maps[1:]: + if not _can_merge_axis_maps(the_map, _map, self.axis): + return merged + + # We are good to go - swap it all + ndsl_log.debug(f" Push IF {the_if.condition.as_string} down") + inner_if_map = the_if.children[0] + + # Swap IF & maps + if_index = list_index(nodes, the_if) + swap_node_position_in_tree(the_if, inner_if_map) + + # Swap ELIF/ELSE & maps + for else_index in range(if_index + 1, len(nodes)): + if else_index < len(nodes) and ( + isinstance(nodes[else_index], stree.ElseScope) + or isinstance(nodes[else_index], stree.ElifScope) + ): + swap_node_position_in_tree( + nodes[else_index], nodes[else_index].children[0] + ) + else: + break + + # Merge the Maps + assert isinstance(nodes[if_index], stree.MapScope) + merged += self._map_overcompute_merge(nodes[if_index], nodes) + + return merged + + def _map_overcompute_merge( + self, + the_map: stree.MapScope, + nodes: list[stree.ScheduleTreeNode], + ) -> int: + if _last_node(nodes, the_map): + return 0 + + next_node = _get_next_node(nodes, the_map) + + # If the next node is not a MapScope - recurse + if not isinstance(next_node, stree.MapScope): + merged = self._merge_node(next_node, nodes) + new_next_node = _get_next_node(nodes, the_map) + if new_next_node == next_node: + return merged + return merged + self._merge_node(the_map, nodes) + + # Attempt to merge consecutive maps + if not _can_merge_axis_maps(the_map, next_node, self.axis): + return 0 + + # Over compute to merge: + # - force-merge by expanding the ranges + # - then, guard children to only run in their respective range + first_range = the_map.node.map.range + second_range = next_node.node.map.range + merged_range = dace.subsets.Range( + [ + ( + f"min({first_range.ranges[0][0]}, {second_range.ranges[0][0]})", + f"max({first_range.ranges[0][1]}, {second_range.ranges[0][1]})", + 1, # NOTE: we can optimize this to gcd later + ) + ] + ) + + ndsl_log.debug( + f" Merge {self.axis.name} map: {first_range} ⋃ {second_range} -> {merged_range}" + ) + + # push IfScope down if children are just maps + axis_as_str = the_map.node.params[0] + first_map = InsertOvercomputationGuard( + axis_as_str, merged_range=merged_range, original_range=first_range + ).visit(the_map) + second_map = InsertOvercomputationGuard( + axis_as_str, + merged_range=merged_range, + original_range=second_range, + ).visit(next_node) + merged_children: list[stree.MapScope] = [ + *first_map.children, + *second_map.children, + ] + first_map.children = merged_children + + # TODO also merge containers and symbols (if applicable) + first_map.node.map.range = merged_range + + # delete now-merged second_map + del nodes[list_index(nodes, next_node)] + + return 1 + + def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int: + merged = 0 + + if __debug__: + detect_cycle(node.children, set()) + + i_candidate = 0 + while i_candidate < len(node.children): + next_node = node.children[i_candidate] + merged += self._merge_node(next_node, node.children) + i_candidate += 1 + + if __debug__: + detect_cycle(node.children, set()) + + return merged + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + """Merge as many maps as possible. + + The algorithm works as follows: + - Start merging - move nodes to surface maps as much as possible + - Try to merge the surfaced maps + - When done, count the number of actual merges + - If NO merges - restore the previous children + (undo potential changes that didn't lead to map merge) + Then exit. + """ + + # TODO: many interval generate many iterator name right now + # e.g. _k_0, _k_1... + # This makes merging more difficult. We could write a pre-pass + # that cleans this up BUT we have an issue with the THIS_K feature + # in the tasklet... + # NormalizeAxisSymbol(self.axis).visit(node) + + overall_merged = 0 + i = 0 + while True: + i += 1 + ndsl_log.debug(f"🔥 Merge attempt #{i}") + previous_children = copy.deepcopy(node.children) + try: + merged = self._merge(node) + overall_merged += merged + if __debug__: + detect_cycle(node.children, set()) + except RecursionError as re: + raise re + + # If we didn't merge, we revert the children + # to the previous state + if merged == 0: + ndsl_log.debug("🥹 No merges, revert!") + node.children = previous_children + break + + ndsl_log.debug( + f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged" + ) diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py new file mode 100644 index 00000000..0626133e --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py @@ -0,0 +1,149 @@ +from enum import Enum + +import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.memlet import Memlet + +from ndsl import ndsl_log + + +class AxisIterator(Enum): + _I = ("__i", 0) + _J = ("__j", 1) + _K = ("__k", 2) + + def as_str(self) -> str: + return self.value[0] + + def as_cartesian_index(self) -> int: + return self.value[1] + + +def no_data_dependencies_on_cartesian_axis( + first: stree.MapScope, + second: stree.MapScope, + axis: AxisIterator, +) -> bool: + """Check for read after write. Allow when indexation on the axis + is not offset.""" + + write_collector = MemletCollector(collect_reads=False) + write_collector.visit(first) + read_collector = MemletCollector(collect_writes=False) + read_collector.visit(second) + for write in write_collector.out_memlets: + # TODO: this can be optimized to allow non-overlapping intervals and such in the future + + if write.subset.dims() <= axis.as_cartesian_index(): + # Dimension does not exist + continue + + previous_axis_index = write.subset[axis.as_cartesian_index()][0] + for read in read_collector.in_memlets: + if write.data == read.data: + if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: + ndsl_log.debug( + f"[{axis.name} Merge] Found read after write conflict " + f"for {write.data} " + f"w/ different offset to {axis.name} (" + f"write at {write.subset[axis.as_cartesian_index()][0]}, " + f"read at {read.subset[axis.as_cartesian_index()][0]})" + ) + return False + return True + + +def no_data_dependencies( + first: stree.MapScope, + second: stree.MapScope, + restrict_check_to_k: bool = False, +) -> bool: + write_collector = MemletCollector(collect_reads=False) + write_collector.visit(first) + read_collector = MemletCollector(collect_writes=False) + read_collector.visit(second) + for write in write_collector.out_memlets: + # Make sure we don't have read after write conditions. + # TODO: this can be optimized to allow non-overlapping intervals and such in the future + if restrict_check_to_k: + if write.subset.dims() < 3: + # Case of 2D write - no K dependency + continue + + previous_k_index = write.subset[2][0] + for read in read_collector.in_memlets: + if write.data == read.data: + if previous_k_index != read.subset[2][0]: + print( + "[K Merge] Found read after write conflict " + f"for {write.data} " + "w/ different offset to K (" + f"write at {write.subset[2][0]}, " + f"read at {read.subset[2][0]})" + ) + return False + + else: + if write.data in [read.data for read in read_collector.in_memlets]: + print( + f"[All dims merge] Found potential read after write conflict for {write.data}" + ) + return False + return True + + +class MemletCollector(stree.ScheduleNodeVisitor): + """Gathers in_memlets and out_memlets of TaskNodes and LibraryCalls.""" + + in_memlets: list[Memlet] + out_memlets: list[Memlet] + + def __init__( + self, *, collect_reads: bool = True, collect_writes: bool = True + ) -> None: + self._collect_reads = collect_reads + self._collect_writes = collect_writes + + self.in_memlets = [] + self.out_memlets = [] + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + if self._collect_reads: + self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) + if self._collect_writes: + self.out_memlets.extend([memlet for memlet in node.out_memlets.values()]) + + def visit_LibraryCall(self, node: stree.LibraryCall) -> None: + if self._collect_reads: + if isinstance(node.in_memlets, set): + self.in_memlets.extend(node.in_memlets) + else: + assert isinstance(node.in_memlets, dict) + self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) + + if self._collect_writes: + if isinstance(node.out_memlets, set): + self.out_memlets.extend(node.out_memlets) + else: + assert isinstance(node.out_memlets, dict) + self.out_memlets.extend( + [memlet for memlet in node.out_memlets.values()] + ) + + +def has_dynamic_memlets(first: stree.MapScope, second: stree.MapScope) -> bool: + first_collector = MemletCollector() + second_collector = MemletCollector() + first_collector.visit(first) + second_collector.visit(second) + has_dynamic_memlets = any( + [ + memlet.dynamic + for memlet in [ + *first_collector.in_memlets, + *first_collector.out_memlets, + *second_collector.in_memlets, + *second_collector.out_memlets, + ] + ] + ) + return has_dynamic_memlets diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py new file mode 100644 index 00000000..2583ec2d --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -0,0 +1,21 @@ +import dace.sdfg.analysis.schedule_tree.treenodes as stree +import dace.subsets as sbs + + +class SpecializeCartesianMaps(stree.ScheduleNodeVisitor): + def __init__(self, mappings: dict[str, int]) -> None: + super().__init__() + self._mappings = mappings + + def visit_MapScope(self, node: stree.MapScope) -> None: + dims = [] + for p in node.node.map.params: + if p == "__i": + dims.append((0, self._mappings["__I"], 1)) + if p == "__j": + dims.append((0, self._mappings["__J"], 1)) + if p.startswith("__k"): + dims.append((0, self._mappings["__K"], 1)) + node.node.map.range = sbs.Range(dims) + + self.visit(node.children) diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py new file mode 100644 index 00000000..b965fc4a --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -0,0 +1,44 @@ +from typing import Collection + +import dace.sdfg.analysis.schedule_tree.treenodes as stree + + +def swap_node_position_in_tree( + top_node: stree.ScheduleTreeScope, child_node: stree.ScheduleTreeScope +) -> None: + """Top node becomes child, child becomes top node""" + # Take refs before swap + top_children = top_node.parent.children + top_level_parent = top_node.parent + + # Swap childrens + top_node.children = child_node.children + child_node.children = [top_node] + top_children.insert(list_index(top_children, top_node), child_node) + + # Re-parent + top_node.parent = child_node + child_node.parent = top_level_parent + + # Remove now-pushed original node + top_children.remove(top_node) + + +def detect_cycle(nodes: list[stree.ScheduleTreeNode], visited: set) -> None: + """Detect the cycles in the tree.""" + # Dev note: isn't there a DaCe tool for this?! + for n in nodes: + if id(n) in visited: + breakpoint() + visited.add(id(n)) + if hasattr(n, "children"): + detect_cycle(n.children, visited) + + +def list_index( + collection: Collection[stree.ScheduleTreeNode], + node: stree.ScheduleTreeNode, +) -> int: + """Check if node is in list with "is" operator.""" + # compare with "is" to get memory comparison. ".index()" uses value comparison + return next(index for index, element in enumerate(collection) if element is node) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py new file mode 100644 index 00000000..10fb77cd --- /dev/null +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod + +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge + + +class StreePipeline(ABC): + @abstractmethod + def __hash__(self) -> int: + raise NotImplementedError("Missing implementation of __hash__") + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError("Missing implementation of __repr__") + + @abstractmethod + def run( + self, + stree: stree.ScheduleTreeRoot, + verbose: bool = False, + ) -> stree.ScheduleTreeRoot: + raise NotImplementedError("Missing implementation of run") + + +class CPUPipeline(StreePipeline): + def __init__( + self, passes: list[stree.ScheduleNodeTransformer] | None = None + ) -> None: + self.passes = ( + passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] + ) + + def __repr__(self) -> str: + return str([type(p) for p in self.passes]) + + def __hash__(self) -> int: + return hash(repr(self)) + + def run( + self, + stree: stree.ScheduleTreeRoot, + verbose: bool = False, + ) -> stree.ScheduleTreeRoot: + for p in self.passes: + if verbose: + print(f"[Stree OPT] {p}") + p.visit(stree) + + return stree + + +class GPUPipeline(StreePipeline): + def __init__( + self, passes: list[stree.ScheduleNodeTransformer] | None = None + ) -> None: + self.passes = passes if passes else [] + + def __repr__(self) -> str: + return str([type(p) for p in self.passes]) + + def __hash__(self) -> int: + return hash(repr(self)) + + def run( + self, + stree: stree.ScheduleTreeRoot, + verbose: bool = False, + ) -> stree.ScheduleTreeRoot: + for p in self.passes: + if verbose: + print(f"[Stree OPT] {p}") + p.visit(stree) + + return stree diff --git a/ndsl/dsl/dace/utils.py b/ndsl/dsl/dace/utils.py index 05fa5754..cb35503e 100644 --- a/ndsl/dsl/dace/utils.py +++ b/ndsl/dsl/dace/utils.py @@ -1,7 +1,7 @@ import json import time from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Any import dace import numpy as np @@ -30,22 +30,20 @@ def __enter__(self) -> None: ndsl_log.debug(f"{self.prefix} {self.label}...") self.start = time.time() - def __exit__(self, _type, _val, _traceback) -> None: + def __exit__(self, _type, _val, _traceback) -> None: # type: ignore elapsed = time.time() - self.start ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed}s.") -def _is_ref(sd: dace.sdfg.SDFG, aname: str): - found = False +def _is_ref(sd: dace.sdfg.SDFG, aname: str) -> bool: for node, state in sd.all_nodes_recursive(): if not isinstance(state, dace.sdfg.SDFGState): continue if state.parent is sd: if isinstance(node, dace.nodes.AccessNode) and aname == node.data: - found = True - break + return True - return found + return False # ---------------------------------------------------------- @@ -68,19 +66,19 @@ class StorageReport: unreferenced_in_bytes: int = 0 in_pooled_in_bytes: int = 0 top_level_in_bytes: int = 0 - details: List[ArrayReport] = field(default_factory=list) + details: list[ArrayReport] = field(default_factory=list) def memory_static_analysis( sdfg: dace.sdfg.SDFG, -) -> Dict[dace.StorageType, StorageReport]: +) -> dict[dace.StorageType, StorageReport]: """Analysis an SDFG for memory pressure. The results split memory by type (dace.StorageType) and account for allocated, unreferenced and top level (e.g. top-most SDFG) memory """ # We report all allocation type - allocations: Dict[dace.StorageType, StorageReport] = {} + allocations: dict[dace.StorageType, StorageReport] = {} for storage_type in dace.StorageType: allocations[storage_type] = StorageReport(name=storage_type) @@ -132,7 +130,7 @@ def memory_static_analysis( def report_memory_static_analysis( sdfg: dace.sdfg.SDFG, - allocations: Dict[dace.StorageType, StorageReport], + allocations: dict[dace.StorageType, StorageReport], detail_report: bool = False, ) -> str: """Create a human readable report form the memory analysis results""" @@ -189,7 +187,7 @@ def copy_kernel(q_in: FloatField, q_out: FloatField) -> None: class MaxBandwidthBenchmarkProgram: - def __init__(self, size, backend) -> None: + def __init__(self, size: Any, backend: str) -> None: from ndsl.dsl.dace.orchestration import DaCeOrchestration, orchestrate dace_config = DaceConfig( @@ -205,16 +203,17 @@ def __init__(self, size, backend) -> None: ) orchestrate(obj=self, config=dace_config) - def __call__(self, A, B, n: int) -> None: - for i in dace.nounroll(range(n)): + def __call__(self, A: Any, B: Any, n: int) -> None: + for _i in dace.nounroll(range(n)): self.copy_stencil(A, B) def kernel_theoretical_timing( sdfg: dace.sdfg.SDFG, - hardware_bw_in_GB_s: Optional[float] = None, - backend: Optional[str] = None, -) -> Dict[str, float]: + *, + backend: str, + hardware_bw_in_GB_s: float | None = None, +) -> dict[str, float]: """Compute a lower timing bound for kernels with the following hypothesis: - Performance is memory bound, e.g. arithmetic intensity isn't counted @@ -271,7 +270,7 @@ def kernel_theoretical_timing( (me, state) for me, state in allmaps if get_parent_map(state, me) is None ] - result: Dict[str, float] = {} + result: dict[str, float] = {} for node, state in topmaps: nsdfg = state.parent mx = state.exit_node(node) @@ -319,9 +318,9 @@ def kernel_theoretical_timing( def report_kernel_theoretical_timing( - timings: Dict[str, float], + timings: dict[str, float], human_readable: bool = True, - out_format: Optional[str] = None, + out_format: str | None = None, ) -> str: """Produce a human readable or CSV of the kernel timings""" result_string = f"Maps processed: {len(timings)}.\n" @@ -343,16 +342,16 @@ def report_kernel_theoretical_timing( def kernel_theoretical_timing_from_path( sdfg_path: str, - hardware_bw_in_GB_s: Optional[float] = None, - backend: Optional[str] = None, - output_format: Optional[str] = None, + backend: str, + hardware_bw_in_GB_s: float | None = None, + output_format: str | None = None, ) -> str: """Load an SDFG and report the theoretical kernel timings""" print(f"Running kernel_theoretical_timing for {sdfg_path}") timings = kernel_theoretical_timing( dace.SDFG.from_file(sdfg_path), - hardware_bw_in_GB_s=hardware_bw_in_GB_s, backend=backend, + hardware_bw_in_GB_s=hardware_bw_in_GB_s, ) return report_kernel_theoretical_timing( timings, diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index 78a68fa4..bdccbe3b 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, List, Optional +from typing import Any from ndsl.comm.communicator import Communicator from ndsl.dsl.dace.orchestration import dace_inhibitor @@ -18,10 +18,10 @@ class WrappedHaloUpdater: def __init__( self, updater: HaloUpdater, - state, - qty_x_names: List[str], - qty_y_names: List[str] = None, - comm: Optional[Communicator] = None, + state: Any, # no typehint on purpose to avoid dependency on PyFV3 + qty_x_names: list[str], + qty_y_names: list[str] | None = None, + comm: Communicator | None = None, ) -> None: self._updater = updater self._state = state @@ -30,19 +30,19 @@ def __init__( self._comm = comm @staticmethod - def check_for_attribute(state: Any, attr: str): + def check_for_attribute(state: Any, attr: str) -> bool: if dataclasses.is_dataclass(state): - return state.__getattribute__(attr) - elif isinstance(state, dict): + return state.__getattribute__(attr) # type: ignore + if isinstance(state, dict): return attr in state.keys() return False @dace_inhibitor - def start(self): + def start(self) -> None: if self._qtx_y_names is None: if dataclasses.is_dataclass(self._state): self._updater.start( - [self._state.__getattribute__(x) for x in self._qtx_x_names] + [self._state.__getattribute__(x) for x in self._qtx_x_names] # type: ignore ) elif isinstance(self._state, dict): self._updater.start([self._state[x] for x in self._qtx_x_names]) @@ -51,8 +51,8 @@ def start(self): else: if dataclasses.is_dataclass(self._state): self._updater.start( - [self._state.__getattribute__(x) for x in self._qtx_x_names], - [self._state.__getattribute__(y) for y in self._qtx_y_names], + [self._state.__getattribute__(x) for x in self._qtx_x_names], # type: ignore + [self._state.__getattribute__(y) for y in self._qtx_y_names], # type: ignore ) elif isinstance(self._state, dict): self._updater.start( @@ -63,16 +63,18 @@ def start(self): raise NotImplementedError @dace_inhibitor - def wait(self): + def wait(self) -> None: self._updater.wait() @dace_inhibitor - def update(self): + def update(self) -> None: self.start() self.wait() @dace_inhibitor - def interface(self): + def interface(self) -> None: + assert self._comm is not None + assert self._qtx_y_names is not None assert len(self._qtx_x_names) == 1 assert len(self._qtx_y_names) == 1 self._comm.synchronize_vector_interfaces( diff --git a/ndsl/dsl/gt4py/__init__.py b/ndsl/dsl/gt4py/__init__.py index 7c051fb0..3fa0b3dc 100644 --- a/ndsl/dsl/gt4py/__init__.py +++ b/ndsl/dsl/gt4py/__init__.py @@ -26,12 +26,18 @@ computation, cos, cosh, + erf, + erfc, exp, externals, + float32, + float64, floor, function, gamma, horizontal, + int32, + int64, interval, isfinite, isinf, @@ -42,6 +48,8 @@ min, mod, region, + round, + round_away_from_zero, sin, sinh, sqrt, @@ -51,3 +59,65 @@ trunc, types, ) + + +__all__ = [ + "BACKWARD", + "FORWARD", + "IJ", + "IJK", + "IK", + "JK", + "PARALLEL", + "Field", + "GlobalTable", + "I", + "J", + "K", + "Sequence", + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", + "cbrt", + "ceil", + "compile_assert", + "computation", + "cos", + "cosh", + "erf", + "erfc", + "exp", + "externals", + "float32", + "float64", + "floor", + "function", + "gamma", + "horizontal", + "int32", + "int64", + "interval", + "isfinite", + "isinf", + "isnan", + "log", + "log10", + "max", + "min", + "mod", + "region", + "round", + "round_away_from_zero", + "sin", + "sinh", + "sqrt", + "stencil", + "tan", + "tanh", + "trunc", + "types", +] diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 5f3619e7..3b9c5fd9 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,12 +1,14 @@ +from collections.abc import Callable, Sequence from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any import numpy as np +import numpy.typing as npt from gt4py import storage as gt_storage from gt4py.cartesian import backend as gt_backend from ndsl.constants import N_HALO_DEFAULT -from ndsl.dsl.typing import DTypes, Field, Float +from ndsl.dsl.typing import DTypes, Float from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp @@ -48,10 +50,10 @@ def wrapper(*args, **kwargs) -> Any: def _mask_to_dimensions( - mask: Tuple[bool, ...], shape: Sequence[int] -) -> List[Union[str, int]]: + mask: tuple[bool, ...], shape: Sequence[int] +) -> list[str | int]: assert len(mask) >= 3 - dimensions: List[Union[str, int]] = [] + dimensions: list[str | int] = [] for i, axis in enumerate(("I", "J", "K")): if mask[i]: dimensions.append(axis) @@ -63,13 +65,13 @@ def _mask_to_dimensions( return dimensions -def _translate_origin(origin: Sequence[int], mask: Tuple[bool, ...]) -> Sequence[int]: +def _translate_origin(origin: Sequence[int], mask: tuple[bool, ...]) -> Sequence[int]: if len(origin) == int(sum(mask)): # Correct length. Assumed to be correctly specified. return origin assert len(mask) == 3 - final_origin: List[int] = [] + final_origin: list[int] = [] for i, has_axis in enumerate(mask): if has_axis: final_origin.append(origin[i]) @@ -79,19 +81,19 @@ def _translate_origin(origin: Sequence[int], mask: Tuple[bool, ...]) -> Sequence def make_storage_data( - data: Field, - shape: Optional[Tuple[int, ...]] = None, - origin: Tuple[int, ...] = origin, + data: npt.NDArray, + shape: tuple[int, ...] | None = None, + origin: tuple[int, ...] = origin, *, backend: str, dtype: DTypes = Float, - mask: Optional[Tuple[bool, ...]] = None, - start: Tuple[int, ...] = (0, 0, 0), - dummy: Optional[Tuple[int, ...]] = None, + mask: tuple[bool, ...] | None = None, + start: tuple[int, ...] = (0, 0, 0), + dummy: tuple[int, ...] | None = None, axis: int = 2, max_dim: int = 3, read_only: bool = True, -) -> Field: +) -> npt.NDArray: """Create a new gt4py storage from the given data. Args: @@ -129,7 +131,7 @@ def make_storage_data( if mask is None: if not read_only: - default_mask: Tuple[bool, ...] = (True, True, True) + default_mask: tuple[bool, ...] = (True, True, True) else: if n_dims == 1: if axis == 1: @@ -137,7 +139,7 @@ def make_storage_data( default_mask = (True, True, False) shape = (1, shape[axis]) else: - default_mask = tuple([i == axis for i in range(max_dim)]) # type: ignore + default_mask = tuple([i == axis for i in range(max_dim)]) elif dummy or axis != 2: default_mask = (True, True, True) else: @@ -196,16 +198,16 @@ def make_storage_data( def _make_storage_data_1d( - data: Field, - shape: Tuple[int, ...], - start: Tuple[int, ...] = (0, 0, 0), - dummy: Optional[Tuple[int, ...]] = None, + data: npt.NDArray, + shape: tuple[int, ...], + start: tuple[int, ...] = (0, 0, 0), + dummy: tuple[int, ...] | None = None, axis: int = 2, read_only: bool = True, *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: # axis refers to a repeated axis, dummy refers to a singleton axis axis = min(axis, len(shape) - 1) buffer = zeros(shape[axis], dtype=dtype, backend=backend) @@ -233,16 +235,16 @@ def _make_storage_data_1d( def _make_storage_data_2d( - data: Field, - shape: Tuple[int, ...], - start: Tuple[int, ...] = (0, 0, 0), - dummy: Optional[Tuple[int, ...]] = None, + data: npt.NDArray, + shape: tuple[int, ...], + start: tuple[int, ...] = (0, 0, 0), + dummy: tuple[int, ...] | None = None, axis: int = 2, read_only: bool = True, *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: # axis refers to which axis should be repeated (when making a full 3d data), # dummy refers to a singleton axis do_reshape = dummy or axis != 2 @@ -270,13 +272,13 @@ def _make_storage_data_2d( def _make_storage_data_3d( - data: Field, - shape: Tuple[int, ...], - start: Tuple[int, ...] = (0, 0, 0), + data: npt.NDArray, + shape: tuple[int, ...], + start: tuple[int, ...] = (0, 0, 0), *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: istart, jstart, kstart = start isize, jsize, ksize = data.shape buffer = zeros(shape, dtype=dtype, backend=backend) @@ -289,13 +291,13 @@ def _make_storage_data_3d( def _make_storage_data_Nd( - data: Field, - shape: Tuple[int, ...], - start: Tuple[int, ...] = None, + data: npt.NDArray, + shape: tuple[int, ...], + start: tuple[int, ...] | None = None, *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: if start is None: start = tuple([0] * data.ndim) buffer = zeros(shape, dtype=dtype, backend=backend) @@ -305,13 +307,13 @@ def _make_storage_data_Nd( def make_storage_from_shape( - shape: Tuple[int, ...], - origin: Tuple[int, ...] = origin, + shape: tuple[int, ...], + origin: tuple[int, ...] = origin, *, backend: str, dtype: DTypes = Float, - mask: Optional[Tuple[bool, ...]] = None, -) -> Field: + mask: tuple[bool, ...] | None = None, +) -> npt.NDArray: """Create a new gt4py storage of a given shape filled with zeros. Args: @@ -348,21 +350,21 @@ def make_storage_from_shape( def make_storage_dict( - data: Field, - shape: Optional[Tuple[int, ...]] = None, - origin: Tuple[int, ...] = origin, - start: Tuple[int, ...] = (0, 0, 0), - dummy: Optional[Tuple[int, ...]] = None, - names: Optional[List[str]] = None, + data: npt.NDArray, + shape: tuple[int, ...] | None = None, + origin: tuple[int, ...] = origin, + start: tuple[int, ...] = (0, 0, 0), + dummy: tuple[int, ...] | None = None, + names: list[str] | None = None, axis: int = 2, *, backend: str, dtype: DTypes = Float, -) -> Dict[str, "Field"]: +) -> dict[str, npt.NDArray]: assert names is not None, "for 4d variable storages, specify a list of names" if shape is None: shape = data.shape - data_dict: Dict[str, Field] = dict() + data_dict: dict[str, npt.NDArray] = dict() for i in range(data.shape[3]): data_dict[names[i]] = make_storage_data( squeeze(data[:, :, :, i]), @@ -473,12 +475,12 @@ def moveaxis(array, source: int, destination: int): return xp.moveaxis(array, source, destination) -def tile(array, reps: Union[int, Tuple[int, ...]]): +def tile(array, reps: int | tuple[int, ...]): xp = cp if cp and type(array) is cp.ndarray else np return xp.tile(array, reps) -def squeeze(array, axis: Union[int, Tuple[int]] = None): +def squeeze(array, axis: int | tuple[int] | None = None): xp = cp if cp and type(array) is cp.ndarray else np return xp.squeeze(array, axis) @@ -504,7 +506,7 @@ def unique( return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, - axis: Union[int, Tuple[int]] = None, + axis: int | tuple[int] | None = None, ): xp = cp if cp and type(array) is cp.ndarray else np return xp.unique(array, return_index, return_inverse, return_counts, axis) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py new file mode 100644 index 00000000..721ae38b --- /dev/null +++ b/ndsl/dsl/ndsl_runtime.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import inspect +import warnings +from collections.abc import Callable +from typing import Any + +from ndsl.dsl.dace import DaceConfig, orchestrate +from ndsl.dsl.typing import Float +from ndsl.initialization.allocator import QuantityFactory +from ndsl.quantity import Local, Quantity + + +_TOP_LEVEL: object | None = None + + +class NDSLRuntime: + """Base class to tool runtime code, allows use of Locals, orchestration and + debug tools. + + The __call__ function will automatically be orchestrated.""" + + def __init__(self, dace_config: DaceConfig) -> None: + self._dace_config = dace_config + # Use this flag to detect that the init wasn't done properly + self._base_class_was_properly_super_init = True + + def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None: + # WARNING: no code outside the `init_decorator` this is cls + # function, it will be called ONLY ONCE for monkey-patching the + # Class - not the instance ! + + def init_decorator(previous_init: Callable) -> Callable: + def new_init( + self: NDSLRuntime, + *args: list[Any], + **kwargs: dict[str, Any], + ) -> None: + global _TOP_LEVEL + if _TOP_LEVEL is None: + _TOP_LEVEL = self + previous_init(self, *args, **kwargs) + self.__post_init__() + + return new_init + + cls.__init__ = init_decorator(cls.__init__) # type: ignore[method-assign] + + def __post_init__(self) -> None: + if not hasattr(self, "_base_class_was_properly_super_init"): + raise RuntimeError( + f"Class {type(self).__name__} inherit from NDSLRuntime but didn't call super().__init__." + ) + + # Check quantity allocation of NDSLRuntime supervised code + if _TOP_LEVEL == self: + + def check_for_quantity(object_: object) -> None: + for key, value in object_.__dict__.items(): + if isinstance(value, Quantity) and not isinstance(value, Local): + warnings.warn( + f"{type(self).__name__}.{key} is a Quantity instead of a Locals" + " on a NDSLRuntime - our eyebrows are frowned.", + UserWarning, + stacklevel=2, + ) + elif isinstance(value, NDSLRuntime): + check_for_quantity(value) + + check_for_quantity(self) + + # Orchestrate __call__ by default + if callable(self): + orchestrate( + obj=self, + config=self._dace_config, + ) + print(type(self)) + + def __getattribute__(self, name: str) -> Any: + attr = super().__getattribute__(name) + # We look at the direct caller frame for our own `self` + # in the locals. + # All other cases are forbidden. + if isinstance(attr, Local): + frame = inspect.currentframe() + if frame is None: + raise NotImplementedError( + "Locals check cannot locate frame. Talk to the team." + ) + caller_frame = frame.f_back + if ( + not caller_frame + or "self" not in caller_frame.f_locals + or not isinstance(caller_frame.f_locals["self"], type(self)) + ): + # We expect the original class to have been monkey-patched + # See `dace.dsl.orchestration.orchestrate` + unpatched_name = type(self).__name__[: -len("_patched")] + raise RuntimeError( + f"Forbidden Local access: {name} called outside of {unpatched_name}." + ) + + return attr + + def make_local( + self, + quantity_factory: QuantityFactory, + dims: list[str], + dtype: type = Float, + units: str = "unspecified", + *, + allow_mismatch_float_precision: bool = False, + ) -> Local: + quantity = quantity_factory.zeros( + dims, + units, + dtype, + allow_mismatch_float_precision=allow_mismatch_float_precision, + ) + return Local( + data=quantity.data, + dims=quantity.dims, + units=quantity.units, + origin=quantity.origin, + extent=quantity.extent, + gt4py_backend=quantity.gt4py_backend, + allow_mismatch_float_precision=allow_mismatch_float_precision, + ) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index d829bbd0..f557c3b3 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -1,26 +1,18 @@ +from __future__ import annotations + import copy import dataclasses import inspect -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) +import numbers +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Any, cast import dace import numpy as np from gt4py.cartesian import config as gt_config from gt4py.cartesian import definitions as gt_definitions from gt4py.cartesian import gtscript +from gt4py.cartesian.definitions import FieldInfo from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from gt4py.cartesian.stencil_object import StencilObject @@ -32,15 +24,27 @@ from ndsl.debug import ndsl_debugger from ndsl.dsl.dace.orchestration import SDFGConvertible from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig -from ndsl.dsl.typing import Float, Index3D, cast_to_index3d -from ndsl.initialization.sizer import GridSizer, SubtileGridSizer +from ndsl.dsl.typing import ( + BoolFieldIJ, + Float, + FloatFieldIJ, + FloatFieldIJ32, + FloatFieldIJ64, + Index3D, + Int, + IntFieldIJ, + IntFieldIJ32, + IntFieldIJ64, + cast_to_index3d, +) +from ndsl.initialization import GridSizer, SubtileGridSizer from ndsl.logging import ndsl_log from ndsl.quantity import Quantity from ndsl.quantity.field_bundle import FieldBundleType, MarkupFieldBundleType from ndsl.testing.comparison import LegacyMetric -def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id): +def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id): # type: ignore[no-untyped-def] report_head = f"comparing against numpy for func {function_name}, gt_id {gt_id}:" report_segments = [] for i, (arg, numpy_arg) in enumerate(zip(args, args_copy)): @@ -66,7 +70,7 @@ def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id print(report_head + report_body) -def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label) -> str: +def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label: str) -> str: metric = LegacyMetric( reference_values=arg, computed_values=numpy_arg, @@ -74,7 +78,7 @@ def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label) -> str: ignore_near_zero_errors=False, near_zero=0, ) - return metric.__repr__() + return f"{label}: {metric.__repr__()}" @dataclasses.dataclass @@ -85,17 +89,17 @@ class TimingCollector: exec_info: contains info about the execution of each stencil. """ - build_info: Dict[str, dict] = dataclasses.field(default_factory=dict) - exec_info: Dict[str, Any] = dataclasses.field( + build_info: dict[str, dict] = dataclasses.field(default_factory=dict) + exec_info: dict[str, Any] = dataclasses.field( default_factory=lambda: {"__aggregate_data": True} ) - def build_report(self, key: str = "build_time", **kwargs) -> str: + def build_report(self, key: str = "build_time", **kwargs: Any) -> str: return type(self)._show_report( self.build_info, self.build_info.keys(), key, **kwargs ) - def exec_report(self, key: str = "total_run_time", **kwargs) -> str: + def exec_report(self, key: str = "total_run_time", **kwargs: Any) -> str: # NOTE: Uses the build_info keys to distinguish stencils return type(self)._show_report( self.exec_info, self.build_info.keys(), key, **kwargs @@ -103,7 +107,7 @@ def exec_report(self, key: str = "total_run_time", **kwargs) -> str: @staticmethod def _show_report( - infos: Dict[str, Any], + infos: dict[str, Any], keys: Iterable[str], secondary_key: str, *, @@ -124,7 +128,7 @@ def _show_report( format = f".{digits}e" - outputs: List[str] = [f"Total: {sum(d[1] for d in data):{format}}"] + outputs: list[str] = [f"Total: {sum(d[1] for d in data):{format}}"] for name, val in sorted_data: if len(name) > name_width: width = int(name_width / 2) - 3 @@ -149,13 +153,13 @@ class CompareToNumpyStencil: def __init__( self, func: Callable[..., None], - origin: Union[Tuple[int, ...], Mapping[str, Tuple[int, ...]]], - domain: Tuple[int, ...], + origin: tuple[int, ...] | Mapping[str, tuple[int, ...]], + domain: tuple[int, ...], stencil_config: StencilConfig, - externals: Optional[Mapping[str, Any]] = None, - skip_passes: Optional[Tuple[str, ...]] = None, - timing_collector: Optional[TimingCollector] = None, - comm: Optional[Comm] = None, + externals: Mapping[str, Any] | None = None, + skip_passes: tuple[str, ...] = (), + timing_collector: TimingCollector | None = None, + comm: Comm | None = None, ): self._actual = FrozenStencil( func=func, @@ -172,7 +176,6 @@ def __init__( rebuild=stencil_config.compilation_config.rebuild, validate_args=stencil_config.compilation_config.validate_args, format_source=True, - device_sync=None, run_mode=RunMode.BuildAndRun, use_minimal_caching=False, ) @@ -194,8 +197,8 @@ def __init__( def __call__( self, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: args_copy = copy.deepcopy(args) kwargs_copy = copy.deepcopy(kwargs) @@ -211,12 +214,12 @@ def __call__( ) -def _stencil_object_name(stencil_object) -> str: +def _stencil_object_name(stencil_object: StencilObject) -> str: """Returns a unique name for each gt4py stencil object, including the hash.""" return type(stencil_object).__name__ -def get_pair_rank(rank: int, size: int): +def get_pair_rank(rank: int, size: int) -> int: dycore_ranks = size // 2 if rank < dycore_ranks: return rank + dycore_ranks @@ -224,7 +227,7 @@ def get_pair_rank(rank: int, size: int): return rank - dycore_ranks -def compare_ranks(comm: Comm, data) -> Mapping[str, int]: +def compare_ranks(comm: Comm, data: dict) -> Mapping[str, int]: rank = comm.Get_rank() size = comm.Get_size() pair_rank = get_pair_rank(rank, size) @@ -255,13 +258,13 @@ class FrozenStencil(SDFGConvertible): def __init__( self, func: Callable[..., None], - origin: Union[Tuple[int, ...], Mapping[str, Tuple[int, ...]]], - domain: Tuple[int, ...], + origin: tuple[int, ...] | Mapping[str, tuple[int, ...]], + domain: tuple[int, ...], stencil_config: StencilConfig, - externals: Optional[Mapping[str, Any]] = None, - skip_passes: Tuple[str, ...] = (), - timing_collector: Optional[TimingCollector] = None, - comm: Optional[Comm] = None, + externals: Mapping[str, Any] | None = None, + skip_passes: tuple[str, ...] = (), + timing_collector: TimingCollector | None = None, + comm: Comm | None = None, ): """ Args: @@ -277,7 +280,6 @@ def __init__( """ if isinstance(origin, tuple): origin = cast_to_index3d(origin) - origin = cast(Union[Index3D, Mapping[str, Tuple[int, ...]]], origin) self.origin = origin self.domain: Index3D = cast_to_index3d(domain) self.stencil_config: StencilConfig = stencil_config @@ -296,10 +298,13 @@ def __init__( stencil_kwargs = self.stencil_config.stencil_kwargs( skip_passes=skip_passes, func=func ) - self.stencil_object: StencilObject | None = None + self.stencil_object: StencilObject self._argument_names = tuple(inspect.getfullargspec(func).args) + # NOTE: this is also down in `dace/build.py` for orchestration + # This is still needed for non-orchestrated used of DaCe. + # A better build system would take care of BOTH of those at the same time if "dace" in self.stencil_config.compilation_config.backend: dace.Config.set( "default_build_folder", @@ -313,6 +318,21 @@ def __init__( len(self._argument_names) > 0 ), "A stencil with no arguments? You may be double decorating" + # Overloading `dtypes` to allow parsing of NDSL concepts + ndsl_dtypes = { + # Mixed precision + float: Float, + int: Int, + # 2D temporaries + "FloatFieldIJ": FloatFieldIJ, + "FloatFieldIJ32": FloatFieldIJ32, + "FloatFieldIJ64": FloatFieldIJ64, + "IntFieldIJ": IntFieldIJ, + "IntFieldIJ32": IntFieldIJ32, + "IntFieldIJ64": IntFieldIJ64, + "BoolFieldIJ": BoolFieldIJ, + } + # Keep compilation at __init__ if we are not orchestrated. # If we orchestrate, move the compilation at call time to make sure # disable_codegen do not lead to call to uncompiled stencils, which fails @@ -321,7 +341,7 @@ def __init__( self.stencil_object = gtscript.lazy_stencil( definition=func, externals=externals, - dtypes={float: Float}, + dtypes=ndsl_dtypes, **stencil_kwargs, build_info=(build_info := {}), # type: ignore ) @@ -345,10 +365,10 @@ def __init__( self.stencil_object = gtscript.stencil( definition=func, externals=externals, - dtypes={float: Float}, + dtypes=ndsl_dtypes, **stencil_kwargs, build_info=(build_info := {}), - ) # type: ignore + ) if ( compilation_config.use_minimal_caching @@ -357,35 +377,37 @@ def __init__( ): unblock_waiting_tiles(MPI.COMM_WORLD) - self._timing_collector.build_info[ - _stencil_object_name(self.stencil_object) - ] = build_info + self._timing_collector.build_info[_stencil_object_name(self.stencil_object)] = ( + build_info + ) field_info = self.stencil_object.field_info - self._field_origins: Dict[ - str, Tuple[int, ...] - ] = FrozenStencil._compute_field_origins(field_info, self.origin) - """mapping from field names to field origins""" + self._field_origins: dict[str, tuple[int, ...]] = ( + FrozenStencil._compute_field_origins(field_info, self.origin) + ) + """Mapping from field names to field origins""" - self._stencil_run_kwargs: Dict[str, Any] = { + self._stencil_run_kwargs: dict[str, Any] = { "_origin_": self._field_origins, "_domain_": self.domain, } - self._written_fields: List[str] = FrozenStencil._get_written_fields(field_info) + self._written_fields = FrozenStencil._get_written_fields(field_info) if stencil_config.compilation_config.run_mode == RunMode.Build: - def nothing_function(*args, **kwargs): + def nothing_function(*args, **kwargs): # type: ignore[no-untyped-def] pass - setattr(self, "__call__", nothing_function) + setattr(self, "__call__", nothing_function) # noqa: B010 - def __call__(self, *args, **kwargs) -> None: + def __call__(self, *args: Any, **kwargs: Any) -> None: # Verbose stencil execution if self.stencil_config.verbose: ndsl_log.debug(f"Running {self._func_name}") + self._validate_quantity_sizes(*args, **kwargs) + # Marshal arguments args_list = list(args) _convert_quantities_to_storage(args_list, kwargs) @@ -420,7 +442,7 @@ def __call__(self, *args, **kwargs) -> None: domain=self.domain, validate_args=True, exec_info=self._timing_collector.exec_info, - ) # type: ignore + ) else: self.stencil_object.run( **args_as_kwargs, @@ -447,8 +469,10 @@ def __call__(self, *args, **kwargs) -> None: @classmethod def _compute_field_origins( - cls, field_info_mapping, origin: Union[Index3D, Mapping[str, Tuple[int, ...]]] - ) -> Dict[str, Tuple[int, ...]]: + cls, + field_info_mapping: dict[str, gt_definitions.FieldInfo], + origin: Index3D | Mapping[str, tuple[int, ...]], + ) -> dict[str, tuple[int, ...]]: """ Computes the origin for each field in the stencil call. @@ -461,8 +485,8 @@ def _compute_field_origins( origin_mapping: a mapping from field names to origins """ if isinstance(origin, tuple): - field_origins: Dict[str, Tuple[int, ...]] = {"_all_": origin} - origin_tuple: Tuple[int, ...] = origin + field_origins: dict[str, tuple[int, ...]] = {"_all_": origin} + origin_tuple: tuple[int, ...] = origin else: field_origins = {**origin} origin_tuple = origin["_all_"] @@ -475,6 +499,9 @@ def _compute_field_origins( for ax in field_info.axes: origin_index = {"I": 0, "J": 1, "K": 2}[ax] field_origin_list.append(origin_tuple[origin_index]) + for i, _data_dim in enumerate(field_info.data_dims): + if field_info.mask[len(field_info.domain_mask) + i]: + field_origin_list.append(0) field_origin = tuple(field_origin_list) else: field_origin = origin_tuple @@ -482,7 +509,7 @@ def _compute_field_origins( return field_origins @classmethod - def _get_written_fields(cls, field_info) -> List[str]: + def _get_written_fields(cls, field_info: dict[str, FieldInfo]) -> list[str]: """Returns the list of fields that are written. Args: @@ -492,10 +519,7 @@ def _get_written_fields(cls, field_info) -> List[str]: field_name for field_name in field_info if field_info[field_name] - and bool( - field_info[field_name].access - & gt_definitions.AccessKind.WRITE # type: ignore - ) + and bool(field_info[field_name].access & gt_definitions.AccessKind.WRITE) ] return write_fields @@ -505,7 +529,7 @@ def _get_oir_pipeline(cls, skip_passes: Sequence[str]) -> OirPipeline: skip_steps = [step_map[pass_name] for pass_name in skip_passes] return DefaultPipeline(skip=skip_steps) - def __sdfg__(self, *args, **kwargs): + def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] """Implemented SDFG generation""" args_as_kwargs = dict(zip(self._argument_names, args)) return self.stencil_object.__sdfg__( @@ -515,22 +539,55 @@ def __sdfg__(self, *args, **kwargs): **kwargs, ) - def __sdfg_signature__(self): + def __sdfg_signature__(self): # type: ignore[no-untyped-def] """Implemented SDFG signature lookup""" return self.stencil_object.__sdfg_signature__() - def __sdfg_closure__(self, *args, **kwargs): + def __sdfg_closure__(self, *args, **kwargs): # type: ignore[no-untyped-def] """Implemented SDFG closure build""" return self.stencil_object.__sdfg_closure__(*args, **kwargs) - def closure_resolver(self, constant_args, given_args, parent_closure=None): + def closure_resolver(self, constant_args, given_args, parent_closure=None): # type: ignore[no-untyped-def] """Implemented SDFG closure resolver build""" return self.stencil_object.closure_resolver( constant_args, given_args, parent_closure=parent_closure ) + def _validate_quantity_sizes(self, *args, **kwargs): # type: ignore[no-untyped-def] + """Checks that the sizes of quantities are compatible with the domain of the stencil. + + This function emits a warning in case one of the dimensions does not match. + + """ + all_args_as_kwargs = dict(zip(self._argument_names, tuple(list(args)))) | kwargs + + domain_sizes = { + axis_name: axis_size + for axis_names, axis_size in zip([X_DIMS, Y_DIMS, Z_DIMS], self.domain) + for axis_name in axis_names + } -def _convert_quantities_to_storage(args, kwargs): + for name, argument in all_args_as_kwargs.items(): + if isinstance(argument, Quantity): + for axis, quantity_size in zip(argument.dims, argument.extent): + full_size = quantity_size + if axis in (X_DIMS + Y_DIMS): + full_size += 2 * argument.metadata.n_halo + if ( + axis in (X_DIMS + Y_DIMS + Z_DIMS) + and full_size < domain_sizes[axis] + ): + ndsl_log.warning( + f"Quantity `{name}` is too small for the targeted " + f"domain in axis {axis}: {full_size} < {domain_sizes[axis]}." + ) + elif not isinstance(argument, numbers.Real): + ndsl_log.warning( + f"Found an array-type argument {name} that is not a Quantity. Some domain-size checks are omitted." + ) + + +def _convert_quantities_to_storage(args, kwargs): # type: ignore[no-untyped-def] for i, arg in enumerate(args): try: # Check that 'dims' is an attribute of arg. If so, @@ -588,28 +645,28 @@ def __init__( self.east_edge = east_edge @property - def domain(self): + def domain(self) -> Index3D: return self._domain @domain.setter - def domain(self, domain): + def domain(self, domain: Index3D) -> None: self._domain = domain self._sizer = SubtileGridSizer( nx=domain[0], ny=domain[1], nz=domain[2], n_halo=self.n_halo, - extra_dim_lengths={}, + data_dimensions={}, ) @classmethod def from_sizer_and_communicator( cls, sizer: GridSizer, comm: Communicator - ) -> "GridIndexing": + ) -> GridIndexing: # TODO: if this class is refactored to split off the *_edge booleans, # this init routine can be refactored to require only a GridSizer domain = cast( - Tuple[int, int, int], + tuple[int, int, int], sizer.get_extent([X_DIM, Y_DIM, Z_DIM]), ) south_edge = comm.tile.partitioner.on_tile_bottom(comm.rank) @@ -626,7 +683,7 @@ def from_sizer_and_communicator( ) @property - def max_shape(self): + def max_shape(self) -> Index3D: """ Maximum required storage shape, corresponding to the shape of a cell-corner variable with maximum halo points. @@ -639,74 +696,74 @@ def max_shape(self): return self.domain_full(add=(1, 1, 1 + self.origin[2])) @property - def isc(self): - """start of the compute domain along the x-axis""" + def isc(self) -> int: + """Start of the compute domain along the x-axis""" return self.origin[0] @property - def iec(self): - """last index of the compute domain along the x-axis""" + def iec(self) -> int: + """Last index of the compute domain along the x-axis""" return self.origin[0] + self.domain[0] - 1 @property - def jsc(self): - """start of the compute domain along the y-axis""" + def jsc(self) -> int: + """Start of the compute domain along the y-axis""" return self.origin[1] @property - def jec(self): - """last index of the compute domain along the y-axis""" + def jec(self) -> int: + """Last index of the compute domain along the y-axis""" return self.origin[1] + self.domain[1] - 1 @property - def isd(self): - """start of the full domain including halos along the x-axis""" + def isd(self) -> int: + """Start of the full domain including halos along the x-axis""" return self.origin[0] - self.n_halo @property - def ied(self): - """index of the last data point along the x-axis""" + def ied(self) -> int: + """Index of the last data point along the x-axis""" return self.isd + self.domain[0] + 2 * self.n_halo - 1 @property - def jsd(self): - """start of the full domain including halos along the y-axis""" + def jsd(self) -> int: + """Start of the full domain including halos along the y-axis""" return self.origin[1] - self.n_halo @property - def jed(self): - """index of the last data point along the y-axis""" + def jed(self) -> int: + """Index of the last data point along the y-axis""" return self.jsd + self.domain[1] + 2 * self.n_halo - 1 @property - def nw_corner(self): + def nw_corner(self) -> bool: return self.north_edge and self.west_edge @property - def sw_corner(self): + def sw_corner(self) -> bool: return self.south_edge and self.west_edge @property - def ne_corner(self): + def ne_corner(self) -> bool: return self.north_edge and self.east_edge @property - def se_corner(self): + def se_corner(self) -> bool: return self.south_edge and self.east_edge - def origin_full(self, add: Index3D = (0, 0, 0)): + def origin_full(self, add: Index3D = (0, 0, 0)) -> Index3D: """ Returns the origin of the full domain including halos, plus an optional offset. """ return (self.isd + add[0], self.jsd + add[1], self.origin[2] + add[2]) - def origin_compute(self, add: Index3D = (0, 0, 0)): + def origin_compute(self, add: Index3D = (0, 0, 0)) -> Index3D: """ Returns the origin of the compute domain, plus an optional offset """ return (self.isc + add[0], self.jsc + add[1], self.origin[2] + add[2]) - def domain_full(self, add: Index3D = (0, 0, 0)): + def domain_full(self, add: Index3D = (0, 0, 0)) -> Index3D: """ Returns the shape of the full domain including halos, plus an optional offset. """ @@ -716,7 +773,7 @@ def domain_full(self, add: Index3D = (0, 0, 0)): self.domain[2] + add[2], ) - def domain_compute(self, add: Index3D = (0, 0, 0)): + def domain_compute(self, add: Index3D = (0, 0, 0)) -> Index3D: """ Returns the shape of the compute domain, plus an optional offset. """ @@ -727,10 +784,8 @@ def domain_compute(self, add: Index3D = (0, 0, 0)): ) def axis_offsets( - self, - origin: Tuple[int, ...], - domain: Tuple[int, ...], - ) -> Dict[str, Any]: + self, origin: tuple[int, ...], domain: tuple[int, ...] + ) -> dict[str, Any]: if self.west_edge: i_start = gtscript.I[0] + self.origin[0] - origin[0] else: @@ -775,7 +830,7 @@ def axis_offsets( def get_origin_domain( self, dims: Sequence[str], halos: Sequence[int] = tuple() - ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Get the origin and domain for a computation that occurs over a certain grid configuration (given by dims) and a certain number of halo points. @@ -795,7 +850,26 @@ def get_origin_domain( domain[i] += 2 * n return tuple(origin), tuple(domain) - def _origin_from_dims(self, dims: Iterable[str]) -> List[int]: + def get_2d_compute_origin_domain( + self, + klevel: int = 0, + ) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Get the origin and domain for a computation that occurs on the lowest klevel over a certain grid + configuration (given by dims) and a certain number of halo points. + + Args: + klevel: the vertical level of the domain, defaults to zero + + Returns: + origin: origin of the computation + domain: shape of the computation + """ + origin = (self.isc, self.jsc, klevel) + domain = (self.iec + 1 - self.isc, self.jec + 1 - self.jsc, 1) + return (origin, domain) + + def _origin_from_dims(self, dims: Iterable[str]) -> list[int]: return_origin = [] for dim in dims: if dim in X_DIMS: @@ -808,7 +882,7 @@ def _origin_from_dims(self, dims: Iterable[str]) -> List[int]: def get_shape( self, dims: Sequence[str], halos: Sequence[int] = tuple() - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: """ Get the storage shape required for an array with the given dimensions which is accessed up to a given number of halo points. @@ -831,7 +905,9 @@ def get_shape( shape[i] += n return tuple(shape) - def restrict_vertical(self, k_start=0, nk=None) -> "GridIndexing": + def restrict_vertical( + self, k_start: int = 0, nk: int | None = None + ) -> GridIndexing: """ Returns a copy of itself with modified vertical origin and domain. @@ -877,7 +953,7 @@ def __init__( self, config: StencilConfig, grid_indexing: GridIndexing, - comm: Optional[Comm] = None, + comm: Comm | None = None, ): """ Args: @@ -893,17 +969,17 @@ def __init__( self.comm = comm @property - def backend(self): + def backend(self) -> str: return self.config.compilation_config.backend def from_origin_domain( self, func: Callable[..., None], - origin: Union[Tuple[int, ...], Mapping[str, Tuple[int, ...]]], - domain: Tuple[int, ...], - externals: Optional[Mapping[str, Any]] = None, - skip_passes: Tuple[str, ...] = (), - ) -> Union[FrozenStencil, CompareToNumpyStencil]: + origin: tuple[int, ...] | Mapping[str, tuple[int, ...]], + domain: tuple[int, ...], + externals: Mapping[str, Any] | None = None, + skip_passes: tuple[str, ...] = (), + ) -> FrozenStencil | CompareToNumpyStencil: """ Args: func: stencil definition function @@ -914,7 +990,7 @@ def from_origin_domain( skip_passes: compiler passes to skip when building stencil """ if self.config.compare_to_numpy: - cls: Type = CompareToNumpyStencil + cls: type = CompareToNumpyStencil else: cls = FrozenStencil return cls( @@ -933,9 +1009,9 @@ def from_dims_halo( func: Callable[..., None], compute_dims: Sequence[str], compute_halos: Sequence[int] = tuple(), - externals: Optional[Mapping[str, Any]] = None, - skip_passes: Tuple[str, ...] = (), - ) -> Union[FrozenStencil, CompareToNumpyStencil]: + externals: Mapping[str, Any] | None = None, + skip_passes: tuple[str, ...] = (), + ) -> FrozenStencil | CompareToNumpyStencil: """ Initialize a stencil from dimensions and number of halo points. @@ -969,29 +1045,31 @@ def from_dims_halo( skip_passes=skip_passes, ) - def restrict_vertical(self, k_start=0, nk=None) -> "StencilFactory": + def restrict_vertical( + self, k_start: int = 0, nk: int | None = None + ) -> StencilFactory: return StencilFactory( config=self.config, grid_indexing=self.grid_indexing.restrict_vertical(k_start=k_start, nk=nk), comm=self.comm, ) - def build_report(self, key: str = "build_time", **kwargs) -> str: + def build_report(self, key: str = "build_time", **kwargs: Any) -> str: """Report all stencils built by this factory.""" return self.timing_collector.build_report(key, **kwargs) - def exec_report(self, key: str = "total_run_time", **kwargs) -> str: + def exec_report(self, key: str = "total_run_time", **kwargs: Any) -> str: """Report all stencils executed that were built by this factory.""" return self.timing_collector.exec_report(key, **kwargs) def get_stencils_with_varied_bounds( func: Callable[..., None], - origins: List[Index3D], - domains: List[Index3D], + origins: list[Index3D], + domains: list[Index3D], stencil_factory: StencilFactory, - externals: Optional[Mapping[str, Any]] = None, -) -> List[Union[FrozenStencil, CompareToNumpyStencil]]: + externals: Mapping[str, Any] | None = None, +) -> list[FrozenStencil | CompareToNumpyStencil]: assert len(origins) == len(domains), ( "Lists of origins and domains need to have the same length, you provided " + str(len(origins)) diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index 4d3eafab..b4b57183 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import dataclasses import enum import hashlib -from typing import Any, Callable, Dict, Hashable, Iterable, Optional, Sequence, Tuple +from collections.abc import Callable, Hashable, Iterable, Sequence +from typing import Any, Self from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline @@ -35,9 +38,9 @@ def __init__( device_sync: bool = False, run_mode: RunMode = RunMode.BuildAndRun, use_minimal_caching: bool = False, - communicator: Optional[Communicator] = None, + communicator: Communicator | None = None, ) -> None: - if (not ("gpu" in backend or "cuda" in backend)) and device_sync is True: + if "gpu" not in backend and device_sync is True: raise RuntimeError("Device sync is true on a CPU based backend") # GT4Py backend args self.backend = backend @@ -117,8 +120,8 @@ def determine_compiling_equivalent( raise RuntimeError("Illegal partition specified") def get_decomposition_info_from_comm( - self, communicator: Optional[Communicator] - ) -> Tuple[int, int, int, bool]: + self, communicator: Communicator | None + ) -> tuple[int, int, int, bool]: if communicator: self.check_communicator(communicator) rank = communicator.rank @@ -138,7 +141,7 @@ def get_decomposition_info_from_comm( is_compiling = True return rank, size, equivalent_compiling_rank, is_compiling - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: return { "backend": self.backend, "rebuild": self.rebuild, @@ -150,7 +153,7 @@ def as_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: dict) -> Self: instance = cls( backend=data.get("backend", "numpy"), rebuild=data.get("rebuild", False), @@ -168,31 +171,43 @@ def from_dict(cls, data: dict): class StencilConfig(Hashable): compare_to_numpy: bool = False compilation_config: CompilationConfig = CompilationConfig() - dace_config: Optional[DaceConfig] = None verbose: bool = False + dace_config: DaceConfig = dataclasses.field(init=False) - def __post_init__(self): - self.backend_opts = { - "device_sync": self.compilation_config.device_sync, - "format_source": self.compilation_config.format_source, - } - self._hash = self._compute_hash() + def __init__( + self, + *, + compare_to_numpy: bool = False, + compilation_config: CompilationConfig | None = None, + verbose: bool = False, + dace_config: DaceConfig | None = None, + ): + if compilation_config is None: + compilation_config = CompilationConfig() - # We need a DaceConfig to know if orchestration is part of the build system - # but we can't hash it very well (for now). The workaround is to make - # sure we have a default Python orchestrated config. - if self.dace_config is None: - self.dace_config = DaceConfig( + self.compare_to_numpy = compare_to_numpy + self.compilation_config = compilation_config + self.verbose = verbose + self.dace_config = ( + dace_config + if dace_config is not None + else DaceConfig( communicator=None, backend=self.compilation_config.backend, orchestration=DaCeOrchestration.Python, ) + ) + self.backend_opts = { + "device_sync": self.compilation_config.device_sync, + "format_source": self.compilation_config.format_source, + } + self._hash = self._compute_hash() @property - def backend(self): + def backend(self) -> str: return self.compilation_config.backend - def _compute_hash(self): + def _compute_hash(self) -> int: md5 = hashlib.md5() md5.update(self.compilation_config.backend.encode()) for attr in ( @@ -203,16 +218,16 @@ def _compute_hash(self): self.backend_opts["format_source"], ): md5.update(bytes(attr)) - attr = self.backend_opts.get("device_sync", None) + attr = self.backend_opts.get("device_sync", False) if attr: md5.update(bytes(attr)) md5.update(bytes(self.compilation_config.run_mode.value)) return int(md5.hexdigest(), base=16) - def __hash__(self): + def __hash__(self) -> int: return self._hash - def __eq__(self, other): + def __eq__(self, other: object) -> bool: try: return self.__hash__() == other.__hash__() except AttributeError: @@ -220,7 +235,7 @@ def __eq__(self, other): def stencil_kwargs( self, *, func: Callable[..., None], skip_passes: Iterable[str] = () - ): + ) -> dict: kwargs = { "backend": self.compilation_config.backend, "rebuild": self.compilation_config.rebuild, @@ -231,7 +246,7 @@ def stencil_kwargs( kwargs.pop("device_sync", None) if skip_passes or kwargs.get("skip_passes", ()): kwargs["oir_pipeline"] = StencilConfig._get_oir_pipeline( - list(kwargs.pop("skip_passes", ())) + list(skip_passes) # type: ignore + list(kwargs.pop("skip_passes", ())) + list(skip_passes) # type: ignore[call-overload] ) return kwargs diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 5f60f401..7229cc9b 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -1,9 +1,10 @@ -import os -from typing import Tuple, TypeAlias, Union, cast +from typing import TypeAlias import numpy as np from gt4py.cartesian import gtscript +from ndsl.dsl import NDSL_GLOBAL_PRECISION + # A Field Field = gtscript.Field @@ -19,11 +20,11 @@ K = gtscript.K # noqa: E741 # Union of valid data types (from gt4py.cartesian.gtscript) -DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64] +DTypes = bool | np.bool_ | int | np.int32 | np.int64 | float | np.float32 | np.float64 def get_precision() -> int: - return int(os.getenv("PACE_FLOAT_PRECISION", "64")) + return NDSL_GLOBAL_PRECISION # We redefine the type as a way to distinguish @@ -35,10 +36,10 @@ def get_precision() -> int: NDSL_64BIT_INT_TYPE: TypeAlias = np.int64 -def global_set_precision() -> Tuple[TypeAlias, TypeAlias]: +def global_set_precision() -> tuple[TypeAlias, TypeAlias]: """Set the global precision for all references of Float and Int in the codebase. Defaults to 64 bit.""" - global Float, Int + global Float, Int # noqa: F824 global ... is unused precision_in_bit = get_precision() if precision_in_bit == 64: return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE @@ -92,10 +93,10 @@ def global_set_precision() -> Tuple[TypeAlias, TypeAlias]: BoolFieldK = Field[gtscript.K, Bool] BoolFieldIJ = Field[gtscript.IJ, Bool] -Index3D = Tuple[int, int, int] +Index3D = tuple[int, int, int] -def set_4d_field_size(n, dtype): +def set_4d_field_size(n: int, dtype: type): # type: ignore[no-untyped-def] """ Defines a 4D field with a given size and type The extra data dimension is not parallel @@ -103,13 +104,13 @@ def set_4d_field_size(n, dtype): return Field[gtscript.IJK, (dtype, (n,))] -def cast_to_index3d(val: Tuple[int, ...]) -> Index3D: +def cast_to_index3d(val: tuple[int, ...]) -> Index3D: if len(val) != 3: - raise ValueError(f"expected 3d index, received {val}") - return cast(Index3D, val) + raise ValueError(f"Expected 3d index, received {val}") + return val -def is_float(dtype: type): +def is_float(dtype: type) -> bool: """Expected floating point type""" return dtype in [ Float, diff --git a/ndsl/exceptions.py b/ndsl/exceptions.py index fa5d118a..4511ea69 100644 --- a/ndsl/exceptions.py +++ b/ndsl/exceptions.py @@ -1,7 +1,16 @@ # flake8: noqa +import warnings + from ndsl.comm.local_comm import ConcurrencyError from ndsl.units import UnitsError class OutOfBoundsError(ValueError): - pass + def __init__(self, *args) -> None: + warnings.warn( + "Usage of `OutOfBoundsError` is discouraged. The class will be " + "removed in the next version in favor of using the built-in `IndexError`.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args) diff --git a/ndsl/filesystem.py b/ndsl/filesystem.py index 66c8142e..df2e709f 100644 --- a/ndsl/filesystem.py +++ b/ndsl/filesystem.py @@ -1,16 +1,37 @@ +import warnings + import fsspec def get_fs(path: str) -> fsspec.AbstractFileSystem: """Return the fsspec filesystem required to handle a given path.""" + warnings.warn( + "Usage of `get_fs()` is discouraged if favor `os.path` and `pathlib` " + "modules. The function will be removed in the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) fs, _, _ = fsspec.get_fs_token_paths(path) return fs def is_file(filename): + warnings.warn( + "Usage of `is_file()` is discouraged if favor of plain `os.path.isfile()`. " + "The function will be removed in the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) return get_fs(filename).isfile(filename) def open(filename, *args, **kwargs): + warnings.warn( + "Usage of `open()` is discouraged if favor the python built-in file " + "open context manager. The function will be removed in the next version " + "of NDSL.", + DeprecationWarning, + stacklevel=2, + ) fs = get_fs(filename) return fs.open(filename, *args, **kwargs) diff --git a/ndsl/global_config.py b/ndsl/global_config.py deleted file mode 100644 index 8cfc8a3b..00000000 --- a/ndsl/global_config.py +++ /dev/null @@ -1,48 +0,0 @@ -import functools -import os -from typing import Optional - - -def getenv_bool(name: str, default: str) -> bool: - indicator = os.getenv(name, default).title() - return indicator == "True" - - -def set_backend(new_backend: str): - global _BACKEND - _BACKEND = new_backend - - -def get_backend() -> str: - return _BACKEND - - -def set_rebuild(flag: bool): - global _REBUILD - _REBUILD = flag - - -def get_rebuild() -> bool: - return _REBUILD - - -def set_validate_args(new_validate_args: bool): - global _VALIDATE_ARGS - _VALIDATE_ARGS = new_validate_args - - -# Set to "False" to skip validating gt4py stencil arguments -@functools.lru_cache(maxsize=None) -def get_validate_args() -> bool: - return _VALIDATE_ARGS - - -# Options -# CPU: numpy, gt:cpu_ifirst, gt:cpu_kfirst -# GPU: gt:gpu, cuda -_BACKEND: Optional[str] = None - -# If TRUE, all caches will bypassed and stencils recompiled -# if FALSE, caches will be checked and rebuild if code changes -_REBUILD: bool = getenv_bool("FV3_STENCIL_REBUILD_FLAG", "False") -_VALIDATE_ARGS: bool = True diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index fabe72bf..c397a706 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -9,3 +9,17 @@ HorizontalGridData, VerticalGridData, ) + + +__all__ = [ + "HybridPressureCoefficients", + "GridDefinitions", + "MetricTerms", + "AngleGridData", + "ContravariantGridData", + "DampingCoefficients", + "DriverGridData", + "GridData", + "HorizontalGridData", + "VerticalGridData", +] diff --git a/ndsl/grid/eta.py b/ndsl/grid/eta.py index 35ac510d..21a3f18d 100644 --- a/ndsl/grid/eta.py +++ b/ndsl/grid/eta.py @@ -1,11 +1,12 @@ import math -import os from dataclasses import dataclass -from typing import Optional, Tuple +from pathlib import Path import numpy as np import xarray as xr +from ndsl import logging + ETA_0 = 0.252 SURFACE_PRESSURE = 1.0e5 # units of (Pa), from Table VI of DCMIP2016 @@ -29,8 +30,8 @@ class HybridPressureCoefficients: bk: np.ndarray -def _load_ak_bk_from_file(eta_file: str) -> Tuple[np.ndarray, np.ndarray]: - if not os.path.isfile(eta_file): +def _load_ak_bk_from_file(eta_file: Path) -> tuple[np.ndarray, np.ndarray]: + if not Path.is_file(eta_file): raise ValueError(f"eta file {eta_file} does not exist") # read file into ak, bk arrays @@ -43,9 +44,9 @@ def _load_ak_bk_from_file(eta_file: str) -> Tuple[np.ndarray, np.ndarray]: def set_hybrid_pressure_coefficients( km: int, - eta_file: str, - ak_data: Optional[np.ndarray] = None, - bk_data: Optional[np.ndarray] = None, + eta_file: Path | None = None, + ak_data: np.ndarray | None = None, + bk_data: np.ndarray | None = None, ) -> HybridPressureCoefficients: """ Sets the coefficients describing the hybrid pressure coordinates. @@ -61,8 +62,16 @@ def set_hybrid_pressure_coefficients( a HybridPressureCoefficients dataclass """ if ak_data is None or bk_data is None: + if eta_file is None: + raise ValueError( + "Please specify either and `eta_file` or eta data as `ak_data` and `bk_data`." + ) ak, bk = _load_ak_bk_from_file(eta_file) else: + if eta_file is not None: + logging.ndsl_log.warning( + f"Ignoring eta_file {eta_file} since `ak_data` and `bk_data` were given." + ) ak, bk = ak_data, bk_data # check size of ak and bk array is km+1 @@ -96,7 +105,7 @@ def vertical_coordinate(eta_value) -> np.ndarray: return (eta_value - ETA_0) * math.pi * 0.5 -def compute_eta(ak, bk) -> Tuple[np.ndarray, np.ndarray]: +def compute_eta(ak, bk) -> tuple[np.ndarray, np.ndarray]: """ Equation (1) JRMS2006 eta is the vertical coordinate and eta_v is an auxiliary vertical coordinate diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 7b28c2ff..1bc37bc5 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import dataclasses import functools import warnings -from typing import Optional, Tuple +from pathlib import Path import numpy as np @@ -51,8 +53,7 @@ set_tile_border_dyc, ) from ndsl.grid.mirror import mirror_grid -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer +from ndsl.initialization import QuantityFactory, SubtileGridSizer from ndsl.quantity import Quantity from ndsl.stencils.corners import ( fill_corners_2d, @@ -62,17 +63,6 @@ ) -# TODO: when every environment in python3.8, remove -# this custom decorator -def cached_property(func): - @property - @functools.lru_cache() - def wrapper(self, *args, **kwargs): - return func(self, *args, **kwargs) - - return wrapper - - def ignore_zero_division(func): @functools.wraps(func) def wrapped(*args, **kwargs): @@ -86,7 +76,7 @@ def wrapped(*args, **kwargs): def quantity_cast_to_model_float( quantity_factory: QuantityFactory, qty_64: Quantity ) -> Quantity: - """Copy & cast from 64-bit float to model precision if need be""" + """Copy & cast from 64-bit float to model precision if need be.""" qty = quantity_factory.zeros(qty_64.dims, qty_64.units, dtype=Float) qty.data[:] = qty_64.data[:] return qty @@ -94,7 +84,7 @@ def quantity_cast_to_model_float( @dataclasses.dataclass class GridDefinition: - dims: Tuple[str, ...] + dims: tuple[str, ...] units: str @@ -238,9 +228,9 @@ def __init__( dy_const: float = 1000.0, deglat: float = 15.0, extdgrid: bool = False, - eta_file: Optional[str] = None, - ak: Optional[np.ndarray] = None, - bk: Optional[np.ndarray] = None, + eta_file: str | Path | None = None, + ak: np.ndarray | None = None, + bk: np.ndarray | None = None, ): self._grid_type = grid_type self._dx_const = dx_const @@ -252,8 +242,8 @@ def __init__( self._tile_partitioner = self._comm.tile.partitioner self._rank = self._comm.rank self.quantity_factory = quantity_factory - self.quantity_factory.set_extra_dim_lengths( - **{ + self.quantity_factory.add_data_dimensions( + { self.LON_OR_LAT_DIM: 2, self.TILE_DIM: 6, self.CARTESIAN_DIM: 3, @@ -276,7 +266,7 @@ def __init__( # This will carry the public version of the grid # for the selected floating point precision self._grid = None - npx, npy, ndims = self._tile_partitioner.global_extent(self._grid_64) + npx, npy, _ = self._tile_partitioner.global_extent(self._grid_64) self._npx = npx self._npy = npy self._npz = self.quantity_factory.sizer.get_extent(Z_DIM)[0] @@ -296,27 +286,26 @@ def __init__( self._dy_agrid = None self._dx_center = None self._dy_center = None - self._area: Optional[Quantity] = None - self._area64: Optional[Quantity] = None + self._area: Quantity | None = None + self._area64: Quantity | None = None self._area_c = None if eta_file is not None or ak is not None or bk is not None: + if type(eta_file) is str: + # Temporary cast. Tob be removed once + eta_file = Path(eta_file) ( self._ks, self._ptop, self._ak, self._bk, - ) = self._set_hybrid_pressure_coefficients(eta_file, ak, bk) - else: - self._ks = self.quantity_factory.zeros( - [], - "", - dtype=Float, - ) - self._ptop = self.quantity_factory.zeros( - [], - "Pa", - dtype=Float, + ) = self._set_hybrid_pressure_coefficients( + eta_file, # type: ignore + ak, + bk, ) + else: + self._ks = 0 + self._ptop = 0 self._ak = self.quantity_factory.zeros( [Z_INTERFACE_DIM], "Pa", @@ -464,11 +453,11 @@ def from_external( quantity_factory, communicator, grid_type, - eta_file: str = "None", - ) -> "MetricTerms": + eta_file: Path | None = None, + ) -> MetricTerms: """ Generates a metric terms object, using input from data contained in an - externally generated tile file + externally generated tile file. """ terms = MetricTerms( quantity_factory=quantity_factory, @@ -482,6 +471,12 @@ def from_external( terms._grid_64.view[:, :, 0] = rad_conv * x terms._grid_64.view[:, :, 1] = rad_conv * y + terms._comm.halo_update(terms._grid_64, n_points=terms._halo) + + fill_corners_2d( + terms._grid_64.data, terms._grid_indexing, gridtype="B", direction="x" + ) + terms._init_agrid() return terms @@ -498,18 +493,13 @@ def from_tile_sizing( dx_const: float = 1000.0, dy_const: float = 1000.0, deglat: float = 15.0, - eta_file: str = "None", - ) -> "MetricTerms": + eta_file: Path | None = None, + ) -> MetricTerms: sizer = SubtileGridSizer.from_tile_params( nx_tile=npx - 1, ny_tile=npy - 1, nz=npz, n_halo=N_HALO_DEFAULT, - extra_dim_lengths={ - cls.LON_OR_LAT_DIM: 2, - cls.TILE_DIM: 6, - cls.CARTESIAN_DIM: 3, - }, layout=communicator.partitioner.tile.layout, ) quantity_factory = QuantityFactory.from_backend(sizer, backend=backend) @@ -533,9 +523,7 @@ def grid(self): @property def dgrid_lon_lat(self): - """ - the longitudes and latitudes of the cell corners - """ + """The longitudes and latitudes of the cell corners.""" return self.grid @property @@ -552,9 +540,7 @@ def agrid(self): @property def agrid_lon_lat(self): - """ - the longitudes and latitudes of the cell centers - """ + """The longitudes and latitudes of the cell centers.""" return self.agrid @property @@ -566,6 +552,7 @@ def lon(self): extent=self.grid.extent[:2], units=self.grid.units, gt4py_backend=self.grid.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) @property @@ -577,6 +564,7 @@ def lat(self) -> Quantity: extent=self.grid.extent[:2], units=self.grid.units, gt4py_backend=self.grid.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) @property @@ -588,6 +576,7 @@ def lon_agrid(self) -> Quantity: extent=self.agrid.extent[:2], units=self.agrid.units, gt4py_backend=self.agrid.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) @property @@ -599,73 +588,61 @@ def lat_agrid(self) -> Quantity: extent=self.agrid.extent[:2], units=self.agrid.units, gt4py_backend=self.agrid.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) @property def dx(self) -> Quantity: - """ - the distance between grid corners along the x-direction - """ + """The distance between grid corners along the x-direction.""" if self._dx is None: self._dx, self._dy = self._compute_dxdy() - return self._dx + return self._dx # type: ignore[return-value] @property def dy(self) -> Quantity: - """ - the distance between grid corners along the y-direction - """ + """The distance between grid corners along the y-direction.""" if self._dy is None: self._dx, self._dy = self._compute_dxdy() - return self._dy + return self._dy # type: ignore[return-value] @property def dxa(self) -> Quantity: - """ - the with of each grid cell along the x-direction - """ + """The with of each grid cell along the x-direction.""" if self._dx_agrid is None: self._dx_agrid, self._dy_agrid = self._compute_dxdy_agrid() - return self._dx_agrid + return self._dx_agrid # type: ignore[return-value] @property def dya(self) -> Quantity: - """ - the with of each grid cell along the y-direction - """ + """The with of each grid cell along the y-direction.""" if self._dy_agrid is None: self._dx_agrid, self._dy_agrid = self._compute_dxdy_agrid() - return self._dy_agrid + return self._dy_agrid # type: ignore[return-value] @property def dxc(self) -> Quantity: - """ - the distance between cell centers along the x-direction - """ + """The distance between cell centers along the x-direction.""" if self._dx_center is None: self._dx_center, self._dy_center = self._compute_dxdy_center() - return self._dx_center + return self._dx_center # type: ignore[return-value] @property def dyc(self) -> Quantity: - """ - the distance between cell centers along the y-direction - """ + """The distance between cell centers along the y-direction.""" if self._dy_center is None: self._dx_center, self._dy_center = self._compute_dxdy_center() - return self._dy_center + return self._dy_center # type: ignore[return-value] @property - def ks(self) -> Quantity: - """ - number of levels where the vertical coordinate is purely pressure-based - """ + def ks(self) -> int: + """Number of levels where the vertical coordinate is purely pressure-based.""" return self._ks @property def ak(self) -> Quantity: """ - the ak coefficient used to calculate the pressure at a given k-level: + The ak coefficient used to calculate the pressure at a given k-level: + pk = ak + (bk * ps) """ return self._ak @@ -673,66 +650,66 @@ def ak(self) -> Quantity: @property def bk(self) -> Quantity: """ - the bk coefficient used to calculate the pressure at a given k-level: + The bk coefficient used to calculate the pressure at a given k-level: + pk = ak + (bk * ps) """ return self._bk @property - def ptop(self) -> Quantity: - """ - the pressure of the top of atmosphere level - """ + def ptop(self) -> int: + """The pressure of the top of atmosphere level.""" return self._ptop @property def ec1(self) -> Quantity: """ - cartesian components of the local unit vector - in the x-direction at the cell centers - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vector in the x-direction at the cell centers. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._ec1 is None: self._ec1, self._ec2 = self._calculate_center_vectors() - return self._ec1 + return self._ec1 # type: ignore[return-value] @property def ec2(self) -> Quantity: """ - cartesian components of the local unit vector - in the y-direction at the cell centers - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vector in the y-direction at the cell centers. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._ec2 is None: self._ec1, self._ec2 = self._calculate_center_vectors() - return self._ec2 + return self._ec2 # type: ignore[return-value] @property def ew1(self) -> Quantity: """ - cartesian components of the local unit vector - in the x-direction at the left/right cell edges - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vector in the x-direction at the left/right cell edges. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._ew1 is None: self._ew1, self._ew2 = self._calculate_vectors_west() - return self._ew1 + return self._ew1 # type: ignore[return-value] @property def ew2(self) -> Quantity: """ - cartesian components of the local unit vector - in the y-direction at the left/right cell edges - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vector in the y-direction at the left/right cell edges. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._ew2 is None: self._ew1, self._ew2 = self._calculate_vectors_west() - return self._ew2 + return self._ew2 # type: ignore[return-value] @property def cos_sg1(self) -> Quantity: """ Cosine of the angle at point 1 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -741,12 +718,13 @@ def cos_sg1(self) -> Quantity: """ if self._cos_sg1 is None: self._init_cell_trigonometry() - return self._cos_sg1 + return self._cos_sg1 # type: ignore[return-value] @property def cos_sg2(self) -> Quantity: """ Cosine of the angle at point 2 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -755,12 +733,13 @@ def cos_sg2(self) -> Quantity: """ if self._cos_sg2 is None: self._init_cell_trigonometry() - return self._cos_sg2 + return self._cos_sg2 # type: ignore[return-value] @property def cos_sg3(self) -> Quantity: """ Cosine of the angle at point 3 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -769,12 +748,13 @@ def cos_sg3(self) -> Quantity: """ if self._cos_sg3 is None: self._init_cell_trigonometry() - return self._cos_sg3 + return self._cos_sg3 # type: ignore[return-value] @property def cos_sg4(self) -> Quantity: """ Cosine of the angle at point 4 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -783,12 +763,13 @@ def cos_sg4(self) -> Quantity: """ if self._cos_sg4 is None: self._init_cell_trigonometry() - return self._cos_sg4 + return self._cos_sg4 # type: ignore[return-value] @property def cos_sg5(self) -> Quantity: """ Cosine of the angle at point 5 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -798,12 +779,13 @@ def cos_sg5(self) -> Quantity: """ if self._cos_sg5 is None: self._init_cell_trigonometry() - return self._cos_sg5 + return self._cos_sg5 # type: ignore[return-value] @property def cos_sg6(self) -> Quantity: """ Cosine of the angle at point 6 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -812,12 +794,13 @@ def cos_sg6(self) -> Quantity: """ if self._cos_sg6 is None: self._init_cell_trigonometry() - return self._cos_sg6 + return self._cos_sg6 # type: ignore[return-value] @property def cos_sg7(self) -> Quantity: """ Cosine of the angle at point 7 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -826,12 +809,13 @@ def cos_sg7(self) -> Quantity: """ if self._cos_sg7 is None: self._init_cell_trigonometry() - return self._cos_sg7 + return self._cos_sg7 # type: ignore[return-value] @property def cos_sg8(self) -> Quantity: """ Cosine of the angle at point 8 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -840,12 +824,13 @@ def cos_sg8(self) -> Quantity: """ if self._cos_sg8 is None: self._init_cell_trigonometry() - return self._cos_sg8 + return self._cos_sg8 # type: ignore[return-value] @property def cos_sg9(self) -> Quantity: """ Cosine of the angle at point 9 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -854,12 +839,13 @@ def cos_sg9(self) -> Quantity: """ if self._cos_sg9 is None: self._init_cell_trigonometry() - return self._cos_sg9 + return self._cos_sg9 # type: ignore[return-value] @property def sin_sg1(self) -> Quantity: """ Sine of the angle at point 1 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -868,12 +854,13 @@ def sin_sg1(self) -> Quantity: """ if self._sin_sg1 is None: self._init_cell_trigonometry() - return self._sin_sg1 + return self._sin_sg1 # type: ignore[return-value] @property def sin_sg2(self) -> Quantity: """ Sine of the angle at point 2 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -882,12 +869,13 @@ def sin_sg2(self) -> Quantity: """ if self._sin_sg2 is None: self._init_cell_trigonometry() - return self._sin_sg2 + return self._sin_sg2 # type: ignore[return-value] @property def sin_sg3(self) -> Quantity: """ Sine of the angle at point 3 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -896,12 +884,13 @@ def sin_sg3(self) -> Quantity: """ if self._sin_sg3 is None: self._init_cell_trigonometry() - return self._sin_sg3 + return self._sin_sg3 # type: ignore[return-value] @property def sin_sg4(self) -> Quantity: """ Sine of the angle at point 4 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -910,27 +899,30 @@ def sin_sg4(self) -> Quantity: """ if self._sin_sg4 is None: self._init_cell_trigonometry() - return self._sin_sg4 + return self._sin_sg4 # type: ignore[return-value] @property def sin_sg5(self) -> Quantity: """ Sine of the angle at point 5 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 | | 6---2---7 + For the center point this is one minus the inner product of ec1 and ec2 squared """ if self._sin_sg5 is None: self._init_cell_trigonometry() - return self._sin_sg5 + return self._sin_sg5 # type: ignore[return-value] @property def sin_sg6(self) -> Quantity: """ Sine of the angle at point 6 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -939,12 +931,13 @@ def sin_sg6(self) -> Quantity: """ if self._sin_sg6 is None: self._init_cell_trigonometry() - return self._sin_sg6 + return self._sin_sg6 # type: ignore[return-value] @property def sin_sg7(self) -> Quantity: """ Sine of the angle at point 7 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -953,12 +946,13 @@ def sin_sg7(self) -> Quantity: """ if self._sin_sg7 is None: self._init_cell_trigonometry() - return self._sin_sg7 + return self._sin_sg7 # type: ignore[return-value] @property def sin_sg8(self) -> Quantity: """ Sine of the angle at point 8 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -967,12 +961,13 @@ def sin_sg8(self) -> Quantity: """ if self._sin_sg8 is None: self._init_cell_trigonometry() - return self._sin_sg8 + return self._sin_sg8 # type: ignore[return-value] @property def sin_sg9(self) -> Quantity: """ Sine of the angle at point 9 of the 'supergrid' within each grid cell: + 9---4---8 | | 1 5 3 @@ -981,71 +976,84 @@ def sin_sg9(self) -> Quantity: """ if self._sin_sg9 is None: self._init_cell_trigonometry() - return self._sin_sg9 + return self._sin_sg9 # type: ignore[return-value] @property def cosa(self) -> Quantity: """ - cosine of angle between coordinate lines at the cell corners - averaged to ensure consistent answers + Cosine of angle between coordinate lines at the cell corners. + + Averaged to ensure consistent answers. """ if self._cosa is None: self._init_cell_trigonometry() - return self._cosa + return self._cosa # type: ignore[return-value] @property def sina(self) -> Quantity: """ - as cosa but sine + Sine of angle between coordinate lines at the cell corners. + + Averaged to ensure consistent answers. """ if self._sina is None: self._init_cell_trigonometry() - return self._sina + return self._sina # type: ignore[return-value] @property def cosa_u(self) -> Quantity: """ - as cosa but defined at the left and right cell edges + Cosine of angle between coordinate lines at the left and right cell edges. + + Averaged to ensure consistent answers. """ if self._cosa_u is None: self._init_cell_trigonometry() - return self._cosa_u + return self._cosa_u # type: ignore[return-value] @property def cosa_v(self) -> Quantity: """ - as cosa but defined at the top and bottom cell edges + Cosine of angle between coordinate lines at the top and bottom cell edges. + + Averaged to ensure consistent answers. """ if self._cosa_v is None: self._init_cell_trigonometry() - return self._cosa_v + return self._cosa_v # type: ignore[return-value] @property def cosa_s(self) -> Quantity: """ as cosa but defined at cell centers + + Averaged to ensure consistent answers. """ if self._cosa_s is None: self._init_cell_trigonometry() - return self._cosa_s + return self._cosa_s # type: ignore[return-value] @property def sina_u(self) -> Quantity: """ - as cosa_u but with sine + Sine of angle between coordinate lines at the left and right cell edges. + + Averaged to ensure consistent answers. """ if self._sina_u is None: self._init_cell_trigonometry() - return self._sina_u + return self._sina_u # type: ignore[return-value] @property def sina_v(self) -> Quantity: """ - as cosa_v but with sine + Sine of angle between coordinate lines at the top and bottom cell edges. + + Averaged to ensure consistent answers. """ if self._sina_v is None: self._init_cell_trigonometry() - return self._sina_v + return self._sina_v # type: ignore[return-value] @property def rsin_u(self) -> Quantity: @@ -1055,7 +1063,7 @@ def rsin_u(self) -> Quantity: """ if self._rsin_u is None: self._init_cell_trigonometry() - return self._rsin_u + return self._rsin_u # type: ignore[return-value] @property def rsin_v(self) -> Quantity: @@ -1065,7 +1073,7 @@ def rsin_v(self) -> Quantity: """ if self._rsin_v is None: self._init_cell_trigonometry() - return self._rsin_v + return self._rsin_v # type: ignore[return-value] @property def rsina(self) -> Quantity: @@ -1075,7 +1083,7 @@ def rsina(self) -> Quantity: """ if self._rsina is None: self._init_cell_trigonometry() - return self._rsina + return self._rsina # type: ignore[return-value] @property def rsin2(self) -> Quantity: @@ -1085,71 +1093,75 @@ def rsin2(self) -> Quantity: """ if self._rsin2 is None: self._init_cell_trigonometry() - return self._rsin2 + return self._rsin2 # type: ignore[return-value] @property def l2c_v(self) -> Quantity: """ - angular momentum correction for converting v-winds - from lat/lon to cartesian coordinates + Angular momentum correction for converting v-winds + from lat/lon to cartesian coordinates. """ if self._l2c_v is None: self._l2c_v, self._l2c_u = self._calculate_latlon_momentum_correction() - return self._l2c_v + return self._l2c_v # type: ignore[return-value] @property def l2c_u(self) -> Quantity: """ - angular momentum correction for converting u-winds - from lat/lon to cartesian coordinates + Angular momentum correction for converting u-winds + from lat/lon to cartesian coordinates. """ if self._l2c_u is None: self._l2c_v, self._l2c_u = self._calculate_latlon_momentum_correction() - return self._l2c_u + return self._l2c_u # type: ignore[return-value] @property def es1(self) -> Quantity: """ - cartesian components of the local unit vetcor - in the x-direation at the top/bottom cell edges, - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vetcor + in the x-direation at the top/bottom cell edges. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._es1 is None: self._es1, self._es2 = self._calculate_vectors_south() - return self._es1 + return self._es1 # type: ignore[return-value] @property def es2(self) -> Quantity: """ - cartesian components of the local unit vetcor - in the y-direation at the top/bottom cell edges, - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vetcor + in the y-direation at the top/bottom cell edges. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._es2 is None: self._es1, self._es2 = self._calculate_vectors_south() - return self._es2 + return self._es2 # type: ignore[return-value] @property def ee1(self) -> Quantity: """ - cartesian components of the local unit vetcor - in the x-direation at the cell corners, - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vetcor + in the x-direation at the cell corners. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._ee1 is None: self._ee1, self._ee2 = self._calculate_xy_unit_vectors() - return self._ee1 + return self._ee1 # type: ignore[return-value] @property def ee2(self) -> Quantity: """ - cartesian components of the local unit vetcor - in the y-direation at the cell corners, - 3d array whose last dimension is length 3 and indicates cartesian x/y/z value + Cartesian components of the local unit vetcor + in the y-direation at the cell corners. + + 3d array whose last dimension is length 3 and indicates cartesian x/y/z value. """ if self._ee2 is None: self._ee1, self._ee2 = self._calculate_xy_unit_vectors() - return self._ee2 + return self._ee2 # type: ignore[return-value] @property def divg_u(self) -> Quantity: @@ -1163,7 +1175,7 @@ def divg_u(self) -> Quantity: self._divg_u, self._divg_v, ) = self._calculate_divg_del6() - return self._divg_u + return self._divg_u # type: ignore[return-value] @property def divg_v(self) -> Quantity: @@ -1177,7 +1189,7 @@ def divg_v(self) -> Quantity: self._divg_u, self._divg_v, ) = self._calculate_divg_del6() - return self._divg_v + return self._divg_v # type: ignore[return-value] @property def del6_u(self) -> Quantity: @@ -1191,7 +1203,7 @@ def del6_u(self) -> Quantity: self._divg_u, self._divg_v, ) = self._calculate_divg_del6() - return self._del6_u + return self._del6_u # type: ignore[return-value] @property def del6_v(self) -> Quantity: @@ -1205,67 +1217,69 @@ def del6_v(self) -> Quantity: self._divg_u, self._divg_v, ) = self._calculate_divg_del6() - return self._del6_v + return self._del6_v # type: ignore[return-value] @property def vlon(self) -> Quantity: """ - unit vector in eastward longitude direction, - 3d array whose last dimension is length 3 and indicates x/y/z value + Unit vector in eastward longitude direction. + + 3d array whose last dimension is length 3 and indicates x/y/z value. """ if self._vlon is None: self._vlon, self._vlat = self._calculate_unit_vectors_lonlat() - return self._vlon + return self._vlon # type: ignore[return-value] @property def vlat(self) -> Quantity: """ - unit vector in northward latitude direction, - 3d array whose last dimension is length 3 and indicates x/y/z value + Unit vector in northward latitude direction. + + 3d array whose last dimension is length 3 and indicates x/y/z value. """ if self._vlat is None: self._vlon, self._vlat = self._calculate_unit_vectors_lonlat() - return self._vlat + return self._vlat # type: ignore[return-value] @property def z11(self) -> Quantity: """ - vector product of horizontal component of the cell-center vector - with the unit longitude vector + Vector product of horizontal component of the cell-center vector + with the unit longitude vector. """ if self._z11 is None: self._z11, self._z12, self._z21, self._z22 = self._calculate_grid_z() - return self._z11 + return self._z11 # type: ignore[return-value] @property def z12(self) -> Quantity: """ - vector product of horizontal component of the cell-center vector - with the unit latitude vector + Vector product of horizontal component of the cell-center vector + with the unit latitude vector. """ if self._z12 is None: self._z11, self._z12, self._z21, self._z22 = self._calculate_grid_z() - return self._z12 + return self._z12 # type: ignore[return-value] @property def z21(self) -> Quantity: """ - vector product of vertical component of the cell-center vector - with the unit longitude vector + Vector product of vertical component of the cell-center vector + with the unit longitude vector. """ if self._z21 is None: self._z11, self._z12, self._z21, self._z22 = self._calculate_grid_z() - return self._z21 + return self._z21 # type: ignore[return-value] @property def z22(self) -> Quantity: """ - vector product of vertical component of the cell-center vector - with the unit latitude vector + Vector product of vertical component of the cell-center vector + with the unit latitude vector. """ if self._z22 is None: self._z11, self._z12, self._z21, self._z22 = self._calculate_grid_z() - return self._z22 + return self._z22 # type: ignore[return-value] @property def a11(self) -> Quantity: @@ -1274,7 +1288,7 @@ def a11(self) -> Quantity: """ if self._a11 is None: self._a11, self._a12, self._a21, self._a22 = self._calculate_grid_a() - return self._a11 + return self._a11 # type: ignore[return-value] @property def a12(self) -> Quantity: @@ -1283,7 +1297,7 @@ def a12(self) -> Quantity: """ if self._a12 is None: self._a11, self._a12, self._a21, self._a22 = self._calculate_grid_a() - return self._a12 + return self._a12 # type: ignore[return-value] @property def a21(self) -> Quantity: @@ -1292,7 +1306,7 @@ def a21(self) -> Quantity: """ if self._a21 is None: self._a11, self._a12, self._a21, self._a22 = self._calculate_grid_a() - return self._a21 + return self._a21 # type: ignore[return-value] @property def a22(self) -> Quantity: @@ -1301,12 +1315,12 @@ def a22(self) -> Quantity: """ if self._a22 is None: self._a11, self._a12, self._a21, self._a22 = self._calculate_grid_a() - return self._a22 + return self._a22 # type: ignore[return-value] @property def edge_w(self) -> Quantity: """ - factor to interpolate scalars from a to c grid at the western grid edge + Factor to interpolate scalars from a to c grid at the western grid edge. """ if self._edge_w is None: ( @@ -1315,12 +1329,12 @@ def edge_w(self) -> Quantity: self._edge_s, self._edge_n, ) = self._calculate_edge_factors() - return self._edge_w + return self._edge_w # type: ignore[return-value] @property def edge_e(self) -> Quantity: """ - factor to interpolate scalars from a to c grid at the eastern grid edge + Factor to interpolate scalars from a to c grid at the eastern grid edge. """ if self._edge_e is None: ( @@ -1329,12 +1343,12 @@ def edge_e(self) -> Quantity: self._edge_s, self._edge_n, ) = self._calculate_edge_factors() - return self._edge_e + return self._edge_e # type: ignore[return-value] @property def edge_s(self) -> Quantity: """ - factor to interpolate scalars from a to c grid at the southern grid edge + Factor to interpolate scalars from a to c grid at the southern grid edge. """ if self._edge_s is None: ( @@ -1343,12 +1357,12 @@ def edge_s(self) -> Quantity: self._edge_s, self._edge_n, ) = self._calculate_edge_factors() - return self._edge_s + return self._edge_s # type: ignore[return-value] @property def edge_n(self) -> Quantity: """ - factor to interpolate scalars from a to c grid at the northern grid edge + Factor to interpolate scalars from a to c grid at the northern grid edge. """ if self._edge_n is None: ( @@ -1357,12 +1371,12 @@ def edge_n(self) -> Quantity: self._edge_s, self._edge_n, ) = self._calculate_edge_factors() - return self._edge_n + return self._edge_n # type: ignore[return-value] @property def edge_vect_w_1d(self) -> Quantity: """ - factor to interpolate vectors from a to c grid at the western grid edge + Factor to interpolate vectors from a to c grid at the western grid edge. """ if self._edge_vect_w is None: ( @@ -1371,25 +1385,25 @@ def edge_vect_w_1d(self) -> Quantity: self._edge_vect_s, self._edge_vect_n, ) = self._calculate_edge_a2c_vect_factors() - return self._edge_vect_w + return self._edge_vect_w # type: ignore[return-value] @property def edge_vect_w(self) -> Quantity: """ - factor to interpolate vectors from a to c grid at the western grid edge - repeated in x and y to be used in stencils + Factor to interpolate vectors from a to c grid at the western grid edge + repeated in x and y to be used in stencils. """ if self._edge_vect_w_2d is None: ( self._edge_vect_e_2d, self._edge_vect_w_2d, ) = self._calculate_2d_edge_a2c_vect_factors() - return self._edge_vect_w_2d + return self._edge_vect_w_2d # type: ignore[return-value] @property def edge_vect_e_1d(self) -> Quantity: """ - factor to interpolate vectors from a to c grid at the eastern grid edge + Factor to interpolate vectors from a to c grid at the eastern grid edge. """ if self._edge_vect_e is None: ( @@ -1398,25 +1412,25 @@ def edge_vect_e_1d(self) -> Quantity: self._edge_vect_s, self._edge_vect_n, ) = self._calculate_edge_a2c_vect_factors() - return self._edge_vect_e + return self._edge_vect_e # type: ignore[return-value] @property def edge_vect_e(self) -> Quantity: """ - factor to interpolate vectors from a to c grid at the eastern grid edge - repeated in x and y to be used in stencils + Factor to interpolate vectors from a to c grid at the eastern grid edge + repeated in x and y to be used in stencils. """ if self._edge_vect_e_2d is None: ( self._edge_vect_e_2d, self._edge_vect_w_2d, ) = self._calculate_2d_edge_a2c_vect_factors() - return self._edge_vect_e_2d + return self._edge_vect_e_2d # type: ignore[return-value] @property def edge_vect_s(self) -> Quantity: """ - factor to interpolate vectors from a to c grid at the southern grid edge + Factor to interpolate vectors from a to c grid at the southern grid edge. """ if self._edge_vect_s is None: ( @@ -1425,12 +1439,12 @@ def edge_vect_s(self) -> Quantity: self._edge_vect_s, self._edge_vect_n, ) = self._calculate_edge_a2c_vect_factors() - return self._edge_vect_s + return self._edge_vect_s # type: ignore[return-value] @property def edge_vect_n(self) -> Quantity: """ - factor to interpolate vectors from a to c grid at the northern grid edge + Factor to interpolate vectors from a to c grid at the northern grid edge. """ if self._edge_vect_n is None: ( @@ -1439,100 +1453,94 @@ def edge_vect_n(self) -> Quantity: self._edge_vect_s, self._edge_vect_n, ) = self._calculate_edge_a2c_vect_factors() - return self._edge_vect_n + return self._edge_vect_n # type: ignore[return-value] @property def da_min(self) -> float: """ - the minimum agrid cell area across all ranks, - if mpi is not present and the communicator is a DummyComm this will be - the minimum on the local rank + The minimum agrid cell area across all ranks. + + If mpi is not present and the communicator is a DummyComm this will be + the minimum on the local rank. """ if self._da_min is None: self._reduce_global_area_minmaxes() - return self._da_min + return self._da_min # type: ignore[return-value] @property def da_max(self) -> float: """ - the maximum agrid cell area across all ranks, - if mpi is not present and the communicator is a DummyComm this will be - the maximum on the local rank + Fhe maximum agrid cell area across all ranks. + + Ff mpi is not present and the communicator is a DummyComm this will be + the maximum on the local rank. """ if self._da_max is None: self._reduce_global_area_minmaxes() - return self._da_max + return self._da_max # type: ignore[return-value] @property def da_min_c(self) -> float: """ - the minimum cgrid cell area across all ranks, - if mpi is not present and the communicator is a DummyComm this will be - the minimum on the local rank + The minimum cgrid cell area across all ranks. + + If mpi is not present and the communicator is a DummyComm this will be + the minimum on the local rank. """ if self._da_min_c is None: self._reduce_global_area_minmaxes() - return self._da_min_c + return self._da_min_c # type: ignore[return-value] @property def da_max_c(self) -> float: """ - the maximum cgrid cell area across all ranks, - if mpi is not present and the communicator is a DummyComm this will be - the maximum on the local rank + The maximum cgrid cell area across all ranks. + + If mpi is not present and the communicator is a DummyComm this will be + the maximum on the local rank. """ if self._da_max_c is None: self._reduce_global_area_minmaxes() - return self._da_max_c + return self._da_max_c # type: ignore[return-value] @property def area(self) -> Quantity: - """ - the area of each a-grid cell - """ + """The area of each a-grid cell.""" if self._area is None: self._area, self._area64 = self._compute_area() return self._area @property def area64(self) -> Quantity: - """ - the area of each a-grid cell, at 64-bit precision - """ + """The area of each a-grid cell, at 64-bit precision.""" if self._area64 is None: self._area, self._area64 = self._compute_area() return self._area64 @property def area_c(self) -> Quantity: - """ - the area of each c-grid cell - """ + """The area of each c-grid cell.""" if self._area_c is None: self._area_c = self._compute_area_c() - return self._area_c + return self._area_c # type: ignore[return-value] - @cached_property + @functools.cached_property def _dgrid_xyz_64(self) -> Quantity: - """ - cartesian coordinates of each dgrid cell center - """ + """Cartesian coordinates of each dgrid cell center.""" return lon_lat_to_xyz( self._grid_64.data[:, :, 0], self._grid_64.data[:, :, 1], self._np ) - @cached_property + @functools.cached_property def _agrid_xyz_64(self) -> Quantity: - """ - cartesian coordinates of each agrid cell center - """ + """Cartesian coordinates of each agrid cell center.""" return lon_lat_to_xyz( self._agrid_64.data[:-1, :-1, 0], self._agrid_64.data[:-1, :-1, 1], self._np, ) - @cached_property + @functools.cached_property def rarea(self) -> Quantity: """ 1/cell area @@ -1544,9 +1552,10 @@ def rarea(self) -> Quantity: extent=self.area.extent, units="m^-2", gt4py_backend=self.area.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property def rarea_c(self) -> Quantity: """ 1/cgrid cell area @@ -1558,9 +1567,10 @@ def rarea_c(self) -> Quantity: extent=self.area_c.extent, units="m^-2", gt4py_backend=self.area_c.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property @ignore_zero_division def rdx(self) -> Quantity: """ @@ -1573,9 +1583,10 @@ def rdx(self) -> Quantity: extent=self.dx.extent, units="m^-1", gt4py_backend=self.dx.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property @ignore_zero_division def rdy(self) -> Quantity: """ @@ -1588,9 +1599,10 @@ def rdy(self) -> Quantity: extent=self.dy.extent, units="m^-1", gt4py_backend=self.dy.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property @ignore_zero_division def rdxa(self) -> Quantity: """ @@ -1603,9 +1615,10 @@ def rdxa(self) -> Quantity: extent=self.dxa.extent, units="m^-1", gt4py_backend=self.dxa.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property @ignore_zero_division def rdya(self) -> Quantity: """ @@ -1618,9 +1631,10 @@ def rdya(self) -> Quantity: extent=self.dya.extent, units="m^-1", gt4py_backend=self.dya.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property @ignore_zero_division def rdxc(self) -> Quantity: """ @@ -1633,9 +1647,10 @@ def rdxc(self) -> Quantity: extent=self.dxc.extent, units="m^-1", gt4py_backend=self.dxc.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) - @cached_property + @functools.cached_property @ignore_zero_division def rdyc(self) -> Quantity: """ @@ -1648,6 +1663,7 @@ def rdyc(self) -> Quantity: extent=self.dyc.extent, units="m^-1", gt4py_backend=self.dyc.gt4py_backend, + number_of_halo_points=N_HALO_DEFAULT, ) def _init_cartesian(self): @@ -2184,20 +2200,10 @@ def _compute_area_c_cartesian(self): def _set_hybrid_pressure_coefficients( self, - eta_file, - ak_data: Optional[np.ndarray] = None, - bk_data: Optional[np.ndarray] = None, - ): - ks = self.quantity_factory.zeros( - [], - "", - dtype=Float, - ) - ptop = self.quantity_factory.zeros( - [], - "Pa", - dtype=Float, - ) + eta_file: Path | None = None, + ak_data: np.ndarray | None = None, + bk_data: np.ndarray | None = None, + ) -> tuple[int, int, Quantity, Quantity]: ak = self.quantity_factory.zeros( [Z_INTERFACE_DIM], "Pa", diff --git a/ndsl/grid/geometry.py b/ndsl/grid/geometry.py index 74441cde..c6862693 100644 --- a/ndsl/grid/geometry.py +++ b/ndsl/grid/geometry.py @@ -207,7 +207,7 @@ def calculate_supergrid_cos_sin( cos_sg[abs(1.0 - cos_sg) < 1e-15] = 1.0 - sin_sg_tmp = 1.0 - cos_sg ** 2 + sin_sg_tmp = 1.0 - cos_sg**2 sin_sg_tmp[sin_sg_tmp < 0] = 0.0 sin_sg = np.sqrt(sin_sg_tmp) sin_sg[sin_sg > 1.0] = 1.0 diff --git a/ndsl/grid/global_setup.py b/ndsl/grid/global_setup.py index 60bd3c3b..1f8f68dd 100644 --- a/ndsl/grid/global_setup.py +++ b/ndsl/grid/global_setup.py @@ -40,7 +40,7 @@ def gnomonic_grid(grid_type: int, lon, lat, np): # closer to the Fortran code def global_gnomonic_ed(lon, lat, np): im = lon.shape[0] - 1 - alpha = np.arcsin(3 ** -0.5) + alpha = np.arcsin(3**-0.5) dely = np.multiply(2.0, alpha) / float(im) pp = np.zeros((3, im + 1, im + 1)) @@ -68,16 +68,16 @@ def global_gnomonic_ed(lon, lat, np): i = 0 for j in range(1, im): pp[:, i, j] = _latlon2xyz(lon[i, j], lat[i, j], np) - pp[1, i, j] = -pp[1, i, j] * (3 ** -0.5) / pp[0, i, j] - pp[2, i, j] = -pp[2, i, j] * (3 ** -0.5) / pp[0, i, j] + pp[1, i, j] = -pp[1, i, j] * (3**-0.5) / pp[0, i, j] + pp[2, i, j] = -pp[2, i, j] * (3**-0.5) / pp[0, i, j] j = 0 for i in range(1, im): pp[:, i, j] = _latlon2xyz(lon[i, j], lat[i, j], np) - pp[1, i, j] = -pp[1, i, j] * (3 ** -0.5) / pp[0, i, j] - pp[2, i, j] = -pp[2, i, j] * (3 ** -0.5) / pp[0, i, j] + pp[1, i, j] = -pp[1, i, j] * (3**-0.5) / pp[0, i, j] + pp[2, i, j] = -pp[2, i, j] * (3**-0.5) / pp[0, i, j] - pp[0, :, :] = -(3 ** -0.5) + pp[0, :, :] = -(3**-0.5) for j in range(1, im + 1): # copy y-z face of the cube along j=0 pp[1, 1:, j] = pp[1, 1:, 0] diff --git a/ndsl/grid/gnomonic.py b/ndsl/grid/gnomonic.py index 0fc421ef..02d41e20 100644 --- a/ndsl/grid/gnomonic.py +++ b/ndsl/grid/gnomonic.py @@ -39,7 +39,7 @@ def local_gnomonic_ed( _check_shapes(lon, lat) # tile_im, wedge_dict, corner_dict, global_is, global_js im = lon.shape[0] - 1 - alpha = np.arcsin(3 ** -0.5) + alpha = np.arcsin(3**-0.5) tile_im = npx - 1 dely = np.multiply(2.0, alpha / float(tile_im)) halo = 3 @@ -96,10 +96,10 @@ def local_gnomonic_ed( lon_west_tile_edge[i, j], lat_west_tile_edge[i, j], np ) pp_west_tile_edge[1, i, j] = ( - -pp_west_tile_edge[1, i, j] * (3 ** -0.5) / pp_west_tile_edge[0, i, j] + -pp_west_tile_edge[1, i, j] * (3**-0.5) / pp_west_tile_edge[0, i, j] ) pp_west_tile_edge[2, i, j] = ( - -pp_west_tile_edge[2, i, j] * (3 ** -0.5) / pp_west_tile_edge[0, i, j] + -pp_west_tile_edge[2, i, j] * (3**-0.5) / pp_west_tile_edge[0, i, j] ) if west_edge: pp[:, 0, :] = pp_west_tile_edge[:, 0, :] @@ -110,10 +110,10 @@ def local_gnomonic_ed( lon_south_tile_edge[i, j], lat_south_tile_edge[i, j], np ) pp_south_tile_edge[1, i, j] = ( - -pp_south_tile_edge[1, i, j] * (3 ** -0.5) / pp_south_tile_edge[0, i, j] + -pp_south_tile_edge[1, i, j] * (3**-0.5) / pp_south_tile_edge[0, i, j] ) pp_south_tile_edge[2, i, j] = ( - -pp_south_tile_edge[2, i, j] * (3 ** -0.5) / pp_south_tile_edge[0, i, j] + -pp_south_tile_edge[2, i, j] * (3**-0.5) / pp_south_tile_edge[0, i, j] ) if south_edge: pp[:, :, 0] = pp_south_tile_edge[:, :, 0] @@ -138,7 +138,7 @@ def local_gnomonic_ed( if north_edge and east_edge: pp[:, im, im] = _latlon2xyz(lon_east, lat_north, np) - pp[0, :, :] = -(3 ** -0.5) + pp[0, :, :] = -(3**-0.5) for j in range(start_j, im + 1): # copy y-z face of the cube along j=0 pp[1, start_i:, j] = pp_south_tile_edge[1, start_i:, 0] # pp[1,:,0] @@ -167,14 +167,14 @@ def _corner_to_center_mean(corner_array): def normalize_vector(np, *vector_components): scale = np.divide( 1.0, - np.sum(np.asarray([item ** 2.0 for item in vector_components]), axis=0) ** 0.5, + np.sum(np.asarray([item**2.0 for item in vector_components]), axis=0) ** 0.5, ) return np.asarray([item * scale for item in vector_components]) def normalize_xyz(xyz): # double transpose to broadcast along last dimension instead of first - return (xyz.T / ((xyz ** 2).sum(axis=-1) ** 0.5).T).T + return (xyz.T / ((xyz**2).sum(axis=-1) ** 0.5).T).T def lon_lat_midpoint(lon1, lon2, lat1, lat2, np): @@ -606,7 +606,7 @@ def get_rectangle_area(p1, p2, p3, p4, radius, np): ) in ((p3, p2, p4), (p4, p3, p1), (p1, p4, p2)): total_angle += spherical_angle(q1, q2, q3, np) - return (total_angle - 2 * PI) * radius ** 2 + return (total_angle - 2 * PI) * radius**2 def get_triangle_area(p1, p2, p3, radius, np): @@ -618,7 +618,7 @@ def get_triangle_area(p1, p2, p3, radius, np): total_angle = spherical_angle(p1, p2, p3, np) for q1, q2, q3 in ((p2, p3, p1), (p3, p1, p2)): total_angle += spherical_angle(q1, q2, q3, np) - return (total_angle - PI) * radius ** 2 + return (total_angle - PI) * radius**2 def fortran_vector_spherical_angle(e1, e2, e3): @@ -678,8 +678,7 @@ def spherical_angle(p_center, p2, p3, np): p = np.cross(p_center, p2) q = np.cross(p_center, p3) angle = np.arccos( - np.sum(p * q, axis=-1) - / np.sqrt(np.sum(p ** 2, axis=-1) * np.sum(q ** 2, axis=-1)) + np.sum(p * q, axis=-1) / np.sqrt(np.sum(p**2, axis=-1) * np.sum(q**2, axis=-1)) ) if not np.isscalar(angle): angle[np.isnan(angle)] = 0.0 @@ -696,7 +695,7 @@ def spherical_cos(p_center, p2, p3, np): p = np.cross(p_center, p2) q = np.cross(p_center, p3) return np.sum(p * q, axis=-1) / np.sqrt( - np.sum(p ** 2, axis=-1) * np.sum(q ** 2, axis=-1) + np.sum(p**2, axis=-1) * np.sum(q**2, axis=-1) ) diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 2fbc34a3..dd612b22 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -1,19 +1,18 @@ +from __future__ import annotations + import dataclasses +import os import pathlib import xarray as xr +import ndsl.constants as constants +from ndsl.constants import Z_DIM, Z_INTERFACE_DIM # TODO: if we can remove translate tests in favor of checkpointer tests, # we can remove this "disallowed" import (ndsl.util does not depend on ndsl.dsl) -try: - from ndsl.dsl.gt4py_utils import is_gpu_backend, split_cartesian_into_storages -except ImportError: - split_cartesian_into_storages = None -import ndsl.constants as constants -from ndsl.constants import Z_DIM, Z_INTERFACE_DIM +from ndsl.dsl.gt4py_utils import is_gpu_backend, split_cartesian_into_storages from ndsl.dsl.typing import Float -from ndsl.filesystem import get_fs from ndsl.grid.generation import MetricTerms from ndsl.initialization.allocator import QuantityFactory from ndsl.quantity import Quantity @@ -86,7 +85,7 @@ class HorizontalGridData: edge_n: Quantity @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "HorizontalGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> HorizontalGridData: return cls( lon=metric_terms.lon, lat=metric_terms.lat, @@ -128,15 +127,15 @@ class VerticalGridData: """ Terms defining the vertical grid. - Eulerian vertical grid is defined by p = ak + bk * p_ref + Eulerian vertical grid is defined by p = ak + bk * p_ref. """ # TODO: make these non-optional, make FloatFieldK a true type and use it ak: Quantity bk: Quantity """ - reference pressure (Pa) used to define pressure at vertical interfaces, - where p = ak + bk * p_ref + Reference pressure (Pa) used to define pressure at vertical interfaces, + where p = ak + bk * p_ref. """ def __post_init__(self): @@ -145,7 +144,7 @@ def __post_init__(self): self._p_interface = None @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "VerticalGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> VerticalGridData: return cls( ak=metric_terms.ak, bk=metric_terms.bk, @@ -153,14 +152,13 @@ def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "VerticalGridData": @classmethod def from_restart(cls, restart_path: str, quantity_factory: QuantityFactory): - fs = get_fs(restart_path) - restart_files = fs.ls(restart_path) + restart_files = os.listdir(restart_path) data_file = restart_files[ [fname.endswith("fv_core.res.nc") for fname in restart_files].index(True) ] ak_bk_data_file = pathlib.Path(restart_path) / data_file - if not fs.isfile(ak_bk_data_file): + if not ak_bk_data_file.is_file(): raise ValueError( """vertical_grid_from_restart is true, but no fv_core.res.nc in restart data file.""" @@ -168,7 +166,7 @@ def from_restart(cls, restart_path: str, quantity_factory: QuantityFactory): ak = quantity_factory.zeros([Z_INTERFACE_DIM], units="Pa") bk = quantity_factory.zeros([Z_INTERFACE_DIM], units="") - with fs.open(ak_bk_data_file, "rb") as f: + with open(ak_bk_data_file, "rb") as f: ds = xr.open_dataset(f).isel(Time=0).drop_vars("Time") ak.view[:] = ds["ak"].values bk.view[:] = ds["bk"].values @@ -177,9 +175,7 @@ def from_restart(cls, restart_path: str, quantity_factory: QuantityFactory): @property def p_ref(self) -> float: - """ - reference pressure (Pa) - """ + """Reference pressure (Pa)""" return 1e5 @property @@ -191,6 +187,7 @@ def p_interface(self) -> Quantity: dims=[Z_INTERFACE_DIM], units="Pa", gt4py_backend=self.ak.gt4py_backend, + number_of_halo_points=self.ak.metadata.n_halo, ) return self._p_interface @@ -207,6 +204,7 @@ def p(self) -> Quantity: dims=[Z_DIM], units="Pa", gt4py_backend=self.p_interface.gt4py_backend, + number_of_halo_points=self.p_interface.metadata.n_halo, ) return self._p @@ -223,17 +221,16 @@ def dp(self) -> Quantity: dims=[Z_DIM], units="Pa", gt4py_backend=self.ak.gt4py_backend, + number_of_halo_points=self.ak.metadata.n_halo, ) return self._dp_ref @property def ptop(self) -> Float: - """ - top of atmosphere pressure (Pa) - """ + """Top of atmosphere pressure (Pa)""" if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - if is_gpu_backend(self.ak.gt4py_backend): + if self.ak.gt4py_backend is not None and is_gpu_backend(self.ak.gt4py_backend): return Float(self.ak.view[0].get()) else: return Float(self.ak.view[0]) @@ -258,9 +255,7 @@ class ContravariantGridData: rsin2: Quantity @classmethod - def new_from_metric_terms( - cls, metric_terms: MetricTerms - ) -> "ContravariantGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> ContravariantGridData: return cls( cosa=metric_terms.cosa, cosa_u=metric_terms.cosa_u, @@ -303,7 +298,7 @@ class AngleGridData: cos_sg9: Quantity @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "AngleGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> AngleGridData: return cls( sin_sg1=metric_terms.sin_sg1, sin_sg2=metric_terms.sin_sg2, @@ -342,14 +337,14 @@ def __init__( self._vertical_data = vertical_data self._contravariant_data = contravariant_data self._angle_data = angle_data - if fc is not None: - self._fC = GridData._fC_from_data(fc, horizontal_data.lat) - else: - self._fC = None - if fc_agrid is not None: - self._fC_agrid = GridData._fC_from_data(fc_agrid, horizontal_data.lat) - else: - self._fC_agrid = None + self._fC = ( + None if fc is None else GridData._fC_from_data(fc, horizontal_data.lat) + ) + self._fC_agrid = ( + None + if fc_agrid is None + else GridData._fC_from_data(fc_agrid, horizontal_data.lat) + ) @classmethod def new_from_metric_terms(cls, metric_terms: MetricTerms): @@ -388,6 +383,7 @@ def _fC_from_data(data, lat: Quantity) -> Quantity: origin=lat.origin, extent=lat.extent, gt4py_backend=lat.gt4py_backend, + number_of_halo_points=lat.metadata.n_halo, ) @staticmethod @@ -752,7 +748,7 @@ class DriverGridData: grid_type: int @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "DriverGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> DriverGridData: return cls.new_from_grid_variables( vlon=metric_terms.vlon, vlat=metric_terms.vlon, @@ -777,7 +773,7 @@ def new_from_grid_variables( es1: Quantity, ew2: Quantity, grid_type: int = 0, - ) -> "DriverGridData": + ) -> DriverGridData: try: vlon1, vlon2, vlon3 = split_quantity_along_last_dim(vlon) vlat1, vlat2, vlat3 = split_quantity_along_last_dim(vlat) @@ -810,16 +806,16 @@ def new_from_grid_variables( ) -def split_quantity_along_last_dim(quantity): +def split_quantity_along_last_dim(quantity: Quantity) -> list[Quantity]: """Split a quantity along the last dimension into a list of quantities. Args: - quantity: Quantity to split. + quantity (Quantity): Quantity to split. Returns: - List of quantities. + list[Quantity]: List of quantities. """ - return_list = [] + return_list: list[Quantity] = [] for i in range(quantity.data.shape[-1]): return_list.append( Quantity( @@ -829,6 +825,7 @@ def split_quantity_along_last_dim(quantity): origin=quantity.origin[:-1], extent=quantity.extent[:-1], gt4py_backend=quantity.gt4py_backend, + number_of_halo_points=quantity.metadata.n_halo, ) ) return return_list diff --git a/ndsl/grid/mirror.py b/ndsl/grid/mirror.py index 3f3a4b6b..5f37aeb7 100644 --- a/ndsl/grid/mirror.py +++ b/ndsl/grid/mirror.py @@ -237,7 +237,7 @@ def _rot_3d(axis, p, angle, np, right_hand_grid, degrees=False, convert=False): y2 = -s * p1[0] + c * p1[1] z2 = p1[2] else: - assert False, "axis must be in [1,2,3]" + raise AssertionError("axis must be in [1,2,3]") if convert: p2 = _cartesian_to_spherical([x2, y2, z2], np, right_hand_grid) diff --git a/ndsl/grid/stretch_transformation.py b/ndsl/grid/stretch_transformation.py index 9a481ab9..65b0f1a7 100644 --- a/ndsl/grid/stretch_transformation.py +++ b/ndsl/grid/stretch_transformation.py @@ -1,12 +1,12 @@ import copy -from typing import Tuple, TypeVar, Union +from typing import TypeVar import numpy as np from ndsl.quantity import Quantity -T = TypeVar("T", bound=Union[Quantity, np.ndarray]) +T = TypeVar("T", bound=Quantity | np.ndarray) def direct_transform( @@ -17,7 +17,7 @@ def direct_transform( lon_target: float, lat_target: float, np, -) -> Tuple[T, T]: +) -> tuple[T, T]: """ The direct_transform subroutine from fv_grid_utils.F90. Takes in latitude and longitude in radians. @@ -56,8 +56,8 @@ def direct_transform( lon_p, lat_p = np.deg2rad(lon_target), np.deg2rad(lat_target) sin_p, cos_p = np.sin(lat_p), np.cos(lat_p) - c2p1 = 1.0 + stretch_factor ** 2 - c2m1 = 1.0 - stretch_factor ** 2 + c2p1 = 1.0 + stretch_factor**2 + c2m1 = 1.0 - stretch_factor**2 # first limit longitude so it's between 0 and 2pi lon_data[lon_data < 0] += 2 * np.pi @@ -102,4 +102,4 @@ def direct_transform( lon_out = lon_transformed lat_out = lat_transformed - return lon_out, lat_out # type: ignore + return lon_out, lat_out diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index e16177d5..deaead80 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -3,3 +3,10 @@ HaloDataTransformerCPU, HaloDataTransformerGPU, ) + + +__all__ = [ + "HaloDataTransformer", + "HaloDataTransformerCPU", + "HaloDataTransformerGPU", +] diff --git a/ndsl/halo/cuda_kernels.py b/ndsl/halo/cuda_kernels.py index 510e9f8e..1dd636c2 100644 --- a/ndsl/halo/cuda_kernels.py +++ b/ndsl/halo/cuda_kernels.py @@ -2,7 +2,7 @@ from ndsl.optional_imports import cupy as cp -def pack_scalar_code(float_dtype: str): +def pack_scalar_code(float_dtype: str) -> str: """Pack into o_destinationBuffer data from i_sourceArray. The indexation into i_sourceArray is stored in i_indexes. @@ -33,7 +33,7 @@ def pack_scalar_code(float_dtype: str): ) -def unpack_scalar_code(float_dtype: str): +def unpack_scalar_code(float_dtype: str) -> str: """Unpack into o_destinationArray data from i_sourceBuffer. The indexation into o_destinationArray is stored in i_indexes. diff --git a/ndsl/halo/data_transformer.py b/ndsl/halo/data_transformer.py index f3133974..3ffbc01e 100644 --- a/ndsl/halo/data_transformer.py +++ b/ndsl/halo/data_transformer.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import abc +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Sequence, Tuple from uuid import UUID, uuid1 import numpy as np @@ -28,7 +30,7 @@ # Simple pool of streams to lower the driver pressure # Use _pop/_push_stream to manipulate the pool -STREAM_POOL: List["cp.cuda.Stream"] = [] +STREAM_POOL: list["cp.cuda.Stream"] = [] def _pop_stream() -> "cp.cuda.Stream": @@ -37,7 +39,7 @@ def _pop_stream() -> "cp.cuda.Stream": return STREAM_POOL.pop() -def _push_stream(stream: "cp.cuda.Stream"): +def _push_stream(stream: "cp.cuda.Stream") -> None: STREAM_POOL.append(stream) @@ -47,14 +49,14 @@ def _push_stream(stream: "cp.cuda.Stream"): # Keyed cached - key is a str at the moment to go around the fact that # a slice is not hashable. getting a string from # Tuple(slices, rotation, shape, strides, itemsize) e.g. # noqa -# str(Tuple[Any, int, Tuple[int], Tuple[int], int]) # noqa -INDICES_CACHE: Dict[str, "cp.ndarray"] = {} +# str(tuple[Any, int, tuple[int], tuple[int], int]) # noqa +INDICES_CACHE: dict[str, "cp.ndarray"] = {} -def _build_flatten_indices( +def _build_flatten_indices( # type: ignore[no-untyped-def] key, shape, - slices: Tuple[slice], + slices: tuple[slice], dims, strides, itemsize: int, @@ -103,7 +105,7 @@ def _build_flatten_indices( # HaloDataTransformer helpers -def _slices_size(slices: Tuple[slice, ...]) -> int: +def _slices_size(slices: tuple[slice, ...]) -> int: """Compute linear size from slices.""" length = 1 for s in slices: @@ -129,11 +131,11 @@ class HaloExchangeSpec: """ specification: QuantityHaloSpec - pack_slices: Tuple[slice, ...] + pack_slices: tuple[slice, ...] pack_clockwise_rotation: int - unpack_slices: Tuple[slice, ...] + unpack_slices: tuple[slice, ...] - def __post_init__(self): + def __post_init__(self) -> None: self._id = uuid1() self.pack_buffer_size = _slices_size(self.pack_slices) self._unpack_buffer_size = _slices_size(self.unpack_slices) @@ -176,17 +178,17 @@ class HaloDataTransformer(abc.ABC): returned to an internal buffer pool. """ - _pack_buffer: Optional[Buffer] - _unpack_buffer: Optional[Buffer] + _pack_buffer: Buffer | None + _unpack_buffer: Buffer | None - _infos_x: Tuple[HaloExchangeSpec, ...] - _infos_y: Tuple[HaloExchangeSpec, ...] + _infos_x: tuple[HaloExchangeSpec, ...] + _infos_y: tuple[HaloExchangeSpec, ...] def __init__( self, np_module: NumpyModule, exchange_descriptors_x: Sequence[HaloExchangeSpec], - exchange_descriptors_y: Optional[Sequence[HaloExchangeSpec]] = None, + exchange_descriptors_y: Sequence[HaloExchangeSpec] | None = None, ) -> None: """ Args: @@ -219,7 +221,7 @@ def __init__( self._unpack_buffer = None self._compile() - def finalize(self): + def finalize(self) -> None: """Deletion routine, making sure all buffers were inserted back into cache.""" # Synchronize all work self.synchronize() @@ -237,8 +239,8 @@ def finalize(self): def get( np_module: NumpyModule, exchange_descriptors_x: Sequence[HaloExchangeSpec], - exchange_descriptors_y: Optional[Sequence[HaloExchangeSpec]] = None, - ) -> "HaloDataTransformer": + exchange_descriptors_y: Sequence[HaloExchangeSpec] | None = None, + ) -> HaloDataTransformer: """Construct a module from a numpy-like module. Args: @@ -301,7 +303,7 @@ def get_pack_buffer(self) -> Buffer: self.synchronize() return self._pack_buffer - def _compile(self): + def _compile(self) -> None: """Allocate contiguous memory buffers from description queued.""" # Compute required size @@ -316,10 +318,10 @@ def _compile(self): # Retrieve two properly sized buffers self._pack_buffer = Buffer.pop_from_cache( - self._np_module.zeros, (buffer_size), dtype + self._np_module.zeros, (buffer_size,), dtype # type: ignore[arg-type] ) self._unpack_buffer = Buffer.pop_from_cache( - self._np_module.zeros, (buffer_size), dtype + self._np_module.zeros, (buffer_size,), dtype # type: ignore[arg-type] ) def ready(self) -> bool: @@ -329,9 +331,9 @@ def ready(self) -> bool: @abc.abstractmethod def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: """Pack all given quantities into a single send Buffer. Does not guarantee the buffer returned by `get_unpack_buffer` has @@ -350,9 +352,9 @@ def async_pack( @abc.abstractmethod def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: """Unpack the buffer into destination quantities. Does not guarantee the buffer returned by `get_unpack_buffer` has @@ -368,7 +370,7 @@ def async_unpack( pass @abc.abstractmethod - def synchronize(self): + def synchronize(self) -> None: """Synchronize all operations. Guarantees all memory is now safe to access. @@ -382,7 +384,7 @@ class HaloDataTransformerCPU(HaloDataTransformer): Default behavior, could be done with any numpy-like library. """ - def synchronize(self): + def synchronize(self) -> None: if self._pack_buffer is not None: self._pack_buffer.finalize_memory_transfer() if self._unpack_buffer is not None: @@ -390,9 +392,9 @@ def synchronize(self): def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: self._pack_scalar(quantities_x) @@ -404,7 +406,7 @@ def async_pack( assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened - def _pack_scalar(self, quantities: List[Quantity]): + def _pack_scalar(self, quantities: list[Quantity]) -> None: if __debug__: if len(quantities) != len(self._infos_x): raise RuntimeError( @@ -433,7 +435,9 @@ def _pack_scalar(self, quantities: List[Quantity]): ) offset += data_size - def _pack_vector(self, quantities_x: List[Quantity], quantities_y: List[Quantity]): + def _pack_vector( + self, quantities_x: list[Quantity], quantities_y: list[Quantity] + ) -> None: if __debug__: if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( self._infos_y @@ -481,9 +485,9 @@ def _pack_vector(self, quantities_x: List[Quantity], quantities_y: List[Quantity def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: self._unpack_scalar(quantities_x) @@ -495,7 +499,7 @@ def async_unpack( assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened - def _unpack_scalar(self, quantities: List[Quantity]): + def _unpack_scalar(self, quantities: list[Quantity]) -> None: if __debug__: if len(quantities) != len(self._infos_x): raise RuntimeError( @@ -517,8 +521,8 @@ def _unpack_scalar(self, quantities: List[Quantity]): offset += data_size def _unpack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] - ): + self, quantities_x: list[Quantity], quantities_y: list[Quantity] + ) -> None: if __debug__: if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( self._infos_y @@ -573,19 +577,19 @@ class HaloDataTransformerGPU(HaloDataTransformer): class _CuKernelArgs: """All arguments required for the CUDA kernels.""" - stream: "cp.cuda.Stream" - x_send_indices: "cp.ndarray" - x_recv_indices: "cp.ndarray" - y_send_indices: Optional["cp.ndarray"] - y_recv_indices: Optional["cp.ndarray"] + stream: cp.cuda.Stream + x_send_indices: cp.ndarray + x_recv_indices: cp.ndarray + y_send_indices: cp.ndarray | None + y_recv_indices: cp.ndarray | None def __init__( self, np_module: NumpyModule, exchange_descriptors_x: Sequence[HaloExchangeSpec], - exchange_descriptors_y: Optional[Sequence[HaloExchangeSpec]] = None, + exchange_descriptors_y: Sequence[HaloExchangeSpec] | None = None, ) -> None: - self._cu_kernel_args: Dict[UUID, HaloDataTransformerGPU._CuKernelArgs] = {} + self._cu_kernel_args: dict[UUID, HaloDataTransformerGPU._CuKernelArgs] = {} super().__init__( np_module, exchange_descriptors_x, @@ -595,7 +599,7 @@ def __init__( def _flatten_indices( self, exchange_data: HaloExchangeSpec, - slices: Tuple[slice], + slices: tuple[slice], rotate: bool, ) -> "cp.ndarray": """Extract a flat array of indices from the memory layout and the slice. @@ -630,7 +634,7 @@ def _flatten_indices( # We don't return a copy since the indices are read-only in the algorithm return INDICES_CACHE[key] - def _compile(self): + def _compile(self) -> None: # Super to get buffer allocation super()._compile() # Allocate the streams & build the indices arrays @@ -639,10 +643,10 @@ def _compile(self): self._cu_kernel_args[info_x._id] = HaloDataTransformerGPU._CuKernelArgs( stream=_pop_stream(), x_send_indices=self._flatten_indices( - info_x, info_x.pack_slices, True + info_x, info_x.pack_slices, True # type: ignore[arg-type] ), x_recv_indices=self._flatten_indices( - info_x, info_x.unpack_slices, False + info_x, info_x.unpack_slices, False # type: ignore[arg-type] ), y_send_indices=None, y_recv_indices=None, @@ -653,33 +657,33 @@ def _compile(self): self._cu_kernel_args[info_x._id] = HaloDataTransformerGPU._CuKernelArgs( stream=_pop_stream(), x_send_indices=self._flatten_indices( - info_x, info_x.pack_slices, True + info_x, info_x.pack_slices, True # type: ignore[arg-type] ), x_recv_indices=self._flatten_indices( - info_x, info_x.unpack_slices, False + info_x, info_x.unpack_slices, False # type: ignore[arg-type] ), y_send_indices=self._flatten_indices( - info_y, info_y.pack_slices, True + info_y, info_y.pack_slices, True # type: ignore[arg-type] ), y_recv_indices=self._flatten_indices( - info_y, info_y.unpack_slices, False + info_y, info_y.unpack_slices, False # type: ignore[arg-type] ), ) - def synchronize(self): + def synchronize(self) -> None: if self._CODE_PATH_DEVICE_WIDE_SYNC: self._safe_synchronize() else: self._streamed_synchronize() - def _streamed_synchronize(self): + def _streamed_synchronize(self) -> None: for cu_kernel in self._cu_kernel_args.values(): cu_kernel.stream.synchronize() - def _safe_synchronize(self): + def _safe_synchronize(self) -> None: device_synchronize() - def _get_stream(self, stream) -> "cp.cuda.stream": + def _get_stream(self, stream: "cp.cuda.stream") -> "cp.cuda.stream": if self._CODE_PATH_DEVICE_WIDE_SYNC: return cp.cuda.Stream.null else: @@ -687,9 +691,9 @@ def _get_stream(self, stream) -> "cp.cuda.stream": def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: """Pack the quantities into a single buffer via streamed cuda kernels Writes into self._pack_buffer using self._x_infos and self._y_infos @@ -710,7 +714,7 @@ def async_pack( else: raise RuntimeError(f"Unimplemented {self._type} pack") - def _opt_pack_scalar(self, quantities: List[Quantity]): + def _opt_pack_scalar(self, quantities: list[Quantity]) -> None: """Specialized packing for scalar. See async_pack docs for usage.""" if __debug__: if len(quantities) != len(self._infos_x): @@ -766,8 +770,8 @@ def _opt_pack_scalar(self, quantities: List[Quantity]): offset += info_x.pack_buffer_size def _opt_pack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] - ): + self, quantities_x: list[Quantity], quantities_y: list[Quantity] + ) -> None: """Specialized packing for vectors. See async_pack docs for usage.""" if __debug__: if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( @@ -840,9 +844,9 @@ def _opt_pack_vector( def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: """Unpack the quantities from a single buffer via streamed cuda kernels Reads from self._unpack_buffer using self._x_infos and self._y_infos @@ -862,7 +866,7 @@ def async_unpack( else: raise RuntimeError(f"Unimplemented {self._type} unpack") - def _opt_unpack_scalar(self, quantities: List[Quantity]): + def _opt_unpack_scalar(self, quantities: list[Quantity]) -> None: """Specialized unpacking for scalars. See async_unpack docs for usage.""" if __debug__: if len(quantities) != len(self._infos_x): @@ -917,8 +921,8 @@ def _opt_unpack_scalar(self, quantities: List[Quantity]): offset += info_x._unpack_buffer_size def _opt_unpack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] - ): + self, quantities_x: list[Quantity], quantities_y: list[Quantity] + ) -> None: """Specialized unpacking for vectors. See async_unpack docs for usage.""" if __debug__: if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( @@ -988,7 +992,7 @@ def _opt_unpack_vector( # Next transformer offset into send buffer offset += edge_size - def finalize(self): + def finalize(self) -> None: super().finalize() # Push the streams back in the pool for cu_info in self._cu_kernel_args.values(): diff --git a/ndsl/halo/rotate.py b/ndsl/halo/rotate.py index 1a7ad7e7..99b618a8 100644 --- a/ndsl/halo/rotate.py +++ b/ndsl/halo/rotate.py @@ -1,7 +1,7 @@ import ndsl.constants as constants -def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): +def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): # type: ignore[no-untyped-def] n_clockwise_rotations = n_clockwise_rotations % 4 if n_clockwise_rotations == 0: pass @@ -34,7 +34,7 @@ def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): return data -def rotate_vector_data(x_data, y_data, n_clockwise_rotations, dims, numpy): +def rotate_vector_data(x_data, y_data, n_clockwise_rotations, dims, numpy): # type: ignore[no-untyped-def] x_data = rotate_scalar_data(x_data, dims, numpy, n_clockwise_rotations) y_data = rotate_scalar_data(y_data, dims, numpy, n_clockwise_rotations) data = [x_data, y_data] diff --git a/ndsl/halo/updater.py b/ndsl/halo/updater.py index 76f7608f..40a9e1aa 100644 --- a/ndsl/halo/updater.py +++ b/ndsl/halo/updater.py @@ -1,10 +1,14 @@ +from __future__ import annotations + from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING import numpy as np import ndsl.constants as constants from ndsl.buffer import Buffer +from ndsl.comm import Comm from ndsl.comm.boundary import Boundary from ndsl.halo.data_transformer import HaloDataTransformer, HaloExchangeSpec from ndsl.halo.rotate import rotate_scalar_data @@ -17,10 +21,10 @@ if TYPE_CHECKING: from ndsl.comm.communicator import Communicator -_HaloSendTuple = Tuple[AsyncRequest, Buffer] -_HaloRecvTuple = Tuple[AsyncRequest, Buffer, np.ndarray] -_HaloRequestSendList = List[_HaloSendTuple] -_HaloRequestRecvList = List[_HaloRecvTuple] +_HaloSendTuple = tuple[AsyncRequest, Buffer] +_HaloRecvTuple = tuple[AsyncRequest, Buffer, np.ndarray] +_HaloRequestSendList = list[_HaloSendTuple] +_HaloRequestRecvList = list[_HaloRecvTuple] TIMER_HALO_EX_KEY = "halo_exchange_global" @@ -43,9 +47,9 @@ class HaloUpdater: def __init__( self, - comm: "Communicator", + comm: Communicator, tag: int, - transformers: Dict[int, HaloDataTransformer], + transformers: dict[int, HaloDataTransformer], timer: Timer, ): """Build the updater. @@ -61,20 +65,20 @@ def __init__( self._tag = tag self._transformers = transformers self._timer = timer - self._recv_requests: List[AsyncRequest] = [] - self._send_requests: List[AsyncRequest] = [] - self._inflight_x_quantities: Optional[Tuple[Quantity, ...]] = None - self._inflight_y_quantities: Optional[Tuple[Quantity, ...]] = None + self._recv_requests: list[AsyncRequest] = [] + self._send_requests: list[AsyncRequest] = [] + self._inflight_x_quantities: tuple[Quantity, ...] | None = None + self._inflight_y_quantities: tuple[Quantity, ...] | None = None self._finalize_on_wait = False - def force_finalize_on_wait(self): + def force_finalize_on_wait(self) -> None: """HaloDataTransformer are finalized after a wait call This is a temporary fix. See DSL-816 which will remove self._finalize_on_wait. """ self._finalize_on_wait = True - def __del__(self): + def __del__(self) -> None: """Clean up all buffers on garbage collection""" if ( self._inflight_x_quantities is not None @@ -90,13 +94,13 @@ def __del__(self): @classmethod def from_scalar_specifications( cls, - comm: "Communicator", + comm: Communicator, numpy_like_module: NumpyModule, specifications: Iterable[QuantityHaloSpec], boundaries: Iterable[Boundary], tag: int, - optional_timer: Optional[Timer] = None, - ) -> "HaloUpdater": + optional_timer: Timer | None = None, + ) -> HaloUpdater: """ Create/retrieve as many packed buffer as needed and queue the slices to exchange. @@ -131,7 +135,7 @@ def from_scalar_specifications( # Create the data transformers to support pack/unpack # One transformer per target rank - transformers: Dict[int, HaloDataTransformer] = {} + transformers: dict[int, HaloDataTransformer] = {} for rank, exchange_specs in exchange_specs_dict.items(): transformers[rank] = HaloDataTransformer.get( numpy_like_module, exchange_specs @@ -142,14 +146,14 @@ def from_scalar_specifications( @classmethod def from_vector_specifications( cls, - comm: "Communicator", + comm: Communicator, numpy_like_module: NumpyModule, specifications_x: Iterable[QuantityHaloSpec], specifications_y: Iterable[QuantityHaloSpec], boundaries: Iterable[Boundary], tag: int, - optional_timer: Optional[Timer] = None, - ) -> "HaloUpdater": + optional_timer: Timer | None = None, + ) -> HaloUpdater: """ Create/retrieve as many packed buffer as needed and queue the slices to exchange. @@ -207,18 +211,18 @@ def from_vector_specifications( def update( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: """Exchange the data and blocks until finished.""" self.start(quantities_x, quantities_y) self.wait() def start( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, - ): + quantities_x: list[Quantity], + quantities_y: list[Quantity] | None = None, + ) -> None: """Start data exchange.""" self._comm._device_synchronize() @@ -269,7 +273,7 @@ def start( self._timer.stop(TIMER_HALO_EX_KEY) - def wait(self): + def wait(self) -> None: """Finalize data exchange.""" if __debug__ and self._inflight_x_quantities is None: raise RuntimeError('Halo update "wait" call before "start"') @@ -288,7 +292,7 @@ def wait(self): with self._timer.clock("unpack"): for buffer in self._transformers.values(): buffer.async_unpack( - self._inflight_x_quantities, self._inflight_y_quantities + self._inflight_x_quantities, self._inflight_y_quantities # type: ignore[arg-type] ) if self._finalize_on_wait: for transformer in self._transformers.values(): @@ -310,8 +314,8 @@ def __init__( self, send_data: _HaloRequestSendList, recv_data: _HaloRequestRecvList, - timer: Optional[Timer] = None, - ): + timer: Timer | None = None, + ) -> None: """Build a halo request. Args: send_data: a tuple of the MPI request and the buffer sent @@ -323,7 +327,7 @@ def __init__( self._recv_data = recv_data self._timer: Timer = timer if timer is not None else NullTimer() - def wait(self): + def wait(self) -> None: """Wait & unpack data into destination buffers Clean up by inserting back all buffers back in cache for potential reuse @@ -341,7 +345,7 @@ def wait(self): Buffer.push_to_cache(transfer_buffer) -def on_c_grid(x_quantity, y_quantity): +def on_c_grid(x_quantity: Quantity, y_quantity: Quantity) -> bool: if ( constants.X_DIM not in x_quantity.dims or constants.Y_INTERFACE_DIM not in x_quantity.dims @@ -359,11 +363,11 @@ def on_c_grid(x_quantity, y_quantity): class VectorInterfaceHaloUpdater: def __init__( self, - comm, + comm: Comm, boundaries: Mapping[int, Boundary], force_cpu: bool = False, - timer: Optional[Timer] = None, - ): + timer: Timer | None = None, + ) -> None: """Initialize a CubedSphereCommunicator. Args: @@ -416,12 +420,12 @@ def start_synchronize_vector_interfaces( return HaloUpdateRequest(send_requests, recv_requests, self.timer) def _Isend_vector_shared_boundary( - self, x_quantity, y_quantity, tag=0 + self, x_quantity: Quantity, y_quantity: Quantity, tag: int = 0 ) -> _HaloRequestSendList: south_boundary = self.boundaries[constants.SOUTH] west_boundary = self.boundaries[constants.WEST] south_data = x_quantity.view.southwest.sel( - **{ + **{ # type: ignore[arg-type] constants.Y_INTERFACE_DIM: 0, constants.X_DIM: slice( 0, x_quantity.extent[x_quantity.dims.index(constants.X_DIM)] @@ -437,7 +441,7 @@ def _Isend_vector_shared_boundary( if south_boundary.n_clockwise_rotations in (3, 2): south_data = -south_data west_data = y_quantity.view.southwest.sel( - **{ + **{ # type: ignore[arg-type] constants.X_INTERFACE_DIM: 0, constants.Y_DIM: slice( 0, y_quantity.extent[y_quantity.dims.index(constants.Y_DIM)] @@ -478,12 +482,12 @@ def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: return module def _Irecv_vector_shared_boundary( - self, x_quantity, y_quantity, tag=0 + self, x_quantity: Quantity, y_quantity: Quantity, tag: int = 0 ) -> _HaloRequestRecvList: north_rank = self.boundaries[constants.NORTH].to_rank east_rank = self.boundaries[constants.EAST].to_rank north_data = x_quantity.view.northwest.sel( - **{ + **{ # type: ignore[arg-type] constants.Y_INTERFACE_DIM: -1, constants.X_DIM: slice( 0, x_quantity.extent[x_quantity.dims.index(constants.X_DIM)] @@ -491,7 +495,7 @@ def _Irecv_vector_shared_boundary( } ) east_data = y_quantity.view.southeast.sel( - **{ + **{ # type: ignore[arg-type] constants.X_INTERFACE_DIM: -1, constants.Y_DIM: slice( 0, y_quantity.extent[y_quantity.dims.index(constants.Y_DIM)] @@ -514,7 +518,7 @@ def _Irecv_vector_shared_boundary( ] return recv_requests - def _Isend(self, numpy_module, in_array, **kwargs) -> _HaloSendTuple: + def _Isend(self, numpy_module, in_array, **kwargs) -> _HaloSendTuple: # type: ignore[no-untyped-def] # copy the resulting view in a contiguous array for transfer with self.timer.clock("pack"): buffer = Buffer.pop_from_cache( @@ -526,7 +530,7 @@ def _Isend(self, numpy_module, in_array, **kwargs) -> _HaloSendTuple: request = self.comm.Isend(buffer.array, **kwargs) return (request, buffer) - def _Irecv(self, numpy_module, out_array, **kwargs) -> _HaloRecvTuple: + def _Irecv(self, numpy_module, out_array, **kwargs) -> _HaloRecvTuple: # type: ignore[no-untyped-def] # Prepare a contiguous buffer to receive data with self.timer.clock("Irecv"): buffer = Buffer.pop_from_cache( diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index 8f40c7af..09551c8a 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1 +1,10 @@ -from .sizer import GridSizer +from .grid_sizer import GridSizer # isort: skip +from .allocator import QuantityFactory +from .subtile_grid_sizer import SubtileGridSizer + + +__all__ = [ + "GridSizer", + "QuantityFactory", + "SubtileGridSizer", +] diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 85ee17dd..e9d8857e 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -1,16 +1,20 @@ -from typing import Callable, Optional, Sequence +from __future__ import annotations + +import warnings +from collections.abc import Callable, Sequence +from typing import Any import numpy as np from gt4py import storage as gt_storage from ndsl.constants import SPATIAL_DIMS from ndsl.dsl.typing import Float -from ndsl.initialization.sizer import GridSizer +from ndsl.initialization import GridSizer from ndsl.quantity import Quantity, QuantityHaloSpec class StorageNumpy: - def __init__(self, backend: str): + def __init__(self, backend: str) -> None: """Initialize an object which behaves like the numpy module, but uses gt4py storage objects for zeros, ones, and empty. @@ -19,29 +23,78 @@ def __init__(self, backend: str): """ self.backend = backend - def empty(self, *args, **kwargs) -> np.ndarray: + def empty(self, *args: Any, **kwargs: Any) -> np.ndarray: return gt_storage.empty(*args, backend=self.backend, **kwargs) - def ones(self, *args, **kwargs) -> np.ndarray: + def ones(self, *args: Any, **kwargs: Any) -> np.ndarray: return gt_storage.ones(*args, backend=self.backend, **kwargs) - def zeros(self, *args, **kwargs) -> np.ndarray: + def zeros(self, *args: Any, **kwargs: Any) -> np.ndarray: return gt_storage.zeros(*args, backend=self.backend, **kwargs) class QuantityFactory: - def __init__(self, sizer: GridSizer, numpy): + def __init__( # type: ignore + self, sizer: GridSizer, numpy, *, silence_deprecation_warning: bool = False + ) -> None: + if not silence_deprecation_warning: + warnings.warn( + "Usage of QuantityFactory(sizer, numpy) is discouraged and will change " + "in the next release. Use QuantityFactory.from_backend(sizer, backend) " + "instead for a stable experience across the release.", + DeprecationWarning, + 2, + ) self.sizer: GridSizer = sizer self._numpy = numpy - def set_extra_dim_lengths(self, **kwargs): + def set_extra_dim_lengths(self, **kwargs: Any) -> None: """ Set the length of extra (non-x/y/z) dimensions. """ - self.sizer.extra_dim_lengths.update(kwargs) + warnings.warn( + "`QuantityFactory.set_extra_dim_lengths` is deprecated. " + "Use `add_data_dimensions` or `update_data_dimensions`.", + DeprecationWarning, + 2, + ) + self.sizer.data_dimensions.update(kwargs) + + def update_data_dimensions( + self, + data_dimension_descriptions: dict[str, int], + ) -> None: + """ + Update the length of data (non-x/y/z) dimensions, unknown data dimensions + will be added, existing ones updated. + + Args: + data_dimension_descriptions: Dict of name/length pairs + """ + self.sizer.data_dimensions.update(data_dimension_descriptions) + + def add_data_dimensions( + self, + data_dimension_descriptions: dict[str, int], + ) -> None: + """ + Add new data (non-x/y/z) dimensions via a key-length pair. If the dimension + already exists, it will error out. + + Args: + data_dimension_descriptions: Dict of name/length pairs + """ + for name in data_dimension_descriptions.keys(): + if name in self.sizer.data_dimensions.keys(): + raise ValueError( + f"[NDSL] Data dimension {name} already exists! " + "Use `update_data_dimensions` if you meant to update the length." + ) + + self.sizer.data_dimensions.update(data_dimension_descriptions) @classmethod - def from_backend(cls, sizer: GridSizer, backend: str): + def from_backend(cls, sizer: GridSizer, backend: str) -> QuantityFactory: """Initialize a QuantityFactory to use a specific gt4py backend. Args: @@ -49,23 +102,32 @@ def from_backend(cls, sizer: GridSizer, backend: str): backend: gt4py backend """ numpy = StorageNumpy(backend) - return cls(sizer, numpy) + # Don't print the deprecation warning in this case + return cls(sizer, numpy, silence_deprecation_warning=True) - def _backend(self) -> Optional[str]: - try: + def _backend(self) -> str | None: + if isinstance(self._numpy, StorageNumpy): return self._numpy.backend - except AttributeError: - return None + + return None def empty( self, dims: Sequence[str], units: str, dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, - ): + ) -> Quantity: + """Allocate a Quantity - values are random. + + Equivalent to `numpy.empty`""" return self._allocate( - self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision + self._numpy.empty, + dims, + units, + dtype, + allow_mismatch_float_precision, ) def zeros( @@ -73,10 +135,18 @@ def zeros( dims: Sequence[str], units: str, dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, - ): + ) -> Quantity: + """Allocate a Quantity and fill it with the value 0. + + Equivalent to `numpy.zeros`""" return self._allocate( - self._numpy.zeros, dims, units, dtype, allow_mismatch_float_precision + self._numpy.zeros, + dims, + units, + dtype, + allow_mismatch_float_precision, ) def ones( @@ -84,19 +154,50 @@ def ones( dims: Sequence[str], units: str, dtype: type = Float, + *, allow_mismatch_float_precision: bool = False, - ): + ) -> Quantity: + """Allocate a Quantity and fill it with the value 1. + + Equivalent to `numpy.ones`""" return self._allocate( - self._numpy.ones, dims, units, dtype, allow_mismatch_float_precision + self._numpy.ones, + dims, + units, + dtype, + allow_mismatch_float_precision, ) + def full( + self, + dims: Sequence[str], + units: str, + value: Any, # no type hint because it would be a TypeVar = type[dtype] and mypy says no + dtype: type = Float, + *, + allow_mismatch_float_precision: bool = False, + ) -> Quantity: + """Allocate a Quantity and fill it with the value. + + Equivalent to `numpy.full`""" + quantity = self._allocate( + self._numpy.empty, + dims, + units, + dtype, + allow_mismatch_float_precision, + ) + quantity.data[:] = value + return quantity + def from_array( self, data: np.ndarray, dims: Sequence[str], units: str, + *, allow_mismatch_float_precision: bool = False, - ): + ) -> Quantity: """ Create a Quantity from a numpy array. @@ -117,8 +218,9 @@ def from_compute_array( data: np.ndarray, dims: Sequence[str], units: str, + *, allow_mismatch_float_precision: bool = False, - ): + ) -> Quantity: """ Create a Quantity from a numpy array. @@ -141,14 +243,16 @@ def _allocate( units: str, dtype: type = Float, allow_mismatch_float_precision: bool = False, - ): + ) -> Quantity: origin = self.sizer.get_origin(dims) extent = self.sizer.get_extent(dims) shape = self.sizer.get_shape(dims) dimensions = [ - axis - if any(dim in axis_dims for axis_dims in SPATIAL_DIMS) - else str(shape[index]) + ( + axis + if any(dim in axis_dims for axis_dims in SPATIAL_DIMS) + else str(shape[index]) + ) for index, (dim, axis) in enumerate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) @@ -167,12 +271,13 @@ def _allocate( extent=extent, gt4py_backend=self._backend(), allow_mismatch_float_precision=allow_mismatch_float_precision, + number_of_halo_points=self.sizer.n_halo, ) def get_quantity_halo_spec( self, dims: Sequence[str], - n_halo: Optional[int] = None, + n_halo: int | None = None, dtype: type = Float, ) -> QuantityHaloSpec: """Build memory specifications for the halo update. diff --git a/ndsl/initialization/grid_sizer.py b/ndsl/initialization/grid_sizer.py new file mode 100644 index 00000000..961ab793 --- /dev/null +++ b/ndsl/initialization/grid_sizer.py @@ -0,0 +1,35 @@ +import warnings +from collections.abc import Sequence +from dataclasses import dataclass + + +@dataclass +class GridSizer: + nx: int + """Length of the x compute dimension for produced arrays.""" + ny: int + """Length of the y compute dimension for produced arrays.""" + nz: int + """Length of the z compute dimension for produced arrays.""" + n_halo: int + """Number of horizontal halo points for produced arrays.""" + data_dimensions: dict[str, int] + """Name/Lengths pair of any non-x/y/z dimensions, such as land or radiation dimensions.""" + + @property + def extra_dim_lengths(self) -> dict[str, int]: + warnings.warn( + "`GridSizer.extra_dim_lengths` is a deprecated API, use `GridSizer.data_dimensions`.", + DeprecationWarning, + 2, + ) + return self.data_dimensions + + def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]: + raise NotImplementedError() + + def get_extent(self, dims: Sequence[str]) -> tuple[int, ...]: + raise NotImplementedError() + + def get_shape(self, dims: Sequence[str]) -> tuple[int, ...]: + raise NotImplementedError() diff --git a/ndsl/initialization/sizer.py b/ndsl/initialization/subtile_grid_sizer.py similarity index 72% rename from ndsl/initialization/sizer.py rename to ndsl/initialization/subtile_grid_sizer.py index 8ad3d196..ff31a1c3 100644 --- a/ndsl/initialization/sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -1,32 +1,11 @@ -import dataclasses -from typing import Dict, Iterable, Sequence, Tuple +import warnings +from collections.abc import Iterable +from typing import Self import ndsl.constants as constants from ndsl.comm.partitioner import TilePartitioner from ndsl.constants import N_HALO_DEFAULT - - -@dataclasses.dataclass -class GridSizer: - nx: int - """length of the x compute dimension for produced arrays""" - ny: int - """length of the y compute dimension for produced arrays""" - nz: int - """length of the z compute dimension for produced arrays""" - n_halo: int - """number of horizontal halo points for produced arrays""" - extra_dim_lengths: Dict[str, int] - """lengths of any non-x/y/z dimensions, such as land or radiation dimensions""" - - def get_origin(self, dims: Sequence[str]) -> Tuple[int, ...]: - raise NotImplementedError() - - def get_extent(self, dims: Sequence[str]) -> Tuple[int, ...]: - raise NotImplementedError() - - def get_shape(self, dims: Sequence[str]) -> Tuple[int, ...]: - raise NotImplementedError() +from ndsl.initialization.grid_sizer import GridSizer class SubtileGridSizer(GridSizer): @@ -37,11 +16,13 @@ def from_tile_params( ny_tile: int, nz: int, n_halo: int, - extra_dim_lengths: Dict[str, int], - layout: Tuple[int, int], - tile_partitioner: TilePartitioner = None, + layout: tuple[int, int], + *, + data_dimensions: dict[str, int] | None = None, + tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, - ): + extra_dim_lengths: dict[str, int] | None = None, + ) -> Self: """Create a SubtileGridSizer from parameters about the full tile. Args: @@ -49,13 +30,24 @@ def from_tile_params( ny_tile: number of y cell centers on the tile nz: number of vertical levels n_halo: number of halo points - extra_dim_lengths: lengths of any non-x/y/z dimensions, + data_dimensions: lengths of any non-x/y/z dimensions, such as land or radiation dimensions layout: (y, x) number of ranks along tile edges tile_partitioner (optional): partitioner object for the tile. By default, a TilePartitioner is created with the given layout tile_rank (optional): rank of this subtile. + extra_dim_lengths: DEPRECATED API - use `data_dimensions` """ + if data_dimensions is None: + data_dimensions = {} + + if extra_dim_lengths is not None: + warnings.warn( + "`extra_dim_lengths` is a deprecated name, please use `data_dimensions` instead.", + DeprecationWarning, + 2, + ) + data_dimensions = extra_dim_lengths if tile_partitioner is None: tile_partitioner = TilePartitioner(layout) y_slice, x_slice = tile_partitioner.subtile_slice( @@ -77,15 +69,15 @@ def from_tile_params( "SubtileGridSizer::from_tile_params: Compute domain extent must be greater than halo size" ) - return cls(nx, ny, nz, n_halo, extra_dim_lengths) + return cls(nx, ny, nz, n_halo, data_dimensions) @classmethod def from_namelist( cls, namelist: dict, - tile_partitioner: TilePartitioner = None, + tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, - ): + ) -> Self: """Create a SubtileGridSizer from a Fortran namelist. Args: @@ -114,19 +106,18 @@ def from_namelist( "expected to find nx_tile or fv_core_nml" ) return cls.from_tile_params( - nx_tile, - ny_tile, - nz, - N_HALO_DEFAULT, - {}, - layout, - tile_partitioner, - tile_rank, + nx_tile=nx_tile, + ny_tile=ny_tile, + nz=nz, + n_halo=N_HALO_DEFAULT, + layout=layout, + tile_partitioner=tile_partitioner, + tile_rank=tile_rank, ) @property - def dim_extents(self) -> Dict[str, int]: - return_dict = self.extra_dim_lengths.copy() + def dim_extents(self) -> dict[str, int]: + return_dict = self.data_dimensions.copy() return_dict.update( { constants.X_DIM: self.nx, @@ -139,18 +130,18 @@ def dim_extents(self) -> Dict[str, int]: ) return return_dict - def get_origin(self, dims: Iterable[str]) -> Tuple[int, ...]: + def get_origin(self, dims: Iterable[str]) -> tuple[int, ...]: return_list = [ self.n_halo if dim in constants.HORIZONTAL_DIMS else 0 for dim in dims ] return tuple(return_list) - def get_extent(self, dims: Iterable[str]) -> Tuple[int, ...]: + def get_extent(self, dims: Iterable[str]) -> tuple[int, ...]: extents = self.dim_extents return tuple(extents[dim] for dim in dims) - def get_shape(self, dims: Iterable[str]) -> Tuple[int, ...]: - shape_dict = self.extra_dim_lengths.copy() + def get_shape(self, dims: Iterable[str]) -> tuple[int, ...]: + shape_dict = self.data_dimensions.copy() # must pad non-interface variables to have the same shape as interface variables shape_dict.update( { diff --git a/ndsl/io.py b/ndsl/io.py index 0aa2b7d4..f4b387ac 100644 --- a/ndsl/io.py +++ b/ndsl/io.py @@ -3,7 +3,6 @@ import cftime import xarray as xr -import ndsl.filesystem as filesystem from ndsl.quantity import Quantity @@ -20,7 +19,7 @@ } -def to_xarray_dataset(state) -> xr.Dataset: +def to_xarray_dataset(state: dict) -> xr.Dataset: data_vars = { name: value.data_as_xarray for name, value in state.items() if name != "time" } @@ -39,7 +38,7 @@ def write_state(state: dict, filename: str) -> None: if "time" not in state: raise ValueError('state must include a value for "time"') ds = to_xarray_dataset(state) - with filesystem.open(filename, "wb") as f: + with open(filename, "wb") as f: ds.to_netcdf(f) @@ -68,7 +67,7 @@ def read_state(filename: str) -> dict: state: a model state dictionary """ out_dict = {} - with filesystem.open(filename, "rb") as f: + with open(filename, "rb") as f: time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) ds = xr.open_dataset(f, decode_times=time_coder) for name, value in ds.data_vars.items(): @@ -79,7 +78,7 @@ def read_state(filename: str) -> dict: return out_dict -def _get_integer_tokens(line, n_tokens): +def _get_integer_tokens(line: str, n_tokens: int) -> list[int]: all_tokens = line.split() return [int(token) for token in all_tokens[:n_tokens]] diff --git a/ndsl/logging.py b/ndsl/logging.py index 73b7979c..183054b6 100644 --- a/ndsl/logging.py +++ b/ndsl/logging.py @@ -5,11 +5,9 @@ import sys from typing import Annotated -from mpi4py import MPI +from ndsl.comm.mpi import MPI -LOGLEVEL = os.environ.get("PACE_LOGLEVEL", "INFO").upper() - # Python log levels are hierarchical, therefore setting INFO # means DEBUG and everything lower will be logged. AVAILABLE_LOG_LEVELS = { @@ -21,12 +19,33 @@ } +def _get_log_level(default: str = "info") -> str: + if os.getenv("PACE_LOGLEVEL", ""): + logging.warning("PACE_LOGLEVEL is deprecated. Use NDSL_LOGLEVEL instead.") + if os.getenv("NDSL_LOGLEVEL", ""): + logging.warning( + "PACE_LOGLEVEL and NDSL_LOGLEVEL were both specified. NDSL_LOGLEVEL will take precedence." + ) + + loglevel = os.getenv("NDSL_LOGLEVEL", os.getenv("PACE_LOGLEVEL", default)).lower() + + if loglevel in AVAILABLE_LOG_LEVELS.keys(): + return loglevel + + logging.warning( + f"Unknown log level '{loglevel}', falling back to '{default}'. Valid values are: {AVAILABLE_LOG_LEVELS.keys()}." + ) + return default + + def _ndsl_logger() -> logging.Logger: + log_level = _get_log_level() + name_log = logging.getLogger(__name__) - name_log.setLevel(LOGLEVEL) + name_log.setLevel(AVAILABLE_LOG_LEVELS[log_level]) handler = logging.StreamHandler(sys.stdout) - handler.setLevel(LOGLEVEL) + handler.setLevel(AVAILABLE_LOG_LEVELS[log_level]) formatter = logging.Formatter( fmt=( f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|" @@ -40,14 +59,16 @@ def _ndsl_logger() -> logging.Logger: def _ndsl_logger_on_rank_0() -> logging.Logger: + log_level = _get_log_level() + name_log = logging.getLogger(f"{__name__}_on_rank_0") - name_log.setLevel(LOGLEVEL) + name_log.setLevel(AVAILABLE_LOG_LEVELS[log_level]) rank = MPI.COMM_WORLD.Get_rank() if rank == 0: handler = logging.StreamHandler(sys.stdout) - handler.setLevel(LOGLEVEL) + handler.setLevel(AVAILABLE_LOG_LEVELS[log_level]) formatter = logging.Formatter( fmt=( f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|" @@ -62,9 +83,9 @@ def _ndsl_logger_on_rank_0() -> logging.Logger: return name_log -ndsl_log: Annotated[ - logging.Logger, "NDSL Python logger, logs on all rank" -] = _ndsl_logger() +ndsl_log: Annotated[logging.Logger, "NDSL Python logger, logs on all rank"] = ( + _ndsl_logger() +) ndsl_log_on_rank_0: Annotated[ logging.Logger, "NDSL Python logger, logs on rank 0 only" diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index 5d732315..e3f5037f 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,2 +1,8 @@ from .protocol import Monitor from .zarr_monitor import ZarrMonitor + + +__all__ = [ + "Monitor", + "ZarrMonitor", +] diff --git a/ndsl/monitor/convert.py b/ndsl/monitor/convert.py index a62af01a..48323bd9 100644 --- a/ndsl/monitor/convert.py +++ b/ndsl/monitor/convert.py @@ -3,7 +3,7 @@ from ndsl.optional_imports import cupy -def to_numpy(array, dtype=None) -> np.ndarray: +def to_numpy(array, dtype=None) -> np.ndarray: # type: ignore[no-untyped-def] """ Input array can be a numpy array or a cupy array. Returns numpy array. """ diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 945483d6..e2cb417f 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -1,15 +1,12 @@ import os from pathlib import Path -from typing import Any, Dict, List, Optional, Set from warnings import warn -import fsspec import numpy as np import xarray as xr from ndsl.comm.communicator import Communicator from ndsl.dsl.typing import Float, get_precision -from ndsl.filesystem import get_fs from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.quantity import Quantity @@ -25,7 +22,7 @@ def __init__(self, initial: Quantity, time_chunk_size: int): self._units = initial.units self._i_time = 1 - def append(self, quantity: Quantity): + def append(self, quantity: Quantity) -> None: # Allow mismatch precision here since this is I/O self._data[self._i_time, ...] = to_numpy( quantity.transpose(self._dims, allow_mismatch_float_precision=True).view[:] @@ -46,17 +43,14 @@ def data(self) -> Quantity: class _ChunkedNetCDFWriter: FILENAME_FORMAT = "state_{chunk:04d}_tile{tile}.nc" - def __init__( - self, path: str, tile: int, fs: fsspec.AbstractFileSystem, time_chunk_size: int - ): + def __init__(self, path: str, tile: int, time_chunk_size: int) -> None: self._path = path self._tile = tile - self._fs = fs self._time_chunk_size = time_chunk_size self._i_time = 0 - self._chunked: Optional[Dict[str, _TimeChunkedVariable]] = None - self._times: List[Any] = [] - self._time_units: Optional[str] = None + self._chunked: dict[str, _TimeChunkedVariable] | None = None + self._times: list = [] + self._time_units: str | None = None def append(self, state): ndsl_log.debug("appending at time %d", self._i_time) @@ -111,13 +105,13 @@ class NetCDFMonitor: _CONSTANT_FILENAME = "constants" - def __init__( + def __init__( # type: ignore[no-untyped-def] self, path: str, communicator: Communicator, time_chunk_size: int = 1, precision=Float, - ): + ) -> None: """Create a NetCDFMonitor. Args: @@ -128,11 +122,10 @@ def __init__( rank = communicator.rank self._tile_index = communicator.partitioner.tile_index(rank) self._path = path - self._fs = get_fs(path) self._communicator = communicator self._time_chunk_size = time_chunk_size - self.__writer: Optional[_ChunkedNetCDFWriter] = None - self._expected_vars: Optional[Set[str]] = None + self.__writer: _ChunkedNetCDFWriter | None = None + self._expected_vars: set[str] | None = None self._transfer_type = precision if self._transfer_type == np.float32 and get_precision() > 32: warn( @@ -145,7 +138,6 @@ def _writer(self): self.__writer = _ChunkedNetCDFWriter( path=self._path, tile=self._tile_index, - fs=self._fs, time_chunk_size=self._time_chunk_size, ) return self.__writer @@ -178,7 +170,7 @@ def store(self, state: dict) -> None: if state is not None: # we are on root rank self._writer.append(state) - def store_constant(self, state: Dict[str, Quantity]) -> None: + def store_constant(self, state: dict[str, Quantity]) -> None: state = self._communicator.gather_state( state, transfer_type=self._transfer_type ) @@ -189,7 +181,7 @@ def store_constant(self, state: Dict[str, Quantity]) -> None: for name, quantity in state.items(): path_for_grid = constants_filename + "_" + name + ".nc" - if self._fs.exists(path_for_grid): + if os.path.exists(path_for_grid): ds = xr.open_dataset(path_for_grid) ds = ds.load() ds[name] = xr.DataArray( diff --git a/ndsl/monitor/protocol.py b/ndsl/monitor/protocol.py index e1448044..3e574bf1 100644 --- a/ndsl/monitor/protocol.py +++ b/ndsl/monitor/protocol.py @@ -1,4 +1,4 @@ -from typing import Dict, Protocol +from typing import Protocol from ndsl.quantity import Quantity @@ -12,8 +12,6 @@ def store(self, state: dict) -> None: """Append the model state dictionary to the stored data.""" ... - def store_constant(self, state: Dict[str, Quantity]) -> None: - ... + def store_constant(self, state: dict[str, Quantity]) -> None: ... - def cleanup(self): - ... + def cleanup(self) -> None: ... diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 20d5db7a..99deec2b 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from datetime import datetime, timedelta -from typing import List, Tuple, Union +from typing import TypeVar import cftime import xarray as xr @@ -14,22 +16,24 @@ __all__ = ["ZarrMonitor"] +T = TypeVar("T") + class DummyComm: - def Get_rank(self): + def Get_rank(self) -> int: return 0 - def Get_size(self): + def Get_size(self) -> int: return 1 - def bcast(self, value, root=0): + def bcast(self, value: T, root: int = 0) -> T: assert root == 0, ( "DummyComm should only be used on a single core, " "so root should only ever be 0" ) return value - def barrier(self): + def barrier(self) -> None: return @@ -40,11 +44,11 @@ class ZarrMonitor: def __init__( self, - store: Union[str, "zarr.storage.MutableMapping"], + store: str | zarr.storage.MutableMapping, partitioner: Partitioner, mode: str = "w", - mpi_comm=DummyComm(), - ): + mpi_comm: DummyComm | None = None, + ) -> None: """Create a ZarrMonitor. Args: @@ -54,6 +58,9 @@ def __init__( mpi_comm: mpi4py comm object to use for communications. By default, will use a dummy comm object that works in single-core mode. """ + if mpi_comm is None: + mpi_comm = DummyComm() + if mpi_comm.Get_rank() == 0: group = zarr.open_group(store, mode=mode) else: @@ -61,7 +68,7 @@ def __init__( self._group = mpi_comm.bcast(group) self._comm = mpi_comm self._writers = None - self._constants: List[str] = [] + self._constants: list[str] = [] self.partitioner = partitioner def _init_writers(self, state): @@ -127,10 +134,10 @@ def store_constant(self, state: dict) -> None: name=name, partitioner=self.partitioner, ) - constant_writer.append(quantity) # type: ignore[index] + constant_writer.append(quantity) self._constants.append(name) - def cleanup(self): + def cleanup(self) -> None: pass @@ -182,7 +189,7 @@ def _init_zarr_root(self, quantity): fill_value=None, ) - def sync_array(self): + def sync_array(self) -> None: self.array = self.comm.bcast(self.array, root=0) def _match_dim_order(self, quantity): @@ -263,10 +270,10 @@ def _check_units(self, new_quantity): def array_chunks( - layout: Tuple[int, int], - tile_array_shape: Tuple[int, ...], - array_dims: Tuple[str, ...], -): + layout: tuple[int, int], + tile_array_shape: tuple[int, ...], + array_dims: tuple[str, ...], +) -> tuple: layout_by_dims = list_by_dims(array_dims, layout, 1) chunks_list = [] for extent, dim, n_ranks in zip(tile_array_shape, array_dims, layout_by_dims): @@ -379,7 +386,7 @@ def append(self, time): self.comm.barrier() -def get_calendar(time: Union[datetime, timedelta, cftime.datetime]): +def get_calendar(time: datetime | timedelta | cftime.datetime) -> str: try: return time.calendar # type: ignore except AttributeError: diff --git a/ndsl/namelist.py b/ndsl/namelist.py index e040c2cc..00180ace 100644 --- a/ndsl/namelist.py +++ b/ndsl/namelist.py @@ -1,5 +1,6 @@ import dataclasses -from typing import Tuple +import warnings +from typing import Any, Self import f90nml @@ -234,7 +235,7 @@ class NamelistDefaults: """flag for turning on shallow water conditions in dyn core""" @classmethod - def as_dict(cls): + def as_dict(cls) -> dict: return { name: default for name, default in cls.__dict__.items() @@ -264,17 +265,17 @@ class Namelist: """ dycore_only: bool = DEFAULT_BOOL # fdiag: float - # knob_ugwp_azdir: Tuple[int, int, int, int] + # knob_ugwp_azdir: tuple[int, int, int, int] # knob_ugwp_doaxyz: int # knob_ugwp_doheat: int # knob_ugwp_dokdis: int - # knob_ugwp_effac: Tuple[int, int, int, int] + # knob_ugwp_effac: tuple[int, int, int, int] # knob_ugwp_ndx4lh: int # knob_ugwp_solver: int - # knob_ugwp_source: Tuple[int, int, int, int] - # knob_ugwp_stoch: Tuple[int, int, int, int] + # knob_ugwp_source: tuple[int, int, int, int] + # knob_ugwp_stoch: tuple[int, int, int, int] # knob_ugwp_version: int - # knob_ugwp_wvspec: Tuple[int, int, int, int] + # knob_ugwp_wvspec: tuple[int, int, int, int] # launch_level: int # reiflag: int # reimax: float @@ -344,7 +345,7 @@ class Namelist: kord_tm: int = DEFAULT_INT kord_tr: int = DEFAULT_INT kord_wz: int = DEFAULT_INT - layout: Tuple[int, int] = (1, 1) + layout: tuple[int, int] = (1, 1) # make_nh: bool # mountain: bool n_split: int = DEFAULT_INT @@ -608,17 +609,26 @@ class Namelist: """Flag to replace cosz with daily mean value in physics""" @classmethod - def from_f90nml(cls, namelist: f90nml.Namelist): + def from_f90nml(cls, namelist: f90nml.Namelist) -> Self: namelist_dict = namelist_to_flatish_dict(namelist.items()) namelist_dict = { key: value for key, value in namelist_dict.items() - if key in cls.__dataclass_fields__ # type: ignore + if key in cls.__dataclass_fields__ } return cls(**namelist_dict) + def __post_init__(self) -> None: + warnings.warn( + "Usage of `ndsl.Namelist` is discouraged. The class will be " + "removed in the next version together with `NamelistDefaults`, see " + "https://github.com/NOAA-GFDL/NDSL/issues/64.", + DeprecationWarning, + stacklevel=2, + ) -def namelist_to_flatish_dict(nml_input): + +def namelist_to_flatish_dict(nml_input: Any) -> dict: nml = dict(nml_input) for name, value in nml.items(): if isinstance(value, f90nml.Namelist): diff --git a/ndsl/optional_imports.py b/ndsl/optional_imports.py index d1079fb7..4cc81ea6 100644 --- a/ndsl/optional_imports.py +++ b/ndsl/optional_imports.py @@ -1,11 +1,14 @@ +from typing import Any + + class RaiseWhenAccessed: - def __init__(self, err): + def __init__(self, err: ModuleNotFoundError) -> None: self._err = err - def __getattr__(self, _): + def __getattr__(self, _: Any) -> None: raise self._err - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: dict) -> None: raise self._err diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index 28e03bc6..359bc761 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,2 +1,5 @@ from .config import PerformanceConfig from .timer import NullTimer, Timer + + +__all__ = ["PerformanceConfig", "NullTimer", "Timer"] diff --git a/ndsl/performance/collector.py b/ndsl/performance/collector.py index 8ec7a817..c6ec0d97 100644 --- a/ndsl/performance/collector.py +++ b/ndsl/performance/collector.py @@ -2,7 +2,7 @@ import os.path import subprocess from collections.abc import Mapping -from typing import List, Protocol +from typing import Protocol import numpy as np @@ -25,52 +25,49 @@ class AbstractPerformanceCollector(Protocol): total_timer: Timer timestep_timer: Timer - def collect_performance(self): - ... + def collect_performance(self) -> None: ... def write_out_performance( self, backend: str, is_orchestrated: bool, dt_atmos: float, - ): - ... + ) -> None: ... def write_out_rank_0( self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str - ): - ... + ) -> None: ... @classmethod - def start_cuda_profiler(cls): + def start_cuda_profiler(cls) -> None: if GPU_AVAILABLE: cp.cuda.profiler.start() @classmethod - def stop_cuda_profiler(cls): + def stop_cuda_profiler(cls) -> None: if GPU_AVAILABLE: cp.cuda.profiler.stop() @classmethod - def mark_cuda_profiler(cls, message: str): + def mark_cuda_profiler(cls, message: str) -> None: if GPU_AVAILABLE: cp.cuda.nvtx.Mark(message) class PerformanceCollector(AbstractPerformanceCollector): - def __init__(self, experiment_name: str, comm: Comm): - self.times_per_step: List[Mapping[str, float]] = [] - self.hits_per_step: List[Mapping[str, int]] = [] + def __init__(self, experiment_name: str, comm: Comm) -> None: + self.times_per_step: list[Mapping[str, float]] = [] + self.hits_per_step: list[Mapping[str, int]] = [] self.timestep_timer = Timer() self.total_timer = Timer() self.experiment_name = experiment_name self.comm = comm - def clear(self): + def clear(self) -> None: self.times_per_step = [] self.hits_per_step = [] - def collect_performance(self): + def collect_performance(self) -> None: """ Take the accumulated timings and flush them into a new entry in times_per_step and hits_per_step. @@ -81,13 +78,13 @@ def collect_performance(self): def write_out_rank_0( self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str - ): + ) -> None: if self.comm.Get_rank() == 0: git_hash = "None" while {} in self.hits_per_step: self.hits_per_step.remove({}) keys = collect_keys_from_data(self.times_per_step) - data: List[float] = [] + data: list[float] = [] timing_info = {} for timer_name in keys: data.clear() @@ -120,7 +117,7 @@ def write_out_performance( backend: str, is_orchestrated: bool, dt_atmos: float, - ): + ) -> None: if self.comm.Get_rank() == 0: try: driver_path = os.path.dirname(__file__) @@ -146,7 +143,7 @@ def write_out_performance( len(self.hits_per_step) - 1, backend, is_orchestrated, - git_hash, + git_hash, # type: ignore[arg-type] self.comm, self.hits_per_step, self.times_per_step, @@ -156,11 +153,11 @@ def write_out_performance( class NullPerformanceCollector(AbstractPerformanceCollector): - def __init__(self): + def __init__(self) -> None: self.total_timer = NullTimer() self.timestep_timer = NullTimer() - def collect_performance(self): + def collect_performance(self) -> None: pass def write_out_performance( @@ -168,10 +165,10 @@ def write_out_performance( backend: str, is_orchestrated: bool, dt_atmos: float, - ): + ) -> None: pass def write_out_rank_0( self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str - ): + ) -> None: pass diff --git a/ndsl/performance/config.py b/ndsl/performance/config.py index 99e6109f..2f789429 100644 --- a/ndsl/performance/config.py +++ b/ndsl/performance/config.py @@ -33,8 +33,8 @@ def build(self, comm: Comm) -> AbstractPerformanceCollector: else: return NullPerformanceCollector() - def build_profiler(self): + def build_profiler(self) -> Profiler | NullProfiler: if self.collect_cProfile: return Profiler() - else: - return NullProfiler() + + return NullProfiler() diff --git a/ndsl/performance/profiler.py b/ndsl/performance/profiler.py index 3f4f4575..7573587a 100644 --- a/ndsl/performance/profiler.py +++ b/ndsl/performance/profiler.py @@ -2,15 +2,15 @@ class Profiler: - def __init__(self): + def __init__(self) -> None: self._enabled = True self.profiler = cProfile.Profile() self.profiler.disable() - def enable(self): + def enable(self) -> None: self.profiler.enable() - def dump_stats(self, filename: str): + def dump_stats(self, filename: str) -> None: self.profiler.disable() self._enabled = False self.profiler.dump_stats(filename) @@ -27,14 +27,14 @@ class NullProfiler: Meant to be used in place of an optional profiler. """ - def __init__(self): + def __init__(self) -> None: self.profiler = None self._enabled = False - def enable(self): + def enable(self) -> None: pass - def dump_stats(self, filename: str): + def dump_stats(self, filename: str) -> None: pass @property diff --git a/ndsl/performance/report.py b/ndsl/performance/report.py index dbd486d2..a1ccd1da 100644 --- a/ndsl/performance/report.py +++ b/ndsl/performance/report.py @@ -1,8 +1,9 @@ import copy import dataclasses import json +from collections.abc import Mapping from datetime import datetime -from typing import Any, Dict, List, Mapping +from typing import Any import numpy as np @@ -33,7 +34,7 @@ class Report: sim_status: str = "Finished" SYPD: float = 0.0 - def __post_init__(self): + def __post_init__(self) -> None: self.SYPD = get_sypd(self.times, self.dt_atmos) @@ -56,7 +57,7 @@ def get_experiment_info( return experiment -def collect_keys_from_data(times_per_step: List[Mapping[str, float]]) -> List[str]: +def collect_keys_from_data(times_per_step: list[Mapping[str, float]]) -> list[str]: """Collects all the keys in the list of dicts and returns a sorted version""" keys = set() for data_point in times_per_step: @@ -68,15 +69,15 @@ def collect_keys_from_data(times_per_step: List[Mapping[str, float]]) -> List[st def gather_timing_data( - times_per_step: List[Mapping[str, float]], - comm, + times_per_step: list[Mapping[str, float]], + comm: Comm, root: int = 0, -) -> Dict[str, Any]: +) -> dict[str, Any]: """returns an updated version of the results dictionary owned by the root node to hold data on the substeps as well as the main loop timers""" is_root = comm.Get_rank() == root keys = collect_keys_from_data(times_per_step) - data: List[float] = [] + data: list[float] = [] timing_info = {} for timer_name in keys: data.clear() @@ -91,7 +92,7 @@ def gather_timing_data( comm.Gather(sendbuf, recvbuf, root=0) if is_root: timing_info[timer_name] = TimeReport( - hits=0, times=copy.deepcopy(recvbuf.tolist()) + hits=0, times=copy.deepcopy(recvbuf.tolist()) # type: ignore[union-attr] # (recvbuf is defined on root rank) ) return timing_info @@ -104,8 +105,8 @@ def write_to_timestamped_json(experiment: Report) -> None: def gather_hit_counts( - hits_per_step: List[Mapping[str, int]], timing_info: Dict[str, TimeReport] -) -> Dict[str, TimeReport]: + hits_per_step: list[Mapping[str, int]], timing_info: dict[str, TimeReport] +) -> dict[str, TimeReport]: """collects the hit count across all timers called in a program execution""" for data_point in hits_per_step: for name, value in data_point.items(): @@ -113,7 +114,7 @@ def gather_hit_counts( return timing_info -def get_sypd(timing_info: Dict[str, TimeReport], dt_atmos: float) -> float: +def get_sypd(timing_info: dict[str, TimeReport], dt_atmos: float) -> float: if "mainloop" in timing_info: is_list_of_list = any( isinstance(el, list) for el in timing_info["mainloop"].times @@ -135,8 +136,8 @@ def collect_data_and_write_to_file( is_orchestrated: bool, git_hash: str, comm: Comm, - hits_per_step: List, - times_per_step: List, + hits_per_step: list, + times_per_step: list, experiment_name: str, dt_atmos: float, ) -> None: diff --git a/ndsl/performance/timer.py b/ndsl/performance/timer.py index e4629923..7d5a08ee 100644 --- a/ndsl/performance/timer.py +++ b/ndsl/performance/timer.py @@ -1,6 +1,6 @@ import warnings +from collections.abc import Mapping from timeit import default_timer as time -from typing import Mapping from ndsl.optional_imports import cupy as cp from ndsl.utils import GPU_AVAILABLE @@ -9,16 +9,16 @@ class Timer: """Class to accumulate timings for named operations.""" - def __init__(self): - self._clock_starts = {} - self._accumulated_time = {} - self._hit_count = {} - self._enabled = True + def __init__(self) -> None: + self._clock_starts: dict = {} + self._accumulated_time: dict = {} + self._hit_count: dict = {} + self._enabled: bool = True # Check if we have CUDA device and it's ready to # perform tasks - self._can_time_CUDA = GPU_AVAILABLE + self._can_time_CUDA: bool = GPU_AVAILABLE - def start(self, name: str): + def start(self, name: str) -> None: """Start timing a given named operation.""" if self._can_time_CUDA: cp.cuda.Device(0).synchronize() @@ -29,7 +29,7 @@ def start(self, name: str): else: self._clock_starts[name] = time() - def stop(self, name: str): + def stop(self, name: str) -> None: """Stop timing a given named operation, add the time elapsed to accumulated timing and increase the hit count. """ @@ -46,7 +46,7 @@ def stop(self, name: str): else: self._hit_count[name] += 1 - def clock(self, name: str): + def clock(self, name: str): # type: ignore """Context manager to produce timings of operations. Args: @@ -70,20 +70,20 @@ def clock(self, name: str): # which self-destroys itself when called, we can't orchestrate # it easily in DaCe. Waiting for a fix DaCe side to this Python # ridiculousness (see contelib.py:_GeneratorContextManager.__enter__) - def dace_inhibitor(func): + def dace_inhibitor(func): # type: ignore[no-untyped-def] return func class Wrapper: - def __init__(self, timer, name) -> None: + def __init__(self, timer: Timer, name: str) -> None: self.timer = timer self.name = name @dace_inhibitor - def __enter__(self): + def __enter__(self) -> None: self.timer.start(name) @dace_inhibitor - def __exit__(self, type, value, traceback): + def __exit__(self, type, value, traceback): # type: ignore[no-untyped-def] self.timer.stop(name) return Wrapper(self, name) @@ -97,6 +97,7 @@ def times(self) -> Mapping[str, float]: "incomplete times are not included: " f"{list(self._clock_starts.keys())}", RuntimeWarning, + stacklevel=2, ) return self._accumulated_time.copy() @@ -109,19 +110,20 @@ def hits(self) -> Mapping[str, int]: "incomplete times are not included: " f"{list(self._clock_starts.keys())}", RuntimeWarning, + stacklevel=2, ) return self._hit_count.copy() - def reset(self): + def reset(self) -> None: """Remove all accumulated timings.""" self._accumulated_time.clear() self._hit_count.clear() - def enable(self): + def enable(self) -> None: """Enable the Timer.""" self._enabled = True - def disable(self): + def disable(self) -> None: """Disable the Timer.""" if len(self._clock_starts) > 0: raise RuntimeError( @@ -142,11 +144,11 @@ class NullTimer(Timer): Meant to be used in place of an optional timer. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self._enabled = False + self._enabled: bool = False - def enable(self): + def enable(self) -> None: """Enable the Timer.""" raise NotImplementedError( "NullTimer cannot be enabled, maybe create a Timer and " diff --git a/ndsl/performance/tools.py b/ndsl/performance/tools.py index 7a20ecd8..80ed377f 100644 --- a/ndsl/performance/tools.py +++ b/ndsl/performance/tools.py @@ -1,5 +1,3 @@ -from typing import Optional - import click from ndsl.dsl.dace.utils import ( @@ -47,12 +45,12 @@ ) def command_line( action: str, - sdfg_path: Optional[str], - report_detail: Optional[bool], - hardware_bw_in_gb_s: Optional[float], - output_format: Optional[str], - backend: Optional[str], -): + sdfg_path: str, + report_detail: bool, + hardware_bw_in_gb_s: float | None, + output_format: str | None, + backend: str, +) -> None: """ Run tooling. """ @@ -62,10 +60,10 @@ def command_line( print( kernel_theoretical_timing_from_path( sdfg_path, + backend=backend, hardware_bw_in_GB_s=( None if hardware_bw_in_gb_s == 0 else hardware_bw_in_gb_s ), - backend=backend, output_format=output_format, ) ) diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index 43751528..26120596 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,11 +1,9 @@ from .metadata import QuantityHaloSpec, QuantityMetadata from .quantity import Quantity +from .state import State -__all__ = [ - "Quantity", - "QuantityMetadata", - "QuantityHaloSpec", - "FieldBundle", - "FieldBundleType", -] +from .local import Local # isort: skip + + +__all__ = ["Local", "Quantity", "QuantityMetadata", "QuantityHaloSpec", "State"] diff --git a/ndsl/quantity/bounds.py b/ndsl/quantity/bounds.py index 419cff22..c9618721 100644 --- a/ndsl/quantity/bounds.py +++ b/ndsl/quantity/bounds.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple, Union +from collections.abc import Sequence import numpy as np @@ -44,7 +44,7 @@ def _get_array_index(self, index): self._dims, self._origin, self._extent, self._boundary_type, index ) - def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: + def sel(self, **kwargs: slice | int) -> np.ndarray: """Convenience method to perform indexing using dimension names without knowing dimension order. @@ -80,9 +80,9 @@ class BoundedArrayView: array, while view.interior[-1:1, -1:1, :] would also include one halo point. """ - def __init__( + def __init__( # type: ignore self, array, dims: Sequence[str], origin: Sequence[int], extent: Sequence[int] - ): + ) -> None: self._data = array self._dims = tuple(dims) self._origin = tuple(origin) @@ -104,12 +104,12 @@ def __init__( ) @property - def origin(self) -> Tuple[int, ...]: + def origin(self) -> tuple[int, ...]: """the start of the computational domain""" return self._origin @property - def extent(self) -> Tuple[int, ...]: + def extent(self) -> tuple[int, ...]: """the shape of the computational domain""" return self._extent diff --git a/ndsl/quantity/field_bundle.py b/ndsl/quantity/field_bundle.py index df637fe2..f33ae8a6 100644 --- a/ndsl/quantity/field_bundle.py +++ b/ndsl/quantity/field_bundle.py @@ -25,13 +25,14 @@ class FieldBundle: """ _quantity: Quantity + _per_name_view: dict[str, Quantity] = {} _indexer: _FieldBundleIndexer = {} def __init__( self, bundle_name: str, quantity: Quantity, - mapping: _FieldBundleIndexer = {}, + mapping: _FieldBundleIndexer | None = None, register_type: bool = False, ): """ @@ -45,6 +46,9 @@ def __init__( mapping: sparse dict of [name, index] to be able to call tracers by name. register_type: boolean to register the type as part of initialization. """ + if mapping is None: + mapping = {} + if len(quantity.shape) != 4: raise NotImplementedError("FieldBundle implementation restricted to 4D") @@ -56,7 +60,7 @@ def __init__( assert len(quantity.shape) == 4 FieldBundleType.register(bundle_name, quantity.shape[3:]) - def map(self, index: _DataDimensionIndex, name: str): + def map(self, index: _DataDimensionIndex, name: str) -> None: """Map a single `index` to ` name`""" self._indexer[name] = index @@ -80,25 +84,31 @@ def __getattr__(self, name: str) -> Quantity: return None # type: ignore # ToDo: extend the dims below to work with more than 4 dims assert len(self._quantity.data.shape) == 4 - return Quantity( - data=self._quantity.data[:, :, :, self.index(name)], - dims=self._quantity.dims[:-1], - units=self._quantity.units, - origin=self._quantity.origin[:-1], - extent=self._quantity.extent[:-1], - ) + + if name not in self._per_name_view: + # Memoize the Quantities returned here to ensue that we only ever + # have one `field.a_name`-Quantity floating around. If not, DaCe + # orchestration gets (rightly so) confused. + self._per_name_view[name] = Quantity( + data=self._quantity.data[:, :, :, self.index(name)], + dims=self._quantity.dims[:-1], + units=self._quantity.units, + origin=self._quantity.origin[:-1], + extent=self._quantity.extent[:-1], + ) + return self._per_name_view[name] def index(self, name: str) -> int: """Get index from name.""" return self._indexer[name] @property - def __array_interface__(self): + def __array_interface__(self): # type: ignore[no-untyped-def] """Memory interface for CPU.""" return self._quantity.__array_interface__ @property - def __cuda_array_interface__(self): + def __cuda_array_interface__(self): # type: ignore[no-untyped-def] """Memory interface for GPU memory as defined by cupy.""" return self._quantity.__cuda_array_interface__ @@ -118,8 +128,8 @@ def extend_3D_quantity_factory( extra_dims: dict of [name, size] of the data dimensions to add. """ new_factory = copy.copy(quantity_factory) - new_factory.set_extra_dim_lengths( - **{ + new_factory.add_data_dimensions( + { **extra_dims, } ) @@ -141,17 +151,17 @@ class FieldBundleType: """Field Bundle Types to help with static sizing of Data Dimensions. Methods: - register: Register a type by sizing it's data dimensions + register: Register a type by sizing its data dimensions T: access any registered types for type hinting. """ _field_type_registrar: dict[str, gtscript._FieldDescriptor] = {} @classmethod - def register( + def register( # type: ignore cls, name: str, data_dims: tuple[int], dtype=Float ) -> gtscript._FieldDescriptor: - """Register a name type by name by giving the size of it's data dimensions. + """Register a name type by name by giving the size of its data dimensions. The same type cannot be registered twice and will error out. diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py new file mode 100644 index 00000000..910999ab --- /dev/null +++ b/ndsl/quantity/local.py @@ -0,0 +1,43 @@ +from typing import Any, Sequence + +import dace +import numpy as np + +from ndsl.optional_imports import cupy +from ndsl.quantity import Quantity + + +if cupy is None: + import numpy as cupy + + +class Local(Quantity): + """Local is a Quantity that cannot be used outside of the class + it was allocated in.""" + + def __init__( + self, + data: np.ndarray | cupy.ndarray, + dims: Sequence[str], + units: str, + origin: Sequence[int] | None = None, + extent: Sequence[int] | None = None, + gt4py_backend: str | None = None, + allow_mismatch_float_precision: bool = False, + ): + super().__init__( + data, + dims, + units, + origin, + extent, + gt4py_backend, + allow_mismatch_float_precision, + ) + self._transient = True + + def __descriptor__(self) -> Any: + """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" + data = dace.data.create_datadescriptor(self.data) + data.transient = True + return data diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index d7ddba0f..7e7b4f16 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import dataclasses -from typing import Any, Dict, Tuple, Union +from typing import Any import numpy as np @@ -13,11 +15,13 @@ @dataclasses.dataclass class QuantityMetadata: - origin: Tuple[int, ...] + origin: tuple[int, ...] "the start of the computational domain" - extent: Tuple[int, ...] + extent: tuple[int, ...] "the shape of the computational domain" - dims: Tuple[str, ...] + n_halo: int + "Number of halo-points used in the horizontal" + dims: tuple[str, ...] "names of each dimension" units: str "units of the quantity" @@ -25,11 +29,11 @@ class QuantityMetadata: "ndarray-like type used to store the data" dtype: type "dtype of the data in the ndarray-like object" - gt4py_backend: Union[str, None] = None + gt4py_backend: str | None = None "backend to use for gt4py storages" @property - def dim_lengths(self) -> Dict[str, int]: + def dim_lengths(self) -> dict[str, int]: """mapping of dimension names to their lengths""" return dict(zip(self.dims, self.extent)) @@ -45,7 +49,7 @@ def np(self) -> NumpyModule: f"quantity underlying data is of unexpected type {self.data_type}" ) - def duplicate_metadata(self, metadata_copy): + def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None: metadata_copy.origin = self.origin metadata_copy.extent = self.extent metadata_copy.dims = self.dims @@ -60,11 +64,11 @@ class QuantityHaloSpec: """Describe the memory to be exchanged, including size of the halo.""" n_points: int - strides: Tuple[int] + strides: tuple[int] itemsize: int - shape: Tuple[int] - origin: Tuple[int, ...] - extent: Tuple[int, ...] - dims: Tuple[str, ...] + shape: tuple[int] + origin: tuple[int, ...] + extent: tuple[int, ...] + dims: tuple[str, ...] numpy_module: NumpyModule dtype: Any diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 6d83cd89..45312d2f 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import warnings -from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast +from collections.abc import Iterable, Sequence +from typing import Any, cast import dace import matplotlib.pyplot as plt @@ -7,9 +10,9 @@ import xarray as xr from gt4py import storage as gt_storage from gt4py.cartesian import backend as gt_backend -from mpi4py import MPI import ndsl.constants as constants +from ndsl.comm.mpi import MPI from ndsl.dsl.typing import Float, is_float from ndsl.optional_imports import cupy from ndsl.quantity.bounds import BoundedArrayView @@ -22,35 +25,40 @@ class Quantity: - """ - Data container for physical quantities. - """ + """Data container for physical quantities.""" def __init__( self, - data, + data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, - origin: Optional[Sequence[int]] = None, - extent: Optional[Sequence[int]] = None, - gt4py_backend: Union[str, None] = None, + origin: Sequence[int] | None = None, + extent: Sequence[int] | None = None, + gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, + number_of_halo_points: int = 0, ): - """ - Initialize a Quantity. + """Initialize a Quantity. Args: - data: ndarray-like object containing the underlying data - dims: dimension names for each axis - units: units of the quantity - origin: first point in data within the computational domain - extent: number of points along each axis within the computational domain - gt4py_backend: backend to use for gt4py storages, if not given this will - be derived from a Storage if given as the data argument, otherwise the - storage attribute is disabled and will raise an exception. Will raise - a TypeError if this is given with a gt4py storage type as data - """ + data (_type_): ndarray-like object containing the underlying data + dims (Sequence[str]): dimension names for each axis + units (str): units of the quantity + origin (Sequence[int] | None, optional): first point in data within the + computational domain. Defaults to None. + extent (Sequence[int] | None, optional): number of points along each axis + within the computational domain. Defaults to None. + gt4py_backend (str | None, optional): backend to use for gt4py storages, + if not given this will be derived from a Storage + if given as the data argument. Defaults to None. + allow_mismatch_float_precision (bool, optional): allow for precision that is + not the simulation-wide default configuration. Defaults to False. + number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + Raises: + ValueError: Data-type mismatch between configuration and input-data + TypeError: Typing of the data that does not fit + """ if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -81,14 +89,15 @@ def __init__( if gt4py_backend is not None: gt4py_backend_cls = gt_backend.from_name(gt4py_backend) - assert gt4py_backend_cls is not None is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] - dimensions: Tuple[Union[str, int], ...] = tuple( + dimensions: tuple[str | int, ...] = tuple( [ - axis - if any(dim in axis_dims for axis_dims in constants.SPATIAL_DIMS) - else str(data.shape[index]) + ( + axis # type: ignore # mypy can't parse this list construction of hell + if any(dim in axis_dims for axis_dims in constants.SPATIAL_DIMS) + else str(data.shape[index]) + ) for index, (dim, axis) in enumerate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) @@ -106,8 +115,6 @@ def __init__( ) ) else: - if data is None: - raise TypeError("requires 'data' to be passed") # We have no info about the gt4py_backend, so just assign it. self._data = data @@ -115,6 +122,7 @@ def __init__( self._metadata = QuantityMetadata( origin=_ensure_int_tuple(origin, "origin"), extent=_ensure_int_tuple(extent, "extent"), + n_halo=number_of_halo_points, dims=tuple(dims), units=units, data_type=type(self._data), @@ -130,10 +138,11 @@ def __init__( def from_data_array( cls, data_array: xr.DataArray, - origin: Sequence[int] = None, - extent: Sequence[int] = None, - gt4py_backend: Union[str, None] = None, - ) -> "Quantity": + origin: Sequence[int] | None = None, + extent: Sequence[int] | None = None, + gt4py_backend: str | None = None, + number_of_halo_points: int = 0, + ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -149,14 +158,17 @@ def from_data_array( raise ValueError("need units attribute to create Quantity from DataArray") return cls( data_array.values, - cast(Tuple[str], data_array.dims), + cast(tuple[str], data_array.dims), data_array.attrs["units"], origin=origin, extent=extent, + number_of_halo_points=number_of_halo_points, gt4py_backend=gt4py_backend, ) - def to_netcdf(self, path: str, name="var", rank: int = -1, all_data=False) -> None: + def to_netcdf( + self, path: str, name: str = "var", rank: int = -1, all_data: bool = False + ) -> None: if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: if all_data: self.data_as_xarray.to_dataset(name=name).to_netcdf( @@ -168,6 +180,15 @@ def to_netcdf(self, path: str, name="var", rank: int = -1, all_data=False) -> No ) def halo_spec(self, n_halo: int) -> QuantityHaloSpec: + # This is a preliminary check to see if this is ever triggered. + # If not, we can remove it down the line and change the call signature. + if n_halo != self._metadata.n_halo: + warnings.warn( + "Found inconsistency with number of halo points in Quantity:" + + f"{n_halo} vs {self._metadata.n_halo}", + UserWarning, + stacklevel=2, + ) return QuantityHaloSpec( n_halo, self.data.strides, @@ -180,14 +201,14 @@ def halo_spec(self, n_halo: int) -> QuantityHaloSpec: self.metadata.dtype, ) - def __repr__(self): + def __repr__(self) -> str: return ( f"Quantity(\n data=\n{self.data},\n dims={self.dims},\n" f" units={self.units},\n origin={self.origin},\n" f" extent={self.extent}\n)" ) - def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: + def sel(self, **kwargs: slice | int) -> np.ndarray: """Convenience method to perform indexing on `view` using dimension names without knowing dimension order. @@ -200,7 +221,7 @@ def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: """ return self.view[tuple(kwargs.get(dim, slice(None, None)) for dim in self.dims)] - def _initialize_data(self, data, origin, gt4py_backend: str, dimensions: Tuple): + def _initialize_data(self, data, origin, gt4py_backend: str, dimensions: tuple): # type: ignore """Allocates an ndarray with optimal memory layout, and copies the data over.""" storage = gt_storage.from_array( data, @@ -217,11 +238,11 @@ def metadata(self) -> QuantityMetadata: @property def units(self) -> str: - """units of the quantity""" + """Units of the quantity""" return self.metadata.units @property - def gt4py_backend(self) -> Union[str, None]: + def gt4py_backend(self) -> str | None: return self.metadata.gt4py_backend @property @@ -229,8 +250,8 @@ def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @property - def dims(self) -> Tuple[str, ...]: - """names of each dimension""" + def dims(self) -> tuple[str, ...]: + """Names of each dimension""" return self.metadata.dims @property @@ -239,6 +260,7 @@ def values(self) -> np.ndarray: "values exists only for backwards-compatibility with " "DataArray and will be removed, use .view[:] instead", DeprecationWarning, + stacklevel=2, ) return_array = np.asarray(self.view[:]) return_array.flags.writeable = False @@ -246,7 +268,7 @@ def values(self) -> np.ndarray: @property def view(self) -> BoundedArrayView: - """a view into the computational domain of the underlying data""" + """A view into the computational domain of the underlying data""" return self._compute_domain_view @property @@ -254,23 +276,38 @@ def field(self) -> np.ndarray | cupy.ndarray: return self._compute_domain_view[:] @property - def data(self) -> Union[np.ndarray, cupy.ndarray]: - """the underlying array of data""" + def data(self) -> np.ndarray | cupy.ndarray: + """The underlying array of data""" return self._data @data.setter - def data(self, inputData): - if type(inputData) in [np.ndarray, cupy.ndarray]: - self._data = inputData + def data(self, input_data: np.ndarray | cupy.ndarray) -> None: + if type(input_data) not in [np.ndarray, cupy.ndarray]: + raise TypeError( + "Quantity.data buffer swap failed: " + f"given data is not an array (type: {type(input_data)})" + ) + + if input_data.shape < self.extent: + raise ValueError( + "Quantity.data buffer swap failed: " + f"new data ({input_data.shape}) is smaller " + f"than expected extent ({self.extent})." + ) + + self._data = input_data + self._compute_domain_view = BoundedArrayView( + self.data, self.dims, self.origin, self.extent + ) @property - def origin(self) -> Tuple[int, ...]: - """the start of the computational domain""" + def origin(self) -> tuple[int, ...]: + """The start of the computational domain""" return self.metadata.origin @property - def extent(self) -> Tuple[int, ...]: - """the shape of the computational domain""" + def extent(self) -> tuple[int, ...]: + """The shape of the computational domain""" return self.metadata.extent @property @@ -288,19 +325,20 @@ def np(self) -> NumpyModule: return self.metadata.np @property - def __array_interface__(self): + def __array_interface__(self): # type: ignore[no-untyped-def] return self.data.__array_interface__ @property - def __cuda_array_interface__(self): + def __cuda_array_interface__(self): # type: ignore[no-untyped-def] return self.data.__cuda_array_interface__ @property - def shape(self): + def shape(self): # type: ignore[no-untyped-def] return self.data.shape def __descriptor__(self) -> Any: """The descriptor is a property that dace uses. + This relies on `dace` capacity to read out data from the buffer protocol. If the internal data given doesn't follow the protocol it will most likely fail. @@ -309,9 +347,9 @@ def __descriptor__(self) -> Any: def transpose( self, - target_dims: Sequence[Union[str, Iterable[str]]], + target_dims: Sequence[str | Iterable[str]], allow_mismatch_float_precision: bool = False, - ) -> "Quantity": + ) -> Quantity: """Change the dimension order of this Quantity. Args: @@ -365,7 +403,7 @@ def transpose( transposed._attrs = self._attrs return transposed - def plot_k_level(self, k_index=0): + def plot_k_level(self, k_index: int = 0) -> None: field = self.data print( "Min and max values:", @@ -382,11 +420,13 @@ def plot_k_level(self, k_index=0): plt.show() -def _transpose_sequence(sequence, order): +def _transpose_sequence(sequence, order): # type: ignore[no-untyped-def] return sequence.__class__(sequence[i] for i in order) -def _collapse_dims(target_dims, dims): +def _collapse_dims( + target_dims: Sequence[str | Iterable[str]], dims: tuple[str, ...] +) -> list[str]: return_list = [] for target in target_dims: if isinstance(target, str): @@ -412,7 +452,7 @@ def _collapse_dims(target_dims, dims): return return_list -def _validate_quantity_property_lengths(shape, dims, origin, extent): +def _validate_quantity_property_lengths(shape, dims, origin, extent): # type: ignore[no-untyped-def] n_dims = len(shape) for var, desc in ( (dims, "dimension names"), @@ -425,7 +465,7 @@ def _validate_quantity_property_lengths(shape, dims, origin, extent): ) -def _ensure_int_tuple(arg, arg_name): +def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple: return_list = [] for item in arg: try: diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py new file mode 100644 index 00000000..3bbe0684 --- /dev/null +++ b/ndsl/quantity/state.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import dataclasses +from collections.abc import Callable +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING, Any, Self, TypeAlias + +import dacite +import xarray as xr +from numpy.typing import ArrayLike + +from ndsl.comm.mpi import MPI +from ndsl.types import Number + + +if TYPE_CHECKING: + from ndsl import QuantityFactory + +StateMemoryMapping: TypeAlias = dict[str, dict | ArrayLike | None] + + +@dataclasses.dataclass +class State: + """Base class for state objects in models. + + A State groups a collection of (possibly nested) Quantities in a dataclass. + + This baseclass implements common initialization functions and serialization. + + Typical usage example: + + ```python + class MyState(State): + pass + + my_state = MyState.zeros(quantity_factory) + + # ... + + my_state.to_netcdf() + ``` + """ + + @classmethod + def _init(cls, quantity_factory_allocator: Callable) -> Self: + """Allocate memory and init with a blind quantity init operation""" + + def _init_recursive(cls: Any) -> dict: + initial_quantities = {} + for _field in dataclasses.fields(cls): + if dataclasses.is_dataclass(_field.type): + initial_quantities[_field.name] = _init_recursive(_field.type) + else: + if "dims" not in _field.metadata.keys(): + raise ValueError( + "Malformed state - no dims to init " + f"Quantity in {_field.name} of type {_field.type}" + ) + + initial_quantities[_field.name] = quantity_factory_allocator( + _field.metadata["dims"], + _field.metadata["units"], + dtype=_field.metadata["dtype"], + allow_mismatch_float_precision=True, + ) + + return initial_quantities + + dict_of_quantities = _init_recursive(cls) + return dacite.from_dict(data_class=cls, data=dict_of_quantities) + + class _FactorySwapDimensionsDefinitions: + """INTERNAL: QuantityFactory carry a sizer which has a full definition of the dimensions. + It's this sizer that is leveraged for the factory to figure out allocations. + In a regular pattern of use, data dimensions fields tend to be _the exception_ rather + than the rule and therefore would need a Factory defined _for a few cases_. + We bring this tool to override temporarily the allocations based on a single descriptions of + the data dimensions at allocation time. + """ + + def __init__(self, factory: QuantityFactory, ddims: dict[str, int]): + self._ddims = ddims + self._factory = factory + + def __enter__(self) -> None: + self._original_dims = self._factory.sizer.data_dimensions + self._factory.sizer.data_dimensions = self._ddims + + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._factory.sizer.data_dimensions = self._original_dims + + @classmethod + def empty( + cls, + quantity_factory: QuantityFactory, + *, + data_dimensions: dict[str, int] | None = None, + ) -> Self: + """Allocate all quantities. Do not expect 0 on values, values are random. + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + if data_dimensions is None: + data_dimensions = {} + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.empty) + return state + + @classmethod + def zeros( + cls, + quantity_factory: QuantityFactory, + *, + data_dimensions: dict[str, int] | None = None, + ) -> Self: + """Allocate all quantities and fill their value to zeros + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + if data_dimensions is None: + data_dimensions = {} + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.zeros) + return state + + @classmethod + def ones( + cls, + quantity_factory: QuantityFactory, + *, + data_dimensions: dict[str, int] | None = None, + ) -> Self: + """Allocate all quantities and fill their value to ones + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + if data_dimensions is None: + data_dimensions = {} + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.ones) + return state + + @classmethod + def full( + cls, + quantity_factory: QuantityFactory, + value: Number, + *, + data_dimensions: dict[str, int] | None = None, + ) -> Self: + """Allocate all quantities and fill them with the input value + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + value: number to initialize the buffers with. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + if data_dimensions is None: + data_dimensions = {} + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.empty) + state.fill(value) + return state + + @classmethod + def copy_memory( + cls, + quantity_factory: QuantityFactory, + memory_map: StateMemoryMapping, + *, + data_dimensions: dict[str, int] | None = None, + ) -> Self: + """Allocate all quantities and fill their value based + on the given memory map. See `update_from_memory`. + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + memory_map: Dict of name/buffer. See `update_from_memory`. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + if data_dimensions is None: + data_dimensions = {} + + state = cls.zeros(quantity_factory, data_dimensions=data_dimensions) + state.update_copy_memory(memory_map) + + return state + + @classmethod + def move_memory( + cls, + quantity_factory: QuantityFactory, + memory_map: StateMemoryMapping, + *, + data_dimensions: dict[str, int] | None = None, + check_shape_and_strides: bool = True, + ) -> Self: + """Allocate all quantities and move memory based on + on the given memory map. See `update_move_memory`. + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + memory_map: Dict of name/buffer. See `update_from_memory`. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + check_shape_and_strides: Check for every given buffer that the shape & strides match the + previously allocated memory. + """ + if data_dimensions is None: + data_dimensions = {} + + state = cls.zeros(quantity_factory, data_dimensions=data_dimensions) + state.update_move_memory( + memory_map, + check_shape_and_strides=check_shape_and_strides, + ) + + return state + + def fill(self, value: Number) -> None: + def _fill_recursive( + state: State, + value: Number, + ) -> None: + for _field in dataclasses.fields(state): + if dataclasses.is_dataclass(_field.type): + _fill_recursive(state.__getattribute__(_field.name), value) + else: + state.__getattribute__(_field.name).field[:] = value + + _fill_recursive(self, value) + + def update_copy_memory(self, memory_map: dict[str, Any]) -> None: + """Copy data into the Quantities carried by the state. + + The memory map must follow the dataclass naming convention, e.g. + + ```python + @dataclass + class MyState: + @dataclass + class InnerA + a: Quantity + + inner_a: InnerA + b: Quantity + ``` + will update with a dictionary looking like + ```python + { + "inner_a": + { + "a": Quantity(...) + } + "b": Quantity(...) + + } + ``` + + The memory map can be sparse. + """ + + def _update_from_memory_recursive( + state: State, + memory_map: StateMemoryMapping, + ) -> None: + for name, array in memory_map.items(): + if array is None: + raise TypeError( + f"State memory copy: illegal copy from None for attribute {name}" + ) + elif isinstance(array, dict): + _update_from_memory_recursive(state.__getattribute__(name), array) + else: + try: + state.__getattribute__(name).field[:] = array + except Exception as e: + e.add_note( + f"Error when initializing field {name} on state {type(self)}" + ) + raise e + + _update_from_memory_recursive(self, memory_map) + + def update_move_memory( + self, + memory_map: StateMemoryMapping, + *, + check_shape_and_strides: bool = True, + ) -> None: + """Move memory into the Quantities carried by the state. + Memory is moved rather than copied (e.g. buffers are swapped) + + The memory map must follow the dataclass naming convention, e.g. + + ```python + @dataclass + class MyState: + @dataclass + class InnerA + a: Quantity + + inner_a: InnerA + b: Quantity + ``` + will update with a dictionary looking like + ```python + { + "inner_a": + { + "a": Quantity(...) + } + "b": Quantity(...) + + } + ``` + + The memory map can be sparse. + + Args: + memory_map: Dictionary of keys to buffers. Buffers must be np.ArrayLike + check_shape_and_strides: check that the given buffers have the same + shape and strides as the original quantity + """ + + def _update_zero_copy_recursive( + state: State, memory_map: StateMemoryMapping + ) -> None: + for name, array in memory_map.items(): + if array is None: + state.__setattr__(name, None) + elif isinstance(array, dict): + _update_zero_copy_recursive(state.__getattribute__(name), array) + else: + quantity = state.__getattribute__(name) + if check_shape_and_strides: + if array.shape != quantity.field.shape: + e = ValueError("Shape mismatch on zero copy for") + e.add_note(f" Error on {name} for {type(state)}") + e.add_note( + f" Shapes: {array.shape} != {quantity.field.shape}" + ) + raise e + if array.strides != quantity.data.strides: + e = ValueError("Stride mismatch on zero copy for") + e.add_note(f" Error on {name} for {type(state)}") + e.add_note( + f" Strides: {array.strides} != {quantity.data.strides}" + ) + raise e + try: + quantity.data = array + except Exception as e: + e.add_note(f" Error on {name} for {type(state)}") + raise e + + _update_zero_copy_recursive(self, memory_map) + + def _netcdf_name(self, directory_path: Path) -> Path: + """Resolve rank-tied postfix if needed""" + rank_postfix = "" + if MPI.COMM_WORLD.Get_size() > 1: + rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" + return directory_path / f"{type(self).__name__}{rank_postfix}.nc4" + + def to_netcdf(self, directory_path: Path | None = None) -> None: + """ + Save state to NetCDF. Can be reloaded with `update_from_netcdf`. + + If applicable, will save separate NetCDF files for each running rank. + + The file names are deduced from the class name, and post fix with rank number + in the case of a multi-process use. + + Args: + directory_path: directory to save the netcdf in + """ + if directory_path is None: + directory_path = Path("./") + + def _save_recursive(state: State) -> dict: + local_data = {} + for _field in dataclasses.fields(state): + if dataclasses.is_dataclass(_field.type): + local_data[_field.name] = xr.Dataset( + data_vars=_save_recursive(state.__getattribute__(_field.name)) + ) + else: + if "dims" not in _field.metadata.keys(): + raise ValueError( + "Malformed state - no dims to init " + f"Quantity in {_field.name} of type {_field.type}" + ) + + local_data[_field.name] = state.__getattribute__( + _field.name + ).field_as_xarray + + return local_data + + datatree = _save_recursive(self) + + # Move top-level into their own dataset in the "/" prefix + # to match DataTree expected format + top_level = {} + for key, value in datatree.items(): + if not isinstance(value, xr.Dataset): + top_level[key] = value + for key, _value in top_level.items(): + datatree.pop(key) + datatree["/"] = xr.Dataset(data_vars=top_level) + + xr.DataTree.from_dict(datatree).to_netcdf(self._netcdf_name(directory_path)) + + def update_from_netcdf(self, directory_path: Path) -> None: + """This is a mirror of the `to_netcdf` method NOT a generic + NetCDF loader. It expects the NetCDF to be named with the auto-naming scheme + of `to_netcdf`. + + Args: + directory_path: directory carrying the netcdf saved with `to_netcdf` + + """ + datatree = xr.open_datatree(self._netcdf_name(directory_path)) + datatree_as_dict = datatree.to_dict() + + # All other cases - recursing downward + def _load_recursive( + data_tree_as_dict: dict[str, xr.Dataset] | xr.Dataset, + ) -> dict: + local_data_dict = {} + for name, data_array in data_tree_as_dict.items(): + # Case of the top_level "/" + if name == "/": + for root_name, root_data_array in datatree_as_dict["/"].items(): + local_data_dict[root_name] = root_data_array.to_numpy() + else: + # Get the leading `/` out + if isinstance(data_array, xr.Dataset): + local_data_dict[name[1:]] = _load_recursive(data_array) + else: + local_data_dict[name] = data_array.to_numpy() + + return local_data_dict + + data_as_numpy_dict = _load_recursive(datatree_as_dict) + + self.update_copy_memory(data_as_numpy_dict) diff --git a/ndsl/restart/_legacy_restart.py b/ndsl/restart/_legacy_restart.py index 7983f8a9..9cf577fe 100644 --- a/ndsl/restart/_legacy_restart.py +++ b/ndsl/restart/_legacy_restart.py @@ -1,11 +1,11 @@ import copy import os -from typing import BinaryIO, Generator, Iterable +from collections.abc import Generator, Iterable +from typing import BinaryIO import xarray as xr import ndsl.constants as constants -import ndsl.filesystem as filesystem import ndsl.io as io from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import get_tile_index @@ -24,9 +24,9 @@ def open_restart( dirname: str, communicator: Communicator, label: str = "", - only_names: Iterable[str] = None, - to_state: dict = None, - tracer_properties: RestartProperties = None, + only_names: Iterable[str] | None = None, + to_state: dict | None = None, + tracer_properties: RestartProperties | None = None, ): """Load restart files output by the Fortran model into a state dictionary. @@ -54,16 +54,16 @@ def open_restart( raise ValueError("no restart files found at {}".format(dirname)) for filename in filenames: - with filesystem.open(filename, "rb") as file: + with open(filename, "rb") as file: state.update( load_partial_state_from_restart_file( file, restart_properties, only_names=only_names ) ) coupler_res_filename = get_coupler_res_filename(dirname, label) - if filesystem.is_file(coupler_res_filename): + if os.path.isfile(coupler_res_filename): if only_names is None or "time" in only_names: - with filesystem.open(coupler_res_filename, "r") as f: + with open(coupler_res_filename, "r") as f: state["time"] = io.get_current_date_from_coupler_res(f) if to_state is None: state = communicator.tile.scatter_state(state) @@ -78,7 +78,7 @@ def get_coupler_res_filename(dirname, label): def restart_files(dirname, tile_index, label) -> Generator[BinaryIO, None, None]: for filename in restart_filenames(dirname, tile_index, label): - with filesystem.open(filename, "rb") as f: + with open(filename, "rb") as f: yield f @@ -89,7 +89,7 @@ def restart_filenames(dirname, tile_index, label): filename = os.path.join(dirname, prepend_label(name, label) + suffix) if ( (name in RESTART_NAMES) - or filesystem.is_file(filename) + or os.path.isfile(filename) or os.path.exists(filename) ): return_list.append(filename) @@ -166,7 +166,7 @@ def load_partial_state_from_restart_file( return state -def _get_restart_standard_names(restart_properties: RestartProperties = None): +def _get_restart_standard_names(restart_properties: RestartProperties | None = None): """Return a list of variable names needed for a smooth restart. By default uses restart_properties from RESTART_PROPERTIES.""" if restart_properties is None: diff --git a/ndsl/restart/_properties.py b/ndsl/restart/_properties.py index c4bd5011..b361d329 100644 --- a/ndsl/restart/_properties.py +++ b/ndsl/restart/_properties.py @@ -1,4 +1,4 @@ -from typing import Iterable, Mapping, Union +from collections.abc import Iterable, Mapping from ..constants import ( X_DIM, @@ -11,7 +11,7 @@ ) -RestartProperties = Mapping[str, Mapping[str, Union[str, Iterable[str]]]] +RestartProperties = Mapping[str, Mapping[str, str | Iterable[str]]] RESTART_PROPERTIES: RestartProperties = { "accumulated_x_courant_number": { "dims": [Z_DIM, Y_DIM, X_DIM], diff --git a/ndsl/restart/restart_properties.yml b/ndsl/restart/restart_properties.yml index f19c2ce4..48ab0f9c 100644 --- a/ndsl/restart/restart_properties.yml +++ b/ndsl/restart/restart_properties.yml @@ -1,568 +1,568 @@ accumulated_x_courant_number: dims: - - z - - y - - x + - z + - y + - x restart_name: cx units: '' accumulated_x_mass_flux: dims: - - z - - y - - x_interface + - z + - y + - x_interface restart_name: mfx units: unknown accumulated_y_courant_number: dims: - - z - - y - - x + - z + - y + - x restart_name: cy units: unknown accumulated_y_mass_flux: dims: - - z - - y_interface - - x + - z + - y_interface + - x restart_name: mfy units: unknown air_temperature: dims: - - z - - y - - x + - z + - y + - x restart_name: T units: degK air_temperature_after_physics: dims: - - z - - y - - x + - z + - y + - x restart_name: gt0 units: K air_temperature_at_2m: dims: - - y - - x + - y + - x restart_name: t2m units: degK area_of_grid_cell: dims: - - y - - x + - y + - x restart_name: area units: m^2 atmosphere_hybrid_a_coordinate: dims: - - z_interface + - z_interface restart_name: ak units: Pa atmosphere_hybrid_b_coordinate: dims: - - z_interface + - z_interface restart_name: bk units: '' canopy_water: dims: - - y - - x + - y + - x restart_name: canopy units: unknown clear_sky_downward_longwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: dnfx0 restart_name: sfcflw units: W/m^2 clear_sky_downward_shortwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: dnfx0 restart_name: sfcfsw units: W/m^2 clear_sky_upward_longwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: upfx0 restart_name: sfcflw units: W/m^2 clear_sky_upward_longwave_flux_at_top_of_atmosphere: dims: - - y - - x + - y + - x fortran_subname: upfx0 restart_name: topflw units: W/m^2 clear_sky_upward_shortwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: upfx0 restart_name: sfcfsw units: W/m^2 clear_sky_upward_shortwave_flux_at_top_of_atmosphere: dims: - - y - - x + - y + - x fortran_subname: upfx0 restart_name: topfsw units: W/m^2 convective_cloud_bottom_pressure: dims: - - y - - x + - y + - x restart_name: cvb units: Pa convective_cloud_fraction: dims: - - y - - x + - y + - x restart_name: cv units: '' convective_cloud_top_pressure: dims: - - y - - x + - y + - x restart_name: cvt units: Pa deep_soil_temperature: dims: - - y - - x + - y + - x restart_name: tg3 units: degK dissipation_estimate_from_heat_source: dims: - - z - - y - - x + - z + - y + - x restart_name: diss_est units: unknown eastward_wind: dims: - - z - - y - - x + - z + - y + - x restart_name: ua units: m/s eastward_wind_after_physics: dims: - - z - - y - - x + - z + - y + - x restart_name: gu0 units: m/s eastward_wind_at_surface: dims: - - y - - x + - y + - x restart_name: u_srf units: m/s fh_parameter: description: used in PBL scheme dims: - - y - - x + - y + - x restart_name: ffhh units: unknown fm_at_10m: description: Ratio of sigma level 1 wind and 10m wind dims: - - y - - x + - y + - x restart_name: f10m units: unknown fm_parameter: description: used in PBL scheme dims: - - y - - x + - y + - x restart_name: ffmm units: unknown fractional_coverage_with_strong_cosz_dependency: dims: - - y - - x + - y + - x restart_name: facsf units: '' fractional_coverage_with_weak_cosz_dependency: dims: - - y - - x + - y + - x restart_name: facwf units: '' friction_velocity: dims: - - y - - x + - y + - x restart_name: uustar units: m/s ice_fraction_over_open_water: dims: - - y - - x + - y + - x restart_name: fice units: '' interface_pressure: dims: - - y - - z_interface - - x + - y + - z_interface + - x restart_name: pe units: Pa interface_pressure_raised_to_power_of_kappa: dims: - - z_interface - - y - - x + - z_interface + - y + - x restart_name: pk units: unknown land_sea_mask: description: sea=0, land=1, sea-ice=2 dims: - - y - - x + - y + - x restart_name: slmsk units: '' latent_heat_flux: dims: - - y - - x + - y + - x restart_name: dqsfci units: W/m^2 latitude: dims: - - y - - x + - y + - x restart_name: xlat units: radians layer_mean_pressure_raised_to_power_of_kappa: dims: - - z - - y - - x + - z + - y + - x restart_name: pkz units: unknown liquid_soil_moisture: dims: - - z_soil - - y - - x + - z_soil + - y + - x restart_name: slc units: unknown logarithm_of_interface_pressure: dims: - - y - - z_interface - - x + - y + - z_interface + - x restart_name: peln units: ln(Pa) longitude: dims: - - y - - x + - y + - x restart_name: xlon units: radians maximum_fractional_coverage_of_green_vegetation: dims: - - y - - x + - y + - x restart_name: shdmax units: '' maximum_snow_albedo_in_fraction: dims: - - y - - x + - y + - x restart_name: snoalb units: '' mean_cos_zenith_angle: dims: - - y - - x + - y + - x restart_name: coszen units: '' mean_near_infrared_albedo_with_strong_cosz_dependency: dims: - - y - - x + - y + - x restart_name: alnsf units: '' mean_near_infrared_albedo_with_weak_cosz_dependency: dims: - - y - - x + - y + - x restart_name: alnwf units: '' mean_visible_albedo_with_strong_cosz_dependency: dims: - - y - - x + - y + - x restart_name: alvsf units: '' mean_visible_albedo_with_weak_cosz_dependency: dims: - - y - - x + - y + - x restart_name: alvwf units: '' minimum_fractional_coverage_of_green_vegetation: dims: - - y - - x + - y + - x restart_name: shdmin units: '' northward_wind: dims: - - z - - y - - x + - z + - y + - x restart_name: va units: m/s northward_wind_after_physics: dims: - - z - - y - - x + - z + - y + - x restart_name: gv0 units: m/s northward_wind_at_surface: dims: - - y - - x + - y + - x restart_name: v_srf units: m/s pressure_thickness_of_atmospheric_layer: dims: - - z - - y - - x + - z + - y + - x restart_name: delp units: Pa sea_ice_thickness: dims: - - y - - x + - y + - x restart_name: hice units: unknown sensible_heat_flux: dims: - - y - - x + - y + - x restart_name: dtsfci units: W/m^2 snow_cover_in_fraction: dims: - - y - - x + - y + - x restart_name: sncovr units: '' snow_depth_water_equivalent: dims: - - y - - x + - y + - x restart_name: snwdph units: mm snow_rain_flag: description: snow/rain flag for precipitation dims: - - y - - x + - y + - x restart_name: srflag units: '' soil_temperature: dims: - - z_soil - - y - - x + - z_soil + - y + - x restart_name: stc units: degK soil_type: dims: - - y - - x + - y + - x restart_name: stype units: '' specific_humidity_at_2m: dims: - - y - - x + - y + - x restart_name: q2m units: kg/kg surface_geopotential: dims: - - y - - x + - y + - x restart_name: phis units: m^2 s^-2 surface_pressure: dims: - - y - - x + - y + - x restart_name: ps units: Pa surface_roughness: dims: - - y - - x + - y + - x restart_name: zorl units: cm surface_slope_type: description: used in land surface model dims: - - y - - x + - y + - x restart_name: slope units: '' surface_temperature: description: surface skin temperature dims: - - y - - x + - y + - x restart_name: tsea units: degK surface_temperature_over_ice_fraction: dims: - - y - - x + - y + - x restart_name: tisfc units: degK total_condensate_mixing_ratio: dims: - - z - - y - - x + - z + - y + - x restart_name: q_con units: kg/kg total_precipitation: dims: - - y - - x + - y + - x restart_name: tprcp units: m total_sky_downward_longwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: dnfxc restart_name: sfcflw units: W/m^2 total_sky_downward_shortwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: dnfxc restart_name: sfcfsw units: W/m^2 total_sky_downward_shortwave_flux_at_top_of_atmosphere: dims: - - y - - x + - y + - x fortran_subname: dnfxc restart_name: topfsw units: W/m^2 total_sky_upward_longwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: upfxc restart_name: sfcflw units: W/m^2 total_sky_upward_longwave_flux_at_top_of_atmosphere: dims: - - y - - x + - y + - x fortran_subname: upfxc restart_name: topflw units: W/m^2 total_sky_upward_shortwave_flux_at_surface: dims: - - y - - x + - y + - x fortran_subname: upfxc restart_name: sfcfsw units: W/m^2 total_sky_upward_shortwave_flux_at_top_of_atmosphere: dims: - - y - - x + - y + - x fortran_subname: upfxc restart_name: topfsw units: W/m^2 total_soil_moisture: dims: - - z_soil - - y - - x + - z_soil + - y + - x restart_name: smc units: unknown vegetation_fraction: dims: - - y - - x + - y + - x restart_name: vfrac units: '' vegetation_type: dims: - - y - - x + - y + - x restart_name: vtype units: '' vertical_pressure_velocity: dims: - - z - - y - - x + - z + - y + - x restart_name: omga units: Pa/s vertical_thickness_of_atmospheric_layer: dims: - - z - - y - - x + - z + - y + - x restart_name: DZ units: m vertical_wind: dims: - - z - - y - - x + - z + - y + - x restart_name: W units: m/s water_equivalent_of_accumulated_snow_depth: description: weasd in Fortran code, over land and sea ice only dims: - - y - - x + - y + - x restart_name: sheleg units: kg/m^2 x_wind: dims: - - z - - y_interface - - x + - z + - y_interface + - x restart_name: u units: m/s x_wind_on_c_grid: dims: - - z - - y - - x_interface + - z + - y + - x_interface restart_name: uc units: m/s y_wind: dims: - - z - - y - - x_interface + - z + - y + - x_interface restart_name: v units: m/s y_wind_on_c_grid: dims: - - z - - y_interface - - x + - z + - y_interface + - x restart_name: vc units: m/s diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 5115a28e..8a635187 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,8 @@ from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid + + +__all__ = [ + "CopyCorners", + "CopyCornersXY", + "FillCornersBGrid", +] diff --git a/ndsl/stencils/basic_operations.py b/ndsl/stencils/basic_operations.py index 18f44afb..e8c0d475 100644 --- a/ndsl/stencils/basic_operations.py +++ b/ndsl/stencils/basic_operations.py @@ -2,7 +2,7 @@ from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, IntField, IntFieldIJ -def copy_defn(q_in: FloatField, q_out: FloatField): +def copy_defn(q_in: FloatField, q_out: FloatField) -> None: """ Copy q_in to q_out. @@ -14,7 +14,7 @@ def copy_defn(q_in: FloatField, q_out: FloatField): q_out = q_in -def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): +def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField) -> None: """ Multiplies every element of q_out by every element of the adjustment field over the interval, replacing the elements of q_out by the result of the multiplication. @@ -27,7 +27,7 @@ def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): q_out = q_out * adjustment -def set_value_defn(q_out: FloatField, value: Float): +def set_value_defn(q_out: FloatField, value: Float) -> None: """ Sets every element of q_out to the value specified by value argument. @@ -39,7 +39,7 @@ def set_value_defn(q_out: FloatField, value: Float): q_out = value -def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): +def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField) -> None: """ Divides every element of q_out by every element of the adjustment field over the interval, replacing the elements of q_out by the result of the multiplication. @@ -57,7 +57,7 @@ def select_k( out_field: FloatFieldIJ, k_mask: IntField, k_select: IntFieldIJ, -): +) -> None: """ Saves a specific k-index of a 3D field to a new 2D array. The k-value can be different for each i,j point. @@ -77,7 +77,7 @@ def select_k( def average_in( q_out: FloatField, adjustment: FloatField, -): +) -> None: """ Averages every element of q_out with every element of the adjustment field, overwriting q_out. @@ -91,7 +91,7 @@ def average_in( @function -def sign(a, b): +def sign(a, b): # type: ignore[no-untyped-def] """ Defines a_sign_b as the absolute value of a, and checks if b is positive or negative, assigning the analogous sign value to a_sign_b. a_sign_b is returned. @@ -105,7 +105,7 @@ def sign(a, b): @function -def dim(a, b): +def dim(a, b): # type: ignore[no-untyped-def] """ Calculates a - b, camped to 0, i.e. max(a - b, 0). """ diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 39194232..4f19fbf9 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -1,3 +1,5 @@ +from typing import Any + from gt4py.cartesian.gtscript import ( __INLINED, PARALLEL, @@ -24,9 +26,9 @@ def mock_exchange( - quantity, - domain_2d, -): + quantity: Any, + domain_2d: list[list], +) -> None: isc = domain_2d[0][0] iec = domain_2d[0][1] isd = domain_2d[1][0] @@ -68,7 +70,7 @@ def c2l_ord2( a22: FloatFieldIJ, ua: FloatField, va: FloatField, -): +) -> None: """ Args: u (in): @@ -110,7 +112,7 @@ def ord4_transform( a22: FloatFieldIJ, ua: FloatField, va: FloatField, -): +) -> None: """ Args: u (in): @@ -156,7 +158,7 @@ class CubedToLatLon: def __init__( self, - state, # No type hint on purpose to remove dependency on pyFV3 + state: Any, # No type hint on purpose to avoid dependency on pyFV3 stencil_factory: StencilFactory, quantity_factory: QuantityFactory, grid_data: GridData, @@ -252,7 +254,7 @@ def __call__( v: FloatField, ua: FloatField, va: FloatField, - ): + ) -> None: """ Interpolate D-grid to A-grid winds at latitude-longitude coordinates. Args: diff --git a/ndsl/stencils/corners.py b/ndsl/stencils/corners.py index 18c1f6c1..021a439b 100644 --- a/ndsl/stencils/corners.py +++ b/ndsl/stencils/corners.py @@ -1,8 +1,10 @@ -from typing import Optional, Sequence, Tuple +import warnings +from collections.abc import Sequence from gt4py.cartesian import gtscript from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region +from ndsl import StencilFactory from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -10,7 +12,7 @@ Y_INTERFACE_DIM, Z_INTERFACE_DIM, ) -from ndsl.dsl.stencil import GridIndexing, StencilFactory +from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import FloatField @@ -22,6 +24,13 @@ class CopyCorners: def __init__(self, direction: str, stencil_factory: StencilFactory) -> None: """The grid for this stencil""" + warnings.warn( + "Usage of the GT4Py implementation of CopyCorners is discouraged and will" + "be removed in the next release. Use `CopyCornersX` or `CopyCornersY` in PyFV3" + "for a more future-proof implementation of the same code.", + DeprecationWarning, + 2, + ) grid_indexing = stencil_factory.grid_indexing n_halo = grid_indexing.n_halo @@ -116,8 +125,8 @@ def __call__(self, field: FloatField): def kslice_from_inputs( - kstart: int, nk: Optional[int], grid_indexer: GridIndexing -) -> Tuple[slice, int]: + kstart: int, nk: int | None, grid_indexer: GridIndexing +) -> tuple[slice, int]: # This expects ints, but it casts in case something was implicitly converted # to a float before this call. if nk is None: @@ -1001,9 +1010,6 @@ def fill_corners_dgrid_defn( from __externals__ import i_end, i_start, j_end, j_start with computation(PARALLEL), interval(...): - # this line of code is used to fix the missing symbol crash due to the node visitor depth limitation - acoef = mysign - x_out = x_out # sw corner with horizontal(region[i_start - 1, j_start - 1]): x_out = mysign * y_in[0, 1, 0] diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index 4be2c60a..a55d8647 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -14,3 +14,22 @@ pad_field_in_j, read_serialized_data, ) + + +__all__ = [ + "Grid", + "ParallelTranslate", + "ParallelTranslate2Py", + "ParallelTranslate2PyState", + "ParallelTranslateBaseSlicing", + "ParallelTranslateGrid", + "SavepointCase", + "Translate", + "assert_same_temporaries", + "TranslateFortranData2Py", + "TranslateGrid", + "pad_field_in_j", + "read_serialized_data", + "dataset_to_dict", + "copy_temporaries", +] diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 6da880fd..4e0bb428 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -1,13 +1,15 @@ import os import re -from typing import Optional, Tuple +from pathlib import Path +from typing import Any -import f90nml import pytest import xarray as xr import yaml +from f90nml import Namelist from ndsl import CompilationConfig, StencilConfig, StencilFactory +from ndsl.comm import Comm from ndsl.comm.communicator import ( Communicator, CubedSphereCommunicator, @@ -16,14 +18,17 @@ from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.namelist import Namelist + +# TODO: Remove NdslNamelist import after Issue#64 is resolved. +from ndsl.namelist import Namelist as NdslNamelist from ndsl.stencils.testing.grid import Grid # type: ignore from ndsl.stencils.testing.parallel_translate import ParallelTranslate -from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict +from ndsl.stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict from ndsl.stencils.testing.translate import TranslateGrid +from ndsl.utils import grid_params_from_f90nml, load_f90nml -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: """Option for the Translate Test system See -h or inline help for details. @@ -73,6 +78,12 @@ def pytest_addoption(parser): default=1, help="How many indices of failures to print from worst to best. Default to 1.", ) + parser.addoption( + "--no_legacy_namelist", + action="store_true", + default=False, + help="Removes support for `ndsl.Namelist` in translate tests (which we are trying to get rid off, see NDSL issue #64). Defaults to False.", + ) parser.addoption( "--grid", action="store", @@ -105,7 +116,7 @@ def pytest_addoption(parser): ) -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: # register an additional marker config.addinivalue_line( "markers", "sequential(name): mark test as running sequentially on ranks" @@ -120,32 +131,33 @@ def pytest_configure(config): @pytest.fixture() -def data_path(pytestconfig): +def data_path(pytestconfig: pytest.Config) -> tuple[Path, Path]: return data_path_and_namelist_filename_from_config(pytestconfig) -def data_path_and_namelist_filename_from_config(config) -> Tuple[str, str]: - data_path = config.getoption("data_path") - namelist_filename = os.path.join(data_path, "input.nml") - return data_path, namelist_filename +def data_path_and_namelist_filename_from_config( + config: pytest.Config, +) -> tuple[Path, Path]: + data_path = Path(config.getoption("data_path")) + return data_path, data_path / "input.nml" @pytest.fixture -def threshold_overrides(pytestconfig): +def threshold_overrides(pytestconfig: pytest.Config) -> dict | None: return thresholds_from_file(pytestconfig) -def thresholds_from_file(config): +def thresholds_from_file(config: pytest.Config) -> dict | None: thresholds_file = config.getoption("threshold_overrides_file") if thresholds_file is None: return None return yaml.safe_load(open(thresholds_file, "r")) -def get_test_class(test_name): +def get_test_class(test_name: str) -> type | None: translate_class_name = f"Translate{test_name.replace('-', '_')}" try: - return_class = getattr(translate, translate_class_name) # noqa: F821 + return_class = getattr(translate, translate_class_name) # type: ignore[name-defined] # noqa: F821 except AttributeError as err: if translate_class_name in err.args[0]: return_class = None @@ -154,29 +166,32 @@ def get_test_class(test_name): return return_class -def is_parallel_test(test_name): +def is_parallel_test(test_name: str) -> bool: test_class = get_test_class(test_name) if test_class is None: return False - else: - return issubclass(test_class, ParallelTranslate) + return issubclass(test_class, ParallelTranslate) -def get_test_class_instance(test_name, grid, namelist, stencil_factory): +def get_test_class_instance( + test_name: str, grid: Grid, namelist: Namelist, stencil_factory: StencilFactory +) -> Translate: translate_class = get_test_class(test_name) if translate_class is None: - return None - else: - return translate_class(grid, namelist, stencil_factory) + raise ValueError( + f"Could not find translate test class for test name '{test_name}'." + ) + + return translate_class(grid, namelist, stencil_factory) -def get_all_savepoint_names(metafunc, data_path): +def get_all_savepoint_names(metafunc: Any, data_path: Path) -> set[str]: only_names = metafunc.config.getoption("which_modules") if only_names is None: - savepoint_names = [ + names = [ fname[:-3] for fname in os.listdir(data_path) if re.match(r".*\.nc", fname) ] - savepoint_names = [s[:-3] for s in savepoint_names if s.endswith("-In")] + savepoint_names = set([s[:-3] for s in names if s.endswith("-In")]) else: savepoint_names = set(only_names.split(",")) savepoint_names.discard("") @@ -186,7 +201,7 @@ def get_all_savepoint_names(metafunc, data_path): return savepoint_names -def get_sequential_savepoint_names(metafunc, data_path): +def get_sequential_savepoint_names(metafunc: Any, data_path: Path) -> list[str]: all_names = get_all_savepoint_names(metafunc, data_path) sequential_names = [] for name in all_names: @@ -195,7 +210,7 @@ def get_sequential_savepoint_names(metafunc, data_path): return sequential_names -def get_parallel_savepoint_names(metafunc, data_path): +def get_parallel_savepoint_names(metafunc: Any, data_path: Path) -> list[str]: all_names = get_all_savepoint_names(metafunc, data_path) parallel_names = [] for name in all_names: @@ -204,31 +219,29 @@ def get_parallel_savepoint_names(metafunc, data_path): return parallel_names -def get_ranks(metafunc, layout): +def get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range: only_rank = metafunc.config.getoption("which_rank") + if only_rank is not None: + return [int(only_rank)] + topology = metafunc.config.getoption("topology") - if only_rank is None: - if topology == "doubly-periodic": - total_ranks = layout[0] * layout[1] - elif topology == "cubed-sphere": - total_ranks = 6 * layout[0] * layout[1] - else: - raise NotImplementedError(f"Topology {topology} is unknown.") - return range(total_ranks) + + if topology == "doubly-periodic": + total_ranks = layout[0] * layout[1] + elif topology == "cubed-sphere": + total_ranks = 6 * layout[0] * layout[1] else: - return [int(only_rank)] + raise NotImplementedError(f"Topology {topology} is unknown.") + + return range(total_ranks) -def get_savepoint_restriction(metafunc): +def get_savepoint_restriction(metafunc: Any) -> int | None: svpt = metafunc.config.getoption("which_savepoint") return int(svpt) if svpt else None -def get_namelist(namelist_filename): - return Namelist.from_f90nml(f90nml.read(namelist_filename)) - - -def get_config(backend: str, communicator: Optional[Communicator]): +def get_config(backend: str, communicator: Communicator | None) -> StencilConfig: stencil_config = StencilConfig( compilation_config=CompilationConfig( backend=backend, rebuild=False, validate_args=True @@ -241,16 +254,23 @@ def get_config(backend: str, communicator: Optional[Communicator]): return stencil_config -def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backend: str): +def sequential_savepoint_cases( + metafunc: Any, data_path: Path, namelist_filename: Path, *, backend: str +) -> list[SavepointCase]: savepoint_names = get_sequential_savepoint_names(metafunc, data_path) - namelist = get_namelist(namelist_filename) + namelist = load_f90nml(namelist_filename) + grid_params = grid_params_from_f90nml(namelist) stencil_config = get_config(backend, None) - ranks = get_ranks(metafunc, namelist.layout) + ranks = get_ranks(metafunc, grid_params["layout"]) savepoint_to_replay = get_savepoint_restriction(metafunc) grid_mode = metafunc.config.getoption("grid") topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") + + # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. + use_legacy_namelist = not metafunc.config.getoption("no_legacy_namelist") + return _savepoint_cases( savepoint_names, ranks, @@ -263,46 +283,49 @@ def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backen topology_mode, sort_report=sort_report, no_report=no_report, + use_legacy_namelist=use_legacy_namelist, # Issue#64: tmp flag ) def _savepoint_cases( - savepoint_names, - ranks, - savepoint_to_replay, - stencil_config, + savepoint_names: list[str], + ranks: list[int] | range, + savepoint_to_replay: int | None, + stencil_config: StencilConfig, namelist: Namelist, backend: str, - data_path: str, + data_path: Path, grid_mode: str, - topology_mode: bool, + topology_mode: str, sort_report: str, no_report: bool, -): + use_legacy_namelist: bool, # Issue#64: tmp flag +) -> list[SavepointCase]: + grid_params = grid_params_from_f90nml(namelist) return_list = [] for rank in ranks: if grid_mode == "default": grid = Grid._make( - namelist.npx, - namelist.npy, - namelist.npz, - namelist.layout, + grid_params["npx"], + grid_params["npy"], + grid_params["npz"], + grid_params["layout"], rank, backend, ) elif grid_mode == "file" or grid_mode == "compute": - ds_grid: xr.Dataset = xr.open_dataset( - os.path.join(data_path, "Grid-Info.nc") - ).isel(savepoint=0) + ds_grid: xr.Dataset = xr.open_dataset(data_path / "Grid-Info.nc").isel( + savepoint=0 + ) grid = TranslateGrid( dataset_to_dict(ds_grid.isel(rank=rank)), rank=rank, - layout=namelist.layout, + layout=grid_params["layout"], backend=backend, ).python_grid() if grid_mode == "compute": compute_grid_data( - grid, namelist, backend, namelist.layout, topology_mode + grid, grid_params, backend, grid_params["layout"], topology_mode ) else: raise NotImplementedError(f"Grid mode {grid_mode} is unknown.") @@ -312,12 +335,18 @@ def _savepoint_cases( grid_indexing=grid.grid_indexing, ) for test_name in sorted(list(savepoint_names)): + # Temporary check (Issue#64): TODO Remove check and conversion from + # f90nml.Namelist to ndsl.Namelist after ndsl.Namelist is removed + if use_legacy_namelist and not isinstance(namelist, NdslNamelist): + assert isinstance(namelist, Namelist) + namelist = NdslNamelist.from_f90nml(namelist) + testobj = get_test_class_instance( test_name, grid, namelist, stencil_factory ) - n_calls = xr.open_dataset( - os.path.join(data_path, f"{test_name}-In.nc") - ).dims["savepoint"] + n_calls = xr.open_dataset(data_path / f"{test_name}-In.nc").sizes[ + "savepoint" + ] if savepoint_to_replay is not None: savepoint_iterator = range(savepoint_to_replay, savepoint_to_replay + 1) else: @@ -337,28 +366,45 @@ def _savepoint_cases( return return_list -def compute_grid_data(grid, namelist, backend, layout, topology_mode): +def compute_grid_data( + grid: Grid, + grid_params: dict, + backend: str, + layout: tuple[int, int], + topology_mode: str, +) -> None: grid.make_grid_data( - npx=namelist.npx, - npy=namelist.npy, - npz=namelist.npz, + npx=grid_params["npx"], + npy=grid_params["npy"], + npz=grid_params["npz"], communicator=get_communicator(MPIComm(), layout, topology_mode), backend=backend, ) def parallel_savepoint_cases( - metafunc, data_path, namelist_filename, mpi_rank, *, backend: str, comm -): - namelist = get_namelist(namelist_filename) + metafunc: Any, + data_path: Path, + namelist_filename: Path, + mpi_rank: int, + *, + backend: str, + comm: Comm, +) -> list[SavepointCase]: + namelist = load_f90nml(namelist_filename) + grid_params = grid_params_from_f90nml(namelist) topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") - communicator = get_communicator(comm, namelist.layout, topology_mode) + communicator = get_communicator(comm, grid_params["layout"], topology_mode) stencil_config = get_config(backend, communicator) savepoint_names = get_parallel_savepoint_names(metafunc, data_path) grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = get_savepoint_restriction(metafunc) + + # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. + use_legacy_namelist = not metafunc.config.getoption("no_legacy_namelist") + return _savepoint_cases( savepoint_names, [mpi_rank], @@ -371,31 +417,35 @@ def parallel_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, + use_legacy_namelist=use_legacy_namelist, # Issue#64: tmp flag ) -def pytest_generate_tests(metafunc): +def pytest_generate_tests(metafunc: Any) -> None: backend = metafunc.config.getoption("backend") - if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: + if MPI.COMM_WORLD.Get_size() > 1: if metafunc.function.__name__ == "test_parallel_savepoint": generate_parallel_stencil_tests(metafunc, backend=backend) elif metafunc.function.__name__ == "test_sequential_savepoint": generate_sequential_stencil_tests(metafunc, backend=backend) -def generate_sequential_stencil_tests(metafunc, *, backend: str): +def generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None: data_path, namelist_filename = data_path_and_namelist_filename_from_config( metafunc.config ) savepoint_cases = sequential_savepoint_cases( - metafunc, data_path, namelist_filename, backend=backend + metafunc, + data_path, + namelist_filename, + backend=backend, ) metafunc.parametrize( "case", savepoint_cases, ids=[str(item) for item in savepoint_cases] ) -def generate_parallel_stencil_tests(metafunc, *, backend: str): +def generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None: data_path, namelist_filename = data_path_and_namelist_filename_from_config( metafunc.config ) @@ -414,41 +464,41 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): ) -def get_communicator(comm, layout, topology_mode): +def get_communicator( + comm: Comm, layout: tuple[int, int], topology_mode: str +) -> Communicator: + tile_partitioner = TilePartitioner(layout) if (comm.Get_size() > 1) and (topology_mode == "cubed-sphere"): - partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(comm, partitioner) - else: - partitioner = TilePartitioner(layout) - communicator = TileCommunicator(comm, partitioner) - return communicator + return CubedSphereCommunicator(comm, CubedSpherePartitioner(tile_partitioner)) + + return TileCommunicator(comm, tile_partitioner) @pytest.fixture() -def print_failures(pytestconfig): +def print_failures(pytestconfig: pytest.Config) -> str: return pytestconfig.getoption("print_failures") @pytest.fixture() -def failure_stride(pytestconfig): +def failure_stride(pytestconfig: pytest.Config) -> int: return int(pytestconfig.getoption("failure_stride")) @pytest.fixture() -def multimodal_metric(pytestconfig): +def multimodal_metric(pytestconfig: pytest.Config) -> bool: return bool(pytestconfig.getoption("multimodal_metric")) @pytest.fixture() -def grid(pytestconfig): +def grid(pytestconfig: pytest.Config) -> str: return pytestconfig.getoption("grid") @pytest.fixture() -def topology_mode(pytestconfig): +def topology_mode(pytestconfig: pytest.Config) -> str: return pytestconfig.getoption("topology_mode") @pytest.fixture() -def backend(pytestconfig): +def backend(pytestconfig: pytest.Config) -> str: return pytestconfig.getoption("backend") diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 6d4b4b38..41c7d356 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -1,5 +1,4 @@ # type: ignore -from typing import Dict, Tuple import numpy as np @@ -19,8 +18,7 @@ VerticalGridData, ) from ndsl.halo.data_transformer import QuantityHaloSpec -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer +from ndsl.initialization import QuantityFactory, SubtileGridSizer from ndsl.quantity import Quantity @@ -85,9 +83,12 @@ def __init__( rank, layout, backend, - data_fields={}, + data_fields: dict | None = None, local_indices=False, ): + if data_fields is None: + data_fields = {} + self.rank = rank self.backend = backend self.partitioner = TilePartitioner(layout) @@ -145,7 +146,7 @@ def sizer(self): ny_tile=self.npy - 1, nz=self.npz, n_halo=self.halo, - extra_dim_lengths={ + data_dimensions={ MetricTerms.LON_OR_LAT_DIM: 2, MetricTerms.TILE_DIM: 6, MetricTerms.CARTESIAN_DIM: 3, @@ -166,7 +167,7 @@ def quantity_factory(self) -> QuantityFactory: def make_quantity( self, array, - dims=[X_DIM, Y_DIM, Z_DIM], + dims=(X_DIM, Y_DIM, Z_DIM), units="Unknown", origin=None, extent=None, @@ -181,7 +182,7 @@ def quantity_dict_update( self, data_dict, varname, - dims=[X_DIM, Y_DIM, Z_DIM], + dims=(X_DIM, Y_DIM, Z_DIM), units="Unknown", ): data_dict[varname + "_quantity"] = self.quantity_wrap( @@ -191,7 +192,7 @@ def quantity_dict_update( def quantity_wrap( self, data, - dims=[X_DIM, Y_DIM, Z_DIM], + dims=(X_DIM, Y_DIM, Z_DIM), units="unknown", ): origin = self.sizer.get_origin(dims) @@ -381,11 +382,11 @@ def y3d_compute_domain_x_dict(self): } return {**self.default_domain_dict(), **horizontal_dict} - def domain_shape_full(self, *, add: Tuple[int, int, int] = (0, 0, 0)): + def domain_shape_full(self, *, add: tuple[int, int, int] = (0, 0, 0)): """Domain shape for the full array including halo points.""" return (self.nid + add[0], self.njd + add[1], self.npz + add[2]) - def domain_shape_compute(self, *, add: Tuple[int, int, int] = (0, 0, 0)): + def domain_shape_compute(self, *, add: tuple[int, int, int] = (0, 0, 0)): """Compute domain shape excluding halo points.""" return (self.nic + add[0], self.njc + add[1], self.npz + add[2]) @@ -414,11 +415,11 @@ def uvar_edge_halo(self, var): def vvar_edge_halo(self, var): return self.copy_right_edge(var, self.ie + 1, self.je + 2) - def compute_origin(self, add: Tuple[int, int, int] = (0, 0, 0)): + def compute_origin(self, add: tuple[int, int, int] = (0, 0, 0)): """Start of the compute domain (e.g. (halo, halo, 0))""" return (self.is_ + add[0], self.js + add[1], add[2]) - def full_origin(self, add: Tuple[int, int, int] = (0, 0, 0)): + def full_origin(self, add: tuple[int, int, int] = (0, 0, 0)): """Start of the full array including halo points (e.g. (0, 0, 0))""" return (self.isd + add[0], self.jsd + add[1], add[2]) @@ -440,7 +441,7 @@ def get_halo_update_spec( shape, origin, halo_points, - dims=[X_DIM, Y_DIM, Z_DIM], + dims=(X_DIM, Y_DIM, Z_DIM), ) -> QuantityHaloSpec: """Build memory specifications for the halo update.""" return self.quantity_factory.get_quantity_halo_spec( @@ -449,7 +450,7 @@ def get_halo_update_spec( ) @property - def grid_indexing(self) -> "GridIndexing": + def grid_indexing(self) -> GridIndexing: return GridIndexing( domain=tuple(int(item) for item in self.domain_shape_compute()), n_halo=self.halo, @@ -460,7 +461,7 @@ def grid_indexing(self) -> "GridIndexing": ) @property - def damping_coefficients(self) -> "DampingCoefficients": + def damping_coefficients(self) -> DampingCoefficients: if self._damping_coefficients is not None: return self._damping_coefficients self._damping_coefficients = DampingCoefficients( @@ -473,18 +474,18 @@ def damping_coefficients(self) -> "DampingCoefficients": ) return self._damping_coefficients - def set_damping_coefficients(self, damping_coefficients: "DampingCoefficients"): + def set_damping_coefficients(self, damping_coefficients: DampingCoefficients): self._damping_coefficients = damping_coefficients @property - def grid_data(self) -> "GridData": + def grid_data(self) -> GridData: if self._grid_data is not None: return self._grid_data # The translate code pads ndarray axes with zeros in certain cases, # in particular the vertical axis. Since we're deprecating those tests, # we simply "fix" those arrays here. - clipped_data: Dict[str, Quantity] = {} + clipped_data: dict[str, Quantity] = {} for name in ( "ee1", "ee2", @@ -835,7 +836,7 @@ def driver_grid_data(self) -> DriverGridData: ) return self._driver_grid_data - def set_grid_data(self, grid_data: "GridData"): + def set_grid_data(self, grid_data: GridData): self._grid_data = grid_data def make_grid_data(self, npx, npy, npz, communicator, backend): diff --git a/ndsl/stencils/testing/parallel_translate.py b/ndsl/stencils/testing/parallel_translate.py index 7df16a17..85b87020 100644 --- a/ndsl/stencils/testing/parallel_translate.py +++ b/ndsl/stencils/testing/parallel_translate.py @@ -1,17 +1,21 @@ import copy from types import SimpleNamespace -from typing import Any, Dict, List +from typing import Any import numpy as np import pytest from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS from ndsl.dsl import gt4py_utils as utils + +# TODO: Remove once ndsl.Namelist is gone (Issue#64) +from ndsl.namelist import Namelist as NdslNamelist from ndsl.quantity import Quantity from ndsl.stencils.testing.translate import ( TranslateFortranData2Py, read_serialized_data, ) +from ndsl.utils import grid_params_from_f90nml class ParallelTranslate: @@ -22,8 +26,8 @@ class ParallelTranslate: mmr_ulp = TranslateFortranData2Py.mmr_ulp compute_grid_option = False tests_grid = False - inputs: Dict[str, Any] = {} - outputs: Dict[str, Any] = {} + inputs: dict[str, Any] = {} + outputs: dict[str, Any] = {} def __init__(self, rank_grids, namelist, stencil_factory, *args, **kwargs): if len(args) > 0: @@ -59,7 +63,7 @@ def __init__(self, rank_grids, namelist, stencil_factory, *args, **kwargs): self.namelist = namelist self.skip_test = False - def state_list_from_inputs_list(self, inputs_list: List[dict]) -> list: + def state_list_from_inputs_list(self, inputs_list: list[dict]) -> list: state_list = [] for inputs in inputs_list: state_list.append(self.state_from_inputs(inputs)) @@ -100,7 +104,7 @@ def collect_input_data(self, serializer, savepoint): return input_data def outputs_from_state(self, state: dict): - return_dict: Dict[str, np.ndarray] = {} + return_dict: dict[str, np.ndarray] = {} if len(self.outputs) == 0: return return_dict for name, properties in self.outputs.items(): @@ -129,7 +133,14 @@ def rank_grids(self): @property def layout(self): - return self.namelist.layout + # TODO: Once ndsl.namelist.Namelist is gone (Issue#64), + # remove this check in favor of f90nml.namelist.Namelist + if isinstance(self.namelist, NdslNamelist): + return self.namelist.layout + + # Assumption: namelist is f90nml.namelist.Namelist + grid_params = grid_params_from_f90nml(self.namelist) + return grid_params["layout"] def compute_sequential(self, inputs_list, communicator_list): """Compute the outputs while iterating over a set of communicator diff --git a/ndsl/stencils/testing/savepoint.py b/ndsl/stencils/testing/savepoint.py index 2708011e..22263b40 100644 --- a/ndsl/stencils/testing/savepoint.py +++ b/ndsl/stencils/testing/savepoint.py @@ -1,6 +1,6 @@ import dataclasses -import os -from typing import Dict, Protocol, Union +from pathlib import Path +from typing import Any, Protocol import numpy as np import xarray as xr @@ -8,13 +8,13 @@ from ndsl.stencils.testing.grid import Grid # type: ignore -def dataset_to_dict(ds: xr.Dataset) -> Dict[str, Union[np.ndarray, float, int]]: +def dataset_to_dict(ds: xr.Dataset) -> dict[str, np.ndarray | float | int]: return { name: _process_if_scalar(array.values) for name, array in ds.data_vars.items() } -def _process_if_scalar(value: np.ndarray) -> Union[np.ndarray, float, int]: +def _process_if_scalar(value: np.ndarray) -> np.ndarray | float | int: if len(value.shape) == 0: return value.max() # trick to make sure we get the right type back else: @@ -22,7 +22,7 @@ def _process_if_scalar(value: np.ndarray) -> Union[np.ndarray, float, int]: class DataLoader: - def __init__(self, rank: int, data_path: str): + def __init__(self, rank: int, data_path: Path) -> None: self._data_path = data_path self._rank = rank @@ -31,23 +31,20 @@ def load( name: str, postfix: str = "", i_call: int = 0, - ) -> Dict[str, Union[np.ndarray, float, int]]: + ) -> dict[str, np.ndarray | float | int]: return dataset_to_dict( - xr.open_dataset(os.path.join(self._data_path, f"{name}{postfix}.nc")) + xr.open_dataset(self._data_path / f"{name}{postfix}.nc") .isel(rank=self._rank) .isel(savepoint=i_call) ) class Translate(Protocol): - def collect_input_data(self, ds: xr.Dataset) -> dict: - ... + def collect_input_data(self, ds: xr.Dataset) -> dict: ... - def compute(self, data: dict): - ... + def compute(self, data: dict) -> Any: ... - def extra_data_load(self, data_loader: DataLoader): - ... + def extra_data_load(self, data_loader: DataLoader) -> None: ... @dataclasses.dataclass @@ -57,29 +54,29 @@ class SavepointCase: """ savepoint_name: str - data_dir: str + data_dir: Path i_call: int testobj: Translate grid: Grid sort_report: str no_report: bool - def __str__(self): + def __str__(self) -> str: return f"{self.savepoint_name}-rank={self.grid.rank}-call={self.i_call}" @property def exists(self) -> bool: return ( - xr.open_dataset( - os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc") - ).sizes["rank"] + xr.open_dataset(self.data_dir / f"{self.savepoint_name}-In.nc").sizes[ + "rank" + ] > self.grid.rank ) @property def ds_in(self) -> xr.Dataset: return ( - xr.open_dataset(os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc")) + xr.open_dataset(self.data_dir / f"{self.savepoint_name}-In.nc") .isel(rank=self.grid.rank) .isel(savepoint=self.i_call) ) @@ -87,9 +84,7 @@ def ds_in(self) -> xr.Dataset: @property def ds_out(self) -> xr.Dataset: return ( - xr.open_dataset( - os.path.join(self.data_dir, f"{self.savepoint_name}-Out.nc") - ) + xr.open_dataset(self.data_dir / f"{self.savepoint_name}-Out.nc") .isel(rank=self.grid.rank) .isel(savepoint=self.i_call) ) diff --git a/ndsl/stencils/testing/serialbox_to_netcdf.py b/ndsl/stencils/testing/serialbox_to_netcdf.py index e0dac98f..e93ace01 100644 --- a/ndsl/stencils/testing/serialbox_to_netcdf.py +++ b/ndsl/stencils/testing/serialbox_to_netcdf.py @@ -1,7 +1,8 @@ import argparse +import copy import os import shutil -from typing import Any, Dict, Optional +from typing import Any import f90nml import numpy as np @@ -57,20 +58,19 @@ def get_all_savepoint_names(serializer): return savepoint_names -def get_serializer(data_path: str, rank: int, data_name: Optional[str] = None): - if data_name: - name = data_name - else: - name = f"Generator_rank{rank}" - return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name) # type: ignore +def get_serializer( + data_path: str, rank: int, data_name: str | None = None +) -> serialbox.Serializer: + name = data_name if data_name else f"Generator_rank{rank}" + return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name) def main( data_path: str, output_path: str, merge_blocks: bool, - data_name: Optional[str] = None, -): + data_name: str | None = None, +) -> None: os.makedirs(output_path, exist_ok=True) namelist_filename_in = os.path.join(data_path, "input.nml") @@ -81,7 +81,7 @@ def main( if namelist_filename_out != namelist_filename_in: shutil.copyfile(os.path.join(data_path, "input.nml"), namelist_filename_out) namelist = f90nml.read(namelist_filename_out) - fv_core_nml: Dict[str, Any] = namelist["fv_core_nml"] # type: ignore + fv_core_nml: dict[str, Any] = namelist["fv_core_nml"] if fv_core_nml["grid_type"] <= 3: total_ranks = 6 * fv_core_nml["layout"][0] * fv_core_nml["layout"][1] else: @@ -94,7 +94,7 @@ def main( savepoint_names = get_all_savepoint_names(serializer_0) for savepoint_name in sorted(list(savepoint_names)): - rank_list = [] + rank_list: list[dict[str, Any]] = [] names_list = list( serializer_0.fields_at_savepoint( serializer_0.get_savepoint(savepoint_name)[0] @@ -106,8 +106,17 @@ def main( serializer = get_serializer(data_path, rank, data_name) serializer_list.append(serializer) savepoints = serializer.get_savepoint(savepoint_name) - rank_data: Dict[str, Any] = {} + rank_names_list = list(serializer.fields_at_savepoint(savepoints[0])) + rank_data: dict[str, Any] = {} for name in set(names_list): + if name not in rank_names_list: + data = copy.deepcopy(rank_list[rank - 1][name][0]) + data[:] = np.nan + rank_data[name] = [data] + print( + f"Skipping {name} for rank {rank} - no data, will fill with NaN" + ) + continue rank_data[name] = [] for savepoint in savepoints: rank_data[name].append( diff --git a/ndsl/stencils/testing/temporaries.py b/ndsl/stencils/testing/temporaries.py index 45c43e4f..c42f3413 100644 --- a/ndsl/stencils/testing/temporaries.py +++ b/ndsl/stencils/testing/temporaries.py @@ -1,12 +1,11 @@ import copy -from typing import List import numpy as np from ndsl.quantity import Quantity -def copy_temporaries(obj, max_depth: int) -> dict: +def copy_temporaries(obj: object, max_depth: int) -> dict: temporaries = {} attrs = [a for a in dir(obj) if not a.startswith("__")] for attr_name in attrs: @@ -16,7 +15,7 @@ def copy_temporaries(obj, max_depth: int) -> dict: attr = None if isinstance(attr, Quantity): temporaries[attr_name] = copy.deepcopy(np.asarray(attr.data)) - elif attr.__class__.__module__.split(".")[0] in ("pyFV3"): # type: ignore + elif attr.__class__.__module__.split(".")[0] in ("pyFV3"): if max_depth > 0: sub_temporaries = copy_temporaries(attr, max_depth - 1) if len(sub_temporaries) > 0: @@ -24,13 +23,13 @@ def copy_temporaries(obj, max_depth: int) -> dict: return temporaries -def assert_same_temporaries(dict1: dict, dict2: dict): +def assert_same_temporaries(dict1: dict, dict2: dict) -> None: diffs = _assert_same_temporaries(dict1, dict2) if len(diffs) > 0: raise AssertionError(f"{len(diffs)} differing temporaries found: {diffs}") -def _assert_same_temporaries(dict1: dict, dict2: dict) -> List[str]: +def _assert_same_temporaries(dict1: dict, dict2: dict) -> list[str]: differences = [] for attr in dict1: attr1 = dict1[attr] diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 18271a4a..a350e76f 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -1,7 +1,7 @@ # type: ignore import copy import os -from typing import Any, Dict, List +from typing import Any import numpy as np import pytest @@ -68,9 +68,9 @@ def process_override(threshold_overrides, testobj, test_name, backend): for key in testobj.out_vars.keys(): if key not in testobj.ignore_near_zero_errors: testobj.ignore_near_zero_errors[key] = {} - testobj.ignore_near_zero_errors[key][ - "near_zero" - ] = float(match["all_other_near_zero"]) + testobj.ignore_near_zero_errors[key]["near_zero"] = ( + float(match["all_other_near_zero"]) + ) else: raise TypeError( @@ -79,7 +79,9 @@ def process_override(threshold_overrides, testobj, test_name, backend): if "multimodal" in match: parsed_multimodal = match["multimodal"] if "absolute_epsilon" in parsed_multimodal: - testobj.mmr_absolute_eps = float(parsed_multimodal["absolute_eps"]) + testobj.mmr_absolute_eps = float( + parsed_multimodal["absolute_epsilon"] + ) if "relative_fraction" in parsed_multimodal: testobj.mmr_relative_fraction = float( parsed_multimodal["relative_fraction"] @@ -140,7 +142,7 @@ def _get_thresholds(compute_function, input_data) -> None: @pytest.mark.sequential @pytest.mark.skipif( - MPI is not None and MPI.COMM_WORLD.Get_size() > 1, + MPI.COMM_WORLD.Get_size() > 1, reason="Running in parallel with mpi", ) def test_sequential_savepoint( @@ -155,7 +157,7 @@ def test_sequential_savepoint( xy_indices=True, ): if case.testobj is None: - pytest.xfail( + raise ValueError( f"No translate object available for savepoint {case.savepoint_name}." ) stencil_config = StencilConfig( @@ -177,7 +179,28 @@ def test_sequential_savepoint( return if not case.exists: pytest.skip(f"Data at rank {case.grid.rank} does not exist.") - input_data = dataset_to_dict(case.ds_in) + + if hasattr(case.testobj, "override_input_netcdf_name"): + import xarray as xr + + from ndsl.logging import ndsl_log + + out_data = ( + xr.open_dataset( + os.path.join( + case.data_dir, f"{case.testobj.override_input_netcdf_name}.nc" + ) + ) + .isel(rank=case.grid.rank) + .isel(savepoint=case.i_call) + ) + input_data = dataset_to_dict(out_data) + ndsl_log.warning( + f"You are loading {case.testobj.override_input_netcdf_name} as a custom input file! Here be dragons." + ) + else: + input_data = dataset_to_dict(case.ds_in) + input_names = ( case.testobj.serialnames(case.testobj.in_vars["data_vars"]) + case.testobj.in_vars["parameters"] @@ -194,9 +217,30 @@ def test_sequential_savepoint( case.testobj.extra_data_load(DataLoader(case.grid.rank, case.data_dir)) # run python version of functionality output = case.testobj.compute(input_data) - failing_names: List[str] = [] - passing_names: List[str] = [] - all_ref_data = dataset_to_dict(case.ds_out) + failing_names: list[str] = [] + passing_names: list[str] = [] + if hasattr(case.testobj, "override_output_netcdf_name"): + import xarray as xr + + from ndsl.logging import ndsl_log + + out_data = ( + xr.open_dataset( + os.path.join( + case.data_dir, f"{case.testobj.override_output_netcdf_name}.nc" + ) + ) + .isel(rank=case.grid.rank) + .isel(savepoint=case.i_call) + ) + + output_data = dataset_to_dict(out_data) + ndsl_log.warning( + f"You are loading {case.testobj.override_output_netcdf_name} as a custom output file! Here be dragons." + ) + else: + output_data = dataset_to_dict(case.ds_out) + all_ref_data = output_data ref_data_out = {} results = {} @@ -266,7 +310,7 @@ def state_from_savepoint(serializer, savepoint, name_to_std_name): properties = RESTART_PROPERTIES origin = gt_utils.origin state = {} - for name, std_name in name_to_std_name.items(): + for name, _std_name in name_to_std_name.items(): array = serializer.read(name, savepoint) extent = tuple(np.asarray(array.shape) - 2 * np.asarray(origin)) state["air_temperature"] = Quantity( @@ -293,7 +337,7 @@ def get_tile_communicator(comm, layout): @pytest.mark.parallel @pytest.mark.skipif( - MPI is None or MPI.COMM_WORLD.Get_size() == 1, + MPI.COMM_WORLD.Get_size() == 1, reason="Not running in parallel with mpi", ) def test_parallel_savepoint( @@ -322,8 +366,8 @@ def test_parallel_savepoint( ) communicator = get_communicator(mpi_comm, layout) if case.testobj is None: - pytest.xfail( - f"no translate object available for savepoint {case.savepoint_name}" + raise ValueError( + f"No translate object available for savepoint {case.savepoint_name}" ) stencil_config = StencilConfig( compilation_config=CompilationConfig(backend=backend), @@ -353,7 +397,7 @@ def test_parallel_savepoint( out_vars.update(list(case.testobj._base.out_vars.keys())) failing_names = [] passing_names = [] - ref_data: Dict[str, Any] = {} + ref_data: dict[str, Any] = {} all_ref_data = dataset_to_dict(case.ds_out) results = {} @@ -425,7 +469,7 @@ def test_parallel_savepoint( def _report_results( savepoint_name: str, rank: int, - results: Dict[str, BaseMetric], + results: dict[str, BaseMetric], ) -> None: detail_dir = f"{OUTDIR}/details" os.makedirs(detail_dir, exist_ok=True) @@ -446,10 +490,10 @@ def _report_results( def _save_datatree( testobj, # first list over rank, second list over savepoint - inputs_list: List[Dict[str, List[np.ndarray]]], - output_list: List[Dict[str, List[np.ndarray]]], - ref_data: Dict[str, List[np.ndarray]], - names: List[str], + inputs_list: list[dict[str, list[np.ndarray]]], + output_list: list[dict[str, list[np.ndarray]]], + ref_data: dict[str, list[np.ndarray]], + names: list[str], ): import xarray as xr @@ -460,6 +504,11 @@ def _save_datatree( varname = names[index] # Read in dimensions and attributes if hasattr(testobj, "outputs") and testobj.outputs != {}: + if not isinstance(testobj.outputs, dict): + raise ValueError( + f"Expecting `outputs` on translate test to be a dict, got {type(testobj.outputs)}." + " Are you overriding `self.outputs`?" + ) dims = [ dim_name + f"_{index}" for dim_name in testobj.outputs[varname]["dims"] ] @@ -504,11 +553,11 @@ def _save_datatree( def save_netcdf( testobj, # first list over rank, second list over savepoint - inputs_list: List[Dict[str, List[np.ndarray]]], - output_list: List[Dict[str, List[np.ndarray]]], - ref_data: Dict[str, List[np.ndarray]], - failing_names: List[str], - passing_names: List[str], + inputs_list: list[dict[str, list[np.ndarray]]], + output_list: list[dict[str, list[np.ndarray]]], + ref_data: dict[str, list[np.ndarray]], + failing_names: list[str], + passing_names: list[str], out_filename, ): import xarray as xr diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 0acee958..8cf78282 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -1,11 +1,11 @@ import logging -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import numpy as np +import numpy.typing as npt import ndsl.dsl.gt4py_utils as utils from ndsl.dsl.stencil import StencilFactory -from ndsl.dsl.typing import Field, Float, Int # noqa: F401 from ndsl.optional_imports import cupy as cp from ndsl.quantity import Quantity from ndsl.stencils.testing.grid import Grid # type: ignore @@ -29,15 +29,15 @@ def pad_field_in_j(field, nj: int, backend: str): def as_numpy( - value: Union[Dict[str, Any], Quantity, np.ndarray], -) -> Union[np.ndarray, Dict[str, np.ndarray]]: - def _convert(value: Union[Quantity, np.ndarray]) -> np.ndarray: + value: dict[str, Any] | Quantity | np.ndarray, +) -> np.ndarray | dict[str, np.ndarray]: + def _convert(value: Quantity | np.ndarray) -> np.ndarray: if isinstance(value, Quantity): return value.data - elif cp is not None and isinstance(value, cp.ndarray): - return cp.asnumpy(value) elif isinstance(value, np.ndarray): return value + elif cp is not None and isinstance(value, cp.ndarray): + return cp.asnumpy(value) else: raise TypeError(f"Unrecognized value type: {type(value)}") @@ -60,17 +60,23 @@ class TranslateFortranData2Py: mmr_relative_fraction = -1 mmr_ulp = -1 - def __init__(self, grid, stencil_factory: StencilFactory, origin=utils.origin): + def __init__( + self, + grid, + stencil_factory: StencilFactory, + origin=utils.origin, + skip_test: bool = False, + ): self.origin = origin self.stencil_factory = stencil_factory - self.in_vars: Dict[str, Any] = {"data_vars": {}, "parameters": []} - self.out_vars: Dict[str, Any] = {} - self.write_vars: List = [] + self.in_vars: dict[str, Any] = {"data_vars": {}, "parameters": []} + self.out_vars: dict[str, Any] = {} + self.write_vars: list = [] self.grid = grid - self.maxshape: Tuple[int, ...] = grid.domain_shape_full(add=(1, 1, 1)) + self.maxshape: tuple[int, ...] = grid.domain_shape_full(add=(1, 1, 1)) self.ordered_input_vars = None - self.ignore_near_zero_errors: Dict[str, Any] = {} - self.skip_test: bool = False + self.ignore_near_zero_errors: dict[str, Any] = {} + self.skip_test = skip_test def extra_data_load(self, data_loader: DataLoader): pass @@ -79,7 +85,7 @@ def setup(self, inputs) -> None: """Transform inputs to gt4py.storages specification (correct device, layout)""" self.make_storage_data_input_vars(inputs) - def compute_func(self, **inputs) -> Optional[dict[str, Any]]: + def compute_func(self, **inputs) -> dict[str, Any] | None: """Compute function to transform the dictionary of `inputs`. Must return a dictionary of updated variables""" raise NotImplementedError("Implement a child class compute method") @@ -100,7 +106,7 @@ def compute_from_storage(self, inputs) -> dict[str, Any]: Hypothesis: `inputs` are `gt4py.storages` - Return: Outputs in the form of a Dict[str, gt4py.storages] + Return: Outputs in the form of a dict[str, gt4py.storages] """ outputs = self.compute_func(**inputs) if outputs is not None: @@ -124,17 +130,17 @@ def make_storage_data( istart: int = 0, jstart: int = 0, kstart: int = 0, - dummy_axes: Optional[Tuple[int, int, int]] = None, + dummy_axes: tuple[int, int, int] | None = None, axis: int = 2, - names_4d: Optional[List[str]] = None, + names_4d: list[str] | None = None, read_only: bool = False, full_shape: bool = False, - ) -> "Field": + ) -> dict[str, npt.NDArray] | npt.NDArray: """Copy input data into a gt4py.storage with given shape. `array` is copied. Takes care of the device upload if necessary. - Return: Array in the form of a Dict[str, gt4py.storages] + Return: Array in the form of a dict[str, gt4py.storages] """ use_shape = list(self.maxshape) if dummy_axes: @@ -195,7 +201,10 @@ def collect_start_indices(self, datashape, varinfo): return istart, jstart, kstart def make_storage_data_input_vars( - self, inputs, storage_vars=None, dict_4d=True + self, + inputs, + storage_vars=None, + dict_4d=True, ) -> None: """From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionary to update inputs to their configured shape. @@ -286,7 +295,18 @@ def slice_output(self, inputs, out_data=None) -> dict[str, Any]: ) out[serialname] = var4d else: - slice_tuple = self.grid.slice_dict(ds, len(data_result.shape)) + # Get slice for data dimensions (after original 3D) + if len(data_result.shape) > 3: + data_dims_slice = tuple( + [slice(0, ddim_end) for ddim_end in data_result.shape[3:]] + ) + else: + data_dims_slice = () + # Slice combine the expected cartesian and data_dims + cartesian_slice = self.grid.slice_dict( + ds, min(len(data_result.shape), 3) + ) + slice_tuple = cartesian_slice + data_dims_slice out[serialname] = np.squeeze(data_result[slice_tuple]) if "kaxis" in info: out[serialname] = np.moveaxis(out[serialname], 2, info["kaxis"]) diff --git a/ndsl/stencils/tridiag.py b/ndsl/stencils/tridiag.py index e1fe50c0..87b7311e 100644 --- a/ndsl/stencils/tridiag.py +++ b/ndsl/stencils/tridiag.py @@ -10,7 +10,7 @@ def tridiag_solve( d: FloatField, x: FloatField, delta: FloatField, -): +) -> None: """ This stencil solves a square, k x k tridiagonal matrix system with coefficients a, b, and c, and vectors p and d using the Thomas algorithm: @@ -58,7 +58,7 @@ def masked_tridiag_solve( x: FloatField, delta: FloatField, mask: BoolFieldIJ, -): +) -> None: """ Same as tridiag_solve but restricted to a subset of horizontal points diff --git a/ndsl/testing/README.md b/ndsl/testing/README.md index 549d3467..74a89a8b 100644 --- a/ndsl/testing/README.md +++ b/ndsl/testing/README.md @@ -112,7 +112,7 @@ where fields other than `var1` and `var2` will use `global_value`. Stencil_name: - backend: multimodal: - absolute_eps: + absolute_epsilon: relative_fraction: ulp_threshold: ``` diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 7862904d..44b4077b 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Union +from abc import ABC, abstractmethod +from typing import Any import numpy as np import numpy.typing as npt @@ -20,7 +21,7 @@ def _fixed_width_float_2e(value: np.floating[Any]) -> str: return f"{value:.2e}" -class BaseMetric: +class BaseMetric(ABC): def __init__( self, reference_values: np.ndarray, @@ -30,17 +31,17 @@ def __init__( self.computed = np.atleast_1d(computed_values) self.check = False - def __str__(self) -> str: - ... + @abstractmethod + def __str__(self) -> str: ... - def __repr__(self) -> str: - ... + @abstractmethod + def __repr__(self) -> str: ... - def report(self, file_path: Optional[str] = None) -> List[str]: - ... + @abstractmethod + def report(self, file_path: str | None = None) -> list[str]: ... - def one_line_report(self) -> str: - ... + @abstractmethod + def one_line_report(self) -> str: ... class LegacyMetric(BaseMetric): @@ -56,7 +57,7 @@ def __init__( reference_values: np.ndarray, computed_values: np.ndarray, eps: float, - ignore_near_zero_errors: Union[dict, bool], + ignore_near_zero_errors: bool | dict, near_zero: float, ): super().__init__(reference_values, computed_values) @@ -70,21 +71,23 @@ def __init__( def _compute_errors( self, - ignore_near_zero_errors, - near_zero, + ignore_near_zero_errors: bool | dict, + near_zero: float, ) -> npt.NDArray[np.bool_]: if self.references.dtype in (np.float64, np.int64, np.float32, np.int32): - denom = self.references - denom[self.references == 0] = self.computed[self.references == 0] + # Rule number 1: Never touch the reference data! + denom = self.references.copy() + # Avoid division by 0. If reference is 0, we expect the computed value to be 0 too. + # (abs(computed - reference) / 1.0) is a good value for the error in this case. + denom[self.references == 0] = 1.0 self._calculated_metric = np.asarray( np.abs((self.computed - self.references) / denom) ) - self._calculated_metric[denom == 0] = 0.0 elif self.references.dtype in (np.bool_, bool): self._calculated_metric = np.logical_xor(self.computed, self.references) else: raise TypeError( - f"received data with unexpected dtype {self.references.dtype}" + f"Received data with unexpected dtype `{self.references.dtype}`." ) success = np.logical_or( np.logical_and(np.isnan(self.computed), np.isnan(self.references)), @@ -116,7 +119,7 @@ def one_line_report(self) -> str: else: return "❌ Numerical failures" - def report(self, file_path: Optional[str] = None) -> List[str]: + def report(self, file_path: str | None = None) -> list[str]: report = [] report.append(self.one_line_report()) if not self.check: @@ -189,17 +192,17 @@ def __repr__(self) -> str: class _Metric: - def __init__(self, value): - self._value: float = value - self.is_default = True + def __init__(self, value: float) -> None: + self._value = value + self.is_default: bool = True @property def value(self) -> float: return self._value @value.setter - def value(self, _value: float): - self._value = _value + def value(self, value: float) -> None: + self._value = value self.is_default = False @@ -227,8 +230,7 @@ def __init__( relative_fraction_override: float = -1, ulp_override: float = -1, sort_report: str = "ulp", - **kwargs, - ): + ) -> None: super().__init__(reference_values, computed_values) self.absolute_distance = np.empty_like(self.references) self.absolute_distance_metric = np.empty_like(self.references, dtype=np.bool_) @@ -320,7 +322,7 @@ def one_line_report(self) -> str: all_indices = len(self.references.flatten()) return f"❌ Numerical failures: {failed_indices}/{all_indices} failed - metric: {metric_thresholds}" - def report(self, file_path: Optional[str] = None) -> List[str]: + def report(self, file_path: str | None = None) -> list[str]: report = [] report.append(self.one_line_report()) failed_indices = np.logical_not(self.success).nonzero() diff --git a/ndsl/testing/perturbation.py b/ndsl/testing/perturbation.py index 25e42302..b97b82d1 100644 --- a/ndsl/testing/perturbation.py +++ b/ndsl/testing/perturbation.py @@ -1,9 +1,9 @@ -from typing import Mapping +from collections.abc import Mapping import numpy as np -def perturb(input: Mapping[str, np.ndarray]): +def perturb(input: Mapping[str, np.ndarray]) -> None: """ Adds roundoff-level noise to the input array in-place through multiplication. diff --git a/ndsl/types.py b/ndsl/types.py index e3461c39..57aa6b3b 100644 --- a/ndsl/types.py +++ b/ndsl/types.py @@ -1,20 +1,20 @@ import functools -from typing import Iterable, TypeVar +from collections.abc import Iterable +from typing import TypeAlias import numpy as np from typing_extensions import Protocol -Array = TypeVar("Array") +Number: TypeAlias = int | float | np.int32 | np.int64 | np.float32 | np.float64 class Allocator(Protocol): - def __call__(self, shape: Iterable[int], dtype: type): + def __call__(self, shape: Iterable[int], dtype: type) -> None: pass class NumpyModule(Protocol): - empty: Allocator zeros: Allocator ones: Allocator @@ -43,6 +43,6 @@ def asarray(self, *args, **kwargs): class AsyncRequest(Protocol): """Define the result of an over-the-network capable communication API""" - def wait(self): + def wait(self) -> None: """Block the current thread waiting for the request to be completed""" ... diff --git a/ndsl/units.py b/ndsl/units.py index 73414715..7fb441b1 100644 --- a/ndsl/units.py +++ b/ndsl/units.py @@ -1,11 +1,33 @@ +import warnings + + def ensure_equal_units(units1: str, units2: str) -> None: + warnings.warn( + "`ensure_equal_units` is unused and usage is discouraged. The function " + "will be removed in the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) if not units_are_equal(units1, units2): raise UnitsError(f"incompatible units {units1} and {units2}") def units_are_equal(units1: str, units2: str) -> bool: + warnings.warn( + "`units_are_equal` is unused and usage is discouraged. The function will " + "be removed in the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) return units1.strip() == units2.strip() class UnitsError(Exception): - pass + def __init__(self, *args) -> None: + warnings.warn( + "`UnitsError` is unused and usage is discouraged. The class will be " + "removed in the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args) diff --git a/ndsl/utils.py b/ndsl/utils.py index 36316680..37a4bbbf 100644 --- a/ndsl/utils.py +++ b/ndsl/utils.py @@ -1,6 +1,9 @@ +from collections.abc import Iterable, Sequence from enum import EnumMeta -from typing import Iterable, Sequence, Tuple, TypeVar, Union +from pathlib import Path +from typing import TypeVar +import f90nml import numpy as np import ndsl.constants as constants @@ -23,13 +26,13 @@ class MetaEnumStr(EnumMeta): - def __contains__(cls, item) -> bool: + def __contains__(cls, item: object) -> bool: return item in cls.__members__.keys() def list_by_dims( dims: Sequence[str], horizontal_list: Sequence[T], non_horizontal_value: T -) -> Tuple[T, ...]: +) -> tuple[T, ...]: """Take in a list of dimensions, a (y, x) set of values, and a value for any non-horizontal dimensions. Return a list of length len(dims) with the value for each dimension. @@ -53,12 +56,12 @@ def is_c_contiguous(array: np.ndarray) -> bool: return array.flags["C_CONTIGUOUS"] -def ensure_contiguous(maybe_array: Union[np.ndarray, None]) -> None: +def ensure_contiguous(maybe_array: np.ndarray | None) -> None: if maybe_array is not None and not is_contiguous(maybe_array): raise BufferError("dlpack: buffer is not contiguous") -def safe_assign_array(to_array: np.ndarray, from_array: np.ndarray): +def safe_assign_array(to_array: np.ndarray, from_array: np.ndarray) -> None: """Failproof assignment for array on different devices. The memory will be downloaded/uploaded from GPU if need be. @@ -78,7 +81,7 @@ def safe_assign_array(to_array: np.ndarray, from_array: np.ndarray): raise -def device_synchronize(): +def device_synchronize() -> None: """Synchronize all memory communication""" if GPU_AVAILABLE: cp.cuda.runtime.deviceSynchronize() @@ -104,10 +107,121 @@ def safe_mpi_allocate( if cp and (allocator is cp.empty or allocator is cp.zeros): original_allocator = cp.cuda.get_allocator() cp.cuda.set_allocator(cp.get_default_memory_pool().malloc) - array = allocator(shape, dtype=dtype) # type: np.ndarray + array = allocator(shape, dtype=dtype) # type: ignore # np.ndarray cp.cuda.set_allocator(original_allocator) else: - array = allocator(shape, dtype=dtype) + array = allocator(shape, dtype=dtype) # type: ignore # np.ndarray if __debug__ and cp and isinstance(array, cp.ndarray): raise RuntimeError("cupy allocation might not be MPI-safe") return array + + +######################################################## +# Helpers for loading and working with Fortran Namelists +# TODO: Consider moving these to a separate utils/namelist.py + + +DEFAULT_GRID_NML_GROUPS = ["fv_core_nml"] + + +def flatten_nml_to_dict(nml: f90nml.Namelist) -> dict: + """Returns a flattened dict version of a f90nml.namelist.Namelist + + Args: + nml: f90nml.Namelist + """ + nml_dict = dict(nml) + for name, value in nml_dict.items(): + if isinstance(value, f90nml.Namelist): + nml_dict[name] = flatten_nml_to_dict(value) + flatter_namelist = {} + for key, value in nml_dict.items(): + if isinstance(value, dict): + for subkey, subvalue in value.items(): + if subkey in flatter_namelist: + raise ValueError( + "Cannot flatten this namelist, duplicate keys: " + subkey + ) + flatter_namelist[subkey] = subvalue + else: + flatter_namelist[key] = value + return flatter_namelist + + +# TODO: Consider a more universal loader, e.g., load_config(path), +# rather than a f90nml-specific loader (See PR#246). +def load_f90nml(namelist_path: Path) -> f90nml.Namelist: + """Loads a Fortran namelist given its path and return a f90nml.Namelist + + Args: + namelist_path: Path to the Fortran namelist file + """ + return f90nml.read(namelist_path) + + +def load_f90nml_as_dict( + namelist_path: Path, + flatten: bool = True, + target_groups: list[str] | None = None, +) -> dict: + """Loads a Fortran namelist given its path and returns a + dict representation. If target_groups are specified, then + the dict is created using only those groups. + + Args: + namelist_path: Path to the Fortran namelist file + flatten: If True, flattens the loaded namelist (without groups) before + returning it. (Default: True) Otherwise, it returns the f90nml.Namelist + dict representation. + target_groups: If 'None' is specified, then all groups are + considered. (Default: None) Otherwise, only parameters + from the specified groups are considered. + """ + nml = load_f90nml(namelist_path) + return f90nml_as_dict(nml, flatten=flatten, target_groups=target_groups) + + +def f90nml_as_dict( + nml: f90nml.Namelist, + flatten: bool = True, + target_groups: list[str] | None = None, +) -> dict: + """Uses a f90nml.Namelist and returns a dict representation. + If target_groups are specified, then the dict is created using only those + groups. The return dicts can be flattened further to remove the group + information or keep the group information. + + Args: + nml: f90nml.Namelist + flatten: If True, flattens the loaded namelist (without groups) before + returning it. (Default: True) Otherwise, it returns the f90nml.Namelist + dict representation. + target_groups: If 'None' is specified, then all groups are + considered. (Default: None) Otherwise, only parameters + from the specified groups are considered. + """ + if target_groups is not None: + extracted_groups = f90nml.Namelist() + for group in target_groups: + if group in nml.keys(): + extracted_groups[group] = nml[group] + else: + extracted_groups = nml + + if flatten: + return flatten_nml_to_dict(extracted_groups) + return extracted_groups.todict() + + +def grid_params_from_f90nml(nml: f90nml.Namelist) -> dict: + """Uses a f90nml.Namelist and returns a dict representation + of parameters useful for grid generation. The return dict + will be flattened with key-value pairs from the nml's + DEFAULT_GRID_NML_GROUPS. + + Args: + nml: f90nml.Namelist + """ + # TODO: Consider returning a {Cartesian,CubeSphere}GridParameters class + # rather than a dict (See PR#246). + return f90nml_as_dict(nml, flatten=True, target_groups=DEFAULT_GRID_NML_GROUPS) diff --git a/ndsl/viz/cube_sphere.py b/ndsl/viz/cube_sphere.py index 09018d1f..c6e5e3a3 100644 --- a/ndsl/viz/cube_sphere.py +++ b/ndsl/viz/cube_sphere.py @@ -26,6 +26,9 @@ def plot_cube_sphere( lon = comm.gather(grid_data.lon) if comm.rank == 0: + # We are on the root rank so comm.gather() did gather. This is just to make mypy happy. + assert data is not None and lat is not None and lon is not None + fig, ax = plt.subplots(1, 1, subplot_kw={"projection": ccrs.Robinson()}) pcolormesh_cube( lat.view[:] * 180.0 / np.pi, diff --git a/ndsl/viz/fv3/_plot_cube.py b/ndsl/viz/fv3/_plot_cube.py index 8942d494..4ebde49f 100644 --- a/ndsl/viz/fv3/_plot_cube.py +++ b/ndsl/viz/fv3/_plot_cube.py @@ -58,15 +58,15 @@ def plot_cube( grid_metadata: GridMetadata = WRAPPER_GRID_METADATA, plotting_function: str = "pcolormesh", ax: plt.axes = None, - row: str = None, - col: str = None, - col_wrap: int = None, - projection: ccrs.Projection = None, + row: str | None = None, + col: str | None = None, + col_wrap: int | None = None, + projection: ccrs.Projection | None = None, colorbar: bool = True, cmap_percentiles_lim: bool = True, - cbar_label: str = None, + cbar_label: str | None = None, coastlines: bool = True, - coastlines_kwargs: dict = None, + coastlines_kwargs: dict | None = None, **kwargs, ): """Plots an xr.DataArray containing tiled cubed sphere gridded data @@ -93,12 +93,12 @@ def plot_cube( ax: Axes onto which the map should be plotted; must be created with a cartopy projection argument. If not supplied, axes are generated - with a projection. If ax is suppled, faceting is disabled. + with a projection. If ax is supplied, faceting is disabled. row: - Name of diemnsion to be faceted along subplot rows. Must not be a + Name of dimension to be faceted along subplot rows. Must not be a tile, lat, or lon dimension. Defaults to no row facets. col: - Name of diemnsion to be faceted along subplot columns. Must not be + Name of dimension to be faceted along subplot columns. Must not be a tile, lat, or lon dimension. Defaults to no column facets. col_wrap: If only one of `col`, `row` is specified, number of columns to plot @@ -124,7 +124,7 @@ def plot_cube( figure (plt.Figure): matplotlib figure object onto which axes grid is created axes (np.ndarray): - Array of `plt.axes` objects assocated with map subplots if faceting; + Array of `plt.axes` objects associated with map subplots if faceting; otherwise array containing single axes object. handles (list): List or nested list of matplotlib object handles associated with @@ -134,7 +134,7 @@ def plot_cube( arg is True, else None. facet_grid (xarray.plot.facetgrid): xarray plotting facetgrid for multi-axes case. In single-axes case, - retunrs None. + returns None. Example: # plot diag winds at two times @@ -205,7 +205,7 @@ def plot_cube( fig, ax = plt.subplots(1, 1, subplot_kw={"projection": projection}) else: fig = ax.figure - handle = _plot_func_short(array, ax=ax) + handle = _plot_func_short(array, ax=ax) # type: ignore axes = np.array(ax) handles = [handle] facet_grid = None diff --git a/ndsl/viz/fv3/_plot_diagnostics.py b/ndsl/viz/fv3/_plot_diagnostics.py index 9c102759..e54f709e 100644 --- a/ndsl/viz/fv3/_plot_diagnostics.py +++ b/ndsl/viz/fv3/_plot_diagnostics.py @@ -8,6 +8,7 @@ """ + import os import matplotlib.pyplot as plt diff --git a/ndsl/viz/fv3/_plot_helpers.py b/ndsl/viz/fv3/_plot_helpers.py index 75da6983..7c64f6e1 100644 --- a/ndsl/viz/fv3/_plot_helpers.py +++ b/ndsl/viz/fv3/_plot_helpers.py @@ -1,5 +1,4 @@ import textwrap -from typing import Optional, Tuple import numpy as np @@ -60,7 +59,11 @@ def _min_max_from_percentiles(x, min_percentile=2, max_percentile=98): def _infer_color_limits( - xmin: float, xmax: float, vmin: float = None, vmax: float = None, cmap: str = None + xmin: float, + xmax: float, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | None = None, ): """ "auto-magical" handling of color limits and colormap if not supplied by user @@ -100,22 +103,24 @@ def _infer_color_limits( else: vmin, vmax = xmin, xmax cmap = "viridis" if not cmap else cmap - elif vmin is None: + elif vmin is None and vmax is not None: if xmin < 0 and vmax > 0: vmin = -vmax cmap = "RdBu_r" if not cmap else cmap else: vmin = xmin cmap = "viridis" if not cmap else cmap - elif vmax is None: + elif vmax is None and vmin is not None: if xmax > 0 and vmin < 0: vmax = -vmin cmap = "RdBu_r" if not cmap else cmap else: vmax = xmax cmap = "viridis" if not cmap else cmap - elif not cmap: + elif not cmap and vmin is not None and vmax is not None: cmap = "RdBu_r" if vmin == -vmax else "viridis" + else: + raise ValueError("Inconsistent arguments supplied.") return vmin, vmax, cmap @@ -147,11 +152,11 @@ def _get_var_label(attrs: dict, var_name: str, max_line_length: int = 30): def infer_cmap_params( data: np.ndarray, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - cmap: Optional[str] = None, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | None = None, robust: bool = False, -) -> Tuple[float, float, str]: +) -> tuple[float, float, str]: """Determine useful colorbar limits and cmap for given data. Args: diff --git a/ndsl/viz/fv3/_timestep_histograms.py b/ndsl/viz/fv3/_timestep_histograms.py index 38985478..7743fa82 100644 --- a/ndsl/viz/fv3/_timestep_histograms.py +++ b/ndsl/viz/fv3/_timestep_histograms.py @@ -1,5 +1,5 @@ import datetime -from typing import Sequence, Union +from collections.abc import Sequence import matplotlib.pyplot as plt import numpy as np @@ -8,7 +8,7 @@ def plot_daily_and_hourly_hist( - time_list: Sequence[Union[datetime.datetime, np.datetime64]], + time_list: Sequence[datetime.datetime | np.datetime64], ) -> plt.figure: """Given a sequence of datetimes (anything that can be handled by pandas) create and return 2-subplot figure with histograms of daily and hourly counts.""" diff --git a/ndsl/viz/fv3/grid_metadata.py b/ndsl/viz/fv3/grid_metadata.py index 2171360e..f44df8be 100644 --- a/ndsl/viz/fv3/grid_metadata.py +++ b/ndsl/viz/fv3/grid_metadata.py @@ -5,8 +5,7 @@ class GridMetadata(abc.ABC): @property @abc.abstractmethod - def coord_vars(self) -> dict: - ... + def coord_vars(self) -> dict: ... @dataclasses.dataclass diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..cf8e0a69 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,122 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = ["setuptools >= 80", "setuptools-scm>=8"] + +[project] +authors = [{name = "NOAA/NASA"}] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Atmospheric Science", + "Private :: Do Not Upload", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11" +] +dynamic = ["dependencies", "version"] +license = "Apache-2.0" +license-files = ["LICENSE.txt", "ndsl/viz/fv3/README.md"] +name = "ndsl" +readme = "README.md" +requires-python = ">=3.11,<3.12" + +[project.optional-dependencies] +demos = ["ipython", "ipykernel"] +dev = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "pre-commit", "flake8-pyproject", "build"] +docs = ["mkdocs-material", "mkdocstrings[python]"] +extras = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "ndsl[dev]"] +test = ["pytest", "coverage"] + +[project.scripts] +ndsl-serialbox_to_netcdf = "ndsl.stencils.testing.serialbox_to_netcdf:entry_point" + +[project.urls] +Repository = "https://github.com/NOAA-GFDL/NDSL" + +[tool.aliases] + +[tool.coverage.run] +branch = true +omit = ["tests/*", "*gt_cache*", ".dacecache*", "external/*", "__init__.py"] +parallel = true +source_pkgs = ["ndsl"] + +[tool.flake8] +exclude = ["docs"] +extend-ignore = [ + # Redundant with black formatter + "E302", # Blank lines before function + "E501", # Line too long + "E704", # One statement per line + "W293", # Blank line contains whitespace + "W503", # Linebreak before binary operator + # Clashes with NDSL/stencil syntax + "E203", # Space before ":", e.g. `field[1, :]` + "F841", # Local variable assigned but unused + # other + "B019" # We'd like to keep using functools.lru_cache +] +max-line-length = 88 + +[tool.isort] +default_section = "THIRDPARTY" +known_third_party = "f90nml,pytest,xarray,numpy,mpi4py,gt4py,dace" +lines_after_imports = 2 +profile = "black" +sections = "FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER" + +[tool.mypy] +disallow_incomplete_defs = true +disallow_untyped_defs = true +# TODO: understand what's wrong with imports and fix +ignore_missing_imports = true +strict_equality = true +strict_optional = true +warn_redundant_casts = true +warn_unreachable = true +warn_unused_configs = true +warn_unused_ignores = true + +[[tool.mypy.overrides]] +disallow_incomplete_defs = false +disallow_untyped_defs = false +module = [ + # TODO: fix issues and remove this override + "ndsl.grid.*", + "ndsl.viz.*", + # Unclear what happens to this module in the future + "ndsl.dsl.gt4py_utils", + # We don't fix legacy (unless we really need it) + "ndsl.restart._legacy_restart", + # These files are deprecated + "ndsl.exceptions", + "ndsl.stencils.corners", + "ndsl.units", + # translate test system is "beyond repair" + "ndsl.stencils.testing.translate", + "ndsl.stencils.testing.parallel_translate" +] + +[[tool.mypy.overrides]] +disallow_untyped_defs = false +module = [ + # Unclear what the state of these modules is + "ndsl.monitor.*", + "ndsl.stencils.testing.serialbox_to_netcdf", + # deprecated + "ndsl.filesystem", + # TODO: fix issue and remove this override + "ndsl.types", + "ndsl.comm._boundary_utils", + "ndsl.quantity.bounds" +] + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +include = ["ndsl", "ndsl.*"] + +[tool.setuptools_scm] +local_scheme = "dirty-tag" +version_scheme = "guess-next-dev" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index ce80d954..00000000 --- a/setup.cfg +++ /dev/null @@ -1,41 +0,0 @@ -[flake8] -exclude = docs -ignore = E203,E501,W293,W503,E302,E203,F841 -max-line-length = 88 - -[aliases] - -[tool:isort] -line_length = 88 -force_grid_wrap = 0 -include_trailing_comma = true -multi_line_output = 3 -use_parentheses = true -lines_after_imports = 2 -default_section = THIRDPARTY -sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER -known_third_party = f90nml,pytest,xarray,numpy,mpi4py,gt4py,dace - -[mypy] -ignore_missing_imports = True -follow_imports = normal -namespace_packages = True -strict_optional = False -warn_unreachable = True -explicit_package_bases = True - -[coverage:run] -parallel = true -branch = true -omit = - tests/* - *gt_cache* - .dacecache* - external/* - __init__.py -source_pkgs = ndsl - -[metadata] -# Include the license file in the generated wheel, see -# https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file -license_files = LICENSE.txt diff --git a/setup.py b/setup.py index 453c4e07..1719226f 100644 --- a/setup.py +++ b/setup.py @@ -1,70 +1,31 @@ -import os from pathlib import Path -from typing import List -from setuptools import find_namespace_packages, setup +from setuptools import setup def local_pkg(name: str, relative_path: str) -> str: """Returns an absolute path to a local package.""" - path = f"{name} @ file://{Path(os.path.abspath(__file__)).parent / relative_path}" - return path + return f"{name} @ file://{Path(__file__).absolute().parent / relative_path}" -docs_requirements = ["mkdocs-material"] -demos_requirements = ["ipython", "ipykernel"] -test_requirements = ["pytest", "pytest-subtests", "coverage"] - -develop_requirements = test_requirements + docs_requirements + ["pre-commit"] - -extras_requires = { - "demos": demos_requirements, - "develop": develop_requirements, - "docs": docs_requirements, - "test": test_requirements, -} - -requirements: List[str] = [ +requirements: list[str] = [ local_pkg("gt4py", "external/gt4py"), local_pkg("dace", "external/dace"), - "mpi4py==3.1.5", + "mpi4py>=4.1", "cftime", "xarray>=2025.01.2", # datatree + fixes "f90nml>=1.1.0", "fsspec", - "netcdf4==1.7.1", + "netcdf4==1.7.2", "scipy", # restart capacities only "h5netcdf", # for xarray "dask", # for xarray "numpy==1.26.4", "matplotlib", # for plotting in boilerplate "cartopy", # for plotting in ndsl.viz + "pytest-subtests", # for translate tests + "dacite", # for state ] -setup( - author="NOAA/NASA", - python_requires=">=3.11", - classifiers=[ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.11", - ], - install_requires=requirements, - extras_require=extras_requires, - name="ndsl", - license="Apache 2.0 license", - packages=find_namespace_packages(include=["ndsl", "ndsl.*"]), - include_package_data=True, - url="https://github.com/NOAA-GFDL/NDSL", - version="2025.05.00", - zip_safe=False, - entry_points={ - "console_scripts": [ - "ndsl-serialbox_to_netcdf = ndsl.stencils.testing.serialbox_to_netcdf:entry_point", - ] - }, -) +setup(install_requires=requirements) diff --git a/tests/checkpointer/__init__.py b/tests/checkpointer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 761e773f..0aeb2f84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,6 @@ import numpy as np import pytest - -try: - import ndsl.dsl # noqa: F401 -except ModuleNotFoundError as error: - error.msg = f"NDSL cannot be loaded because {error.msg}" - raise error - from ndsl.optional_imports import cupy diff --git a/tests/data/eta/README.md b/tests/data/eta/README.md new file mode 100644 index 00000000..1e6ed4b1 --- /dev/null +++ b/tests/data/eta/README.md @@ -0,0 +1,17 @@ +# Data for hybrid pressure calculations at each vertical level + +NDSL requires the coefficients necessary for calculation of the pressure at each k-level to be supplied during a run. The equation for calculating these pressures takes the form: + +$$\\P_k = a_k + b_k * P_s\\$$ + +where $P_k$ (also $\eta$) is the pressure at the k-level, $a_k$ and $b_k$, the needed coefficients, and $P_s$ the surface level pressure. These coefficients must be supplied in a NetCDF file format, and in a monotonically increasing format. + +# Testing of $\eta$ calculation coefficients data file input + +Current unit tests check that an input file will set the values correctly, check that a clear error is thrown when no file is supplied, and that the data contained within is monotonically increasing. To run these tests contained in [tests/grid/test_eta.py](../../../tests/grid/test_eta.py) after a successful installation of NDSL, run: + +```shell +pytest tests/grid/test_eta.py +``` + +from the top level of the repository. Sample data files for these tests are contained in the directory [tests/data/eta](../../../tests/data/eta/) diff --git a/tests/data/eta/eta79.nc b/tests/data/eta/eta79.nc new file mode 100644 index 00000000..d0ffc9a0 Binary files /dev/null and b/tests/data/eta/eta79.nc differ diff --git a/tests/data/eta/eta91.nc b/tests/data/eta/eta91.nc new file mode 100644 index 00000000..10efe21f Binary files /dev/null and b/tests/data/eta/eta91.nc differ diff --git a/tests/data/eta/non_mono_eta79.nc b/tests/data/eta/non_mono_eta79.nc new file mode 100644 index 00000000..3836ceab Binary files /dev/null and b/tests/data/eta/non_mono_eta79.nc differ diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py new file mode 100644 index 00000000..d0c9f1b0 --- /dev/null +++ b/tests/dsl/orchestration/test_call.py @@ -0,0 +1,42 @@ +from ndsl import QuantityFactory, StencilFactory +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.dace.orchestration import orchestrate +from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval + + +def _stencil(out: Field[float]): + with computation(PARALLEL), interval(...): + out = out + 1 + + +class OrchestratedProgram: + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo(_stencil, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, out_qty): + self.stencil(out_qty) + + +def test_memory_reallocation(): + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 2, 0 + ) + code = OrchestratedProgram(stencil_factory, quantity_factory) + qty_A = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "A") + qty_B = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "B") + + code(qty_A) + assert (qty_A.field[0, 0, :] == 2).all() + + code(qty_A) + assert (qty_A.field[0, 0, :] == 3).all() + + code(qty_B) + assert (qty_A.field[0, 0, :] == 3).all() + assert (qty_B.field[0, 0, :] == 2).all() diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index d550c123..5b8ef1d7 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -1,5 +1,7 @@ import os import shutil +import sys +from pathlib import Path import pytest from gt4py.cartesian import config as gt_config @@ -16,30 +18,18 @@ from ndsl.comm.mpi import MPI from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval +from ndsl.dsl.stencil import CompareToNumpyStencil, FrozenStencil +from tests.dsl import utils -def _make_storage( - func, - grid_indexing, - stencil_config: StencilConfig, - *, - dtype=float, - aligned_index=(0, 0, 0), -): - return func( - backend=stencil_config.compilation_config.backend, - shape=grid_indexing.domain, - dtype=dtype, - aligned_index=aligned_index, - ) - - -def _stencil(inp: Field[float], out: Field[float], scalar: float): +def _stencil(inp: Field[float], out: Field[float]): with computation(PARALLEL), interval(...): out = inp -def _build_stencil(backend, orchestrated: DaCeOrchestration): +def _build_stencil( + backend: str, orchestrated: DaCeOrchestration +) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: # Make stencil and verify it ran grid_indexing = GridIndexing( domain=(5, 5, 5), @@ -65,115 +55,143 @@ def _build_stencil(backend, orchestrated: DaCeOrchestration): class OrchestratedProgram: - def __init__(self, backend, orchestration): + def __init__(self, backend, orchestration: DaCeOrchestration): self.stencil, grid_indexing, stencil_config = _build_stencil( backend, orchestration ) orchestrate(obj=self, config=stencil_config.dace_config) - self.inp = _make_storage(ones, grid_indexing, stencil_config, dtype=float) - self.out = _make_storage(empty, grid_indexing, stencil_config, dtype=float) + self.inp = utils.make_storage(ones, grid_indexing, stencil_config, dtype=float) + self.out = utils.make_storage(empty, grid_indexing, stencil_config, dtype=float) def __call__(self): - self.stencil(self.inp, self.out, self.inp[0, 0, 0]) + self.stencil(self.inp, self.out) -@pytest.mark.parametrize( - "backend", - [ - pytest.param("dace:cpu"), - ], -) @pytest.mark.skipif( - MPI is not None, reason="relocatibility checked with a one-rank setup" + MPI.COMM_WORLD.Get_size() > 1, reason="Relocatability checked with a one-rank setup" ) -def test_relocatability_orchestration(backend): - original_root_directory = gt_config.cache_settings["root_path"] - working_dir = str(os.getcwd()) - +def test_relocatability_orchestration() -> None: # Compile on default - p0 = OrchestratedProgram(backend, DaCeOrchestration.BuildAndRun) + p0 = OrchestratedProgram("dace:cpu", DaCeOrchestration.BuildAndRun) p0() - assert os.path.exists( - f"{working_dir}/.gt_cache_FV3_A/dacecache/" - "test_caches_OrchestratedProgam___call__", - ) or os.path.exists( - f"{working_dir}/.gt_cache_FV3_A/dacecache/OrchestratedProgam___call__", + + expected_cache_dir = ( + Path.cwd() + / ".gt_cache_FV3_A" + / "dacecache" + / "tests_dsl_test_caches_OrchestratedProgram___call__" ) + assert expected_cache_dir.exists() - # Compile in another directory - custom_path = f"{working_dir}/.my_cache_path" - gt_config.cache_settings["root_path"] = custom_path +@pytest.mark.skipif( + MPI.COMM_WORLD.Get_size() > 1, reason="Relocatability checked with a one-rank setup" +) +def test_relocatability_orchestration_tmpdir(tmpdir) -> None: + gt_config.cache_settings["root_path"] = tmpdir + + # Compile in temporary directory that is only available in this test session. + backend = "dace:cpu" p1 = OrchestratedProgram(backend, DaCeOrchestration.BuildAndRun) p1() - assert os.path.exists( - f"{custom_path}/.gt_cache_FV3_A/dacecache/" - "test_caches_OrchestratedProgam___call__", - ) or os.path.exists( - f"{working_dir}/.gt_cache_FV3_A/dacecache/OrchestratedProgam___call__", + + expected_cache_dir = ( + tmpdir + / ".gt_cache_FV3_A" + / "dacecache" + / "tests_dsl_test_caches_OrchestratedProgram___call__" ) + assert expected_cache_dir.exists() - # Check relocability by copying the second cache directory, + # Check relocatability by copying the second cache directory, # changing the path of gt_config.cache_settings and trying to Run on it - relocated_path = f"{working_dir}/.my_relocated_cache_path" - shutil.copytree(custom_path, relocated_path, dirs_exist_ok=True) + relocated_path = tmpdir / ".my_relocated_cache_path" + shutil.copytree(tmpdir, relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path p2 = OrchestratedProgram(backend, DaCeOrchestration.Run) p2() # Generate a file exists error to check for bad path - bogus_path = "./nope/notatall/nothappening" + bogus_path = "./nope/not_at_all/not_happening" gt_config.cache_settings["root_path"] = bogus_path with pytest.raises(RuntimeError): OrchestratedProgram(backend, DaCeOrchestration.Run) - # Restore cache settings - gt_config.cache_settings["root_path"] = original_root_directory - -@pytest.mark.parametrize( - "backend", - [ - pytest.param("dace:cpu"), - ], -) @pytest.mark.skipif( - MPI is not None, reason="relocatibility checked with a one-rank setup" + MPI.COMM_WORLD.Get_size() > 1, reason="Relocatability checked with a one-rank setup" ) -def test_relocatability(backend: str): - # Restore original dir name +def test_relocatability() -> None: gt_config.cache_settings["dir_name"] = os.environ.get( "GT_CACHE_DIR_NAME", f".gt_cache_{MPI.COMM_WORLD.Get_rank():06}" ) - - backend_sanitized = backend.replace(":", "") + gt_config.cache_settings["root_path"] = Path.cwd() # Compile on default + backend = "dace:cpu" p0 = OrchestratedProgram(backend, DaCeOrchestration.Python) p0() - assert os.path.exists( - f"./.gt_cache_000000/py38_1013/{backend_sanitized}/test_caches/_stencil/" + + backend_sanitized = backend.replace(":", "") + python_version = f"py{sys.version_info[0]}{sys.version_info[1]}" + expected_cache_path = ( + Path.cwd() + / ".gt_cache_000000" + / f"{python_version}_1013" + / f"{backend_sanitized}" + / "tests" + / "dsl" + / "test_caches" + / "_stencil" ) + assert expected_cache_path.exists() - # Compile in another directory - custom_path = "./.my_cache_path" - gt_config.cache_settings["root_path"] = custom_path +@pytest.mark.skipif( + MPI.COMM_WORLD.Get_size() > 1, reason="Relocatability checked with a one-rank setup" +) +def test_relocatability_tmpdir(tmpdir) -> None: + gt_config.cache_settings["dir_name"] = os.environ.get( + "GT_CACHE_DIR_NAME", f".gt_cache_{MPI.COMM_WORLD.Get_rank():06}" + ) + gt_config.cache_settings["root_path"] = tmpdir + + # Compile in another directory + backend = "dace:cpu" p1 = OrchestratedProgram(backend, DaCeOrchestration.Python) p1() - assert os.path.exists( - f"{custom_path}/.gt_cache_000000/py38_1013/{backend_sanitized}" - "/test_caches/_stencil/" + + backend_sanitized = backend.replace(":", "") + python_version = f"py{sys.version_info[0]}{sys.version_info[1]}" + expected_cache_path = ( + tmpdir + / ".gt_cache_000000" + / f"{python_version}_1013" + / f"{backend_sanitized}" + / "tests" + / "dsl" + / "test_caches" + / "_stencil" ) + assert expected_cache_path.exists() - # Check relocability by copying the second cache directory, + # Check relocatability by copying the first cache directory, # changing the path of gt_config.cache_settings and trying to Run on it - relocated_path = "./.my_relocated_cache_path" - shutil.copytree("./.gt_cache_000000", relocated_path, dirs_exist_ok=True) + relocated_path = tmpdir / ".my_relocated_cache_path" + shutil.copytree(tmpdir / ".gt_cache_000000", relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path + p2 = OrchestratedProgram(backend, DaCeOrchestration.Python) p2() - assert os.path.exists( - f"{relocated_path}/.gt_cache_000000/py38_1013/{backend_sanitized}" - "/test_caches/_stencil/" + + relocated_cache_path = ( + relocated_path + / ".gt_cache_000000" + / f"{python_version}_1013" + / f"{backend_sanitized}" + / "tests" + / "dsl" + / "test_caches" + / "_stencil" ) + assert relocated_cache_path.exists() diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index fa323b06..326eb222 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -30,7 +30,7 @@ def test_safety_checks(): ) def test_check_communicator_valid( size: int, use_minimal_caching: bool, run_mode: RunMode -): +) -> None: partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int((sqrt(size / 6))))) ) @@ -50,7 +50,7 @@ def test_check_communicator_valid( ) def test_check_communicator_invalid( nx: int, ny: int, use_minimal_caching: bool, run_mode: RunMode -): +) -> None: partitioner = CubedSpherePartitioner(TilePartitioner((nx, ny))) comm = NullComm(rank=0, total_ranks=nx * ny * 6) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) @@ -61,7 +61,7 @@ def test_check_communicator_invalid( config.check_communicator(cubed_sphere_comm) -def test_get_decomposition_info_from_no_comm(): +def test_get_decomposition_info_from_no_comm() -> None: config = CompilationConfig() ( computed_rank, @@ -86,7 +86,7 @@ def test_get_decomposition_info_from_no_comm(): ) def test_get_decomposition_info_from_comm( rank: int, size: int, is_compiling: bool, equivalent: int -): +) -> None: partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int(sqrt(size / 6)))) ) @@ -123,8 +123,8 @@ def test_get_decomposition_info_from_comm( ], ) def test_determine_compiling_equivalent( - rank, size, minimal_caching, run_mode, equivalent -): + rank: int, size: int, minimal_caching: bool, run_mode: RunMode, equivalent: int +) -> None: config = CompilationConfig(use_minimal_caching=minimal_caching, run_mode=run_mode) partitioner = CubedSpherePartitioner( TilePartitioner((sqrt(size / 6), sqrt(size / 6))) @@ -138,7 +138,7 @@ def test_determine_compiling_equivalent( ) -def test_as_dict(): +def test_as_dict() -> None: config = CompilationConfig() asdict = config.as_dict() assert asdict["backend"] == "numpy" @@ -151,7 +151,7 @@ def test_as_dict(): assert len(asdict) == 7 -def test_from_dict(): +def test_from_dict() -> None: specification_dict = {} config = CompilationConfig.from_dict(specification_dict) assert config.backend == "numpy" diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index c044cb16..e89b69c2 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -1,6 +1,7 @@ import unittest.mock from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.dace.dace_config import _determine_compiling_ranks from ndsl.dsl.dace.orchestration import orchestrate, orchestrate_function @@ -11,8 +12,8 @@ """ -def test_orchestrate_function_calls_dace(): - def foo(): +def test_orchestrate_function_calls_dace() -> None: + def foo() -> None: pass dace_config = DaceConfig( @@ -29,8 +30,8 @@ def foo(): assert mock_call_sdfg.call_args.args[0].f == foo -def test_orchestrate_function_does_not_call_dace(): - def foo(): +def test_orchestrate_function_does_not_call_dace() -> None: + def foo() -> None: pass dace_config = DaceConfig( @@ -46,7 +47,7 @@ def foo(): assert not mock_call_sdfg.called -def test_orchestrate_calls_dace(): +def test_orchestrate_calls_dace() -> None: dace_config = DaceConfig( communicator=None, backend="gtc:dace", @@ -54,10 +55,10 @@ def test_orchestrate_calls_dace(): ) class A: - def __init__(self): + def __init__(self) -> None: orchestrate(obj=self, config=dace_config, method_to_orchestrate="foo") - def foo(self): + def foo(self) -> None: pass with unittest.mock.patch( @@ -68,7 +69,7 @@ def foo(self): assert mock_call_sdfg.called -def test_orchestrate_does_not_call_dace(): +def test_orchestrate_does_not_call_dace() -> None: dace_config = DaceConfig( communicator=None, backend="gtc:dace", @@ -76,10 +77,10 @@ def test_orchestrate_does_not_call_dace(): ) class A: - def __init__(self): + def __init__(self) -> None: orchestrate(obj=self, config=dace_config, method_to_orchestrate="foo") - def foo(self): + def foo(self) -> None: pass with unittest.mock.patch( @@ -90,14 +91,14 @@ def foo(self): assert not mock_call_sdfg.called -def test_orchestrate_distributed_build(): +def test_orchestrate_distributed_build() -> None: dummy_dace_config = DaceConfig( communicator=None, backend="gtc:dace", orchestration=DaCeOrchestration.BuildAndRun, ) - def _does_compile(rank, partitioner) -> bool: + def _does_compile(rank: int, partitioner: Partitioner) -> bool: dummy_dace_config.layout = partitioner.layout dummy_dace_config.rank_size = partitioner.layout[0] * partitioner.layout[1] * 6 dummy_dace_config.my_rank = rank diff --git a/tests/dsl/test_skip_passes.py b/tests/dsl/test_skip_passes.py index 22b840cb..6ffb0c95 100644 --- a/tests/dsl/test_skip_passes.py +++ b/tests/dsl/test_skip_passes.py @@ -23,7 +23,7 @@ def stencil_definition(a: FloatField): a = 0.0 -def test_skip_passes_becomes_oir_pipeline(): +def test_skip_passes_becomes_oir_pipeline() -> None: backend = "numpy" dace_config = DaceConfig(None, backend) config = StencilConfig( diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 5348f346..7130853e 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -1,26 +1,32 @@ -from gt4py.storage import empty, ones +from unittest.mock import MagicMock, patch -from ndsl import CompilationConfig, GridIndexing, StencilConfig, StencilFactory -from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval +import numpy as np +import pytest +from gt4py.storage import empty, ones +from ndsl import ( + CompilationConfig, + FrozenStencil, + GridIndexing, + StencilConfig, + StencilFactory, +) +from ndsl.dsl.gt4py import FORWARD, PARALLEL, Field, computation, interval +from ndsl.dsl.typing import ( + BoolFieldIJ, + FloatField, + FloatFieldIJ, + FloatFieldIJ32, + FloatFieldIJ64, + IntFieldIJ, + IntFieldIJ32, + IntFieldIJ64, +) +from ndsl.quantity import Quantity +from tests.dsl import utils -def _make_storage( - func, - grid_indexing, - stencil_config: StencilConfig, - *, - dtype=float, - aligned_index=(0, 0, 0), -): - return func( - backend=stencil_config.compilation_config.backend, - shape=grid_indexing.domain, - dtype=dtype, - aligned_index=aligned_index, - ) - -def test_timing_collector(): +def test_timing_collector() -> None: grid_indexing = GridIndexing( domain=(5, 5, 5), n_halo=2, @@ -46,9 +52,107 @@ def func(inp: Field[float], out: Field[float]): build_report = stencil_factory.build_report(key="parse_time") assert "func" in build_report - inp = _make_storage(ones, grid_indexing, stencil_config, dtype=float) - out = _make_storage(empty, grid_indexing, stencil_config, dtype=float) + inp = utils.make_storage(ones, grid_indexing, stencil_config, dtype=float) + out = utils.make_storage(empty, grid_indexing, stencil_config, dtype=float) test(inp, out) exec_report = stencil_factory.exec_report() assert "func" in exec_report + + +@pytest.mark.parametrize("klevel,expected_origin_k", [(None, 0), (1, 1), (30, 30)]) +def test_grid_indexing_get_2d_compute_origin_domain( + klevel: int | None, + expected_origin_k: int, +): + indexing = GridIndexing( + domain=(12, 12, 79), + n_halo=3, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, + ) + + if klevel is None: + origin, domain = indexing.get_2d_compute_origin_domain() + else: + origin, domain = indexing.get_2d_compute_origin_domain(klevel) + + assert origin[2] == expected_origin_k + assert domain[2] == 1 + + +def copy_stencil(q_in: FloatField, q_out: FloatField): # type: ignore + with computation(PARALLEL), interval(...): + q_out[0, 0, 0] = q_in + + +@pytest.mark.parametrize( + "extent,dimensions,domain,call_count", + [ + ((20, 20, 30), ["x", "y", "z"], (20, 20, 20), 0), + ((20, 20), ["x", "y"], (20, 20, 30), 0), + ((20, 20), ["x_interface", "y"], (20, 20, 30), 0), + ((20, 20), ["x", "y_interface"], (20, 20, 30), 0), + ((20,), ["z"], (20, 20, 10), 0), + ((20,), ["z_interface"], (20, 20, 10), 0), + ((15, 20, 30), ["x", "y", "z"], (20, 20, 30), 1), + ((20, 15, 30), ["x", "y", "z"], (20, 20, 30), 1), + ((20, 20, 15), ["x", "y", "z"], (20, 20, 30), 1), + ], +) +def test_domain_size_comparison( + extent: tuple[int], + dimensions: list[str], + domain: tuple[int], + call_count: int, +): + quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent) + stencil = FrozenStencil( + copy_stencil, + origin=(0, 0, 0), + domain=domain, + stencil_config=MagicMock(spec=StencilConfig()), + ) + # with expectation: + warning_mock = MagicMock() + with patch("ndsl.ndsl_log.warning", warning_mock): + stencil._validate_quantity_sizes(quantity) + + assert warning_mock.call_count == call_count + + +def two_dim_temporaries_stencil(q_out: FloatField) -> None: + with computation(FORWARD), interval(0, 1): + tmp_2d_fij: FloatFieldIJ = 1.0 + tmp_2d_fij32: FloatFieldIJ32 = 2.0 + tmp_3d_fij64: FloatFieldIJ64 = 3.0 + tmp_3d_iij: IntFieldIJ = 4 + tmp_3d_iij32: IntFieldIJ32 = 5 + tmp_3d_iij64: IntFieldIJ64 = 6 + mask: BoolFieldIJ = q_out >= 0 + + with computation(PARALLEL), interval(...): + if mask: + q_out = ( + tmp_2d_fij + + tmp_2d_fij32 + + tmp_3d_fij64 + + tmp_3d_iij + + tmp_3d_iij32 + + tmp_3d_iij64 + ) + + +def test_stencil_2D_temporaries() -> None: + domain = (2, 2, 5) + quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + stencil = FrozenStencil( + two_dim_temporaries_stencil, + origin=(0, 0, 0), + domain=domain, + stencil_config=MagicMock(spec=StencilConfig()), + ) + stencil(quantity) + assert (quantity.data[1, 1, :] == 21.0).all() diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 7e6b4da3..89498695 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -14,7 +14,7 @@ def test_same_config_equal( validate_args: bool, format_source: bool, compare_to_numpy: bool, -): +) -> None: dace_config = DaceConfig( communicator=None, backend=backend, @@ -30,7 +30,6 @@ def test_same_config_equal( compare_to_numpy=compare_to_numpy, dace_config=dace_config, ) - assert config == config same_config = StencilConfig( compilation_config=CompilationConfig( @@ -46,19 +45,14 @@ def test_same_config_equal( assert config == same_config -@pytest.mark.parametrize("validate_args", [True]) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [True]) -@pytest.mark.parametrize("format_source", [True]) -@pytest.mark.parametrize("compare_to_numpy", [True]) def test_different_backend_not_equal( - backend: str, - rebuild: bool, - validate_args: bool, - format_source: bool, - device_sync: bool, - compare_to_numpy: bool, -): + backend: str = "numpy", + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( communicator=None, backend=backend, @@ -77,7 +71,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - backend="fakebackend", + backend="fake_backend", rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -89,19 +83,14 @@ def test_different_backend_not_equal( assert config != different_config -@pytest.mark.parametrize("validate_args", [True]) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [True]) -@pytest.mark.parametrize("format_source", [True]) -@pytest.mark.parametrize("compare_to_numpy", [True]) def test_different_rebuild_not_equal( - backend: str, - rebuild: bool, - validate_args: bool, - format_source: bool, - device_sync: bool, - compare_to_numpy: bool, -): + backend: str = "numpy", + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( communicator=None, backend=backend, @@ -132,18 +121,13 @@ def test_different_rebuild_not_equal( assert config != different_config -@pytest.mark.parametrize("validate_args", [True]) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [True]) -@pytest.mark.parametrize("format_source", [True]) -@pytest.mark.parametrize("compare_to_numpy", [True]) def test_different_device_sync_not_equal( - rebuild: bool, - validate_args: bool, - format_source: bool, - device_sync: bool, - compare_to_numpy: bool, -): + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( communicator=None, backend="gt:gpu", @@ -174,19 +158,14 @@ def test_different_device_sync_not_equal( assert config != different_config -@pytest.mark.parametrize("validate_args", [True]) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [True]) -@pytest.mark.parametrize("format_source", [True]) -@pytest.mark.parametrize("compare_to_numpy", [True]) def test_different_validate_args_not_equal( - backend: str, - rebuild: bool, - validate_args: bool, - format_source: bool, - device_sync: bool, - compare_to_numpy: bool, -): + backend: str = "numpy", + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( None, backend, @@ -217,19 +196,14 @@ def test_different_validate_args_not_equal( assert config != different_config -@pytest.mark.parametrize("validate_args", [True]) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [True]) -@pytest.mark.parametrize("format_source", [True]) -@pytest.mark.parametrize("compare_to_numpy", [True]) def test_different_format_source_not_equal( - backend: str, - rebuild: bool, - validate_args: bool, - format_source: bool, - device_sync: bool, - compare_to_numpy: bool, -): + backend: str = "numpy", + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( @@ -265,7 +239,7 @@ def test_different_compare_to_numpy_not_equal( format_source: bool = True, rebuild: bool = True, validate_args: bool = False, -): +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 65bf1cf2..ce9de962 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -13,7 +13,10 @@ from ndsl.dsl.gt4py import PARALLEL, computation, horizontal, interval, region from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import CompareToNumpyStencil, get_stencils_with_varied_bounds -from ndsl.dsl.typing import FloatField +from ndsl.dsl.typing import Field, FloatField + + +BACKENDS = ["numpy", "dace:cpu"] def copy_stencil(q_in: FloatField, q_out: FloatField): @@ -36,7 +39,7 @@ def add_1_in_region_stencil(q_in: FloatField, q_out: FloatField): q_out = q_in + 1.0 -def setup_data_vars(backend: str): +def setup_data_vars(backend: str) -> tuple[Field, Field]: shape = (7, 7, 3) q = make_storage_from_shape(shape, backend=backend) q[:] = 1.0 @@ -68,7 +71,8 @@ def get_stencil_factory(backend: str) -> StencilFactory: return StencilFactory(config=config, grid_indexing=indexing) -def test_get_stencils_with_varied_bounds(backend: str): +@pytest.mark.parametrize("backend", BACKENDS) +def test_get_stencils_with_varied_bounds(backend: str) -> None: origins = [(2, 2, 0), (1, 1, 0)] domains = [(1, 1, 3), (2, 2, 3)] factory = get_stencil_factory(backend) @@ -87,7 +91,8 @@ def test_get_stencils_with_varied_bounds(backend: str): np.testing.assert_array_equal(q.data, q_ref.data) -def test_get_stencils_with_varied_bounds_and_regions(backend: str): +@pytest.mark.parametrize("backend", BACKENDS) +def test_get_stencils_with_varied_bounds_and_regions(backend: str) -> None: factory = get_stencil_factory(backend) origins = [(3, 3, 0), (2, 2, 0)] domains = [(1, 1, 3), (2, 2, 3)] @@ -107,7 +112,8 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str): np.testing.assert_array_equal(q_orig.data, q_ref.data) -def test_stencil_vertical_bounds(backend: str): +@pytest.mark.parametrize("backend", BACKENDS) +def test_stencil_vertical_bounds(backend: str) -> None: factory = get_stencil_factory(backend) origins = [(3, 3, 0), (2, 2, 1)] domains = [(1, 1, 3), (2, 2, 4)] @@ -124,9 +130,11 @@ def test_stencil_vertical_bounds(backend: str): assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 4 +@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("enabled", [True, False]) -def test_stencil_factory_numpy_comparison_from_dims_halo(enabled: bool): - backend = "numpy" +def test_stencil_factory_numpy_comparison_from_dims_halo( + backend: str, enabled: bool +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( @@ -159,9 +167,11 @@ def test_stencil_factory_numpy_comparison_from_dims_halo(enabled: bool): assert isinstance(stencil, FrozenStencil) +@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("enabled", [True, False]) -def test_stencil_factory_numpy_comparison_from_origin_domain(enabled: bool): - backend = "numpy" +def test_stencil_factory_numpy_comparison_from_origin_domain( + backend: str, enabled: bool +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( @@ -192,8 +202,8 @@ def test_stencil_factory_numpy_comparison_from_origin_domain(enabled: bool): assert isinstance(stencil, FrozenStencil) -def test_stencil_factory_numpy_comparison_runs_without_exceptions(): - backend = "numpy" +@pytest.mark.parametrize("backend", BACKENDS) +def test_stencil_factory_numpy_comparison_runs_without_exceptions(backend: str) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( diff --git a/tests/dsl/test_stencil_tables.py b/tests/dsl/test_stencil_tables.py new file mode 100644 index 00000000..0b31826a --- /dev/null +++ b/tests/dsl/test_stencil_tables.py @@ -0,0 +1,79 @@ +import numpy as np +from gt4py.storage import ones, zeros + +from ndsl import ( + CompilationConfig, + DaceConfig, + DaCeOrchestration, + FrozenStencil, + GridIndexing, + StencilConfig, + StencilFactory, + orchestrate, +) +from ndsl.dsl.gt4py import FORWARD, PARALLEL, Field, GlobalTable, computation, interval +from ndsl.dsl.stencil import CompareToNumpyStencil +from tests.dsl import utils + + +def _stencil(inp: GlobalTable[np.int32, (5,)], out: Field[np.float64]) -> None: + with computation(PARALLEL), interval(0, -1): + out[0, 0, 0] = inp.A[1] + with computation(FORWARD), interval(-1, None): + out[0, 0, 0] = inp.A[1] + inp.A[2] + + +def _build_stencil( + backend: str, orchestrated: DaCeOrchestration +) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: + # Make stencil and verify it ran + grid_indexing = GridIndexing( + domain=(5, 5, 5), + n_halo=2, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, + ) + + stencil_config = StencilConfig( + compilation_config=CompilationConfig(backend=backend, rebuild=True), + dace_config=DaceConfig(None, backend, 5, 5, orchestrated), + ) + + stencil_factory = StencilFactory(stencil_config, grid_indexing) + + built_stencil = stencil_factory.from_origin_domain( + _stencil, origin=(0, 0, 0), domain=grid_indexing.domain + ) + + return built_stencil, grid_indexing, stencil_config + + +class OrchestratedProgram: + def __init__(self, backend, orchestration: DaCeOrchestration): + self.stencil, grid_indexing, stencil_config = _build_stencil( + backend, orchestration + ) + orchestrate(obj=self, config=stencil_config.dace_config) + + self.inp = ones(shape=(5,), dtype=np.int32, backend=backend) + self.inp[1] = 42 + self.out = utils.make_storage(zeros, grid_indexing, stencil_config, dtype=float) + + def __call__(self): + self.stencil(self.inp, self.out) + + +def test_stecil_with_table_orchestrated() -> None: + program = OrchestratedProgram( + backend="dace:cpu", orchestration=DaCeOrchestration.BuildAndRun + ) + + # run the orchestrated stencil + program() + + # validate output + for k in range(4): + assert (program.out[:, :, k] == 42).all() + assert (program.out[:, :, 4] == 43).all() diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 94ea894c..0458389e 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -1,7 +1,6 @@ import contextlib import unittest.mock -import gt4py.cartesian.gtscript import numpy as np import pytest @@ -16,7 +15,23 @@ from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage -from ndsl.dsl.typing import Float, FloatField +from ndsl.dsl.typing import ( + BoolFieldIJ, + Float, + FloatField, + FloatFieldIJ, + FloatFieldIJ32, + FloatFieldIJ64, + Int, + IntFieldIJ, + IntFieldIJ32, + IntFieldIJ64, +) + + +# GT4Py direct import need to be down after any `ndsl` +import gt4py.cartesian.gtscript # isort: skip +from gt4py.cartesian import definitions # isort: skip def get_stencil_config( @@ -24,7 +39,7 @@ def get_stencil_config( backend: str, orchestration: DaCeOrchestration = DaCeOrchestration.Python, **kwargs, -): +) -> StencilConfig: dace_config = DaceConfig(None, backend=backend, orchestration=orchestration) config = StencilConfig( compilation_config=CompilationConfig( @@ -46,46 +61,57 @@ def mock_gtscript_stencil(mock): gt4py.cartesian.gtscript.stencil = original_stencil -class MockFieldInfo: - def __init__(self, axes): - self.axes = axes +class MockFieldInfo(definitions.FieldInfo): + def __init__(self, *, axes: tuple[str, ...] = (), data_dims: tuple[int, ...] = ()): + # defaults + access = definitions.AccessKind.READ + boundary = None + dtype = np.float64 + + super().__init__( + axes=axes, + data_dims=data_dims, + access=access, + boundary=boundary, + dtype=dtype, + ) @pytest.mark.parametrize( "field_info, origin, field_origins", [ pytest.param( - {"a": MockFieldInfo(["I"])}, + {"a": MockFieldInfo(axes=("I"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1,)}, id="single_field_I", ), pytest.param( - {"a": MockFieldInfo(["J"])}, + {"a": MockFieldInfo(axes=("J"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (2,)}, id="single_field_J", ), pytest.param( - {"a": MockFieldInfo(["K"])}, + {"a": MockFieldInfo(axes=("K"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (3,)}, id="single_field_K", ), pytest.param( - {"a": MockFieldInfo(["I", "J"])}, + {"a": MockFieldInfo(axes=("I", "J"))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1, 2)}, id="single_field_IJ", ), pytest.param( - {"a": MockFieldInfo(["I", "J", "K"])}, + {"a": MockFieldInfo(axes=("I", "J", "K"))}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, id="single_field_origin_mapping", ), pytest.param( - {"a": MockFieldInfo(["I", "J", "K"]), "b": MockFieldInfo(["I"])}, + {"a": MockFieldInfo(axes=("I", "J", "K")), "b": MockFieldInfo(axes=("I"))}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, {"_all_": (1, 2, 3), "a": (1, 2, 3), "b": (1,)}, id="two_fields_update_origin_mapping", @@ -97,14 +123,26 @@ def __init__(self, axes): id="single_field_None", ), pytest.param( - {"a": MockFieldInfo(["I", "J"]), "b": MockFieldInfo(["I", "J", "K"])}, + { + "a": MockFieldInfo(axes=("I", "J")), + "b": MockFieldInfo(axes=("I", "J", "K")), + }, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1, 2), "b": (1, 2, 3)}, id="two_fields", ), + pytest.param( + { + "field": MockFieldInfo(axes=("I", "J", "K")), + "table": MockFieldInfo(data_dims=(5,)), + }, + (1, 2, 3), + {"_all_": (1, 2, 3), "field": (1, 2, 3), "table": (0,)}, + id="field_and_table", + ), ], ) -def test_compute_field_origins(field_info, origin, field_origins): +def test_compute_field_origins(field_info, origin, field_origins) -> None: result = FrozenStencil._compute_field_origins(field_info, origin) assert result == field_origins @@ -115,16 +153,13 @@ def copy_stencil(q_in: FloatField, q_out: FloatField): @pytest.mark.parametrize("validate_args", [True, False]) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [False]) -@pytest.mark.parametrize("format_source", [False]) def test_copy_frozen_stencil( - backend: str, - rebuild: bool, validate_args: bool, - format_source: bool, - device_sync: bool, -): + backend: str = "numpy", + rebuild: bool = False, + format_source: bool = False, + device_sync: bool = False, +) -> None: config = get_stencil_config( backend=backend, rebuild=rebuild, @@ -147,15 +182,12 @@ def test_copy_frozen_stencil( np.testing.assert_array_equal(q_in, q_out) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [False]) -@pytest.mark.parametrize("format_source", [False]) def test_frozen_stencil_raises_if_given_origin( - backend: str, - rebuild: bool, - format_source: bool, - device_sync: bool, -): + backend: str = "numpy", + rebuild: bool = False, + format_source: bool = False, + device_sync: bool = False, +) -> None: # only guaranteed when validating args config = get_stencil_config( backend=backend, @@ -177,14 +209,11 @@ def test_frozen_stencil_raises_if_given_origin( stencil(q_in, q_out, origin=(0, 0, 0)) -@pytest.mark.parametrize("device_sync", [False]) -@pytest.mark.parametrize("rebuild", [False]) -@pytest.mark.parametrize("format_source", [False]) def test_frozen_stencil_raises_if_given_domain( - backend: str, - rebuild: bool, - format_source: bool, - device_sync: bool, + backend: str = "numpy", + rebuild: bool = False, + format_source: bool = False, + device_sync: bool = False, ): # only guaranteed when validating args config = get_stencil_config( @@ -212,11 +241,11 @@ def test_frozen_stencil_raises_if_given_domain( [[False, False, False, False], [True, False, False, False]], ) def test_frozen_stencil_kwargs_passed_to_init( - backend: str, rebuild: bool, validate_args: bool, format_source: bool, device_sync: bool, + backend: str = "numpy", ): config = get_stencil_config( backend=backend, @@ -246,7 +275,19 @@ def test_frozen_stencil_kwargs_passed_to_init( externals={}, **config.stencil_kwargs(func=copy_stencil), build_info={}, - dtypes={float: Float}, + dtypes={ + # Mixed precision + float: Float, + int: Int, + # 2D temporaries + "FloatFieldIJ": FloatFieldIJ, + "FloatFieldIJ32": FloatFieldIJ32, + "FloatFieldIJ64": FloatFieldIJ64, + "IntFieldIJ": IntFieldIJ, + "IntFieldIJ32": IntFieldIJ32, + "IntFieldIJ64": IntFieldIJ64, + "BoolFieldIJ": BoolFieldIJ, + }, ) @@ -255,9 +296,9 @@ def field_after_parameter_stencil(q_in: FloatField, param: float, q_out: FloatFi q_out = param * q_in -def test_frozen_field_after_parameter(backend): +def test_frozen_field_after_parameter() -> None: config = get_stencil_config( - backend=backend, + backend="numpy", rebuild=False, validate_args=False, format_source=False, @@ -272,42 +313,31 @@ def test_frozen_field_after_parameter(backend): ) -@pytest.mark.parametrize("backend", ("numpy", "cuda")) -@pytest.mark.parametrize("rebuild", [True]) -@pytest.mark.parametrize("validate_args", [True]) def test_backend_options( - backend: str, - rebuild: bool, - validate_args: bool, -): + rebuild: bool = True, + validate_args: bool = True, +) -> None: + backend = "numpy" expected_options = { - "numpy": { - "backend": "numpy", - "rebuild": True, - "format_source": False, - "name": "test_stencil_wrapper.copy_stencil", - }, - "cuda": { - "backend": "cuda", - "rebuild": True, - "device_sync": False, - "format_source": False, - "name": "test_stencil_wrapper.copy_stencil", - }, + "backend": "numpy", + "rebuild": True, + "format_source": False, + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", } actual = get_stencil_config( - backend=backend, rebuild=rebuild, validate_args=validate_args + backend=backend, + rebuild=rebuild, + validate_args=validate_args, ).stencil_kwargs(func=copy_stencil) - expected = expected_options[backend] - assert actual == expected + assert actual == expected_options def get_mock_quantity(): return unittest.mock.MagicMock(spec=Quantity) -def test_convert_quantities_to_storage_no_args(): +def test_convert_quantities_to_storage_no_args() -> None: args = [] kwargs = {} _convert_quantities_to_storage(args, kwargs) @@ -315,7 +345,7 @@ def test_convert_quantities_to_storage_no_args(): assert len(kwargs) == 0 -def test_convert_quantities_to_storage_one_arg_quantity(): +def test_convert_quantities_to_storage_one_arg_quantity() -> None: quantity = get_mock_quantity() args = [quantity] kwargs = {} @@ -325,7 +355,7 @@ def test_convert_quantities_to_storage_one_arg_quantity(): assert len(kwargs) == 0 -def test_convert_quantities_to_storage_one_kwarg_quantity(): +def test_convert_quantities_to_storage_one_kwarg_quantity() -> None: quantity = get_mock_quantity() args = [] kwargs = {"val": quantity} @@ -335,7 +365,7 @@ def test_convert_quantities_to_storage_one_kwarg_quantity(): assert kwargs["val"] == quantity.data -def test_convert_quantities_to_storage_one_arg_nonquantity(): +def test_convert_quantities_to_storage_one_arg_nonquantity() -> None: non_quantity = unittest.mock.MagicMock(spec=tuple) args = [non_quantity] kwargs = {} @@ -345,7 +375,7 @@ def test_convert_quantities_to_storage_one_arg_nonquantity(): assert len(kwargs) == 0 -def test_convert_quantities_to_storage_one_kwarg_non_quantity(): +def test_convert_quantities_to_storage_one_kwarg_non_quantity() -> None: non_quantity = unittest.mock.MagicMock(spec=tuple) args = [] kwargs = {"val": non_quantity} diff --git a/tests/dsl/utils.py b/tests/dsl/utils.py new file mode 100644 index 00000000..5e8535e5 --- /dev/null +++ b/tests/dsl/utils.py @@ -0,0 +1,17 @@ +from ndsl import GridIndexing, StencilConfig + + +def make_storage( + func, + grid_indexing: GridIndexing, + stencil_config: StencilConfig, + *, + dtype=float, + aligned_index=(0, 0, 0), +): + return func( + backend=stencil_config.compilation_config.backend, + shape=grid_indexing.domain, + dtype=dtype, + aligned_index=aligned_index, + ) diff --git a/tests/grid/__init__.py b/tests/grid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/grid/generate_eta_files.py b/tests/grid/generate_eta_files.py deleted file mode 100755 index 1fb4d5ee..00000000 --- a/tests/grid/generate_eta_files.py +++ /dev/null @@ -1,399 +0,0 @@ -import numpy as np -import xarray as xr - - -""" -This notebook uses the python xarray module -to create an eta_file containing ak and bk coefficients -for km=79 and km=91. The coefficients are written out to -eta79.nc and eta91.nc netcdf files respectively - -To run this script: `python3 ./generate_eta_files.py` -""" - -# km = 79 -ak = xr.DataArray( - dims=["km1"], - attrs=dict(units="Pa", _FillValue=False), - data=np.array( - [ - 3.000000e02, - 6.467159e02, - 1.045222e03, - 1.469188e03, - 1.897829e03, - 2.325385e03, - 2.754396e03, - 3.191294e03, - 3.648332e03, - 4.135675e03, - 4.668282e03, - 5.247940e03, - 5.876271e03, - 6.554716e03, - 7.284521e03, - 8.066738e03, - 8.902188e03, - 9.791482e03, - 1.073499e04, - 1.162625e04, - 1.237212e04, - 1.299041e04, - 1.349629e04, - 1.390277e04, - 1.422098e04, - 1.446058e04, - 1.462993e04, - 1.473633e04, - 1.478617e04, - 1.478511e04, - 1.473812e04, - 1.464966e04, - 1.452370e04, - 1.436382e04, - 1.417324e04, - 1.395491e04, - 1.371148e04, - 1.344540e04, - 1.315890e04, - 1.285407e04, - 1.253280e04, - 1.219685e04, - 1.184788e04, - 1.148739e04, - 1.111682e04, - 1.073748e04, - 1.035062e04, - 9.957395e03, - 9.558875e03, - 9.156069e03, - 8.749922e03, - 8.341315e03, - 7.931065e03, - 7.519942e03, - 7.108648e03, - 6.698281e03, - 6.290007e03, - 5.884984e03, - 5.484372e03, - 5.089319e03, - 4.700960e03, - 4.320421e03, - 3.948807e03, - 3.587201e03, - 3.236666e03, - 2.898237e03, - 2.572912e03, - 2.261667e03, - 1.965424e03, - 1.685079e03, - 1.421479e03, - 1.175419e03, - 9.476516e02, - 7.388688e02, - 5.497130e02, - 3.807626e02, - 2.325417e02, - 1.054810e02, - -8.381903e-04, - 0.000000e00, - ] - ), -) -bk = xr.DataArray( - dims=["km1"], - attrs=dict(units="None", _FillValue=False), - data=np.array( - [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.00106595, - 0.00412866, - 0.00900663, - 0.01554263, - 0.02359921, - 0.03305481, - 0.0438012, - 0.05574095, - 0.06878554, - 0.08285347, - 0.09786981, - 0.1137643, - 0.130471, - 0.1479275, - 0.1660746, - 0.1848558, - 0.2042166, - 0.2241053, - 0.2444716, - 0.2652672, - 0.286445, - 0.3079604, - 0.3297701, - 0.351832, - 0.3741062, - 0.3965532, - 0.4191364, - 0.4418194, - 0.4645682, - 0.48735, - 0.5101338, - 0.5328897, - 0.5555894, - 0.5782067, - 0.6007158, - 0.6230936, - 0.6452944, - 0.6672683, - 0.6889648, - 0.7103333, - 0.7313231, - 0.7518838, - 0.7719651, - 0.7915173, - 0.8104913, - 0.828839, - 0.846513, - 0.8634676, - 0.8796583, - 0.8950421, - 0.9095779, - 0.9232264, - 0.9359506, - 0.9477157, - 0.9584892, - 0.9682413, - 0.9769447, - 0.9845753, - 0.9911126, - 0.9965372, - 1.0, - ] - ), -) -coefficients = xr.Dataset(data_vars={"ak": ak, "bk": bk}) -coefficients.to_netcdf("eta79.nc") - - -# km = 91 -ak = xr.DataArray( - dims=["km1"], - attrs=dict(units="Pa", _FillValue=False), - data=np.array( - [ - 1.00000000e00, - 1.75000000e00, - 2.75000000e00, - 4.09999990e00, - 5.98951054e00, - 8.62932968e00, - 1.22572632e01, - 1.71510906e01, - 2.36545467e01, - 3.21627693e01, - 4.31310921e01, - 5.71100426e01, - 7.46595764e01, - 9.64470978e01, - 1.23169769e02, - 1.55601318e02, - 1.94594009e02, - 2.41047531e02, - 2.95873840e02, - 3.60046967e02, - 4.34604828e02, - 5.20628723e02, - 6.19154846e02, - 7.31296021e02, - 8.58240906e02, - 1.00106561e03, - 1.16092859e03, - 1.33903992e03, - 1.53650012e03, - 1.75448938e03, - 1.99417834e03, - 2.25667407e03, - 2.54317139e03, - 2.85476392e03, - 3.19258569e03, - 3.55775366e03, - 3.95135107e03, - 4.37428662e03, - 4.82711084e03, - 5.31022168e03, - 5.82387793e03, - 6.36904248e03, - 6.94875244e03, - 7.56691992e03, - 8.22634277e03, - 8.93120996e03, - 9.68446191e03, - 1.04822725e04, - 1.13182793e04, - 1.21840771e04, - 1.30655674e04, - 1.39532207e04, - 1.48307285e04, - 1.56872617e04, - 1.65080645e04, - 1.72810996e04, - 1.79942988e04, - 1.86363223e04, - 1.91961797e04, - 1.96640723e04, - 2.00301914e04, - 2.02853691e04, - 2.04215254e04, - 2.04300684e04, - 2.03028730e04, - 2.00323711e04, - 1.96110664e04, - 1.90313848e04, - 1.82866426e04, - 1.73777930e04, - 1.63224639e04, - 1.51444033e04, - 1.38725674e04, - 1.25404785e04, - 1.11834170e04, - 9.83532715e03, - 8.52630664e03, - 7.28224512e03, - 6.12326074e03, - 5.06350684e03, - 4.11124902e03, - 3.27000122e03, - 2.53922729e03, - 1.91530762e03, - 1.39244995e03, - 9.63134766e02, - 6.20599365e02, - 3.57989502e02, - 1.69421387e02, - 5.10314941e01, - 2.48413086e00, - 0.00000000e00, - ] - ), -) -bk = xr.DataArray( - dims=["km1"], - attrs=dict(units="None", _FillValue=False), - data=np.array( - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 3.50123992e-06, - 2.81484008e-05, - 9.38666999e-05, - 2.28561999e-04, - 5.12343016e-04, - 1.04712998e-03, - 1.95625005e-03, - 3.42317997e-03, - 5.58632007e-03, - 8.65428988e-03, - 1.27844000e-02, - 1.81719996e-02, - 2.49934997e-02, - 3.34198996e-02, - 4.36249003e-02, - 5.57769015e-02, - 7.00351968e-02, - 8.65636021e-02, - 1.05520003e-01, - 1.27051994e-01, - 1.51319996e-01, - 1.78477004e-01, - 2.08675995e-01, - 2.42069006e-01, - 2.78813988e-01, - 3.19043010e-01, - 3.62558991e-01, - 4.08596009e-01, - 4.56384987e-01, - 5.05111992e-01, - 5.53902984e-01, - 6.01903021e-01, - 6.48333013e-01, - 6.92534983e-01, - 7.33981013e-01, - 7.72292018e-01, - 8.07236016e-01, - 8.38724971e-01, - 8.66774976e-01, - 8.91497016e-01, - 9.13065016e-01, - 9.31702971e-01, - 9.47658002e-01, - 9.61175978e-01, - 9.72495019e-01, - 9.81844008e-01, - 9.89410996e-01, - 9.95342016e-01, - 1.00000000e00, - ] - ), -) -coefficients = xr.Dataset(data_vars={"ak": ak, "bk": bk}) -coefficients.to_netcdf("eta91.nc") - -# km = diff --git a/tests/grid/test_eta.py b/tests/grid/test_eta.py index 1acd9c6d..4090b3b5 100755 --- a/tests/grid/test_eta.py +++ b/tests/grid/test_eta.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 - -import os +from pathlib import Path import numpy as np import pytest @@ -18,61 +16,30 @@ """ -This test checks to ensure that ak and bk -values are read-in and stored properly. -In addition, this test checks to ensure that -the function set_hybrid_pressure_coefficients -fails as expected if the computed eta values -vary non-monotonically and if the eta_file +This test checks to ensure that ak and bk values are read-in and stored properly. In +addition, this test checks to ensure that the function set_hybrid_pressure_coefficients +fails as expected if the computed eta values vary non-monotonically and if the eta_file is not provided. """ -def set_answers(eta_file): - - """ - Read in the expected values of ak and bk - arrays from the input eta NetCDF files. - """ - - data = xr.open_dataset(eta_file) - return data["ak"].values, data["bk"].values - - -def write_non_mono_eta_file(in_eta_file, out_eta_file): +@pytest.mark.parametrize("levels", [79, 91]) +def test_set_hybrid_pressure_coefficients_correct(levels): """ - Reads in file eta79.nc and alters randomly chosen ak/bk values - This tests the expected failure of set_eta_hybrid_coefficients - for coefficients that lead to non-monotonically increasing - eta values + This test checks to see that the ak and bk arrays are read-in correctly and are + stored as expected. Both values of km=79 and km=91 are tested and both tests are + expected to pass with the stored ak and bk values agreeing with the values read-in + directly from the NetCDF file. """ - data = xr.open_dataset(in_eta_file) - data["ak"].values[10] = data["ak"].values[0] - data["bk"].values[20] = 0.0 - - data.to_netcdf(out_eta_file) - - -@pytest.mark.parametrize("km", [79, 91]) -def test_set_hybrid_pressure_coefficients_correct(km): - - """This test checks to see that the ak and bk arrays - are read-in correctly and are stored as - expected. Both values of km=79 and km=91 are - tested and both tests are expected to pass - with the stored ak and bk values agreeing with the - values read-in directly from the NetCDF file. - """ - - working_dir = str(os.getcwd()) - eta_file = f"{working_dir}/eta{km}.nc" + eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" + eta_data = xr.open_dataset(eta_file) backend = "numpy" layout = (1, 1) - nz = km + nz = levels ny = 48 nx = 48 nhalo = 3 @@ -86,7 +53,6 @@ def test_set_hybrid_pressure_coefficients_correct(km): ny_tile=ny, nz=nz, n_halo=nhalo, - extra_dim_lengths={}, layout=layout, tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, @@ -100,28 +66,22 @@ def test_set_hybrid_pressure_coefficients_correct(km): ak_results = metric_terms.ak.data bk_results = metric_terms.bk.data - ak_answers, bk_answers = set_answers(f"eta{km}.nc") + ak_answers, bk_answers = eta_data["ak"].values, eta_data["bk"].values - if ak_answers.size != ak_results.size: - raise ValueError("Unexpected size of bk") - if bk_answers.size != bk_results.size: - raise ValueError("Unexpected size of ak") + assert ak_answers.size == ak_results.size, "Unexpected size of bk" + assert bk_answers.size == bk_results.size, "Unexpected size of ak" - if not np.array_equal(ak_answers, ak_results): - raise ValueError("Unexpected value of ak") - if not np.array_equal(bk_answers, bk_results): - raise ValueError("Unexpected value of bk") + assert np.array_equal(ak_answers, ak_results), "Unexpected value of ak" + assert np.array_equal(bk_answers, bk_results), "Unexpected value of bk" def test_set_hybrid_pressure_coefficients_nofile(): - """ - This test checks to see that the program - fails when the eta_file is not specified + This test checks to see that the program fails when the eta_file is not specified in the yaml configuration file. """ - eta_file = "NULL" + eta_file = Path("NULL") backend = "numpy" @@ -141,7 +101,6 @@ def test_set_hybrid_pressure_coefficients_nofile(): ny_tile=ny, nz=nz, n_halo=nhalo, - extra_dim_lengths={}, layout=layout, tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, @@ -149,36 +108,23 @@ def test_set_hybrid_pressure_coefficients_nofile(): quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - try: - metric_terms = MetricTerms( + with pytest.raises(ValueError, match=f"eta file {eta_file} does not exist"): + MetricTerms( quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file, ) - except Exception as error: - if str(error) == "eta file NULL does not exist": - pytest.xfail("testing eta file not correctly specified") - else: - pytest.fail(f"ERROR {error}") def test_set_hybrid_pressure_coefficients_not_mono(): - """ - This test checks to see that the program - fails when the computed eta values increase - non-monotonically. For the latter test, the - eta_file is specified in test_config_not_mono.yaml - file and the ak and bk values in the eta_file - have been changed nonsensically to result in - erroneous eta values. + This test checks to see that the program fails when the computed eta values + increase non-monotonically. For the latter test, the eta_file is specified in + test_config_not_mono.yaml file and the ak and bk values in the eta_file have been + changed nonsensically to result in erroneous eta values. """ - working_dir = str(os.getcwd()) - in_eta_file = f"{working_dir}/eta79.nc" - out_eta_file = "eta_not_mono_79.nc" - write_non_mono_eta_file(in_eta_file, out_eta_file) - eta_file = out_eta_file + eta_file = str(Path.cwd()) + "/tests/data/eta/non_mono_eta79.nc" backend = "numpy" @@ -198,7 +144,6 @@ def test_set_hybrid_pressure_coefficients_not_mono(): ny_tile=ny, nz=nz, n_halo=nhalo, - extra_dim_lengths={}, layout=layout, tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, @@ -206,18 +151,9 @@ def test_set_hybrid_pressure_coefficients_not_mono(): quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - try: - metric_terms = MetricTerms( + with pytest.raises(ValueError, match="ETA values are not monotonically increasing"): + MetricTerms( quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file, ) - except Exception as error: - if os.path.isfile(out_eta_file): - os.remove(out_eta_file) - if str(error) == "ETA values are not monotonically increasing": - pytest.xfail("testing eta values are not monotonically increasing") - else: - pytest.fail( - "ERROR in testing eta values not are not monotonically increasing" - ) diff --git a/tests/initialization/__init__.py b/tests/initialization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/initialization/test_allocator.py b/tests/initialization/test_allocator.py new file mode 100644 index 00000000..3ee01e43 --- /dev/null +++ b/tests/initialization/test_allocator.py @@ -0,0 +1,19 @@ +import warnings + +import numpy as np +import pytest + +from ndsl import QuantityFactory + + +def test_QuantityFactory_constructor_warns() -> None: + with pytest.warns( + DeprecationWarning, + match="Usage of QuantityFactory.* is discouraged and will change", + ): + QuantityFactory(None, np) + + # Make sure no warnings are emitted if users use `QuantityFactory.from_backend()` + with warnings.catch_warnings(): + warnings.simplefilter("error") + QuantityFactory.from_backend(None, "numpy") diff --git a/tests/mpi/mpi_comm.py b/tests/mpi/mpi_comm.py index 0052f777..ad92926d 100644 --- a/tests/mpi/mpi_comm.py +++ b/tests/mpi/mpi_comm.py @@ -1,4 +1,4 @@ -from mpi4py import MPI +from ndsl.comm.mpi import MPI if MPI.COMM_WORLD.Get_size() == 1: diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index bec096dd..6cab1023 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -15,14 +15,14 @@ @pytest.fixture def layout(): - if MPI is not None: - size = MPI.COMM_WORLD.Get_size() - ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile ** 0.5) - return (ranks_per_edge, ranks_per_edge) - else: + if MPI is None: return (1, 1) + size = MPI.COMM_WORLD.Get_size() + ranks_per_tile = size // 6 + ranks_per_edge = int(ranks_per_tile**0.5) + return (ranks_per_edge, ranks_per_edge) + @pytest.fixture(params=[0.1, 1.0]) def edge_interior_ratio(request): @@ -47,9 +47,7 @@ def communicator(cube_partitioner): ) -@pytest.mark.skipif( - MPI is None, reason="mpi4py is not available or pytest was not run in parallel" -) +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") def test_all_reduce(communicator): backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index b6c38e95..76aca71e 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -37,14 +37,14 @@ def dtype(numpy): @pytest.fixture def layout(): - if MPI is not None: - size = MPI.COMM_WORLD.Get_size() - ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile ** 0.5) - return (ranks_per_edge, ranks_per_edge) - else: + if MPI is None: return (1, 1) + size = MPI.COMM_WORLD.Get_size() + ranks_per_tile = size // 6 + ranks_per_edge = int(ranks_per_tile**0.5) + return (ranks_per_edge, ranks_per_edge) + @pytest.fixture def nz(): @@ -281,9 +281,7 @@ def depth_quantity( return quantity -@pytest.mark.skipif( - MPI is None, reason="mpi4py is not available or pytest was not run in parallel" -) +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") def test_depth_halo_update( depth_quantity, communicator, @@ -332,9 +330,7 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype): return quantity -@pytest.mark.skipif( - MPI is None, reason="mpi4py is not available or pytest was not run in parallel" -) +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") def test_zeros_halo_update( zeros_quantity, communicator, @@ -371,9 +367,7 @@ def test_zeros_halo_update( ) -@pytest.mark.skipif( - MPI is None, reason="mpi4py is not available or pytest was not run in parallel" -) +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") def test_zeros_vector_halo_update( zeros_quantity, communicator, diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 6b441702..0da2f21b 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -293,9 +293,7 @@ def dummy_results(worker_function, dummy_list, numpy): return result_list -@pytest.mark.skipif( - MPI is None, reason="mpi4py is not available or pytest was not run in parallel" -) +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") def test_worker(comm, dummy_results, mpi_results, numpy): comm.barrier() # synchronize the test "dots" across ranks if comm.Get_rank() == 0: @@ -304,7 +302,7 @@ def test_worker(comm, dummy_results, mpi_results, numpy): if isinstance(mpi, numpy.ndarray): numpy.testing.assert_array_equal(np.asarray(dummy), np.asarray(mpi)) elif isinstance(mpi, Exception): - assert type(dummy) == type(mpi) + assert type(dummy) is type(mpi) assert dummy.args == mpi.args else: assert dummy == mpi diff --git a/tests/quantity/__init__.py b/tests/quantity/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 2b4954d1..dccfa94f 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -160,7 +160,7 @@ def test_compute_view_edit_all_domain(quantity, n_halo, n_dims, extent_1d): pytest.skip("cannot edit an empty domain") quantity.data[:] = 0.0 quantity.view[:] = 1 - assert quantity.np.sum(quantity.data) == extent_1d ** n_dims + assert quantity.np.sum(quantity.data) == extent_1d**n_dims if n_dims > 1: quantity.np.testing.assert_array_equal(quantity.data[:n_halo, :], 0.0) quantity.np.testing.assert_array_equal( @@ -261,3 +261,31 @@ def test_to_data_array(quantity): assert ( quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data ), "data memory address is not equal" + + +def test_data_setter(): + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") + + # After allocation - field and data are the same (origin is 0) + assert quantity.data.shape == quantity.field.shape + + # Allows swap: new array is bigger than Q.shape + new_array = np.ones((10,)) + new_array[:] = 2 + quantity.data = new_array + + # After swap - field and data points to the same memory + # BUT field still respects the original origin/extent + assert (quantity.data[:] == 2).all() + assert (quantity.field[:] == 2).all() + assert quantity.data.shape != quantity.field.shape + assert quantity.field.shape == (5,) + + # Expected fail: new array is too small + new_array = np.ones((2,)) + with pytest.raises(ValueError, match="Quantity.data buffer swap failed.*"): + quantity.data = new_array + + # Expected fail: new array is not even an array + with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): + quantity.data = "meh" diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py new file mode 100644 index 00000000..4c3cd1b5 --- /dev/null +++ b/tests/quantity/test_state.py @@ -0,0 +1,146 @@ +import dataclasses +from pathlib import Path + +import numpy as np + +from ndsl import Quantity, State +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float + + +@dataclasses.dataclass +class CodeState(State): + @dataclasses.dataclass + class InnerA: + A: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + @dataclasses.dataclass + class InnerB: + B: Quantity = dataclasses.field( + metadata={ + "name": "B", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "1", + "intent": "?", + "dtype": Float, + } + ) + + inner_A: InnerA + inner_B: InnerB + C: Quantity = dataclasses.field( + metadata={ + "name": "C", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +def test_basic_state(tmpdir): + _, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + + # Test allocator + microphys_state = CodeState.ones(quantity_factory) + assert (microphys_state.inner_A.A.field[:] == 1.0).all() + + # Test NetCDF round trip + microphys_state.inner_A.A.field[:] = 42.42 + microphys_state.to_netcdf(Path(tmpdir)) + microphys_state2 = CodeState.zeros(quantity_factory) + microphys_state2.update_from_netcdf(Path(tmpdir)) + assert (microphys_state2.inner_A.A.field[:] == 42.42).all() + + # Test memory move (no copy) + a = np.ones((5, 5, 3)) + b = np.ones((5, 5, 3)) + c = np.ones((5, 5, 3)) + b[:] = 23.23 + microphys_state2.update_move_memory( + {"inner_A": {"A": a}, "inner_B": {"B": b}, "C": c}, + check_shape_and_strides=False, + ) + assert (microphys_state2.inner_A.A.field[:] == 1.0).all() + assert (microphys_state2.inner_B.B.field[:] == 23.23).all() + + # Test fill + microphys_state2.fill(18.18) + assert (microphys_state2.inner_A.A.field[:] == 18.18).all() + assert (microphys_state2.inner_B.B.field[:] == 18.18).all() + + # Test full + microphys_state3 = CodeState.full(quantity_factory, 90.90) + assert (microphys_state3.inner_A.A.field[:] == 90.90).all() + assert (microphys_state3.inner_B.B.field[:] == 90.90).all() + assert (microphys_state3.C.field[:] == 90.90).all() + + +@dataclasses.dataclass +class CodeStateWithDDim(State): + @dataclasses.dataclass + class InnerA: + ddim_A: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM, "ExtraDim1"], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + @dataclasses.dataclass + class InnerB: + ddim_B: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM, "ExtraDim2"], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + inner_A: InnerA + inner_B: InnerB + no_ddim: Quantity = dataclasses.field( + metadata={ + "name": "C", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +def test_state_ddim(): + _, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + + # Test allocator + microphys_state = CodeStateWithDDim.ones( + quantity_factory, + data_dimensions={ + "ExtraDim1": 3, + "ExtraDim2": 4, + }, + ) + assert (microphys_state.no_ddim.field[:] == 1.0).all() + assert microphys_state.inner_A.ddim_A.field.shape == (5, 5, 3, 3) + assert (microphys_state.inner_A.ddim_A.field[:] == 1.0).all() + assert microphys_state.inner_B.ddim_B.field.shape == (5, 5, 3, 4) + assert (microphys_state.inner_B.ddim_B.field[:] == 1.0).all() diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py new file mode 100644 index 00000000..fb32a35d --- /dev/null +++ b/tests/stree_optimizer/test_optimization.py @@ -0,0 +1,69 @@ +from ndsl import StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def stencil_A(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + +def stencil_B(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = out_field + in_field * 3 + + +class TriviallyMergeableCode: + def __init__(self, stencil_factory: StencilFactory): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil_A = stencil_factory.from_dims_halo( + func=stencil_A, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_B = stencil_factory.from_dims_halo( + func=stencil_B, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, in_field: FloatField, out_field: FloatField): + self.stencil_A(in_field, out_field) + self.stencil_B(in_field, out_field) + + +def test_stree_roundtrip_no_opt(): + """Dev Note: + + The below code sucessfully merges top level K loop (2 loops) + How do we test it?! Running doesn't test merging and the compilation + is a near-black box. We could reach in the `dace_config.compiled_sdfg` + cache but it's keyed on the dace.program and if we can reach the program + well we can reach the SDFG and turn it into an stree for verification + Should we run orchestration "by hand"? + Can we intercept the `stree` ? After all we just want to check that! + + Test is deactivated for now""" + + return True + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] + + code(in_qty, out_qty) + + assert (out_qty.field[:] == 4).all() + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/stree_optimizer/test_pipeline.py b/tests/stree_optimizer/test_pipeline.py new file mode 100644 index 00000000..6d4c6f74 --- /dev/null +++ b/tests/stree_optimizer/test_pipeline.py @@ -0,0 +1,48 @@ +from ndsl import StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def double_map(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + with computation(PARALLEL), interval(...): + out_field = out_field + in_field * 3 + + +class TriviallyMergeableCode: + def __init__(self, stencil_factory: StencilFactory): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo( + func=double_map, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, in_field: FloatField, out_field: FloatField): + self.stencil(in_field, out_field) + + +def test_stree_roundtrip_no_opt(): + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + code = TriviallyMergeableCode(stencil_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + # Temporarily flip the internal switch + import ndsl.dsl.dace.orchestration as orch + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_PASSES = [] + + code(in_qty, out_qty) + + assert (out_qty.field[:] == 4).all() + + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/test_4d_fields.py b/tests/test_4d_fields.py new file mode 100644 index 00000000..f8e8bac1 --- /dev/null +++ b/tests/test_4d_fields.py @@ -0,0 +1,84 @@ +from ndsl import StencilFactory +from ndsl.boilerplate import ( + get_factories_single_tile, + get_factories_single_tile_orchestrated, +) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval, max +from ndsl.dsl.typing import Float, FloatField, set_4d_field_size + + +TRACER_DIM = "n_tracers" +FloatFieldTracer = set_4d_field_size(9, Float) +ntracers = 9 +ntke = 8 +fill_value = 42.0 + + +def sample_4d_stencil(q_in: FloatFieldTracer, q_out: FloatField): + from __externals__ import ntke + + with computation(PARALLEL), interval(...): + q_out = max(q_in[0, 0, 0][ntke], 1.0e-9) + + +class SampleCalculation: + def __init__(self, stencil_factory: StencilFactory, *, ntke: int): + self._test_calc = stencil_factory.from_dims_halo( + func=sample_4d_stencil, + externals={ + "ntke": ntke, + }, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, q_in: FloatFieldTracer, q_out: FloatField): + self._test_calc(q_in, q_out) + + +def test_non_orchestrated_call() -> None: + stencil_factory, quantity_factory = get_factories_single_tile(24, 24, 91, 3) + quantity_factory.add_data_dimensions( + { + TRACER_DIM: ntracers, + } + ) + + q_out = quantity_factory.zeros( + [X_DIM, Y_DIM, Z_DIM], + units="unknown", + ) + q_in = quantity_factory.zeros( + [X_DIM, Y_DIM, Z_DIM, TRACER_DIM], + units="unknown", + ) + q_in.field[:, :, :, ntke] = fill_value + + calc = SampleCalculation(stencil_factory, ntke=ntke) + calc(q_in, q_out) + assert (q_out.field[:, :, :] == fill_value).all() + + +def test_orchestrated_call() -> None: + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 24, 24, 91, 3 + ) + quantity_factory.add_data_dimensions( + { + TRACER_DIM: ntracers, + } + ) + + q_out = quantity_factory.zeros( + [X_DIM, Y_DIM, Z_DIM], + units="unknown", + ) + q_in = quantity_factory.zeros( + [X_DIM, Y_DIM, Z_DIM, TRACER_DIM], + units="unknown", + ) + q_in.field[:, :, :, ntke] = fill_value + + calc = SampleCalculation(stencil_factory, ntke=ntke) + calc(q_in, q_out) + assert (q_out.field[:, :, :] == fill_value).all() diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index 574163b5..b531453d 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from ndsl import QuantityFactory, StencilFactory from ndsl.constants import X_DIM, Y_DIM, Z_DIM @@ -42,6 +43,10 @@ def test_boilerplate_import_numpy(): nx=5, ny=5, nz=2, nhalo=1 ) + # Ensure backend is propagated to StencilFactory and QuantityFactory + assert stencil_factory.backend == "numpy" + assert quantity_factory._backend() == "numpy" + _copy_ops(stencil_factory, quantity_factory) @@ -57,4 +62,17 @@ def test_boilerplate_import_orchestrated_cpu(): nx=5, ny=5, nz=2, nhalo=1 ) + # Ensure backend is propagated to StencilFactory and QuantityFactory + assert stencil_factory.backend == "dace:cpu" + assert quantity_factory._backend() == "dace:cpu" + _copy_ops(stencil_factory, quantity_factory) + + +def test_boilerplate_non_dace_based_orchestration_raises(): + from ndsl.boilerplate import get_factories_single_tile_orchestrated + + with pytest.raises(ValueError, match="Only .* backends can be orchestrated."): + get_factories_single_tile_orchestrated( + nx=5, ny=5, nz=2, nhalo=1, backend="numpy" + ) diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index 5674bfc9..d2d7f64e 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -75,7 +75,7 @@ def perform_serial_halo_updates( def test_Recv_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6, fill_value=0.0)) + comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.received_buffers) == 0 @@ -85,7 +85,7 @@ def test_Recv_inserts_data(): def test_Irecv_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6, fill_value=0.0)) + comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.received_buffers) == 0 @@ -97,7 +97,7 @@ def test_Irecv_inserts_data(): def test_bcast_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6, fill_value=0.0)) + comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.bcast_objects) == 0 diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 75be7acf..236e1eb9 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -228,7 +228,7 @@ def test_cube_scatter_no_recv_quantity( result_list.append(communicator.scatter(send_quantity=cube_quantity)) else: result_list.append(communicator.scatter()) - for rank, (result, scattered) in enumerate(zip(result_list, scattered_quantities)): + for _rank, (result, scattered) in enumerate(zip(result_list, scattered_quantities)): assert_quantity_equals(result, scattered) @@ -246,7 +246,7 @@ def test_cube_scatter_with_recv_quantity( else: result = communicator.scatter(recv_quantity=recv) assert result is recv - for rank, (result, scattered) in enumerate( + for _rank, (result, scattered) in enumerate( zip(recv_quantities, scattered_quantities) ): assert_quantity_equals(result, scattered) @@ -304,7 +304,7 @@ def test_cube_scatter_state_with_recv_state( result = communicator.scatter_state(recv_state=state) assert result["time"] == time assert result["air_temperature"] is recv - for rank, (result, scattered) in enumerate( + for _rank, (result, scattered) in enumerate( zip(recv_quantities, scattered_quantities) ): assert_quantity_equals(result, scattered) diff --git a/tests/test_decomposition.py b/tests/test_decomposition.py index bf7363e3..2160e23e 100644 --- a/tests/test_decomposition.py +++ b/tests/test_decomposition.py @@ -70,8 +70,8 @@ def test_build_cache_path( @pytest.mark.skipif( - MPI is None or MPI.COMM_WORLD.Get_size() != 6, - reason="mpi4py is not available or pytest was not run in parallel", + MPI is None, + reason="pytest is not run in parallel", ) def test_unblock_waiting_tiles(): comm = MPI.COMM_WORLD diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index a401e698..ebad25ef 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -70,12 +70,12 @@ def namelist(nx_tile, ny_tile, nz, layout): def sizer(request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths): if request.param == "from_tile_params": sizer = SubtileGridSizer.from_tile_params( - nx_tile, - ny_tile, - nz, - N_HALO_DEFAULT, - extra_dimension_lengths, - layout, + nx_tile=nx_tile, + ny_tile=ny_tile, + nz=nz, + n_halo=N_HALO_DEFAULT, + layout=layout, + data_dimensions=extra_dimension_lengths, ) elif request.param == "from_namelist": sizer = SubtileGridSizer.from_namelist(namelist) @@ -191,7 +191,7 @@ def test_subtile_dimension_sizer_shape(sizer, dim_case): def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, numpy) + allocator = QuantityFactory.from_backend(sizer, "numpy") quantity = allocator.zeros(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -202,7 +202,7 @@ def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): def test_allocator_ones(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, numpy) + allocator = QuantityFactory.from_backend(sizer, "numpy") quantity = allocator.ones(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -212,11 +212,25 @@ def test_allocator_ones(numpy, sizer, dim_case, units, dtype): assert numpy.all(quantity.data == 1) -def test_allocator_empty(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, numpy) +def test_allocator_empty(sizer, dim_case, units, dtype): + allocator = QuantityFactory.from_backend(sizer, "numpy") quantity = allocator.empty(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims assert quantity.origin == dim_case.origin assert quantity.extent == dim_case.extent assert quantity.data.shape == dim_case.shape + + +def test_allocator_data_dimensions_operations(sizer): + quantity_factory = QuantityFactory.from_backend(sizer, "numpy") + quantity_factory.add_data_dimensions({"D0": 11}) + assert "D0" in quantity_factory.sizer.data_dimensions.keys() + assert quantity_factory.sizer.data_dimensions["D0"] == 11 + quantity_factory.update_data_dimensions({"D0": 22}) + assert quantity_factory.sizer.data_dimensions["D0"] == 22 + with pytest.raises( + ValueError, + match="Use `update_data_dimensions` if you meant to update the length.", + ): + quantity_factory.add_data_dimensions({"D0": 33}) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 00000000..f5d66eb7 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,8 @@ +import pytest + +from ndsl import OutOfBoundsError + + +def test_OutOfBoundsError_is_deprecation() -> None: + with pytest.deprecated_call(): + OutOfBoundsError("This should trigger a deprecation warning.") diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 00000000..5f05af82 --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,19 @@ +import pytest + +import ndsl.filesystem as fs + + +def test_is_file_is_deprecated() -> None: + with pytest.deprecated_call(): + fs.is_file("path/to/my_file.txt") + + +def test_open_is_deprecated() -> None: + with pytest.deprecated_call(): + with fs.open("README.md", "r"): + pass + + +def test_get_fs_is_deprecated() -> None: + with pytest.deprecated_call(): + fs.get_fs("path/to/my/file.txt") diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 40595669..dab27cb3 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -1,7 +1,8 @@ -""" Test of the GPU to GPU communication strategy. +"""Test of the GPU to GPU communication strategy. Those test use halo_update but are separated from the entire """ + import contextlib import functools @@ -92,7 +93,7 @@ def gpu_communicators(cube_partitioner): @contextlib.contextmanager def module_count_calls_to_zeros(module): - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused N_ZEROS_CALLS[module.zeros] = 0 def count_calls(func): @@ -100,7 +101,7 @@ def count_calls(func): @functools.wraps(func) def wrapped(*args, **kwargs): - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused N_ZEROS_CALLS[func] = N_ZEROS_CALLS[func] + 1 return func(*args, **kwargs) @@ -135,7 +136,7 @@ def test_halo_update_only_communicate_on_gpu(backend, gpu_communicators): halo_updater.wait() # We expect no np calls and several cp calls - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused print(f"Results {N_ZEROS_CALLS}") assert N_ZEROS_CALLS[cp.zeros] > 0 assert N_ZEROS_CALLS[np.zeros] == 0 @@ -165,6 +166,6 @@ def test_halo_update_communicate_though_cpu(backend, cpu_communicators): halo_updater.wait() # We expect several np calls and several cp calls - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused assert N_ZEROS_CALLS[np.zeros] > 0 assert N_ZEROS_CALLS[cp.zeros] == 0 diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index 9c1ac220..ca76ab8b 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -8,7 +8,6 @@ CubedSpherePartitioner, DummyComm, HaloUpdater, - OutOfBoundsError, Quantity, TileCommunicator, TilePartitioner, @@ -300,7 +299,7 @@ def depth_quantity_list( """A list of quantities whose value indicates the distance from the computational domain boundary.""" return_list = [] - for rank in range(total_ranks): + for _rank in range(total_ranks): data = numpy.empty(shape, dtype=dtype) data[:] = numpy.nan for n_inside in range(max(n_points, max(extent) // 2), -1, -1): @@ -337,7 +336,7 @@ def tile_depth_quantity_list( """A list of quantities whose value indicates the distance from the computational domain boundary for a single tile.""" return_list = [] - for rank in range(single_tile_ranks): + for _rank in range(single_tile_ranks): data = numpy.empty(shape, dtype=dtype) data[:] = numpy.nan for n_inside in range(max(n_points, max(extent) // 2), -1, -1): @@ -383,7 +382,7 @@ def test_halo_update_timer( ranks_per_tile, ): """ - test that halo update produces nonzero timings for all expected labels + Test that halo update produces nonzero timings for all expected labels. """ halo_updater_list = [] for communicator, quantity in zip(communicator_list, zeros_quantity_list): @@ -422,7 +421,7 @@ def test_depth_halo_update( boundary_dict, ranks_per_tile, ): - """test that written values have the correct orientation""" + """Test that written values have the correct orientation.""" sample_quantity = depth_quantity_list[0] y_dim, x_dim = get_horizontal_dims(sample_quantity.dims) y_index = sample_quantity.dims.index(y_dim) @@ -462,7 +461,7 @@ def test_depth_tile_halo_update( boundary_dict, ranks_per_tile, ): - """test that written values have the correct orientation on a tile""" + """Test that written values have the correct orientation on a tile.""" sample_quantity = tile_depth_quantity_list[0] y_dim, x_dim = get_horizontal_dims(sample_quantity.dims) y_index = sample_quantity.dims.index(y_dim) @@ -498,7 +497,7 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, """A list of quantities whose values are 0 in the computational domain and 1 outside of it.""" return_list = [] - for rank in range(total_ranks): + for _rank in range(total_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( data, @@ -519,7 +518,7 @@ def zeros_quantity_tile_list( """A list of quantities whose values are 0 in the computational domain and 1 outside of it on a single tile.""" return_list = [] - for rank in range(single_tile_ranks): + for _rank in range(single_tile_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( data, @@ -542,10 +541,12 @@ def test_too_many_points_requested( n_points_update, ): """ - test that an exception is raised when trying to update more halo points than exist + Test that an exception is raised when trying to update more halo points than exist. """ for communicator, quantity in zip(communicator_list, zeros_quantity_list): - with pytest.raises(OutOfBoundsError): + with pytest.raises( + IndexError, match="Boundary slice extends past end of domain.*" + ): communicator.start_halo_update(quantity, n_points_update) @@ -559,11 +560,11 @@ def test_too_many_points_requested_tile( n_points_update, ): """ - test that an exception is raised when trying to update more halo points than exist - on a tile + Test that an exception is raised when trying to update more halo points than exist + on a tile. """ for communicator, quantity in zip(tile_communicator_list, zeros_quantity_tile_list): - with pytest.raises(OutOfBoundsError): + with pytest.raises(IndexError): communicator.start_halo_update(quantity, n_points_update) @@ -621,7 +622,7 @@ def test_tile_halo_update_unsupported_layout( tile_communicator_list, n_points_update, ): - """test that correct exception is raised if layout is unsupported""" + """Test that correct exception is raised if layout is unsupported.""" # if you delete this test because this is now implemented, # please add the appropriate layout cases to the main halo update test for communicator, quantity in zip(tile_communicator_list, zeros_quantity_tile_list): @@ -642,8 +643,8 @@ def test_zeros_tile_halo_update( boundary_dict, ranks_per_tile, ): - """test that zeros from adjacent domains get written over ones on local halo - on a single tile""" + """Test that zeros from adjacent domains get written over ones on local halo + on a single tile.""" halo_updater_list = [] if 0 < n_points_update <= n_points: for communicator, quantity in zip( @@ -688,7 +689,7 @@ def test_zeros_vector_halo_update( boundary_dict, ranks_per_tile, ): - """test that zeros from adjacent domains get written over ones on local halo""" + """Test that zeros from adjacent domains get written over ones on local halo.""" x_list = zeros_quantity_list y_list = copy.deepcopy(x_list) if 0 < n_points_update <= n_points: @@ -738,8 +739,8 @@ def test_zeros_vector_tile_halo_update( boundary_dict, ranks_per_tile, ): - """test that zeros from adjacent domains get written over ones on local halo - on a single tile""" + """Test that zeros from adjacent domains get written over ones on local halo + on a single tile.""" x_list = zeros_quantity_tile_list y_list = copy.deepcopy(x_list) if 0 < n_points_update <= n_points: @@ -794,7 +795,7 @@ def test_vector_halo_update_timer( ranks_per_tile, ): """ - test that halo update produces nonzero timings for all expected labels + Test that halo update produces nonzero timings for all expected labels. """ x_list = zeros_quantity_list y_list = copy.deepcopy(x_list) diff --git a/tests/test_namelist.py b/tests/test_namelist.py new file mode 100644 index 00000000..a32919e7 --- /dev/null +++ b/tests/test_namelist.py @@ -0,0 +1,8 @@ +import pytest + +from ndsl import Namelist + + +def test_ndsl_namelist_deprecation() -> None: + with pytest.deprecated_call(): + my_namelist = Namelist() diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py new file mode 100644 index 00000000..ad334842 --- /dev/null +++ b/tests/test_ndsl_runtime.py @@ -0,0 +1,106 @@ +from typing import Any + +import pytest + +from ndsl import NDSLRuntime, QuantityFactory, StencilFactory +from ndsl.boilerplate import ( + get_factories_single_tile, + get_factories_single_tile_orchestrated, +) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField + + +def the_copy_stencil(from_: FloatField, to: FloatField) -> None: + with computation(PARALLEL), interval(...): + to = from_ + + +class Code(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(dace_config=stencil_factory.config.dace_config) + self.copy = stencil_factory.from_dims_halo( + the_copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] + ) + self.local = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + + def test_check(self) -> None: + assert self.local.__descriptor__().transient + + def __call__(self, A, B) -> None: # type: ignore[no-untyped-def] + self.copy(A, self.local) + self.copy(self.local, B) + + +class BadCode_NoSuperInit(NDSLRuntime): + def __init__(self) -> None: + # Forget to init + pass + + +class Code_NoCall(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(dace_config=stencil_factory.config.dace_config) + pass + + def run(self, A: Any, B: Any) -> None: + pass + + +def test_runtime_make_local() -> None: + stencil_factory, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="numpy" + ) + A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + + code = Code(stencil_factory, quantity_factory) + + # Check that local is not reachable outside of Code + with pytest.raises(RuntimeError, match="Forbidden Local access:"): + assert code.local.__descriptor__().transient + + # Check the local is properly transient - with access in Code + code.test_check() + + # Check regular quantity are not transient + assert not A_.__descriptor__().transient + assert not B_.__descriptor__().transient + + +def test_runtime_has_orchestracted_call() -> None: + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + code = Code(stencil_factory, quantity_factory) + code(A_, B_) + + # We monkey patch the class, a __name__ attribute is now available + # and the original Class name is postfixed with "_patched" + assert hasattr(code, "__name__") + assert code.__name__ == "Code_patched" + assert (A_.field[:] == B_.field[:]).all() + + +def test_runtime_does_not_orchestrate_when_call_is_not_present() -> None: + stencil_factory, _ = get_factories_single_tile_orchestrated( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + code = Code_NoCall(stencil_factory) + + # We didn't monkey patch the class, no __name__ on object + # and the original Class name is intact + assert not hasattr(code, "__name__") + assert type(code).__name__ == "Code_NoCall" + + +def test_runtime_fail_when_not_super_init() -> None: + with pytest.raises( + RuntimeError, match="inherit from NDSLRuntime but didn't call super()" + ): + bad_code = BadCode_NoSuperInit() diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index 6bd15eda..cc0a105e 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -27,7 +27,7 @@ total_ranks = 6 * ranks_per_tile rank = 0 for tile in range(6): - for subtile in range(ranks_per_tile): + for _subtile in range(ranks_per_tile): rank_list.append(rank) total_rank_list.append(total_ranks) tile_index_list.append(tile) @@ -63,7 +63,7 @@ def test_get_tile_index(rank, total_ranks, tile_index): for layout in ((1, 1), (1, 2), (2, 2), (2, 3)): rank = 0 - for tile in range(6): + for _tile in range(6): for y_subtile in range(layout[0]): for x_subtile in range(layout[1]): rank_list.append(rank) diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 918d1ea4..60ef583a 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -219,7 +219,7 @@ def test_tile_scatter_no_recv_quantity( result_list.append(communicator.scatter(send_quantity=tile_quantity)) else: result_list.append(communicator.scatter()) - for rank, (result, scattered) in enumerate(zip(result_list, scattered_quantities)): + for _rank, (result, scattered) in enumerate(zip(result_list, scattered_quantities)): assert result.dims == scattered.dims assert result.units == scattered.units assert result.extent == scattered.extent @@ -240,7 +240,7 @@ def test_tile_scatter_with_recv_quantity( else: result = communicator.scatter(recv_quantity=recv) assert result is recv - for rank, (result, scattered) in enumerate( + for _rank, (result, scattered) in enumerate( zip(recv_quantities, scattered_quantities) ): assert result.dims == scattered.dims @@ -328,7 +328,7 @@ def test_tile_scatter_state_with_recv_state( result = communicator.scatter_state(recv_state=state) assert result["time"] == time assert result["air_temperature"] is recv - for rank, (result, scattered) in enumerate( + for _rank, (result, scattered) in enumerate( zip(recv_quantities, scattered_quantities) ): assert result.dims == scattered.dims @@ -354,7 +354,7 @@ def test_tile_scatter_state_with_recv_state_without_time( result = communicator.scatter_state(recv_state=state) assert result["air_temperature"] is recv assert "time" not in result - for rank, (result, scattered) in enumerate( + for _rank, (result, scattered) in enumerate( zip(recv_quantities, scattered_quantities) ): assert result.dims == scattered.dims diff --git a/tests/test_timer.py b/tests/test_timer.py index bb8ec3a6..7970b5c9 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -69,7 +69,7 @@ def test_consecutive_start_stops(timer): time.sleep(0.01) timer.stop("label") previous_time = timer.times["label"] - for i in range(5): + for _i in range(5): timer.start("label") time.sleep(0.01) timer.stop("label") @@ -83,7 +83,7 @@ def test_consecutive_clocks(timer): with timer.clock("label"): time.sleep(0.01) previous_time = timer.times["label"] - for i in range(5): + for _i in range(5): with timer.clock("label"): time.sleep(0.01) assert timer.times["label"] >= previous_time + 0.01 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..70dd9739 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,18 @@ +import pytest + +from ndsl.units import UnitsError, ensure_equal_units, units_are_equal + + +def test_UnitsError_is_deprecated() -> None: + with pytest.deprecated_call(): + UnitsError() + + +def test_units_are_equal_is_deprecated() -> None: + with pytest.deprecated_call(): + units_are_equal("asdf", "asdf") + + +def test_ensure_equal_units_is_deprecated() -> None: + with pytest.deprecated_call(): + ensure_equal_units("asdf", "asdf") diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index a67b6599..b3847ee4 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -324,7 +324,7 @@ def test_array_chunks(layout, tile_array_shape, array_dims, target): assert result == target -def _assert_no_nulls(dataset: "xr.Dataset"): +def _assert_no_nulls(dataset: xr.Dataset): number_of_null = dataset["var"].isnull().sum().item() total_size = dataset["var"].size