diff --git a/.github/PULL_REQUEST_TEMPLATE/README.md b/.github/PULL_REQUEST_TEMPLATE/README.md new file mode 100644 index 00000000..e5980e53 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/README.md @@ -0,0 +1,9 @@ +# Pull request templates + +- `../pull_request_template.md`: The default pull request template. Used for PRs. +- `release.md`: Special template used for releasing a new version of NDSL. +- `release-patch.md`: Special template used for patch releases. + +Note: GitHub has limited support for multiple pull request templates. Most notably, templates can only be [selected by an URL query parameter](https://github.com/orgs/community/discussions/4620) and there's currently [no way to set a title](https://github.com/orgs/community/discussions/63965). + +Note: GitHub does not support having the default pull request template in this folder. All templates in this folder can only be used with the `template=` URL parameter (see note above). diff --git a/.github/PULL_REQUEST_TEMPLATE/release-patch.md b/.github/PULL_REQUEST_TEMPLATE/release-patch.md new file mode 100644 index 00000000..c588dd8f --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/release-patch.md @@ -0,0 +1,32 @@ +# Release NDSL version `YYYY.MM.PP` + +This PR patches release `YYYY.MM.PP` because + +1. reason +2. reason +3. ... + +## Pre-release checklist + +Things to do before the patch release. Helps to keep the fallout from this release as minimal as possible. + +- [ ] setup a draft PR in [NOAA-GFDL/pace](https://github.com/NOAA-GFDL/pace) with updated submodules for `NDSL`, `pyFV3`, and `pySHiELD`. + Don't merge yet - just let CI run and fix potential issues before the release. To be merged afterwards, see post-release checklist. + +## Release checklist + +What to do to actually release: + +- [x] create this PR to merge changes from `my-patches` into `main` + - use "squash merge" +- [ ] once merged, create a GitHub release and tag the new version + - version format is `[year].[month].[patch]`. Increase the patch version, e.g. `2025.10.01` if this is patching the `2025.10.00` release. + - let GitHub auto-generate release notes from the last tagged version +- [ ] send an announcement on Mattermost + +## Post-release checklist + +What to do after a release: + +- [ ] update the pace PR from the pre-commit checklist to include the released version of NDSL and merge it. +- [ ] in NDSL, merge `main` back into `develop` (potentially adding a commit to fix the issue "properly") diff --git a/.github/PULL_REQUEST_TEMPLATE/release.md b/.github/PULL_REQUEST_TEMPLATE/release.md new file mode 100644 index 00000000..00aa3403 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/release.md @@ -0,0 +1,26 @@ +# Release NDSL version `YYYY.MM.00` + +## Pre-release checklist + +Things to do before the release. Helps to keep the fallout from this release as minimal as possible. + +- [ ] setup a draft PR in [NOAA-GFDL/pace](https://github.com/NOAA-GFDL/pace) with updated submodules for `NDSL`, `pyFV3`, and `pySHiELD`. + Don't merge yet - just let CI run and fix potential issues before the release. To be merged afterwards, see post-release checklist. + +## Release checklist + +What to do to actually release: + +- [x] create this PR to merge changes from `develop` into `main` + - merge as "Merge commit" +- [ ] once merged, create a GitHub release and tag the new version + - version format is `[year].[month].[patch]`, e.g. `2025.10.00` + - let GitHub auto-generate release notes from the last tagged version +- [ ] send an announcement on Mattermost + +## Post-release checklist + +What to do after a release: + +- [ ] update the pace PR from the pre-commit checklist to include the released version of NDSL and merge it. +- [ ] merge breaking changes in NDSL (e.g. search for deprecation warnings) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/pull_request_template.md similarity index 100% rename from .github/PULL_REQUEST_TEMPLATE.md rename to .github/pull_request_template.md diff --git a/.github/workflows/create-cache.yaml b/.github/workflows/create-cache.yaml index 6d823662..f5e38a6b 100644 --- a/.github/workflows/create-cache.yaml +++ b/.github/workflows/create-cache.yaml @@ -14,7 +14,7 @@ on: # Cancel running jobs if there's a newer push concurrency: - group: ${{ github.repository }}-${{ github.workflow }}-${{ github.ref }} + group: ndsl-${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: @@ -45,8 +45,8 @@ jobs: # GitHub Actions cache of pyFV3 test data pyFV3_test_data: - uses: NOAA-GFDL/pyFV3/.github/workflows/create_cache.yml@develop + uses: NOAA-GFDL/pyFV3/.github/workflows/create_cache.yaml@develop # GitHub Actions cache of pySHiELD test data pySHiELD_test_data: - uses: NOAA-GFDL/pySHiELD/.github/workflows/create_cache.yml@develop + uses: NOAA-GFDL/pySHiELD/.github/workflows/create_cache.yaml@develop diff --git a/.github/workflows/docs_build.yaml b/.github/workflows/docs_build.yaml index 3b0f1ed1..4b59ed78 100644 --- a/.github/workflows/docs_build.yaml +++ b/.github/workflows/docs_build.yaml @@ -28,7 +28,7 @@ jobs: python-version: '3.11' - name: Install mkdocs - run: pip install mkdocs-material mkdocstrings[python] + run: pip install mkdocs-material mkdocstrings[python] mkdocs-exclude - name: Build docs run: mkdocs build diff --git a/.github/workflows/docs_deploy.yaml b/.github/workflows/docs_deploy.yaml index 1e6659d3..cbbb381c 100644 --- a/.github/workflows/docs_deploy.yaml +++ b/.github/workflows/docs_deploy.yaml @@ -29,7 +29,7 @@ jobs: python-version: 3.11 - name: Install dependencies - run: pip install mkdocs-material mkdocstrings[python] + run: pip install mkdocs-material mkdocstrings[python] mkdocs-exclude - name: Deploy docs to GitHub Pages run: mkdocs gh-deploy --force diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index bd90f66c..ef9fdcea 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -32,7 +32,7 @@ jobs: run: pip3 install mpich - name: Install Python packages - run: pip3 install .[test] + run: pip3 install .[test,zarr] - name: Run serial-cpu tests run: coverage run --rcfile=pyproject.toml -m pytest tests diff --git a/docs/docstrings/testing/dummy_comm.md b/docs/docstrings/testing/dummy_comm.md deleted file mode 100644 index 538c01d2..00000000 --- a/docs/docstrings/testing/dummy_comm.md +++ /dev/null @@ -1,3 +0,0 @@ -# dummy_comm - -::: testing.dummy_comm diff --git a/docs/docstrings/top/exceptions.md b/docs/docstrings/top/exceptions.md deleted file mode 100644 index 8803fee0..00000000 --- a/docs/docstrings/top/exceptions.md +++ /dev/null @@ -1,3 +0,0 @@ -# exceptions - -::: exceptions diff --git a/docs/docstrings/top/filesystem.md b/docs/docstrings/top/filesystem.md deleted file mode 100644 index 9f05cc3a..00000000 --- a/docs/docstrings/top/filesystem.md +++ /dev/null @@ -1,3 +0,0 @@ -# filesystem - -::: filesystem diff --git a/docs/docstrings/top/namelist.md b/docs/docstrings/top/namelist.md deleted file mode 100644 index c129312a..00000000 --- a/docs/docstrings/top/namelist.md +++ /dev/null @@ -1,3 +0,0 @@ -# namelist - -::: namelist diff --git a/docs/docstrings/top/units.md b/docs/docstrings/top/units.md deleted file mode 100644 index 5d2efcf8..00000000 --- a/docs/docstrings/top/units.md +++ /dev/null @@ -1,3 +0,0 @@ -# units - -::: units diff --git a/docs/internal/README.md b/docs/internal/README.md new file mode 100644 index 00000000..d8eea9dc --- /dev/null +++ b/docs/internal/README.md @@ -0,0 +1,3 @@ +# Internal documentation + +This folder contains internal / developer documentation and processes, e.g. how to build a release. This folder is thus ignored when building the public facing documentation from the `docs/` folder. diff --git a/docs/internal/release.md b/docs/internal/release.md new file mode 100644 index 00000000..7b804a82 --- /dev/null +++ b/docs/internal/release.md @@ -0,0 +1,24 @@ +# Release a new version + +This internal documentation guides you through the process of releasing a new version of NDSL. It is very simple: + +1. Click [create a release](https://github.com/NOAA-GFDL/NDSL/compare/main...develop?expand=1&template=release.md) and follow the steps in the release checklist. + +## Patch release + +Every now and then, we'll need to patch the currently released version of NDSL. To do so, follow these steps: + +1. Create a branch from `main`. +2. Commit your changes on that branch. +3. Use the following URL and follow the steps in the patch release checklist. + +As an example, you'd go and create branch `my-patches` from `main` + +```bash +git checkout main +git switch -c my-patches +# do changes ... +git push +``` + +and in that case, the URL with the patch release template is: . diff --git a/docs/user/index.md b/docs/user/index.md index 46ddc3eb..106ec159 100644 --- a/docs/user/index.md +++ b/docs/user/index.md @@ -12,13 +12,13 @@ NDSL tries to have sensible defaults. In cases you want tweak something, here ar ### Literal precision (float/int) -Unspecified integer and floating point literals (e.g. `42` and `3.1415`) default to 64-bit precision. This can be changed with the environment variable `PACE_FLOAT_PRECISION`. +Unspecified integer and floating point literals (e.g. `42` and `3.1415`) default to 64-bit precision. This can be changed with the environment variable `NDSL_LITERAL_PRECISION`. For mixed precision code, you can specify the "hard coded" precision with type hints and casts, e.g. ```python with computation(PARALLEL), interval(...): - # Either 32-bit or 64-bit depending on `PACE_FLOAT_PRECISION` + # Either 32-bit or 64-bit depending on `NDSL_LITERAL_PRECISION` my_int = 42 my_float = 3.1415 diff --git a/external/dace b/external/dace index 1033dfcf..4a9f4602 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 1033dfcf9d118856d82c6ee8d6f6cfacec662335 +Subproject commit 4a9f46027147a52e2b0ac9eedeb101c3ab27d0bf diff --git a/external/gt4py b/external/gt4py index e140f707..a7429094 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit e140f70731b723c519239e027237fb6281f4733b +Subproject commit a7429094d7dd9418a2e3c1e57b2e9c783d79250d diff --git a/mkdocs.yml b/mkdocs.yml index 1ca8682c..24aa01e0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,15 +24,12 @@ nav: - "boilerplate": docstrings/top/boilerplate.md - "buffer": docstrings/top/buffer.md - "constants": docstrings/top/constants.md - - "exceptions": docstrings/top/exceptions.md - - "filesystem": docstrings/top/filesystem.md - "io": docstrings/top/io.md - "logging": docstrings/top/logging.md - "namelist": docstrings/top/namelist.md - "optional_imports": docstrings/top/optional_imports.md - "types": docstrings/top/types.md - "typing": docstrings/top/typing.md - - "units": docstrings/top/units.md - "utils": docstrings/top/utils.md - checkpointer: - "base": docstrings/checkpointer/base.md @@ -106,7 +103,6 @@ nav: - "tridiag": docstrings/stencils/tridiag.md - testing: - "comparison": docstrings/testing/comparison.md - - "dummy_comm": docstrings/testing/dummy_comm.md - "perturbation": docstrings/testing/perturbation.md - viz: - "cube_sphere": docstrings/viz/cube_sphere.md @@ -149,6 +145,9 @@ plugins: paths: [./ndsl] # Adjust this path to where your Python modules are options: show_source: false + - exclude: + glob: + - internal/* watch: # reload when the glossary file is updated diff --git a/ndsl/__init__.py b/ndsl/__init__.py index c7730282..3c2a018c 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -19,18 +19,15 @@ from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig -from .exceptions import OutOfBoundsError from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .initialization import GridSizer, QuantityFactory, SubtileGridSizer from .monitor.netcdf_monitor import NetCDFMonitor -from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport from .quantity import Local, Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import -from .testing.dummy_comm import DummyComm from .types import Allocator from .utils import MetaEnumStr @@ -62,7 +59,6 @@ "CompilationConfig", "RunMode", "StencilConfig", - "OutOfBoundsError", "HaloExchangeSpec", "HaloUpdater", "HaloUpdateRequest", @@ -72,7 +68,6 @@ "SubtileGridSizer", "ndsl_log", "NetCDFMonitor", - "Namelist", "NullPerformanceCollector", "PerformanceCollector", "NullProfiler", @@ -83,7 +78,6 @@ "Quantity", "FieldBundle", "FieldBundleType", - "DummyComm", "Allocator", "MetaEnumStr", "State", diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index d17d6349..c6fb82f8 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -3,7 +3,7 @@ DaceConfig, DaCeOrchestration, GridIndexing, - NullComm, + MPIComm, QuantityFactory, RunMode, StencilConfig, @@ -54,6 +54,13 @@ def _get_factories( ) if topology == "tile": + mpi_comm = MPIComm() + if mpi_comm.Get_size() != 1: + raise ValueError( + "Single tile topology requested with an MPI communicator of size " + f"{mpi_comm.Get_size()} > 1. Re-configure MPI to run on only one rank." + ) + partitioner = TilePartitioner((1, 1)) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, @@ -63,13 +70,13 @@ def _get_factories( layout=partitioner.layout, tile_partitioner=partitioner, ) - comm = TileCommunicator(comm=NullComm(0, 1, 42), partitioner=partitioner) + comm = TileCommunicator(comm=mpi_comm, partitioner=partitioner) else: raise NotImplementedError(f"Topology {topology} is not implemented.") grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) - quantity_factory = QuantityFactory.from_backend(sizer, backend) + quantity_factory = QuantityFactory(sizer, backend=backend) return stencil_factory, quantity_factory diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 580f881a..ec886bae 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -5,7 +5,7 @@ CachingRequestReader, CachingRequestWriter, ) -from .comm_abc import Comm, Request +from .comm_abc import Comm, ReductionOperator, Request __all__ = [ @@ -15,5 +15,6 @@ "CachingRequestReader", "CachingRequestWriter", "Comm", + "ReductionOperator", "Request", ] diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 6eee4514..983a35b2 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -108,7 +108,7 @@ def _create_all_reduce_quantity( units=input_metadata.units, origin=input_metadata.origin, extent=input_metadata.extent, - gt4py_backend=input_metadata.gt4py_backend, + backend=input_metadata.backend, allow_mismatch_float_precision=False, ) return all_reduce_quantity @@ -228,7 +228,7 @@ def _get_gather_recv_quantity( units=send_metadata.units, origin=tuple([0 for dim in send_metadata.dims]), extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, + backend=send_metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -241,7 +241,7 @@ def _get_scatter_recv_quantity( send_metadata.np.zeros(shape, dtype=send_metadata.dtype), # type: ignore dims=send_metadata.dims, units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, + backend=send_metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -326,6 +326,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None): # dims=quantity.dims, units=quantity.units, allow_mismatch_float_precision=True, + backend=quantity.backend, ) if recv_state is not None and name in recv_state: tile_quantity = self.gather( @@ -841,7 +842,7 @@ def _get_gather_recv_quantity( units=metadata.units, origin=(0,) + tuple([0 for dim in metadata.dims]), extent=global_extent, - gt4py_backend=metadata.gt4py_backend, + backend=metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -861,7 +862,7 @@ def _get_scatter_recv_quantity( metadata.np.zeros(shape, dtype=metadata.dtype), # type: ignore dims=metadata.dims[1:], units=metadata.units, - gt4py_backend=metadata.gt4py_backend, + backend=metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity diff --git a/ndsl/comm/null_comm.py b/ndsl/comm/null_comm.py index 2d67c78f..bee499ce 100644 --- a/ndsl/comm/null_comm.py +++ b/ndsl/comm/null_comm.py @@ -1,4 +1,5 @@ import copy +import warnings from collections.abc import Mapping from typing import Any, TypeVar, cast @@ -33,6 +34,12 @@ def __init__(self, rank: int, total_ranks: int, fill_value: T = default_fill_val fill_value: fill halos with this value when performing halo updates. """ + warnings.warn( + "NullComm is deprecated and will be removed with the next version of NDSL. " + "Use MPIComm or LocalComm instead.", + DeprecationWarning, + stacklevel=2, + ) self.rank = rank self.total_ranks = total_ranks self._fill_value = fill_value diff --git a/ndsl/constants.py b/ndsl/constants.py index 82d16d30..a02f70f6 100644 --- a/ndsl/constants.py +++ b/ndsl/constants.py @@ -20,14 +20,7 @@ class ConstantVersions(Enum): def _get_constant_version( default: Literal["GFDL", "UFS", "GEOS"] = "UFS", ) -> Literal["GFDL", "UFS", "GEOS"]: - if os.getenv("PACE_CONSTANTS", ""): - ndsl_log.warning("PACE_CONSTANTS is deprecated. Use NDSL_CONSTANTS instead.") - if os.getenv("NDSL_CONSTANTS", ""): - ndsl_log.warning( - "PACE_CONSTANTS and NDSL_CONSTANTS were both specified. NDSL_CONSTANTS will take precedence." - ) - - constants_as_str = os.getenv("NDSL_CONSTANTS", os.getenv("PACE_CONSTANTS", default)) + constants_as_str = os.getenv("NDSL_CONSTANTS", default) expected: list[Literal["GFDL", "UFS", "GEOS"]] = ["GFDL", "UFS", "GEOS"] if constants_as_str not in expected: diff --git a/ndsl/debug/debugger.py b/ndsl/debug/debugger.py index b9de8631..14037428 100644 --- a/ndsl/debug/debugger.py +++ b/ndsl/debug/debugger.py @@ -9,6 +9,7 @@ from ndsl.logging import ndsl_log from ndsl.quantity import Quantity +from ndsl.quantity.field_bundle import FieldBundle @dataclasses.dataclass @@ -87,6 +88,8 @@ def save_as_dataset(self, data_as_dict: dict, savename: str, is_in: bool) -> Non data_arrays[f"{name}.{field.name}"] = self._to_xarray( getattr(data, field.name), field.name ) + elif isinstance(data, FieldBundle): + data_arrays[name] = data.quantity.field_as_xarray else: data_arrays[name] = self._to_xarray(data, name) diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 5fa508d2..f931129b 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -17,18 +17,7 @@ def _get_literal_precision(default: Literal["32", "64"] = "64") -> Literal["32", "64"]: - if os.getenv("PACE_FLOAT_PRECISION", ""): - ndsl_log.warning( - "PACE_FLOAT_PRECISION is deprecated. Use NDSL_LITERAL_PRECISION instead." - ) - if os.getenv("NDSL_LITERAL_PRECISION", ""): - ndsl_log.warning( - "PACE_FLOAT_PRECISION and NDSL_LOGLEVEL were both specified. NDSL_LITERAL_PRECISION will take precedence." - ) - - precision = os.getenv( - "NDSL_LITERAL_PRECISION", os.getenv("PACE_FLOAT_PRECISION", default) - ) + precision = os.getenv("NDSL_LITERAL_PRECISION", default) expected: list[Literal["32", "64"]] = ["32", "64"] if precision in expected: diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 1c1e7ec8..87d608dd 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -5,7 +5,20 @@ def identify_code_path( rank: int, partitioner: Partitioner, + single_code_path: bool, ) -> FV3CodePath: + """Determine which code path your rank will hit. + + If single_code_path is True, single_code_path is True, + only one code path exists (case of doubly periodic grid). + If single_code_path is False, we are in the case of the + cube-sphere and we will look at our position on the tile.""" + + # Doubly-periodic or single tile grid + if single_code_path: + return FV3CodePath.All + + # Cube-sphere if partitioner.layout == (1, 1): return FV3CodePath.All elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1: diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index d76e10da..600a12c3 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -8,14 +8,13 @@ from dace.frontend.python.parser import DaceProgram from gt4py.cartesian.config import GT4PY_COMPILE_OPT_LEVEL +from ndsl import LocalComm from ndsl.comm.communicator import Communicator -from ndsl.comm.null_comm import NullComm from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import get_precision -from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector @@ -31,14 +30,7 @@ def _debug_dace_orchestration() -> bool: Debugging Dace orchestration deeper can be done by turning on `syncdebug`. We control this Dace configuration below with our own override. """ - if os.getenv("PACE_DACE_DEBUG", ""): - ndsl_log.warning("PACE_DACE_DEBUG is deprecated. Use NDSL_DACE_DEBUG instead.") - if os.getenv("NDSL_DACE_DEBUG", ""): - ndsl_log.warning( - "PACE_DACE_DEBUG and NDSL_DACE_DEBUG were both specified. NDSL_DACE_DEBUG will take precedence." - ) - - return os.getenv("NDSL_DACE_DEBUG", os.getenv("PACE_DACE_DEBUG", "False")) == "True" + return os.getenv("NDSL_DACE_DEBUG", "False") == "True" def _is_corner(rank: int, partitioner: Partitioner) -> bool: @@ -110,6 +102,9 @@ def _determine_compiling_ranks( 15 -> 8 """ + if config._single_code_path: + return config.my_rank == 0 + # Tile 0 compiles if partitioner.tile_index(config.my_rank) != 0: return False @@ -155,6 +150,7 @@ def __init__( tile_nz: int = 0, orchestration: DaCeOrchestration | None = None, time: bool = False, + single_code_path: bool = False, ): """Specialize the DaCe configuration for NDSL use. @@ -171,8 +167,11 @@ def __init__( orchestration: orchestration mode from DaCeOrchestration time: trigger performance collection, available to user with `performance_collector` + single_codepath: code is expected to be the same on every rank (case + of column-physics) and therefore can be compiled once """ + self._single_code_path = single_code_path # Recording SDFG loaded for fast re-access # ToDo: DaceConfig becomes a bit more than a read-only config # with this. Should be refactored into a DaceExecutor carrying a config @@ -181,7 +180,7 @@ def __init__( PerformanceCollector( "InternalOrchestrationTimer", comm=( - communicator.comm if communicator is not None else NullComm(0, 6, 0) + LocalComm(0, 6, {}) if communicator is None else communicator.comm ), ) if time @@ -339,7 +338,11 @@ def __init__( if communicator: self.my_rank = communicator.rank self.rank_size = communicator.comm.Get_size() - self.code_path = identify_code_path(self.my_rank, communicator.partitioner) + self.code_path = identify_code_path( + self.my_rank, + communicator.partitioner, + self._single_code_path, + ) self.layout = communicator.partitioner.layout self.do_compile = ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 43b6c4f3..049a1095 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -16,6 +16,7 @@ from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.transformation.auto.auto_optimize import make_transients_persistent +from dace.transformation.dataflow import MapExpansion from dace.transformation.helpers import get_parent_map from dace.transformation.passes.simplify import SimplifyPass from gt4py import storage @@ -34,7 +35,12 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations import ( + AxisIterator, + CartesianAxisMerge, + CartesianRefineTransients, + CleanUpScheduleTree, +) from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -42,14 +48,12 @@ ) from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp +from ndsl.quantity import Quantity, State _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" -_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] -"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" - def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" @@ -125,7 +129,7 @@ def _simplify( # We disable ScalarToSymbolPromotion because it might push symbols onto edges # that DaCe itself can't parse anymore later, e.g. casts, inlined function # calls or (complicated) field accesses. - skip=["ScalarToSymbolPromotion"], + skip={"ScalarToSymbolPromotion"}, ).apply_pass(sdfg, {}) @@ -156,14 +160,39 @@ def _build_sdfg( _simplify(sdfg) if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + # Break all loops into uni-dimensional loops to simplify optimizations + sdfg.apply_transformations_repeated(MapExpansion, validate=True) stree = sdfg.as_schedule_tree() with DaCeProgress(config, "Schedule Tree: optimization"): if config.is_gpu_backend(): GPUPipeline().run(stree) else: - CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + passes = [] + + if config.get_backend() == "dace:cpu_kfirst": + passes.extend( + [ + CleanUpScheduleTree(), + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianAxisMerge(AxisIterator._K), + CartesianRefineTransients(config.get_backend()), + ] + ) + else: + passes.extend( + [ + CleanUpScheduleTree(), + CartesianAxisMerge(AxisIterator._K), + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianRefineTransients(config.get_backend()), + ] + ) + CPUPipeline(passes=passes).run(stree) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) @@ -543,12 +572,18 @@ def orchestrate( if dace_compiletime_args is None: dace_compiletime_args = [] - func = type.__getattribute__(type(obj), method_to_orchestrate) + func: Callable = type.__getattribute__(type(obj), method_to_orchestrate) # Flag argument as dace.constant for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime + for arg_name, annotation in func.__annotations__.items(): + if annotation in [Quantity, State] or ( + isinstance(annotation, type) and issubclass(annotation, State) + ): + func.__annotations__[arg_name] = DaceCompiletime + # Build DaCe orchestrated wrapper # This is a JIT object, e.g. DaCe compilation will happen on call wrapped = _LazyComputepathMethod(func, config).__get__(obj) diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 47c764b3..73497f93 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,11 @@ from .axis_merge import AxisIterator, CartesianAxisMerge +from .clean_tree import CleanUpScheduleTree +from .refine_transients import CartesianRefineTransients -__all__ = ["AxisIterator", "CartesianAxisMerge"] +__all__ = [ + "AxisIterator", + "CartesianAxisMerge", + "CartesianRefineTransients", + "CleanUpScheduleTree", +] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 262a6021..1ee2ff70 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -20,6 +20,10 @@ ) +# Buggy passes that should work +PUSH_IFSCOPE_DOWNWARD = False + + def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool: """Returns true if node is a map over the given axis.""" map_parameter = node.node.params @@ -171,7 +175,7 @@ def __init__( self.eager = eager def __str__(self) -> str: - return f"CartesianAxisMerge({self.axis.name})" + return f"CartesianAxisMerge_{self.axis.name}" def _merge_node( self, @@ -187,7 +191,7 @@ def _merge_node( if isinstance(node, stree.MapScope): return self._map_overcompute_merge(node, nodes) - elif isinstance(node, stree.IfScope): + elif PUSH_IFSCOPE_DOWNWARD and isinstance(node, stree.IfScope): return self._push_ifelse_down(node, nodes) elif isinstance(node, stree.TaskletNode): return self._push_tasklet_down(node, nodes) @@ -221,7 +225,7 @@ def _push_tasklet_down( return 0 # Tasklet is a callback next_index = list_index(nodes, the_tasklet) - if next_index == len(nodes): + if next_index == len(nodes) - 1: return 0 # Last node - done next_node = nodes[next_index + 1] @@ -231,7 +235,6 @@ def _push_tasklet_down( merged = self._merge_node(next_node, nodes) # Attempt to push the tasklet in the next map - ndsl_log.debug(" Push tasklet down into next map") next_node = nodes[next_index + 1] if isinstance(next_node, stree.MapScope): next_node.children.insert(0, the_tasklet) @@ -292,7 +295,6 @@ def _push_ifelse_down( return merged # We are good to go - swap it all - ndsl_log.debug(f" Push IF {the_if.condition.as_string} down") inner_if_map = the_if.children[0] # Swap IF & maps @@ -322,18 +324,20 @@ def _map_overcompute_merge( the_map: stree.MapScope, nodes: list[stree.ScheduleTreeNode], ) -> int: - if _last_node(nodes, the_map): - return 0 + # End of nodes OR + # Not the right axis + # --> recurse + if _last_node(nodes, the_map) or not _is_axis_map(the_map, self.axis): + merged = 0 + for child in the_map.children: + merged += self._merge_node(child, the_map.children) + return merged next_node = _get_next_node(nodes, the_map) - # If the next node is not a MapScope - recurse + # Next node is not a MapScope - no merge if not isinstance(next_node, stree.MapScope): - merged = self._merge_node(next_node, nodes) - new_next_node = _get_next_node(nodes, the_map) - if new_next_node == next_node: - return merged - return merged + self._merge_node(the_map, nodes) + return 0 # Attempt to merge consecutive maps if not _can_merge_axis_maps(the_map, next_node, self.axis): @@ -354,10 +358,6 @@ def _map_overcompute_merge( ] ) - ndsl_log.debug( - f" Merge {self.axis.name} map: {first_range} ⋃ {second_range} -> {merged_range}" - ) - # push IfScope down if children are just maps axis_as_str = the_map.node.params[0] first_map = InsertOvercomputationGuard( @@ -418,11 +418,17 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: # in the tasklet... # NormalizeAxisSymbol(self.axis).visit(node) + # TODO: we are meging single axis, we could prefix those runs by moving + # if scope down inside the map if it has the proper axis, preparing + # for a better merging scope. If we can't nerge, we can revert this + # orep step + overall_merged = 0 + passes_apply = 0 i = 0 while True: i += 1 - ndsl_log.debug(f"🔥 Merge attempt #{i}") + # ndsl_log.debug(f"🔥 Merge attempt #{i}") previous_children = copy.deepcopy(node.children) try: merged = self._merge(node) @@ -435,10 +441,11 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: # If we didn't merge, we revert the children # to the previous state if merged == 0: - ndsl_log.debug("🥹 No merges, revert!") + # ndsl_log.debug("🥹 No merges, revert!") node.children = previous_children break + passes_apply += 1 ndsl_log.debug( - f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged" + f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged in {passes_apply} passes" ) diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py new file mode 100644 index 00000000..41056ec4 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl import ndsl_log + + +class CleanUpScheduleTree(stree.ScheduleNodeTransformer): + """Clean up unused nodes, or nodes barrying further optimizations.""" + + def __init__(self) -> None: + self.cleaned_state_boundaries = 0 + + def __str__(self) -> str: + return "CleanUpScheduleTree" + + def _remove_state_boundaries_from_my_childs( + self, node: stree.ScheduleTreeScope + ) -> None: + to_remove = [ + child + for child in node.children + if isinstance(child, stree.StateBoundaryNode) + ] + for to_remove_child in to_remove: + self.cleaned_state_boundaries += 1 + node.children.remove(to_remove_child) + + def visit_WhileScope(self, node: stree.WhileScope) -> stree.WhileScope: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + return node + + def visit_ForScope(self, node: stree.ForScope) -> stree.ForScope: + self._remove_state_boundaries_from_my_childs(node) + + # We might have inherited a proper `loop_range` from the SDFG + # but the data (sdfg) it relies on is no longer valid. + node.header.loop_range = lambda: None + + for child in node.children: + self.visit(child) + + return node + + def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + return node + + def visit_IfScope(self, node: stree.IfScope) -> stree.IfScope: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + return node + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + ndsl_log.debug( + f"Clean up StateBoundary : {self.cleaned_state_boundaries} nodes" + ) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py new file mode 100644 index 00000000..e24788bc --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import warnings + +import dace.data +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator + + +def _change_index_of_tuple( + old_tuple: tuple[int, ...], index: int, value: int = 1 +) -> tuple[int, ...]: + """Return a copy of the given tuple with `old_tuple[index]` being replaced by `value`. + + Args: + old_tuple: to be copied + index: at which index to replace a value + value: to replace `old_tuple[index]` + """ + new_list = list(old_tuple) + new_list[index] = value + return tuple(new_list) + + +def _reduce_cartesian_axis_size_to_1( + axis: AxisIterator, + transient_map_reads: dace.subsets.Range | None, + transient_map_writes: dace.subsets.Range | None, + transient_data: dace.data.Data, + ijk_order: tuple[int, int, int], +) -> bool: + """Reduce dimension size of transient to 1 if all access (reads and writes) + are atomic""" + + # Dev Note: Better dataflow analysis would look at exactly + # what's goin on here! + + # Assume 3D cartesian! + if len(transient_data.shape) < 3: + warnings.warn( + f"Potential non-3D array: {transient_data}, skipping.", + UserWarning, + stacklevel=2, + ) + return False + + read_write_range: dace.subsets.Range = dace.subsets.union( + transient_map_reads, transient_map_writes + ) + + if read_write_range is None: + return False + + if read_write_range.size()[axis.as_cartesian_index()] != 1: + return False + + # This transient read and write access is done on exactly one element + # therefore this dimension can be removed. BUT we are not truly + # removing it, we are reducing it to 1 to not have to deal + # with different slicing. + transient_data.shape = _change_index_of_tuple( + transient_data.shape, + axis.as_cartesian_index(), + value=1, + ) + + if len(transient_data.shape) == 3: + layout = [*ijk_order] + else: + data_dim_count = len(transient_data.shape) - 3 + layout = [dim + data_dim_count for dim in ijk_order] + [ + i - 1 for i in range(data_dim_count, 0, -1) + ] + + transient_data.set_strides_from_layout(*layout) + transient_data.lifetime = dace.dtypes.AllocationLifetime.State + return True + + +class CollectTransientRangeAccess(stree.ScheduleNodeVisitor): + """Unionize all transient arrays access into a single Range.""" + + def __init__(self) -> None: + # Map access is a `list` instead of a `set` because we want to double count + # access that are in/out as two access on the axis. + self.transients_range_writes: dict[str, dace.subsets.Range | None] = {} + self.transients_range_reads: dict[str, dace.subsets.Range | None] = {} + + def __str__(self) -> str: + return "CartesianCollectMaps" + + def _record_access( + self, + memlets: stree.MemletSet, + recording_set: dict[str, dace.subsets.Range | None], + ) -> None: + for memlet in memlets: + data = self.containers[memlet.data] + if data.transient and isinstance(data, dace.data.Array): + if not isinstance(memlet.subset, dace.subsets.Range): + raise NotImplementedError( + "Memlet refining only works with Range subsets" + ) + recording_set[memlet.data] = dace.subsets.union( + recording_set[memlet.data], memlet.subset + ) + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + self._record_access(node.input_memlets(), self.transients_range_writes) + self._record_access(node.output_memlets(), self.transients_range_reads) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for name, data in self.containers.items(): + if data.transient and isinstance(data, dace.data.Array): + self.transients_range_writes[name] = None + self.transients_range_reads[name] = None + + for child in node.children: + self.visit(child) + + +class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): + """Rebuild memlets from containers to ensure they are scope to the right size.""" + + def __init__(self, refined_arrays: set[str]) -> None: + self._refined_arrays = refined_arrays + + def __str__(self) -> str: + return "RefineTransientAxis" + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for memlet in [*node.output_memlets(), *node.input_memlets()]: + if memlet.data not in self._refined_arrays: + continue + array = self.containers[memlet.data] + if array.transient: + if not isinstance(memlet.subset, dace.subsets.Range): + raise NotImplementedError( + "Memlet refining only works with Range subsets" + ) + + # Reduce "refined" dimension to a single element, effectively + # eliminating it. + for index, _ in enumerate(memlet.subset.ranges): + if array.shape[index] == 1: + memlet.subset.ranges[index] = (0, 0, 1) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for child in node.children: + self.visit(child) + + +class CartesianRefineTransients(stree.ScheduleNodeTransformer): + """Refine (reduce dimensionality) of transients based on their true use in + the cartesian dimensions. + + + It can do: + - Looking at usage of a transient in a cartesian axis (e.g. loop over a + cartesian axis) it will reduce that axis to 1 if all access are atomic + (exactly _one_ element of the array is ever worked on) + + It should but cannot do/will bug if: + - If the transient is _written_ before being _read_ this won't catch it (not its job), but we could + - With better dataflow analysis, we can reduce the dimensions to the correct lowest + size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element + buffer), instead of the defensive _no refine_ strategy used now. We have _most_ of the + info in the `Range` + - Current action when detecting a valid candidate is to reduce the size of the dimension + to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce + the cache access significantly. This also has been implemented to _not_ deal with offset/slicing + downstream impact of removing an axis. Nevertheless the xis should be removed if it's not + used. + - It only knows how to deal with 3D cartesian and 3D cartesian + data dimensions. Anything else will + fail `_reduce_cartesian_axes_size_to_1` calculation + + More tests: + - Test for dataflow with offset + - Test for I/J refine but not in K + - Test for J refine but not in I or K + - Test with dataflow: if/else, while, etc. + - Test with ForScope (FORWARD/BACKWARD) instead of Map + + Coding traps: + - We reduce the "refined" dimensions to 1 which is functionally eliminating it. This is solid. In the + case of the one we can't eliminate we don't do anything. We could find the "smallest buffer size" needed + and reduce the local dimension to it. BUT if we do this, we have to take into account the offset into + memory (e.g. halo) for the `RebuildMemletsFromContainers`! + """ + + def __init__(self, backend: str) -> None: + warnings.warn( + "CartesianRefineTransients is a WIP. It's usage is *severely* limited " + "and will most likely lead to bad numerics. Check the docs, check utest.", + UserWarning, + stacklevel=2, + ) + + if backend in ["dace:cpu_kfirst"]: + self.ijk_order = (2, 1, 0) + elif backend in ["dace:cpu", "dace:gpu"]: + self.ijk_order = (1, 0, 2) + else: + raise NotImplementedError( + "[Schedule Tree Opt] CartesianRefineTransient not implemented for " + f"backend {backend}" + ) + + self.refined_array: set[str] = set() + + def __str__(self) -> str: + return "CartesianRefineTransients" + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + collect_map = CollectTransientRangeAccess() + collect_map.visit(node) + + # Remove Axis + refined_transient = 0 + for name, data in node.containers.items(): + if not (data.transient and isinstance(data, dace.data.Array)): + continue + refined = False + for axis in AxisIterator: + refined |= _reduce_cartesian_axis_size_to_1( + axis, + collect_map.transients_range_reads[name], + collect_map.transients_range_writes[name], + data, + self.ijk_order, + ) + + refined_transient += 1 if refined else 0 + self.refined_array.add(name) + + RebuildMemletsFromContainers(self.refined_array).visit(node) + + ndsl_log.debug(f"🚀 {refined_transient} Transient refined") diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 10fb77cd..702af7fe 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -3,6 +3,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.logging import ndsl_log_on_rank_0 class StreePipeline(ABC): @@ -42,10 +43,14 @@ def run( stree: stree.ScheduleTreeRoot, verbose: bool = False, ) -> stree.ScheduleTreeRoot: - for p in self.passes: + for i, p in enumerate(self.passes): if verbose: - print(f"[Stree OPT] {p}") + path = f"pass{i}_{p}.txt" + ndsl_log_on_rank_0.info(f"[Stree OPT] {p} (saving {path} after)") p.visit(stree) + if verbose: + with open(path, "w") as f: + f.write(stree.as_string()) return stree diff --git a/ndsl/dsl/dace/utils.py b/ndsl/dsl/dace/utils.py index cb35503e..ae6b1d80 100644 --- a/ndsl/dsl/dace/utils.py +++ b/ndsl/dsl/dace/utils.py @@ -32,7 +32,7 @@ def __enter__(self) -> None: def __exit__(self, _type, _val, _traceback) -> None: # type: ignore elapsed = time.time() - self.start - ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed}s.") + ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed:.2f}s.") def _is_ref(sd: dace.sdfg.SDFG, aname: str) -> bool: diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 721ae38b..c798ded4 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -75,7 +75,6 @@ def check_for_quantity(object_: object) -> None: obj=self, config=self._dace_config, ) - print(type(self)) def __getattribute__(self, name: str) -> Any: attr = super().__getattribute__(name) @@ -124,6 +123,6 @@ def make_local( units=quantity.units, origin=quantity.origin, extent=quantity.extent, - gt4py_backend=quantity.gt4py_backend, + backend=quantity.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, ) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index f557c3b3..990c685a 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -240,7 +240,6 @@ def compare_ranks(comm: Comm, data: dict) -> Mapping[str, int]: other = comm.sendrecv(array, pair_rank) arr_diffs = np.sum(np.logical_and(~np.isnan(array), array != other)) if arr_diffs > 0: - print(name, rank, pair_rank, array, other) differences[name] = arr_diffs return differences @@ -290,6 +289,8 @@ def __init__( else: self._timing_collector = timing_collector + self._arguments_already_checked = False + if externals is None: externals = {} self.externals = externals @@ -406,7 +407,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: if self.stencil_config.verbose: ndsl_log.debug(f"Running {self._func_name}") - self._validate_quantity_sizes(*args, **kwargs) + if ( + not self._arguments_already_checked + and self.stencil_config.compilation_config.validate_args + ): + self._validate_quantity_sizes(*args, **kwargs) # Marshal arguments args_list = list(args) @@ -430,7 +435,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: ndsl_debugger.track_data(all_args, self._func_qualname, is_in=True) # Execute stencil - if self.stencil_config.compilation_config.validate_args: + if ( + not self._arguments_already_checked + and self.stencil_config.compilation_config.validate_args + ): if __debug__ and "origin" in kwargs: raise TypeError("origin cannot be passed to FrozenStencil call") if __debug__ and "domain" in kwargs: @@ -443,6 +451,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: validate_args=True, exec_info=self._timing_collector.exec_info, ) + self._arguments_already_checked = True else: self.stencil_object.run( **args_as_kwargs, diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index b4b57183..f90db27a 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from typing import Any, Self +from gt4py.cartesian.backend import from_name as check_backend_existence from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.communicator import Communicator @@ -43,6 +44,7 @@ def __init__( if "gpu" not in backend and device_sync is True: raise RuntimeError("Device sync is true on a CPU based backend") # GT4Py backend args + check_backend_existence(backend) self.backend = backend self.rebuild = rebuild self.validate_args = validate_args diff --git a/ndsl/exceptions.py b/ndsl/exceptions.py deleted file mode 100644 index 4511ea69..00000000 --- a/ndsl/exceptions.py +++ /dev/null @@ -1,16 +0,0 @@ -# flake8: noqa -import warnings - -from ndsl.comm.local_comm import ConcurrencyError -from ndsl.units import UnitsError - - -class OutOfBoundsError(ValueError): - def __init__(self, *args) -> None: - warnings.warn( - "Usage of `OutOfBoundsError` is discouraged. The class will be " - "removed in the next version in favor of using the built-in `IndexError`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args) diff --git a/ndsl/filesystem.py b/ndsl/filesystem.py deleted file mode 100644 index df2e709f..00000000 --- a/ndsl/filesystem.py +++ /dev/null @@ -1,37 +0,0 @@ -import warnings - -import fsspec - - -def get_fs(path: str) -> fsspec.AbstractFileSystem: - """Return the fsspec filesystem required to handle a given path.""" - warnings.warn( - "Usage of `get_fs()` is discouraged if favor `os.path` and `pathlib` " - "modules. The function will be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - fs, _, _ = fsspec.get_fs_token_paths(path) - return fs - - -def is_file(filename): - warnings.warn( - "Usage of `is_file()` is discouraged if favor of plain `os.path.isfile()`. " - "The function will be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - return get_fs(filename).isfile(filename) - - -def open(filename, *args, **kwargs): - warnings.warn( - "Usage of `open()` is discouraged if favor the python built-in file " - "open context manager. The function will be removed in the next version " - "of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - fs = get_fs(filename) - return fs.open(filename, *args, **kwargs) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 1bc37bc5..7296fa76 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -502,7 +502,7 @@ def from_tile_sizing( n_halo=N_HALO_DEFAULT, layout=communicator.partitioner.tile.layout, ) - quantity_factory = QuantityFactory.from_backend(sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) return cls( quantity_factory=quantity_factory, communicator=communicator, @@ -551,7 +551,7 @@ def lon(self): origin=self.grid.origin[:2], extent=self.grid.extent[:2], units=self.grid.units, - gt4py_backend=self.grid.gt4py_backend, + backend=self.grid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -563,7 +563,7 @@ def lat(self) -> Quantity: origin=self.grid.origin[:2], extent=self.grid.extent[:2], units=self.grid.units, - gt4py_backend=self.grid.gt4py_backend, + backend=self.grid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -575,7 +575,7 @@ def lon_agrid(self) -> Quantity: origin=self.agrid.origin[:2], extent=self.agrid.extent[:2], units=self.agrid.units, - gt4py_backend=self.agrid.gt4py_backend, + backend=self.agrid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -587,7 +587,7 @@ def lat_agrid(self) -> Quantity: origin=self.agrid.origin[:2], extent=self.agrid.extent[:2], units=self.agrid.units, - gt4py_backend=self.agrid.gt4py_backend, + backend=self.agrid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1551,7 +1551,7 @@ def rarea(self) -> Quantity: origin=self.area.origin, extent=self.area.extent, units="m^-2", - gt4py_backend=self.area.gt4py_backend, + backend=self.area.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1566,7 +1566,7 @@ def rarea_c(self) -> Quantity: origin=self.area_c.origin, extent=self.area_c.extent, units="m^-2", - gt4py_backend=self.area_c.gt4py_backend, + backend=self.area_c.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1582,7 +1582,7 @@ def rdx(self) -> Quantity: origin=self.dx.origin, extent=self.dx.extent, units="m^-1", - gt4py_backend=self.dx.gt4py_backend, + backend=self.dx.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1598,7 +1598,7 @@ def rdy(self) -> Quantity: origin=self.dy.origin, extent=self.dy.extent, units="m^-1", - gt4py_backend=self.dy.gt4py_backend, + backend=self.dy.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1614,7 +1614,7 @@ def rdxa(self) -> Quantity: origin=self.dxa.origin, extent=self.dxa.extent, units="m^-1", - gt4py_backend=self.dxa.gt4py_backend, + backend=self.dxa.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1630,7 +1630,7 @@ def rdya(self) -> Quantity: origin=self.dya.origin, extent=self.dya.extent, units="m^-1", - gt4py_backend=self.dya.gt4py_backend, + backend=self.dya.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1646,7 +1646,7 @@ def rdxc(self) -> Quantity: origin=self.dxc.origin, extent=self.dxc.extent, units="m^-1", - gt4py_backend=self.dxc.gt4py_backend, + backend=self.dxc.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1662,7 +1662,7 @@ def rdyc(self) -> Quantity: origin=self.dyc.origin, extent=self.dyc.extent, units="m^-1", - gt4py_backend=self.dyc.gt4py_backend, + backend=self.dyc.backend, number_of_halo_points=N_HALO_DEFAULT, ) diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index dd612b22..d907e49a 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -186,7 +186,7 @@ def p_interface(self) -> Quantity: p_interface_data, dims=[Z_INTERFACE_DIM], units="Pa", - gt4py_backend=self.ak.gt4py_backend, + backend=self.ak.backend, number_of_halo_points=self.ak.metadata.n_halo, ) return self._p_interface @@ -203,7 +203,7 @@ def p(self) -> Quantity: p_data, dims=[Z_DIM], units="Pa", - gt4py_backend=self.p_interface.gt4py_backend, + backend=self.p_interface.backend, number_of_halo_points=self.p_interface.metadata.n_halo, ) return self._p @@ -220,7 +220,7 @@ def dp(self) -> Quantity: dp_ref_data, dims=[Z_DIM], units="Pa", - gt4py_backend=self.ak.gt4py_backend, + backend=self.ak.backend, number_of_halo_points=self.ak.metadata.n_halo, ) return self._dp_ref @@ -230,7 +230,7 @@ def ptop(self) -> Float: """Top of atmosphere pressure (Pa)""" if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - if self.ak.gt4py_backend is not None and is_gpu_backend(self.ak.gt4py_backend): + if self.ak.backend is not None and is_gpu_backend(self.ak.backend): return Float(self.ak.view[0].get()) else: return Float(self.ak.view[0]) @@ -382,7 +382,7 @@ def _fC_from_data(data, lat: Quantity) -> Quantity: dims=lat.dims, origin=lat.origin, extent=lat.extent, - gt4py_backend=lat.gt4py_backend, + backend=lat.backend, number_of_halo_points=lat.metadata.n_halo, ) @@ -824,7 +824,7 @@ def split_quantity_along_last_dim(quantity: Quantity) -> list[Quantity]: units=quantity.units, origin=quantity.origin[:-1], extent=quantity.extent[:-1], - gt4py_backend=quantity.gt4py_backend, + backend=quantity.backend, number_of_halo_points=quantity.metadata.n_halo, ) ) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index e9d8857e..f36cd02e 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -13,53 +13,18 @@ from ndsl.quantity import Quantity, QuantityHaloSpec -class StorageNumpy: - def __init__(self, backend: str) -> None: - """Initialize an object which behaves like the numpy module, but uses - gt4py storage objects for zeros, ones, and empty. +class QuantityFactory: + def __init__(self, sizer: GridSizer, *, backend: str) -> None: + """ + Initialize a QuantityFactory from a GridSizer and a GT4Py backend name. Args: - backend: gt4py backend + sizer: GridSizer object that determines the array sizes. + backend: GT4Py backend name used for performance-optimized allocation. """ + self.sizer = sizer self.backend = backend - def empty(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.empty(*args, backend=self.backend, **kwargs) - - def ones(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.ones(*args, backend=self.backend, **kwargs) - - def zeros(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.zeros(*args, backend=self.backend, **kwargs) - - -class QuantityFactory: - def __init__( # type: ignore - self, sizer: GridSizer, numpy, *, silence_deprecation_warning: bool = False - ) -> None: - if not silence_deprecation_warning: - warnings.warn( - "Usage of QuantityFactory(sizer, numpy) is discouraged and will change " - "in the next release. Use QuantityFactory.from_backend(sizer, backend) " - "instead for a stable experience across the release.", - DeprecationWarning, - 2, - ) - self.sizer: GridSizer = sizer - self._numpy = numpy - - def set_extra_dim_lengths(self, **kwargs: Any) -> None: - """ - Set the length of extra (non-x/y/z) dimensions. - """ - warnings.warn( - "`QuantityFactory.set_extra_dim_lengths` is deprecated. " - "Use `add_data_dimensions` or `update_data_dimensions`.", - DeprecationWarning, - 2, - ) - self.sizer.data_dimensions.update(kwargs) - def update_data_dimensions( self, data_dimension_descriptions: dict[str, int], @@ -95,21 +60,22 @@ def add_data_dimensions( @classmethod def from_backend(cls, sizer: GridSizer, backend: str) -> QuantityFactory: - """Initialize a QuantityFactory to use a specific gt4py backend. + """Initialize a QuantityFactory to use a specific GT4Py backend. + + Note: This method is deprecated. Please change your code to use the + constructor instead. Args: - sizer: object which determines array sizes - backend: gt4py backend + sizer: GridSizer object that determines the array sizes. + backend: GT4Py backend name used for performance-optimized allocation. """ - numpy = StorageNumpy(backend) - # Don't print the deprecation warning in this case - return cls(sizer, numpy, silence_deprecation_warning=True) - - def _backend(self) -> str | None: - if isinstance(self._numpy, StorageNumpy): - return self._numpy.backend - - return None + warnings.warn( + "QuantityFactory.from_backend(sizer, backend) is deprecated. Use " + "QuantityFactory(sizer, backend=backend) instead.", + DeprecationWarning, + stacklevel=2, + ) + return cls(sizer, backend=backend) def empty( self, @@ -119,15 +85,11 @@ def empty( *, allow_mismatch_float_precision: bool = False, ) -> Quantity: - """Allocate a Quantity - values are random. + """Allocate a Quantity and fill it with uninitialized (undefined) values. Equivalent to `numpy.empty`""" return self._allocate( - self._numpy.empty, - dims, - units, - dtype, - allow_mismatch_float_precision, + gt_storage.empty, dims, units, dtype, allow_mismatch_float_precision ) def zeros( @@ -142,11 +104,7 @@ def zeros( Equivalent to `numpy.zeros`""" return self._allocate( - self._numpy.zeros, - dims, - units, - dtype, - allow_mismatch_float_precision, + gt_storage.zeros, dims, units, dtype, allow_mismatch_float_precision ) def ones( @@ -161,11 +119,7 @@ def ones( Equivalent to `numpy.ones`""" return self._allocate( - self._numpy.ones, - dims, - units, - dtype, - allow_mismatch_float_precision, + gt_storage.ones, dims, units, dtype, allow_mismatch_float_precision ) def full( @@ -177,11 +131,11 @@ def full( *, allow_mismatch_float_precision: bool = False, ) -> Quantity: - """Allocate a Quantity and fill it with the value. + """Allocate a Quantity and fill it with the given value. Equivalent to `numpy.full`""" quantity = self._allocate( - self._numpy.empty, + gt_storage.empty, dims, units, dtype, @@ -199,10 +153,10 @@ def from_array( allow_mismatch_float_precision: bool = False, ) -> Quantity: """ - Create a Quantity from a numpy array. + Create a Quantity from values in the `data` array. - That numpy array must correspond to the correct shape and extent - for the given dims. + This copies the values of `data` into the resulting Quantity. The data + array thus must correspond to the correct shape and extent for the given dims. """ base = self.zeros( dims=dims, @@ -222,10 +176,12 @@ def from_compute_array( allow_mismatch_float_precision: bool = False, ) -> Quantity: """ - Create a Quantity from a numpy array. + Create a Quantity from values of the compute domain. - That numpy array must correspond to the correct shape and extent - of the compute domain for the given dims. + This function will allocate the full Quantity (including potential + halo points) to zero. The values of `data` are then copied into + the compute domain. That numpy array must correspond to the correct + shape and extent of the compute domain for the given dims. """ base = self.zeros( dims=dims, @@ -257,19 +213,22 @@ def _allocate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) ] - try: - data = allocator( - shape, dtype=dtype, aligned_index=origin, dimensions=dimensions - ) - except TypeError: - data = allocator(shape, dtype=dtype) + + data = allocator( + shape, + dtype=dtype, + aligned_index=origin, + dimensions=dimensions, + backend=self.backend, + ) + return Quantity( data, dims=dims, units=units, origin=origin, extent=extent, - gt4py_backend=self._backend(), + backend=self.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, number_of_halo_points=self.sizer.n_halo, ) diff --git a/ndsl/initialization/grid_sizer.py b/ndsl/initialization/grid_sizer.py index 961ab793..684fa5ab 100644 --- a/ndsl/initialization/grid_sizer.py +++ b/ndsl/initialization/grid_sizer.py @@ -1,10 +1,10 @@ -import warnings +from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @dataclass -class GridSizer: +class GridSizer(ABC): nx: int """Length of the x compute dimension for produced arrays.""" ny: int @@ -16,20 +16,11 @@ class GridSizer: data_dimensions: dict[str, int] """Name/Lengths pair of any non-x/y/z dimensions, such as land or radiation dimensions.""" - @property - def extra_dim_lengths(self) -> dict[str, int]: - warnings.warn( - "`GridSizer.extra_dim_lengths` is a deprecated API, use `GridSizer.data_dimensions`.", - DeprecationWarning, - 2, - ) - return self.data_dimensions + @abstractmethod + def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]: ... - def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]: - raise NotImplementedError() + @abstractmethod + def get_extent(self, dims: Sequence[str]) -> tuple[int, ...]: ... - def get_extent(self, dims: Sequence[str]) -> tuple[int, ...]: - raise NotImplementedError() - - def get_shape(self, dims: Sequence[str]) -> tuple[int, ...]: - raise NotImplementedError() + @abstractmethod + def get_shape(self, dims: Sequence[str]) -> tuple[int, ...]: ... diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index ff31a1c3..c86bbd44 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -1,4 +1,3 @@ -import warnings from collections.abc import Iterable from typing import Self @@ -21,7 +20,6 @@ def from_tile_params( data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, - extra_dim_lengths: dict[str, int] | None = None, ) -> Self: """Create a SubtileGridSizer from parameters about the full tile. @@ -36,18 +34,10 @@ def from_tile_params( tile_partitioner (optional): partitioner object for the tile. By default, a TilePartitioner is created with the given layout tile_rank (optional): rank of this subtile. - extra_dim_lengths: DEPRECATED API - use `data_dimensions` """ if data_dimensions is None: data_dimensions = {} - if extra_dim_lengths is not None: - warnings.warn( - "`extra_dim_lengths` is a deprecated name, please use `data_dimensions` instead.", - DeprecationWarning, - 2, - ) - data_dimensions = extra_dim_lengths if tile_partitioner is None: tile_partitioner = TilePartitioner(layout) y_slice, x_slice = tile_partitioner.subtile_slice( diff --git a/ndsl/logging.py b/ndsl/logging.py index 183054b6..7892a647 100644 --- a/ndsl/logging.py +++ b/ndsl/logging.py @@ -20,14 +20,7 @@ def _get_log_level(default: str = "info") -> str: - if os.getenv("PACE_LOGLEVEL", ""): - logging.warning("PACE_LOGLEVEL is deprecated. Use NDSL_LOGLEVEL instead.") - if os.getenv("NDSL_LOGLEVEL", ""): - logging.warning( - "PACE_LOGLEVEL and NDSL_LOGLEVEL were both specified. NDSL_LOGLEVEL will take precedence." - ) - - loglevel = os.getenv("NDSL_LOGLEVEL", os.getenv("PACE_LOGLEVEL", default)).lower() + loglevel = os.getenv("NDSL_LOGLEVEL", default).lower() if loglevel in AVAILABLE_LOG_LEVELS.keys(): return loglevel diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index e2cb417f..a95c532e 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -21,6 +21,7 @@ def __init__(self, initial: Quantity, time_chunk_size: int): self._dims = initial.dims self._units = initial.units self._i_time = 1 + self._backend = initial.backend def append(self, quantity: Quantity) -> None: # Allow mismatch precision here since this is I/O @@ -37,6 +38,7 @@ def data(self) -> Quantity: dims=("time",) + tuple(self._dims), units=self._units, allow_mismatch_float_precision=True, + backend=self._backend, ) diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 99deec2b..8d2469ba 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from datetime import datetime, timedelta from typing import TypeVar @@ -7,6 +8,7 @@ import xarray as xr import ndsl.constants as constants +from ndsl.comm import Comm, ReductionOperator, Request from ndsl.comm.partitioner import Partitioner, subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy @@ -19,14 +21,16 @@ T = TypeVar("T") -class DummyComm: +class DummyComm(Comm[T]): + """Dummy comm object that works in single-core mode.""" + def Get_rank(self) -> int: return 0 def Get_size(self) -> int: return 1 - def bcast(self, value: T, root: int = 0) -> T: + def bcast(self, value: T | None, root: int = 0) -> T | None: assert root == 0, ( "DummyComm should only be used on a single core, " "so root should only ever be 0" @@ -36,6 +40,47 @@ def bcast(self, value: T, root: int = 0) -> T: def barrier(self) -> None: return + def Barrier(self) -> None: + raise NotImplementedError("DummyComm.Barrier") + + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Scatter") + + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Gather") + + def allgather(self, sendobj: T) -> list[T]: + raise NotImplementedError("DummyComm.allgather") + + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Send") + + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.sendrcv") + + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Isend") + + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Recv") + + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Irecv") + + def Split(self, color, key) -> DummyComm: # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Split") + + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: + raise NotImplementedError("DummyComm.allreduce") + + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + raise NotImplementedError("DummyComm.Allreduce") + + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: + raise NotImplementedError("DummyComm.Allreduce_inplace") + class ZarrMonitor: """ @@ -46,8 +91,9 @@ def __init__( self, store: str | zarr.storage.MutableMapping, partitioner: Partitioner, + *, mode: str = "w", - mpi_comm: DummyComm | None = None, + mpi_comm: Comm | None = None, ) -> None: """Create a ZarrMonitor. @@ -59,6 +105,11 @@ def __init__( use a dummy comm object that works in single-core mode. """ if mpi_comm is None: + warnings.warn( + "`mpi_comm` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) mpi_comm = DummyComm() if mpi_comm.Get_rank() == 0: diff --git a/ndsl/namelist.py b/ndsl/namelist.py deleted file mode 100644 index 00180ace..00000000 --- a/ndsl/namelist.py +++ /dev/null @@ -1,647 +0,0 @@ -import dataclasses -import warnings -from typing import Any, Self - -import f90nml - - -DEFAULT_INT = 0 -DEFAULT_STR = "" -DEFAULT_FLOAT = 0.0 -DEFAULT_BOOL = False - - -# Global set of namelist defaults, attached to class for namespacing and static typing -class NamelistDefaults: - layout = (1, 1) - grid_type = 0 - dx_const = 1000.0 - dy_const = 1000.0 - deglat = 15.0 - u_max = 350.0 - do_f3d = False - inline_q = False - do_skeb = False - """Save dissipation estimate""" - use_logp = False - moist_phys = True - check_negative = False - # gfdl_cloud_microphys.F90 - tau_r2g = 900.0 - """rain freezing during fast_sat""" - tau_smlt = 900.0 - """snow melting""" - tau_g2r = 600.0 - """graupel melting to rain""" - tau_imlt = 600.0 - """cloud ice melting""" - tau_i2s = 1000.0 - """cloud ice to snow auto - conversion""" - tau_l2r = 900.0 - """cloud water to rain auto - conversion""" - tau_g2v = 1200.0 - """graupel sublimation""" - tau_v2g = 21600.0 - """graupel deposition -- make it a slow process""" - sat_adj0 = 0.90 - """adjustment factor (0: no, 1: full) during fast_sat_adj""" - ql_gen = 1.0e-3 - """max new cloud water during remapping step if fast_sat_adj = .t.""" - ql_mlt = 2.0e-3 - """max value of cloud water allowed from melted cloud ice""" - qs_mlt = 1.0e-6 - """max cloud water due to snow melt""" - ql0_max = 2.0e-3 - """max cloud water value (auto converted to rain)""" - t_sub = 184.0 - """min temp for sublimation of cloud ice""" - qi_gen = 1.82e-6 - """max cloud ice generation during remapping step""" - qi_lim = 1.0 - """cloud ice limiter to prevent large ice build up""" - qi0_max = 1.0e-4 - """max cloud ice value (by other sources)""" - rad_snow = True - """consider snow in cloud fraction calculation""" - rad_rain = True - """consider rain in cloud fraction calculation""" - rad_graupel = True - """consider graupel in cloud fraction calculation""" - tintqs = False - """use temperature in the saturation mixing in PDF""" - dw_ocean = 0.10 - """base value for ocean""" - dw_land = 0.15 - """base value for subgrid deviation / variability over land""" - # cloud scheme 0 - ? - # 1: old fvgfs gfdl) mp implementation - # 2: binary cloud scheme (0 / 1) - icloud_f = 0 - cld_min = 0.05 - """minimum cloud fraction""" - tau_l2v = 300.0 - """cloud water to water vapor (evaporation)""" - tau_v2l = 90.0 - """water vapor to cloud water (condensation)""" - c2l_ord = 4 - regional = False - m_split = 0 - convert_ke = False - breed_vortex_inline = False - use_old_omega = True - use_logp = False - rf_fast = False - p_ref = 1e5 - """Surface pressure used to construct a horizontally-uniform reference""" - adiabatic = False - nf_omega = 1 - fv_sg_adj = -1 - n_sponge = 1 - fast_sat_adj = True - qc_crt = 5.0e-8 - """Minimum condensate mixing ratio to allow partial cloudiness""" - c_cracw = 0.8 - """Rain accretion efficiency""" - c_paut = 0.5 - """Autoconversion cloud water to rain (use 0.5 to reduce autoconversion)""" - c_pgacs = 0.01 - """Snow to graupel "accretion" eff. (was 0.1 in zetac)""" - c_psaci = 0.05 - """Accretion: cloud ice to snow (was 0.1 in zetac)""" - ccn_l = 300.0 - """CCN over land (cm^-3)""" - ccn_o = 100.0 - """CCN over ocean (cm^-3)""" - const_vg = False - """Fall velocity tuning constant of graupel""" - const_vi = False - """Fall velocity tuning constant of ice""" - const_vr = False - """Fall velocity tuning constant of rain water""" - const_vs = False - """Fall velocity tuning constant of snow""" - vi_fac = 1.0 - """if const_vi: 1/3""" - vs_fac = 1.0 - """if const_vs: 1.""" - vg_fac = 1.0 - """if const_vg: 2.""" - vr_fac = 1.0 - """if const_vr: 4.""" - de_ice = False - """To prevent excessive build-up of cloud ice from external sources""" - do_qa = True - """Do inline cloud fraction""" - do_sedi_heat = False - """Transport of heat in sedimentation""" - do_sedi_w = True - """Transport of vertical motion in sedimentation""" - fix_negative = True - """Fix negative water species""" - irain_f = 0 - """Cloud water to rain auto conversion scheme""" - mono_prof = False - """Perform terminal fall with mono ppm scheme""" - mp_time = 225.0 - """Maximum microphysics timestep (sec)""" - prog_ccn = False - """Do prognostic ccn (yi ming's method)""" - qi0_crt = 8e-05 - """Cloud ice to snow autoconversion threshold""" - qs0_crt = 0.003 - """Snow to graupel density threshold (0.6e-3 in purdue lin scheme)""" - rh_inc = 0.2 - """RH increment for complete evaporation of cloud water and cloud ice""" - rh_inr = 0.3 - """RH increment for minimum evaporation of rain""" - rthresh = 1e-05 - """Critical cloud drop radius (micrometers)""" - sedi_transport = True - """Transport of momentum in sedimentation""" - use_ppm = False - """Use ppm fall scheme""" - vg_max = 16.0 - """Maximum fall speed for graupel""" - vi_max = 1.0 - """Maximum fall speed for ice""" - vr_max = 16.0 - """Maximum fall speed for rain""" - vs_max = 2.0 - """Maximum fall speed for snow""" - z_slope_ice = True - """Use linear mono slope for autoconversions""" - z_slope_liq = True - """Use linear mono slope for autoconversions""" - tice = 273.16 - """set tice = 165. to turn off ice - phase phys (kessler emulator)""" - alin = 842.0 - """value for 'a' in lin1983""" - clin = 4.8 - """"c" in lin 1983, 4.8 -- > 6. (to enhance ql -- > qs)""" - mom4ice = False - lsm = 1 - redrag = False - isatmedmf = 0 - """which version of satmedmfvdif to use""" - dspheat = False - """flag for tke dissipative heating""" - xkzm_h = 1.0 - """background vertical diffusion for heat q over ocean""" - xkzm_m = 1.0 - """background vertical diffusion for momentum over ocean""" - xkzm_hl = 1.0 - """background vertical diffusion for heat q over land""" - xkzm_ml = 1.0 - """background vertical diffusion for momentum over land""" - xkzm_hi = 1.0 - """background vertical diffusion for heat q over ice""" - xkzm_mi = 1.0 - """background vertical diffusion for momentum over ice""" - xkzm_ho = 1.0 - """background vertical diffusion for heat q over ocean""" - xkzm_mo = 1.0 - """background vertical diffusion for momentum over ocean""" - xkzm_s = 1.0 - """sigma threshold for background mom. diffusion""" - xkzm_lim = 0.01 - """background vertical diffusion limit""" - xkzminv = 0.15 - """diffusivity in inversion layers""" - xkgdx = 25.0e3 - """background vertical diffusion threshold""" - rlmn = 30.0 - """lower-limiter on asymtotic mixing length in satmedmfdiff""" - rlmx = 300.0 - """upper-limiter on asymtotic mixing length in satmedmfdiff""" - do_dk_hb19 = False - """flag for using hb19 background diff formula in satmedmfdiff""" - cap_k0_land = True - """flag for applying limter on background diff in inversion layer over land in satmedmfdiff""" - ncld = 1 - """choice of cloud scheme""" - c0s_shal = 0.002 - """c_e for shallow convection (Han and Pan, 2011, eq(6))""" - c1_shal = 5.0e-4 - """conversion parameter of detrainment from liquid water into convetive precipitaiton""" - clam_shal = 0.3 - """conversion parameter of detrainment from liquid water into grid-scale cloud water""" - pgcon_shal = 0.55 - """control the reduction in momentum transport""" - asolfac_shal = 0.89 - """aerosol-aware parameter based on Lim & Hong (2012): asolfac= cx / c0s(=.002), cx = min([-0.7 ln(Nccn) + 24]*1.e-4, c0s), Nccn: CCN number concentration in cm^(-3), Until a realistic Nccn is provided, typical Nccns are assumed, as Nccn=100 for sea and Nccn=7000 for land""" - lsoil = 4 - """Number of soil levels in land surface model""" - sw_dynamics = False - """flag for turning on shallow water conditions in dyn core""" - - @classmethod - def as_dict(cls) -> dict: - return { - name: default - for name, default in cls.__dict__.items() - if not name.startswith("_") - } - - -@dataclasses.dataclass -class Namelist: - # data_set: Any - # date_out_of_range: str - # do_sst_pert: bool - # interp_oi_sst: bool - # no_anom_sst: bool - # sst_pert: float - # sst_pert_type: str - # use_daily: bool - # use_ncep_ice: bool - # use_ncep_sst: bool - # blocksize: int - # chksum_debug: bool - """ - note: dycore_only may not be used in this model - the same way it is in the Fortran version, watch for - consequences of these inconsistencies, or more closely - parallel the Fortran structure - """ - dycore_only: bool = DEFAULT_BOOL - # fdiag: float - # knob_ugwp_azdir: tuple[int, int, int, int] - # knob_ugwp_doaxyz: int - # knob_ugwp_doheat: int - # knob_ugwp_dokdis: int - # knob_ugwp_effac: tuple[int, int, int, int] - # knob_ugwp_ndx4lh: int - # knob_ugwp_solver: int - # knob_ugwp_source: tuple[int, int, int, int] - # knob_ugwp_stoch: tuple[int, int, int, int] - # knob_ugwp_version: int - # knob_ugwp_wvspec: tuple[int, int, int, int] - # launch_level: int - # reiflag: int - # reimax: float - # reimin: float - # rewmax: float - # rewmin: float - # atmos_nthreads: int - # calendar: Any - # current_date: Any - days: int = 0 - dt_atmos: int = DEFAULT_INT - # dt_ocean: Any - hours: int = 0 - # memuse_verbose: Any - minutes: int = 0 - # months: Any - # ncores_per_node: Any - seconds: int = 0 - # use_hyper_thread: Any - # max_axes: Any - # max_files: Any - # max_num_axis_sets: Any - # prepend_date: Any - # checker_tr: Any - # filtered_terrain: Any - # gfs_dwinds: Any - # levp: Any - # nt_checker: Any - # checksum_required: Any - # max_files_r: Any - # max_files_w: Any - # clock_grain: Any - # domains_stack_size: Any - # print_memory_usage: Any - a_imp: float = DEFAULT_FLOAT - # adjust_dry_mass: Any - beta: float = DEFAULT_FLOAT - # consv_am: Any - consv_te: float = DEFAULT_FLOAT - d2_bg: float = DEFAULT_FLOAT - d2_bg_k1: float = DEFAULT_FLOAT - d2_bg_k2: float = DEFAULT_FLOAT - d4_bg: float = DEFAULT_FLOAT - d_con: float = DEFAULT_FLOAT - d_ext: float = DEFAULT_FLOAT - dddmp: float = DEFAULT_FLOAT - delt_max: float = DEFAULT_FLOAT - # dnats: int - do_sat_adj: bool = DEFAULT_BOOL - do_vort_damp: bool = DEFAULT_BOOL - # dwind_2d: Any - # external_ic: Any - fill: bool = DEFAULT_BOOL - # fill_dp: bool - # fv_debug: Any - # gfs_phil: Any - hord_dp: int = DEFAULT_INT - hord_mt: int = DEFAULT_INT - hord_tm: int = DEFAULT_INT - hord_tr: int = DEFAULT_INT - hord_vt: int = DEFAULT_INT - hydrostatic: bool = DEFAULT_BOOL - # io_layout: Any - k_split: int = DEFAULT_INT - ke_bg: float = DEFAULT_FLOAT - kord_mt: int = DEFAULT_INT - kord_tm: int = DEFAULT_INT - kord_tr: int = DEFAULT_INT - kord_wz: int = DEFAULT_INT - layout: tuple[int, int] = (1, 1) - # make_nh: bool - # mountain: bool - n_split: int = DEFAULT_INT - # na_init: Any - # ncep_ic: Any - # nggps_ic: Any - nord: int = DEFAULT_INT - npx: int = DEFAULT_INT - npy: int = DEFAULT_INT - npz: int = DEFAULT_INT - ntiles: int = DEFAULT_INT - # nudge: Any - # nudge_qv: Any - nwat: int = DEFAULT_INT - p_fac: float = DEFAULT_FLOAT - # phys_hydrostatic: Any - # print_freq: Any - # range_warn: Any - # reset_eta: Any - rf_cutoff: float = DEFAULT_FLOAT - tau: float = DEFAULT_FLOAT - # tau_h2o: Any - # use_hydro_pressure: Any - vtdm4: float = DEFAULT_FLOAT - # warm_start: bool - z_tracer: bool = DEFAULT_BOOL - c_cracw: float = NamelistDefaults.c_cracw - c_paut: float = NamelistDefaults.c_paut - c_pgacs: float = NamelistDefaults.c_pgacs - c_psaci: float = NamelistDefaults.c_psaci - ccn_l: float = NamelistDefaults.ccn_l - ccn_o: float = NamelistDefaults.ccn_o - const_vg: bool = NamelistDefaults.const_vg - const_vi: bool = NamelistDefaults.const_vi - const_vr: bool = NamelistDefaults.const_vr - const_vs: bool = NamelistDefaults.const_vs - qc_crt: float = NamelistDefaults.qc_crt - vs_fac: float = NamelistDefaults.vs_fac - vg_fac: float = NamelistDefaults.vg_fac - vi_fac: float = NamelistDefaults.vi_fac - vr_fac: float = NamelistDefaults.vr_fac - de_ice: bool = NamelistDefaults.de_ice - do_qa: bool = NamelistDefaults.do_qa - do_sedi_heat: bool = NamelistDefaults.do_sedi_heat - do_sedi_w: bool = NamelistDefaults.do_sedi_w - fast_sat_adj: bool = NamelistDefaults.fast_sat_adj - fix_negative: bool = NamelistDefaults.fix_negative - irain_f: int = NamelistDefaults.irain_f - mono_prof: bool = NamelistDefaults.mono_prof - mp_time: float = NamelistDefaults.mp_time - prog_ccn: bool = NamelistDefaults.prog_ccn - qi0_crt: float = NamelistDefaults.qi0_crt - qs0_crt: float = NamelistDefaults.qs0_crt - rh_inc: float = NamelistDefaults.rh_inc - rh_inr: float = NamelistDefaults.rh_inr - # rh_ins: Any - rthresh: float = NamelistDefaults.rthresh - sedi_transport: bool = NamelistDefaults.sedi_transport - # use_ccn: Any - use_ppm: bool = NamelistDefaults.use_ppm - vg_max: float = NamelistDefaults.vg_max - vi_max: float = NamelistDefaults.vi_max - vr_max: float = NamelistDefaults.vr_max - vs_max: float = NamelistDefaults.vs_max - z_slope_ice: bool = NamelistDefaults.z_slope_ice - z_slope_liq: bool = NamelistDefaults.z_slope_liq - tice: float = NamelistDefaults.tice - alin: float = NamelistDefaults.alin - clin: float = NamelistDefaults.clin - mom4ice: bool = NamelistDefaults.mom4ice - lsm: int = NamelistDefaults.lsm - redrag: bool = NamelistDefaults.redrag - isatmedmf: int = NamelistDefaults.isatmedmf - dspheat: bool = NamelistDefaults.dspheat - xkzm_h: float = NamelistDefaults.xkzm_h - xkzm_m: float = NamelistDefaults.xkzm_m - xkzm_hl: float = NamelistDefaults.xkzm_hl - xkzm_ml: float = NamelistDefaults.xkzm_ml - xkzm_hi: float = NamelistDefaults.xkzm_hi - xkzm_mi: float = NamelistDefaults.xkzm_mi - xkzm_ho: float = NamelistDefaults.xkzm_ho - xkzm_mo: float = NamelistDefaults.xkzm_mo - xkzm_s: float = NamelistDefaults.xkzm_s - xkzm_lim: float = NamelistDefaults.xkzm_lim - xkzminv: float = NamelistDefaults.xkzminv - xkgdx: float = NamelistDefaults.xkgdx - rlmn: float = NamelistDefaults.rlmn - rlmx: float = NamelistDefaults.rlmx - do_dk_hb19: bool = NamelistDefaults.do_dk_hb19 - cap_k0_land: bool = NamelistDefaults.cap_k0_land - c0s_shal: float = NamelistDefaults.c0s_shal - c1_shal: float = NamelistDefaults.c1_shal - clam_shal: float = NamelistDefaults.clam_shal - pgcon_shal: float = NamelistDefaults.pgcon_shal - asolfac_shal: float = NamelistDefaults.asolfac_shal - ncld: int = NamelistDefaults.ncld - # cal_pre: Any - # cdmbgwd: Any - # cnvcld: Any - # cnvgwd: Any - # debug: Any - # do_deep: Any - # dspheat: Any - # fhcyc: Any - # fhlwr: Any - # fhswr: Any - # fhzero: Any - # hybedmf: Any - # iaer: Any - # ialb: Any - # ico2: Any - # iems: Any - # imfdeepcnv: Any - # imfshalcnv: Any - # imp_physics: Any - # isol: Any - # isot: Any - # isubc_lw: Any - # isubc_sw: Any - # ivegsrc: Any - # ldiag3d: Any - # lwhtr: Any - # nst_anl: Any - # pdfcld: Any - # pre_rad: Any - # prslrd0: Any - # random_clds: Any - # redrag: Any - # satmedmf: Any - # shal_cnv: Any - # swhtr: Any - # trans_trac: Any - # use_ufo: Any - # xkzm_h: Any - # xkzm_m: Any - # xkzminv: Any - # interp_method: Any - # lat_s: Any - # lon_s: Any - # ntrunc: Any - # fabsl: Any - # faisl: Any - # faiss: Any - # fnabsc: Any - # fnacna: Any - # fnaisc: Any - # fnalbc: Any - # fnalbc2: Any - # fnglac: Any - # fnmskh: Any - # fnmxic: Any - # fnslpc: Any - # fnsmcc: Any - # fnsnoa: Any - # fnsnoc: Any - # fnsotc: Any - # fntg3c: Any - # fntsfa: Any - # fntsfc: Any - # fnvegc: Any - # fnvetc: Any - # fnvmnc: Any - # fnvmxc: Any - # fnzorc: Any - # fsicl: Any - # fsics: Any - # fslpl: Any - # fsmcl: Any - # fsnol: Any - # fsnos: Any - # fsotl: Any - # ftsfl: Any - # ftsfs: Any - # fvetl: Any - # fvmnl: Any - # fvmxl: Any - # ldebug: Any - grid_type: int = NamelistDefaults.grid_type - dx_const: float = NamelistDefaults.dx_const - dy_const: float = NamelistDefaults.dy_const - deglat: float = NamelistDefaults.deglat - u_max: float = NamelistDefaults.u_max - do_f3d: bool = NamelistDefaults.do_f3d - inline_q: bool = NamelistDefaults.inline_q - do_skeb: bool = NamelistDefaults.do_skeb - """save dissipation estimate""" - use_logp: bool = NamelistDefaults.use_logp - moist_phys: bool = NamelistDefaults.moist_phys - check_negative: bool = NamelistDefaults.check_negative - # gfdl_cloud_microphys.F90 - tau_r2g: float = NamelistDefaults.tau_r2g - """rain freezing during fast_sat""" - tau_smlt: float = NamelistDefaults.tau_smlt - """snow melting""" - tau_g2r: float = NamelistDefaults.tau_g2r - """graupel melting to rain""" - tau_imlt: float = NamelistDefaults.tau_imlt - """cloud ice melting""" - tau_i2s: float = NamelistDefaults.tau_i2s - """cloud ice to snow auto - conversion""" - tau_l2r: float = NamelistDefaults.tau_l2r - """cloud water to rain auto - conversion""" - tau_g2v: float = NamelistDefaults.tau_g2v - """graupel sublimation""" - tau_v2g: float = NamelistDefaults.tau_v2g - """graupel deposition -- make it a slow process""" - sat_adj0: float = NamelistDefaults.sat_adj0 - """adjustment factor (0: no 1: full) during fast_sat_adj""" - ql_gen: float = 1.0e-3 - """max new cloud water during remapping step if fast_sat_adj = .t.""" - ql_mlt: float = NamelistDefaults.ql_mlt - """max value of cloud water allowed from melted cloud ice""" - qs_mlt: float = NamelistDefaults.qs_mlt - """max cloud water due to snow melt""" - ql0_max: float = NamelistDefaults.ql0_max - """max cloud water value (auto converted to rain)""" - t_sub: float = NamelistDefaults.t_sub - """min temp for sublimation of cloud ice""" - qi_gen: float = NamelistDefaults.qi_gen - """max cloud ice generation during remapping step""" - qi_lim: float = NamelistDefaults.qi_lim - """cloud ice limiter to prevent large ice build up""" - qi0_max: float = NamelistDefaults.qi0_max - """max cloud ice value (by other sources)""" - rad_snow: bool = NamelistDefaults.rad_snow - """consider snow in cloud fraction calculation""" - rad_rain: bool = NamelistDefaults.rad_rain - """consider rain in cloud fraction calculation""" - rad_graupel: bool = NamelistDefaults.rad_graupel - """consider graupel in cloud fraction calculation""" - tintqs: bool = NamelistDefaults.tintqs - """use temperature in the saturation mixing in PDF""" - dw_ocean: float = NamelistDefaults.dw_ocean - """base value for ocean""" - dw_land: float = NamelistDefaults.dw_land - """base value for subgrid deviation / variability over land""" - # cloud scheme 0 - ? - # 1: old fvgfs gfdl) mp implementation - # 2: binary cloud scheme (0 / 1) - icloud_f: int = NamelistDefaults.icloud_f - cld_min: float = NamelistDefaults.cld_min - """minimum cloud fraction""" - tau_l2v: float = NamelistDefaults.tau_l2v - """cloud water to water vapor (evaporation)""" - tau_v2l: float = NamelistDefaults.tau_v2l - """water vapor to cloud water (condensation)""" - c2l_ord: int = NamelistDefaults.c2l_ord - regional: bool = NamelistDefaults.regional - m_split: int = NamelistDefaults.m_split - convert_ke: bool = NamelistDefaults.convert_ke - breed_vortex_inline: bool = NamelistDefaults.breed_vortex_inline - use_old_omega: bool = NamelistDefaults.use_old_omega - rf_fast: bool = NamelistDefaults.rf_fast - adiabatic: bool = NamelistDefaults.adiabatic - nf_omega: int = NamelistDefaults.nf_omega - fv_sg_adj: int = NamelistDefaults.fv_sg_adj - n_sponge: int = NamelistDefaults.n_sponge - lsoil: int = NamelistDefaults.lsoil - daily_mean: bool = False - sw_dynamics: bool = NamelistDefaults.sw_dynamics - """Flag to replace cosz with daily mean value in physics""" - - @classmethod - def from_f90nml(cls, namelist: f90nml.Namelist) -> Self: - namelist_dict = namelist_to_flatish_dict(namelist.items()) - namelist_dict = { - key: value - for key, value in namelist_dict.items() - if key in cls.__dataclass_fields__ - } - return cls(**namelist_dict) - - def __post_init__(self) -> None: - warnings.warn( - "Usage of `ndsl.Namelist` is discouraged. The class will be " - "removed in the next version together with `NamelistDefaults`, see " - "https://github.com/NOAA-GFDL/NDSL/issues/64.", - DeprecationWarning, - stacklevel=2, - ) - - -def namelist_to_flatish_dict(nml_input: Any) -> dict: - nml = dict(nml_input) - for name, value in nml.items(): - if isinstance(value, f90nml.Namelist): - nml[name] = namelist_to_flatish_dict(value) - flatter_namelist = {} - for key, value in nml.items(): - if isinstance(value, dict): - for subkey, subvalue in value.items(): - if subkey in flatter_namelist: - raise ValueError( - "Cannot flatten this namelist, duplicate keys: " + subkey - ) - flatter_namelist[subkey] = subvalue - else: - flatter_namelist[key] = value - return flatter_namelist diff --git a/ndsl/quantity/field_bundle.py b/ndsl/quantity/field_bundle.py index f33ae8a6..c66b338f 100644 --- a/ndsl/quantity/field_bundle.py +++ b/ndsl/quantity/field_bundle.py @@ -95,6 +95,7 @@ def __getattr__(self, name: str) -> Quantity: units=self._quantity.units, origin=self._quantity.origin[:-1], extent=self._quantity.extent[:-1], + backend=self._quantity.backend, ) return self._per_name_view[name] diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 910999ab..438def1b 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +import warnings +from collections.abc import Sequence +from typing import Any import dace import numpy as np @@ -20,21 +22,38 @@ def __init__( data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, + *, + backend: str | None = None, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, ): + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + + if backend is None: + warnings.warn( + "`backend` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( data, dims, units, - origin, - extent, - gt4py_backend, - allow_mismatch_float_precision, + origin=origin, + extent=extent, + allow_mismatch_float_precision=allow_mismatch_float_precision, + backend=backend, ) - self._transient = True def __descriptor__(self) -> Any: """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index 7e7b4f16..45a14445 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -30,7 +30,9 @@ class QuantityMetadata: dtype: type "dtype of the data in the ndarray-like object" gt4py_backend: str | None = None - "backend to use for gt4py storages" + "Deprecated. Use backend instead." + backend: str | None = None + "GT4Py backend name. Used for performance optimal data allocation." @property def dim_lengths(self) -> dict[str, int]: @@ -57,6 +59,7 @@ def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None: metadata_copy.data_type = self.data_type metadata_copy.dtype = self.dtype metadata_copy.gt4py_backend = self.gt4py_backend + metadata_copy.backend = self.backend @dataclasses.dataclass diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 45312d2f..5890fcfe 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -32,6 +32,8 @@ def __init__( data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, + *, + backend: str | None = None, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, @@ -41,24 +43,40 @@ def __init__( """Initialize a Quantity. Args: - data (_type_): ndarray-like object containing the underlying data - dims (Sequence[str]): dimension names for each axis - units (str): units of the quantity - origin (Sequence[int] | None, optional): first point in data within the - computational domain. Defaults to None. - extent (Sequence[int] | None, optional): number of points along each axis + data: ndarray-like object containing the underlying data + dims: dimension names for each axis + units: units of the quantity + backend: GT4Py backend name. We ensure that the data is allocated in a + performance optimal way for that backend and copy if necessary. + origin: first point in data within the + computational domain. Defaults to None. + extent: number of points along each axis within the computational domain. Defaults to None. - gt4py_backend (str | None, optional): backend to use for gt4py storages, - if not given this will be derived from a Storage - if given as the data argument. Defaults to None. - allow_mismatch_float_precision (bool, optional): allow for precision that is + gt4py_backend: deprecated, use `backend` instead. + allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. - number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + number_of_halo_points: Number of halo points used. Defaults to 0. Raises: ValueError: Data-type mismatch between configuration and input-data TypeError: Typing of the data that does not fit """ + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + + if backend is None: + warnings.warn( + "`backend` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -80,6 +98,11 @@ def __init__( if isinstance(data, (int, float, list)): # If converting basic data, use a numpy ndarray. + warnings.warn( + "Usage of basic data in Quantities is deprecated. Please use it with a numpy or cuppy ndarray instead.", + DeprecationWarning, + stacklevel=2, + ) data = np.asarray(data) if not isinstance(data, (np.ndarray, cupy.ndarray)): @@ -87,8 +110,10 @@ def __init__( f"Only supports numpy.ndarray and cupy.ndarray, got {type(data)}" ) - if gt4py_backend is not None: - gt4py_backend_cls = gt_backend.from_name(gt4py_backend) + _validate_quantity_property_lengths(data.shape, dims, origin, extent) + + if backend is not None: + gt4py_backend_cls = gt_backend.from_name(backend) is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] dimensions: tuple[str | int, ...] = tuple( @@ -104,21 +129,25 @@ def __init__( ] ) - self._data = ( - data - if is_optimal_layout(data, dimensions) - else self._initialize_data( + if is_optimal_layout(data, dimensions): + self._data = data + else: + warnings.warn( + f"Suboptimal data layout found. Copying data to optimally align for backend '{backend}'.", + UserWarning, + stacklevel=2, + ) + self._data = gt_storage.from_array( data, - origin=origin, - gt4py_backend=gt4py_backend, + data.dtype, + backend=backend, + aligned_index=origin, dimensions=dimensions, ) - ) else: - # We have no info about the gt4py_backend, so just assign it. + # We have no info about the gt4py backend, so just assign it. self._data = data - _validate_quantity_property_lengths(data.shape, dims, origin, extent) self._metadata = QuantityMetadata( origin=_ensure_int_tuple(origin, "origin"), extent=_ensure_int_tuple(extent, "extent"), @@ -127,7 +156,8 @@ def __init__( units=units, data_type=type(self._data), dtype=data.dtype, - gt4py_backend=gt4py_backend, + backend=backend, + gt4py_backend=backend, ) self._attrs = {} # type: ignore[var-annotated] self._compute_domain_view = BoundedArrayView( @@ -138,10 +168,12 @@ def __init__( def from_data_array( cls, data_array: xr.DataArray, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, number_of_halo_points: int = 0, + backend: str | None = None, ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -150,12 +182,26 @@ def from_data_array( data_array origin: first point in data within the computational domain extent: number of points along each axis within the computational domain - gt4py_backend: backend to use for gt4py storages, if not given this will - be derived from a Storage if given as the data argument, otherwise the - storage attribute is disabled and will raise an exception + gt4py_backend: deprecated, use `backend` instead. + allow_mismatch_float_precision: allow for precision that is + not the simulation-wide default configuration. Defaults to False. + number_of_halo_points: Number of halo points used. Defaults to 0. + backend: GT4Py backend name. If given, we allocate data in a performance + optimal way for this backend. Overrides any potentially saved `backend` + in `data.attrs["backend"]`. """ if "units" not in data_array.attrs: raise ValueError("need units attribute to create Quantity from DataArray") + + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + return cls( data_array.values, cast(tuple[str], data_array.dims), @@ -163,13 +209,15 @@ def from_data_array( origin=origin, extent=extent, number_of_halo_points=number_of_halo_points, - gt4py_backend=gt4py_backend, + backend=_resolve_backend(data_array, backend), ) def to_netcdf( self, path: str, name: str = "var", rank: int = -1, all_data: bool = False ) -> None: if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: + if rank < 0: + rank = MPI.COMM_WORLD.Get_rank() if all_data: self.data_as_xarray.to_dataset(name=name).to_netcdf( f"{path}__r{rank}.nc4" @@ -221,17 +269,6 @@ def sel(self, **kwargs: slice | int) -> np.ndarray: """ return self.view[tuple(kwargs.get(dim, slice(None, None)) for dim in self.dims)] - def _initialize_data(self, data, origin, gt4py_backend: str, dimensions: tuple): # type: ignore - """Allocates an ndarray with optimal memory layout, and copies the data over.""" - storage = gt_storage.from_array( - data, - data.dtype, - backend=gt4py_backend, - aligned_index=origin, - dimensions=dimensions, - ) - return storage - @property def metadata(self) -> QuantityMetadata: return self._metadata @@ -243,29 +280,26 @@ def units(self) -> str: @property def gt4py_backend(self) -> str | None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) return self.metadata.gt4py_backend + @property + def backend(self) -> str | None: + return self.metadata.backend + @property def attrs(self) -> dict: - return dict(**self._attrs, units=self._metadata.units) + return dict(**self._attrs, units=self.units, backend=self.backend) @property def dims(self) -> tuple[str, ...]: """Names of each dimension""" return self.metadata.dims - @property - def values(self) -> np.ndarray: - warnings.warn( - "values exists only for backwards-compatibility with " - "DataArray and will be removed, use .view[:] instead", - DeprecationWarning, - stacklevel=2, - ) - return_array = np.asarray(self.view[:]) - return_array.flags.writeable = False - return return_array - @property def view(self) -> BoundedArrayView: """A view into the computational domain of the underlying data""" @@ -397,19 +431,14 @@ def transpose( units=self.units, origin=_transpose_sequence(self.origin, transpose_order), extent=_transpose_sequence(self.extent, transpose_order), - gt4py_backend=self.gt4py_backend, allow_mismatch_float_precision=allow_mismatch_float_precision, + backend=self.backend, ) transposed._attrs = self._attrs return transposed def plot_k_level(self, k_index: int = 0) -> None: field = self.data - print( - "Min and max values:", - field[:, :, k_index].min(), - field[:, :, k_index].max(), - ) plt.xlabel("I") plt.ylabel("J") @@ -476,3 +505,16 @@ def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple: f"unexpected type {type(item)}" ) return tuple(return_list) + + +def _resolve_backend(data: xr.DataArray, backend: str | None) -> str: + if backend is not None: + # Forced backend name takes precedence + return backend + + # If backend name was serialized with data, take this one + if "backend" in data.attrs: + return data.attrs["backend"] + + # else, fall back to assume python-based layout. + return "debug" diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 8a635187..08c24561 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,8 +1,4 @@ -from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid +from .corners import CopyCornersXY, FillCornersBGrid -__all__ = [ - "CopyCorners", - "CopyCornersXY", - "FillCornersBGrid", -] +__all__ = ["CopyCornersXY", "FillCornersBGrid"] diff --git a/ndsl/stencils/column_operations.py b/ndsl/stencils/column_operations.py new file mode 100644 index 00000000..25ede5de --- /dev/null +++ b/ndsl/stencils/column_operations.py @@ -0,0 +1,103 @@ +import typing + +from ndsl.dsl.gt4py import function + + +@typing.no_type_check +@function +def column_max(field, start_index, end_index): + """ + Find the maximum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [max value, index of max value] + """ + max_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level) + old = field.at(K=max_index) + if new > old: + max_index = level + level += 1 + + return field.at(K=max_index), max_index + + +@typing.no_type_check +@function +def column_max_ddim(field, ddim, start_index, end_index): + """ + Find the maximum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [max value, index of max value] + """ + max_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level, ddim=[ddim]) + old = field.at(K=max_index, ddim=[ddim]) + if new > old: + max_index = level + level += 1 + + return field.at(K=max_index, ddim=[ddim]), max_index + + +@typing.no_type_check +@function +def column_min(field, start_index, end_index): + """ + Find the minimum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [min value, index of min value] + """ + min_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level) + old = field.at(K=min_index) + if new < old: + min_index = level + level += 1 + + return field.at(K=min_index), min_index + + +@typing.no_type_check +@function +def column_min_ddim(field, ddim, start_index, end_index): + """ + Find the minimum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [min value, index of min value] + """ + min_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level, ddim=[ddim]) + old = field.at(K=min_index, ddim=[ddim]) + if new < old: + min_index = level + level += 1 + + return field.at(K=min_index, ddim=[ddim]), min_index diff --git a/ndsl/stencils/corners.py b/ndsl/stencils/corners.py index 021a439b..d3fd2c0d 100644 --- a/ndsl/stencils/corners.py +++ b/ndsl/stencils/corners.py @@ -5,69 +5,11 @@ from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region from ndsl import StencilFactory -from ndsl.constants import ( - X_DIM, - X_INTERFACE_DIM, - Y_DIM, - Y_INTERFACE_DIM, - Z_INTERFACE_DIM, -) +from ndsl.constants import X_INTERFACE_DIM, Y_INTERFACE_DIM, Z_INTERFACE_DIM from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import FloatField -class CopyCorners: - """ - Helper-class to copy corners corresponding to the fortran functions - copy_corners_x or copy_corners_y respectively - """ - - def __init__(self, direction: str, stencil_factory: StencilFactory) -> None: - """The grid for this stencil""" - warnings.warn( - "Usage of the GT4Py implementation of CopyCorners is discouraged and will" - "be removed in the next release. Use `CopyCornersX` or `CopyCornersY` in PyFV3" - "for a more future-proof implementation of the same code.", - DeprecationWarning, - 2, - ) - grid_indexing = stencil_factory.grid_indexing - - n_halo = grid_indexing.n_halo - origin, domain = grid_indexing.get_origin_domain( - dims=[X_DIM, Y_DIM, Z_INTERFACE_DIM], halos=(n_halo, n_halo) - ) - - ax_offsets = grid_indexing.axis_offsets(origin, domain) - if direction == "x": - self._copy_corners = stencil_factory.from_origin_domain( - func=copy_corners_x_stencil_defn, - origin=origin, - domain=domain, - externals={ - **ax_offsets, - }, - ) - elif direction == "y": - self._copy_corners = stencil_factory.from_origin_domain( - func=copy_corners_y_stencil_defn, - origin=origin, - domain=domain, - externals={ - **ax_offsets, - }, - ) - else: - raise ValueError("Direction must be either 'x' or 'y'") - - def __call__(self, field: FloatField): - """ - Fills cell quantity field using corners from itself and multipliers - in the direction specified initialization of the instance of this class. - """ - self._copy_corners(field, field) - - class CopyCornersXY: """ Helper-class to copy corners corresponding to the Fortran functions @@ -87,6 +29,13 @@ def __init__( y_field: 3D gt4py storage to use for y-differenceable field (x-differenceable field uses same memory as base field) """ + warnings.warn( + "Usage of CopyCornersXY is deprecated and will be removed in the next release. " + "Use `CopyCornersX` and `CopyCornersY` in PyFV3 for a more future-proof " + "implementation of the corner code.", + DeprecationWarning, + stacklevel=2, + ) grid_indexing = stencil_factory.grid_indexing origin, domain = grid_indexing.get_origin_domain( dims=dims, halos=(grid_indexing.n_halo, grid_indexing.n_halo) @@ -313,126 +262,6 @@ def fill_corners_3cells_mult_y( return q -def copy_corners_x_stencil_defn(q_in: FloatField, q_out: FloatField): - from __externals__ import i_end, i_start, j_end, j_start - - with computation(PARALLEL), interval(...): - with horizontal( - region[i_start - 3, j_start - 3], region[i_end + 3, j_start - 3] - ): - q_out = q_in[0, 5, 0] - with horizontal( - region[i_start - 2, j_start - 3], region[i_end + 3, j_start - 2] - ): - q_out = q_in[-1, 4, 0] - with horizontal( - region[i_start - 1, j_start - 3], region[i_end + 3, j_start - 1] - ): - q_out = q_in[-2, 3, 0] - with horizontal( - region[i_start - 3, j_start - 2], region[i_end + 2, j_start - 3] - ): - q_out = q_in[1, 4, 0] - with horizontal( - region[i_start - 2, j_start - 2], region[i_end + 2, j_start - 2] - ): - q_out = q_in[0, 3, 0] - with horizontal( - region[i_start - 1, j_start - 2], region[i_end + 2, j_start - 1] - ): - q_out = q_in[-1, 2, 0] - with horizontal( - region[i_start - 3, j_start - 1], region[i_end + 1, j_start - 3] - ): - q_out = q_in[2, 3, 0] - with horizontal( - region[i_start - 2, j_start - 1], region[i_end + 1, j_start - 2] - ): - q_out = q_in[1, 2, 0] - with horizontal( - region[i_start - 1, j_start - 1], region[i_end + 1, j_start - 1] - ): - q_out = q_in[0, 1, 0] - with horizontal(region[i_start - 3, j_end + 1], region[i_end + 1, j_end + 3]): - q_out = q_in[2, -3, 0] - with horizontal(region[i_start - 2, j_end + 1], region[i_end + 1, j_end + 2]): - q_out = q_in[1, -2, 0] - with horizontal(region[i_start - 1, j_end + 1], region[i_end + 1, j_end + 1]): - q_out = q_in[0, -1, 0] - with horizontal(region[i_start - 3, j_end + 2], region[i_end + 2, j_end + 3]): - q_out = q_in[1, -4, 0] - with horizontal(region[i_start - 2, j_end + 2], region[i_end + 2, j_end + 2]): - q_out = q_in[0, -3, 0] - with horizontal(region[i_start - 1, j_end + 2], region[i_end + 2, j_end + 1]): - q_out = q_in[-1, -2, 0] - with horizontal(region[i_start - 3, j_end + 3], region[i_end + 3, j_end + 3]): - q_out = q_in[0, -5, 0] - with horizontal(region[i_start - 2, j_end + 3], region[i_end + 3, j_end + 2]): - q_out = q_in[-1, -4, 0] - with horizontal(region[i_start - 1, j_end + 3], region[i_end + 3, j_end + 1]): - q_out = q_in[-2, -3, 0] - - -def copy_corners_y_stencil_defn(q_in: FloatField, q_out: FloatField): - from __externals__ import i_end, i_start, j_end, j_start - - with computation(PARALLEL), interval(...): - with horizontal( - region[i_start - 3, j_start - 3], region[i_start - 3, j_end + 3] - ): - q_out = q_in[5, 0, 0] - with horizontal( - region[i_start - 2, j_start - 3], region[i_start - 3, j_end + 2] - ): - q_out = q_in[4, 1, 0] - with horizontal( - region[i_start - 1, j_start - 3], region[i_start - 3, j_end + 1] - ): - q_out = q_in[3, 2, 0] - with horizontal( - region[i_start - 3, j_start - 2], region[i_start - 2, j_end + 3] - ): - q_out = q_in[4, -1, 0] - with horizontal( - region[i_start - 2, j_start - 2], region[i_start - 2, j_end + 2] - ): - q_out = q_in[3, 0, 0] - with horizontal( - region[i_start - 1, j_start - 2], region[i_start - 2, j_end + 1] - ): - q_out = q_in[2, 1, 0] - with horizontal( - region[i_start - 3, j_start - 1], region[i_start - 1, j_end + 3] - ): - q_out = q_in[3, -2, 0] - with horizontal( - region[i_start - 2, j_start - 1], region[i_start - 1, j_end + 2] - ): - q_out = q_in[2, -1, 0] - with horizontal( - region[i_start - 1, j_start - 1], region[i_start - 1, j_end + 1] - ): - q_out = q_in[1, 0, 0] - with horizontal(region[i_end + 1, j_start - 3], region[i_end + 3, j_end + 1]): - q_out = q_in[-3, 2, 0] - with horizontal(region[i_end + 2, j_start - 3], region[i_end + 3, j_end + 2]): - q_out = q_in[-4, 1, 0] - with horizontal(region[i_end + 3, j_start - 3], region[i_end + 3, j_end + 3]): - q_out = q_in[-5, 0, 0] - with horizontal(region[i_end + 1, j_start - 2], region[i_end + 2, j_end + 1]): - q_out = q_in[-2, 1, 0] - with horizontal(region[i_end + 2, j_start - 2], region[i_end + 2, j_end + 2]): - q_out = q_in[-3, 0, 0] - with horizontal(region[i_end + 3, j_start - 2], region[i_end + 2, j_end + 3]): - q_out = q_in[-4, -1, 0] - with horizontal(region[i_end + 1, j_start - 1], region[i_end + 1, j_end + 1]): - q_out = q_in[-1, 0, 0] - with horizontal(region[i_end + 2, j_start - 1], region[i_end + 1, j_end + 2]): - q_out = q_in[-2, -1, 0] - with horizontal(region[i_end + 3, j_start - 1], region[i_end + 1, j_end + 3]): - q_out = q_in[-3, -2, 0] - - def copy_corners_xy_stencil_defn( q_in: FloatField, q_out_x: FloatField, q_out_y: FloatField ): diff --git a/ndsl/stencils/testing/best_guess_diff.py b/ndsl/stencils/testing/best_guess_diff.py new file mode 100644 index 00000000..7bb0ca82 --- /dev/null +++ b/ndsl/stencils/testing/best_guess_diff.py @@ -0,0 +1,194 @@ +import argparse +import pathlib + +import numpy as np +import xarray as xr +import yaml + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + "Attempt to diff two NetCDFs with similar data." + "Differences that can be reconciled are strict domain vs halo, variable name mapping." + "The program will report on assumptions taken." + ) + parser.add_argument( + "netcdf_A", + type=str, + help="path of first NetCDFs, named A in the logs.", + ) + parser.add_argument( + "netcdf_B", + type=str, + help="path of second NetCDFs, named B in the logs.", + ) + parser.add_argument( + "-nm", + "--name_mapping", + type=str, + help="[Optional] Yaml file describing the mapping of the names.", + ) + parser.add_argument( + "-ha", + "--halo", + type=int, + default=3, + help="[Optional] Halo size if any, default to 3.", + ) + return parser + + +def main( + netcdf_A: str, + netcdf_B: str, + name_mapping: str | None = None, + halo: int = 3, +) -> None: + A = xr.open_dataset(netcdf_A) + B = xr.open_dataset(netcdf_B) + name_map = {} + if name_mapping is not None: + with open(name_mapping) as f: + name_map = yaml.safe_load(f) + + dataset = {} + for name_A, data_A in A.items(): + print(f"Best guess for {name_A} from A:") + # Resolve name + resolved_name_B = None + if name_A not in B.keys(): + if name_A not in name_map.keys(): + print(" [Failed] name can't be found in B nor in name mapping") + else: + resolved_name_B = name_map[name_A] + else: + resolved_name_B = name_A + + if resolved_name_B is None: + continue + print(f" [Hyp] use {resolved_name_B} for B") + + # Resolve domain size + data_B = B[resolved_name_B] + if len(data_A.shape) >= 5: + print( + " [Hyp] A data dims are >= 5, assuming savepoints/rank are the firt two and going A[0, 0, ::]" + ) + data_A = data_A[0, 0, ::] + if len(data_B.shape) >= 5: + print( + " [Hyp] B data dims are >= 5, assuming savepoints/rank are the firt two and going B[0, 0, ::]" + ) + data_B = data_B[0, 0, ::] + + if len(data_A.shape) != len(data_B.shape): + print( + f" [Failed] A is shape {len(data_A.shape)}, B is shape {len(data_B.shape)}: can't reconcile." + ) + continue + + # - Assume we have 0 == I, 1 == J now + resolved_I = None + A_uses_halo = False + B_uses_halo = False + if data_A.shape[0] != data_B.shape[0]: + if data_A.shape[0] < data_B.shape[0]: + if data_B.shape[0] - 2 * halo != data_A.shape[0]: + print( + f" [Failed] B in dim I is too big, even with halo substracted {data_B.shape[0]} (halo: {halo})" + ) + else: + B_uses_halo = True + resolved_I = data_A.shape[0] + else: + if data_A.shape[0] - 2 * halo != data_B.shape[0]: + print( + f" [Failed] A in dim I is too big, even with halo substracted {data_A.shape[0]} (halo: {halo})" + ) + else: + A_uses_halo = True + resolved_I = data_B.shape[0] + else: + resolved_I = data_A.shape[0] + + if resolved_I is None: + continue + + print(f" [Hyp] Using {resolved_I} as I dim size") + + resolved_J = None + if data_A.shape[1] != data_B.shape[1]: + if data_A.shape[1] < data_B.shape[1]: + if data_B.shape[1] - 2 * halo != data_A.shape[1]: + print( + f" [Failed] B in dim J is too big, even with halo substracted {data_B.shape[1]} (halo: {halo})" + ) + else: + resolved_J = data_A.shape[1] + else: + if data_A.shape[1] - 2 * halo != data_B.shape[1]: + print( + f" [Failed] A in dim J is too big, even with halo substracted {data_A.shape[1]} (halo: {halo})" + ) + else: + resolved_J = data_B.shape[1] + else: + resolved_J = data_A.shape[1] + + if resolved_J is None: + continue + + print(f" [Hyp] Using {resolved_J} as J dim size") + + # - Assume 2 == K + if data_A.shape[2] != data_B.shape[2]: + print( + f" [Failed] Can't reconcile K dim: A ({data_A.shape[2]}) != B ({data_B.shape[2]})" + ) + continue + resolved_K = data_A.shape[2] + + print(f" [Hyp] Using {resolved_K} as K dim size") + + # We should now be ready to diff + if A_uses_halo: + data_A = data_A[halo:-halo, halo:-halo, ::] + if B_uses_halo: + data_B = data_B[halo:-halo, halo:-halo, ::] + + dims = [f"D{i}_{s}" for i, s in enumerate(data_A.shape)] + absolute_diff = data_A.data - data_B.data + dataset[name_A] = xr.DataArray( + absolute_diff, + dims=dims, + ) + + # ULP diffs + max_values = np.maximum( + np.absolute(data_A.data.flatten()), np.absolute(data_B.data.flatten()) + ) + ulp_diff = np.divide(np.abs(absolute_diff.flatten()), np.spacing(max_values)) + ulp_diff = np.sort(ulp_diff) + dataset[f"ulp_{name_A}"] = xr.DataArray( + ulp_diff, + dims=[f"GP_{ulp_diff.shape[0]}"], + ) + + print(" [Success]") + + xr.Dataset(dataset).to_netcdf(f"best_guest_diff_{pathlib.Path(netcdf_A).stem}.nc4") + + +def entry_point() -> None: + parser = get_parser() + args = parser.parse_args() + main( + netcdf_A=args.netcdf_A, + netcdf_B=args.netcdf_B, + name_mapping=args.name_mapping, + halo=args.halo, + ) + + +if __name__ == "__main__": + entry_point() diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 4e0bb428..1e5f41f7 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -18,9 +18,6 @@ from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig - -# TODO: Remove NdslNamelist import after Issue#64 is resolved. -from ndsl.namelist import Namelist as NdslNamelist from ndsl.stencils.testing.grid import Grid # type: ignore from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict @@ -78,12 +75,6 @@ def pytest_addoption(parser: pytest.Parser) -> None: default=1, help="How many indices of failures to print from worst to best. Default to 1.", ) - parser.addoption( - "--no_legacy_namelist", - action="store_true", - default=False, - help="Removes support for `ndsl.Namelist` in translate tests (which we are trying to get rid off, see NDSL issue #64). Defaults to False.", - ) parser.addoption( "--grid", action="store", @@ -268,9 +259,6 @@ def sequential_savepoint_cases( sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") - # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. - use_legacy_namelist = not metafunc.config.getoption("no_legacy_namelist") - return _savepoint_cases( savepoint_names, ranks, @@ -283,7 +271,6 @@ def sequential_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, - use_legacy_namelist=use_legacy_namelist, # Issue#64: tmp flag ) @@ -299,7 +286,6 @@ def _savepoint_cases( topology_mode: str, sort_report: str, no_report: bool, - use_legacy_namelist: bool, # Issue#64: tmp flag ) -> list[SavepointCase]: grid_params = grid_params_from_f90nml(namelist) return_list = [] @@ -335,12 +321,6 @@ def _savepoint_cases( grid_indexing=grid.grid_indexing, ) for test_name in sorted(list(savepoint_names)): - # Temporary check (Issue#64): TODO Remove check and conversion from - # f90nml.Namelist to ndsl.Namelist after ndsl.Namelist is removed - if use_legacy_namelist and not isinstance(namelist, NdslNamelist): - assert isinstance(namelist, Namelist) - namelist = NdslNamelist.from_f90nml(namelist) - testobj = get_test_class_instance( test_name, grid, namelist, stencil_factory ) @@ -402,9 +382,6 @@ def parallel_savepoint_cases( grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = get_savepoint_restriction(metafunc) - # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. - use_legacy_namelist = not metafunc.config.getoption("no_legacy_namelist") - return _savepoint_cases( savepoint_names, [mpi_rank], @@ -417,7 +394,6 @@ def parallel_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, - use_legacy_namelist=use_legacy_namelist, # Issue#64: tmp flag ) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 41c7d356..6437f561 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -159,9 +159,7 @@ def sizer(self): @property def quantity_factory(self) -> QuantityFactory: if self._quantity_factory is None: - self._quantity_factory = QuantityFactory.from_backend( - self.sizer, backend=self.backend - ) + self._quantity_factory = QuantityFactory(self.sizer, backend=self.backend) return self._quantity_factory def make_quantity( diff --git a/ndsl/stencils/testing/parallel_translate.py b/ndsl/stencils/testing/parallel_translate.py index 85b87020..d23d5459 100644 --- a/ndsl/stencils/testing/parallel_translate.py +++ b/ndsl/stencils/testing/parallel_translate.py @@ -7,9 +7,6 @@ from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS from ndsl.dsl import gt4py_utils as utils - -# TODO: Remove once ndsl.Namelist is gone (Issue#64) -from ndsl.namelist import Namelist as NdslNamelist from ndsl.quantity import Quantity from ndsl.stencils.testing.translate import ( TranslateFortranData2Py, @@ -133,12 +130,6 @@ def rank_grids(self): @property def layout(self): - # TODO: Once ndsl.namelist.Namelist is gone (Issue#64), - # remove this check in favor of f90nml.namelist.Namelist - if isinstance(self.namelist, NdslNamelist): - return self.namelist.layout - - # Assumption: namelist is f90nml.namelist.Namelist grid_params = grid_params_from_f90nml(self.namelist) return grid_params["layout"] diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index a350e76f..8ba789cb 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -258,6 +258,11 @@ def test_sequential_savepoint( output_data = gt_utils.asarray(output[varname]) if multimodal_metric: metric = MultiModalFloatMetric( + input_values=( + original_input_data[varname] + if varname in original_input_data.keys() + else None + ), reference_values=ref_data, computed_values=output_data, absolute_eps_override=case.testobj.mmr_absolute_eps, @@ -391,6 +396,7 @@ def test_parallel_savepoint( if not case.exists: pytest.skip(f"Data at rank {case.grid.rank} does not exists") input_data = dataset_to_dict(case.ds_in) + original_input_data = copy.deepcopy(input_data) # run python version of functionality output = case.testobj.compute_parallel(input_data, communicator) out_vars = set(case.testobj.outputs.keys()) @@ -414,13 +420,16 @@ def test_parallel_savepoint( output_data = gt_utils.asarray(output[varname]) if multimodal_metric: metric = MultiModalFloatMetric( + input_values=( + original_input_data[varname] + if varname in original_input_data.keys() + else None + ), reference_values=ref_data[varname][0], computed_values=output_data, absolute_eps_override=case.testobj.mmr_absolute_eps, relative_fraction_override=case.testobj.mmr_relative_fraction, ulp_override=case.testobj.mmr_ulp, - ignore_near_zero_errors=ignore_near_zero, - near_zero=case.testobj.near_zero, sort_report=case.sort_report, ) else: diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 44b4077b..c1a5b1ef 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -224,6 +224,7 @@ class MultiModalFloatMetric(BaseMetric): def __init__( self, + input_values: np.ndarray | None, reference_values: np.ndarray, computed_values: np.ndarray, absolute_eps_override: float = -1, @@ -256,6 +257,15 @@ def __init__( self.check = np.all(self.success) self.sort_report = sort_report + # We might have sliced outputs in the translate test. Rather than funnel the slicing + # all the way down, we bail out if we can measure input vs reference + if input_values is not None and input_values.shape == reference_values.shape: + self.number_changing_values = ( + (input_values != reference_values).flatten().shape[0] + ) + else: + self.number_changing_values = None + def _compute_all_metrics( self, ) -> npt.NDArray[np.bool_]: @@ -329,9 +339,19 @@ def report(self, file_path: str | None = None) -> list[str]: # List all errors to terminal and file bad_indices_count = len(failed_indices[0]) full_count = len(self.references.flatten()) - failures_pct = round(100.0 * (bad_indices_count / full_count), 2) + failures_of_all_grid_points_pct = round( + 100.0 * (bad_indices_count / full_count), 2 + ) + if self.number_changing_values is not None: + failures_of_changing_gridpoint_pct = round( + 100.0 * (bad_indices_count / self.number_changing_values), 2 + ) + report_local_failures = f"Failures (changing grid points) ({bad_indices_count}/{self.number_changing_values}) ({failures_of_changing_gridpoint_pct}%)\n" + else: + report_local_failures = "" report = [ - f"All failures ({bad_indices_count}/{full_count}) ({failures_pct}%),\n", + f"{report_local_failures}" + f"Failures (all grid points) ({bad_indices_count}/{full_count}) ({failures_of_all_grid_points_pct}%)\n", f"Index Computed Reference " f"{'🔶 ' if not self.absolute_eps.is_default else ''}Absolute E(<{self.absolute_eps.value:.2e}) " f"{'🔶 ' if not self.relative_fraction.is_default else ''}Relative E(<{self.relative_fraction.value * 100:.2e}%) " diff --git a/ndsl/testing/dummy_comm.py b/ndsl/testing/dummy_comm.py deleted file mode 100644 index f3e93817..00000000 --- a/ndsl/testing/dummy_comm.py +++ /dev/null @@ -1 +0,0 @@ -from ndsl.comm.local_comm import LocalComm as DummyComm # noqa diff --git a/ndsl/units.py b/ndsl/units.py deleted file mode 100644 index 7fb441b1..00000000 --- a/ndsl/units.py +++ /dev/null @@ -1,33 +0,0 @@ -import warnings - - -def ensure_equal_units(units1: str, units2: str) -> None: - warnings.warn( - "`ensure_equal_units` is unused and usage is discouraged. The function " - "will be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - if not units_are_equal(units1, units2): - raise UnitsError(f"incompatible units {units1} and {units2}") - - -def units_are_equal(units1: str, units2: str) -> bool: - warnings.warn( - "`units_are_equal` is unused and usage is discouraged. The function will " - "be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - return units1.strip() == units2.strip() - - -class UnitsError(Exception): - def __init__(self, *args) -> None: - warnings.warn( - "`UnitsError` is unused and usage is discouraged. The class will be " - "removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args) diff --git a/pyproject.toml b/pyproject.toml index cf8e0a69..8bf3ab7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,11 +23,13 @@ requires-python = ">=3.11,<3.12" [project.optional-dependencies] demos = ["ipython", "ipykernel"] dev = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "pre-commit", "flake8-pyproject", "build"] -docs = ["mkdocs-material", "mkdocstrings[python]"] +docs = ["mkdocs-material", "mkdocstrings[python]", "mkdocs-exclude"] extras = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "ndsl[dev]"] test = ["pytest", "coverage"] +zarr = ["zarr<3"] [project.scripts] +best_guess_diff = "ndsl.stencils.testing.best_guess_diff:entry_point" ndsl-serialbox_to_netcdf = "ndsl.stencils.testing.serialbox_to_netcdf:entry_point" [project.urls] diff --git a/setup.py b/setup.py index 1719226f..6d109c36 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ def local_pkg(name: str, relative_path: str) -> str: "cftime", "xarray>=2025.01.2", # datatree + fixes "f90nml>=1.1.0", - "fsspec", "netcdf4==1.7.2", "scipy", # restart capacities only "h5netcdf", # for xarray diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py index d0c9f1b0..2ff2c16a 100644 --- a/tests/dsl/orchestration/test_call.py +++ b/tests/dsl/orchestration/test_call.py @@ -1,8 +1,11 @@ -from ndsl import QuantityFactory, StencilFactory +import dataclasses + +from ndsl import NDSLRuntime, QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile_orchestrated -from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval +from ndsl.quantity import Quantity, State def _stencil(out: Field[float]): @@ -40,3 +43,40 @@ def test_memory_reallocation(): code(qty_B) assert (qty_A.field[0, 0, :] == 3).all() assert (qty_B.field[0, 0, :] == 2).all() + + +@dataclasses.dataclass +class AState(State): + the_quantity: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +class DefaultTypeProgram(NDSLRuntime): + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ): + super().__init__(stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo(_stencil, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, a_quantity: Quantity, a_state: AState): + self.stencil(a_quantity) + self.stencil(a_state.the_quantity) + + +def test_default_types_are_compiletime(): + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 2, 0 + ) + qty_A = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "A") + state_A = AState.zeros(quantity_factory) + code = DefaultTypeProgram(stencil_factory, quantity_factory) + code(qty_A, state_A) diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 326eb222..9f9c165c 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -7,7 +7,7 @@ CompilationConfig, CubedSphereCommunicator, CubedSpherePartitioner, - NullComm, + LocalComm, RunMode, TilePartitioner, ) @@ -34,7 +34,7 @@ def test_check_communicator_valid( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int((sqrt(size / 6))))) ) - comm = NullComm(rank=0, total_ranks=size) + comm = LocalComm(rank=0, total_ranks=size, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -52,7 +52,7 @@ def test_check_communicator_invalid( nx: int, ny: int, use_minimal_caching: bool, run_mode: RunMode ) -> None: partitioner = CubedSpherePartitioner(TilePartitioner((nx, ny))) - comm = NullComm(rank=0, total_ranks=nx * ny * 6) + comm = LocalComm(rank=0, total_ranks=nx * ny * 6, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -90,7 +90,7 @@ def test_get_decomposition_info_from_comm( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int(sqrt(size / 6)))) ) - comm = NullComm(rank=rank, total_ranks=size) + comm = LocalComm(rank=rank, total_ranks=size, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig(use_minimal_caching=True, run_mode=RunMode.Run) ( @@ -130,7 +130,7 @@ def test_determine_compiling_equivalent( TilePartitioner((sqrt(size / 6), sqrt(size / 6))) ) comm = unittest.mock.MagicMock() - comm = NullComm(rank=rank, total_ranks=size) + comm = LocalComm(rank=rank, total_ranks=size, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) assert ( config.determine_compiling_equivalent(rank, cubed_sphere_comm.partitioner) diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 7130853e..1f7f3836 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -108,7 +108,9 @@ def test_domain_size_comparison( domain: tuple[int], call_count: int, ): - quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent) + quantity = Quantity( + np.zeros(extent), dimensions, "n/a", extent=extent, backend="debug" + ) stencil = FrozenStencil( copy_stencil, origin=(0, 0, 0), @@ -147,7 +149,9 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None: def test_stencil_2D_temporaries() -> None: domain = (2, 2, 5) - quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + quantity = Quantity( + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + ) stencil = FrozenStencil( two_dim_temporaries_stencil, origin=(0, 0, 0), @@ -156,3 +160,30 @@ def test_stencil_2D_temporaries() -> None: ) stencil(quantity) assert (quantity.data[1, 1, :] == 21.0).all() + + +@pytest.mark.parametrize( + "iterations", + [2, 1], +) +def test_validation_call_count(iterations: tuple[int]): + domain = (2, 2, 5) + quantity = Quantity( + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + ) + stencil_config = StencilConfig( + compilation_config=CompilationConfig(backend="numpy", rebuild=True) + ) + stencil = FrozenStencil( + copy_stencil, + origin=(0, 0, 0), + domain=domain, + stencil_config=stencil_config, + ) + # with expectation: + counting_mock = MagicMock() + with patch.object(FrozenStencil, "_validate_quantity_sizes", counting_mock): + for _i in range(iterations): + stencil(quantity, quantity) + + assert counting_mock.call_count == 1 diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 89498695..5cdf79dc 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("rebuild", [True, False]) @pytest.mark.parametrize("format_source", [True, False]) @pytest.mark.parametrize("compare_to_numpy", [True, False]) -@pytest.mark.parametrize("backend", ["numpy", "gt_gpu"]) +@pytest.mark.parametrize("backend", ["numpy", "gt:gpu"]) def test_same_config_equal( backend: str, rebuild: bool, @@ -71,7 +71,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - backend="fake_backend", + backend="debug", rebuild=rebuild, validate_args=validate_args, format_source=format_source, diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 0458389e..49b7b2a9 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -313,24 +313,38 @@ def test_frozen_field_after_parameter() -> None: ) +@pytest.mark.parametrize("backend", ("numpy", "gt:gpu")) def test_backend_options( + backend: str, rebuild: bool = True, validate_args: bool = True, ) -> None: - backend = "numpy" expected_options = { - "backend": "numpy", - "rebuild": True, - "format_source": False, - "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + "numpy": { + "backend": "numpy", + "rebuild": True, + "format_source": False, + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + }, + "gt:gpu": { + "backend": "gt:gpu", + "rebuild": True, + "device_sync": False, + "format_source": False, + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + }, } actual = get_stencil_config( - backend=backend, - rebuild=rebuild, - validate_args=validate_args, + backend=backend, rebuild=rebuild, validate_args=validate_args ).stencil_kwargs(func=copy_stencil) - assert actual == expected_options + expected = expected_options[backend] + assert actual == expected + + +def test_illegal_backend_options(): + with pytest.raises(ValueError): + get_stencil_config(backend="illegal") def get_mock_quantity(): diff --git a/tests/grid/test_eta.py b/tests/grid/test_eta.py index 4090b3b5..c4fab0eb 100755 --- a/tests/grid/test_eta.py +++ b/tests/grid/test_eta.py @@ -1,13 +1,11 @@ from pathlib import Path -import numpy as np import pytest -import xarray as xr from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - NullComm, + LocalComm, QuantityFactory, SubtileGridSizer, TilePartitioner, @@ -23,58 +21,6 @@ """ -@pytest.mark.parametrize("levels", [79, 91]) -def test_set_hybrid_pressure_coefficients_correct(levels): - """ - This test checks to see that the ak and bk arrays are read-in correctly and are - stored as expected. Both values of km=79 and km=91 are tested and both tests are - expected to pass with the stored ak and bk values agreeing with the values read-in - directly from the NetCDF file. - """ - - eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" - eta_data = xr.open_dataset(eta_file) - - backend = "numpy" - - layout = (1, 1) - - nz = levels - ny = 48 - nx = 48 - nhalo = 3 - - partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - - communicator = CubedSphereCommunicator(NullComm(rank=0, total_ranks=6), partitioner) - - sizer = SubtileGridSizer.from_tile_params( - nx_tile=nx, - ny_tile=ny, - nz=nz, - n_halo=nhalo, - layout=layout, - tile_partitioner=partitioner.tile, - tile_rank=communicator.tile.rank, - ) - - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - - metric_terms = MetricTerms( - quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file - ) - - ak_results = metric_terms.ak.data - bk_results = metric_terms.bk.data - ak_answers, bk_answers = eta_data["ak"].values, eta_data["bk"].values - - assert ak_answers.size == ak_results.size, "Unexpected size of bk" - assert bk_answers.size == bk_results.size, "Unexpected size of ak" - - assert np.array_equal(ak_answers, ak_results), "Unexpected value of ak" - assert np.array_equal(bk_answers, bk_results), "Unexpected value of bk" - - def test_set_hybrid_pressure_coefficients_nofile(): """ This test checks to see that the program fails when the eta_file is not specified @@ -94,7 +40,9 @@ def test_set_hybrid_pressure_coefficients_nofile(): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(NullComm(rank=0, total_ranks=6), partitioner) + communicator = CubedSphereCommunicator( + LocalComm(rank=0, total_ranks=6, buffer_dict={}), partitioner + ) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, @@ -106,7 +54,7 @@ def test_set_hybrid_pressure_coefficients_nofile(): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) with pytest.raises(ValueError, match=f"eta file {eta_file} does not exist"): MetricTerms( @@ -137,7 +85,9 @@ def test_set_hybrid_pressure_coefficients_not_mono(): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(NullComm(rank=0, total_ranks=6), partitioner) + communicator = CubedSphereCommunicator( + LocalComm(rank=0, total_ranks=6, buffer_dict={}), partitioner + ) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, @@ -149,7 +99,7 @@ def test_set_hybrid_pressure_coefficients_not_mono(): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) with pytest.raises(ValueError, match="ETA values are not monotonically increasing"): MetricTerms( diff --git a/tests/initialization/test_allocator.py b/tests/initialization/test_allocator.py index 3ee01e43..dad8d407 100644 --- a/tests/initialization/test_allocator.py +++ b/tests/initialization/test_allocator.py @@ -1,19 +1,8 @@ -import warnings - -import numpy as np import pytest from ndsl import QuantityFactory -def test_QuantityFactory_constructor_warns() -> None: - with pytest.warns( - DeprecationWarning, - match="Usage of QuantityFactory.* is discouraged and will change", - ): - QuantityFactory(None, np) - - # Make sure no warnings are emitted if users use `QuantityFactory.from_backend()` - with warnings.catch_warnings(): - warnings.simplefilter("error") - QuantityFactory.from_backend(None, "numpy") +def test_QuantityFactory_from_backend_warns() -> None: + with pytest.deprecated_call(): + QuantityFactory.from_backend(None, backend="numpy") diff --git a/tests/mpi/__init__.py b/tests/mpi/__init__.py index e69de29b..ad92926d 100644 --- a/tests/mpi/__init__.py +++ b/tests/mpi/__init__.py @@ -0,0 +1,6 @@ +from ndsl.comm.mpi import MPI + + +if MPI.COMM_WORLD.Get_size() == 1: + # not run as a parallel test, disable MPI tests + MPI = None diff --git a/tests/mpi/mpi_comm.py b/tests/mpi/mpi_comm.py deleted file mode 100644 index ad92926d..00000000 --- a/tests/mpi/mpi_comm.py +++ /dev/null @@ -1,6 +0,0 @@ -from ndsl.comm.mpi import MPI - - -if MPI.COMM_WORLD.Get_size() == 1: - # not run as a parallel test, disable MPI tests - MPI = None diff --git a/tests/mpi/test_decomposition.py b/tests/mpi/test_decomposition.py new file mode 100644 index 00000000..94fee856 --- /dev/null +++ b/tests/mpi/test_decomposition.py @@ -0,0 +1,18 @@ +from unittest.mock import MagicMock + +import pytest + +from ndsl import MPIComm +from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles +from tests.mpi import MPI + + +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") +def test_unblock_waiting_tiles(): + comm = MPIComm() + compilation_config = MagicMock(compiling_equivalent=0) + + if comm.Get_rank() == 0: + unblock_waiting_tiles(comm._comm) + else: + block_waiting_for_compilation(comm._comm, compilation_config) diff --git a/tests/mpi/test_eta.py b/tests/mpi/test_eta.py new file mode 100644 index 00000000..5c09cbf1 --- /dev/null +++ b/tests/mpi/test_eta.py @@ -0,0 +1,68 @@ +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr + +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + MPIComm, + QuantityFactory, + SubtileGridSizer, + TilePartitioner, +) +from ndsl.grid import MetricTerms +from tests.mpi import MPI + + +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") +@pytest.mark.parametrize("levels", [79, 91]) +def test_set_hybrid_pressure_coefficients_correct(levels): + """ + This test checks to see that the ak and bk arrays are read-in correctly and are + stored as expected. Both values of km=79 and km=91 are tested and both tests are + expected to pass with the stored ak and bk values agreeing with the values read-in + directly from the NetCDF file. + """ + + eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" + eta_data = xr.open_dataset(eta_file) + + backend = "numpy" + + layout = (1, 1) + + nz = levels + ny = 48 + nx = 48 + nhalo = 3 + + partitioner = CubedSpherePartitioner(TilePartitioner(layout)) + communicator = CubedSphereCommunicator(MPIComm(), partitioner) + + sizer = SubtileGridSizer.from_tile_params( + nx_tile=nx, + ny_tile=ny, + nz=nz, + n_halo=nhalo, + layout=layout, + tile_partitioner=partitioner.tile, + tile_rank=communicator.tile.rank, + ) + + quantity_factory = QuantityFactory(sizer, backend=backend) + + metric_terms = MetricTerms( + quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file + ) + + ak_results = metric_terms.ak.data + bk_results = metric_terms.bk.data + ak_answers, bk_answers = eta_data["ak"].values, eta_data["bk"].values + + assert ak_answers.size == ak_results.size, "Unexpected size of bk" + assert bk_answers.size == bk_results.size, "Unexpected size of ak" + + assert np.array_equal(ak_answers, ak_results), "Unexpected value of ak" + assert np.array_equal(bk_answers, bk_results), "Unexpected value of bk" diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 6cab1023..34a3b0dd 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -10,7 +10,7 @@ from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.mpi import MPIComm from ndsl.dsl.typing import Float -from tests.mpi.mpi_comm import MPI +from tests.mpi import MPI @pytest.fixture @@ -58,7 +58,7 @@ def test_all_reduce(communicator): data=base_array, dims=["K"], units="Some 1D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5)], dtype=Float) @@ -68,7 +68,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J"], units="Some 2D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) @@ -78,7 +78,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J", "K"], units="Some 3D unit", - gt4py_backend=backend, + backend=backend, ) global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM) @@ -98,7 +98,7 @@ def test_all_reduce(communicator): data=base_array, dims=["K"], units="New 1D unit", - gt4py_backend=backend, + backend=backend, origin=(8,), extent=(7,), ) @@ -110,7 +110,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J"], units="Some 2D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) @@ -120,7 +120,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J", "K"], units="Some 3D unit", - gt4py_backend=backend, + backend=backend, ) communicator.all_reduce( testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 76aca71e..45278cc5 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -27,7 +27,7 @@ Z_DIM, Z_INTERFACE_DIM, ) -from tests.mpi.mpi_comm import MPI +from tests.mpi import MPI @pytest.fixture @@ -271,14 +271,9 @@ def depth_quantity( data[tuple(pos)] = numpy.nan pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan - quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + return Quantity( + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) - return quantity @pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") @@ -320,11 +315,7 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype): outside of it.""" data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return quantity diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 0da2f21b..7007d5d5 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,10 +1,10 @@ import numpy as np import pytest -from ndsl import DummyComm +from ndsl import LocalComm from ndsl.buffer import recv_buffer -from ndsl.exceptions import ConcurrencyError -from tests.mpi.mpi_comm import MPI +from ndsl.comm.local_comm import ConcurrencyError +from tests.mpi import MPI worker_function_list = [] @@ -33,11 +33,11 @@ def send_recv(comm, numpy): data = numpy.asarray([rank], dtype=numpy.int32) if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") comm.Recv(data, source=rank - 1) return data @@ -50,11 +50,11 @@ def send_recv_big_data(comm, numpy): data = numpy.ones([5, 3, 96], dtype=numpy.float64) * rank if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") comm.Recv(data, source=rank - 1) return data @@ -96,11 +96,11 @@ def send_f_contiguous_buffer(comm, numpy): data = numpy.random.uniform(size=[2, 3]).T if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") comm.Recv(data, source=rank - 1) return data @@ -115,7 +115,7 @@ def send_non_contiguous_buffer(comm, numpy): recv_buffer = numpy.zeros([4, 2, 3]) if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: @@ -132,7 +132,7 @@ def send_subarray(comm, numpy): recv_buffer = numpy.zeros([2, 2, 2]) if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data[1:-1, 1:-1, 1:-1], dest=rank + 1) if rank > 0: @@ -151,11 +151,11 @@ def recv_to_subarray(comm, numpy): return_value = recv_buffer if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") try: comm.Recv(recv_buffer[1:-1, 1:-1, 1:-1], source=rank - 1) @@ -255,7 +255,7 @@ def dummy_list(total_ranks): return_list = [] for rank in range(total_ranks): return_list.append( - DummyComm(rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer) + LocalComm(rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer) ) return return_list diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index a4f8e812..7ba2eddb 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -36,6 +36,7 @@ def test_boundary_data_1_by_1_array_1_halo(): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) for side in ( WEST, @@ -71,6 +72,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy): units="m", origin=(1, 1, 1), extent=(1, 1, 1), + backend="debug", ) for side in ( WEST, @@ -109,6 +111,7 @@ def test_boundary_data_2_by_2_array_2_halo(): units="m", origin=(2, 2), extent=(2, 2), + backend="debug", ) for side in ( WEST, diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index d6b5c7cb..f0c7bb13 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -14,6 +14,7 @@ def test_deepcopy_copy_is_editable_by_view(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) # assertion below is only valid if we're overwriting the entire data through view @@ -31,6 +32,7 @@ def test_deepcopy_copy_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 @@ -46,6 +48,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py new file mode 100644 index 00000000..bb6027f5 --- /dev/null +++ b/tests/quantity/test_local.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +from ndsl import Local + + +def test_dace_data_descriptor_is_transient() -> None: + nx = 5 + shape = (nx,) + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + backend="debug", + ) + array = local.__descriptor__() + assert array.transient + + +def test_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + gt4py_backend=backend, + ) + + # make sure we assign backend + assert local.backend == backend + + # make sure we are backwards compatible (for now) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert local.gt4py_backend == backend + + +def test_backend_will_be_required() -> None: + nx = 5 + shape = (nx,) + with pytest.deprecated_call(match="`backend` will be a required argument"): + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + ) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index dccfa94f..ef94b45e 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import xarray as xr from ndsl import Quantity from ndsl.quantity.bounds import _shift_slice @@ -61,7 +62,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): - return Quantity(data, origin=origin, extent=extent, dims=dims, units=units) + return Quantity( + data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + ) def test_smaller_data_raises(data, origin, extent, dims, units): @@ -71,25 +74,55 @@ def test_smaller_data_raises(data, origin, extent, dims, units): except IndexError: pass else: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="received .* dimension names for .* dimensions: .*" + ): Quantity( - small_data, origin=origin, extent=extent, dims=dims, units=units + small_data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend="debug", ) def test_smaller_dims_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin, extent=extent, dims=dims[:-1], units=units) + with pytest.raises( + ValueError, match="received .* dimension names for .* dimensions: .*" + ): + Quantity( + data, + origin=origin, + extent=extent, + dims=dims[:-1], + units=units, + backend="debug", + ) def test_smaller_origin_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin[:-1], extent=extent, dims=dims, units=units) + with pytest.raises(ValueError, match="received .* origins for .* dimensions: .*"): + Quantity( + data, + origin=origin[:-1], + extent=extent, + dims=dims, + units=units, + backend="debug", + ) def test_smaller_extent_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin, extent=extent[:-1], dims=dims, units=units) + with pytest.raises(ValueError, match="received .* extents for .* dimensions: .*"): + Quantity( + data, + origin=origin, + extent=extent[:-1], + dims=dims, + units=units, + backend="debug", + ) def test_data_change_affects_quantity(data, quantity, numpy): @@ -228,20 +261,15 @@ def test_shift_slice(slice_in, shift, extent, slice_out): @pytest.mark.parametrize( "quantity", [ + Quantity(np.array(5), dims=[], units="", backend="debug"), Quantity( - np.array(5), - dims=[], - units="", - ), - Quantity( - np.array([1, 2, 3]), - dims=["dimension"], - units="degK", + np.array([1, 2, 3]), dims=["dimension"], units="degK", backend="debug" ), Quantity( np.random.randn(3, 2, 4), dims=["dim1", "dim_2", "dimension_3"], units="m", + backend="debug", ), Quantity( np.random.randn(8, 6, 6), @@ -249,6 +277,7 @@ def test_shift_slice(slice_in, shift, extent, slice_out): units="km", origin=(2, 2, 2), extent=(4, 2, 2), + backend="debug", ), ], ) @@ -264,7 +293,7 @@ def test_to_data_array(quantity): def test_data_setter(): - quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend="debug") # After allocation - field and data are the same (origin is 0) assert quantity.data.shape == quantity.field.shape @@ -289,3 +318,82 @@ def test_data_setter(): # Expected fail: new array is not even an array with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): quantity.data = "meh" + + +def test_constructor_with_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + quantity = Quantity( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + gt4py_backend=backend, + ) + + # make sure we assign backend + assert quantity.backend == backend + + # make sure we are backwards compatible (on the QuantityMetadata) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert quantity.gt4py_backend == backend + + +def test_from_data_array_with_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + np_data = np.empty(shape) + data_array = xr.DataArray(data=np_data, attrs={"units": "n/a"}) + quantity = Quantity.from_data_array( + data_array, + origin=(0,), + extent=(nx,), + number_of_halo_points=0, + gt4py_backend=backend, + ) + + # make sure we assign backend + assert quantity.backend == backend + + # make sure we don't assign gt4py_backend anymore (on the QuantityMetadata) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert quantity.gt4py_backend == backend + + +def test_assign_basic_data_is_deprecated() -> None: + nx = 5 + backend = "debug" + with pytest.deprecated_call( + match="Usage of basic data in Quantities is deprecated" + ): + quantity = Quantity( + data=[0, 1, 2, 3, 4], + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + backend=backend, + allow_mismatch_float_precision=True, + ) + + # make sure we can still use it (for now) + for i in range(5): + assert quantity.data[i] == i + + +def test_constructor_backend_will_be_required() -> None: + nx = 5 + shape = (nx,) + with pytest.deprecated_call(match="`backend` will be a required argument"): + local = Quantity( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + ) diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index bc39f61f..6d8dc4a4 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -53,7 +53,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): - return Quantity(data, origin=origin, extent=extent, dims=dims, units=units) + return Quantity( + data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + ) def test_numpy(quantity, backend): @@ -72,7 +74,7 @@ def test_modifying_numpy_data_modifies_view_and_field(): extent=shape, dims=["dim1", "dim2"], units="units", - gt4py_backend="numpy", + backend="numpy", ) assert np.all(quantity.data == 0) quantity.data[0, 0] = 1 @@ -99,7 +101,7 @@ def test_data_and_field_access_right_full_array_and_compute_domain(): extent=(5, 5), dims=["dim1", "dim2"], units="units", - gt4py_backend="numpy", + backend="numpy", ) assert np.all(quantity.data == 0) # Write compute domain - test data is written with the offset @@ -139,7 +141,7 @@ def test_accessing_data_does_not_break_view( extent=extent, dims=dims, units=units, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] @@ -158,6 +160,6 @@ def test_numpy_data_becomes_cupy_with_gpu_backend( extent=extent, dims=dims, units=units, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) assert isinstance(quantity.data, cp.ndarray) diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 5e527279..b7ceb0f8 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -35,7 +35,6 @@ def quantity_data_input(initial_data, numpy, backend): array[:] = initial_data else: array = initial_data - print(type(array)) return array @@ -87,6 +86,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): units="unit_string", origin=initial_origin, extent=initial_extent, + backend="debug", ) @@ -166,7 +166,13 @@ def param_product(*param_lists): ) @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_transpose( - quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy + quantity: Quantity, + target_dims, + final_data, + final_dims, + final_origin, + final_extent, + numpy, ): result = quantity.transpose(target_dims) numpy.testing.assert_array_equal(result.data, final_data) @@ -174,7 +180,7 @@ def test_transpose( assert result.origin == final_origin assert result.extent == final_extent assert result.units == quantity.units - assert result.gt4py_backend == quantity.gt4py_backend + assert result.backend == quantity.backend @pytest.mark.parametrize( @@ -213,7 +219,9 @@ def test_transpose_invalid_cases( def test_transpose_retains_attrs(numpy): - quantity = Quantity(numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string") + quantity = Quantity( + numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string", backend="debug" + ) quantity._attrs = {"long_name": "500 mb height"} transposed = quantity.transpose(["y", "x"]) assert transposed.attrs == quantity.attrs diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index 73245093..0ce44cfe 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -183,6 +183,7 @@ def quantity(request): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) ], ) @@ -216,6 +217,7 @@ def test_many_indices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) ], ) @@ -742,6 +744,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 4, @@ -754,6 +757,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 0, @@ -766,6 +770,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(-1, 0), slice(-1, 0)), np.array([[0]]), @@ -778,6 +783,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 1, @@ -798,6 +804,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[1, 2, 3], [6, 7, 8]]), @@ -816,6 +823,7 @@ def test_southwest(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.southwest[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -834,6 +842,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -846,6 +855,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, -1), 6, @@ -858,6 +868,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -870,6 +881,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 3, @@ -890,6 +902,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -908,6 +921,7 @@ def test_southeast(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.southeast[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -926,6 +940,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -938,6 +953,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, -1), 6, @@ -950,6 +966,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -962,6 +979,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -974,6 +992,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 3, @@ -994,6 +1013,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -1012,6 +1032,7 @@ def test_northwest(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.northwest[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -1030,6 +1051,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 4, @@ -1042,6 +1064,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 8, @@ -1054,6 +1077,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(0, 1)), np.array([[8]]), @@ -1066,6 +1090,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 4, @@ -1078,6 +1103,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 5, @@ -1098,6 +1124,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[7, 8, 9], [12, 13, 14]]), @@ -1116,6 +1143,7 @@ def test_northeast(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.northeast[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -1134,6 +1162,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 4, @@ -1146,6 +1175,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 0), slice(0, 0)), 4, @@ -1158,6 +1188,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(-1, 1), slice(-1, 1)), np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), @@ -1178,6 +1209,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(0, 1)), np.array([[2, 3], [7, 8], [12, 13]]), @@ -1198,6 +1230,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(3, 3), + backend="debug", ), (0,), np.array([6, 7, 8]), @@ -1216,6 +1249,7 @@ def test_interior(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) if len(view_slice) == len(quantity.dims): # skip if not transposed_result = transposed_quantity.view.interior[view_slice[::-1]] diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py new file mode 100644 index 00000000..48d2232a --- /dev/null +++ b/tests/stencils/test_stencils.py @@ -0,0 +1,174 @@ +import numpy as np +import pytest + +from ndsl import QuantityFactory, StencilFactory +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import FORWARD, computation, interval +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, set_4d_field_size +from ndsl.stencils import CopyCornersXY +from ndsl.stencils.column_operations import ( + column_max, + column_max_ddim, + column_min, + column_min_ddim, +) + + +FloatField_ddim = set_4d_field_size(2, Float) + + +@pytest.fixture +def boilerplate() -> tuple[StencilFactory, QuantityFactory]: + return get_factories_single_tile(nx=1, ny=1, nz=10, nhalo=0, backend="dace:cpu") + + +class ColumnOperations: + def __init__(self, stencil_factory: StencilFactory): + + def column_max_stencil( + data: FloatField, max_value: FloatFieldIJ, max_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + max_value, max_index = column_max(data, 0, k_end) + + def column_max_ddim_stencil( + data: FloatField_ddim, max_value: FloatFieldIJ, max_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + max_value, max_index = column_max_ddim(data, 1, 0, k_end) + + def column_min_stencil( + data: FloatField, min_value: FloatFieldIJ, min_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + min_value, min_index = column_min(data, 5, k_end) + + def column_min_ddim_stencil( + data: FloatField_ddim, min_value: FloatFieldIJ, min_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + min_value, min_index = column_min_ddim(data, 1, 5, k_end) + + self._column_max_stencil = stencil_factory.from_dims_halo( + func=column_max_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self._column_max_ddim_stencil = stencil_factory.from_dims_halo( + func=column_max_ddim_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self._column_min_stencil = stencil_factory.from_dims_halo( + func=column_min_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self._column_min_ddim_stencil = stencil_factory.from_dims_halo( + func=column_min_ddim_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__( + self, + data: FloatField, + max_value: FloatFieldIJ, + max_index: FloatFieldIJ, + min_value: FloatFieldIJ, + min_index: FloatFieldIJ, + data_ddim: FloatField_ddim, + max_value_ddim: FloatFieldIJ, + max_index_ddim: FloatFieldIJ, + min_value_ddim: FloatFieldIJ, + min_index_ddim: FloatFieldIJ, + ): + self._column_max_stencil(data, max_value, max_index) + self._column_max_ddim_stencil(data_ddim, max_value_ddim, max_index_ddim) + self._column_min_stencil(data, min_value, min_index) + self._column_min_ddim_stencil(data_ddim, min_value_ddim, min_index_ddim) + + +def test_column_operations(boilerplate): + stencil_factory, quantity_factory = boilerplate + quantity_factory.add_data_dimensions({"ddim": 2}) + data = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") + data.field[:] = [ + 47.3821, + 2.9157, + 88.6034, + 71.9275, + 53.1412, + 19.4783, + 94.2258, + 36.8099, + 64.0175, + 7.3504, + ] + + data_ddim = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM, "ddim"], "n/a") + data_ddim.field[:] = [ + [ + [ + [47.3821, 27.4825], + [2.9157, 93.1242], + [88.6034, 14.6347], + [71.9275, 58.2094], + [53.1412, 6.2369], + [19.4783, 71.5457], + [94.2258, 42.3091], + [36.8099, 89.7718], + [64.0175, 63.1910], + [7.3504, 3.4991], + ] + ] + ] + + max_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + + code = ColumnOperations(stencil_factory) + code( + data, + max_value, + max_index, + min_value, + min_index, + data_ddim, + max_value_ddim, + max_index_ddim, + min_value_ddim, + min_index_ddim, + ) + + # 3d field tests + assert max_value.field[:] == np.max(data.field[:], axis=2) + assert max_index.field[:] == np.argmax(data.field[:], axis=2) + assert min_value.field[:] == np.min(data.field[:, :, 5:], axis=2) + assert min_index.field[:] == 5 + np.argmin(data.field[:, :, 5:], axis=2) + + # 4d field tests + assert max_value_ddim.field[:] == np.max(data_ddim.field[:, :, :, 1], axis=2) + assert max_index_ddim.field[:] == np.argmax(data_ddim.field[:, :, :, 1], axis=2) + assert min_value_ddim.field[:] == np.min(data_ddim.field[:, :, 5:, 1], axis=2) + assert min_index_ddim.field[:] == 5 + np.argmin( + data_ddim.field[:, :, 5:, 1], axis=2 + ) + + +def test_CopyCornersXY_deprecation(boilerplate) -> None: + stencil_factory, _ = boilerplate + + with pytest.deprecated_call(match="Usage of CopyCornersXY is deprecated"): + CopyCornersXY(stencil_factory, [X_DIM, Y_DIM, Z_DIM], None) diff --git a/tests/stree_optimizer/__init__.py b/tests/stree_optimizer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stree_optimizer/sdfg_stree_tools.py b/tests/stree_optimizer/sdfg_stree_tools.py new file mode 100644 index 00000000..d5a07154 --- /dev/null +++ b/tests/stree_optimizer/sdfg_stree_tools.py @@ -0,0 +1,36 @@ +from types import TracebackType + +import dace + +import ndsl.dsl.dace.orchestration as orch +from ndsl import StencilFactory + + +def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: + """Get the Precompiled SDFG from the dace config dict where they are cached post + compilation and flush the cache in order for next build to re-use the function.""" + sdfg_repo = stencil_factory.config.dace_config.loaded_precompiled_SDFG + + if len(sdfg_repo.values()) != 1: + raise RuntimeError("Failure to compile SDFG") + sdfg = list(sdfg_repo.values())[0] + + sdfg_repo.clear() + + return sdfg + + +class StreeOptimization: + def __init__(self) -> None: + pass + + def __enter__(self) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/stree_optimizer/test_merge.py b/tests/stree_optimizer/test_merge.py new file mode 100644 index 00000000..8fff866d --- /dev/null +++ b/tests/stree_optimizer/test_merge.py @@ -0,0 +1,174 @@ +import dace + +from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval +from ndsl.dsl.typing import FloatField + +from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge + + +def stencil(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(...): + out_field = in_field + 1 + + +def stencil_with_self_assign(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(...): + out_field = out_field + in_field + 2 + + +def stencil_with_forward_K(in_field: FloatField, out_field: FloatField) -> None: + with computation(FORWARD), interval(...): + out_field = in_field + 3 + + +def stencil_with_different_intervals( + in_field: FloatField, + out_field: FloatField, +) -> None: + with computation(PARALLEL), interval(1, None): + out_field = in_field + 5 + + +def stencil_with_buffer_read_offset_in_K( + in_field: FloatField, out_field: FloatField, buffer: FloatField +) -> None: + with computation(PARALLEL), interval(1, None): + buffer = in_field + 6 + + with computation(PARALLEL), interval(1, None): + out_field = buffer[K - 1] + 7 + + +class OrchestratedCode: + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ) -> None: + orchestratable_methods = [ + "trivial_merge", + "missing_merge_of_forscope_and_map", + "overcompute_merge", + "block_merge_when_depandencies_is_found", + ] + for method in orchestratable_methods: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self.stencil = stencil_factory.from_dims_halo( + func=stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_forward_K = stencil_factory.from_dims_halo( + func=stencil_with_forward_K, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_buffer_read_offset_in_K = stencil_factory.from_dims_halo( + func=stencil_with_buffer_read_offset_in_K, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_different_intervals = stencil_factory.from_dims_halo( + func=stencil_with_different_intervals, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + self._buffer = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], units="") + + def trivial_merge( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil(in_field, out_field) + self.stencil(in_field, out_field) + + def missing_merge_of_forscope_and_map( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil(in_field, out_field) + self.stencil_with_forward_K(in_field, out_field) + self.stencil(in_field, out_field) + + def block_merge_when_depandencies_is_found( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil(in_field, out_field) + self.stencil_with_buffer_read_offset_in_K(in_field, out_field, self._buffer) + + def overcompute_merge( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil(in_field, out_field) + self.stencil_with_different_intervals(in_field, out_field) + + +def test_stree_merge_maps() -> None: + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + code = OrchestratedCode(stencil_factory, quantity_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + with StreeOptimization(): + # Trivial merge + code.trivial_merge(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + + assert len(all_maps) == 3 + assert (out_qty.field[:] == 2).all() + + # Merge IJ - but do not merge K map & for (missing feature) + code.missing_merge_of_forscope_and_map(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg + all_maps = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + assert len(all_maps) == 4 # 2 IJ + 2 Ks + all_loop_guard_state = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.SDFGState) and me.name.startswith("loop_guard") + ] + assert len(all_loop_guard_state) == 1 # 1 For loop + + # Overcompute merge in K - we merge and introduce an If guard + code.overcompute_merge(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg + all_maps = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + assert len(all_maps) == 3 + + # Forbid merging when data dependancy is detected + code.block_merge_when_depandencies_is_found(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg + all_maps = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged) diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py deleted file mode 100644 index fb32a35d..00000000 --- a/tests/stree_optimizer/test_optimization.py +++ /dev/null @@ -1,69 +0,0 @@ -from ndsl import StencilFactory, orchestrate -from ndsl.boilerplate import get_factories_single_tile_orchestrated -from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge -from ndsl.dsl.gt4py import PARALLEL, computation, interval -from ndsl.dsl.typing import FloatField - - -def stencil_A(in_field: FloatField, out_field: FloatField): - with computation(PARALLEL), interval(...): - out_field = in_field - - -def stencil_B(in_field: FloatField, out_field: FloatField): - with computation(PARALLEL), interval(...): - out_field = out_field + in_field * 3 - - -class TriviallyMergeableCode: - def __init__(self, stencil_factory: StencilFactory): - orchestrate(obj=self, config=stencil_factory.config.dace_config) - self.stencil_A = stencil_factory.from_dims_halo( - func=stencil_A, - compute_dims=[X_DIM, Y_DIM, Z_DIM], - ) - self.stencil_B = stencil_factory.from_dims_halo( - func=stencil_B, - compute_dims=[X_DIM, Y_DIM, Z_DIM], - ) - - def __call__(self, in_field: FloatField, out_field: FloatField): - self.stencil_A(in_field, out_field) - self.stencil_B(in_field, out_field) - - -def test_stree_roundtrip_no_opt(): - """Dev Note: - - The below code sucessfully merges top level K loop (2 loops) - How do we test it?! Running doesn't test merging and the compilation - is a near-black box. We could reach in the `dace_config.compiled_sdfg` - cache but it's keyed on the dace.program and if we can reach the program - well we can reach the SDFG and turn it into an stree for verification - Should we run orchestration "by hand"? - Can we intercept the `stree` ? After all we just want to check that! - - Test is deactivated for now""" - - return True - domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu" - ) - - code = TriviallyMergeableCode(stencil_factory) - in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") - out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - - # Temporarily flip the internal switch - import ndsl.dsl.dace.orchestration as orch - - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True - orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] - - code(in_qty, out_qty) - - assert (out_qty.field[:] == 4).all() - - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/stree_optimizer/test_transient_refine.py b/tests/stree_optimizer/test_transient_refine.py new file mode 100644 index 00000000..16314b8d --- /dev/null +++ b/tests/stree_optimizer/test_transient_refine.py @@ -0,0 +1,144 @@ +from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval +from ndsl.dsl.typing import Float, FloatField + +from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge + + +DATADIM_SIZE = 8 +DDIM_NAME = "DDIM" +DDIM_TYPE = Field[IJK, (Float, (DATADIM_SIZE))] + + +def stencil(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(...): + out_field = in_field + 1 + + +def stencil_with_K_offset(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(0, -1): + out_field = in_field[K + 1] + 2 + + +def stencil_with_JK_offset(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(...): + out_field = in_field[J + 1, K + 1] + 3 + + +def stencil_with_ddim(in_field: DDIM_TYPE, out_field: DDIM_TYPE) -> None: + with computation(PARALLEL), interval(...): + n = 0 + while n < DATADIM_SIZE: + out_field[0, 0, 0][n] = in_field[0, 0, 0][n] + 4 + n = n + 1 + + +class TransientRefineableCode(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(stencil_factory.config.dace_config) + orchestratable_methods = [ + "refine_to_scalar", + "refine_to_K_buffer", + "refine_to_JK_buffer", + "do_not_refine_datadims", + ] + for method in orchestratable_methods: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + self.stencil = stencil_factory.from_dims_halo( + func=stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_K_offset = stencil_factory.from_dims_halo( + func=stencil_with_K_offset, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_JK_offset = stencil_factory.from_dims_halo( + func=stencil_with_JK_offset, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_ddim = stencil_factory.from_dims_halo( + func=stencil_with_ddim, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + self.tmp_ddim = self.make_local( + quantity_factory, [X_DIM, Y_DIM, Z_DIM, DDIM_NAME] + ) + + def refine_to_scalar(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil(in_field, self.tmp) + self.stencil(self.tmp, out_field) + + def refine_to_K_buffer(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil(in_field, self.tmp) + self.stencil_with_K_offset(self.tmp, out_field) + + def refine_to_JK_buffer(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil(in_field, self.tmp) + self.stencil_with_JK_offset(self.tmp, out_field) + + def do_not_refine_datadims(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil_with_ddim(in_field, self.tmp_ddim) + self.stencil_with_ddim(self.tmp_ddim, out_field) + + +def test_stree_roundtrip_transient_is_refined() -> None: + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + quantity_factory.add_data_dimensions({DDIM_NAME: DATADIM_SIZE}) + in_qty_ddim = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM, DDIM_NAME], "") + out_qty_ddim = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM, DDIM_NAME], "") + + code = TransientRefineableCode(stencil_factory, quantity_factory) + + with StreeOptimization(): + # Refine to scalar + code.refine_to_scalar(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) + + # Refine cartesian axis to buffers + # IJ merges - K is a buffer + code.refine_to_K_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + 1, + domain[2] + 1, # Quantity are domain size + 1 + ) + + # I merges - JK buffer + code.refine_to_JK_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + domain[1] + 1, # Quantity are domain size + 1 + domain[2] + 1, + ) + + # Refine to remaining data dimensions + code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1 diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 0d707240..74916015 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -128,12 +128,14 @@ def test_copy(): data=np.zeros([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.ones([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) copy(f_in=infield.data, f_out=outfield.data) @@ -148,18 +150,21 @@ def test_adjustmentfactor(): data=np.full(shape=[20, 20], fill_value=2.0), dims=[X_DIM, Y_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=4.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) adfact(factor=factorfield.data, f_out=outfield.data) @@ -173,12 +178,14 @@ def test_setvalue(): data=np.zeros(shape=[20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) setvalue(f_out=outfield.data, value=2.0) @@ -193,18 +200,21 @@ def test_adjustdivide(): data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=1.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) addiv(factor=factorfield.data, f_out=outfield.data) diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index b531453d..0e00d459 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -45,7 +45,7 @@ def test_boilerplate_import_numpy(): # Ensure backend is propagated to StencilFactory and QuantityFactory assert stencil_factory.backend == "numpy" - assert quantity_factory._backend() == "numpy" + assert quantity_factory.backend == "numpy" _copy_ops(stencil_factory, quantity_factory) @@ -64,7 +64,7 @@ def test_boilerplate_import_orchestrated_cpu(): # Ensure backend is propagated to StencilFactory and QuantityFactory assert stencil_factory.backend == "dace:cpu" - assert quantity_factory._backend() == "dace:cpu" + assert quantity_factory.backend == "dace:cpu" _copy_ops(stencil_factory, quantity_factory) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 1498402d..de7723ce 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -187,6 +187,5 @@ def test_new_args_gives_different_buffer(allocator, backend, first_args, second_ @pytest.mark.parametrize("allocator, backend", [["ones", "cupy"]], indirect=True) def test_mpi_unsafe_allocator_exception(backend, allocator): BUFFER_CACHE.clear() - print(allocator) with pytest.raises(RuntimeError): Buffer.pop_from_cache(allocator, shape=(10, 10, 10), dtype=float) diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index d2d7f64e..bdbab4cd 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -8,7 +8,6 @@ CubedSphereCommunicator, CubedSpherePartitioner, LocalComm, - NullComm, Quantity, TilePartitioner, ) @@ -30,6 +29,7 @@ def test_halo_update_integration(): units="", origin=origin, extent=extent, + backend="debug", ) for _ in range(n_ranks) ] @@ -75,29 +75,38 @@ def perform_serial_halo_updates( def test_Recv_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) + comm = CachingCommWriter(comm=LocalComm(rank=0, total_ranks=6, buffer_dict={})) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.received_buffers) == 0 - comm.Recv(recvbuf, source=0) - assert len(comm._data.received_buffers) == 1 - assert comm._data.received_buffers[0].shape == shape + + if comm.Get_rank() == 0: + comm.bcast(np.random.randn(*shape)) + else: + comm.Recv(recvbuf, source=0) + assert len(comm._data.received_buffers) == 1 + assert comm._data.received_buffers[0].shape == shape def test_Irecv_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) + comm = CachingCommWriter(comm=LocalComm(rank=0, total_ranks=6, buffer_dict={})) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.received_buffers) == 0 - req = comm.Irecv(recvbuf, source=0) - assert len(comm._data.received_buffers) == 0 - req.wait() - assert len(comm._data.received_buffers) == 1 - assert comm._data.received_buffers[0].shape == shape + + if comm.Get_rank() == 0: + comm.Isend(np.random.randn(*shape), dest=2) + + if comm.Get_rank() == 2: + req = comm.Irecv(recvbuf, source=0) + assert len(comm._data.received_buffers) == 0 + req.wait() + assert len(comm._data.received_buffers) == 1 + assert comm._data.received_buffers[0].shape == shape def test_bcast_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) + comm = CachingCommWriter(comm=LocalComm(rank=0, total_ranks=6, buffer_dict={})) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.bcast_objects) == 0 diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 236e1eb9..5fe4f5b9 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -6,7 +6,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -100,7 +100,7 @@ def communicator_list(layout): for rank in range(total_ranks): return_list.append( CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), timer=Timer(), ) @@ -169,6 +169,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), + backend="numpy", ) diff --git a/tests/test_decomposition.py b/tests/test_decomposition.py index 2160e23e..2077a609 100644 --- a/tests/test_decomposition.py +++ b/tests/test_decomposition.py @@ -6,13 +6,10 @@ from ndsl import CubedSpherePartitioner, TilePartitioner from ndsl.comm.decomposition import ( - block_waiting_for_compilation, build_cache_path, check_cached_path_exists, determine_rank_is_compiling, - unblock_waiting_tiles, ) -from tests.mpi.mpi_comm import MPI @pytest.mark.parametrize( @@ -35,7 +32,7 @@ def test_determine_rank_is_compiling( def test_check_cached_path_exists(): with pytest.raises(RuntimeError): - check_cached_path_exists("notarealpath") + check_cached_path_exists("not/a/real/path") def test_check_cached_path_exists_working(): @@ -67,18 +64,3 @@ def test_build_cache_path( ) _, rank_str = build_cache_path(compilation_config) assert rank_str == target_rank_str - - -@pytest.mark.skipif( - MPI is None, - reason="pytest is not run in parallel", -) -def test_unblock_waiting_tiles(): - comm = MPI.COMM_WORLD - compilation_config = unittest.mock.MagicMock(compiling_equivalent=0) - rank = comm.Get_rank() - size = comm.Get_size() - if rank != 0: - block_waiting_for_compilation(comm, compilation_config) - if rank == 0: - unblock_waiting_tiles(comm) diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index ebad25ef..9d78a4e0 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -2,7 +2,7 @@ import pytest -from ndsl import QuantityFactory, SubtileGridSizer +from ndsl import GridSizer, QuantityFactory, SubtileGridSizer from ndsl.constants import ( N_HALO_DEFAULT, X_DIM, @@ -55,7 +55,7 @@ def extra_dimension_lengths(): @pytest.fixture def namelist(nx_tile, ny_tile, nz, layout): - namelist = { + return { "fv_core_nml": { "npx": nx_tile + 1, "npy": ny_tile + 1, @@ -63,13 +63,14 @@ def namelist(nx_tile, ny_tile, nz, layout): "layout": layout, } } - return namelist @pytest.fixture(params=["from_namelist", "from_tile_params"]) -def sizer(request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths): +def sizer( + request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths +) -> GridSizer: if request.param == "from_tile_params": - sizer = SubtileGridSizer.from_tile_params( + return SubtileGridSizer.from_tile_params( nx_tile=nx_tile, ny_tile=ny_tile, nz=nz, @@ -77,11 +78,11 @@ def sizer(request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengt layout=layout, data_dimensions=extra_dimension_lengths, ) - elif request.param == "from_namelist": - sizer = SubtileGridSizer.from_namelist(namelist) - else: - raise NotImplementedError() - return sizer + + if request.param == "from_namelist": + return SubtileGridSizer.from_namelist(namelist) + + raise NotImplementedError() @pytest.fixture @@ -109,7 +110,7 @@ def dtype(request): "z_y_x", ] ) -def dim_case(request, nx, ny, nz): +def dim_case(request, nx, ny, nz) -> DimCase: if request.param == "x_only": return DimCase( (X_DIM,), @@ -117,32 +118,38 @@ def dim_case(request, nx, ny, nz): (nx,), (2 * N_HALO_DEFAULT + nx + 1,), ) - elif request.param == "x_interface_only": + + if request.param == "x_interface_only": return DimCase( (X_INTERFACE_DIM,), (N_HALO_DEFAULT,), (nx + 1,), (2 * N_HALO_DEFAULT + nx + 1,), ) - elif request.param == "y_only": + + if request.param == "y_only": return DimCase( (Y_DIM,), (N_HALO_DEFAULT,), (ny,), (2 * N_HALO_DEFAULT + ny + 1,), ) - elif request.param == "y_interface_only": + + if request.param == "y_interface_only": return DimCase( (Y_INTERFACE_DIM,), (N_HALO_DEFAULT,), (ny + 1,), (2 * N_HALO_DEFAULT + ny + 1,), ) - elif request.param == "z_only": + + if request.param == "z_only": return DimCase((Z_DIM,), (0,), (nz,), (nz + 1,)) - elif request.param == "z_interface_only": + + if request.param == "z_interface_only": return DimCase((Z_INTERFACE_DIM,), (0,), (nz + 1,), (nz + 1,)) - elif request.param == "x_y": + + if request.param == "x_y": return DimCase( ( X_DIM, @@ -155,7 +162,8 @@ def dim_case(request, nx, ny, nz): 2 * N_HALO_DEFAULT + ny + 1, ), ) - elif request.param == "z_y_x": + + if request.param == "z_y_x": return DimCase( ( Z_DIM, @@ -171,6 +179,8 @@ def dim_case(request, nx, ny, nz): ), ) + raise NotImplementedError() + @pytest.mark.cpu_only def test_subtile_dimension_sizer_origin(sizer, dim_case): @@ -191,7 +201,7 @@ def test_subtile_dimension_sizer_shape(sizer, dim_case): def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory.from_backend(sizer, "numpy") + allocator = QuantityFactory(sizer, backend="numpy") quantity = allocator.zeros(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -202,7 +212,7 @@ def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): def test_allocator_ones(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory.from_backend(sizer, "numpy") + allocator = QuantityFactory(sizer, backend="numpy") quantity = allocator.ones(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -213,7 +223,7 @@ def test_allocator_ones(numpy, sizer, dim_case, units, dtype): def test_allocator_empty(sizer, dim_case, units, dtype): - allocator = QuantityFactory.from_backend(sizer, "numpy") + allocator = QuantityFactory(sizer, backend="numpy") quantity = allocator.empty(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -223,7 +233,7 @@ def test_allocator_empty(sizer, dim_case, units, dtype): def test_allocator_data_dimensions_operations(sizer): - quantity_factory = QuantityFactory.from_backend(sizer, "numpy") + quantity_factory = QuantityFactory(sizer, backend="numpy") quantity_factory.add_data_dimensions({"D0": 11}) assert "D0" in quantity_factory.sizer.data_dimensions.keys() assert quantity_factory.sizer.data_dimensions["D0"] == 11 diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py deleted file mode 100644 index f5d66eb7..00000000 --- a/tests/test_exceptions.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - -from ndsl import OutOfBoundsError - - -def test_OutOfBoundsError_is_deprecation() -> None: - with pytest.deprecated_call(): - OutOfBoundsError("This should trigger a deprecation warning.") diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py deleted file mode 100644 index 5f05af82..00000000 --- a/tests/test_filesystem.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -import ndsl.filesystem as fs - - -def test_is_file_is_deprecated() -> None: - with pytest.deprecated_call(): - fs.is_file("path/to/my_file.txt") - - -def test_open_is_deprecated() -> None: - with pytest.deprecated_call(): - with fs.open("README.md", "r"): - pass - - -def test_get_fs_is_deprecated() -> None: - with pytest.deprecated_call(): - fs.get_fs("path/to/my/file.txt") diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index dab27cb3..b7af3e12 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -12,7 +12,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -56,7 +56,7 @@ def cpu_communicators(cube_partitioner): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), force_cpu=True, @@ -74,7 +74,7 @@ def gpu_communicators(cube_partitioner): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, @@ -137,7 +137,6 @@ def test_halo_update_only_communicate_on_gpu(backend, gpu_communicators): # We expect no np calls and several cp calls global N_ZEROS_CALLS # noqa: F824 global ... is unused - print(f"Results {N_ZEROS_CALLS}") assert N_ZEROS_CALLS[cp.zeros] > 0 assert N_ZEROS_CALLS[np.zeros] == 0 diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 7e3385e7..2d34567a 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -165,20 +165,18 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): """A list of quantities whose values are 42.42 in the computational domain and 1 outside of it.""" sz = _shape_length(shape) - print(f"{shape} {sz}") data = np.arange(0, sz, dtype=dtype).reshape(shape) if "gtc" not in gt4py_backend: # should also test code if gt4py_backend is unset gt4py_backend = None - quantity = Quantity( + return Quantity( data, dims=dims, units=units, origin=origin, extent=extent, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) - return quantity @pytest.fixture(params=[-0, -1, -2, -3]) diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index ca76ab8b..e4ca02df 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -6,8 +6,8 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, HaloUpdater, + LocalComm, Quantity, TileCommunicator, TilePartitioner, @@ -204,7 +204,7 @@ def communicator_list(cube_partitioner: CubedSpherePartitioner): for rank in range(total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, @@ -222,7 +222,7 @@ def tile_communicator_list(tile_partitioner): for rank in range(total_ranks): return_list.append( TileCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=tile_partitioner, @@ -319,11 +319,7 @@ def depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) return_list.append(quantity) return return_list @@ -356,11 +352,7 @@ def tile_depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) return_list.append(quantity) return return_list @@ -500,11 +492,7 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, for _rank in range(total_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return_list.append(quantity) @@ -521,11 +509,7 @@ def zeros_quantity_tile_list( for _rank in range(single_tile_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return_list.append(quantity) diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index 6ceb4886..4a101f37 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -3,7 +3,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -103,7 +103,7 @@ def communicator_list(cube_partitioner, total_ranks): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, @@ -126,6 +126,7 @@ def rank_quantity_list(total_ranks, numpy, dtype): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) quantity_list.append(quantity) return quantity_list diff --git a/tests/test_legacy_restart.py b/tests/test_legacy_restart.py index 65514747..3e4ecd03 100644 --- a/tests/test_legacy_restart.py +++ b/tests/test_legacy_restart.py @@ -10,7 +10,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -38,7 +38,7 @@ def get_c12_restart_state_list(layout, only_names, tracer_properties): communicator_list = [] for rank in range(total_ranks): communicator = CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), ) communicator_list.append(communicator) @@ -148,7 +148,7 @@ def test_open_c12_restart_empty_to_state_without_crashing(layout): communicator_list = [] for rank in range(total_ranks): communicator = CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), ) communicator_list.append(communicator) @@ -190,7 +190,7 @@ def test_open_c12_restart_to_allocated_state_without_crashing(layout): communicator_list = [] for rank in range(total_ranks): communicator = CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), ) communicator_list.append(communicator) diff --git a/tests/test_namelist.py b/tests/test_namelist.py deleted file mode 100644 index a32919e7..00000000 --- a/tests/test_namelist.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - -from ndsl import Namelist - - -def test_ndsl_namelist_deprecation() -> None: - with pytest.deprecated_call(): - my_namelist = Namelist() diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 19614486..c4cbe575 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -10,7 +10,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, NetCDFMonitor, Quantity, TilePartitioner, @@ -39,6 +39,7 @@ def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): units = "m" + backend = "debug" nz, ny, nx = shape ny_rank = int(ny / layout[0] + ny_rank_add) nx_rank = int(nx / layout[1] + nx_rank_add) @@ -53,7 +54,7 @@ def test_monitor_store_multi_rank_state( for rank in range(total_ranks): communicator = CubedSphereCommunicator( partitioner=partitioner, - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), ) @@ -74,6 +75,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store_constant(state) @@ -87,6 +89,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store(state) @@ -100,6 +103,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store_constant(state) diff --git a/tests/test_null_comm.py b/tests/test_null_comm.py index 74065f67..9ed19d9c 100644 --- a/tests/test_null_comm.py +++ b/tests/test_null_comm.py @@ -1,3 +1,5 @@ +import pytest + from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, @@ -7,10 +9,8 @@ def test_can_create_cube_communicator(): - rank = 2 - total_ranks = 24 - mpi_comm = NullComm(rank, total_ranks) - layout = (2, 2) - partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(mpi_comm, partitioner) - communicator.tile.partitioner + with pytest.deprecated_call(match="NullComm is deprecated"): + mpi_comm = NullComm(rank=2, total_ranks=24) + partitioner = CubedSpherePartitioner(TilePartitioner(layout=(2, 2))) + communicator = CubedSphereCommunicator(mpi_comm, partitioner) + assert communicator.tile.partitioner diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index cc0a105e..5b8836ac 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -993,7 +993,9 @@ def test_subtile_extent_with_tile_dimensions( cubedsphere_expected, ): data_array = np.zeros((tile_extent)) - quantity = Quantity(data_array, array_dims, "dimensionless", [0, 0, 0, 0]) + quantity = Quantity( + data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0], backend="debug" + ) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) cubedsphere_partitioner = CubedSpherePartitioner(tile_partitioner) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 7db5a621..27db0048 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -3,7 +3,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -56,7 +56,7 @@ def communicator_list(cube_partitioner, total_ranks): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, @@ -81,6 +81,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), + backend="debug", ) y_data = numpy.empty((2, 3), dtype=dtype) y_data[:] = rank @@ -90,6 +91,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), + backend="debug", ) quantity_list.append((x_quantity, y_quantity)) return quantity_list @@ -147,6 +149,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), + backend="debug", ) y_data = 6 * total_ranks + numpy.array([[0, 1, 2], [3, 4, 5]]) + 6 * rank y_quantity = Quantity( @@ -155,6 +158,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), + backend="debug", ) quantity_list.append((x_quantity, y_quantity)) return quantity_list diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index d768bb15..9ab00eba 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -1,6 +1,6 @@ import pytest -from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner +from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM @@ -20,7 +20,7 @@ def get_tile_communicator_list(partitioner): for rank in range(total_ranks): tile_communicator_list.append( TileCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=partitioner, @@ -36,11 +36,13 @@ def test_interface_state_two_by_two_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0] + 1, layout[1] + 1]), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", + backend="debug", ), "pos_i": Quantity( numpy.empty([layout[0] + 1, layout[1] + 1], dtype=numpy.int32), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", + backend="debug", ), } @@ -80,16 +82,19 @@ def test_centered_state_one_item_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), "rank_pos_j": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), "rank_pos_i": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), } @@ -137,6 +142,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), "rank_pos_j": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -144,6 +150,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), "rank_pos_i": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -151,6 +158,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), } diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 60ef583a..4e87eb65 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -3,7 +3,7 @@ import pytest -from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner +from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import ( HORIZONTAL_DIMS, X_DIM, @@ -84,7 +84,7 @@ def communicator_list(layout): for rank in range(total_ranks): return_list.append( TileCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), TilePartitioner(layout), ) ) @@ -150,6 +150,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), + backend="debug", ) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 70dd9739..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from ndsl.units import UnitsError, ensure_equal_units, units_are_equal - - -def test_UnitsError_is_deprecated() -> None: - with pytest.deprecated_call(): - UnitsError() - - -def test_units_are_equal_is_deprecated() -> None: - with pytest.deprecated_call(): - units_are_equal("asdf", "asdf") - - -def test_ensure_equal_units_is_deprecated() -> None: - with pytest.deprecated_call(): - ensure_equal_units("asdf", "asdf") diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index b3847ee4..b0943ffb 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -7,7 +7,7 @@ import pytest import xarray as xr -from ndsl import CubedSpherePartitioner, DummyComm, Quantity, TilePartitioner +from ndsl import CubedSpherePartitioner, LocalComm, MPIComm, Quantity, TilePartitioner from ndsl.constants import ( X_DIM, X_DIMS, @@ -91,40 +91,41 @@ def cube_partitioner(tile_partitioner): @pytest.fixture(params=["empty", "one_var_2d", "one_var_3d", "two_vars"]) -def base_state(request, nz, ny, nx, numpy): +def base_state(request, nz, ny, nx, numpy) -> dict: if request.param == "empty": return {} - elif request.param == "one_var_2d": + + if request.param == "one_var_2d": return { "var1": Quantity( - numpy.ones([ny, nx]), - dims=("y", "x"), - units="m", + numpy.ones([ny, nx]), dims=("y", "x"), units="m", backend="debug" ) } - elif request.param == "one_var_3d": + + if request.param == "one_var_3d": return { "var1": Quantity( numpy.ones([nz, ny, nx]), dims=("z", "y", "x"), units="m", + backend="debug", ) } - elif request.param == "two_vars": + + if request.param == "two_vars": return { "var1": Quantity( - numpy.ones([ny, nx]), - dims=("y", "x"), - units="m", + numpy.ones([ny, nx]), dims=("y", "x"), units="m", backend="debug" ), "var2": Quantity( numpy.ones([nz, ny, nx]), dims=("z", "y", "x"), units="degK", + backend="debug", ), } - else: - raise NotImplementedError() + + raise NotImplementedError() @pytest.fixture @@ -139,10 +140,17 @@ def state_list(base_state, n_times, start_time, time_step, numpy): return state_list +@requires_zarr +def test_mpi_comm_will_be_required(cube_partitioner): + with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: + with pytest.deprecated_call(match="`mpi_comm` will be a required argument"): + ZarrMonitor(tempdir, cube_partitioner) + + @requires_zarr def test_monitor_file_store(state_list, cube_partitioner, numpy, start_time): with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: - monitor = ZarrMonitor(tempdir, cube_partitioner) + monitor = ZarrMonitor(tempdir, cube_partitioner, mpi_comm=MPIComm()) for state in state_list: monitor.store(state) validate_store(state_list, tempdir, numpy, start_time) @@ -152,7 +160,7 @@ def test_monitor_file_store(state_list, cube_partitioner, numpy, start_time): @requires_zarr def validate_xarray_can_open(dirname): # just checking there are no crashes, validate_group checks data - xr.open_zarr(dirname) + xr.open_zarr(dirname, consolidated=False) @requires_zarr @@ -239,8 +247,8 @@ def test_monitor_file_store_multi_rank_state( ZarrMonitor( store, partitioner, - "w", - mpi_comm=DummyComm( + mode="w", + mpi_comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), ) @@ -253,6 +261,7 @@ def test_monitor_file_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend="debug", ), } monitor_list[rank].store(state) @@ -337,14 +346,19 @@ def _assert_no_nulls(dataset: xr.Dataset): @requires_zarr def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale): store = {} + buffer = {} # initialize store - monitor = ZarrMonitor(store, cube_partitioner) - zero_quantity = Quantity(numpy.zeros([10, 10]), dims=("y", "x"), units="m") + monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) + zero_quantity = Quantity( + numpy.zeros([10, 10]), dims=("y", "x"), units="m", backend="debug" + ) monitor.store({"var": zero_quantity}) # open w/o dask using chunks=None - dataset = xr.open_zarr(store, chunks=None, mask_and_scale=mask_and_scale) + dataset = xr.open_zarr( + store, chunks=None, mask_and_scale=mask_and_scale, consolidated=False + ) _assert_no_nulls(dataset.sel(tile=0)) @@ -354,14 +368,17 @@ def test_values_preserved(cube_partitioner, numpy): units = "m" store = {} + buffer = {} # initialize store - monitor = ZarrMonitor(store, cube_partitioner) - quantity = Quantity(numpy.random.uniform(size=(10, 10)), dims=dims, units=units) + monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) + quantity = Quantity( + numpy.random.uniform(size=(10, 10)), dims=dims, units=units, backend="debug" + ) monitor.store({"var": quantity}) # open w/o dask using chunks=None - dataset = xr.open_zarr(store, chunks=None) + dataset = xr.open_zarr(store, chunks=None, consolidated=False) numpy.testing.assert_array_almost_equal( dataset["var"][0, 0, :, :].values, quantity.data ) @@ -388,7 +405,7 @@ def test_monitor_file_store_inconsistent_calendars( state_list_with_inconsistent_calendars, cube_partitioner, numpy ): with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: - monitor = ZarrMonitor(tempdir, cube_partitioner) + monitor = ZarrMonitor(tempdir, cube_partitioner, mpi_comm=MPIComm()) initial_state, final_state = state_list_with_inconsistent_calendars monitor.store(initial_state) with pytest.raises(ValueError, match="Calendar type"): @@ -406,29 +423,31 @@ def test_monitor_file_store_inconsistent_calendars( ) def diag(request, numpy): dims = request.param - diag = Quantity( - numpy.ones([size + 2 for size in range(len(dims))]), dims=dims, units="m" + return Quantity( + numpy.ones([size + 2 for size in range(len(dims))]), + dims=dims, + units="m", + backend="debug", ) - return diag def _transpose(quantity, dims_2d, dims_3d): if len(quantity.dims) == 2: return quantity.transpose(dims_2d) - elif len(quantity.dims) == 3: + + if len(quantity.dims) == 3: return quantity.transpose(dims_3d) @pytest.fixture(scope="function") def zarr_store(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("diags.zarr") - store = zarr.storage.DirectoryStore(tmpdir) - return store + return zarr.storage.DirectoryStore(tmpdir) @pytest.fixture(scope="function") def zarr_monitor_single_rank(zarr_store, cube_partitioner): - return ZarrMonitor(zarr_store, cube_partitioner) + return ZarrMonitor(zarr_store, cube_partitioner, mpi_comm=MPIComm()) @requires_zarr @@ -440,7 +459,7 @@ def test_transposed_diags_write_across_ranks(diag, cube_partitioner, zarr_store) monitor = ZarrMonitor( zarr_store, cube_partitioner, - mpi_comm=DummyComm( + mpi_comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), ) @@ -482,6 +501,7 @@ def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): numpy.ones([size + 2 for size in range(len(diag.dims))]), dims=new_dims, units="m", + backend="debug", ) with pytest.raises(ValueError) as excinfo: zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) @@ -497,6 +517,6 @@ def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_ra diag_2 = copy.deepcopy(diag) diag_2._attrs.update({"some_non_units_attrs": 9.0}) zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) - diag_3 = Quantity(data=diag.values, dims=diag.dims, units="not_m") + diag_3 = Quantity(data=diag.view[:], dims=diag.dims, units="not_m", backend="debug") with pytest.raises(ValueError): zarr_monitor_single_rank.store({"time": time_3, "a": diag_3})