From 4b3d440580fff3c16a0b192bcb6bdf075b9ed38f Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 29 Oct 2025 10:05:25 +0100 Subject: [PATCH 01/43] Check gt4py-backend options in config (#291) * For release `2025.03.00` (#127) * updating 4d handling * debug 4d test data * more iter * moving ser_to_nc here * updating datatype in translate test * typing works * fix dict, lint * remove empty line * change from 4d to Nd * Expose `k_start` and `k_end` automatically for any FrozenStencil * Fix k_start + utest * lint * Fix for 2d stencils * Add threshold overrides to the multimodal metric * Always report results, add summary with one liners * Remove "mmr" from the keys * README in testing * Better Latex (?) * Better Latex (?) * fixing a typo that breaks bools in translate tests (#80) * Fix summary filename * Fix report, filename * Fix choosing right absolute difference for F32 * Make robust for NaN value * Detect when array have different dimensions, if only one dimension, collapse Clean up type infer and log work * Lint * Add rank 0 to the data * Check data exists for rank, skip & print if not * Fix bad logic on skip test for parallel * Verbose exported names * Make boilerplate calls more nimble * New option: `which_savepoint` Better error on bad output data Fix missing integer type check * QOL for mypy/flak8 type hints * Add SECONDS_PER_DAY as a constants following mixed precision standards * Lint * Cleanups in dace orchestration Readability improvements in dace orchestration including - early returns - spelling out variable names - fixing typos * Rename program -> dace_program * Make sure all constants adhere to the floating point precision set by the system * Move `is_float` to `dsl.typing` * Move Quantity to sub-directory + breakout the subcomponent * Fix tests * Lint * Remove `cp.ndarray` since cupy is optional * Restore workaround for optional cupy * "GFS" -> "UFS" * Cupy trick for metadata * Add comments for constant explanation * Describe 64/32-bit FloatFields * Make sure the `make_storage_data` respects the array dtype. * Fix logic for MultiModal metric and verbose it * Added an MPI all_reduce for quantities based on SUM operation to communicator.py * linted * Add initial skeleton of pytest test for all reduce * Added assertion tests for 1, 2 and 3D quantities passed through mpi_allreduce_sum * Linted * Added pytest.mark to skip test if mpi4py isn't available * lint changes * Addressed PR comments and added additional CPU backends to unit test * Added setters for various Quantity properties to enable setting of Quantity metadata and data properties. * Added function in QuantityMetadata class that allows copying of Metadata properties from one class to another. Subsequent Quantity setters that performed the copying of QuantityMetadata properties were removed * Expose all SG metric terms in grid_data * Add `Allreduce` and all MPI OP * Update utest * Fix `local_comm` * Fix utest * Enforce `comm_abc.Comm` into Communicator * Fix `comm` object in serial utest * Lint + `MPIComm` on testing architecture * Make sure the correct allocator backend is used for Quantities * Add in_place option for Allreduce * Cleanup ndsl/dsl/dace/utils.py (#96) * Fix typos * DaCeProgress: avoid double assignment of prefix * Add type hints/simplify kernel_theoretical_timing Adding type hints allowed to simplify `kernel_theoretical_timing`. * Fix merge * Hotfix for grid generation use of mpi operators * Merge examples/mpi/.gitignore into top-level .gitignore * Remove hard-coded __version__ numbers Removes hard-coded version numbers from `__init__` files. * Fixing a bunch of typos * hotfix netcdf version for dockerfiles * Updated version number in setup.py to reflect new release, 2025.01.00 * Adding in exception for compute domains with sizes less than or equal to halo size (#103) * Adding in exception for compute domains with less than 4 points to vector_halo_update method * Updated exception in communicator to compare halo size to compute domain size * linting * Moved domain size checker to SubtileGridSizer class method from_tile_params * Fix passing down ak/bk for pressure coefficients when they are available from an outside source (online model case) (#107) * [QOL] Logging, Type Hints and Quantity helpers (#108) * Log on rank 0 Docstrings & typi hints on logger Stencil Config has a `verbose` option On verbose: FrozenStencil log when run (in GT backends) * Update `config` in orchestrate call to solve type hint inconcistencies * Quantity helper `to_netcdf` with multi rank support * Automatic Int precision and stencil regeneration change (#104) * Added feature to enable automatic detection of integer precision. Should remove the need for i32/i64 declaration (although their functionality is still retained) and replace both with the regular Int type * change default rebuild state to false for get_factories * Merged Float and Int precision detection functions into one common path * Re-added old function to fulfil a PACE dependency * updated docstring * Added ability to declare 32 or 64 bit IntFields, overrulling the system precision * Added one dimensional bool fields * Fix error message in typing.py Co-authored-by: Florian Deconinck * output type for global_set_precision --------- Co-authored-by: Florian Deconinck * Bump DaCe to v1.0.1 (#109) Our current DaCe version is some commit from September 2024. Meanwhile DaCe matured to v1 and recently release v1.0.1. This brings the DaCe submodule to the latest stable release version. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Streamline linting workflow (#110) Linting should give fast feedback. The current workflow takes ~3mins where most of the time is spent installing (unnecessary) python packages. To run `pre-commit`, we only need the source files and `pre-commit` itself, which can be installed standalone. This brings runtime of the linting stage down to ~30 seconds. Other changes - update checkout action to v4 - update python setup action to v5 - change python version from 3.11.7 to 3.11 (any patch number will do) This is a follow-up of PR https://github.com/NOAA-GFDL/PyFV3/pull/40 in PyFV3. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * [FIX] Type hint for precision dependant Float, Int (#111) * Fix the type hint of Float, Int * Attempt using TypeAlias * Feature: Adding documentation (#97) * Added doc files * Adding image files to docs * Linting * Updated docs to reflect changes requested in PR 97 * Linting --------- Co-authored-by: Florian Deconinck * [Translate test] Save better reports & netCDF for multiple ranks on failure (#106) * Save reports & netCDF for multiple ranks on failure Fix multi modal threshold for parallel tests * Order field by name in NetCDF * Print all indices in logs. Sort by descernding ULP * Allow sorting by metrics and index with `--sort_report` option * Remove the `rank` froom SavepointCase. Access is done via `grid` * Some docstrings * Adds some quick capacities used in the post-radiation phase of the physics, including the Stefan-Boltzmann constant (#116) * add namelist option * add stephan boltzmann constant * lint * Apply suggestions from code review Change comments to docstring style Co-authored-by: Florian Deconinck --------- Co-authored-by: Florian Deconinck * Adding temperature of h2o triple point (#115) * add ttp * Update ndsl/constants.py Co-authored-by: Florian Deconinck * switch comments to docstrings for autodocs * lint --------- Co-authored-by: Florian Deconinck * [Feature] Porting workflow: enhancing errors readability (#114) * Save all fields (pass and fail) and organize them by field * Option `--no_report` to bypass logging & netcdf save Move logs per variable into a `details` subfolder * Order variable name in serialbox-to-netcdf * `extra_data_load` function to load savepoint data saved outside the canonical savepoint * Docs / Type Hint * Fixed typo in error statment --------- Co-authored-by: Charles Kropiewnicki * Feature: NetCDF output precision configurable (#117) * Removed hard-code of np.float32 from NetCDFMonitor transfer_type, replaced with Float type * Added multiple options for NetCDF precision * Added checking for use of 32 precision and float64 output * Using NumPy type instead of string in NetCDFMonitor precision variable * Added warning to netcdf_monitor.py for mismatch in precision settings * Forgot f-string in warn message of netcdf_monitor * Mixed Precision fixes and QOL (#118) * Ignore `.next` caches * CNST_OP20 is a true 64-bit * Translate: Fix reading parameters with the right precision * Multimodal metric: Skip reporting on expected values * Bad commit * Add license (Apache 2.0) (#105) Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Change deprecated `np.product()` to `np.prod()` (#120) Starting with numpy v1.25.0, `np.product()` is deprecated and `np.prod()` should be used instead. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Update GT4Py and DaCe to bring in refactored GT4Py/DaCe bridge that exposes control flow (#119) * Update DaCe to v1.0.2 DaCe v1.0.2 brings two fixes for DaCe transformations: one for DeadDataflowElimination and one for StateFusion. * Bump gt4py to include refactored gt4py/dace bridge * Test with modified pace pipeline - added this to re-trigger the new pace pipeline after limiting zarr to not install v3 (for now) because of breaking API changes. - added this note to re-trigger after fixing the pace pipeline to not pull requirements from `develop`. - added this note to ret-trigger after fixing the repo name * Revert "Test with modified pace pipeline" This reverts commit cd6560ea6129663d3445fafb36d02f03cb661b4d. --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Grid Mixed Precision and Coriolis force load (+ QOL) (#121) * Pass `dtype` down in allocator utils (gt4py_utils) * Allow coriolis forces to be read in * Edge factors are always 64-bit * Quantity QOL * Make sure to pass `dtype` to load the grid cleanly * Translate grid: load coriolis forces, area 64 is 64-bit * Bad merge * Typo * GEOS version of dz_min (#122) * Doc enhancment (#123) **Description** Port and adaptation of the initial commit of the documentation. Fixes issue https://github.com/NOAA-GFDL/NDSL/issues/113 **Checklist:** - [X] I have performed a self-review of my own code - [X] I have made corresponding changes to the documentation - [X] My changes generate no new warnings * Fix saving NetCDF for parallel translate test (#125) * Release candidate 2025.03.00 (#124) Co-authored-by: Florian Deconinck * Fix for bad merge of 7fdfa5 (#129) --------- Co-authored-by: Oliver Elbert Co-authored-by: Florian Deconinck Co-authored-by: Florian Deconinck Co-authored-by: Oliver Elbert Co-authored-by: Roman Cattaneo <> Co-authored-by: Christopher Kung Co-authored-by: Roman Cattaneo Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki <79879064+CharlesKrop@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki Co-authored-by: Tobias Wicky-Pfund * check for backend existence in config * pc * update stale backend name --------- Co-authored-by: Frank Malatino <142349306+fmalatino@users.noreply.github.com> Co-authored-by: Oliver Elbert Co-authored-by: Florian Deconinck Co-authored-by: Florian Deconinck Co-authored-by: Oliver Elbert Co-authored-by: Christopher Kung Co-authored-by: Roman Cattaneo Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki <79879064+CharlesKrop@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki Co-authored-by: Frank Malatino --- ndsl/dsl/stencil_config.py | 2 ++ tests/dsl/test_stencil_config.py | 4 ++-- tests/dsl/test_stencil_wrapper.py | 32 ++++++++++++++++++++++--------- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index b4b57183..f90db27a 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from typing import Any, Self +from gt4py.cartesian.backend import from_name as check_backend_existence from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.communicator import Communicator @@ -43,6 +44,7 @@ def __init__( if "gpu" not in backend and device_sync is True: raise RuntimeError("Device sync is true on a CPU based backend") # GT4Py backend args + check_backend_existence(backend) self.backend = backend self.rebuild = rebuild self.validate_args = validate_args diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 89498695..5cdf79dc 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("rebuild", [True, False]) @pytest.mark.parametrize("format_source", [True, False]) @pytest.mark.parametrize("compare_to_numpy", [True, False]) -@pytest.mark.parametrize("backend", ["numpy", "gt_gpu"]) +@pytest.mark.parametrize("backend", ["numpy", "gt:gpu"]) def test_same_config_equal( backend: str, rebuild: bool, @@ -71,7 +71,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - backend="fake_backend", + backend="debug", rebuild=rebuild, validate_args=validate_args, format_source=format_source, diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 0458389e..49b7b2a9 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -313,24 +313,38 @@ def test_frozen_field_after_parameter() -> None: ) +@pytest.mark.parametrize("backend", ("numpy", "gt:gpu")) def test_backend_options( + backend: str, rebuild: bool = True, validate_args: bool = True, ) -> None: - backend = "numpy" expected_options = { - "backend": "numpy", - "rebuild": True, - "format_source": False, - "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + "numpy": { + "backend": "numpy", + "rebuild": True, + "format_source": False, + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + }, + "gt:gpu": { + "backend": "gt:gpu", + "rebuild": True, + "device_sync": False, + "format_source": False, + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + }, } actual = get_stencil_config( - backend=backend, - rebuild=rebuild, - validate_args=validate_args, + backend=backend, rebuild=rebuild, validate_args=validate_args ).stencil_kwargs(func=copy_stencil) - assert actual == expected_options + expected = expected_options[backend] + assert actual == expected + + +def test_illegal_backend_options(): + with pytest.raises(ValueError): + get_stencil_config(backend="illegal") def get_mock_quantity(): From cd5dd0dba69b3058af78ee933b188f02c2b2416d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 29 Oct 2025 14:29:15 +0100 Subject: [PATCH 02/43] fix: allow any Comm object in ZarrMonitor (#292) This PR is fallout from adding types in PR #257 and #258. The `ZarrMonitor` provides a `DummyComm` which is instantiated in case no `Comm` object is given. The type of the `Comm` object in `ZarrMonitor` was wrongly limited to that `DummyComm`, which only broke when we attempted to update the submodule in `pace`. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/comm/__init__.py | 3 ++- ndsl/monitor/zarr_monitor.py | 50 +++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 580f881a..ec886bae 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -5,7 +5,7 @@ CachingRequestReader, CachingRequestWriter, ) -from .comm_abc import Comm, Request +from .comm_abc import Comm, ReductionOperator, Request __all__ = [ @@ -15,5 +15,6 @@ "CachingRequestReader", "CachingRequestWriter", "Comm", + "ReductionOperator", "Request", ] diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 99deec2b..31e73c62 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -7,6 +7,7 @@ import xarray as xr import ndsl.constants as constants +from ndsl.comm import Comm, ReductionOperator, Request from ndsl.comm.partitioner import Partitioner, subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy @@ -19,14 +20,16 @@ T = TypeVar("T") -class DummyComm: +class DummyComm(Comm[T]): + """Dummy comm object that works in single-core mode.""" + def Get_rank(self) -> int: return 0 def Get_size(self) -> int: return 1 - def bcast(self, value: T, root: int = 0) -> T: + def bcast(self, value: T | None, root: int = 0) -> T | None: assert root == 0, ( "DummyComm should only be used on a single core, " "so root should only ever be 0" @@ -36,6 +39,47 @@ def bcast(self, value: T, root: int = 0) -> T: def barrier(self) -> None: return + def Barrier(self) -> None: + raise NotImplementedError("DummyComm.Barrier") + + def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Scatter") + + def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Gather") + + def allgather(self, sendobj: T) -> list[T]: + raise NotImplementedError("DummyComm.allgather") + + def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Send") + + def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.sendrcv") + + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Isend") + + def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Recv") + + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Irecv") + + def Split(self, color, key) -> DummyComm: # type: ignore[no-untyped-def] + raise NotImplementedError("DummyComm.Split") + + def allreduce( + self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP + ) -> T: + raise NotImplementedError("DummyComm.allreduce") + + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + raise NotImplementedError("DummyComm.Allreduce") + + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: + raise NotImplementedError("DummyComm.Allreduce_inplace") + class ZarrMonitor: """ @@ -47,7 +91,7 @@ def __init__( store: str | zarr.storage.MutableMapping, partitioner: Partitioner, mode: str = "w", - mpi_comm: DummyComm | None = None, + mpi_comm: Comm | None = None, ) -> None: """Create a ZarrMonitor. From 103e63385f314df507cb279fd8a3ea288084575c Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Thu, 30 Oct 2025 08:54:51 +0100 Subject: [PATCH 03/43] Patch domain checks to only happen once (#293) * For release `2025.03.00` (#127) * updating 4d handling * debug 4d test data * more iter * moving ser_to_nc here * updating datatype in translate test * typing works * fix dict, lint * remove empty line * change from 4d to Nd * Expose `k_start` and `k_end` automatically for any FrozenStencil * Fix k_start + utest * lint * Fix for 2d stencils * Add threshold overrides to the multimodal metric * Always report results, add summary with one liners * Remove "mmr" from the keys * README in testing * Better Latex (?) * Better Latex (?) * fixing a typo that breaks bools in translate tests (#80) * Fix summary filename * Fix report, filename * Fix choosing right absolute difference for F32 * Make robust for NaN value * Detect when array have different dimensions, if only one dimension, collapse Clean up type infer and log work * Lint * Add rank 0 to the data * Check data exists for rank, skip & print if not * Fix bad logic on skip test for parallel * Verbose exported names * Make boilerplate calls more nimble * New option: `which_savepoint` Better error on bad output data Fix missing integer type check * QOL for mypy/flak8 type hints * Add SECONDS_PER_DAY as a constants following mixed precision standards * Lint * Cleanups in dace orchestration Readability improvements in dace orchestration including - early returns - spelling out variable names - fixing typos * Rename program -> dace_program * Make sure all constants adhere to the floating point precision set by the system * Move `is_float` to `dsl.typing` * Move Quantity to sub-directory + breakout the subcomponent * Fix tests * Lint * Remove `cp.ndarray` since cupy is optional * Restore workaround for optional cupy * "GFS" -> "UFS" * Cupy trick for metadata * Add comments for constant explanation * Describe 64/32-bit FloatFields * Make sure the `make_storage_data` respects the array dtype. * Fix logic for MultiModal metric and verbose it * Added an MPI all_reduce for quantities based on SUM operation to communicator.py * linted * Add initial skeleton of pytest test for all reduce * Added assertion tests for 1, 2 and 3D quantities passed through mpi_allreduce_sum * Linted * Added pytest.mark to skip test if mpi4py isn't available * lint changes * Addressed PR comments and added additional CPU backends to unit test * Added setters for various Quantity properties to enable setting of Quantity metadata and data properties. * Added function in QuantityMetadata class that allows copying of Metadata properties from one class to another. Subsequent Quantity setters that performed the copying of QuantityMetadata properties were removed * Expose all SG metric terms in grid_data * Add `Allreduce` and all MPI OP * Update utest * Fix `local_comm` * Fix utest * Enforce `comm_abc.Comm` into Communicator * Fix `comm` object in serial utest * Lint + `MPIComm` on testing architecture * Make sure the correct allocator backend is used for Quantities * Add in_place option for Allreduce * Cleanup ndsl/dsl/dace/utils.py (#96) * Fix typos * DaCeProgress: avoid double assignment of prefix * Add type hints/simplify kernel_theoretical_timing Adding type hints allowed to simplify `kernel_theoretical_timing`. * Fix merge * Hotfix for grid generation use of mpi operators * Merge examples/mpi/.gitignore into top-level .gitignore * Remove hard-coded __version__ numbers Removes hard-coded version numbers from `__init__` files. * Fixing a bunch of typos * hotfix netcdf version for dockerfiles * Updated version number in setup.py to reflect new release, 2025.01.00 * Adding in exception for compute domains with sizes less than or equal to halo size (#103) * Adding in exception for compute domains with less than 4 points to vector_halo_update method * Updated exception in communicator to compare halo size to compute domain size * linting * Moved domain size checker to SubtileGridSizer class method from_tile_params * Fix passing down ak/bk for pressure coefficients when they are available from an outside source (online model case) (#107) * [QOL] Logging, Type Hints and Quantity helpers (#108) * Log on rank 0 Docstrings & typi hints on logger Stencil Config has a `verbose` option On verbose: FrozenStencil log when run (in GT backends) * Update `config` in orchestrate call to solve type hint inconcistencies * Quantity helper `to_netcdf` with multi rank support * Automatic Int precision and stencil regeneration change (#104) * Added feature to enable automatic detection of integer precision. Should remove the need for i32/i64 declaration (although their functionality is still retained) and replace both with the regular Int type * change default rebuild state to false for get_factories * Merged Float and Int precision detection functions into one common path * Re-added old function to fulfil a PACE dependency * updated docstring * Added ability to declare 32 or 64 bit IntFields, overrulling the system precision * Added one dimensional bool fields * Fix error message in typing.py Co-authored-by: Florian Deconinck * output type for global_set_precision --------- Co-authored-by: Florian Deconinck * Bump DaCe to v1.0.1 (#109) Our current DaCe version is some commit from September 2024. Meanwhile DaCe matured to v1 and recently release v1.0.1. This brings the DaCe submodule to the latest stable release version. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Streamline linting workflow (#110) Linting should give fast feedback. The current workflow takes ~3mins where most of the time is spent installing (unnecessary) python packages. To run `pre-commit`, we only need the source files and `pre-commit` itself, which can be installed standalone. This brings runtime of the linting stage down to ~30 seconds. Other changes - update checkout action to v4 - update python setup action to v5 - change python version from 3.11.7 to 3.11 (any patch number will do) This is a follow-up of PR https://github.com/NOAA-GFDL/PyFV3/pull/40 in PyFV3. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * [FIX] Type hint for precision dependant Float, Int (#111) * Fix the type hint of Float, Int * Attempt using TypeAlias * Feature: Adding documentation (#97) * Added doc files * Adding image files to docs * Linting * Updated docs to reflect changes requested in PR 97 * Linting --------- Co-authored-by: Florian Deconinck * [Translate test] Save better reports & netCDF for multiple ranks on failure (#106) * Save reports & netCDF for multiple ranks on failure Fix multi modal threshold for parallel tests * Order field by name in NetCDF * Print all indices in logs. Sort by descernding ULP * Allow sorting by metrics and index with `--sort_report` option * Remove the `rank` froom SavepointCase. Access is done via `grid` * Some docstrings * Adds some quick capacities used in the post-radiation phase of the physics, including the Stefan-Boltzmann constant (#116) * add namelist option * add stephan boltzmann constant * lint * Apply suggestions from code review Change comments to docstring style Co-authored-by: Florian Deconinck --------- Co-authored-by: Florian Deconinck * Adding temperature of h2o triple point (#115) * add ttp * Update ndsl/constants.py Co-authored-by: Florian Deconinck * switch comments to docstrings for autodocs * lint --------- Co-authored-by: Florian Deconinck * [Feature] Porting workflow: enhancing errors readability (#114) * Save all fields (pass and fail) and organize them by field * Option `--no_report` to bypass logging & netcdf save Move logs per variable into a `details` subfolder * Order variable name in serialbox-to-netcdf * `extra_data_load` function to load savepoint data saved outside the canonical savepoint * Docs / Type Hint * Fixed typo in error statment --------- Co-authored-by: Charles Kropiewnicki * Feature: NetCDF output precision configurable (#117) * Removed hard-code of np.float32 from NetCDFMonitor transfer_type, replaced with Float type * Added multiple options for NetCDF precision * Added checking for use of 32 precision and float64 output * Using NumPy type instead of string in NetCDFMonitor precision variable * Added warning to netcdf_monitor.py for mismatch in precision settings * Forgot f-string in warn message of netcdf_monitor * Mixed Precision fixes and QOL (#118) * Ignore `.next` caches * CNST_OP20 is a true 64-bit * Translate: Fix reading parameters with the right precision * Multimodal metric: Skip reporting on expected values * Bad commit * Add license (Apache 2.0) (#105) Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Change deprecated `np.product()` to `np.prod()` (#120) Starting with numpy v1.25.0, `np.product()` is deprecated and `np.prod()` should be used instead. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Update GT4Py and DaCe to bring in refactored GT4Py/DaCe bridge that exposes control flow (#119) * Update DaCe to v1.0.2 DaCe v1.0.2 brings two fixes for DaCe transformations: one for DeadDataflowElimination and one for StateFusion. * Bump gt4py to include refactored gt4py/dace bridge * Test with modified pace pipeline - added this to re-trigger the new pace pipeline after limiting zarr to not install v3 (for now) because of breaking API changes. - added this note to re-trigger after fixing the pace pipeline to not pull requirements from `develop`. - added this note to ret-trigger after fixing the repo name * Revert "Test with modified pace pipeline" This reverts commit cd6560ea6129663d3445fafb36d02f03cb661b4d. --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> * Grid Mixed Precision and Coriolis force load (+ QOL) (#121) * Pass `dtype` down in allocator utils (gt4py_utils) * Allow coriolis forces to be read in * Edge factors are always 64-bit * Quantity QOL * Make sure to pass `dtype` to load the grid cleanly * Translate grid: load coriolis forces, area 64 is 64-bit * Bad merge * Typo * GEOS version of dz_min (#122) * Doc enhancment (#123) **Description** Port and adaptation of the initial commit of the documentation. Fixes issue https://github.com/NOAA-GFDL/NDSL/issues/113 **Checklist:** - [X] I have performed a self-review of my own code - [X] I have made corresponding changes to the documentation - [X] My changes generate no new warnings * Fix saving NetCDF for parallel translate test (#125) * Release candidate 2025.03.00 (#124) Co-authored-by: Florian Deconinck * Fix for bad merge of 7fdfa5 (#129) --------- Co-authored-by: Oliver Elbert Co-authored-by: Florian Deconinck Co-authored-by: Florian Deconinck Co-authored-by: Oliver Elbert Co-authored-by: Roman Cattaneo <> Co-authored-by: Christopher Kung Co-authored-by: Roman Cattaneo Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki <79879064+CharlesKrop@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki Co-authored-by: Tobias Wicky-Pfund * check domain size args only once * review & test --------- Co-authored-by: Frank Malatino <142349306+fmalatino@users.noreply.github.com> Co-authored-by: Oliver Elbert Co-authored-by: Florian Deconinck Co-authored-by: Florian Deconinck Co-authored-by: Oliver Elbert Co-authored-by: Christopher Kung Co-authored-by: Roman Cattaneo Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki <79879064+CharlesKrop@users.noreply.github.com> Co-authored-by: Charles Kropiewnicki Co-authored-by: Frank Malatino --- ndsl/dsl/stencil.py | 14 ++++++++++++-- tests/dsl/test_stencil.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index f557c3b3..c1fd297e 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -290,6 +290,8 @@ def __init__( else: self._timing_collector = timing_collector + self._arguments_already_checked = False + if externals is None: externals = {} self.externals = externals @@ -406,7 +408,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: if self.stencil_config.verbose: ndsl_log.debug(f"Running {self._func_name}") - self._validate_quantity_sizes(*args, **kwargs) + if ( + not self._arguments_already_checked + and self.stencil_config.compilation_config.validate_args + ): + self._validate_quantity_sizes(*args, **kwargs) # Marshal arguments args_list = list(args) @@ -430,7 +436,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: ndsl_debugger.track_data(all_args, self._func_qualname, is_in=True) # Execute stencil - if self.stencil_config.compilation_config.validate_args: + if ( + not self._arguments_already_checked + and self.stencil_config.compilation_config.validate_args + ): if __debug__ and "origin" in kwargs: raise TypeError("origin cannot be passed to FrozenStencil call") if __debug__ and "domain" in kwargs: @@ -443,6 +452,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: validate_args=True, exec_info=self._timing_collector.exec_info, ) + self._arguments_already_checked = True else: self.stencil_object.run( **args_as_kwargs, diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 7130853e..4daa401c 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -156,3 +156,28 @@ def test_stencil_2D_temporaries() -> None: ) stencil(quantity) assert (quantity.data[1, 1, :] == 21.0).all() + + +@pytest.mark.parametrize( + "iterations", + [2, 1], +) +def test_validation_call_count(iterations: tuple[int]): + domain = (2, 2, 5) + quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + stencil_config = StencilConfig( + compilation_config=CompilationConfig(backend="numpy", rebuild=True) + ) + stencil = FrozenStencil( + copy_stencil, + origin=(0, 0, 0), + domain=domain, + stencil_config=stencil_config, + ) + # with expectation: + counting_mock = MagicMock() + with patch.object(FrozenStencil, "_validate_quantity_sizes", counting_mock): + for _i in range(iterations): + stencil(quantity, quantity) + + assert counting_mock.call_count == 1 From bdbf15763565fd5cbe34e634eba708a4e7cdfbb1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 30 Oct 2025 09:25:25 +0100 Subject: [PATCH 04/43] BREAKING CHANGE: change constructor of `QuantityFactory` (#228) * Breaking change: QuantityFactory from GridSizer and backend name Change `QuantityFactory` to initialize from a `GridSizer` (as previously) and a backend name (new). This effectively hides the previous `numpy` argument, which is effectively an internal allocator that users shouldn't need to know about. It's basically what `from_backend()` was doing before (which is now obsolete and was thus removed). This is a BREAKING CHANGE and users will need to update their codes if they instantiated QuantityFactories themselves. For users relying on the `boilerplate` module, no changes need to happen. * Keep QuantityFactory.from_backend() with a deprecation warning * Extended docstings This is mainly to force a new run of the pyshild workflow now that pyshield tests are exclusively using `QuantityFactory.from_backend()` which is compatible with changes proposed in this PR. * More updates to docstrings * fixup after rebase * Unrelated: tests are supposed to return `None` * fixup: move method back to current place --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/boilerplate.py | 2 +- ndsl/grid/generation.py | 2 +- ndsl/initialization/allocator.py | 98 ++++++++++------------ ndsl/stencils/testing/grid.py | 4 +- tests/grid/test_eta.py | 6 +- tests/initialization/test_allocator.py | 17 +--- tests/stree_optimizer/test_optimization.py | 14 ++-- tests/test_boilerplate.py | 4 +- tests/test_dimension_sizer.py | 54 +++++++----- 9 files changed, 94 insertions(+), 107 deletions(-) diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index d17d6349..6d485887 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -69,7 +69,7 @@ def _get_factories( grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) - quantity_factory = QuantityFactory.from_backend(sizer, backend) + quantity_factory = QuantityFactory(sizer, backend=backend) return stencil_factory, quantity_factory diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 1bc37bc5..7fd6eb7d 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -502,7 +502,7 @@ def from_tile_sizing( n_halo=N_HALO_DEFAULT, layout=communicator.partitioner.tile.layout, ) - quantity_factory = QuantityFactory.from_backend(sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) return cls( quantity_factory=quantity_factory, communicator=communicator, diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index e9d8857e..f421c712 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -13,13 +13,13 @@ from ndsl.quantity import Quantity, QuantityHaloSpec -class StorageNumpy: +class _Allocator: def __init__(self, backend: str) -> None: - """Initialize an object which behaves like the numpy module, but uses - gt4py storage objects for zeros, ones, and empty. + """ + Initialize an object that provides gt4py storage objects for zeros(), ones(), and empty(). Args: - backend: gt4py backend + backend: GT4Py backend name used for performance-optimized allocation. """ self.backend = backend @@ -34,19 +34,18 @@ def zeros(self, *args: Any, **kwargs: Any) -> np.ndarray: class QuantityFactory: - def __init__( # type: ignore - self, sizer: GridSizer, numpy, *, silence_deprecation_warning: bool = False - ) -> None: - if not silence_deprecation_warning: - warnings.warn( - "Usage of QuantityFactory(sizer, numpy) is discouraged and will change " - "in the next release. Use QuantityFactory.from_backend(sizer, backend) " - "instead for a stable experience across the release.", - DeprecationWarning, - 2, - ) - self.sizer: GridSizer = sizer - self._numpy = numpy + def __init__(self, sizer: GridSizer, *, backend: str) -> None: + """ + Initialize a QuantityFactory from a GridSizer and a GT4Py backend name. + + Args: + sizer: GridSizer object that determines the array sizes. + backend: GT4Py backend name used for performance-optimized allocation. + """ + self.sizer = sizer + self.backend = backend + + self._allocator = _Allocator(self.backend) def set_extra_dim_lengths(self, **kwargs: Any) -> None: """ @@ -95,21 +94,22 @@ def add_data_dimensions( @classmethod def from_backend(cls, sizer: GridSizer, backend: str) -> QuantityFactory: - """Initialize a QuantityFactory to use a specific gt4py backend. + """Initialize a QuantityFactory to use a specific GT4Py backend. + + Note: This method is deprecated. Please change your code to use the + constructor instead. Args: - sizer: object which determines array sizes - backend: gt4py backend + sizer: GridSizer object that determines the array sizes. + backend: GT4Py backend name used for performance-optimized allocation. """ - numpy = StorageNumpy(backend) - # Don't print the deprecation warning in this case - return cls(sizer, numpy, silence_deprecation_warning=True) - - def _backend(self) -> str | None: - if isinstance(self._numpy, StorageNumpy): - return self._numpy.backend - - return None + warnings.warn( + "QuantityFactory.from_backend(sizer, backend) is deprecated. Use " + "QuantityFactory(sizer, backend=backend) instead.", + DeprecationWarning, + stacklevel=2, + ) + return cls(sizer, backend=backend) def empty( self, @@ -119,15 +119,11 @@ def empty( *, allow_mismatch_float_precision: bool = False, ) -> Quantity: - """Allocate a Quantity - values are random. + """Allocate a Quantity and fill it with uninitialized (undefined) values. Equivalent to `numpy.empty`""" return self._allocate( - self._numpy.empty, - dims, - units, - dtype, - allow_mismatch_float_precision, + self._allocator.empty, dims, units, dtype, allow_mismatch_float_precision ) def zeros( @@ -142,11 +138,7 @@ def zeros( Equivalent to `numpy.zeros`""" return self._allocate( - self._numpy.zeros, - dims, - units, - dtype, - allow_mismatch_float_precision, + self._allocator.zeros, dims, units, dtype, allow_mismatch_float_precision ) def ones( @@ -161,11 +153,7 @@ def ones( Equivalent to `numpy.ones`""" return self._allocate( - self._numpy.ones, - dims, - units, - dtype, - allow_mismatch_float_precision, + self._allocator.ones, dims, units, dtype, allow_mismatch_float_precision ) def full( @@ -177,11 +165,11 @@ def full( *, allow_mismatch_float_precision: bool = False, ) -> Quantity: - """Allocate a Quantity and fill it with the value. + """Allocate a Quantity and fill it with the given value. Equivalent to `numpy.full`""" quantity = self._allocate( - self._numpy.empty, + self._allocator.empty, dims, units, dtype, @@ -199,10 +187,10 @@ def from_array( allow_mismatch_float_precision: bool = False, ) -> Quantity: """ - Create a Quantity from a numpy array. + Create a Quantity from values in the `data` array. - That numpy array must correspond to the correct shape and extent - for the given dims. + This copies the values of `data` into the resulting Quantity. The data + array thus must correspond to the correct shape and extent for the given dims. """ base = self.zeros( dims=dims, @@ -222,10 +210,12 @@ def from_compute_array( allow_mismatch_float_precision: bool = False, ) -> Quantity: """ - Create a Quantity from a numpy array. + Create a Quantity from values of the compute domain. - That numpy array must correspond to the correct shape and extent - of the compute domain for the given dims. + This function will allocate the full Quantity (including potential + halo points) to zero. The values of `data` are then copied into + the compute domain. That numpy array must correspond to the correct + shape and extent of the compute domain for the given dims. """ base = self.zeros( dims=dims, @@ -269,7 +259,7 @@ def _allocate( units=units, origin=origin, extent=extent, - gt4py_backend=self._backend(), + gt4py_backend=self.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, number_of_halo_points=self.sizer.n_halo, ) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 41c7d356..6437f561 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -159,9 +159,7 @@ def sizer(self): @property def quantity_factory(self) -> QuantityFactory: if self._quantity_factory is None: - self._quantity_factory = QuantityFactory.from_backend( - self.sizer, backend=self.backend - ) + self._quantity_factory = QuantityFactory(self.sizer, backend=self.backend) return self._quantity_factory def make_quantity( diff --git a/tests/grid/test_eta.py b/tests/grid/test_eta.py index 4090b3b5..d50e2fd9 100755 --- a/tests/grid/test_eta.py +++ b/tests/grid/test_eta.py @@ -58,7 +58,7 @@ def test_set_hybrid_pressure_coefficients_correct(levels): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) metric_terms = MetricTerms( quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file @@ -106,7 +106,7 @@ def test_set_hybrid_pressure_coefficients_nofile(): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) with pytest.raises(ValueError, match=f"eta file {eta_file} does not exist"): MetricTerms( @@ -149,7 +149,7 @@ def test_set_hybrid_pressure_coefficients_not_mono(): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer, backend=backend) with pytest.raises(ValueError, match="ETA values are not monotonically increasing"): MetricTerms( diff --git a/tests/initialization/test_allocator.py b/tests/initialization/test_allocator.py index 3ee01e43..dad8d407 100644 --- a/tests/initialization/test_allocator.py +++ b/tests/initialization/test_allocator.py @@ -1,19 +1,8 @@ -import warnings - -import numpy as np import pytest from ndsl import QuantityFactory -def test_QuantityFactory_constructor_warns() -> None: - with pytest.warns( - DeprecationWarning, - match="Usage of QuantityFactory.* is discouraged and will change", - ): - QuantityFactory(None, np) - - # Make sure no warnings are emitted if users use `QuantityFactory.from_backend()` - with warnings.catch_warnings(): - warnings.simplefilter("error") - QuantityFactory.from_backend(None, "numpy") +def test_QuantityFactory_from_backend_warns() -> None: + with pytest.deprecated_call(): + QuantityFactory.from_backend(None, backend="numpy") diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py index fb32a35d..6ed29479 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_optimization.py @@ -6,18 +6,18 @@ from ndsl.dsl.typing import FloatField -def stencil_A(in_field: FloatField, out_field: FloatField): +def stencil_A(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = in_field -def stencil_B(in_field: FloatField, out_field: FloatField): +def stencil_B(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = out_field + in_field * 3 class TriviallyMergeableCode: - def __init__(self, stencil_factory: StencilFactory): + def __init__(self, stencil_factory: StencilFactory) -> None: orchestrate(obj=self, config=stencil_factory.config.dace_config) self.stencil_A = stencil_factory.from_dims_halo( func=stencil_A, @@ -28,15 +28,15 @@ def __init__(self, stencil_factory: StencilFactory): compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - def __call__(self, in_field: FloatField, out_field: FloatField): + def __call__(self, in_field: FloatField, out_field: FloatField) -> None: self.stencil_A(in_field, out_field) self.stencil_B(in_field, out_field) -def test_stree_roundtrip_no_opt(): +def test_stree_roundtrip_no_opt() -> None: """Dev Note: - The below code sucessfully merges top level K loop (2 loops) + The below code successfully merges top level K loop (2 loops) How do we test it?! Running doesn't test merging and the compilation is a near-black box. We could reach in the `dace_config.compiled_sdfg` cache but it's keyed on the dace.program and if we can reach the program @@ -46,7 +46,7 @@ def test_stree_roundtrip_no_opt(): Test is deactivated for now""" - return True + return domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend="dace:cpu" diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index b531453d..0e00d459 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -45,7 +45,7 @@ def test_boilerplate_import_numpy(): # Ensure backend is propagated to StencilFactory and QuantityFactory assert stencil_factory.backend == "numpy" - assert quantity_factory._backend() == "numpy" + assert quantity_factory.backend == "numpy" _copy_ops(stencil_factory, quantity_factory) @@ -64,7 +64,7 @@ def test_boilerplate_import_orchestrated_cpu(): # Ensure backend is propagated to StencilFactory and QuantityFactory assert stencil_factory.backend == "dace:cpu" - assert quantity_factory._backend() == "dace:cpu" + assert quantity_factory.backend == "dace:cpu" _copy_ops(stencil_factory, quantity_factory) diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index ebad25ef..9d78a4e0 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -2,7 +2,7 @@ import pytest -from ndsl import QuantityFactory, SubtileGridSizer +from ndsl import GridSizer, QuantityFactory, SubtileGridSizer from ndsl.constants import ( N_HALO_DEFAULT, X_DIM, @@ -55,7 +55,7 @@ def extra_dimension_lengths(): @pytest.fixture def namelist(nx_tile, ny_tile, nz, layout): - namelist = { + return { "fv_core_nml": { "npx": nx_tile + 1, "npy": ny_tile + 1, @@ -63,13 +63,14 @@ def namelist(nx_tile, ny_tile, nz, layout): "layout": layout, } } - return namelist @pytest.fixture(params=["from_namelist", "from_tile_params"]) -def sizer(request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths): +def sizer( + request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths +) -> GridSizer: if request.param == "from_tile_params": - sizer = SubtileGridSizer.from_tile_params( + return SubtileGridSizer.from_tile_params( nx_tile=nx_tile, ny_tile=ny_tile, nz=nz, @@ -77,11 +78,11 @@ def sizer(request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengt layout=layout, data_dimensions=extra_dimension_lengths, ) - elif request.param == "from_namelist": - sizer = SubtileGridSizer.from_namelist(namelist) - else: - raise NotImplementedError() - return sizer + + if request.param == "from_namelist": + return SubtileGridSizer.from_namelist(namelist) + + raise NotImplementedError() @pytest.fixture @@ -109,7 +110,7 @@ def dtype(request): "z_y_x", ] ) -def dim_case(request, nx, ny, nz): +def dim_case(request, nx, ny, nz) -> DimCase: if request.param == "x_only": return DimCase( (X_DIM,), @@ -117,32 +118,38 @@ def dim_case(request, nx, ny, nz): (nx,), (2 * N_HALO_DEFAULT + nx + 1,), ) - elif request.param == "x_interface_only": + + if request.param == "x_interface_only": return DimCase( (X_INTERFACE_DIM,), (N_HALO_DEFAULT,), (nx + 1,), (2 * N_HALO_DEFAULT + nx + 1,), ) - elif request.param == "y_only": + + if request.param == "y_only": return DimCase( (Y_DIM,), (N_HALO_DEFAULT,), (ny,), (2 * N_HALO_DEFAULT + ny + 1,), ) - elif request.param == "y_interface_only": + + if request.param == "y_interface_only": return DimCase( (Y_INTERFACE_DIM,), (N_HALO_DEFAULT,), (ny + 1,), (2 * N_HALO_DEFAULT + ny + 1,), ) - elif request.param == "z_only": + + if request.param == "z_only": return DimCase((Z_DIM,), (0,), (nz,), (nz + 1,)) - elif request.param == "z_interface_only": + + if request.param == "z_interface_only": return DimCase((Z_INTERFACE_DIM,), (0,), (nz + 1,), (nz + 1,)) - elif request.param == "x_y": + + if request.param == "x_y": return DimCase( ( X_DIM, @@ -155,7 +162,8 @@ def dim_case(request, nx, ny, nz): 2 * N_HALO_DEFAULT + ny + 1, ), ) - elif request.param == "z_y_x": + + if request.param == "z_y_x": return DimCase( ( Z_DIM, @@ -171,6 +179,8 @@ def dim_case(request, nx, ny, nz): ), ) + raise NotImplementedError() + @pytest.mark.cpu_only def test_subtile_dimension_sizer_origin(sizer, dim_case): @@ -191,7 +201,7 @@ def test_subtile_dimension_sizer_shape(sizer, dim_case): def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory.from_backend(sizer, "numpy") + allocator = QuantityFactory(sizer, backend="numpy") quantity = allocator.zeros(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -202,7 +212,7 @@ def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): def test_allocator_ones(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory.from_backend(sizer, "numpy") + allocator = QuantityFactory(sizer, backend="numpy") quantity = allocator.ones(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -213,7 +223,7 @@ def test_allocator_ones(numpy, sizer, dim_case, units, dtype): def test_allocator_empty(sizer, dim_case, units, dtype): - allocator = QuantityFactory.from_backend(sizer, "numpy") + allocator = QuantityFactory(sizer, backend="numpy") quantity = allocator.empty(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -223,7 +233,7 @@ def test_allocator_empty(sizer, dim_case, units, dtype): def test_allocator_data_dimensions_operations(sizer): - quantity_factory = QuantityFactory.from_backend(sizer, "numpy") + quantity_factory = QuantityFactory(sizer, backend="numpy") quantity_factory.add_data_dimensions({"D0": 11}) assert "D0" in quantity_factory.sizer.data_dimensions.keys() assert quantity_factory.sizer.data_dimensions["D0"] == 11 From dd620f0052fb0602f463d6e770e542037a3a30e7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 30 Oct 2025 14:47:46 +0100 Subject: [PATCH 05/43] BREAKING CHANGE: remove ndsl/exceptions (#281) * BREAKING CHANGE: remove ndsl/exceptions The module has been deprecated last release and will be removed with this release. * fixup: documentation --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- docs/docstrings/top/exceptions.md | 3 --- mkdocs.yml | 1 - ndsl/__init__.py | 2 -- ndsl/exceptions.py | 16 ---------------- tests/mpi/test_mpi_mock.py | 2 +- tests/test_exceptions.py | 8 -------- 6 files changed, 1 insertion(+), 31 deletions(-) delete mode 100644 docs/docstrings/top/exceptions.md delete mode 100644 ndsl/exceptions.py delete mode 100644 tests/test_exceptions.py diff --git a/docs/docstrings/top/exceptions.md b/docs/docstrings/top/exceptions.md deleted file mode 100644 index 8803fee0..00000000 --- a/docs/docstrings/top/exceptions.md +++ /dev/null @@ -1,3 +0,0 @@ -# exceptions - -::: exceptions diff --git a/mkdocs.yml b/mkdocs.yml index 1ca8682c..34692b3f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,7 +24,6 @@ nav: - "boilerplate": docstrings/top/boilerplate.md - "buffer": docstrings/top/buffer.md - "constants": docstrings/top/constants.md - - "exceptions": docstrings/top/exceptions.md - "filesystem": docstrings/top/filesystem.md - "io": docstrings/top/io.md - "logging": docstrings/top/logging.md diff --git a/ndsl/__init__.py b/ndsl/__init__.py index c7730282..215ee14c 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -19,7 +19,6 @@ from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig -from .exceptions import OutOfBoundsError from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .initialization import GridSizer, QuantityFactory, SubtileGridSizer @@ -62,7 +61,6 @@ "CompilationConfig", "RunMode", "StencilConfig", - "OutOfBoundsError", "HaloExchangeSpec", "HaloUpdater", "HaloUpdateRequest", diff --git a/ndsl/exceptions.py b/ndsl/exceptions.py deleted file mode 100644 index 4511ea69..00000000 --- a/ndsl/exceptions.py +++ /dev/null @@ -1,16 +0,0 @@ -# flake8: noqa -import warnings - -from ndsl.comm.local_comm import ConcurrencyError -from ndsl.units import UnitsError - - -class OutOfBoundsError(ValueError): - def __init__(self, *args) -> None: - warnings.warn( - "Usage of `OutOfBoundsError` is discouraged. The class will be " - "removed in the next version in favor of using the built-in `IndexError`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args) diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 0da2f21b..6436af02 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -3,7 +3,7 @@ from ndsl import DummyComm from ndsl.buffer import recv_buffer -from ndsl.exceptions import ConcurrencyError +from ndsl.comm.local_comm import ConcurrencyError from tests.mpi.mpi_comm import MPI diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py deleted file mode 100644 index f5d66eb7..00000000 --- a/tests/test_exceptions.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - -from ndsl import OutOfBoundsError - - -def test_OutOfBoundsError_is_deprecation() -> None: - with pytest.deprecated_call(): - OutOfBoundsError("This should trigger a deprecation warning.") From 7118bad6dfca34aea27f8604b69c678a9870f668 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 30 Oct 2025 17:16:35 +0100 Subject: [PATCH 06/43] BREAKING CHANGE: remove deprecated environment variables (#282) Those environment variable were deprecated in the last release and will be removed with this release. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- docs/user/index.md | 4 ++-- ndsl/constants.py | 9 +-------- ndsl/dsl/__init__.py | 13 +------------ ndsl/dsl/dace/dace_config.py | 10 +--------- ndsl/logging.py | 9 +-------- 5 files changed, 6 insertions(+), 39 deletions(-) diff --git a/docs/user/index.md b/docs/user/index.md index 46ddc3eb..106ec159 100644 --- a/docs/user/index.md +++ b/docs/user/index.md @@ -12,13 +12,13 @@ NDSL tries to have sensible defaults. In cases you want tweak something, here ar ### Literal precision (float/int) -Unspecified integer and floating point literals (e.g. `42` and `3.1415`) default to 64-bit precision. This can be changed with the environment variable `PACE_FLOAT_PRECISION`. +Unspecified integer and floating point literals (e.g. `42` and `3.1415`) default to 64-bit precision. This can be changed with the environment variable `NDSL_LITERAL_PRECISION`. For mixed precision code, you can specify the "hard coded" precision with type hints and casts, e.g. ```python with computation(PARALLEL), interval(...): - # Either 32-bit or 64-bit depending on `PACE_FLOAT_PRECISION` + # Either 32-bit or 64-bit depending on `NDSL_LITERAL_PRECISION` my_int = 42 my_float = 3.1415 diff --git a/ndsl/constants.py b/ndsl/constants.py index 82d16d30..a02f70f6 100644 --- a/ndsl/constants.py +++ b/ndsl/constants.py @@ -20,14 +20,7 @@ class ConstantVersions(Enum): def _get_constant_version( default: Literal["GFDL", "UFS", "GEOS"] = "UFS", ) -> Literal["GFDL", "UFS", "GEOS"]: - if os.getenv("PACE_CONSTANTS", ""): - ndsl_log.warning("PACE_CONSTANTS is deprecated. Use NDSL_CONSTANTS instead.") - if os.getenv("NDSL_CONSTANTS", ""): - ndsl_log.warning( - "PACE_CONSTANTS and NDSL_CONSTANTS were both specified. NDSL_CONSTANTS will take precedence." - ) - - constants_as_str = os.getenv("NDSL_CONSTANTS", os.getenv("PACE_CONSTANTS", default)) + constants_as_str = os.getenv("NDSL_CONSTANTS", default) expected: list[Literal["GFDL", "UFS", "GEOS"]] = ["GFDL", "UFS", "GEOS"] if constants_as_str not in expected: diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 5fa508d2..f931129b 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -17,18 +17,7 @@ def _get_literal_precision(default: Literal["32", "64"] = "64") -> Literal["32", "64"]: - if os.getenv("PACE_FLOAT_PRECISION", ""): - ndsl_log.warning( - "PACE_FLOAT_PRECISION is deprecated. Use NDSL_LITERAL_PRECISION instead." - ) - if os.getenv("NDSL_LITERAL_PRECISION", ""): - ndsl_log.warning( - "PACE_FLOAT_PRECISION and NDSL_LOGLEVEL were both specified. NDSL_LITERAL_PRECISION will take precedence." - ) - - precision = os.getenv( - "NDSL_LITERAL_PRECISION", os.getenv("PACE_FLOAT_PRECISION", default) - ) + precision = os.getenv("NDSL_LITERAL_PRECISION", default) expected: list[Literal["32", "64"]] = ["32", "64"] if precision in expected: diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index d76e10da..66d43f17 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -15,7 +15,6 @@ from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import get_precision -from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector @@ -31,14 +30,7 @@ def _debug_dace_orchestration() -> bool: Debugging Dace orchestration deeper can be done by turning on `syncdebug`. We control this Dace configuration below with our own override. """ - if os.getenv("PACE_DACE_DEBUG", ""): - ndsl_log.warning("PACE_DACE_DEBUG is deprecated. Use NDSL_DACE_DEBUG instead.") - if os.getenv("NDSL_DACE_DEBUG", ""): - ndsl_log.warning( - "PACE_DACE_DEBUG and NDSL_DACE_DEBUG were both specified. NDSL_DACE_DEBUG will take precedence." - ) - - return os.getenv("NDSL_DACE_DEBUG", os.getenv("PACE_DACE_DEBUG", "False")) == "True" + return os.getenv("NDSL_DACE_DEBUG", "False") == "True" def _is_corner(rank: int, partitioner: Partitioner) -> bool: diff --git a/ndsl/logging.py b/ndsl/logging.py index 183054b6..7892a647 100644 --- a/ndsl/logging.py +++ b/ndsl/logging.py @@ -20,14 +20,7 @@ def _get_log_level(default: str = "info") -> str: - if os.getenv("PACE_LOGLEVEL", ""): - logging.warning("PACE_LOGLEVEL is deprecated. Use NDSL_LOGLEVEL instead.") - if os.getenv("NDSL_LOGLEVEL", ""): - logging.warning( - "PACE_LOGLEVEL and NDSL_LOGLEVEL were both specified. NDSL_LOGLEVEL will take precedence." - ) - - loglevel = os.getenv("NDSL_LOGLEVEL", os.getenv("PACE_LOGLEVEL", default)).lower() + loglevel = os.getenv("NDSL_LOGLEVEL", default).lower() if loglevel in AVAILABLE_LOG_LEVELS.keys(): return loglevel From d8027ea376752d6c80f24ece8fa38df23063e0d3 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 30 Oct 2025 17:50:11 +0100 Subject: [PATCH 07/43] ci: specialize concurrency group per repo (#287) * ci: per repo concurrency group Note: using `${{ github.repository }}` sounds like a good idea. In practice, that doesn't play nice when the workflow is called from another repository because in that case, `github.repository` resolves to the calling repository. * fix file ending of called workflows --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .github/workflows/create-cache.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/create-cache.yaml b/.github/workflows/create-cache.yaml index 6d823662..f5e38a6b 100644 --- a/.github/workflows/create-cache.yaml +++ b/.github/workflows/create-cache.yaml @@ -14,7 +14,7 @@ on: # Cancel running jobs if there's a newer push concurrency: - group: ${{ github.repository }}-${{ github.workflow }}-${{ github.ref }} + group: ndsl-${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: @@ -45,8 +45,8 @@ jobs: # GitHub Actions cache of pyFV3 test data pyFV3_test_data: - uses: NOAA-GFDL/pyFV3/.github/workflows/create_cache.yml@develop + uses: NOAA-GFDL/pyFV3/.github/workflows/create_cache.yaml@develop # GitHub Actions cache of pySHiELD test data pySHiELD_test_data: - uses: NOAA-GFDL/pySHiELD/.github/workflows/create_cache.yml@develop + uses: NOAA-GFDL/pySHiELD/.github/workflows/create_cache.yaml@develop From d5f3e54fa9010ef6864142bcecccbbc3e74928cf Mon Sep 17 00:00:00 2001 From: Janice Kim Date: Fri, 31 Oct 2025 06:18:44 -0400 Subject: [PATCH 08/43] Remove ndsl.Namelist (#297) * Removing ndsl.Namelist * Removing use_legacy_namelist flag functionality while keeping the flag itself. * - Removing ndsl.Namelist - Removing use_legacy_namelist flag functionality (while keeping the flag itself for now) * linting * Removing namelist.md and test_namelist.py --- docs/docstrings/top/namelist.md | 3 - ndsl/__init__.py | 2 - ndsl/namelist.py | 647 -------------------- ndsl/stencils/testing/conftest.py | 20 +- ndsl/stencils/testing/parallel_translate.py | 9 - tests/test_namelist.py | 8 - 6 files changed, 1 insertion(+), 688 deletions(-) delete mode 100644 docs/docstrings/top/namelist.md delete mode 100644 ndsl/namelist.py delete mode 100644 tests/test_namelist.py diff --git a/docs/docstrings/top/namelist.md b/docs/docstrings/top/namelist.md deleted file mode 100644 index c129312a..00000000 --- a/docs/docstrings/top/namelist.md +++ /dev/null @@ -1,3 +0,0 @@ -# namelist - -::: namelist diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 215ee14c..1c445feb 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -23,7 +23,6 @@ from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .initialization import GridSizer, QuantityFactory, SubtileGridSizer from .monitor.netcdf_monitor import NetCDFMonitor -from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport @@ -70,7 +69,6 @@ "SubtileGridSizer", "ndsl_log", "NetCDFMonitor", - "Namelist", "NullPerformanceCollector", "PerformanceCollector", "NullProfiler", diff --git a/ndsl/namelist.py b/ndsl/namelist.py deleted file mode 100644 index 00180ace..00000000 --- a/ndsl/namelist.py +++ /dev/null @@ -1,647 +0,0 @@ -import dataclasses -import warnings -from typing import Any, Self - -import f90nml - - -DEFAULT_INT = 0 -DEFAULT_STR = "" -DEFAULT_FLOAT = 0.0 -DEFAULT_BOOL = False - - -# Global set of namelist defaults, attached to class for namespacing and static typing -class NamelistDefaults: - layout = (1, 1) - grid_type = 0 - dx_const = 1000.0 - dy_const = 1000.0 - deglat = 15.0 - u_max = 350.0 - do_f3d = False - inline_q = False - do_skeb = False - """Save dissipation estimate""" - use_logp = False - moist_phys = True - check_negative = False - # gfdl_cloud_microphys.F90 - tau_r2g = 900.0 - """rain freezing during fast_sat""" - tau_smlt = 900.0 - """snow melting""" - tau_g2r = 600.0 - """graupel melting to rain""" - tau_imlt = 600.0 - """cloud ice melting""" - tau_i2s = 1000.0 - """cloud ice to snow auto - conversion""" - tau_l2r = 900.0 - """cloud water to rain auto - conversion""" - tau_g2v = 1200.0 - """graupel sublimation""" - tau_v2g = 21600.0 - """graupel deposition -- make it a slow process""" - sat_adj0 = 0.90 - """adjustment factor (0: no, 1: full) during fast_sat_adj""" - ql_gen = 1.0e-3 - """max new cloud water during remapping step if fast_sat_adj = .t.""" - ql_mlt = 2.0e-3 - """max value of cloud water allowed from melted cloud ice""" - qs_mlt = 1.0e-6 - """max cloud water due to snow melt""" - ql0_max = 2.0e-3 - """max cloud water value (auto converted to rain)""" - t_sub = 184.0 - """min temp for sublimation of cloud ice""" - qi_gen = 1.82e-6 - """max cloud ice generation during remapping step""" - qi_lim = 1.0 - """cloud ice limiter to prevent large ice build up""" - qi0_max = 1.0e-4 - """max cloud ice value (by other sources)""" - rad_snow = True - """consider snow in cloud fraction calculation""" - rad_rain = True - """consider rain in cloud fraction calculation""" - rad_graupel = True - """consider graupel in cloud fraction calculation""" - tintqs = False - """use temperature in the saturation mixing in PDF""" - dw_ocean = 0.10 - """base value for ocean""" - dw_land = 0.15 - """base value for subgrid deviation / variability over land""" - # cloud scheme 0 - ? - # 1: old fvgfs gfdl) mp implementation - # 2: binary cloud scheme (0 / 1) - icloud_f = 0 - cld_min = 0.05 - """minimum cloud fraction""" - tau_l2v = 300.0 - """cloud water to water vapor (evaporation)""" - tau_v2l = 90.0 - """water vapor to cloud water (condensation)""" - c2l_ord = 4 - regional = False - m_split = 0 - convert_ke = False - breed_vortex_inline = False - use_old_omega = True - use_logp = False - rf_fast = False - p_ref = 1e5 - """Surface pressure used to construct a horizontally-uniform reference""" - adiabatic = False - nf_omega = 1 - fv_sg_adj = -1 - n_sponge = 1 - fast_sat_adj = True - qc_crt = 5.0e-8 - """Minimum condensate mixing ratio to allow partial cloudiness""" - c_cracw = 0.8 - """Rain accretion efficiency""" - c_paut = 0.5 - """Autoconversion cloud water to rain (use 0.5 to reduce autoconversion)""" - c_pgacs = 0.01 - """Snow to graupel "accretion" eff. (was 0.1 in zetac)""" - c_psaci = 0.05 - """Accretion: cloud ice to snow (was 0.1 in zetac)""" - ccn_l = 300.0 - """CCN over land (cm^-3)""" - ccn_o = 100.0 - """CCN over ocean (cm^-3)""" - const_vg = False - """Fall velocity tuning constant of graupel""" - const_vi = False - """Fall velocity tuning constant of ice""" - const_vr = False - """Fall velocity tuning constant of rain water""" - const_vs = False - """Fall velocity tuning constant of snow""" - vi_fac = 1.0 - """if const_vi: 1/3""" - vs_fac = 1.0 - """if const_vs: 1.""" - vg_fac = 1.0 - """if const_vg: 2.""" - vr_fac = 1.0 - """if const_vr: 4.""" - de_ice = False - """To prevent excessive build-up of cloud ice from external sources""" - do_qa = True - """Do inline cloud fraction""" - do_sedi_heat = False - """Transport of heat in sedimentation""" - do_sedi_w = True - """Transport of vertical motion in sedimentation""" - fix_negative = True - """Fix negative water species""" - irain_f = 0 - """Cloud water to rain auto conversion scheme""" - mono_prof = False - """Perform terminal fall with mono ppm scheme""" - mp_time = 225.0 - """Maximum microphysics timestep (sec)""" - prog_ccn = False - """Do prognostic ccn (yi ming's method)""" - qi0_crt = 8e-05 - """Cloud ice to snow autoconversion threshold""" - qs0_crt = 0.003 - """Snow to graupel density threshold (0.6e-3 in purdue lin scheme)""" - rh_inc = 0.2 - """RH increment for complete evaporation of cloud water and cloud ice""" - rh_inr = 0.3 - """RH increment for minimum evaporation of rain""" - rthresh = 1e-05 - """Critical cloud drop radius (micrometers)""" - sedi_transport = True - """Transport of momentum in sedimentation""" - use_ppm = False - """Use ppm fall scheme""" - vg_max = 16.0 - """Maximum fall speed for graupel""" - vi_max = 1.0 - """Maximum fall speed for ice""" - vr_max = 16.0 - """Maximum fall speed for rain""" - vs_max = 2.0 - """Maximum fall speed for snow""" - z_slope_ice = True - """Use linear mono slope for autoconversions""" - z_slope_liq = True - """Use linear mono slope for autoconversions""" - tice = 273.16 - """set tice = 165. to turn off ice - phase phys (kessler emulator)""" - alin = 842.0 - """value for 'a' in lin1983""" - clin = 4.8 - """"c" in lin 1983, 4.8 -- > 6. (to enhance ql -- > qs)""" - mom4ice = False - lsm = 1 - redrag = False - isatmedmf = 0 - """which version of satmedmfvdif to use""" - dspheat = False - """flag for tke dissipative heating""" - xkzm_h = 1.0 - """background vertical diffusion for heat q over ocean""" - xkzm_m = 1.0 - """background vertical diffusion for momentum over ocean""" - xkzm_hl = 1.0 - """background vertical diffusion for heat q over land""" - xkzm_ml = 1.0 - """background vertical diffusion for momentum over land""" - xkzm_hi = 1.0 - """background vertical diffusion for heat q over ice""" - xkzm_mi = 1.0 - """background vertical diffusion for momentum over ice""" - xkzm_ho = 1.0 - """background vertical diffusion for heat q over ocean""" - xkzm_mo = 1.0 - """background vertical diffusion for momentum over ocean""" - xkzm_s = 1.0 - """sigma threshold for background mom. diffusion""" - xkzm_lim = 0.01 - """background vertical diffusion limit""" - xkzminv = 0.15 - """diffusivity in inversion layers""" - xkgdx = 25.0e3 - """background vertical diffusion threshold""" - rlmn = 30.0 - """lower-limiter on asymtotic mixing length in satmedmfdiff""" - rlmx = 300.0 - """upper-limiter on asymtotic mixing length in satmedmfdiff""" - do_dk_hb19 = False - """flag for using hb19 background diff formula in satmedmfdiff""" - cap_k0_land = True - """flag for applying limter on background diff in inversion layer over land in satmedmfdiff""" - ncld = 1 - """choice of cloud scheme""" - c0s_shal = 0.002 - """c_e for shallow convection (Han and Pan, 2011, eq(6))""" - c1_shal = 5.0e-4 - """conversion parameter of detrainment from liquid water into convetive precipitaiton""" - clam_shal = 0.3 - """conversion parameter of detrainment from liquid water into grid-scale cloud water""" - pgcon_shal = 0.55 - """control the reduction in momentum transport""" - asolfac_shal = 0.89 - """aerosol-aware parameter based on Lim & Hong (2012): asolfac= cx / c0s(=.002), cx = min([-0.7 ln(Nccn) + 24]*1.e-4, c0s), Nccn: CCN number concentration in cm^(-3), Until a realistic Nccn is provided, typical Nccns are assumed, as Nccn=100 for sea and Nccn=7000 for land""" - lsoil = 4 - """Number of soil levels in land surface model""" - sw_dynamics = False - """flag for turning on shallow water conditions in dyn core""" - - @classmethod - def as_dict(cls) -> dict: - return { - name: default - for name, default in cls.__dict__.items() - if not name.startswith("_") - } - - -@dataclasses.dataclass -class Namelist: - # data_set: Any - # date_out_of_range: str - # do_sst_pert: bool - # interp_oi_sst: bool - # no_anom_sst: bool - # sst_pert: float - # sst_pert_type: str - # use_daily: bool - # use_ncep_ice: bool - # use_ncep_sst: bool - # blocksize: int - # chksum_debug: bool - """ - note: dycore_only may not be used in this model - the same way it is in the Fortran version, watch for - consequences of these inconsistencies, or more closely - parallel the Fortran structure - """ - dycore_only: bool = DEFAULT_BOOL - # fdiag: float - # knob_ugwp_azdir: tuple[int, int, int, int] - # knob_ugwp_doaxyz: int - # knob_ugwp_doheat: int - # knob_ugwp_dokdis: int - # knob_ugwp_effac: tuple[int, int, int, int] - # knob_ugwp_ndx4lh: int - # knob_ugwp_solver: int - # knob_ugwp_source: tuple[int, int, int, int] - # knob_ugwp_stoch: tuple[int, int, int, int] - # knob_ugwp_version: int - # knob_ugwp_wvspec: tuple[int, int, int, int] - # launch_level: int - # reiflag: int - # reimax: float - # reimin: float - # rewmax: float - # rewmin: float - # atmos_nthreads: int - # calendar: Any - # current_date: Any - days: int = 0 - dt_atmos: int = DEFAULT_INT - # dt_ocean: Any - hours: int = 0 - # memuse_verbose: Any - minutes: int = 0 - # months: Any - # ncores_per_node: Any - seconds: int = 0 - # use_hyper_thread: Any - # max_axes: Any - # max_files: Any - # max_num_axis_sets: Any - # prepend_date: Any - # checker_tr: Any - # filtered_terrain: Any - # gfs_dwinds: Any - # levp: Any - # nt_checker: Any - # checksum_required: Any - # max_files_r: Any - # max_files_w: Any - # clock_grain: Any - # domains_stack_size: Any - # print_memory_usage: Any - a_imp: float = DEFAULT_FLOAT - # adjust_dry_mass: Any - beta: float = DEFAULT_FLOAT - # consv_am: Any - consv_te: float = DEFAULT_FLOAT - d2_bg: float = DEFAULT_FLOAT - d2_bg_k1: float = DEFAULT_FLOAT - d2_bg_k2: float = DEFAULT_FLOAT - d4_bg: float = DEFAULT_FLOAT - d_con: float = DEFAULT_FLOAT - d_ext: float = DEFAULT_FLOAT - dddmp: float = DEFAULT_FLOAT - delt_max: float = DEFAULT_FLOAT - # dnats: int - do_sat_adj: bool = DEFAULT_BOOL - do_vort_damp: bool = DEFAULT_BOOL - # dwind_2d: Any - # external_ic: Any - fill: bool = DEFAULT_BOOL - # fill_dp: bool - # fv_debug: Any - # gfs_phil: Any - hord_dp: int = DEFAULT_INT - hord_mt: int = DEFAULT_INT - hord_tm: int = DEFAULT_INT - hord_tr: int = DEFAULT_INT - hord_vt: int = DEFAULT_INT - hydrostatic: bool = DEFAULT_BOOL - # io_layout: Any - k_split: int = DEFAULT_INT - ke_bg: float = DEFAULT_FLOAT - kord_mt: int = DEFAULT_INT - kord_tm: int = DEFAULT_INT - kord_tr: int = DEFAULT_INT - kord_wz: int = DEFAULT_INT - layout: tuple[int, int] = (1, 1) - # make_nh: bool - # mountain: bool - n_split: int = DEFAULT_INT - # na_init: Any - # ncep_ic: Any - # nggps_ic: Any - nord: int = DEFAULT_INT - npx: int = DEFAULT_INT - npy: int = DEFAULT_INT - npz: int = DEFAULT_INT - ntiles: int = DEFAULT_INT - # nudge: Any - # nudge_qv: Any - nwat: int = DEFAULT_INT - p_fac: float = DEFAULT_FLOAT - # phys_hydrostatic: Any - # print_freq: Any - # range_warn: Any - # reset_eta: Any - rf_cutoff: float = DEFAULT_FLOAT - tau: float = DEFAULT_FLOAT - # tau_h2o: Any - # use_hydro_pressure: Any - vtdm4: float = DEFAULT_FLOAT - # warm_start: bool - z_tracer: bool = DEFAULT_BOOL - c_cracw: float = NamelistDefaults.c_cracw - c_paut: float = NamelistDefaults.c_paut - c_pgacs: float = NamelistDefaults.c_pgacs - c_psaci: float = NamelistDefaults.c_psaci - ccn_l: float = NamelistDefaults.ccn_l - ccn_o: float = NamelistDefaults.ccn_o - const_vg: bool = NamelistDefaults.const_vg - const_vi: bool = NamelistDefaults.const_vi - const_vr: bool = NamelistDefaults.const_vr - const_vs: bool = NamelistDefaults.const_vs - qc_crt: float = NamelistDefaults.qc_crt - vs_fac: float = NamelistDefaults.vs_fac - vg_fac: float = NamelistDefaults.vg_fac - vi_fac: float = NamelistDefaults.vi_fac - vr_fac: float = NamelistDefaults.vr_fac - de_ice: bool = NamelistDefaults.de_ice - do_qa: bool = NamelistDefaults.do_qa - do_sedi_heat: bool = NamelistDefaults.do_sedi_heat - do_sedi_w: bool = NamelistDefaults.do_sedi_w - fast_sat_adj: bool = NamelistDefaults.fast_sat_adj - fix_negative: bool = NamelistDefaults.fix_negative - irain_f: int = NamelistDefaults.irain_f - mono_prof: bool = NamelistDefaults.mono_prof - mp_time: float = NamelistDefaults.mp_time - prog_ccn: bool = NamelistDefaults.prog_ccn - qi0_crt: float = NamelistDefaults.qi0_crt - qs0_crt: float = NamelistDefaults.qs0_crt - rh_inc: float = NamelistDefaults.rh_inc - rh_inr: float = NamelistDefaults.rh_inr - # rh_ins: Any - rthresh: float = NamelistDefaults.rthresh - sedi_transport: bool = NamelistDefaults.sedi_transport - # use_ccn: Any - use_ppm: bool = NamelistDefaults.use_ppm - vg_max: float = NamelistDefaults.vg_max - vi_max: float = NamelistDefaults.vi_max - vr_max: float = NamelistDefaults.vr_max - vs_max: float = NamelistDefaults.vs_max - z_slope_ice: bool = NamelistDefaults.z_slope_ice - z_slope_liq: bool = NamelistDefaults.z_slope_liq - tice: float = NamelistDefaults.tice - alin: float = NamelistDefaults.alin - clin: float = NamelistDefaults.clin - mom4ice: bool = NamelistDefaults.mom4ice - lsm: int = NamelistDefaults.lsm - redrag: bool = NamelistDefaults.redrag - isatmedmf: int = NamelistDefaults.isatmedmf - dspheat: bool = NamelistDefaults.dspheat - xkzm_h: float = NamelistDefaults.xkzm_h - xkzm_m: float = NamelistDefaults.xkzm_m - xkzm_hl: float = NamelistDefaults.xkzm_hl - xkzm_ml: float = NamelistDefaults.xkzm_ml - xkzm_hi: float = NamelistDefaults.xkzm_hi - xkzm_mi: float = NamelistDefaults.xkzm_mi - xkzm_ho: float = NamelistDefaults.xkzm_ho - xkzm_mo: float = NamelistDefaults.xkzm_mo - xkzm_s: float = NamelistDefaults.xkzm_s - xkzm_lim: float = NamelistDefaults.xkzm_lim - xkzminv: float = NamelistDefaults.xkzminv - xkgdx: float = NamelistDefaults.xkgdx - rlmn: float = NamelistDefaults.rlmn - rlmx: float = NamelistDefaults.rlmx - do_dk_hb19: bool = NamelistDefaults.do_dk_hb19 - cap_k0_land: bool = NamelistDefaults.cap_k0_land - c0s_shal: float = NamelistDefaults.c0s_shal - c1_shal: float = NamelistDefaults.c1_shal - clam_shal: float = NamelistDefaults.clam_shal - pgcon_shal: float = NamelistDefaults.pgcon_shal - asolfac_shal: float = NamelistDefaults.asolfac_shal - ncld: int = NamelistDefaults.ncld - # cal_pre: Any - # cdmbgwd: Any - # cnvcld: Any - # cnvgwd: Any - # debug: Any - # do_deep: Any - # dspheat: Any - # fhcyc: Any - # fhlwr: Any - # fhswr: Any - # fhzero: Any - # hybedmf: Any - # iaer: Any - # ialb: Any - # ico2: Any - # iems: Any - # imfdeepcnv: Any - # imfshalcnv: Any - # imp_physics: Any - # isol: Any - # isot: Any - # isubc_lw: Any - # isubc_sw: Any - # ivegsrc: Any - # ldiag3d: Any - # lwhtr: Any - # nst_anl: Any - # pdfcld: Any - # pre_rad: Any - # prslrd0: Any - # random_clds: Any - # redrag: Any - # satmedmf: Any - # shal_cnv: Any - # swhtr: Any - # trans_trac: Any - # use_ufo: Any - # xkzm_h: Any - # xkzm_m: Any - # xkzminv: Any - # interp_method: Any - # lat_s: Any - # lon_s: Any - # ntrunc: Any - # fabsl: Any - # faisl: Any - # faiss: Any - # fnabsc: Any - # fnacna: Any - # fnaisc: Any - # fnalbc: Any - # fnalbc2: Any - # fnglac: Any - # fnmskh: Any - # fnmxic: Any - # fnslpc: Any - # fnsmcc: Any - # fnsnoa: Any - # fnsnoc: Any - # fnsotc: Any - # fntg3c: Any - # fntsfa: Any - # fntsfc: Any - # fnvegc: Any - # fnvetc: Any - # fnvmnc: Any - # fnvmxc: Any - # fnzorc: Any - # fsicl: Any - # fsics: Any - # fslpl: Any - # fsmcl: Any - # fsnol: Any - # fsnos: Any - # fsotl: Any - # ftsfl: Any - # ftsfs: Any - # fvetl: Any - # fvmnl: Any - # fvmxl: Any - # ldebug: Any - grid_type: int = NamelistDefaults.grid_type - dx_const: float = NamelistDefaults.dx_const - dy_const: float = NamelistDefaults.dy_const - deglat: float = NamelistDefaults.deglat - u_max: float = NamelistDefaults.u_max - do_f3d: bool = NamelistDefaults.do_f3d - inline_q: bool = NamelistDefaults.inline_q - do_skeb: bool = NamelistDefaults.do_skeb - """save dissipation estimate""" - use_logp: bool = NamelistDefaults.use_logp - moist_phys: bool = NamelistDefaults.moist_phys - check_negative: bool = NamelistDefaults.check_negative - # gfdl_cloud_microphys.F90 - tau_r2g: float = NamelistDefaults.tau_r2g - """rain freezing during fast_sat""" - tau_smlt: float = NamelistDefaults.tau_smlt - """snow melting""" - tau_g2r: float = NamelistDefaults.tau_g2r - """graupel melting to rain""" - tau_imlt: float = NamelistDefaults.tau_imlt - """cloud ice melting""" - tau_i2s: float = NamelistDefaults.tau_i2s - """cloud ice to snow auto - conversion""" - tau_l2r: float = NamelistDefaults.tau_l2r - """cloud water to rain auto - conversion""" - tau_g2v: float = NamelistDefaults.tau_g2v - """graupel sublimation""" - tau_v2g: float = NamelistDefaults.tau_v2g - """graupel deposition -- make it a slow process""" - sat_adj0: float = NamelistDefaults.sat_adj0 - """adjustment factor (0: no 1: full) during fast_sat_adj""" - ql_gen: float = 1.0e-3 - """max new cloud water during remapping step if fast_sat_adj = .t.""" - ql_mlt: float = NamelistDefaults.ql_mlt - """max value of cloud water allowed from melted cloud ice""" - qs_mlt: float = NamelistDefaults.qs_mlt - """max cloud water due to snow melt""" - ql0_max: float = NamelistDefaults.ql0_max - """max cloud water value (auto converted to rain)""" - t_sub: float = NamelistDefaults.t_sub - """min temp for sublimation of cloud ice""" - qi_gen: float = NamelistDefaults.qi_gen - """max cloud ice generation during remapping step""" - qi_lim: float = NamelistDefaults.qi_lim - """cloud ice limiter to prevent large ice build up""" - qi0_max: float = NamelistDefaults.qi0_max - """max cloud ice value (by other sources)""" - rad_snow: bool = NamelistDefaults.rad_snow - """consider snow in cloud fraction calculation""" - rad_rain: bool = NamelistDefaults.rad_rain - """consider rain in cloud fraction calculation""" - rad_graupel: bool = NamelistDefaults.rad_graupel - """consider graupel in cloud fraction calculation""" - tintqs: bool = NamelistDefaults.tintqs - """use temperature in the saturation mixing in PDF""" - dw_ocean: float = NamelistDefaults.dw_ocean - """base value for ocean""" - dw_land: float = NamelistDefaults.dw_land - """base value for subgrid deviation / variability over land""" - # cloud scheme 0 - ? - # 1: old fvgfs gfdl) mp implementation - # 2: binary cloud scheme (0 / 1) - icloud_f: int = NamelistDefaults.icloud_f - cld_min: float = NamelistDefaults.cld_min - """minimum cloud fraction""" - tau_l2v: float = NamelistDefaults.tau_l2v - """cloud water to water vapor (evaporation)""" - tau_v2l: float = NamelistDefaults.tau_v2l - """water vapor to cloud water (condensation)""" - c2l_ord: int = NamelistDefaults.c2l_ord - regional: bool = NamelistDefaults.regional - m_split: int = NamelistDefaults.m_split - convert_ke: bool = NamelistDefaults.convert_ke - breed_vortex_inline: bool = NamelistDefaults.breed_vortex_inline - use_old_omega: bool = NamelistDefaults.use_old_omega - rf_fast: bool = NamelistDefaults.rf_fast - adiabatic: bool = NamelistDefaults.adiabatic - nf_omega: int = NamelistDefaults.nf_omega - fv_sg_adj: int = NamelistDefaults.fv_sg_adj - n_sponge: int = NamelistDefaults.n_sponge - lsoil: int = NamelistDefaults.lsoil - daily_mean: bool = False - sw_dynamics: bool = NamelistDefaults.sw_dynamics - """Flag to replace cosz with daily mean value in physics""" - - @classmethod - def from_f90nml(cls, namelist: f90nml.Namelist) -> Self: - namelist_dict = namelist_to_flatish_dict(namelist.items()) - namelist_dict = { - key: value - for key, value in namelist_dict.items() - if key in cls.__dataclass_fields__ - } - return cls(**namelist_dict) - - def __post_init__(self) -> None: - warnings.warn( - "Usage of `ndsl.Namelist` is discouraged. The class will be " - "removed in the next version together with `NamelistDefaults`, see " - "https://github.com/NOAA-GFDL/NDSL/issues/64.", - DeprecationWarning, - stacklevel=2, - ) - - -def namelist_to_flatish_dict(nml_input: Any) -> dict: - nml = dict(nml_input) - for name, value in nml.items(): - if isinstance(value, f90nml.Namelist): - nml[name] = namelist_to_flatish_dict(value) - flatter_namelist = {} - for key, value in nml.items(): - if isinstance(value, dict): - for subkey, subvalue in value.items(): - if subkey in flatter_namelist: - raise ValueError( - "Cannot flatten this namelist, duplicate keys: " + subkey - ) - flatter_namelist[subkey] = subvalue - else: - flatter_namelist[key] = value - return flatter_namelist diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 4e0bb428..2ed29e37 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -18,9 +18,6 @@ from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig - -# TODO: Remove NdslNamelist import after Issue#64 is resolved. -from ndsl.namelist import Namelist as NdslNamelist from ndsl.stencils.testing.grid import Grid # type: ignore from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict @@ -82,7 +79,7 @@ def pytest_addoption(parser: pytest.Parser) -> None: "--no_legacy_namelist", action="store_true", default=False, - help="Removes support for `ndsl.Namelist` in translate tests (which we are trying to get rid off, see NDSL issue #64). Defaults to False.", + help="Temporary flag introduced as part of NDSL issue #64. No functionality. Soon to be removed.", ) parser.addoption( "--grid", @@ -268,9 +265,6 @@ def sequential_savepoint_cases( sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") - # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. - use_legacy_namelist = not metafunc.config.getoption("no_legacy_namelist") - return _savepoint_cases( savepoint_names, ranks, @@ -283,7 +277,6 @@ def sequential_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, - use_legacy_namelist=use_legacy_namelist, # Issue#64: tmp flag ) @@ -299,7 +292,6 @@ def _savepoint_cases( topology_mode: str, sort_report: str, no_report: bool, - use_legacy_namelist: bool, # Issue#64: tmp flag ) -> list[SavepointCase]: grid_params = grid_params_from_f90nml(namelist) return_list = [] @@ -335,12 +327,6 @@ def _savepoint_cases( grid_indexing=grid.grid_indexing, ) for test_name in sorted(list(savepoint_names)): - # Temporary check (Issue#64): TODO Remove check and conversion from - # f90nml.Namelist to ndsl.Namelist after ndsl.Namelist is removed - if use_legacy_namelist and not isinstance(namelist, NdslNamelist): - assert isinstance(namelist, Namelist) - namelist = NdslNamelist.from_f90nml(namelist) - testobj = get_test_class_instance( test_name, grid, namelist, stencil_factory ) @@ -402,9 +388,6 @@ def parallel_savepoint_cases( grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = get_savepoint_restriction(metafunc) - # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. - use_legacy_namelist = not metafunc.config.getoption("no_legacy_namelist") - return _savepoint_cases( savepoint_names, [mpi_rank], @@ -417,7 +400,6 @@ def parallel_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, - use_legacy_namelist=use_legacy_namelist, # Issue#64: tmp flag ) diff --git a/ndsl/stencils/testing/parallel_translate.py b/ndsl/stencils/testing/parallel_translate.py index 85b87020..d23d5459 100644 --- a/ndsl/stencils/testing/parallel_translate.py +++ b/ndsl/stencils/testing/parallel_translate.py @@ -7,9 +7,6 @@ from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS from ndsl.dsl import gt4py_utils as utils - -# TODO: Remove once ndsl.Namelist is gone (Issue#64) -from ndsl.namelist import Namelist as NdslNamelist from ndsl.quantity import Quantity from ndsl.stencils.testing.translate import ( TranslateFortranData2Py, @@ -133,12 +130,6 @@ def rank_grids(self): @property def layout(self): - # TODO: Once ndsl.namelist.Namelist is gone (Issue#64), - # remove this check in favor of f90nml.namelist.Namelist - if isinstance(self.namelist, NdslNamelist): - return self.namelist.layout - - # Assumption: namelist is f90nml.namelist.Namelist grid_params = grid_params_from_f90nml(self.namelist) return grid_params["layout"] diff --git a/tests/test_namelist.py b/tests/test_namelist.py deleted file mode 100644 index a32919e7..00000000 --- a/tests/test_namelist.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - -from ndsl import Namelist - - -def test_ndsl_namelist_deprecation() -> None: - with pytest.deprecated_call(): - my_namelist = Namelist() From 2145824cf69b8f82ddb56c93b7a1a1e384f78917 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 31 Oct 2025 08:30:05 -0400 Subject: [PATCH 09/43] [feature] Common data types for orchestration via `compiletime` (#296) * `Quantity`, `Local` & `State` default to `dace.compiletime` auto-magically in orchestration * Fix type check, remove `Local` * Unit tests * Fix for type annotations that aren't type --- ndsl/dsl/dace/orchestration.py | 9 +++++- tests/dsl/orchestration/test_call.py | 44 ++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 43b6c4f3..4d959d15 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -42,6 +42,7 @@ ) from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp +from ndsl.quantity import Quantity, State _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False @@ -543,12 +544,18 @@ def orchestrate( if dace_compiletime_args is None: dace_compiletime_args = [] - func = type.__getattribute__(type(obj), method_to_orchestrate) + func: Callable = type.__getattribute__(type(obj), method_to_orchestrate) # Flag argument as dace.constant for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime + for arg_name, annotation in func.__annotations__.items(): + if annotation in [Quantity, State] or ( + isinstance(annotation, type) and issubclass(annotation, State) + ): + func.__annotations__[arg_name] = DaceCompiletime + # Build DaCe orchestrated wrapper # This is a JIT object, e.g. DaCe compilation will happen on call wrapped = _LazyComputepathMethod(func, config).__get__(obj) diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py index d0c9f1b0..2ff2c16a 100644 --- a/tests/dsl/orchestration/test_call.py +++ b/tests/dsl/orchestration/test_call.py @@ -1,8 +1,11 @@ -from ndsl import QuantityFactory, StencilFactory +import dataclasses + +from ndsl import NDSLRuntime, QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile_orchestrated -from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval +from ndsl.quantity import Quantity, State def _stencil(out: Field[float]): @@ -40,3 +43,40 @@ def test_memory_reallocation(): code(qty_B) assert (qty_A.field[0, 0, :] == 3).all() assert (qty_B.field[0, 0, :] == 2).all() + + +@dataclasses.dataclass +class AState(State): + the_quantity: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +class DefaultTypeProgram(NDSLRuntime): + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ): + super().__init__(stencil_factory.config.dace_config) + self.stencil = stencil_factory.from_dims_halo(_stencil, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, a_quantity: Quantity, a_state: AState): + self.stencil(a_quantity) + self.stencil(a_state.the_quantity) + + +def test_default_types_are_compiletime(): + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + 5, 5, 2, 0 + ) + qty_A = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "A") + state_A = AState.zeros(quantity_factory) + code = DefaultTypeProgram(stencil_factory, quantity_factory) + code(qty_A, state_A) From f0f6798c403108a902deca3762171867dded2e28 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 31 Oct 2025 13:57:04 +0100 Subject: [PATCH 10/43] BREAKING CHANGE: remove deprecated ndsl/units.py (#283) The module has been deprecated in the last release and is now removed in this release cycle. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- docs/docstrings/top/units.md | 3 --- mkdocs.yml | 1 - ndsl/units.py | 33 --------------------------------- tests/test_utils.py | 18 ------------------ 4 files changed, 55 deletions(-) delete mode 100644 docs/docstrings/top/units.md delete mode 100644 ndsl/units.py delete mode 100644 tests/test_utils.py diff --git a/docs/docstrings/top/units.md b/docs/docstrings/top/units.md deleted file mode 100644 index 5d2efcf8..00000000 --- a/docs/docstrings/top/units.md +++ /dev/null @@ -1,3 +0,0 @@ -# units - -::: units diff --git a/mkdocs.yml b/mkdocs.yml index 34692b3f..2019c893 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,7 +31,6 @@ nav: - "optional_imports": docstrings/top/optional_imports.md - "types": docstrings/top/types.md - "typing": docstrings/top/typing.md - - "units": docstrings/top/units.md - "utils": docstrings/top/utils.md - checkpointer: - "base": docstrings/checkpointer/base.md diff --git a/ndsl/units.py b/ndsl/units.py deleted file mode 100644 index 7fb441b1..00000000 --- a/ndsl/units.py +++ /dev/null @@ -1,33 +0,0 @@ -import warnings - - -def ensure_equal_units(units1: str, units2: str) -> None: - warnings.warn( - "`ensure_equal_units` is unused and usage is discouraged. The function " - "will be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - if not units_are_equal(units1, units2): - raise UnitsError(f"incompatible units {units1} and {units2}") - - -def units_are_equal(units1: str, units2: str) -> bool: - warnings.warn( - "`units_are_equal` is unused and usage is discouraged. The function will " - "be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - return units1.strip() == units2.strip() - - -class UnitsError(Exception): - def __init__(self, *args) -> None: - warnings.warn( - "`UnitsError` is unused and usage is discouraged. The class will be " - "removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 70dd9739..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from ndsl.units import UnitsError, ensure_equal_units, units_are_equal - - -def test_UnitsError_is_deprecated() -> None: - with pytest.deprecated_call(): - UnitsError() - - -def test_units_are_equal_is_deprecated() -> None: - with pytest.deprecated_call(): - units_are_equal("asdf", "asdf") - - -def test_ensure_equal_units_is_deprecated() -> None: - with pytest.deprecated_call(): - ensure_equal_units("asdf", "asdf") From 616d02b8c512b51c7319351b301d93ff9557d1fd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 31 Oct 2025 13:57:17 +0100 Subject: [PATCH 11/43] BREAKING CHANGE: removal of extra_dim_lengths (#295) `extra_dim_lengths` on the `GridSizer` was replaced by `data_dimensions` in the `2025.10.00` release. Now that the release is out, let's clean up and remove the deprecated API. This also includes `set_extra_dim_lengths()` in the `QuantityFactory`. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/initialization/allocator.py | 12 ------------ ndsl/initialization/grid_sizer.py | 10 ---------- 2 files changed, 22 deletions(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index f421c712..66332ec8 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -47,18 +47,6 @@ def __init__(self, sizer: GridSizer, *, backend: str) -> None: self._allocator = _Allocator(self.backend) - def set_extra_dim_lengths(self, **kwargs: Any) -> None: - """ - Set the length of extra (non-x/y/z) dimensions. - """ - warnings.warn( - "`QuantityFactory.set_extra_dim_lengths` is deprecated. " - "Use `add_data_dimensions` or `update_data_dimensions`.", - DeprecationWarning, - 2, - ) - self.sizer.data_dimensions.update(kwargs) - def update_data_dimensions( self, data_dimension_descriptions: dict[str, int], diff --git a/ndsl/initialization/grid_sizer.py b/ndsl/initialization/grid_sizer.py index 961ab793..f3347e27 100644 --- a/ndsl/initialization/grid_sizer.py +++ b/ndsl/initialization/grid_sizer.py @@ -1,4 +1,3 @@ -import warnings from collections.abc import Sequence from dataclasses import dataclass @@ -16,15 +15,6 @@ class GridSizer: data_dimensions: dict[str, int] """Name/Lengths pair of any non-x/y/z dimensions, such as land or radiation dimensions.""" - @property - def extra_dim_lengths(self) -> dict[str, int]: - warnings.warn( - "`GridSizer.extra_dim_lengths` is a deprecated API, use `GridSizer.data_dimensions`.", - DeprecationWarning, - 2, - ) - return self.data_dimensions - def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]: raise NotImplementedError() From 911ae320be023f6de8dfa6fae1fe415f952e0260 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 31 Oct 2025 13:57:40 +0100 Subject: [PATCH 12/43] BREAKING CHANGE: remove deprecated ndsl/filesystem.py (#284) The module was deprecated in the last release and will now be remove in this release cycle. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- docs/docstrings/top/filesystem.md | 3 --- mkdocs.yml | 1 - ndsl/filesystem.py | 37 ------------------------------- setup.py | 1 - tests/test_filesystem.py | 19 ---------------- 5 files changed, 61 deletions(-) delete mode 100644 docs/docstrings/top/filesystem.md delete mode 100644 ndsl/filesystem.py delete mode 100644 tests/test_filesystem.py diff --git a/docs/docstrings/top/filesystem.md b/docs/docstrings/top/filesystem.md deleted file mode 100644 index 9f05cc3a..00000000 --- a/docs/docstrings/top/filesystem.md +++ /dev/null @@ -1,3 +0,0 @@ -# filesystem - -::: filesystem diff --git a/mkdocs.yml b/mkdocs.yml index 2019c893..60fc8e6a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,7 +24,6 @@ nav: - "boilerplate": docstrings/top/boilerplate.md - "buffer": docstrings/top/buffer.md - "constants": docstrings/top/constants.md - - "filesystem": docstrings/top/filesystem.md - "io": docstrings/top/io.md - "logging": docstrings/top/logging.md - "namelist": docstrings/top/namelist.md diff --git a/ndsl/filesystem.py b/ndsl/filesystem.py deleted file mode 100644 index df2e709f..00000000 --- a/ndsl/filesystem.py +++ /dev/null @@ -1,37 +0,0 @@ -import warnings - -import fsspec - - -def get_fs(path: str) -> fsspec.AbstractFileSystem: - """Return the fsspec filesystem required to handle a given path.""" - warnings.warn( - "Usage of `get_fs()` is discouraged if favor `os.path` and `pathlib` " - "modules. The function will be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - fs, _, _ = fsspec.get_fs_token_paths(path) - return fs - - -def is_file(filename): - warnings.warn( - "Usage of `is_file()` is discouraged if favor of plain `os.path.isfile()`. " - "The function will be removed in the next version of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - return get_fs(filename).isfile(filename) - - -def open(filename, *args, **kwargs): - warnings.warn( - "Usage of `open()` is discouraged if favor the python built-in file " - "open context manager. The function will be removed in the next version " - "of NDSL.", - DeprecationWarning, - stacklevel=2, - ) - fs = get_fs(filename) - return fs.open(filename, *args, **kwargs) diff --git a/setup.py b/setup.py index 1719226f..6d109c36 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ def local_pkg(name: str, relative_path: str) -> str: "cftime", "xarray>=2025.01.2", # datatree + fixes "f90nml>=1.1.0", - "fsspec", "netcdf4==1.7.2", "scipy", # restart capacities only "h5netcdf", # for xarray diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py deleted file mode 100644 index 5f05af82..00000000 --- a/tests/test_filesystem.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -import ndsl.filesystem as fs - - -def test_is_file_is_deprecated() -> None: - with pytest.deprecated_call(): - fs.is_file("path/to/my_file.txt") - - -def test_open_is_deprecated() -> None: - with pytest.deprecated_call(): - with fs.open("README.md", "r"): - pass - - -def test_get_fs_is_deprecated() -> None: - with pytest.deprecated_call(): - fs.get_fs("path/to/my/file.txt") From a89dd22dd2533add8516e684c8e7bc35d106a265 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 3 Nov 2025 14:48:51 +0100 Subject: [PATCH 13/43] docs: release checklist and documentation (#299) * release checklist and documentation * Add template for patch release * review --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .github/PULL_REQUEST_TEMPLATE/README.md | 7 ++++ .../pull_request_template.md} | 0 .../PULL_REQUEST_TEMPLATE/release-patch.md | 32 +++++++++++++++++++ .github/PULL_REQUEST_TEMPLATE/release.md | 26 +++++++++++++++ .github/workflows/docs_build.yaml | 2 +- .github/workflows/docs_deploy.yaml | 2 +- docs/internal/README.md | 3 ++ docs/internal/release.md | 24 ++++++++++++++ mkdocs.yml | 3 ++ pyproject.toml | 2 +- 10 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 .github/PULL_REQUEST_TEMPLATE/README.md rename .github/{PULL_REQUEST_TEMPLATE.md => PULL_REQUEST_TEMPLATE/pull_request_template.md} (100%) create mode 100644 .github/PULL_REQUEST_TEMPLATE/release-patch.md create mode 100644 .github/PULL_REQUEST_TEMPLATE/release.md create mode 100644 docs/internal/README.md create mode 100644 docs/internal/release.md diff --git a/.github/PULL_REQUEST_TEMPLATE/README.md b/.github/PULL_REQUEST_TEMPLATE/README.md new file mode 100644 index 00000000..618d2e17 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/README.md @@ -0,0 +1,7 @@ +# Pull request templates + +- `pull_request_template.md`: The default pull request template. Used for PRs. +- `release.md`: Special template used for releasing a new version of NDSL. +- `release-patch.md`: Special template used for patch releases. + +Note: GitHub has limited support for multiple pull request templates. Most notably, templates can only be [selected by an URL query parameter](https://github.com/orgs/community/discussions/4620) and there's currently [no way to set a title](https://github.com/orgs/community/discussions/63965). diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md similarity index 100% rename from .github/PULL_REQUEST_TEMPLATE.md rename to .github/PULL_REQUEST_TEMPLATE/pull_request_template.md diff --git a/.github/PULL_REQUEST_TEMPLATE/release-patch.md b/.github/PULL_REQUEST_TEMPLATE/release-patch.md new file mode 100644 index 00000000..c588dd8f --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/release-patch.md @@ -0,0 +1,32 @@ +# Release NDSL version `YYYY.MM.PP` + +This PR patches release `YYYY.MM.PP` because + +1. reason +2. reason +3. ... + +## Pre-release checklist + +Things to do before the patch release. Helps to keep the fallout from this release as minimal as possible. + +- [ ] setup a draft PR in [NOAA-GFDL/pace](https://github.com/NOAA-GFDL/pace) with updated submodules for `NDSL`, `pyFV3`, and `pySHiELD`. + Don't merge yet - just let CI run and fix potential issues before the release. To be merged afterwards, see post-release checklist. + +## Release checklist + +What to do to actually release: + +- [x] create this PR to merge changes from `my-patches` into `main` + - use "squash merge" +- [ ] once merged, create a GitHub release and tag the new version + - version format is `[year].[month].[patch]`. Increase the patch version, e.g. `2025.10.01` if this is patching the `2025.10.00` release. + - let GitHub auto-generate release notes from the last tagged version +- [ ] send an announcement on Mattermost + +## Post-release checklist + +What to do after a release: + +- [ ] update the pace PR from the pre-commit checklist to include the released version of NDSL and merge it. +- [ ] in NDSL, merge `main` back into `develop` (potentially adding a commit to fix the issue "properly") diff --git a/.github/PULL_REQUEST_TEMPLATE/release.md b/.github/PULL_REQUEST_TEMPLATE/release.md new file mode 100644 index 00000000..00aa3403 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/release.md @@ -0,0 +1,26 @@ +# Release NDSL version `YYYY.MM.00` + +## Pre-release checklist + +Things to do before the release. Helps to keep the fallout from this release as minimal as possible. + +- [ ] setup a draft PR in [NOAA-GFDL/pace](https://github.com/NOAA-GFDL/pace) with updated submodules for `NDSL`, `pyFV3`, and `pySHiELD`. + Don't merge yet - just let CI run and fix potential issues before the release. To be merged afterwards, see post-release checklist. + +## Release checklist + +What to do to actually release: + +- [x] create this PR to merge changes from `develop` into `main` + - merge as "Merge commit" +- [ ] once merged, create a GitHub release and tag the new version + - version format is `[year].[month].[patch]`, e.g. `2025.10.00` + - let GitHub auto-generate release notes from the last tagged version +- [ ] send an announcement on Mattermost + +## Post-release checklist + +What to do after a release: + +- [ ] update the pace PR from the pre-commit checklist to include the released version of NDSL and merge it. +- [ ] merge breaking changes in NDSL (e.g. search for deprecation warnings) diff --git a/.github/workflows/docs_build.yaml b/.github/workflows/docs_build.yaml index 3b0f1ed1..4b59ed78 100644 --- a/.github/workflows/docs_build.yaml +++ b/.github/workflows/docs_build.yaml @@ -28,7 +28,7 @@ jobs: python-version: '3.11' - name: Install mkdocs - run: pip install mkdocs-material mkdocstrings[python] + run: pip install mkdocs-material mkdocstrings[python] mkdocs-exclude - name: Build docs run: mkdocs build diff --git a/.github/workflows/docs_deploy.yaml b/.github/workflows/docs_deploy.yaml index 1e6659d3..cbbb381c 100644 --- a/.github/workflows/docs_deploy.yaml +++ b/.github/workflows/docs_deploy.yaml @@ -29,7 +29,7 @@ jobs: python-version: 3.11 - name: Install dependencies - run: pip install mkdocs-material mkdocstrings[python] + run: pip install mkdocs-material mkdocstrings[python] mkdocs-exclude - name: Deploy docs to GitHub Pages run: mkdocs gh-deploy --force diff --git a/docs/internal/README.md b/docs/internal/README.md new file mode 100644 index 00000000..d8eea9dc --- /dev/null +++ b/docs/internal/README.md @@ -0,0 +1,3 @@ +# Internal documentation + +This folder contains internal / developer documentation and processes, e.g. how to build a release. This folder is thus ignored when building the public facing documentation from the `docs/` folder. diff --git a/docs/internal/release.md b/docs/internal/release.md new file mode 100644 index 00000000..7b804a82 --- /dev/null +++ b/docs/internal/release.md @@ -0,0 +1,24 @@ +# Release a new version + +This internal documentation guides you through the process of releasing a new version of NDSL. It is very simple: + +1. Click [create a release](https://github.com/NOAA-GFDL/NDSL/compare/main...develop?expand=1&template=release.md) and follow the steps in the release checklist. + +## Patch release + +Every now and then, we'll need to patch the currently released version of NDSL. To do so, follow these steps: + +1. Create a branch from `main`. +2. Commit your changes on that branch. +3. Use the following URL and follow the steps in the patch release checklist. + +As an example, you'd go and create branch `my-patches` from `main` + +```bash +git checkout main +git switch -c my-patches +# do changes ... +git push +``` + +and in that case, the URL with the patch release template is: . diff --git a/mkdocs.yml b/mkdocs.yml index 60fc8e6a..4c385bd1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -146,6 +146,9 @@ plugins: paths: [./ndsl] # Adjust this path to where your Python modules are options: show_source: false + - exclude: + glob: + - internal/* watch: # reload when the glossary file is updated diff --git a/pyproject.toml b/pyproject.toml index cf8e0a69..f05bb3b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.11,<3.12" [project.optional-dependencies] demos = ["ipython", "ipykernel"] dev = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "pre-commit", "flake8-pyproject", "build"] -docs = ["mkdocs-material", "mkdocstrings[python]"] +docs = ["mkdocs-material", "mkdocstrings[python]", "mkdocs-exclude"] extras = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "ndsl[dev]"] test = ["pytest", "coverage"] From 7534675816b2f6462b6dc8435c744d7058f5e3ee Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 3 Nov 2025 14:53:13 +0100 Subject: [PATCH 14/43] gt4py update: fix absolute indexin in debug backend (#302) Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index e140f707..fe0a82e4 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit e140f70731b723c519239e027237fb6281f4733b +Subproject commit fe0a82e4dc2a8a40916742078f9b948534077d02 From 25092ad471abe2d43adcad0d3f1556bb1cce233d Mon Sep 17 00:00:00 2001 From: Charles Kropiewnicki <79879064+CharlesKrop@users.noreply.github.com> Date: Mon, 3 Nov 2025 08:57:25 -0500 Subject: [PATCH 15/43] column min/max stencil - value and index (#301) * column min/max and a unit test * working unit test, pre-commit changes * alternative type ignore method * reverted previous change * using boilerplate code * reverting previous change --- ndsl/stencils/column_operations.py | 53 ++++++++++++++++++ tests/stencils/test_stencils.py | 90 ++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 ndsl/stencils/column_operations.py create mode 100644 tests/stencils/test_stencils.py diff --git a/ndsl/stencils/column_operations.py b/ndsl/stencils/column_operations.py new file mode 100644 index 00000000..21838ea2 --- /dev/null +++ b/ndsl/stencils/column_operations.py @@ -0,0 +1,53 @@ +import typing + +from ndsl.dsl.gt4py import function + + +@typing.no_type_check +@function +def column_max(field, start_index, end_index): + """ + Find the maximum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [max value, index of max value] + """ + max_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level) + old = field.at(K=max_index) + if new > old: + max_index = level + level += 1 + + return field.at(K=max_index), max_index + + +@typing.no_type_check +@function +def column_min(field, start_index, end_index): + """ + Find the minimum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [min value, index of min value] + """ + min_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level) + old = field.at(K=min_index) + if new < old: + min_index = level + level += 1 + + return field.at(K=min_index), min_index diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py new file mode 100644 index 00000000..32b1ade9 --- /dev/null +++ b/tests/stencils/test_stencils.py @@ -0,0 +1,90 @@ +import numpy as np + +from ndsl import StencilFactory +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import FORWARD, computation, interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.stencils.column_operations import column_max, column_min + + +nx = 1 +ny = 1 +nz = 10 +nhalo = 0 +backend = "dace:cpu" + +stencil_factory, quantity_factory = get_factories_single_tile( + nx, ny, nz, nhalo, backend +) + + +class ColumnOperations: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + + def column_max_stencil( + data: FloatField, max_value: FloatFieldIJ, max_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + max_value, max_index = column_max(data, 0, k_end) + + def column_min_stencil( + data: FloatField, min_value: FloatFieldIJ, min_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + min_value, min_index = column_min(data, 5, k_end) + + self._column_max_stencil = stencil_factory.from_dims_halo( + func=column_max_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self._column_min_stencil = stencil_factory.from_dims_halo( + func=column_min_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__( + self, + data: FloatField, + max_value: FloatFieldIJ, + max_index: FloatFieldIJ, + min_value: FloatFieldIJ, + min_index: FloatFieldIJ, + ): + self._column_max_stencil(data, max_value, max_index) + self._column_min_stencil(data, min_value, min_index) + + +def test_column_operations(): + data = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") + data.field[:] = [ + 47.3821, + 2.9157, + 88.6034, + 71.9275, + 53.1412, + 19.4783, + 94.2258, + 36.8099, + 64.0175, + 7.3504, + ] + + max_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + + code = ColumnOperations(stencil_factory) + print("initalized the class") + code(data, max_value, max_index, min_value, min_index) + + assert max_value.field[:] == np.max(data.field[:], axis=2) + assert max_index.field[:] == np.argmax(data.field[:], axis=2) + assert min_value.field[:] == np.min(data.field[:, :, 5:], axis=2) + assert min_index.field[:] == 5 + np.argmin(data.field[:, :, 5:], axis=2) From d1be2e70a9cf150d8d0b6e8fc4a049228dc57434 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 3 Nov 2025 20:05:28 +0100 Subject: [PATCH 16/43] build: gt4py udpdate (fix upcasting, abs k test coverage) (#303) This PR updates GT4Py to bring the following up from GT4Py - fix upcasting such that users can have variable k-offsets with expressions consisting of different types. - increase test coverage for absolute k indexing Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index fe0a82e4..97994b86 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit fe0a82e4dc2a8a40916742078f9b948534077d02 +Subproject commit 97994b8692f46483aaad3c2afbde363d350ae565 From dd931b2331767848a89d3b637390713382095879 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 3 Nov 2025 21:06:13 +0100 Subject: [PATCH 17/43] restore default PR template (#305) Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .github/PULL_REQUEST_TEMPLATE/README.md | 4 +++- .github/{PULL_REQUEST_TEMPLATE => }/pull_request_template.md | 0 2 files changed, 3 insertions(+), 1 deletion(-) rename .github/{PULL_REQUEST_TEMPLATE => }/pull_request_template.md (100%) diff --git a/.github/PULL_REQUEST_TEMPLATE/README.md b/.github/PULL_REQUEST_TEMPLATE/README.md index 618d2e17..e5980e53 100644 --- a/.github/PULL_REQUEST_TEMPLATE/README.md +++ b/.github/PULL_REQUEST_TEMPLATE/README.md @@ -1,7 +1,9 @@ # Pull request templates -- `pull_request_template.md`: The default pull request template. Used for PRs. +- `../pull_request_template.md`: The default pull request template. Used for PRs. - `release.md`: Special template used for releasing a new version of NDSL. - `release-patch.md`: Special template used for patch releases. Note: GitHub has limited support for multiple pull request templates. Most notably, templates can only be [selected by an URL query parameter](https://github.com/orgs/community/discussions/4620) and there's currently [no way to set a title](https://github.com/orgs/community/discussions/63965). + +Note: GitHub does not support having the default pull request template in this folder. All templates in this folder can only be used with the `template=` URL parameter (see note above). diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/pull_request_template.md similarity index 100% rename from .github/PULL_REQUEST_TEMPLATE/pull_request_template.md rename to .github/pull_request_template.md From 2c9946c0d74ff8d557db5ccdf8acc8e94a73d391 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 4 Nov 2025 10:22:42 +0100 Subject: [PATCH 18/43] BREAKING CHANGE: last 2025.10.00 deprecations (`CopyCorners`, `Quantity.values()`, `extra_dim_lengths` on `SubtileGridSizer` (#300) * Remove deprecated extra_dim_lengths of SubtileGridSizer This is a follow-up from https://github.com/NOAA-GFDL/NDSL/pull/295. * Remove deprecated CopyCorners * Remove deprecated `Quantity.values()` --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/initialization/subtile_grid_sizer.py | 10 -- ndsl/quantity/quantity.py | 12 -- ndsl/stencils/__init__.py | 8 +- ndsl/stencils/corners.py | 181 +--------------------- 4 files changed, 3 insertions(+), 208 deletions(-) diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index ff31a1c3..c86bbd44 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -1,4 +1,3 @@ -import warnings from collections.abc import Iterable from typing import Self @@ -21,7 +20,6 @@ def from_tile_params( data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, - extra_dim_lengths: dict[str, int] | None = None, ) -> Self: """Create a SubtileGridSizer from parameters about the full tile. @@ -36,18 +34,10 @@ def from_tile_params( tile_partitioner (optional): partitioner object for the tile. By default, a TilePartitioner is created with the given layout tile_rank (optional): rank of this subtile. - extra_dim_lengths: DEPRECATED API - use `data_dimensions` """ if data_dimensions is None: data_dimensions = {} - if extra_dim_lengths is not None: - warnings.warn( - "`extra_dim_lengths` is a deprecated name, please use `data_dimensions` instead.", - DeprecationWarning, - 2, - ) - data_dimensions = extra_dim_lengths if tile_partitioner is None: tile_partitioner = TilePartitioner(layout) y_slice, x_slice = tile_partitioner.subtile_slice( diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 45312d2f..c460b53d 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -254,18 +254,6 @@ def dims(self) -> tuple[str, ...]: """Names of each dimension""" return self.metadata.dims - @property - def values(self) -> np.ndarray: - warnings.warn( - "values exists only for backwards-compatibility with " - "DataArray and will be removed, use .view[:] instead", - DeprecationWarning, - stacklevel=2, - ) - return_array = np.asarray(self.view[:]) - return_array.flags.writeable = False - return return_array - @property def view(self) -> BoundedArrayView: """A view into the computational domain of the underlying data""" diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 8a635187..08c24561 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,8 +1,4 @@ -from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid +from .corners import CopyCornersXY, FillCornersBGrid -__all__ = [ - "CopyCorners", - "CopyCornersXY", - "FillCornersBGrid", -] +__all__ = ["CopyCornersXY", "FillCornersBGrid"] diff --git a/ndsl/stencils/corners.py b/ndsl/stencils/corners.py index 021a439b..4e21f8b1 100644 --- a/ndsl/stencils/corners.py +++ b/ndsl/stencils/corners.py @@ -1,73 +1,14 @@ -import warnings from collections.abc import Sequence from gt4py.cartesian import gtscript from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region from ndsl import StencilFactory -from ndsl.constants import ( - X_DIM, - X_INTERFACE_DIM, - Y_DIM, - Y_INTERFACE_DIM, - Z_INTERFACE_DIM, -) +from ndsl.constants import X_INTERFACE_DIM, Y_INTERFACE_DIM, Z_INTERFACE_DIM from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import FloatField -class CopyCorners: - """ - Helper-class to copy corners corresponding to the fortran functions - copy_corners_x or copy_corners_y respectively - """ - - def __init__(self, direction: str, stencil_factory: StencilFactory) -> None: - """The grid for this stencil""" - warnings.warn( - "Usage of the GT4Py implementation of CopyCorners is discouraged and will" - "be removed in the next release. Use `CopyCornersX` or `CopyCornersY` in PyFV3" - "for a more future-proof implementation of the same code.", - DeprecationWarning, - 2, - ) - grid_indexing = stencil_factory.grid_indexing - - n_halo = grid_indexing.n_halo - origin, domain = grid_indexing.get_origin_domain( - dims=[X_DIM, Y_DIM, Z_INTERFACE_DIM], halos=(n_halo, n_halo) - ) - - ax_offsets = grid_indexing.axis_offsets(origin, domain) - if direction == "x": - self._copy_corners = stencil_factory.from_origin_domain( - func=copy_corners_x_stencil_defn, - origin=origin, - domain=domain, - externals={ - **ax_offsets, - }, - ) - elif direction == "y": - self._copy_corners = stencil_factory.from_origin_domain( - func=copy_corners_y_stencil_defn, - origin=origin, - domain=domain, - externals={ - **ax_offsets, - }, - ) - else: - raise ValueError("Direction must be either 'x' or 'y'") - - def __call__(self, field: FloatField): - """ - Fills cell quantity field using corners from itself and multipliers - in the direction specified initialization of the instance of this class. - """ - self._copy_corners(field, field) - - class CopyCornersXY: """ Helper-class to copy corners corresponding to the Fortran functions @@ -313,126 +254,6 @@ def fill_corners_3cells_mult_y( return q -def copy_corners_x_stencil_defn(q_in: FloatField, q_out: FloatField): - from __externals__ import i_end, i_start, j_end, j_start - - with computation(PARALLEL), interval(...): - with horizontal( - region[i_start - 3, j_start - 3], region[i_end + 3, j_start - 3] - ): - q_out = q_in[0, 5, 0] - with horizontal( - region[i_start - 2, j_start - 3], region[i_end + 3, j_start - 2] - ): - q_out = q_in[-1, 4, 0] - with horizontal( - region[i_start - 1, j_start - 3], region[i_end + 3, j_start - 1] - ): - q_out = q_in[-2, 3, 0] - with horizontal( - region[i_start - 3, j_start - 2], region[i_end + 2, j_start - 3] - ): - q_out = q_in[1, 4, 0] - with horizontal( - region[i_start - 2, j_start - 2], region[i_end + 2, j_start - 2] - ): - q_out = q_in[0, 3, 0] - with horizontal( - region[i_start - 1, j_start - 2], region[i_end + 2, j_start - 1] - ): - q_out = q_in[-1, 2, 0] - with horizontal( - region[i_start - 3, j_start - 1], region[i_end + 1, j_start - 3] - ): - q_out = q_in[2, 3, 0] - with horizontal( - region[i_start - 2, j_start - 1], region[i_end + 1, j_start - 2] - ): - q_out = q_in[1, 2, 0] - with horizontal( - region[i_start - 1, j_start - 1], region[i_end + 1, j_start - 1] - ): - q_out = q_in[0, 1, 0] - with horizontal(region[i_start - 3, j_end + 1], region[i_end + 1, j_end + 3]): - q_out = q_in[2, -3, 0] - with horizontal(region[i_start - 2, j_end + 1], region[i_end + 1, j_end + 2]): - q_out = q_in[1, -2, 0] - with horizontal(region[i_start - 1, j_end + 1], region[i_end + 1, j_end + 1]): - q_out = q_in[0, -1, 0] - with horizontal(region[i_start - 3, j_end + 2], region[i_end + 2, j_end + 3]): - q_out = q_in[1, -4, 0] - with horizontal(region[i_start - 2, j_end + 2], region[i_end + 2, j_end + 2]): - q_out = q_in[0, -3, 0] - with horizontal(region[i_start - 1, j_end + 2], region[i_end + 2, j_end + 1]): - q_out = q_in[-1, -2, 0] - with horizontal(region[i_start - 3, j_end + 3], region[i_end + 3, j_end + 3]): - q_out = q_in[0, -5, 0] - with horizontal(region[i_start - 2, j_end + 3], region[i_end + 3, j_end + 2]): - q_out = q_in[-1, -4, 0] - with horizontal(region[i_start - 1, j_end + 3], region[i_end + 3, j_end + 1]): - q_out = q_in[-2, -3, 0] - - -def copy_corners_y_stencil_defn(q_in: FloatField, q_out: FloatField): - from __externals__ import i_end, i_start, j_end, j_start - - with computation(PARALLEL), interval(...): - with horizontal( - region[i_start - 3, j_start - 3], region[i_start - 3, j_end + 3] - ): - q_out = q_in[5, 0, 0] - with horizontal( - region[i_start - 2, j_start - 3], region[i_start - 3, j_end + 2] - ): - q_out = q_in[4, 1, 0] - with horizontal( - region[i_start - 1, j_start - 3], region[i_start - 3, j_end + 1] - ): - q_out = q_in[3, 2, 0] - with horizontal( - region[i_start - 3, j_start - 2], region[i_start - 2, j_end + 3] - ): - q_out = q_in[4, -1, 0] - with horizontal( - region[i_start - 2, j_start - 2], region[i_start - 2, j_end + 2] - ): - q_out = q_in[3, 0, 0] - with horizontal( - region[i_start - 1, j_start - 2], region[i_start - 2, j_end + 1] - ): - q_out = q_in[2, 1, 0] - with horizontal( - region[i_start - 3, j_start - 1], region[i_start - 1, j_end + 3] - ): - q_out = q_in[3, -2, 0] - with horizontal( - region[i_start - 2, j_start - 1], region[i_start - 1, j_end + 2] - ): - q_out = q_in[2, -1, 0] - with horizontal( - region[i_start - 1, j_start - 1], region[i_start - 1, j_end + 1] - ): - q_out = q_in[1, 0, 0] - with horizontal(region[i_end + 1, j_start - 3], region[i_end + 3, j_end + 1]): - q_out = q_in[-3, 2, 0] - with horizontal(region[i_end + 2, j_start - 3], region[i_end + 3, j_end + 2]): - q_out = q_in[-4, 1, 0] - with horizontal(region[i_end + 3, j_start - 3], region[i_end + 3, j_end + 3]): - q_out = q_in[-5, 0, 0] - with horizontal(region[i_end + 1, j_start - 2], region[i_end + 2, j_end + 1]): - q_out = q_in[-2, 1, 0] - with horizontal(region[i_end + 2, j_start - 2], region[i_end + 2, j_end + 2]): - q_out = q_in[-3, 0, 0] - with horizontal(region[i_end + 3, j_start - 2], region[i_end + 2, j_end + 3]): - q_out = q_in[-4, -1, 0] - with horizontal(region[i_end + 1, j_start - 1], region[i_end + 1, j_end + 1]): - q_out = q_in[-1, 0, 0] - with horizontal(region[i_end + 2, j_start - 1], region[i_end + 1, j_end + 2]): - q_out = q_in[-2, -1, 0] - with horizontal(region[i_end + 3, j_start - 1], region[i_end + 1, j_end + 3]): - q_out = q_in[-3, -2, 0] - - def copy_corners_xy_stencil_defn( q_in: FloatField, q_out_x: FloatField, q_out_y: FloatField ): From f66e473fdee0809c295507a5c398d36412b7e7a7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 5 Nov 2025 09:50:41 +0100 Subject: [PATCH 19/43] refactor: remove leftover debug print statements (#308) This PR just removes a bunch of leftover debug print statements from `ndsl/` and `tests/`. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/dsl/ndsl_runtime.py | 1 - ndsl/dsl/stencil.py | 1 - ndsl/quantity/quantity.py | 5 ----- tests/quantity/test_transpose.py | 1 - tests/stencils/test_stencils.py | 1 - tests/test_buffer.py | 1 - tests/test_g2g_communication.py | 1 - tests/test_halo_data_transformer.py | 4 +--- 8 files changed, 1 insertion(+), 14 deletions(-) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 721ae38b..065cc298 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -75,7 +75,6 @@ def check_for_quantity(object_: object) -> None: obj=self, config=self._dace_config, ) - print(type(self)) def __getattribute__(self, name: str) -> Any: attr = super().__getattribute__(name) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index c1fd297e..990c685a 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -240,7 +240,6 @@ def compare_ranks(comm: Comm, data: dict) -> Mapping[str, int]: other = comm.sendrecv(array, pair_rank) arr_diffs = np.sum(np.logical_and(~np.isnan(array), array != other)) if arr_diffs > 0: - print(name, rank, pair_rank, array, other) differences[name] = arr_diffs return differences diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index c460b53d..289fa644 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -393,11 +393,6 @@ def transpose( def plot_k_level(self, k_index: int = 0) -> None: field = self.data - print( - "Min and max values:", - field[:, :, k_index].min(), - field[:, :, k_index].max(), - ) plt.xlabel("I") plt.ylabel("J") diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 5e527279..745a8cd9 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -35,7 +35,6 @@ def quantity_data_input(initial_data, numpy, backend): array[:] = initial_data else: array = initial_data - print(type(array)) return array diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index 32b1ade9..ce3fe36f 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -81,7 +81,6 @@ def test_column_operations(): min_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") code = ColumnOperations(stencil_factory) - print("initalized the class") code(data, max_value, max_index, min_value, min_index) assert max_value.field[:] == np.max(data.field[:], axis=2) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 1498402d..de7723ce 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -187,6 +187,5 @@ def test_new_args_gives_different_buffer(allocator, backend, first_args, second_ @pytest.mark.parametrize("allocator, backend", [["ones", "cupy"]], indirect=True) def test_mpi_unsafe_allocator_exception(backend, allocator): BUFFER_CACHE.clear() - print(allocator) with pytest.raises(RuntimeError): Buffer.pop_from_cache(allocator, shape=(10, 10, 10), dtype=float) diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index dab27cb3..9147a4a8 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -137,7 +137,6 @@ def test_halo_update_only_communicate_on_gpu(backend, gpu_communicators): # We expect no np calls and several cp calls global N_ZEROS_CALLS # noqa: F824 global ... is unused - print(f"Results {N_ZEROS_CALLS}") assert N_ZEROS_CALLS[cp.zeros] > 0 assert N_ZEROS_CALLS[np.zeros] == 0 diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 7e3385e7..ed647e2c 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -165,12 +165,11 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): """A list of quantities whose values are 42.42 in the computational domain and 1 outside of it.""" sz = _shape_length(shape) - print(f"{shape} {sz}") data = np.arange(0, sz, dtype=dtype).reshape(shape) if "gtc" not in gt4py_backend: # should also test code if gt4py_backend is unset gt4py_backend = None - quantity = Quantity( + return Quantity( data, dims=dims, units=units, @@ -178,7 +177,6 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): extent=extent, gt4py_backend=gt4py_backend, ) - return quantity @pytest.fixture(params=[-0, -1, -2, -3]) From 9708cace5367e071eee724fa9cf4a408f4f0fdfb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 5 Nov 2025 09:50:48 +0100 Subject: [PATCH 20/43] refactor: make GridSizer an abstract base class (#306) `GridSizer` is de-facto already a base class with abstract methods `get_origin()`, `get_extent()`, and `get_shape()`. This PR just formalizes that intent. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/initialization/grid_sizer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ndsl/initialization/grid_sizer.py b/ndsl/initialization/grid_sizer.py index f3347e27..684fa5ab 100644 --- a/ndsl/initialization/grid_sizer.py +++ b/ndsl/initialization/grid_sizer.py @@ -1,9 +1,10 @@ +from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @dataclass -class GridSizer: +class GridSizer(ABC): nx: int """Length of the x compute dimension for produced arrays.""" ny: int @@ -15,11 +16,11 @@ class GridSizer: data_dimensions: dict[str, int] """Name/Lengths pair of any non-x/y/z dimensions, such as land or radiation dimensions.""" - def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]: - raise NotImplementedError() + @abstractmethod + def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]: ... - def get_extent(self, dims: Sequence[str]) -> tuple[int, ...]: - raise NotImplementedError() + @abstractmethod + def get_extent(self, dims: Sequence[str]) -> tuple[int, ...]: ... - def get_shape(self, dims: Sequence[str]) -> tuple[int, ...]: - raise NotImplementedError() + @abstractmethod + def get_shape(self, dims: Sequence[str]) -> tuple[int, ...]: ... From baadd98e5849864290af06dc693f3028999a7251 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 5 Nov 2025 14:52:39 +0100 Subject: [PATCH 21/43] refactor: directly use gt_storage in QuantityFactory (#307) In the past, `QuantityFactory` would allow not only allocating with gt4py storage objects, but also directly from `numpy` or `cupy`. This ability was removed in PR https://github.com/NOAA-GFDL/NDSL/pull/228. With that removal comes the opportunity to streamline allocation in `QuantityFactory`, removing the need for a `Allocator` class in the middle. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/initialization/allocator.py | 45 +++++++++----------------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 66332ec8..a8b3efbc 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -13,26 +13,6 @@ from ndsl.quantity import Quantity, QuantityHaloSpec -class _Allocator: - def __init__(self, backend: str) -> None: - """ - Initialize an object that provides gt4py storage objects for zeros(), ones(), and empty(). - - Args: - backend: GT4Py backend name used for performance-optimized allocation. - """ - self.backend = backend - - def empty(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.empty(*args, backend=self.backend, **kwargs) - - def ones(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.ones(*args, backend=self.backend, **kwargs) - - def zeros(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.zeros(*args, backend=self.backend, **kwargs) - - class QuantityFactory: def __init__(self, sizer: GridSizer, *, backend: str) -> None: """ @@ -45,8 +25,6 @@ def __init__(self, sizer: GridSizer, *, backend: str) -> None: self.sizer = sizer self.backend = backend - self._allocator = _Allocator(self.backend) - def update_data_dimensions( self, data_dimension_descriptions: dict[str, int], @@ -111,7 +89,7 @@ def empty( Equivalent to `numpy.empty`""" return self._allocate( - self._allocator.empty, dims, units, dtype, allow_mismatch_float_precision + gt_storage.empty, dims, units, dtype, allow_mismatch_float_precision ) def zeros( @@ -126,7 +104,7 @@ def zeros( Equivalent to `numpy.zeros`""" return self._allocate( - self._allocator.zeros, dims, units, dtype, allow_mismatch_float_precision + gt_storage.zeros, dims, units, dtype, allow_mismatch_float_precision ) def ones( @@ -141,7 +119,7 @@ def ones( Equivalent to `numpy.ones`""" return self._allocate( - self._allocator.ones, dims, units, dtype, allow_mismatch_float_precision + gt_storage.ones, dims, units, dtype, allow_mismatch_float_precision ) def full( @@ -157,7 +135,7 @@ def full( Equivalent to `numpy.full`""" quantity = self._allocate( - self._allocator.empty, + gt_storage.empty, dims, units, dtype, @@ -235,12 +213,15 @@ def _allocate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) ] - try: - data = allocator( - shape, dtype=dtype, aligned_index=origin, dimensions=dimensions - ) - except TypeError: - data = allocator(shape, dtype=dtype) + + data = allocator( + shape, + dtype=dtype, + aligned_index=origin, + dimensions=dimensions, + backend=self.backend, + ) + return Quantity( data, dims=dims, From 67ecc9063fa4f577dd5cda2db8ad01055ae5c488 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 5 Nov 2025 12:52:53 -0500 Subject: [PATCH 22/43] [Feature] Schedule Tree: refine transient (#304) * Fix axis merge * Remove debug print * Refine transients + utests * Lint * Revert to deactivating the experimental stree work * Use context manager for `_INTERNAL__SCHEDULE_TREE_OPTIMIZATION` * Typo * Clean refine transients code * Derive common strides layout from backend Refactor code to make re-sizing more compact in main algorithm Fix bad recursion Add todo list and verbose state of optimization * Lint * Remove `transient` to `State` lifetime - keep PR on target * Lint --- ndsl/dsl/dace/orchestration.py | 37 ++- ndsl/dsl/dace/stree/optimizations/__init__.py | 3 +- .../dace/stree/optimizations/axis_merge.py | 7 +- .../stree/optimizations/refine_transients.py | 225 ++++++++++++++++++ tests/stree_optimizer/test_optimization.py | 102 ++++++-- 5 files changed, 342 insertions(+), 32 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/refine_transients.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 4d959d15..a5c278ac 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -16,6 +16,7 @@ from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.transformation.auto.auto_optimize import make_transients_persistent +from dace.transformation.dataflow import MapExpansion from dace.transformation.helpers import get_parent_map from dace.transformation.passes.simplify import SimplifyPass from gt4py import storage @@ -34,7 +35,11 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations import ( + AxisIterator, + CartesianAxisMerge, + CartesianRefineTransients, +) from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -48,9 +53,6 @@ _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = False """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" -_INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] -"""INTERNAL: Default schedule passes for CPU. To be replaced with proper configuration.""" - def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" @@ -126,7 +128,7 @@ def _simplify( # We disable ScalarToSymbolPromotion because it might push symbols onto edges # that DaCe itself can't parse anymore later, e.g. casts, inlined function # calls or (complicated) field accesses. - skip=["ScalarToSymbolPromotion"], + skip={"ScalarToSymbolPromotion"}, ).apply_pass(sdfg, {}) @@ -157,14 +159,37 @@ def _build_sdfg( _simplify(sdfg) if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): + # Break all loops into uni-dimensional loops to simplify optimizations + sdfg.apply_transformations_repeated(MapExpansion, validate=True) stree = sdfg.as_schedule_tree() with DaCeProgress(config, "Schedule Tree: optimization"): if config.is_gpu_backend(): GPUPipeline().run(stree) else: - CPUPipeline(passes=_INTERNAL__SCHEDULE_TREE_PASSES).run(stree) + passes = [] + + if config.get_backend() == "dace:cpu_kfirst": + passes.extend( + [ + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianAxisMerge(AxisIterator._K), + CartesianRefineTransients(config.get_backend()), + ] + ) + else: + passes.extend( + [ + CartesianAxisMerge(AxisIterator._K), + CartesianAxisMerge(AxisIterator._I), + CartesianAxisMerge(AxisIterator._J), + CartesianRefineTransients(config.get_backend()), + ] + ) + CPUPipeline(passes=passes).run(stree) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): sdfg = stree.as_sdfg(skip={"ScalarToSymbolPromotion"}) diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 47c764b3..8e371ee9 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,5 @@ from .axis_merge import AxisIterator, CartesianAxisMerge +from .refine_transients import CartesianRefineTransients -__all__ = ["AxisIterator", "CartesianAxisMerge"] +__all__ = ["AxisIterator", "CartesianAxisMerge", "CartesianRefineTransients"] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 262a6021..74fe9e02 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -221,7 +221,7 @@ def _push_tasklet_down( return 0 # Tasklet is a callback next_index = list_index(nodes, the_tasklet) - if next_index == len(nodes): + if next_index == len(nodes) - 1: return 0 # Last node - done next_node = nodes[next_index + 1] @@ -323,7 +323,10 @@ def _map_overcompute_merge( nodes: list[stree.ScheduleTreeNode], ) -> int: if _last_node(nodes, the_map): - return 0 + merged = 0 + for child in the_map.children: + merged += self._merge_node(child, the_map.children) + return merged next_node = _get_next_node(nodes, the_map) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py new file mode 100644 index 00000000..967d6923 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import warnings +from types import TracebackType + +import dace.data +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator + + +def _change_index_of_tuple( + old_tuple: tuple[int, ...], index: int, value: int = 1 +) -> tuple[int, ...]: + """Return a copy of the given tuple with `old_tuple[index]` being replaced by `value`. + + Args: + old_tuple: to be copied + index: at which index to replace a value + value: to replace `old_tuple[index]` + """ + new_list = list(old_tuple) + new_list[index] = value + return tuple(new_list) + + +def _reduce_cartesian_axes_size_to_1( + transient_map_access: set[stree.nodes.MapEntry], + transient_data: dace.data.Data, + ijk_order: tuple[int, int, int], +) -> bool: + """Reduce dimension size of transient to 1 if their are accessed only + in a single Map for the cartesian dimensions""" + refined = False + for axis in AxisIterator: + access_in_map_count = 0 + for map_entry in transient_map_access: + if axis.as_str() in map_entry.params[0]: + access_in_map_count += 1 + + if access_in_map_count != 1: + continue + + # This transient is used in exactly one single-Axis map + # therefore this dimension can be removed. BUT we are not truly + # removing it, we are reducing it to 1 to not have to deal + # with different slicing. + transient_data.shape = _change_index_of_tuple( + transient_data.shape, + axis.as_cartesian_index(), + value=1, + ) + transient_data.set_strides_from_layout(*ijk_order) + refined = True + + return refined + + +class _CartesianMapNesting: + def __init__( + self, + cartesian_current_map_nesting: list[stree.nodes.MapEntry | None], + node: stree.MapScope, + ) -> None: + self._cartesian_current_map_nesting = cartesian_current_map_nesting + self._node = node + + def __enter__(self) -> None: + if AxisIterator._I.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[0] = self._node.node + elif AxisIterator._J.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[1] = self._node.node + elif AxisIterator._K.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[2] = self._node.node + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if AxisIterator._I.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[0] = None + elif AxisIterator._J.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[1] = None + elif AxisIterator._K.value[0] in self._node.node.params[0]: + self._cartesian_current_map_nesting[2] = None + + +class CollectTransientAccessInCartesianMaps(stree.ScheduleNodeVisitor): + """Collect all access of transient arrays per Maps.""" + + def __init__(self) -> None: + self.transient_map_access: dict[str, set[stree.nodes.MapEntry]] = {} + self._cartesian_current_map_nesting: list[stree.nodes.MapEntry | None] = [ + None, + None, + None, + ] + + def __str__(self) -> str: + return "CartesianCollectMaps" + + def visit_MapScope(self, node: stree.MapScope) -> None: + if len(node.node.params) > 1: + ndsl_log.debug( + "Can't apply CartesianRefineTransients, require unidimensional Maps" + ) + return + + with _CartesianMapNesting(self._cartesian_current_map_nesting, node): + for child in node.children: + self.visit(child) + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for memlet in [*node.input_memlets(), *node.output_memlets()]: + if self.containers[memlet.data].transient: + for map_entry in self._cartesian_current_map_nesting: + if map_entry is not None: + self.transient_map_access[memlet.data].add(map_entry) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for name, data in self.containers.items(): + if data.transient: + self.transient_map_access[name] = set() + + for child in node.children: + self.visit(child) + + +class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): + """Rebuild memlets from containers to ensure they are scope to the right size.""" + + def __str__(self) -> str: + return "RefineTransientAxis" + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + for name, memlet in node.in_memlets.items(): + if self.containers[memlet.data].transient: + node.in_memlets[name] = memlet.from_array( + memlet.data, self.containers[memlet.data] + ) + + for name, memlet in node.out_memlets.items(): + if self.containers[memlet.data].transient: + node.out_memlets[name] = memlet.from_array( + memlet.data, self.containers[memlet.data] + ) + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self.containers = node.containers + for child in node.children: + self.visit(child) + + +class CartesianRefineTransients(stree.ScheduleNodeTransformer): + """Refine (reduce dimensionality) of transients based on their true use in + the cartesian dimensions. + + + It can do: + - Looking at usage of a transient in a cartesian axis (e.g. loop over a + cartesian axis) it will reduce that axis to 1 if it exists in _only one_. + It should but cannot do/will bug if: + - Dataflow analysis on the axis to prevent reducing an axis to one where + the transient is used with offset, leading to faulty numerics + - Using the dataflow above, we can reduce the dimensions to the correct lowest + size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element + buffer) + - Current action when detecting a valide candidate is to reduce the size of the dimension + to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce + the cache access significantly. This also has been implemented to _not_ deal with offset/slicing + downstream impact of removing an axis. Nevertheless the xis should be removed if it's not + used. + + More tests: + - Test for dataflow with offset + - Test for I/J refine but not in K + - Test for J refine but not in I or K + - Test with dataflow: if/else, while, etc. + - Test with ForScope (FORWARD/BACKWARD) instead of Map + """ + + def __init__(self, backend: str) -> None: + warnings.warn( + "CartesianRefineTransients is a WIP. It's usage is *severaly* limited" + "and will most likely lead to bad numerics. Check the docs, check utest.", + UserWarning, + stacklevel=2, + ) + + if backend in ["dace:cpu_kfirst"]: + self.ijk_order = (2, 1, 0) + elif backend in ["dace:cpu", "dace:gpu"]: + self.ijk_order = (1, 0, 2) + else: + raise NotImplementedError( + "[Schedule Tree Opt] CartesianRefineTransient not implemented for " + f"backend {backend}" + ) + + def __str__(self) -> str: + return "CartesianRefineTransients" + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + collect_map = CollectTransientAccessInCartesianMaps() + collect_map.visit(node) + + # Remove Axis + refined_transient = 0 + for name, data in node.containers.items(): + if not data.transient: + continue + refined = _reduce_cartesian_axes_size_to_1( + collect_map.transient_map_access[name], + data, + self.ijk_order, + ) + refined_transient += 1 if refined else 0 + + RebuildMemletsFromContainers().visit(node) + + ndsl_log.debug(f"🚀 {refined_transient} Transient refined") diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py index 6ed29479..fd08a1cb 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_optimization.py @@ -1,11 +1,31 @@ -from ndsl import StencilFactory, orchestrate +from types import TracebackType + +import dace + +import ndsl.dsl.dace.orchestration as orch +from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField +class StreeOptimization: + def __init__(self) -> None: + pass + + def __enter__(self) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + + def stencil_A(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = in_field @@ -33,20 +53,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField) -> None: self.stencil_B(in_field, out_field) -def test_stree_roundtrip_no_opt() -> None: - """Dev Note: - - The below code successfully merges top level K loop (2 loops) - How do we test it?! Running doesn't test merging and the compilation - is a near-black box. We could reach in the `dace_config.compiled_sdfg` - cache but it's keyed on the dace.program and if we can reach the program - well we can reach the SDFG and turn it into an stree for verification - Should we run orchestration "by hand"? - Can we intercept the `stree` ? After all we just want to check that! - - Test is deactivated for now""" - - return +def test_stree_merge_maps() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend="dace:cpu" @@ -56,14 +63,63 @@ def test_stree_roundtrip_no_opt() -> None: in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - # Temporarily flip the internal switch - import ndsl.dsl.dace.orchestration as orch + with StreeOptimization(): + code(in_qty, out_qty) - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True - orch._INTERNAL__SCHEDULE_TREE_PASSES = [CartesianAxisMerge(AxisIterator._K)] + assert ( + len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) + == 1 + ) + sdfg = list( + stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() + )[0] + all_maps = [ + (me, state) + for me, state in sdfg.sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + + assert len(all_maps) == 3 + assert (out_qty.field[:] == 4).all() + + +class LocalRefineableCode(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(stencil_factory.config.dace_config) + self.stencil_A = stencil_factory.from_dims_halo( + func=stencil_A, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + + def __call__(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil_A(in_field, self.tmp) + self.stencil_A(self.tmp, out_field) + + +def test_stree_roundtrip_transient_is_refined() -> None: + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu" + ) + + code = LocalRefineableCode(stencil_factory, quantity_factory) + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - code(in_qty, out_qty) + with StreeOptimization(): + code(in_qty, out_qty) - assert (out_qty.field[:] == 4).all() + assert ( + len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) + == 1 + ) + sdfg = list( + stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() + )[0] - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + for array in sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) From 72159540a84a692ede7484c73489a5ae28da81a8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 6 Nov 2025 14:15:02 +0100 Subject: [PATCH 23/43] build: gt4py update (upcasting in cast operations) (#310) This PR updates GT4Py to bring the fix for upcasting inside cast operations from GT4py to NDSL. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 97994b86..9ac1aa85 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 97994b8692f46483aaad3c2afbde363d350ae565 +Subproject commit 9ac1aa857f6df5160e79617ddca90f6a807f85d1 From 74a144d1a6d3d03d0657c79e37c674dac60dc3aa Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 7 Nov 2025 16:49:20 +0100 Subject: [PATCH 24/43] build: gt4py update (precision of global constants) (#313) This PR updates GT4Py in NDSL to bring up a PR that fixes the precision of global constants. So far, we'd discard any type annotation on global constants and just use the default literal precision instead. With this change, we respect potential type annotations on global constants. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 9ac1aa85..7d123536 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 9ac1aa857f6df5160e79617ddca90f6a807f85d1 +Subproject commit 7d1235363594766aa57bc259b94208091137c32e From 32eacf671bdce6a7e0c0d5620755ca9048d077c0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 10 Nov 2025 15:01:18 +0100 Subject: [PATCH 25/43] refactor: Quantity constructor: `gt4py_backend` -> `backend` (#312) * refactor: force kwargs in ctor of Quantity/Local Force keyword arguments for optional arguments to those constructors. This will facilitate the `gt4py_backen` -> `backend` transition. * refactor: prefer `backend` over `gt4py_backend` in Quantity --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/comm/communicator.py | 10 +-- ndsl/dsl/ndsl_runtime.py | 2 +- ndsl/grid/generation.py | 24 +++--- ndsl/grid/helper.py | 12 +-- ndsl/initialization/allocator.py | 2 +- ndsl/quantity/local.py | 24 ++++-- ndsl/quantity/metadata.py | 5 +- ndsl/quantity/quantity.py | 114 ++++++++++++++++++--------- tests/mpi/test_mpi_all_reduce_sum.py | 12 +-- tests/quantity/test_local.py | 41 ++++++++++ tests/quantity/test_quantity.py | 67 ++++++++++++++++ tests/quantity/test_storage.py | 8 +- tests/quantity/test_transpose.py | 10 ++- tests/test_halo_data_transformer.py | 2 +- tests/test_partitioner.py | 2 +- 15 files changed, 250 insertions(+), 85 deletions(-) create mode 100644 tests/quantity/test_local.py diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 6eee4514..d1c1205f 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -108,7 +108,7 @@ def _create_all_reduce_quantity( units=input_metadata.units, origin=input_metadata.origin, extent=input_metadata.extent, - gt4py_backend=input_metadata.gt4py_backend, + backend=input_metadata.backend, allow_mismatch_float_precision=False, ) return all_reduce_quantity @@ -228,7 +228,7 @@ def _get_gather_recv_quantity( units=send_metadata.units, origin=tuple([0 for dim in send_metadata.dims]), extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, + backend=send_metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -241,7 +241,7 @@ def _get_scatter_recv_quantity( send_metadata.np.zeros(shape, dtype=send_metadata.dtype), # type: ignore dims=send_metadata.dims, units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, + backend=send_metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -841,7 +841,7 @@ def _get_gather_recv_quantity( units=metadata.units, origin=(0,) + tuple([0 for dim in metadata.dims]), extent=global_extent, - gt4py_backend=metadata.gt4py_backend, + backend=metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -861,7 +861,7 @@ def _get_scatter_recv_quantity( metadata.np.zeros(shape, dtype=metadata.dtype), # type: ignore dims=metadata.dims[1:], units=metadata.units, - gt4py_backend=metadata.gt4py_backend, + backend=metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 065cc298..c798ded4 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -123,6 +123,6 @@ def make_local( units=quantity.units, origin=quantity.origin, extent=quantity.extent, - gt4py_backend=quantity.gt4py_backend, + backend=quantity.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, ) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 7fd6eb7d..7296fa76 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -551,7 +551,7 @@ def lon(self): origin=self.grid.origin[:2], extent=self.grid.extent[:2], units=self.grid.units, - gt4py_backend=self.grid.gt4py_backend, + backend=self.grid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -563,7 +563,7 @@ def lat(self) -> Quantity: origin=self.grid.origin[:2], extent=self.grid.extent[:2], units=self.grid.units, - gt4py_backend=self.grid.gt4py_backend, + backend=self.grid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -575,7 +575,7 @@ def lon_agrid(self) -> Quantity: origin=self.agrid.origin[:2], extent=self.agrid.extent[:2], units=self.agrid.units, - gt4py_backend=self.agrid.gt4py_backend, + backend=self.agrid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -587,7 +587,7 @@ def lat_agrid(self) -> Quantity: origin=self.agrid.origin[:2], extent=self.agrid.extent[:2], units=self.agrid.units, - gt4py_backend=self.agrid.gt4py_backend, + backend=self.agrid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1551,7 +1551,7 @@ def rarea(self) -> Quantity: origin=self.area.origin, extent=self.area.extent, units="m^-2", - gt4py_backend=self.area.gt4py_backend, + backend=self.area.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1566,7 +1566,7 @@ def rarea_c(self) -> Quantity: origin=self.area_c.origin, extent=self.area_c.extent, units="m^-2", - gt4py_backend=self.area_c.gt4py_backend, + backend=self.area_c.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1582,7 +1582,7 @@ def rdx(self) -> Quantity: origin=self.dx.origin, extent=self.dx.extent, units="m^-1", - gt4py_backend=self.dx.gt4py_backend, + backend=self.dx.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1598,7 +1598,7 @@ def rdy(self) -> Quantity: origin=self.dy.origin, extent=self.dy.extent, units="m^-1", - gt4py_backend=self.dy.gt4py_backend, + backend=self.dy.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1614,7 +1614,7 @@ def rdxa(self) -> Quantity: origin=self.dxa.origin, extent=self.dxa.extent, units="m^-1", - gt4py_backend=self.dxa.gt4py_backend, + backend=self.dxa.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1630,7 +1630,7 @@ def rdya(self) -> Quantity: origin=self.dya.origin, extent=self.dya.extent, units="m^-1", - gt4py_backend=self.dya.gt4py_backend, + backend=self.dya.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1646,7 +1646,7 @@ def rdxc(self) -> Quantity: origin=self.dxc.origin, extent=self.dxc.extent, units="m^-1", - gt4py_backend=self.dxc.gt4py_backend, + backend=self.dxc.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1662,7 +1662,7 @@ def rdyc(self) -> Quantity: origin=self.dyc.origin, extent=self.dyc.extent, units="m^-1", - gt4py_backend=self.dyc.gt4py_backend, + backend=self.dyc.backend, number_of_halo_points=N_HALO_DEFAULT, ) diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index dd612b22..d907e49a 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -186,7 +186,7 @@ def p_interface(self) -> Quantity: p_interface_data, dims=[Z_INTERFACE_DIM], units="Pa", - gt4py_backend=self.ak.gt4py_backend, + backend=self.ak.backend, number_of_halo_points=self.ak.metadata.n_halo, ) return self._p_interface @@ -203,7 +203,7 @@ def p(self) -> Quantity: p_data, dims=[Z_DIM], units="Pa", - gt4py_backend=self.p_interface.gt4py_backend, + backend=self.p_interface.backend, number_of_halo_points=self.p_interface.metadata.n_halo, ) return self._p @@ -220,7 +220,7 @@ def dp(self) -> Quantity: dp_ref_data, dims=[Z_DIM], units="Pa", - gt4py_backend=self.ak.gt4py_backend, + backend=self.ak.backend, number_of_halo_points=self.ak.metadata.n_halo, ) return self._dp_ref @@ -230,7 +230,7 @@ def ptop(self) -> Float: """Top of atmosphere pressure (Pa)""" if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - if self.ak.gt4py_backend is not None and is_gpu_backend(self.ak.gt4py_backend): + if self.ak.backend is not None and is_gpu_backend(self.ak.backend): return Float(self.ak.view[0].get()) else: return Float(self.ak.view[0]) @@ -382,7 +382,7 @@ def _fC_from_data(data, lat: Quantity) -> Quantity: dims=lat.dims, origin=lat.origin, extent=lat.extent, - gt4py_backend=lat.gt4py_backend, + backend=lat.backend, number_of_halo_points=lat.metadata.n_halo, ) @@ -824,7 +824,7 @@ def split_quantity_along_last_dim(quantity: Quantity) -> list[Quantity]: units=quantity.units, origin=quantity.origin[:-1], extent=quantity.extent[:-1], - gt4py_backend=quantity.gt4py_backend, + backend=quantity.backend, number_of_halo_points=quantity.metadata.n_halo, ) ) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index a8b3efbc..f36cd02e 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -228,7 +228,7 @@ def _allocate( units=units, origin=origin, extent=extent, - gt4py_backend=self.backend, + backend=self.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, number_of_halo_points=self.sizer.n_halo, ) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 910999ab..af75242d 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +import warnings +from collections.abc import Sequence +from typing import Any import dace import numpy as np @@ -20,21 +22,31 @@ def __init__( data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, + backend: str | None = None, ): + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + super().__init__( data, dims, units, - origin, - extent, - gt4py_backend, - allow_mismatch_float_precision, + origin=origin, + extent=extent, + allow_mismatch_float_precision=allow_mismatch_float_precision, + backend=backend, ) - self._transient = True def __descriptor__(self) -> Any: """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index 7e7b4f16..45a14445 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -30,7 +30,9 @@ class QuantityMetadata: dtype: type "dtype of the data in the ndarray-like object" gt4py_backend: str | None = None - "backend to use for gt4py storages" + "Deprecated. Use backend instead." + backend: str | None = None + "GT4Py backend name. Used for performance optimal data allocation." @property def dim_lengths(self) -> dict[str, int]: @@ -57,6 +59,7 @@ def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None: metadata_copy.data_type = self.data_type metadata_copy.dtype = self.dtype metadata_copy.gt4py_backend = self.gt4py_backend + metadata_copy.backend = self.backend @dataclasses.dataclass diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 289fa644..4b902a72 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -32,33 +32,44 @@ def __init__( data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, number_of_halo_points: int = 0, + backend: str | None = None, ): """Initialize a Quantity. Args: - data (_type_): ndarray-like object containing the underlying data - dims (Sequence[str]): dimension names for each axis - units (str): units of the quantity - origin (Sequence[int] | None, optional): first point in data within the - computational domain. Defaults to None. - extent (Sequence[int] | None, optional): number of points along each axis + data: ndarray-like object containing the underlying data + dims: dimension names for each axis + units: units of the quantity + origin: first point in data within the + computational domain. Defaults to None. + extent: number of points along each axis within the computational domain. Defaults to None. - gt4py_backend (str | None, optional): backend to use for gt4py storages, - if not given this will be derived from a Storage - if given as the data argument. Defaults to None. - allow_mismatch_float_precision (bool, optional): allow for precision that is + gt4py_backend: deprecated, use `backend` instead. + allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. - number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + number_of_halo_points: Number of halo points used. Defaults to 0. + backend: GT4Py backend name. If given, we check that the data is + allocated in a performance optimal way for that backend. Raises: ValueError: Data-type mismatch between configuration and input-data TypeError: Typing of the data that does not fit """ + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -80,6 +91,11 @@ def __init__( if isinstance(data, (int, float, list)): # If converting basic data, use a numpy ndarray. + warnings.warn( + "Usage of basic data in Quantities is deprecated. Please use it with a numpy or cuppy ndarray instead.", + DeprecationWarning, + stacklevel=2, + ) data = np.asarray(data) if not isinstance(data, (np.ndarray, cupy.ndarray)): @@ -87,8 +103,10 @@ def __init__( f"Only supports numpy.ndarray and cupy.ndarray, got {type(data)}" ) - if gt4py_backend is not None: - gt4py_backend_cls = gt_backend.from_name(gt4py_backend) + _validate_quantity_property_lengths(data.shape, dims, origin, extent) + + if backend is not None: + gt4py_backend_cls = gt_backend.from_name(backend) is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] dimensions: tuple[str | int, ...] = tuple( @@ -104,21 +122,25 @@ def __init__( ] ) - self._data = ( - data - if is_optimal_layout(data, dimensions) - else self._initialize_data( + if is_optimal_layout(data, dimensions): + self._data = data + else: + warnings.warn( + f"Suboptimal data layout found. Copying data to optimally align for backend '{backend}'.", + UserWarning, + stacklevel=2, + ) + self._data = gt_storage.from_array( data, - origin=origin, - gt4py_backend=gt4py_backend, + data.dtype, + backend=backend, + aligned_index=origin, dimensions=dimensions, ) - ) else: - # We have no info about the gt4py_backend, so just assign it. + # We have no info about the gt4py backend, so just assign it. self._data = data - _validate_quantity_property_lengths(data.shape, dims, origin, extent) self._metadata = QuantityMetadata( origin=_ensure_int_tuple(origin, "origin"), extent=_ensure_int_tuple(extent, "extent"), @@ -127,7 +149,8 @@ def __init__( units=units, data_type=type(self._data), dtype=data.dtype, - gt4py_backend=gt4py_backend, + backend=backend, + gt4py_backend=backend, ) self._attrs = {} # type: ignore[var-annotated] self._compute_domain_view = BoundedArrayView( @@ -138,10 +161,12 @@ def __init__( def from_data_array( cls, data_array: xr.DataArray, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, number_of_halo_points: int = 0, + backend: str | None = None, ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -150,12 +175,25 @@ def from_data_array( data_array origin: first point in data within the computational domain extent: number of points along each axis within the computational domain - gt4py_backend: backend to use for gt4py storages, if not given this will - be derived from a Storage if given as the data argument, otherwise the - storage attribute is disabled and will raise an exception + gt4py_backend: deprecated, use `backend` instead. + allow_mismatch_float_precision: allow for precision that is + not the simulation-wide default configuration. Defaults to False. + number_of_halo_points: Number of halo points used. Defaults to 0. + backend: GT4Py backend name. If given, we check that the data is + allocated in a performance optimal way for that backend. """ if "units" not in data_array.attrs: raise ValueError("need units attribute to create Quantity from DataArray") + + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + return cls( data_array.values, cast(tuple[str], data_array.dims), @@ -163,7 +201,7 @@ def from_data_array( origin=origin, extent=extent, number_of_halo_points=number_of_halo_points, - gt4py_backend=gt4py_backend, + backend=backend, ) def to_netcdf( @@ -221,17 +259,6 @@ def sel(self, **kwargs: slice | int) -> np.ndarray: """ return self.view[tuple(kwargs.get(dim, slice(None, None)) for dim in self.dims)] - def _initialize_data(self, data, origin, gt4py_backend: str, dimensions: tuple): # type: ignore - """Allocates an ndarray with optimal memory layout, and copies the data over.""" - storage = gt_storage.from_array( - data, - data.dtype, - backend=gt4py_backend, - aligned_index=origin, - dimensions=dimensions, - ) - return storage - @property def metadata(self) -> QuantityMetadata: return self._metadata @@ -243,8 +270,17 @@ def units(self) -> str: @property def gt4py_backend(self) -> str | None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) return self.metadata.gt4py_backend + @property + def backend(self) -> str | None: + return self.metadata.backend + @property def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @@ -385,8 +421,8 @@ def transpose( units=self.units, origin=_transpose_sequence(self.origin, transpose_order), extent=_transpose_sequence(self.extent, transpose_order), - gt4py_backend=self.gt4py_backend, allow_mismatch_float_precision=allow_mismatch_float_precision, + backend=self.backend, ) transposed._attrs = self._attrs return transposed diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 6cab1023..52b02dad 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -58,7 +58,7 @@ def test_all_reduce(communicator): data=base_array, dims=["K"], units="Some 1D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5)], dtype=Float) @@ -68,7 +68,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J"], units="Some 2D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) @@ -78,7 +78,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J", "K"], units="Some 3D unit", - gt4py_backend=backend, + backend=backend, ) global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM) @@ -98,7 +98,7 @@ def test_all_reduce(communicator): data=base_array, dims=["K"], units="New 1D unit", - gt4py_backend=backend, + backend=backend, origin=(8,), extent=(7,), ) @@ -110,7 +110,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J"], units="Some 2D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) @@ -120,7 +120,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J", "K"], units="Some 3D unit", - gt4py_backend=backend, + backend=backend, ) communicator.all_reduce( testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py new file mode 100644 index 00000000..859bb009 --- /dev/null +++ b/tests/quantity/test_local.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +from ndsl import Local + + +def test_local_descriptor_is_transient() -> None: + nx = 5 + shape = (nx,) + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + backend="debug", + ) + array = local.__descriptor__() + assert array.transient + + +def test_local_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + gt4py_backend=backend, + ) + + # make sure we assign backend + assert local.backend == backend + + # make sure we are backwards compatible (for now) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert local.gt4py_backend == backend diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index dccfa94f..3286ce0f 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import xarray as xr from ndsl import Quantity from ndsl.quantity.bounds import _shift_slice @@ -289,3 +290,69 @@ def test_data_setter(): # Expected fail: new array is not even an array with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): quantity.data = "meh" + + +def test_constructor_with_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + quantity = Quantity( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + gt4py_backend=backend, + ) + + # make sure we assign backend + assert quantity.backend == backend + + # make sure we are backwards compatible (on the QuantityMetadata) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert quantity.gt4py_backend == backend + + +def test_from_data_array_with_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + np_data = np.empty(shape) + data_array = xr.DataArray(data=np_data, attrs={"units": "n/a"}) + quantity = Quantity.from_data_array( + data_array, + origin=(0,), + extent=(nx,), + number_of_halo_points=0, + gt4py_backend=backend, + ) + + # make sure we assign backend + assert quantity.backend == backend + + # make sure we don't assign gt4py_backend anymore (on the QuantityMetadata) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert quantity.gt4py_backend == backend + + +def test_assign_basic_data_is_deprecated() -> None: + nx = 5 + backend = "debug" + with pytest.deprecated_call( + match="Usage of basic data in Quantities is deprecated" + ): + quantity = Quantity( + data=[0, 1, 2, 3, 4], + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + backend=backend, + allow_mismatch_float_precision=True, + ) + + # make sure we can still use it (for now) + for i in range(5): + assert quantity.data[i] == i diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index bc39f61f..7fbd1a04 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -72,7 +72,7 @@ def test_modifying_numpy_data_modifies_view_and_field(): extent=shape, dims=["dim1", "dim2"], units="units", - gt4py_backend="numpy", + backend="numpy", ) assert np.all(quantity.data == 0) quantity.data[0, 0] = 1 @@ -99,7 +99,7 @@ def test_data_and_field_access_right_full_array_and_compute_domain(): extent=(5, 5), dims=["dim1", "dim2"], units="units", - gt4py_backend="numpy", + backend="numpy", ) assert np.all(quantity.data == 0) # Write compute domain - test data is written with the offset @@ -139,7 +139,7 @@ def test_accessing_data_does_not_break_view( extent=extent, dims=dims, units=units, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] @@ -158,6 +158,6 @@ def test_numpy_data_becomes_cupy_with_gpu_backend( extent=extent, dims=dims, units=units, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) assert isinstance(quantity.data, cp.ndarray) diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 745a8cd9..88653676 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -165,7 +165,13 @@ def param_product(*param_lists): ) @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_transpose( - quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy + quantity: Quantity, + target_dims, + final_data, + final_dims, + final_origin, + final_extent, + numpy, ): result = quantity.transpose(target_dims) numpy.testing.assert_array_equal(result.data, final_data) @@ -173,7 +179,7 @@ def test_transpose( assert result.origin == final_origin assert result.extent == final_extent assert result.units == quantity.units - assert result.gt4py_backend == quantity.gt4py_backend + assert result.backend == quantity.backend @pytest.mark.parametrize( diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index ed647e2c..2d34567a 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -175,7 +175,7 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): units=units, origin=origin, extent=extent, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index cc0a105e..4d20e709 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -993,7 +993,7 @@ def test_subtile_extent_with_tile_dimensions( cubedsphere_expected, ): data_array = np.zeros((tile_extent)) - quantity = Quantity(data_array, array_dims, "dimensionless", [0, 0, 0, 0]) + quantity = Quantity(data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0]) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) cubedsphere_partitioner = CubedSpherePartitioner(tile_partitioner) From a76b671e73d929d32086881ec60c3d0fd8fa6932 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 10 Nov 2025 15:05:13 +0100 Subject: [PATCH 26/43] refactor: prepare `ZarrMonitor` for upcomming `Comm` changes (#315) * refactor: ZarrMonitor: you'll have to bring your own comm objects * ci: run unit tests with optional zarr dependency --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .github/workflows/unit_tests.yaml | 2 +- ndsl/monitor/zarr_monitor.py | 7 ++++ pyproject.toml | 1 + tests/test_zarr_monitor.py | 61 +++++++++++++++++++------------ 4 files changed, 46 insertions(+), 25 deletions(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index bd90f66c..ef9fdcea 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -32,7 +32,7 @@ jobs: run: pip3 install mpich - name: Install Python packages - run: pip3 install .[test] + run: pip3 install .[test,zarr] - name: Run serial-cpu tests run: coverage run --rcfile=pyproject.toml -m pytest tests diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 31e73c62..8d2469ba 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from datetime import datetime, timedelta from typing import TypeVar @@ -90,6 +91,7 @@ def __init__( self, store: str | zarr.storage.MutableMapping, partitioner: Partitioner, + *, mode: str = "w", mpi_comm: Comm | None = None, ) -> None: @@ -103,6 +105,11 @@ def __init__( use a dummy comm object that works in single-core mode. """ if mpi_comm is None: + warnings.warn( + "`mpi_comm` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) mpi_comm = DummyComm() if mpi_comm.Get_rank() == 0: diff --git a/pyproject.toml b/pyproject.toml index f05bb3b2..a562bfb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dev = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "pre-commit", "flake8-pyprojec docs = ["mkdocs-material", "mkdocstrings[python]", "mkdocs-exclude"] extras = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "ndsl[dev]"] test = ["pytest", "coverage"] +zarr = ["zarr<3"] [project.scripts] ndsl-serialbox_to_netcdf = "ndsl.stencils.testing.serialbox_to_netcdf:entry_point" diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index b3847ee4..be3e7650 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -7,7 +7,7 @@ import pytest import xarray as xr -from ndsl import CubedSpherePartitioner, DummyComm, Quantity, TilePartitioner +from ndsl import CubedSpherePartitioner, LocalComm, MPIComm, Quantity, TilePartitioner from ndsl.constants import ( X_DIM, X_DIMS, @@ -91,10 +91,11 @@ def cube_partitioner(tile_partitioner): @pytest.fixture(params=["empty", "one_var_2d", "one_var_3d", "two_vars"]) -def base_state(request, nz, ny, nx, numpy): +def base_state(request, nz, ny, nx, numpy) -> dict: if request.param == "empty": return {} - elif request.param == "one_var_2d": + + if request.param == "one_var_2d": return { "var1": Quantity( numpy.ones([ny, nx]), @@ -102,7 +103,8 @@ def base_state(request, nz, ny, nx, numpy): units="m", ) } - elif request.param == "one_var_3d": + + if request.param == "one_var_3d": return { "var1": Quantity( numpy.ones([nz, ny, nx]), @@ -110,7 +112,8 @@ def base_state(request, nz, ny, nx, numpy): units="m", ) } - elif request.param == "two_vars": + + if request.param == "two_vars": return { "var1": Quantity( numpy.ones([ny, nx]), @@ -123,8 +126,8 @@ def base_state(request, nz, ny, nx, numpy): units="degK", ), } - else: - raise NotImplementedError() + + raise NotImplementedError() @pytest.fixture @@ -139,10 +142,17 @@ def state_list(base_state, n_times, start_time, time_step, numpy): return state_list +@requires_zarr +def test_mpi_comm_will_be_required(cube_partitioner): + with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: + with pytest.deprecated_call(match="`mpi_comm` will be a required argument"): + ZarrMonitor(tempdir, cube_partitioner) + + @requires_zarr def test_monitor_file_store(state_list, cube_partitioner, numpy, start_time): with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: - monitor = ZarrMonitor(tempdir, cube_partitioner) + monitor = ZarrMonitor(tempdir, cube_partitioner, mpi_comm=MPIComm()) for state in state_list: monitor.store(state) validate_store(state_list, tempdir, numpy, start_time) @@ -152,7 +162,7 @@ def test_monitor_file_store(state_list, cube_partitioner, numpy, start_time): @requires_zarr def validate_xarray_can_open(dirname): # just checking there are no crashes, validate_group checks data - xr.open_zarr(dirname) + xr.open_zarr(dirname, consolidated=False) @requires_zarr @@ -239,8 +249,8 @@ def test_monitor_file_store_multi_rank_state( ZarrMonitor( store, partitioner, - "w", - mpi_comm=DummyComm( + mode="w", + mpi_comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), ) @@ -337,14 +347,17 @@ def _assert_no_nulls(dataset: xr.Dataset): @requires_zarr def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale): store = {} + buffer = {} # initialize store - monitor = ZarrMonitor(store, cube_partitioner) + monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) zero_quantity = Quantity(numpy.zeros([10, 10]), dims=("y", "x"), units="m") monitor.store({"var": zero_quantity}) # open w/o dask using chunks=None - dataset = xr.open_zarr(store, chunks=None, mask_and_scale=mask_and_scale) + dataset = xr.open_zarr( + store, chunks=None, mask_and_scale=mask_and_scale, consolidated=False + ) _assert_no_nulls(dataset.sel(tile=0)) @@ -354,14 +367,15 @@ def test_values_preserved(cube_partitioner, numpy): units = "m" store = {} + buffer = {} # initialize store - monitor = ZarrMonitor(store, cube_partitioner) + monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) quantity = Quantity(numpy.random.uniform(size=(10, 10)), dims=dims, units=units) monitor.store({"var": quantity}) # open w/o dask using chunks=None - dataset = xr.open_zarr(store, chunks=None) + dataset = xr.open_zarr(store, chunks=None, consolidated=False) numpy.testing.assert_array_almost_equal( dataset["var"][0, 0, :, :].values, quantity.data ) @@ -388,7 +402,7 @@ def test_monitor_file_store_inconsistent_calendars( state_list_with_inconsistent_calendars, cube_partitioner, numpy ): with tempfile.TemporaryDirectory(suffix=".zarr") as tempdir: - monitor = ZarrMonitor(tempdir, cube_partitioner) + monitor = ZarrMonitor(tempdir, cube_partitioner, mpi_comm=MPIComm()) initial_state, final_state = state_list_with_inconsistent_calendars monitor.store(initial_state) with pytest.raises(ValueError, match="Calendar type"): @@ -406,29 +420,28 @@ def test_monitor_file_store_inconsistent_calendars( ) def diag(request, numpy): dims = request.param - diag = Quantity( + return Quantity( numpy.ones([size + 2 for size in range(len(dims))]), dims=dims, units="m" ) - return diag def _transpose(quantity, dims_2d, dims_3d): if len(quantity.dims) == 2: return quantity.transpose(dims_2d) - elif len(quantity.dims) == 3: + + if len(quantity.dims) == 3: return quantity.transpose(dims_3d) @pytest.fixture(scope="function") def zarr_store(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("diags.zarr") - store = zarr.storage.DirectoryStore(tmpdir) - return store + return zarr.storage.DirectoryStore(tmpdir) @pytest.fixture(scope="function") def zarr_monitor_single_rank(zarr_store, cube_partitioner): - return ZarrMonitor(zarr_store, cube_partitioner) + return ZarrMonitor(zarr_store, cube_partitioner, mpi_comm=MPIComm()) @requires_zarr @@ -440,7 +453,7 @@ def test_transposed_diags_write_across_ranks(diag, cube_partitioner, zarr_store) monitor = ZarrMonitor( zarr_store, cube_partitioner, - mpi_comm=DummyComm( + mpi_comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), ) @@ -497,6 +510,6 @@ def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_ra diag_2 = copy.deepcopy(diag) diag_2._attrs.update({"some_non_units_attrs": 9.0}) zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) - diag_3 = Quantity(data=diag.values, dims=diag.dims, units="not_m") + diag_3 = Quantity(data=diag.view[:], dims=diag.dims, units="not_m") with pytest.raises(ValueError): zarr_monitor_single_rank.store({"time": time_3, "a": diag_3}) From cbfa2a3f585506aa864eddb0adef18f17b59b859 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 10 Nov 2025 09:32:01 -0500 Subject: [PATCH 27/43] Introduce a `single_code_path` flag in the DaCeConfig that forces a single cache to be built. (#311) --- ndsl/dsl/caches/cache_location.py | 13 +++++++++++++ ndsl/dsl/dace/dace_config.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 1c1e7ec8..87d608dd 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -5,7 +5,20 @@ def identify_code_path( rank: int, partitioner: Partitioner, + single_code_path: bool, ) -> FV3CodePath: + """Determine which code path your rank will hit. + + If single_code_path is True, single_code_path is True, + only one code path exists (case of doubly periodic grid). + If single_code_path is False, we are in the case of the + cube-sphere and we will look at our position on the tile.""" + + # Doubly-periodic or single tile grid + if single_code_path: + return FV3CodePath.All + + # Cube-sphere if partitioner.layout == (1, 1): return FV3CodePath.All elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1: diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 66d43f17..e9aa9bba 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -102,6 +102,9 @@ def _determine_compiling_ranks( 15 -> 8 """ + if config._single_code_path: + return config.my_rank == 0 + # Tile 0 compiles if partitioner.tile_index(config.my_rank) != 0: return False @@ -147,6 +150,7 @@ def __init__( tile_nz: int = 0, orchestration: DaCeOrchestration | None = None, time: bool = False, + single_code_path: bool = False, ): """Specialize the DaCe configuration for NDSL use. @@ -163,8 +167,11 @@ def __init__( orchestration: orchestration mode from DaCeOrchestration time: trigger performance collection, available to user with `performance_collector` + single_codepath: code is expected to be the same on every rank (case + of column-physics) and therefore can be compiled once """ + self._single_code_path = single_code_path # Recording SDFG loaded for fast re-access # ToDo: DaceConfig becomes a bit more than a read-only config # with this. Should be refactored into a DaceExecutor carrying a config @@ -331,7 +338,11 @@ def __init__( if communicator: self.my_rank = communicator.rank self.rank_size = communicator.comm.Get_size() - self.code_path = identify_code_path(self.my_rank, communicator.partitioner) + self.code_path = identify_code_path( + self.my_rank, + communicator.partitioner, + self._single_code_path, + ) self.layout = communicator.partitioner.layout self.do_compile = ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE From 66d351536e95216072a95f2a9fc657e0b3afb88a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 12 Nov 2025 14:23:35 +0100 Subject: [PATCH 28/43] refactor: Deprecate optional backend argument to Quantity/Local (#314) Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/comm/communicator.py | 1 + ndsl/monitor/netcdf_monitor.py | 2 + ndsl/quantity/local.py | 9 +++- ndsl/quantity/quantity.py | 35 +++++++++++--- tests/dsl/test_stencil.py | 12 +++-- tests/quantity/test_boundary.py | 3 ++ tests/quantity/test_deepcopy.py | 3 ++ tests/quantity/test_local.py | 17 ++++++- tests/quantity/test_quantity.py | 77 +++++++++++++++++++++++------- tests/quantity/test_storage.py | 4 +- tests/quantity/test_transpose.py | 5 +- tests/quantity/test_view.py | 34 +++++++++++++ tests/test_basic_operations.py | 10 ++++ tests/test_caching_comm.py | 1 + tests/test_cube_scatter_gather.py | 1 + tests/test_halo_update.py | 24 ++-------- tests/test_halo_update_ranks.py | 1 + tests/test_netcdf_monitor.py | 4 ++ tests/test_partitioner.py | 4 +- tests/test_sync_shared_boundary.py | 4 ++ tests/test_tile_scatter.py | 8 ++++ tests/test_tile_scatter_gather.py | 1 + 22 files changed, 206 insertions(+), 54 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index d1c1205f..983a35b2 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -326,6 +326,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None): # dims=quantity.dims, units=quantity.units, allow_mismatch_float_precision=True, + backend=quantity.backend, ) if recv_state is not None and name in recv_state: tile_quantity = self.gather( diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index e2cb417f..a95c532e 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -21,6 +21,7 @@ def __init__(self, initial: Quantity, time_chunk_size: int): self._dims = initial.dims self._units = initial.units self._i_time = 1 + self._backend = initial.backend def append(self, quantity: Quantity) -> None: # Allow mismatch precision here since this is I/O @@ -37,6 +38,7 @@ def data(self) -> Quantity: dims=("time",) + tuple(self._dims), units=self._units, allow_mismatch_float_precision=True, + backend=self._backend, ) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index af75242d..438def1b 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -23,11 +23,11 @@ def __init__( dims: Sequence[str], units: str, *, + backend: str | None = None, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, - backend: str | None = None, ): if gt4py_backend is not None: warnings.warn( @@ -38,6 +38,13 @@ def __init__( if backend is None: backend = gt4py_backend + if backend is None: + warnings.warn( + "`backend` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( data, dims, diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 4b902a72..300f6376 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -33,12 +33,12 @@ def __init__( dims: Sequence[str], units: str, *, + backend: str | None = None, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, number_of_halo_points: int = 0, - backend: str | None = None, ): """Initialize a Quantity. @@ -46,6 +46,8 @@ def __init__( data: ndarray-like object containing the underlying data dims: dimension names for each axis units: units of the quantity + backend: GT4Py backend name. We ensure that the data is allocated in a + performance optimal way for that backend and copy if necessary. origin: first point in data within the computational domain. Defaults to None. extent: number of points along each axis @@ -54,8 +56,6 @@ def __init__( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: GT4Py backend name. If given, we check that the data is - allocated in a performance optimal way for that backend. Raises: ValueError: Data-type mismatch between configuration and input-data @@ -70,6 +70,13 @@ def __init__( if backend is None: backend = gt4py_backend + if backend is None: + warnings.warn( + "`backend` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -179,8 +186,9 @@ def from_data_array( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: GT4Py backend name. If given, we check that the data is - allocated in a performance optimal way for that backend. + backend: GT4Py backend name. If given, we allocate data in a performance + optimal way for this backend. Overrides any potentially saved `backend` + in `data.attrs["backend"]`. """ if "units" not in data_array.attrs: raise ValueError("need units attribute to create Quantity from DataArray") @@ -201,7 +209,7 @@ def from_data_array( origin=origin, extent=extent, number_of_halo_points=number_of_halo_points, - backend=backend, + backend=_resolve_backend(data_array, backend), ) def to_netcdf( @@ -283,7 +291,7 @@ def backend(self) -> str | None: @property def attrs(self) -> dict: - return dict(**self._attrs, units=self._metadata.units) + return dict(**self._attrs, units=self.units, backend=self.backend) @property def dims(self) -> tuple[str, ...]: @@ -495,3 +503,16 @@ def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple: f"unexpected type {type(item)}" ) return tuple(return_list) + + +def _resolve_backend(data: xr.DataArray, backend: str | None) -> str: + if backend is not None: + # Forced backend name takes precedence + return backend + + # If backend name was serialized with data, take this one + if "backend" in data.attrs: + return data.attrs["backend"] + + # else, fall back to assume python-based layout. + return "debug" diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 4daa401c..1f7f3836 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -108,7 +108,9 @@ def test_domain_size_comparison( domain: tuple[int], call_count: int, ): - quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent) + quantity = Quantity( + np.zeros(extent), dimensions, "n/a", extent=extent, backend="debug" + ) stencil = FrozenStencil( copy_stencil, origin=(0, 0, 0), @@ -147,7 +149,9 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None: def test_stencil_2D_temporaries() -> None: domain = (2, 2, 5) - quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + quantity = Quantity( + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + ) stencil = FrozenStencil( two_dim_temporaries_stencil, origin=(0, 0, 0), @@ -164,7 +168,9 @@ def test_stencil_2D_temporaries() -> None: ) def test_validation_call_count(iterations: tuple[int]): domain = (2, 2, 5) - quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + quantity = Quantity( + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + ) stencil_config = StencilConfig( compilation_config=CompilationConfig(backend="numpy", rebuild=True) ) diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index a4f8e812..7ba2eddb 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -36,6 +36,7 @@ def test_boundary_data_1_by_1_array_1_halo(): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) for side in ( WEST, @@ -71,6 +72,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy): units="m", origin=(1, 1, 1), extent=(1, 1, 1), + backend="debug", ) for side in ( WEST, @@ -109,6 +111,7 @@ def test_boundary_data_2_by_2_array_2_halo(): units="m", origin=(2, 2), extent=(2, 2), + backend="debug", ) for side in ( WEST, diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index d6b5c7cb..f0c7bb13 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -14,6 +14,7 @@ def test_deepcopy_copy_is_editable_by_view(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) # assertion below is only valid if we're overwriting the entire data through view @@ -31,6 +32,7 @@ def test_deepcopy_copy_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 @@ -46,6 +48,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py index 859bb009..bb6027f5 100644 --- a/tests/quantity/test_local.py +++ b/tests/quantity/test_local.py @@ -4,7 +4,7 @@ from ndsl import Local -def test_local_descriptor_is_transient() -> None: +def test_dace_data_descriptor_is_transient() -> None: nx = 5 shape = (nx,) local = Local( @@ -19,7 +19,7 @@ def test_local_descriptor_is_transient() -> None: assert array.transient -def test_local_gt4py_backend_is_deprecated() -> None: +def test_gt4py_backend_is_deprecated() -> None: nx = 5 shape = (nx,) backend = "debug" @@ -39,3 +39,16 @@ def test_local_gt4py_backend_is_deprecated() -> None: # make sure we are backwards compatible (for now) with pytest.deprecated_call(match="gt4py_backend is deprecated"): assert local.gt4py_backend == backend + + +def test_backend_will_be_required() -> None: + nx = 5 + shape = (nx,) + with pytest.deprecated_call(match="`backend` will be a required argument"): + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + ) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 3286ce0f..ef94b45e 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -62,7 +62,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): - return Quantity(data, origin=origin, extent=extent, dims=dims, units=units) + return Quantity( + data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + ) def test_smaller_data_raises(data, origin, extent, dims, units): @@ -72,25 +74,55 @@ def test_smaller_data_raises(data, origin, extent, dims, units): except IndexError: pass else: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="received .* dimension names for .* dimensions: .*" + ): Quantity( - small_data, origin=origin, extent=extent, dims=dims, units=units + small_data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend="debug", ) def test_smaller_dims_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin, extent=extent, dims=dims[:-1], units=units) + with pytest.raises( + ValueError, match="received .* dimension names for .* dimensions: .*" + ): + Quantity( + data, + origin=origin, + extent=extent, + dims=dims[:-1], + units=units, + backend="debug", + ) def test_smaller_origin_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin[:-1], extent=extent, dims=dims, units=units) + with pytest.raises(ValueError, match="received .* origins for .* dimensions: .*"): + Quantity( + data, + origin=origin[:-1], + extent=extent, + dims=dims, + units=units, + backend="debug", + ) def test_smaller_extent_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin, extent=extent[:-1], dims=dims, units=units) + with pytest.raises(ValueError, match="received .* extents for .* dimensions: .*"): + Quantity( + data, + origin=origin, + extent=extent[:-1], + dims=dims, + units=units, + backend="debug", + ) def test_data_change_affects_quantity(data, quantity, numpy): @@ -229,20 +261,15 @@ def test_shift_slice(slice_in, shift, extent, slice_out): @pytest.mark.parametrize( "quantity", [ + Quantity(np.array(5), dims=[], units="", backend="debug"), Quantity( - np.array(5), - dims=[], - units="", - ), - Quantity( - np.array([1, 2, 3]), - dims=["dimension"], - units="degK", + np.array([1, 2, 3]), dims=["dimension"], units="degK", backend="debug" ), Quantity( np.random.randn(3, 2, 4), dims=["dim1", "dim_2", "dimension_3"], units="m", + backend="debug", ), Quantity( np.random.randn(8, 6, 6), @@ -250,6 +277,7 @@ def test_shift_slice(slice_in, shift, extent, slice_out): units="km", origin=(2, 2, 2), extent=(4, 2, 2), + backend="debug", ), ], ) @@ -265,7 +293,7 @@ def test_to_data_array(quantity): def test_data_setter(): - quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend="debug") # After allocation - field and data are the same (origin is 0) assert quantity.data.shape == quantity.field.shape @@ -356,3 +384,16 @@ def test_assign_basic_data_is_deprecated() -> None: # make sure we can still use it (for now) for i in range(5): assert quantity.data[i] == i + + +def test_constructor_backend_will_be_required() -> None: + nx = 5 + shape = (nx,) + with pytest.deprecated_call(match="`backend` will be a required argument"): + local = Quantity( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + ) diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 7fbd1a04..6d8dc4a4 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -53,7 +53,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): - return Quantity(data, origin=origin, extent=extent, dims=dims, units=units) + return Quantity( + data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + ) def test_numpy(quantity, backend): diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 88653676..b7ceb0f8 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -86,6 +86,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): units="unit_string", origin=initial_origin, extent=initial_extent, + backend="debug", ) @@ -218,7 +219,9 @@ def test_transpose_invalid_cases( def test_transpose_retains_attrs(numpy): - quantity = Quantity(numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string") + quantity = Quantity( + numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string", backend="debug" + ) quantity._attrs = {"long_name": "500 mb height"} transposed = quantity.transpose(["y", "x"]) assert transposed.attrs == quantity.attrs diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index 73245093..0ce44cfe 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -183,6 +183,7 @@ def quantity(request): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) ], ) @@ -216,6 +217,7 @@ def test_many_indices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) ], ) @@ -742,6 +744,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 4, @@ -754,6 +757,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 0, @@ -766,6 +770,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(-1, 0), slice(-1, 0)), np.array([[0]]), @@ -778,6 +783,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 1, @@ -798,6 +804,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[1, 2, 3], [6, 7, 8]]), @@ -816,6 +823,7 @@ def test_southwest(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.southwest[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -834,6 +842,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -846,6 +855,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, -1), 6, @@ -858,6 +868,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -870,6 +881,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 3, @@ -890,6 +902,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -908,6 +921,7 @@ def test_southeast(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.southeast[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -926,6 +940,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -938,6 +953,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, -1), 6, @@ -950,6 +966,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -962,6 +979,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -974,6 +992,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 3, @@ -994,6 +1013,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -1012,6 +1032,7 @@ def test_northwest(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.northwest[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -1030,6 +1051,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 4, @@ -1042,6 +1064,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 8, @@ -1054,6 +1077,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(0, 1)), np.array([[8]]), @@ -1066,6 +1090,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 4, @@ -1078,6 +1103,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 5, @@ -1098,6 +1124,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[7, 8, 9], [12, 13, 14]]), @@ -1116,6 +1143,7 @@ def test_northeast(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.northeast[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -1134,6 +1162,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 4, @@ -1146,6 +1175,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 0), slice(0, 0)), 4, @@ -1158,6 +1188,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(-1, 1), slice(-1, 1)), np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), @@ -1178,6 +1209,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(0, 1)), np.array([[2, 3], [7, 8], [12, 13]]), @@ -1198,6 +1230,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(3, 3), + backend="debug", ), (0,), np.array([6, 7, 8]), @@ -1216,6 +1249,7 @@ def test_interior(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) if len(view_slice) == len(quantity.dims): # skip if not transposed_result = transposed_quantity.view.interior[view_slice[::-1]] diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 0d707240..74916015 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -128,12 +128,14 @@ def test_copy(): data=np.zeros([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.ones([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) copy(f_in=infield.data, f_out=outfield.data) @@ -148,18 +150,21 @@ def test_adjustmentfactor(): data=np.full(shape=[20, 20], fill_value=2.0), dims=[X_DIM, Y_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=4.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) adfact(factor=factorfield.data, f_out=outfield.data) @@ -173,12 +178,14 @@ def test_setvalue(): data=np.zeros(shape=[20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) setvalue(f_out=outfield.data, value=2.0) @@ -193,18 +200,21 @@ def test_adjustdivide(): data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=1.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) addiv(factor=factorfield.data, f_out=outfield.data) diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index d2d7f64e..cafecaf8 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -30,6 +30,7 @@ def test_halo_update_integration(): units="", origin=origin, extent=extent, + backend="debug", ) for _ in range(n_ranks) ] diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 236e1eb9..b71f54e1 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -169,6 +169,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), + backend="numpy", ) diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index ca76ab8b..21137c17 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -319,11 +319,7 @@ def depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) return_list.append(quantity) return return_list @@ -356,11 +352,7 @@ def tile_depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) return_list.append(quantity) return return_list @@ -500,11 +492,7 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, for _rank in range(total_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return_list.append(quantity) @@ -521,11 +509,7 @@ def zeros_quantity_tile_list( for _rank in range(single_tile_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return_list.append(quantity) diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index 6ceb4886..b50f03d1 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -126,6 +126,7 @@ def rank_quantity_list(total_ranks, numpy, dtype): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) quantity_list.append(quantity) return quantity_list diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 19614486..a19cf193 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -39,6 +39,7 @@ def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): units = "m" + backend = "debug" nz, ny, nx = shape ny_rank = int(ny / layout[0] + ny_rank_add) nx_rank = int(nx / layout[1] + nx_rank_add) @@ -74,6 +75,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store_constant(state) @@ -87,6 +89,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store(state) @@ -100,6 +103,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store_constant(state) diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index 4d20e709..5b8836ac 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -993,7 +993,9 @@ def test_subtile_extent_with_tile_dimensions( cubedsphere_expected, ): data_array = np.zeros((tile_extent)) - quantity = Quantity(data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0]) + quantity = Quantity( + data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0], backend="debug" + ) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) cubedsphere_partitioner = CubedSpherePartitioner(tile_partitioner) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 7db5a621..be281194 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -81,6 +81,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), + backend="debug", ) y_data = numpy.empty((2, 3), dtype=dtype) y_data[:] = rank @@ -90,6 +91,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), + backend="debug", ) quantity_list.append((x_quantity, y_quantity)) return quantity_list @@ -147,6 +149,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), + backend="debug", ) y_data = 6 * total_ranks + numpy.array([[0, 1, 2], [3, 4, 5]]) + 6 * rank y_quantity = Quantity( @@ -155,6 +158,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), + backend="debug", ) quantity_list.append((x_quantity, y_quantity)) return quantity_list diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index d768bb15..ccbe1cdb 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -36,11 +36,13 @@ def test_interface_state_two_by_two_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0] + 1, layout[1] + 1]), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", + backend="debug", ), "pos_i": Quantity( numpy.empty([layout[0] + 1, layout[1] + 1], dtype=numpy.int32), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", + backend="debug", ), } @@ -80,16 +82,19 @@ def test_centered_state_one_item_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), "rank_pos_j": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), "rank_pos_i": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), } @@ -137,6 +142,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), "rank_pos_j": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -144,6 +150,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), "rank_pos_i": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -151,6 +158,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), } diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 60ef583a..e8bfa861 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -150,6 +150,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), + backend="debug", ) From 51cae6fb44becc52239c2ee479b5ed5109968b4a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 13 Nov 2025 12:03:55 +0100 Subject: [PATCH 29/43] refactor: remove DummyComm as alias to LocalComm (#319) There's no need for this alias. We thus replace all occurrences for the alias with the underlying `LocalComm` directly. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- docs/docstrings/testing/dummy_comm.md | 3 --- mkdocs.yml | 1 - ndsl/__init__.py | 2 -- ndsl/testing/dummy_comm.py | 1 - tests/mpi/test_mpi_mock.py | 24 ++++++++++++------------ tests/test_cube_scatter_gather.py | 4 ++-- tests/test_g2g_communication.py | 6 +++--- tests/test_halo_update.py | 6 +++--- tests/test_halo_update_ranks.py | 4 ++-- tests/test_legacy_restart.py | 8 ++++---- tests/test_netcdf_monitor.py | 4 ++-- tests/test_sync_shared_boundary.py | 4 ++-- tests/test_tile_scatter.py | 4 ++-- tests/test_tile_scatter_gather.py | 4 ++-- 14 files changed, 34 insertions(+), 41 deletions(-) delete mode 100644 docs/docstrings/testing/dummy_comm.md delete mode 100644 ndsl/testing/dummy_comm.py diff --git a/docs/docstrings/testing/dummy_comm.md b/docs/docstrings/testing/dummy_comm.md deleted file mode 100644 index 538c01d2..00000000 --- a/docs/docstrings/testing/dummy_comm.md +++ /dev/null @@ -1,3 +0,0 @@ -# dummy_comm - -::: testing.dummy_comm diff --git a/mkdocs.yml b/mkdocs.yml index 4c385bd1..24aa01e0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -103,7 +103,6 @@ nav: - "tridiag": docstrings/stencils/tridiag.md - testing: - "comparison": docstrings/testing/comparison.md - - "dummy_comm": docstrings/testing/dummy_comm.md - "perturbation": docstrings/testing/perturbation.md - viz: - "cube_sphere": docstrings/viz/cube_sphere.md diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 1c445feb..3c2a018c 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -28,7 +28,6 @@ from .performance.report import Experiment, Report, TimeReport from .quantity import Local, Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import -from .testing.dummy_comm import DummyComm from .types import Allocator from .utils import MetaEnumStr @@ -79,7 +78,6 @@ "Quantity", "FieldBundle", "FieldBundleType", - "DummyComm", "Allocator", "MetaEnumStr", "State", diff --git a/ndsl/testing/dummy_comm.py b/ndsl/testing/dummy_comm.py deleted file mode 100644 index f3e93817..00000000 --- a/ndsl/testing/dummy_comm.py +++ /dev/null @@ -1 +0,0 @@ -from ndsl.comm.local_comm import LocalComm as DummyComm # noqa diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 6436af02..852b5da8 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl import DummyComm +from ndsl import LocalComm from ndsl.buffer import recv_buffer from ndsl.comm.local_comm import ConcurrencyError from tests.mpi.mpi_comm import MPI @@ -33,11 +33,11 @@ def send_recv(comm, numpy): data = numpy.asarray([rank], dtype=numpy.int32) if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") comm.Recv(data, source=rank - 1) return data @@ -50,11 +50,11 @@ def send_recv_big_data(comm, numpy): data = numpy.ones([5, 3, 96], dtype=numpy.float64) * rank if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") comm.Recv(data, source=rank - 1) return data @@ -96,11 +96,11 @@ def send_f_contiguous_buffer(comm, numpy): data = numpy.random.uniform(size=[2, 3]).T if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") comm.Recv(data, source=rank - 1) return data @@ -115,7 +115,7 @@ def send_non_contiguous_buffer(comm, numpy): recv_buffer = numpy.zeros([4, 2, 3]) if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: @@ -132,7 +132,7 @@ def send_subarray(comm, numpy): recv_buffer = numpy.zeros([2, 2, 2]) if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data[1:-1, 1:-1, 1:-1], dest=rank + 1) if rank > 0: @@ -151,11 +151,11 @@ def recv_to_subarray(comm, numpy): return_value = recv_buffer if rank < size - 1: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"sending data from {rank} to {rank + 1}") comm.Send(data, dest=rank + 1) if rank > 0: - if isinstance(comm, DummyComm): + if isinstance(comm, LocalComm): print(f"receiving data from {rank - 1} to {rank}") try: comm.Recv(recv_buffer[1:-1, 1:-1, 1:-1], source=rank - 1) @@ -255,7 +255,7 @@ def dummy_list(total_ranks): return_list = [] for rank in range(total_ranks): return_list.append( - DummyComm(rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer) + LocalComm(rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer) ) return return_list diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index b71f54e1..5fe4f5b9 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -6,7 +6,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -100,7 +100,7 @@ def communicator_list(layout): for rank in range(total_ranks): return_list.append( CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), timer=Timer(), ) diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 9147a4a8..b7af3e12 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -12,7 +12,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -56,7 +56,7 @@ def cpu_communicators(cube_partitioner): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), force_cpu=True, @@ -74,7 +74,7 @@ def gpu_communicators(cube_partitioner): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index 21137c17..e4ca02df 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -6,8 +6,8 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, HaloUpdater, + LocalComm, Quantity, TileCommunicator, TilePartitioner, @@ -204,7 +204,7 @@ def communicator_list(cube_partitioner: CubedSpherePartitioner): for rank in range(total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, @@ -222,7 +222,7 @@ def tile_communicator_list(tile_partitioner): for rank in range(total_ranks): return_list.append( TileCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=tile_partitioner, diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index b50f03d1..4a101f37 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -3,7 +3,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -103,7 +103,7 @@ def communicator_list(cube_partitioner, total_ranks): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, diff --git a/tests/test_legacy_restart.py b/tests/test_legacy_restart.py index 65514747..3e4ecd03 100644 --- a/tests/test_legacy_restart.py +++ b/tests/test_legacy_restart.py @@ -10,7 +10,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -38,7 +38,7 @@ def get_c12_restart_state_list(layout, only_names, tracer_properties): communicator_list = [] for rank in range(total_ranks): communicator = CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), ) communicator_list.append(communicator) @@ -148,7 +148,7 @@ def test_open_c12_restart_empty_to_state_without_crashing(layout): communicator_list = [] for rank in range(total_ranks): communicator = CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), ) communicator_list.append(communicator) @@ -190,7 +190,7 @@ def test_open_c12_restart_to_allocated_state_without_crashing(layout): communicator_list = [] for rank in range(total_ranks): communicator = CubedSphereCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), CubedSpherePartitioner(TilePartitioner(layout)), ) communicator_list.append(communicator) diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index a19cf193..c4cbe575 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -10,7 +10,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, NetCDFMonitor, Quantity, TilePartitioner, @@ -54,7 +54,7 @@ def test_monitor_store_multi_rank_state( for rank in range(total_ranks): communicator = CubedSphereCommunicator( partitioner=partitioner, - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), ) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index be281194..27db0048 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -3,7 +3,7 @@ from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - DummyComm, + LocalComm, Quantity, TilePartitioner, ) @@ -56,7 +56,7 @@ def communicator_list(cube_partitioner, total_ranks): for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=cube_partitioner, diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index ccbe1cdb..9ab00eba 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -1,6 +1,6 @@ import pytest -from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner +from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM @@ -20,7 +20,7 @@ def get_tile_communicator_list(partitioner): for rank in range(total_ranks): tile_communicator_list.append( TileCommunicator( - comm=DummyComm( + comm=LocalComm( rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer ), partitioner=partitioner, diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index e8bfa861..4e87eb65 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -3,7 +3,7 @@ import pytest -from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner +from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import ( HORIZONTAL_DIMS, X_DIM, @@ -84,7 +84,7 @@ def communicator_list(layout): for rank in range(total_ranks): return_list.append( TileCommunicator( - DummyComm(rank, total_ranks, shared_buffer), + LocalComm(rank, total_ranks, shared_buffer), TilePartitioner(layout), ) ) From 9c84e7512a0977ba70383b54d10572df42ba9f56 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 13 Nov 2025 12:04:38 +0100 Subject: [PATCH 30/43] Deprecate `CopyCornersXY` (#317) `CopyCornersXY` are replaced with `CopyCornersX` and `CopyCornersY` in PyFV3. The class is currently unused and will be removed after the next release of NDSL. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/stencils/corners.py | 8 ++++++++ tests/stencils/test_stencils.py | 27 +++++++++++++++------------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ndsl/stencils/corners.py b/ndsl/stencils/corners.py index 4e21f8b1..d3fd2c0d 100644 --- a/ndsl/stencils/corners.py +++ b/ndsl/stencils/corners.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Sequence from gt4py.cartesian import gtscript @@ -28,6 +29,13 @@ def __init__( y_field: 3D gt4py storage to use for y-differenceable field (x-differenceable field uses same memory as base field) """ + warnings.warn( + "Usage of CopyCornersXY is deprecated and will be removed in the next release. " + "Use `CopyCornersX` and `CopyCornersY` in PyFV3 for a more future-proof " + "implementation of the corner code.", + DeprecationWarning, + stacklevel=2, + ) grid_indexing = stencil_factory.grid_indexing origin, domain = grid_indexing.get_origin_domain( dims=dims, halos=(grid_indexing.n_halo, grid_indexing.n_halo) diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index ce3fe36f..73802a23 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -1,27 +1,22 @@ import numpy as np +import pytest -from ndsl import StencilFactory +from ndsl import QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.stencils import CopyCornersXY from ndsl.stencils.column_operations import column_max, column_min -nx = 1 -ny = 1 -nz = 10 -nhalo = 0 -backend = "dace:cpu" - -stencil_factory, quantity_factory = get_factories_single_tile( - nx, ny, nz, nhalo, backend -) +@pytest.fixture +def boilerplate() -> tuple[StencilFactory, QuantityFactory]: + return get_factories_single_tile(nx=1, ny=1, nz=10, nhalo=0, backend="dace:cpu") class ColumnOperations: def __init__(self, stencil_factory: StencilFactory): - grid_indexing = stencil_factory.grid_indexing def column_max_stencil( data: FloatField, max_value: FloatFieldIJ, max_index: FloatFieldIJ @@ -60,7 +55,8 @@ def __call__( self._column_min_stencil(data, min_value, min_index) -def test_column_operations(): +def test_column_operations(boilerplate): + stencil_factory, quantity_factory = boilerplate data = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") data.field[:] = [ 47.3821, @@ -87,3 +83,10 @@ def test_column_operations(): assert max_index.field[:] == np.argmax(data.field[:], axis=2) assert min_value.field[:] == np.min(data.field[:, :, 5:], axis=2) assert min_index.field[:] == 5 + np.argmin(data.field[:, :, 5:], axis=2) + + +def test_CopyCornersXY_deprecation(boilerplate) -> None: + stencil_factory, _ = boilerplate + + with pytest.deprecated_call(match="Usage of CopyCornersXY is deprecated"): + CopyCornersXY(stencil_factory, [X_DIM, Y_DIM, Z_DIM], None) From 29bbbad35cc8815f623985691c9a5f97844cefa3 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 13 Nov 2025 16:45:26 +0100 Subject: [PATCH 31/43] refactor: Deprecate `NullComm` in favor of `MPIComm` and `LocalComm` (#318) * unrelated: fix typo in warning message * refactor: change NullComm -> MPIComm in boilerplate This adds a test that the MPI communicator only has one rank if a single-tile setup is requested. * refactor: deprecate NullComm `NullComm` can be replaced with either `LocalComm` or `MPIComm`. --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/boilerplate.py | 11 ++- ndsl/comm/null_comm.py | 7 ++ ndsl/dsl/dace/dace_config.py | 4 +- .../stree/optimizations/refine_transients.py | 4 +- tests/dsl/test_compilation_config.py | 10 +-- tests/grid/test_eta.py | 64 ++--------------- tests/mpi/__init__.py | 6 ++ tests/mpi/mpi_comm.py | 6 -- tests/mpi/test_decomposition.py | 18 +++++ tests/mpi/test_eta.py | 68 +++++++++++++++++++ tests/mpi/test_mpi_all_reduce_sum.py | 2 +- tests/mpi/test_mpi_halo_update.py | 2 +- tests/mpi/test_mpi_mock.py | 2 +- tests/test_caching_comm.py | 32 +++++---- tests/test_decomposition.py | 20 +----- tests/test_null_comm.py | 14 ++-- 16 files changed, 155 insertions(+), 115 deletions(-) delete mode 100644 tests/mpi/mpi_comm.py create mode 100644 tests/mpi/test_decomposition.py create mode 100644 tests/mpi/test_eta.py diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index 6d485887..c6fb82f8 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -3,7 +3,7 @@ DaceConfig, DaCeOrchestration, GridIndexing, - NullComm, + MPIComm, QuantityFactory, RunMode, StencilConfig, @@ -54,6 +54,13 @@ def _get_factories( ) if topology == "tile": + mpi_comm = MPIComm() + if mpi_comm.Get_size() != 1: + raise ValueError( + "Single tile topology requested with an MPI communicator of size " + f"{mpi_comm.Get_size()} > 1. Re-configure MPI to run on only one rank." + ) + partitioner = TilePartitioner((1, 1)) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, @@ -63,7 +70,7 @@ def _get_factories( layout=partitioner.layout, tile_partitioner=partitioner, ) - comm = TileCommunicator(comm=NullComm(0, 1, 42), partitioner=partitioner) + comm = TileCommunicator(comm=mpi_comm, partitioner=partitioner) else: raise NotImplementedError(f"Topology {topology} is not implemented.") diff --git a/ndsl/comm/null_comm.py b/ndsl/comm/null_comm.py index 2d67c78f..bee499ce 100644 --- a/ndsl/comm/null_comm.py +++ b/ndsl/comm/null_comm.py @@ -1,4 +1,5 @@ import copy +import warnings from collections.abc import Mapping from typing import Any, TypeVar, cast @@ -33,6 +34,12 @@ def __init__(self, rank: int, total_ranks: int, fill_value: T = default_fill_val fill_value: fill halos with this value when performing halo updates. """ + warnings.warn( + "NullComm is deprecated and will be removed with the next version of NDSL. " + "Use MPIComm or LocalComm instead.", + DeprecationWarning, + stacklevel=2, + ) self.rank = rank self.total_ranks = total_ranks self._fill_value = fill_value diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index e9aa9bba..600a12c3 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -8,8 +8,8 @@ from dace.frontend.python.parser import DaceProgram from gt4py.cartesian.config import GT4PY_COMPILE_OPT_LEVEL +from ndsl import LocalComm from ndsl.comm.communicator import Communicator -from ndsl.comm.null_comm import NullComm from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath @@ -180,7 +180,7 @@ def __init__( PerformanceCollector( "InternalOrchestrationTimer", comm=( - communicator.comm if communicator is not None else NullComm(0, 6, 0) + LocalComm(0, 6, {}) if communicator is None else communicator.comm ), ) if time diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 967d6923..49e0917b 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -169,7 +169,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): - Using the dataflow above, we can reduce the dimensions to the correct lowest size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element buffer) - - Current action when detecting a valide candidate is to reduce the size of the dimension + - Current action when detecting a valid candidate is to reduce the size of the dimension to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce the cache access significantly. This also has been implemented to _not_ deal with offset/slicing downstream impact of removing an axis. Nevertheless the xis should be removed if it's not @@ -185,7 +185,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): def __init__(self, backend: str) -> None: warnings.warn( - "CartesianRefineTransients is a WIP. It's usage is *severaly* limited" + "CartesianRefineTransients is a WIP. It's usage is *severely* limited " "and will most likely lead to bad numerics. Check the docs, check utest.", UserWarning, stacklevel=2, diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 326eb222..9f9c165c 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -7,7 +7,7 @@ CompilationConfig, CubedSphereCommunicator, CubedSpherePartitioner, - NullComm, + LocalComm, RunMode, TilePartitioner, ) @@ -34,7 +34,7 @@ def test_check_communicator_valid( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int((sqrt(size / 6))))) ) - comm = NullComm(rank=0, total_ranks=size) + comm = LocalComm(rank=0, total_ranks=size, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -52,7 +52,7 @@ def test_check_communicator_invalid( nx: int, ny: int, use_minimal_caching: bool, run_mode: RunMode ) -> None: partitioner = CubedSpherePartitioner(TilePartitioner((nx, ny))) - comm = NullComm(rank=0, total_ranks=nx * ny * 6) + comm = LocalComm(rank=0, total_ranks=nx * ny * 6, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -90,7 +90,7 @@ def test_get_decomposition_info_from_comm( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int(sqrt(size / 6)))) ) - comm = NullComm(rank=rank, total_ranks=size) + comm = LocalComm(rank=rank, total_ranks=size, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig(use_minimal_caching=True, run_mode=RunMode.Run) ( @@ -130,7 +130,7 @@ def test_determine_compiling_equivalent( TilePartitioner((sqrt(size / 6), sqrt(size / 6))) ) comm = unittest.mock.MagicMock() - comm = NullComm(rank=rank, total_ranks=size) + comm = LocalComm(rank=rank, total_ranks=size, buffer_dict={}) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) assert ( config.determine_compiling_equivalent(rank, cubed_sphere_comm.partitioner) diff --git a/tests/grid/test_eta.py b/tests/grid/test_eta.py index d50e2fd9..c4fab0eb 100755 --- a/tests/grid/test_eta.py +++ b/tests/grid/test_eta.py @@ -1,13 +1,11 @@ from pathlib import Path -import numpy as np import pytest -import xarray as xr from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, - NullComm, + LocalComm, QuantityFactory, SubtileGridSizer, TilePartitioner, @@ -23,58 +21,6 @@ """ -@pytest.mark.parametrize("levels", [79, 91]) -def test_set_hybrid_pressure_coefficients_correct(levels): - """ - This test checks to see that the ak and bk arrays are read-in correctly and are - stored as expected. Both values of km=79 and km=91 are tested and both tests are - expected to pass with the stored ak and bk values agreeing with the values read-in - directly from the NetCDF file. - """ - - eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" - eta_data = xr.open_dataset(eta_file) - - backend = "numpy" - - layout = (1, 1) - - nz = levels - ny = 48 - nx = 48 - nhalo = 3 - - partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - - communicator = CubedSphereCommunicator(NullComm(rank=0, total_ranks=6), partitioner) - - sizer = SubtileGridSizer.from_tile_params( - nx_tile=nx, - ny_tile=ny, - nz=nz, - n_halo=nhalo, - layout=layout, - tile_partitioner=partitioner.tile, - tile_rank=communicator.tile.rank, - ) - - quantity_factory = QuantityFactory(sizer, backend=backend) - - metric_terms = MetricTerms( - quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file - ) - - ak_results = metric_terms.ak.data - bk_results = metric_terms.bk.data - ak_answers, bk_answers = eta_data["ak"].values, eta_data["bk"].values - - assert ak_answers.size == ak_results.size, "Unexpected size of bk" - assert bk_answers.size == bk_results.size, "Unexpected size of ak" - - assert np.array_equal(ak_answers, ak_results), "Unexpected value of ak" - assert np.array_equal(bk_answers, bk_results), "Unexpected value of bk" - - def test_set_hybrid_pressure_coefficients_nofile(): """ This test checks to see that the program fails when the eta_file is not specified @@ -94,7 +40,9 @@ def test_set_hybrid_pressure_coefficients_nofile(): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(NullComm(rank=0, total_ranks=6), partitioner) + communicator = CubedSphereCommunicator( + LocalComm(rank=0, total_ranks=6, buffer_dict={}), partitioner + ) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, @@ -137,7 +85,9 @@ def test_set_hybrid_pressure_coefficients_not_mono(): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(NullComm(rank=0, total_ranks=6), partitioner) + communicator = CubedSphereCommunicator( + LocalComm(rank=0, total_ranks=6, buffer_dict={}), partitioner + ) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, diff --git a/tests/mpi/__init__.py b/tests/mpi/__init__.py index e69de29b..ad92926d 100644 --- a/tests/mpi/__init__.py +++ b/tests/mpi/__init__.py @@ -0,0 +1,6 @@ +from ndsl.comm.mpi import MPI + + +if MPI.COMM_WORLD.Get_size() == 1: + # not run as a parallel test, disable MPI tests + MPI = None diff --git a/tests/mpi/mpi_comm.py b/tests/mpi/mpi_comm.py deleted file mode 100644 index ad92926d..00000000 --- a/tests/mpi/mpi_comm.py +++ /dev/null @@ -1,6 +0,0 @@ -from ndsl.comm.mpi import MPI - - -if MPI.COMM_WORLD.Get_size() == 1: - # not run as a parallel test, disable MPI tests - MPI = None diff --git a/tests/mpi/test_decomposition.py b/tests/mpi/test_decomposition.py new file mode 100644 index 00000000..94fee856 --- /dev/null +++ b/tests/mpi/test_decomposition.py @@ -0,0 +1,18 @@ +from unittest.mock import MagicMock + +import pytest + +from ndsl import MPIComm +from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles +from tests.mpi import MPI + + +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") +def test_unblock_waiting_tiles(): + comm = MPIComm() + compilation_config = MagicMock(compiling_equivalent=0) + + if comm.Get_rank() == 0: + unblock_waiting_tiles(comm._comm) + else: + block_waiting_for_compilation(comm._comm, compilation_config) diff --git a/tests/mpi/test_eta.py b/tests/mpi/test_eta.py new file mode 100644 index 00000000..5c09cbf1 --- /dev/null +++ b/tests/mpi/test_eta.py @@ -0,0 +1,68 @@ +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr + +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + MPIComm, + QuantityFactory, + SubtileGridSizer, + TilePartitioner, +) +from ndsl.grid import MetricTerms +from tests.mpi import MPI + + +@pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") +@pytest.mark.parametrize("levels", [79, 91]) +def test_set_hybrid_pressure_coefficients_correct(levels): + """ + This test checks to see that the ak and bk arrays are read-in correctly and are + stored as expected. Both values of km=79 and km=91 are tested and both tests are + expected to pass with the stored ak and bk values agreeing with the values read-in + directly from the NetCDF file. + """ + + eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" + eta_data = xr.open_dataset(eta_file) + + backend = "numpy" + + layout = (1, 1) + + nz = levels + ny = 48 + nx = 48 + nhalo = 3 + + partitioner = CubedSpherePartitioner(TilePartitioner(layout)) + communicator = CubedSphereCommunicator(MPIComm(), partitioner) + + sizer = SubtileGridSizer.from_tile_params( + nx_tile=nx, + ny_tile=ny, + nz=nz, + n_halo=nhalo, + layout=layout, + tile_partitioner=partitioner.tile, + tile_rank=communicator.tile.rank, + ) + + quantity_factory = QuantityFactory(sizer, backend=backend) + + metric_terms = MetricTerms( + quantity_factory=quantity_factory, communicator=communicator, eta_file=eta_file + ) + + ak_results = metric_terms.ak.data + bk_results = metric_terms.bk.data + ak_answers, bk_answers = eta_data["ak"].values, eta_data["bk"].values + + assert ak_answers.size == ak_results.size, "Unexpected size of bk" + assert bk_answers.size == bk_results.size, "Unexpected size of ak" + + assert np.array_equal(ak_answers, ak_results), "Unexpected value of ak" + assert np.array_equal(bk_answers, bk_results), "Unexpected value of bk" diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 52b02dad..34a3b0dd 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -10,7 +10,7 @@ from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.mpi import MPIComm from ndsl.dsl.typing import Float -from tests.mpi.mpi_comm import MPI +from tests.mpi import MPI @pytest.fixture diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 76aca71e..6558a22c 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -27,7 +27,7 @@ Z_DIM, Z_INTERFACE_DIM, ) -from tests.mpi.mpi_comm import MPI +from tests.mpi import MPI @pytest.fixture diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 852b5da8..7007d5d5 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -4,7 +4,7 @@ from ndsl import LocalComm from ndsl.buffer import recv_buffer from ndsl.comm.local_comm import ConcurrencyError -from tests.mpi.mpi_comm import MPI +from tests.mpi import MPI worker_function_list = [] diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index cafecaf8..bdbab4cd 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -8,7 +8,6 @@ CubedSphereCommunicator, CubedSpherePartitioner, LocalComm, - NullComm, Quantity, TilePartitioner, ) @@ -76,29 +75,38 @@ def perform_serial_halo_updates( def test_Recv_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) + comm = CachingCommWriter(comm=LocalComm(rank=0, total_ranks=6, buffer_dict={})) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.received_buffers) == 0 - comm.Recv(recvbuf, source=0) - assert len(comm._data.received_buffers) == 1 - assert comm._data.received_buffers[0].shape == shape + + if comm.Get_rank() == 0: + comm.bcast(np.random.randn(*shape)) + else: + comm.Recv(recvbuf, source=0) + assert len(comm._data.received_buffers) == 1 + assert comm._data.received_buffers[0].shape == shape def test_Irecv_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) + comm = CachingCommWriter(comm=LocalComm(rank=0, total_ranks=6, buffer_dict={})) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.received_buffers) == 0 - req = comm.Irecv(recvbuf, source=0) - assert len(comm._data.received_buffers) == 0 - req.wait() - assert len(comm._data.received_buffers) == 1 - assert comm._data.received_buffers[0].shape == shape + + if comm.Get_rank() == 0: + comm.Isend(np.random.randn(*shape), dest=2) + + if comm.Get_rank() == 2: + req = comm.Irecv(recvbuf, source=0) + assert len(comm._data.received_buffers) == 0 + req.wait() + assert len(comm._data.received_buffers) == 1 + assert comm._data.received_buffers[0].shape == shape def test_bcast_inserts_data(): - comm = CachingCommWriter(comm=NullComm(rank=0, total_ranks=6)) + comm = CachingCommWriter(comm=LocalComm(rank=0, total_ranks=6, buffer_dict={})) shape = (12, 12) recvbuf = np.random.randn(*shape) assert len(comm._data.bcast_objects) == 0 diff --git a/tests/test_decomposition.py b/tests/test_decomposition.py index 2160e23e..2077a609 100644 --- a/tests/test_decomposition.py +++ b/tests/test_decomposition.py @@ -6,13 +6,10 @@ from ndsl import CubedSpherePartitioner, TilePartitioner from ndsl.comm.decomposition import ( - block_waiting_for_compilation, build_cache_path, check_cached_path_exists, determine_rank_is_compiling, - unblock_waiting_tiles, ) -from tests.mpi.mpi_comm import MPI @pytest.mark.parametrize( @@ -35,7 +32,7 @@ def test_determine_rank_is_compiling( def test_check_cached_path_exists(): with pytest.raises(RuntimeError): - check_cached_path_exists("notarealpath") + check_cached_path_exists("not/a/real/path") def test_check_cached_path_exists_working(): @@ -67,18 +64,3 @@ def test_build_cache_path( ) _, rank_str = build_cache_path(compilation_config) assert rank_str == target_rank_str - - -@pytest.mark.skipif( - MPI is None, - reason="pytest is not run in parallel", -) -def test_unblock_waiting_tiles(): - comm = MPI.COMM_WORLD - compilation_config = unittest.mock.MagicMock(compiling_equivalent=0) - rank = comm.Get_rank() - size = comm.Get_size() - if rank != 0: - block_waiting_for_compilation(comm, compilation_config) - if rank == 0: - unblock_waiting_tiles(comm) diff --git a/tests/test_null_comm.py b/tests/test_null_comm.py index 74065f67..9ed19d9c 100644 --- a/tests/test_null_comm.py +++ b/tests/test_null_comm.py @@ -1,3 +1,5 @@ +import pytest + from ndsl import ( CubedSphereCommunicator, CubedSpherePartitioner, @@ -7,10 +9,8 @@ def test_can_create_cube_communicator(): - rank = 2 - total_ranks = 24 - mpi_comm = NullComm(rank, total_ranks) - layout = (2, 2) - partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - communicator = CubedSphereCommunicator(mpi_comm, partitioner) - communicator.tile.partitioner + with pytest.deprecated_call(match="NullComm is deprecated"): + mpi_comm = NullComm(rank=2, total_ranks=24) + partitioner = CubedSpherePartitioner(TilePartitioner(layout=(2, 2))) + communicator = CubedSphereCommunicator(mpi_comm, partitioner) + assert communicator.tile.partitioner From 1098794b37ed9a8204b5ac5afc7cfef717657900 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 13 Nov 2025 21:28:29 +0100 Subject: [PATCH 32/43] build: gt4py update (self-assignment in serial vertical loops) (#316) This updates the gt4py dependency to bring up the fix that allows self-assignment with offset reads in K for serial (e.g. FORWARD/BACKWARD) vertical loops. See https://github.com/GridTools/gt4py/pull/2388 (in particular the test cases for details on what is allowed and what not). Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 7d123536..751e64c1 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 7d1235363594766aa57bc259b94208091137c32e +Subproject commit 751e64c192afaf53720aafa8fb841cd66291ca9f From 18919a820a259b49ff1726021612b3f6f295d546 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 13 Nov 2025 23:12:18 +0100 Subject: [PATCH 33/43] refactor: specify backend when allocating a Quantity (#320) This PR is a follow-up from https://github.com/NOAA-GFDL/NDSL/pull/314 and adds the soon to be required `backend` parameter to constructor calls of `Quantity`. I missed a couple ones because PRs were merged in parallel, e.g. re-enabling the `ZarrMonitor` tests. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- ndsl/quantity/field_bundle.py | 1 + tests/mpi/test_mpi_halo_update.py | 15 +++------------ tests/test_zarr_monitor.py | 27 +++++++++++++++++---------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/ndsl/quantity/field_bundle.py b/ndsl/quantity/field_bundle.py index f33ae8a6..c66b338f 100644 --- a/ndsl/quantity/field_bundle.py +++ b/ndsl/quantity/field_bundle.py @@ -95,6 +95,7 @@ def __getattr__(self, name: str) -> Quantity: units=self._quantity.units, origin=self._quantity.origin[:-1], extent=self._quantity.extent[:-1], + backend=self._quantity.backend, ) return self._per_name_view[name] diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 6558a22c..45278cc5 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -271,14 +271,9 @@ def depth_quantity( data[tuple(pos)] = numpy.nan pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan - quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + return Quantity( + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) - return quantity @pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") @@ -320,11 +315,7 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype): outside of it.""" data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return quantity diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index be3e7650..b0943ffb 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -98,9 +98,7 @@ def base_state(request, nz, ny, nx, numpy) -> dict: if request.param == "one_var_2d": return { "var1": Quantity( - numpy.ones([ny, nx]), - dims=("y", "x"), - units="m", + numpy.ones([ny, nx]), dims=("y", "x"), units="m", backend="debug" ) } @@ -110,20 +108,20 @@ def base_state(request, nz, ny, nx, numpy) -> dict: numpy.ones([nz, ny, nx]), dims=("z", "y", "x"), units="m", + backend="debug", ) } if request.param == "two_vars": return { "var1": Quantity( - numpy.ones([ny, nx]), - dims=("y", "x"), - units="m", + numpy.ones([ny, nx]), dims=("y", "x"), units="m", backend="debug" ), "var2": Quantity( numpy.ones([nz, ny, nx]), dims=("z", "y", "x"), units="degK", + backend="debug", ), } @@ -263,6 +261,7 @@ def test_monitor_file_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend="debug", ), } monitor_list[rank].store(state) @@ -351,7 +350,9 @@ def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale # initialize store monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) - zero_quantity = Quantity(numpy.zeros([10, 10]), dims=("y", "x"), units="m") + zero_quantity = Quantity( + numpy.zeros([10, 10]), dims=("y", "x"), units="m", backend="debug" + ) monitor.store({"var": zero_quantity}) # open w/o dask using chunks=None @@ -371,7 +372,9 @@ def test_values_preserved(cube_partitioner, numpy): # initialize store monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) - quantity = Quantity(numpy.random.uniform(size=(10, 10)), dims=dims, units=units) + quantity = Quantity( + numpy.random.uniform(size=(10, 10)), dims=dims, units=units, backend="debug" + ) monitor.store({"var": quantity}) # open w/o dask using chunks=None @@ -421,7 +424,10 @@ def test_monitor_file_store_inconsistent_calendars( def diag(request, numpy): dims = request.param return Quantity( - numpy.ones([size + 2 for size in range(len(dims))]), dims=dims, units="m" + numpy.ones([size + 2 for size in range(len(dims))]), + dims=dims, + units="m", + backend="debug", ) @@ -495,6 +501,7 @@ def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): numpy.ones([size + 2 for size in range(len(diag.dims))]), dims=new_dims, units="m", + backend="debug", ) with pytest.raises(ValueError) as excinfo: zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) @@ -510,6 +517,6 @@ def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_ra diag_2 = copy.deepcopy(diag) diag_2._attrs.update({"some_non_units_attrs": 9.0}) zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) - diag_3 = Quantity(data=diag.view[:], dims=diag.dims, units="not_m") + diag_3 = Quantity(data=diag.view[:], dims=diag.dims, units="not_m", backend="debug") with pytest.raises(ValueError): zarr_monitor_single_rank.store({"time": time_3, "a": diag_3}) From e8177b7b7b022e30c7b2a5a01dc9674b7681c594 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 14 Nov 2025 07:59:45 -0500 Subject: [PATCH 34/43] [Translate test] Compute the percentage of changing grid points that error (#322) * Add `inputs` to MultiModalFloat metric Compute the percentage of changing grid points that errored * Lint --- ndsl/stencils/testing/test_translate.py | 11 +++++++++++ ndsl/testing/comparison.py | 22 ++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index a350e76f..d8c615e2 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -258,6 +258,11 @@ def test_sequential_savepoint( output_data = gt_utils.asarray(output[varname]) if multimodal_metric: metric = MultiModalFloatMetric( + input_values=( + original_input_data[varname] + if varname in original_input_data.keys() + else None + ), reference_values=ref_data, computed_values=output_data, absolute_eps_override=case.testobj.mmr_absolute_eps, @@ -391,6 +396,7 @@ def test_parallel_savepoint( if not case.exists: pytest.skip(f"Data at rank {case.grid.rank} does not exists") input_data = dataset_to_dict(case.ds_in) + original_input_data = copy.deepcopy(input_data) # run python version of functionality output = case.testobj.compute_parallel(input_data, communicator) out_vars = set(case.testobj.outputs.keys()) @@ -414,6 +420,11 @@ def test_parallel_savepoint( output_data = gt_utils.asarray(output[varname]) if multimodal_metric: metric = MultiModalFloatMetric( + input_values=( + original_input_data[varname] + if varname in original_input_data.keys() + else None + ), reference_values=ref_data[varname][0], computed_values=output_data, absolute_eps_override=case.testobj.mmr_absolute_eps, diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 44b4077b..3891060e 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -224,6 +224,7 @@ class MultiModalFloatMetric(BaseMetric): def __init__( self, + input_values: np.ndarray | None, reference_values: np.ndarray, computed_values: np.ndarray, absolute_eps_override: float = -1, @@ -256,6 +257,13 @@ def __init__( self.check = np.all(self.success) self.sort_report = sort_report + if input_values is not None: + self.number_changing_values = ( + (input_values != reference_values).flatten().shape[0] + ) + else: + self.number_changing_values = None + def _compute_all_metrics( self, ) -> npt.NDArray[np.bool_]: @@ -329,9 +337,19 @@ def report(self, file_path: str | None = None) -> list[str]: # List all errors to terminal and file bad_indices_count = len(failed_indices[0]) full_count = len(self.references.flatten()) - failures_pct = round(100.0 * (bad_indices_count / full_count), 2) + failures_of_all_grid_points_pct = round( + 100.0 * (bad_indices_count / full_count), 2 + ) + if self.number_changing_values is not None: + failures_of_changing_gridpoint_pct = round( + 100.0 * (bad_indices_count / self.number_changing_values), 2 + ) + report_local_failures = f"Failures (changing grid points) ({bad_indices_count}/{self.number_changing_values}) ({failures_of_changing_gridpoint_pct}%)\n" + else: + report_local_failures = "" report = [ - f"All failures ({bad_indices_count}/{full_count}) ({failures_pct}%),\n", + f"{report_local_failures}" + f"Failures (all grid points) ({bad_indices_count}/{full_count}) ({failures_of_all_grid_points_pct}%)\n", f"Index Computed Reference " f"{'🔶 ' if not self.absolute_eps.is_default else ''}Absolute E(<{self.absolute_eps.value:.2e}) " f"{'🔶 ' if not self.relative_fraction.is_default else ''}Relative E(<{self.relative_fraction.value * 100:.2e}%) " From da1d2e7e031f601027e0fcc2c562ed60f8168dcf Mon Sep 17 00:00:00 2001 From: Janice Kim Date: Fri, 14 Nov 2025 13:12:18 -0500 Subject: [PATCH 35/43] Removing --no_legacy_namelist flag (#323) --- ndsl/stencils/testing/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 2ed29e37..1e5f41f7 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -75,12 +75,6 @@ def pytest_addoption(parser: pytest.Parser) -> None: default=1, help="How many indices of failures to print from worst to best. Default to 1.", ) - parser.addoption( - "--no_legacy_namelist", - action="store_true", - default=False, - help="Temporary flag introduced as part of NDSL issue #64. No functionality. Soon to be removed.", - ) parser.addoption( "--grid", action="store", From b4d109e2c07311bb207996489849ca073e6a1766 Mon Sep 17 00:00:00 2001 From: Charles Kropiewnicki <79879064+CharlesKrop@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:51:51 -0500 Subject: [PATCH 36/43] Added new functions: column_min_ddim & column_max_ddim and cooresponding test (#324) Functionality is the same as column_min/max, but separate functions are needed to handle cases with off grid data dimensions --- ndsl/stencils/column_operations.py | 50 +++++++++++++++++ tests/stencils/test_stencils.py | 90 ++++++++++++++++++++++++++++-- 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/ndsl/stencils/column_operations.py b/ndsl/stencils/column_operations.py index 21838ea2..25ede5de 100644 --- a/ndsl/stencils/column_operations.py +++ b/ndsl/stencils/column_operations.py @@ -28,6 +28,31 @@ def column_max(field, start_index, end_index): return field.at(K=max_index), max_index +@typing.no_type_check +@function +def column_max_ddim(field, ddim, start_index, end_index): + """ + Find the maximum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [max value, index of max value] + """ + max_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level, ddim=[ddim]) + old = field.at(K=max_index, ddim=[ddim]) + if new > old: + max_index = level + level += 1 + + return field.at(K=max_index, ddim=[ddim]), max_index + + @typing.no_type_check @function def column_min(field, start_index, end_index): @@ -51,3 +76,28 @@ def column_min(field, start_index, end_index): level += 1 return field.at(K=min_index), min_index + + +@typing.no_type_check +@function +def column_min_ddim(field, ddim, start_index, end_index): + """ + Find the minimum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [min value, index of min value] + """ + min_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level, ddim=[ddim]) + old = field.at(K=min_index, ddim=[ddim]) + if new < old: + min_index = level + level += 1 + + return field.at(K=min_index, ddim=[ddim]), min_index diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index 73802a23..48d2232a 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -5,9 +5,17 @@ from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import FORWARD, computation, interval -from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, set_4d_field_size from ndsl.stencils import CopyCornersXY -from ndsl.stencils.column_operations import column_max, column_min +from ndsl.stencils.column_operations import ( + column_max, + column_max_ddim, + column_min, + column_min_ddim, +) + + +FloatField_ddim = set_4d_field_size(2, Float) @pytest.fixture @@ -26,6 +34,14 @@ def column_max_stencil( with computation(FORWARD), interval(0, 1): max_value, max_index = column_max(data, 0, k_end) + def column_max_ddim_stencil( + data: FloatField_ddim, max_value: FloatFieldIJ, max_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + max_value, max_index = column_max_ddim(data, 1, 0, k_end) + def column_min_stencil( data: FloatField, min_value: FloatFieldIJ, min_index: FloatFieldIJ ): @@ -34,14 +50,30 @@ def column_min_stencil( with computation(FORWARD), interval(0, 1): min_value, min_index = column_min(data, 5, k_end) + def column_min_ddim_stencil( + data: FloatField_ddim, min_value: FloatFieldIJ, min_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + min_value, min_index = column_min_ddim(data, 1, 5, k_end) + self._column_max_stencil = stencil_factory.from_dims_halo( func=column_max_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) + self._column_max_ddim_stencil = stencil_factory.from_dims_halo( + func=column_max_ddim_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) self._column_min_stencil = stencil_factory.from_dims_halo( func=column_min_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) + self._column_min_ddim_stencil = stencil_factory.from_dims_halo( + func=column_min_ddim_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) def __call__( self, @@ -50,13 +82,21 @@ def __call__( max_index: FloatFieldIJ, min_value: FloatFieldIJ, min_index: FloatFieldIJ, + data_ddim: FloatField_ddim, + max_value_ddim: FloatFieldIJ, + max_index_ddim: FloatFieldIJ, + min_value_ddim: FloatFieldIJ, + min_index_ddim: FloatFieldIJ, ): self._column_max_stencil(data, max_value, max_index) + self._column_max_ddim_stencil(data_ddim, max_value_ddim, max_index_ddim) self._column_min_stencil(data, min_value, min_index) + self._column_min_ddim_stencil(data_ddim, min_value_ddim, min_index_ddim) def test_column_operations(boilerplate): stencil_factory, quantity_factory = boilerplate + quantity_factory.add_data_dimensions({"ddim": 2}) data = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") data.field[:] = [ 47.3821, @@ -71,19 +111,61 @@ def test_column_operations(boilerplate): 7.3504, ] + data_ddim = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM, "ddim"], "n/a") + data_ddim.field[:] = [ + [ + [ + [47.3821, 27.4825], + [2.9157, 93.1242], + [88.6034, 14.6347], + [71.9275, 58.2094], + [53.1412, 6.2369], + [19.4783, 71.5457], + [94.2258, 42.3091], + [36.8099, 89.7718], + [64.0175, 63.1910], + [7.3504, 3.4991], + ] + ] + ] + max_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") max_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") min_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") min_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") code = ColumnOperations(stencil_factory) - code(data, max_value, max_index, min_value, min_index) - + code( + data, + max_value, + max_index, + min_value, + min_index, + data_ddim, + max_value_ddim, + max_index_ddim, + min_value_ddim, + min_index_ddim, + ) + + # 3d field tests assert max_value.field[:] == np.max(data.field[:], axis=2) assert max_index.field[:] == np.argmax(data.field[:], axis=2) assert min_value.field[:] == np.min(data.field[:, :, 5:], axis=2) assert min_index.field[:] == 5 + np.argmin(data.field[:, :, 5:], axis=2) + # 4d field tests + assert max_value_ddim.field[:] == np.max(data_ddim.field[:, :, :, 1], axis=2) + assert max_index_ddim.field[:] == np.argmax(data_ddim.field[:, :, :, 1], axis=2) + assert min_value_ddim.field[:] == np.min(data_ddim.field[:, :, 5:, 1], axis=2) + assert min_index_ddim.field[:] == 5 + np.argmin( + data_ddim.field[:, :, 5:, 1], axis=2 + ) + def test_CopyCornersXY_deprecation(boilerplate) -> None: stencil_factory, _ = boilerplate From 3812e02120f02a93d0321735b5837eab46a4bb2b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 19 Nov 2025 08:48:56 -0500 Subject: [PATCH 37/43] [Optimization/Experimental] Better `AxisMerge` for column physics (#325) * Add `CleanUpScheduleTree` pass to prep for merge * Decluter axis merge logs, expose new pass * Verbose Pipeline passes (with temporary stree saves) * Deactivaete IF_SCOPE push, remove attempt to keep merging if next nodes not a MapScope * Docs of TODO * Draft of more extended testing * Fix `CartesianRefineTransients` for non-array * Some lint * Clean up the Tree of ForScope.loop_range * Utest: group test under a single orchestrated class, add missing feature and expected failures --- ndsl/dsl/dace/orchestration.py | 3 + ndsl/dsl/dace/stree/optimizations/__init__.py | 8 +- .../dace/stree/optimizations/axis_merge.py | 40 ++-- .../dace/stree/optimizations/clean_tree.py | 69 +++++++ .../stree/optimizations/refine_transients.py | 8 +- ndsl/dsl/dace/stree/pipeline.py | 9 +- tests/stree_optimizer/test_optimization.py | 190 ++++++++++++++---- 7 files changed, 267 insertions(+), 60 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/clean_tree.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index a5c278ac..049a1095 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -39,6 +39,7 @@ AxisIterator, CartesianAxisMerge, CartesianRefineTransients, + CleanUpScheduleTree, ) from ndsl.dsl.dace.utils import ( DaCeProgress, @@ -174,6 +175,7 @@ def _build_sdfg( if config.get_backend() == "dace:cpu_kfirst": passes.extend( [ + CleanUpScheduleTree(), CartesianAxisMerge(AxisIterator._I), CartesianAxisMerge(AxisIterator._J), CartesianAxisMerge(AxisIterator._K), @@ -183,6 +185,7 @@ def _build_sdfg( else: passes.extend( [ + CleanUpScheduleTree(), CartesianAxisMerge(AxisIterator._K), CartesianAxisMerge(AxisIterator._I), CartesianAxisMerge(AxisIterator._J), diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 8e371ee9..73497f93 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,5 +1,11 @@ from .axis_merge import AxisIterator, CartesianAxisMerge +from .clean_tree import CleanUpScheduleTree from .refine_transients import CartesianRefineTransients -__all__ = ["AxisIterator", "CartesianAxisMerge", "CartesianRefineTransients"] +__all__ = [ + "AxisIterator", + "CartesianAxisMerge", + "CartesianRefineTransients", + "CleanUpScheduleTree", +] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 74fe9e02..1ee2ff70 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -20,6 +20,10 @@ ) +# Buggy passes that should work +PUSH_IFSCOPE_DOWNWARD = False + + def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool: """Returns true if node is a map over the given axis.""" map_parameter = node.node.params @@ -171,7 +175,7 @@ def __init__( self.eager = eager def __str__(self) -> str: - return f"CartesianAxisMerge({self.axis.name})" + return f"CartesianAxisMerge_{self.axis.name}" def _merge_node( self, @@ -187,7 +191,7 @@ def _merge_node( if isinstance(node, stree.MapScope): return self._map_overcompute_merge(node, nodes) - elif isinstance(node, stree.IfScope): + elif PUSH_IFSCOPE_DOWNWARD and isinstance(node, stree.IfScope): return self._push_ifelse_down(node, nodes) elif isinstance(node, stree.TaskletNode): return self._push_tasklet_down(node, nodes) @@ -231,7 +235,6 @@ def _push_tasklet_down( merged = self._merge_node(next_node, nodes) # Attempt to push the tasklet in the next map - ndsl_log.debug(" Push tasklet down into next map") next_node = nodes[next_index + 1] if isinstance(next_node, stree.MapScope): next_node.children.insert(0, the_tasklet) @@ -292,7 +295,6 @@ def _push_ifelse_down( return merged # We are good to go - swap it all - ndsl_log.debug(f" Push IF {the_if.condition.as_string} down") inner_if_map = the_if.children[0] # Swap IF & maps @@ -322,7 +324,10 @@ def _map_overcompute_merge( the_map: stree.MapScope, nodes: list[stree.ScheduleTreeNode], ) -> int: - if _last_node(nodes, the_map): + # End of nodes OR + # Not the right axis + # --> recurse + if _last_node(nodes, the_map) or not _is_axis_map(the_map, self.axis): merged = 0 for child in the_map.children: merged += self._merge_node(child, the_map.children) @@ -330,13 +335,9 @@ def _map_overcompute_merge( next_node = _get_next_node(nodes, the_map) - # If the next node is not a MapScope - recurse + # Next node is not a MapScope - no merge if not isinstance(next_node, stree.MapScope): - merged = self._merge_node(next_node, nodes) - new_next_node = _get_next_node(nodes, the_map) - if new_next_node == next_node: - return merged - return merged + self._merge_node(the_map, nodes) + return 0 # Attempt to merge consecutive maps if not _can_merge_axis_maps(the_map, next_node, self.axis): @@ -357,10 +358,6 @@ def _map_overcompute_merge( ] ) - ndsl_log.debug( - f" Merge {self.axis.name} map: {first_range} ⋃ {second_range} -> {merged_range}" - ) - # push IfScope down if children are just maps axis_as_str = the_map.node.params[0] first_map = InsertOvercomputationGuard( @@ -421,11 +418,17 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: # in the tasklet... # NormalizeAxisSymbol(self.axis).visit(node) + # TODO: we are meging single axis, we could prefix those runs by moving + # if scope down inside the map if it has the proper axis, preparing + # for a better merging scope. If we can't nerge, we can revert this + # orep step + overall_merged = 0 + passes_apply = 0 i = 0 while True: i += 1 - ndsl_log.debug(f"🔥 Merge attempt #{i}") + # ndsl_log.debug(f"🔥 Merge attempt #{i}") previous_children = copy.deepcopy(node.children) try: merged = self._merge(node) @@ -438,10 +441,11 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: # If we didn't merge, we revert the children # to the previous state if merged == 0: - ndsl_log.debug("🥹 No merges, revert!") + # ndsl_log.debug("🥹 No merges, revert!") node.children = previous_children break + passes_apply += 1 ndsl_log.debug( - f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged" + f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged in {passes_apply} passes" ) diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py new file mode 100644 index 00000000..41056ec4 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl import ndsl_log + + +class CleanUpScheduleTree(stree.ScheduleNodeTransformer): + """Clean up unused nodes, or nodes barrying further optimizations.""" + + def __init__(self) -> None: + self.cleaned_state_boundaries = 0 + + def __str__(self) -> str: + return "CleanUpScheduleTree" + + def _remove_state_boundaries_from_my_childs( + self, node: stree.ScheduleTreeScope + ) -> None: + to_remove = [ + child + for child in node.children + if isinstance(child, stree.StateBoundaryNode) + ] + for to_remove_child in to_remove: + self.cleaned_state_boundaries += 1 + node.children.remove(to_remove_child) + + def visit_WhileScope(self, node: stree.WhileScope) -> stree.WhileScope: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + return node + + def visit_ForScope(self, node: stree.ForScope) -> stree.ForScope: + self._remove_state_boundaries_from_my_childs(node) + + # We might have inherited a proper `loop_range` from the SDFG + # but the data (sdfg) it relies on is no longer valid. + node.header.loop_range = lambda: None + + for child in node.children: + self.visit(child) + + return node + + def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + return node + + def visit_IfScope(self, node: stree.IfScope) -> stree.IfScope: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + return node + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + self._remove_state_boundaries_from_my_childs(node) + for child in node.children: + self.visit(child) + + ndsl_log.debug( + f"Clean up StateBoundary : {self.cleaned_state_boundaries} nodes" + ) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 49e0917b..f0174e17 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -115,7 +115,8 @@ def visit_MapScope(self, node: stree.MapScope) -> None: def visit_TaskletNode(self, node: stree.TaskletNode) -> None: for memlet in [*node.input_memlets(), *node.output_memlets()]: - if self.containers[memlet.data].transient: + data = self.containers[memlet.data] + if data.transient and isinstance(data, dace.data.Array): for map_entry in self._cartesian_current_map_nesting: if map_entry is not None: self.transient_map_access[memlet.data].add(map_entry) @@ -123,7 +124,7 @@ def visit_TaskletNode(self, node: stree.TaskletNode) -> None: def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: self.containers = node.containers for name, data in self.containers.items(): - if data.transient: + if data.transient and isinstance(data, dace.data.Array): self.transient_map_access[name] = set() for child in node.children: @@ -163,6 +164,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): It can do: - Looking at usage of a transient in a cartesian axis (e.g. loop over a cartesian axis) it will reduce that axis to 1 if it exists in _only one_. + It should but cannot do/will bug if: - Dataflow analysis on the axis to prevent reducing an axis to one where the transient is used with offset, leading to faulty numerics @@ -211,7 +213,7 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: # Remove Axis refined_transient = 0 for name, data in node.containers.items(): - if not data.transient: + if not (data.transient and isinstance(data, dace.data.Array)): continue refined = _reduce_cartesian_axes_size_to_1( collect_map.transient_map_access[name], diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 10fb77cd..702af7fe 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -3,6 +3,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.logging import ndsl_log_on_rank_0 class StreePipeline(ABC): @@ -42,10 +43,14 @@ def run( stree: stree.ScheduleTreeRoot, verbose: bool = False, ) -> stree.ScheduleTreeRoot: - for p in self.passes: + for i, p in enumerate(self.passes): if verbose: - print(f"[Stree OPT] {p}") + path = f"pass{i}_{p}.txt" + ndsl_log_on_rank_0.info(f"[Stree OPT] {p} (saving {path} after)") p.visit(stree) + if verbose: + with open(path, "w") as f: + f.write(stree.as_string()) return stree diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_optimization.py index fd08a1cb..9cf6ecd6 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_optimization.py @@ -6,10 +6,24 @@ from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField +def _get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: + """Get the Precompiled SDFG from the dace config dict where they are cached post + compilation and flush the cache in order for next build to re-use the function.""" + sdfg_repo = stencil_factory.config.dace_config.loaded_precompiled_SDFG + + if len(sdfg_repo.values()) != 1: + raise RuntimeError("Failure to compile SDFG") + sdfg = list(sdfg_repo.values())[0] + + sdfg_repo.clear() + + return sdfg + + class StreeOptimization: def __init__(self) -> None: pass @@ -26,61 +40,171 @@ def __exit__( orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False -def stencil_A(in_field: FloatField, out_field: FloatField) -> None: +def copy_stencil(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): - out_field = in_field + out_field = in_field + 1 -def stencil_B(in_field: FloatField, out_field: FloatField) -> None: +def copy_stencil_with_self_assign(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): - out_field = out_field + in_field * 3 + out_field = out_field + in_field + 2 -class TriviallyMergeableCode: - def __init__(self, stencil_factory: StencilFactory) -> None: - orchestrate(obj=self, config=stencil_factory.config.dace_config) - self.stencil_A = stencil_factory.from_dims_halo( - func=stencil_A, +def copy_stencil_with_forward_K(in_field: FloatField, out_field: FloatField) -> None: + with computation(FORWARD), interval(...): + out_field = in_field + 3 + + +def copy_stencil_with_different_intervals( + in_field: FloatField, + out_field: FloatField, +) -> None: + with computation(PARALLEL), interval(1, None): + out_field = in_field + 5 + + +def copy_stencil_with_buffer_read_offset_in_K( + in_field: FloatField, out_field: FloatField, buffer: FloatField +) -> None: + with computation(PARALLEL), interval(1, None): + buffer = in_field + 6 + + with computation(PARALLEL), interval(1, None): + out_field = buffer[K - 1] + 7 + + +class OrchestratedCode: + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ) -> None: + orchestratable_methods = [ + "trivial_merge", + "missing_merge_of_forscope_and_map", + "overcompute_merge", + "block_merge_when_depandencies_is_found", + ] + for method in orchestratable_methods: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self.copy_stencil = stencil_factory.from_dims_halo( + func=copy_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.copy_stencil_with_forward_K = stencil_factory.from_dims_halo( + func=copy_stencil_with_forward_K, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.copy_stencil_with_buffer_read_offset_in_K = stencil_factory.from_dims_halo( + func=copy_stencil_with_buffer_read_offset_in_K, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - self.stencil_B = stencil_factory.from_dims_halo( - func=stencil_B, + self.copy_stencil_with_different_intervals = stencil_factory.from_dims_halo( + func=copy_stencil_with_different_intervals, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - def __call__(self, in_field: FloatField, out_field: FloatField) -> None: - self.stencil_A(in_field, out_field) - self.stencil_B(in_field, out_field) + self._buffer = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], units="") + + def trivial_merge( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.copy_stencil(in_field, out_field) + self.copy_stencil(in_field, out_field) + + def missing_merge_of_forscope_and_map( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.copy_stencil(in_field, out_field) + self.copy_stencil_with_forward_K(in_field, out_field) + self.copy_stencil(in_field, out_field) + + def block_merge_when_depandencies_is_found( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.copy_stencil(in_field, out_field) + self.copy_stencil_with_buffer_read_offset_in_K( + in_field, out_field, self._buffer + ) + + def overcompute_merge( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.copy_stencil(in_field, out_field) + self.copy_stencil_with_different_intervals(in_field, out_field) def test_stree_merge_maps() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu" + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" ) - code = TriviallyMergeableCode(stencil_factory) + code = OrchestratedCode(stencil_factory, quantity_factory) in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") with StreeOptimization(): - code(in_qty, out_qty) + # Trivial merge + code.trivial_merge(in_qty, out_qty) + precompiled_sdfg = _get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] - assert ( - len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) - == 1 - ) - sdfg = list( - stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() - )[0] + assert len(all_maps) == 3 + assert (out_qty.field[:] == 2).all() + + # Merge IJ - but do not merge K map & for (missing feature) + code.missing_merge_of_forscope_and_map(in_qty, out_qty) + sdfg = _get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) - for me, state in sdfg.sdfg.all_nodes_recursive() + for me, state in sdfg.all_nodes_recursive() if isinstance(me, dace.nodes.MapEntry) ] + assert len(all_maps) == 4 # 2 IJ + 2 Ks + all_loop_guard_state = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.SDFGState) and me.name.startswith("loop_guard") + ] + assert len(all_loop_guard_state) == 1 # 1 For loop + # Overcompute merge in K - we merge and introduce an If guard + code.overcompute_merge(in_qty, out_qty) + sdfg = _get_SDFG_and_purge(stencil_factory).sdfg + all_maps = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] assert len(all_maps) == 3 - assert (out_qty.field[:] == 4).all() + + # Forbid merging when data dependancy is detected + code.block_merge_when_depandencies_is_found(in_qty, out_qty) + sdfg = _get_SDFG_and_purge(stencil_factory).sdfg + all_maps = [ + (me, state) + for me, state in sdfg.all_nodes_recursive() + if isinstance(me, dace.nodes.MapEntry) + ] + assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged) class LocalRefineableCode(NDSLRuntime): @@ -89,7 +213,7 @@ def __init__( ) -> None: super().__init__(stencil_factory.config.dace_config) self.stencil_A = stencil_factory.from_dims_halo( - func=stencil_A, + func=copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) @@ -112,14 +236,8 @@ def test_stree_roundtrip_transient_is_refined() -> None: with StreeOptimization(): code(in_qty, out_qty) - assert ( - len(stencil_factory.config.dace_config.loaded_precompiled_SDFG.values()) - == 1 - ) - sdfg = list( - stencil_factory.config.dace_config.loaded_precompiled_SDFG.values() - )[0] + precompiled_sdfg = _get_SDFG_and_purge(stencil_factory) - for array in sdfg.sdfg.arrays.values(): + for array in precompiled_sdfg.sdfg.arrays.values(): if array.transient: assert array.shape == (1, 1, 1) From 7b80a19424d3d32f190e1bd3a8ee8950b4ac2ba2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 20 Nov 2025 12:27:43 -0500 Subject: [PATCH 38/43] [Feature/Experimental] Stree Refine Transient optimization pass: data dimensions and proper unit tests (#327) * Rename test for axis merge * Properly refine fields with data dimensions Fix indexing in memlets properly * utest: coverage of all implemented tests * Clean up timing print of orchestration * Lint * Fix bad reference to in/out memlets, remopve dead code, better code * Share test infrastructure, rename stencils * Lint * Better naming in utest stencils --- .../stree/optimizations/refine_transients.py | 51 +++++-- ndsl/dsl/dace/utils.py | 2 +- tests/stree_optimizer/__init__.py | 0 tests/stree_optimizer/sdfg_stree_tools.py | 36 +++++ .../{test_optimization.py => test_merge.py} | 125 ++++----------- .../stree_optimizer/test_transient_refine.py | 144 ++++++++++++++++++ 6 files changed, 248 insertions(+), 110 deletions(-) create mode 100644 tests/stree_optimizer/__init__.py create mode 100644 tests/stree_optimizer/sdfg_stree_tools.py rename tests/stree_optimizer/{test_optimization.py => test_merge.py} (52%) create mode 100644 tests/stree_optimizer/test_transient_refine.py diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index f0174e17..342d676d 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -51,7 +51,25 @@ def _reduce_cartesian_axes_size_to_1( axis.as_cartesian_index(), value=1, ) - transient_data.set_strides_from_layout(*ijk_order) + + # Assume 3D cartesian! + if len(transient_data.shape) < 3: + warnings.warn( + f"Potential non-3D array: {transient_data}, skipping.", + UserWarning, + stacklevel=2, + ) + return refined + if len(transient_data.shape) == 3: + layout = [*ijk_order] + else: + data_dim_count = len(transient_data.shape) - 3 + layout = [dim + data_dim_count for dim in ijk_order] + [ + i - 1 for i in range(data_dim_count, 0, -1) + ] + + transient_data.set_strides_from_layout(*layout) + transient_data.lifetime = dace.dtypes.AllocationLifetime.State refined = True return refined @@ -138,17 +156,24 @@ def __str__(self) -> str: return "RefineTransientAxis" def visit_TaskletNode(self, node: stree.TaskletNode) -> None: - for name, memlet in node.in_memlets.items(): - if self.containers[memlet.data].transient: - node.in_memlets[name] = memlet.from_array( - memlet.data, self.containers[memlet.data] - ) - - for name, memlet in node.out_memlets.items(): - if self.containers[memlet.data].transient: - node.out_memlets[name] = memlet.from_array( - memlet.data, self.containers[memlet.data] - ) + for memlet in [*node.output_memlets(), *node.input_memlets()]: + array = self.containers[memlet.data] + if array.transient: + replace_cartesian_access = {} + if len(array.shape) >= 1 and array.shape[0] == 1: + replace_cartesian_access[AxisIterator._I.as_str()] = 0 + if len(array.shape) >= 2 and array.shape[1] == 1: + replace_cartesian_access[AxisIterator._J.as_str()] = 0 + if len(array.shape) >= 3 and array.shape[2] == 1: + # Workaround because the iterator can be `__k_0` instead of `__k` + axis = None + for axis_symbol in memlet.free_symbols: + if axis_symbol.startswith(AxisIterator._K.as_str()): + axis = axis_symbol + break + if axis: + replace_cartesian_access[axis] = 0 + memlet.replace(replace_cartesian_access) def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: self.containers = node.containers @@ -176,6 +201,8 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): the cache access significantly. This also has been implemented to _not_ deal with offset/slicing downstream impact of removing an axis. Nevertheless the xis should be removed if it's not used. + - It only knows how to deal with 3D cartesian and 3D cartesian + data dimensions. Anything else will + fail `_reduce_cartesian_axes_size_to_1` calculation More tests: - Test for dataflow with offset diff --git a/ndsl/dsl/dace/utils.py b/ndsl/dsl/dace/utils.py index cb35503e..ae6b1d80 100644 --- a/ndsl/dsl/dace/utils.py +++ b/ndsl/dsl/dace/utils.py @@ -32,7 +32,7 @@ def __enter__(self) -> None: def __exit__(self, _type, _val, _traceback) -> None: # type: ignore elapsed = time.time() - self.start - ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed}s.") + ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed:.2f}s.") def _is_ref(sd: dace.sdfg.SDFG, aname: str) -> bool: diff --git a/tests/stree_optimizer/__init__.py b/tests/stree_optimizer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stree_optimizer/sdfg_stree_tools.py b/tests/stree_optimizer/sdfg_stree_tools.py new file mode 100644 index 00000000..d5a07154 --- /dev/null +++ b/tests/stree_optimizer/sdfg_stree_tools.py @@ -0,0 +1,36 @@ +from types import TracebackType + +import dace + +import ndsl.dsl.dace.orchestration as orch +from ndsl import StencilFactory + + +def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: + """Get the Precompiled SDFG from the dace config dict where they are cached post + compilation and flush the cache in order for next build to re-use the function.""" + sdfg_repo = stencil_factory.config.dace_config.loaded_precompiled_SDFG + + if len(sdfg_repo.values()) != 1: + raise RuntimeError("Failure to compile SDFG") + sdfg = list(sdfg_repo.values())[0] + + sdfg_repo.clear() + + return sdfg + + +class StreeOptimization: + def __init__(self) -> None: + pass + + def __enter__(self) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False diff --git a/tests/stree_optimizer/test_optimization.py b/tests/stree_optimizer/test_merge.py similarity index 52% rename from tests/stree_optimizer/test_optimization.py rename to tests/stree_optimizer/test_merge.py index 9cf6ecd6..8fff866d 100644 --- a/tests/stree_optimizer/test_optimization.py +++ b/tests/stree_optimizer/test_merge.py @@ -1,61 +1,30 @@ -from types import TracebackType - import dace -import ndsl.dsl.dace.orchestration as orch -from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField +from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge -def _get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: - """Get the Precompiled SDFG from the dace config dict where they are cached post - compilation and flush the cache in order for next build to re-use the function.""" - sdfg_repo = stencil_factory.config.dace_config.loaded_precompiled_SDFG - - if len(sdfg_repo.values()) != 1: - raise RuntimeError("Failure to compile SDFG") - sdfg = list(sdfg_repo.values())[0] - - sdfg_repo.clear() - - return sdfg - - -class StreeOptimization: - def __init__(self) -> None: - pass - - def __enter__(self) -> None: - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False - -def copy_stencil(in_field: FloatField, out_field: FloatField) -> None: +def stencil(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = in_field + 1 -def copy_stencil_with_self_assign(in_field: FloatField, out_field: FloatField) -> None: +def stencil_with_self_assign(in_field: FloatField, out_field: FloatField) -> None: with computation(PARALLEL), interval(...): out_field = out_field + in_field + 2 -def copy_stencil_with_forward_K(in_field: FloatField, out_field: FloatField) -> None: +def stencil_with_forward_K(in_field: FloatField, out_field: FloatField) -> None: with computation(FORWARD), interval(...): out_field = in_field + 3 -def copy_stencil_with_different_intervals( +def stencil_with_different_intervals( in_field: FloatField, out_field: FloatField, ) -> None: @@ -63,7 +32,7 @@ def copy_stencil_with_different_intervals( out_field = in_field + 5 -def copy_stencil_with_buffer_read_offset_in_K( +def stencil_with_buffer_read_offset_in_K( in_field: FloatField, out_field: FloatField, buffer: FloatField ) -> None: with computation(PARALLEL), interval(1, None): @@ -92,20 +61,20 @@ def __init__( method_to_orchestrate=method, ) - self.copy_stencil = stencil_factory.from_dims_halo( - func=copy_stencil, + self.stencil = stencil_factory.from_dims_halo( + func=stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - self.copy_stencil_with_forward_K = stencil_factory.from_dims_halo( - func=copy_stencil_with_forward_K, + self.stencil_with_forward_K = stencil_factory.from_dims_halo( + func=stencil_with_forward_K, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - self.copy_stencil_with_buffer_read_offset_in_K = stencil_factory.from_dims_halo( - func=copy_stencil_with_buffer_read_offset_in_K, + self.stencil_with_buffer_read_offset_in_K = stencil_factory.from_dims_halo( + func=stencil_with_buffer_read_offset_in_K, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - self.copy_stencil_with_different_intervals = stencil_factory.from_dims_halo( - func=copy_stencil_with_different_intervals, + self.stencil_with_different_intervals = stencil_factory.from_dims_halo( + func=stencil_with_different_intervals, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) @@ -116,35 +85,33 @@ def trivial_merge( in_field: FloatField, out_field: FloatField, ) -> None: - self.copy_stencil(in_field, out_field) - self.copy_stencil(in_field, out_field) + self.stencil(in_field, out_field) + self.stencil(in_field, out_field) def missing_merge_of_forscope_and_map( self, in_field: FloatField, out_field: FloatField, ) -> None: - self.copy_stencil(in_field, out_field) - self.copy_stencil_with_forward_K(in_field, out_field) - self.copy_stencil(in_field, out_field) + self.stencil(in_field, out_field) + self.stencil_with_forward_K(in_field, out_field) + self.stencil(in_field, out_field) def block_merge_when_depandencies_is_found( self, in_field: FloatField, out_field: FloatField, ) -> None: - self.copy_stencil(in_field, out_field) - self.copy_stencil_with_buffer_read_offset_in_K( - in_field, out_field, self._buffer - ) + self.stencil(in_field, out_field) + self.stencil_with_buffer_read_offset_in_K(in_field, out_field, self._buffer) def overcompute_merge( self, in_field: FloatField, out_field: FloatField, ) -> None: - self.copy_stencil(in_field, out_field) - self.copy_stencil_with_different_intervals(in_field, out_field) + self.stencil(in_field, out_field) + self.stencil_with_different_intervals(in_field, out_field) def test_stree_merge_maps() -> None: @@ -160,7 +127,7 @@ def test_stree_merge_maps() -> None: with StreeOptimization(): # Trivial merge code.trivial_merge(in_qty, out_qty) - precompiled_sdfg = _get_SDFG_and_purge(stencil_factory) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() @@ -172,7 +139,7 @@ def test_stree_merge_maps() -> None: # Merge IJ - but do not merge K map & for (missing feature) code.missing_merge_of_forscope_and_map(in_qty, out_qty) - sdfg = _get_SDFG_and_purge(stencil_factory).sdfg + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) for me, state in sdfg.all_nodes_recursive() @@ -188,7 +155,7 @@ def test_stree_merge_maps() -> None: # Overcompute merge in K - we merge and introduce an If guard code.overcompute_merge(in_qty, out_qty) - sdfg = _get_SDFG_and_purge(stencil_factory).sdfg + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) for me, state in sdfg.all_nodes_recursive() @@ -198,46 +165,10 @@ def test_stree_merge_maps() -> None: # Forbid merging when data dependancy is detected code.block_merge_when_depandencies_is_found(in_qty, out_qty) - sdfg = _get_SDFG_and_purge(stencil_factory).sdfg + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) for me, state in sdfg.all_nodes_recursive() if isinstance(me, dace.nodes.MapEntry) ] assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged) - - -class LocalRefineableCode(NDSLRuntime): - def __init__( - self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory - ) -> None: - super().__init__(stencil_factory.config.dace_config) - self.stencil_A = stencil_factory.from_dims_halo( - func=copy_stencil, - compute_dims=[X_DIM, Y_DIM, Z_DIM], - ) - self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) - - def __call__(self, in_field: Quantity, out_field: Quantity) -> None: - self.stencil_A(in_field, self.tmp) - self.stencil_A(self.tmp, out_field) - - -def test_stree_roundtrip_transient_is_refined() -> None: - domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu" - ) - - code = LocalRefineableCode(stencil_factory, quantity_factory) - in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") - out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") - - with StreeOptimization(): - code(in_qty, out_qty) - - precompiled_sdfg = _get_SDFG_and_purge(stencil_factory) - - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == (1, 1, 1) diff --git a/tests/stree_optimizer/test_transient_refine.py b/tests/stree_optimizer/test_transient_refine.py new file mode 100644 index 00000000..16314b8d --- /dev/null +++ b/tests/stree_optimizer/test_transient_refine.py @@ -0,0 +1,144 @@ +from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval +from ndsl.dsl.typing import Float, FloatField + +from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge + + +DATADIM_SIZE = 8 +DDIM_NAME = "DDIM" +DDIM_TYPE = Field[IJK, (Float, (DATADIM_SIZE))] + + +def stencil(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(...): + out_field = in_field + 1 + + +def stencil_with_K_offset(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(0, -1): + out_field = in_field[K + 1] + 2 + + +def stencil_with_JK_offset(in_field: FloatField, out_field: FloatField) -> None: + with computation(PARALLEL), interval(...): + out_field = in_field[J + 1, K + 1] + 3 + + +def stencil_with_ddim(in_field: DDIM_TYPE, out_field: DDIM_TYPE) -> None: + with computation(PARALLEL), interval(...): + n = 0 + while n < DATADIM_SIZE: + out_field[0, 0, 0][n] = in_field[0, 0, 0][n] + 4 + n = n + 1 + + +class TransientRefineableCode(NDSLRuntime): + def __init__( + self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory + ) -> None: + super().__init__(stencil_factory.config.dace_config) + orchestratable_methods = [ + "refine_to_scalar", + "refine_to_K_buffer", + "refine_to_JK_buffer", + "do_not_refine_datadims", + ] + for method in orchestratable_methods: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + self.stencil = stencil_factory.from_dims_halo( + func=stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_K_offset = stencil_factory.from_dims_halo( + func=stencil_with_K_offset, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_JK_offset = stencil_factory.from_dims_halo( + func=stencil_with_JK_offset, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.stencil_with_ddim = stencil_factory.from_dims_halo( + func=stencil_with_ddim, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.tmp = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM]) + self.tmp_ddim = self.make_local( + quantity_factory, [X_DIM, Y_DIM, Z_DIM, DDIM_NAME] + ) + + def refine_to_scalar(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil(in_field, self.tmp) + self.stencil(self.tmp, out_field) + + def refine_to_K_buffer(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil(in_field, self.tmp) + self.stencil_with_K_offset(self.tmp, out_field) + + def refine_to_JK_buffer(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil(in_field, self.tmp) + self.stencil_with_JK_offset(self.tmp, out_field) + + def do_not_refine_datadims(self, in_field: Quantity, out_field: Quantity) -> None: + self.stencil_with_ddim(in_field, self.tmp_ddim) + self.stencil_with_ddim(self.tmp_ddim, out_field) + + +def test_stree_roundtrip_transient_is_refined() -> None: + domain = (3, 3, 4) + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") + out_qty = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "") + + quantity_factory.add_data_dimensions({DDIM_NAME: DATADIM_SIZE}) + in_qty_ddim = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM, DDIM_NAME], "") + out_qty_ddim = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM, DDIM_NAME], "") + + code = TransientRefineableCode(stencil_factory, quantity_factory) + + with StreeOptimization(): + # Refine to scalar + code.refine_to_scalar(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) + + # Refine cartesian axis to buffers + # IJ merges - K is a buffer + code.refine_to_K_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + 1, + domain[2] + 1, # Quantity are domain size + 1 + ) + + # I merges - JK buffer + code.refine_to_JK_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + domain[1] + 1, # Quantity are domain size + 1 + domain[2] + 1, + ) + + # Refine to remaining data dimensions + code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1 From e965beacce99517e9b441072a70b0bb53b5a1e65 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Nov 2025 10:14:44 -0500 Subject: [PATCH 39/43] [Update] GT4Py & DaCe updated to 2025.11.25 state of `main` (#330) * DaCe update: fix networkx dependency breaking with 3.6 * GT4Py: Runtime interval bounds in `debug` --- external/dace | 2 +- external/gt4py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/dace b/external/dace index 1033dfcf..4a9f4602 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 1033dfcf9d118856d82c6ee8d6f6cfacec662335 +Subproject commit 4a9f46027147a52e2b0ac9eedeb101c3ab27d0bf diff --git a/external/gt4py b/external/gt4py index 751e64c1..cdf2ce72 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 751e64c192afaf53720aafa8fb841cd66291ca9f +Subproject commit cdf2ce7230f4c55c1977f197207d99d320d6fcee From 91296247c6ff3471a96bb3d01e9e956a5e0affa0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Nov 2025 12:52:21 -0500 Subject: [PATCH 40/43] [Tool] Best Guess Netcdfs diff (#177) * Best guess netcdfs compare * Add FieldBundle to debugger * lint * Move executable to `pyproject` * Lint --- ndsl/debug/debugger.py | 3 + ndsl/quantity/quantity.py | 2 + ndsl/stencils/testing/best_guess_diff.py | 194 +++++++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 200 insertions(+) create mode 100644 ndsl/stencils/testing/best_guess_diff.py diff --git a/ndsl/debug/debugger.py b/ndsl/debug/debugger.py index b9de8631..14037428 100644 --- a/ndsl/debug/debugger.py +++ b/ndsl/debug/debugger.py @@ -9,6 +9,7 @@ from ndsl.logging import ndsl_log from ndsl.quantity import Quantity +from ndsl.quantity.field_bundle import FieldBundle @dataclasses.dataclass @@ -87,6 +88,8 @@ def save_as_dataset(self, data_as_dict: dict, savename: str, is_in: bool) -> Non data_arrays[f"{name}.{field.name}"] = self._to_xarray( getattr(data, field.name), field.name ) + elif isinstance(data, FieldBundle): + data_arrays[name] = data.quantity.field_as_xarray else: data_arrays[name] = self._to_xarray(data, name) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 300f6376..5890fcfe 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -216,6 +216,8 @@ def to_netcdf( self, path: str, name: str = "var", rank: int = -1, all_data: bool = False ) -> None: if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: + if rank < 0: + rank = MPI.COMM_WORLD.Get_rank() if all_data: self.data_as_xarray.to_dataset(name=name).to_netcdf( f"{path}__r{rank}.nc4" diff --git a/ndsl/stencils/testing/best_guess_diff.py b/ndsl/stencils/testing/best_guess_diff.py new file mode 100644 index 00000000..7bb0ca82 --- /dev/null +++ b/ndsl/stencils/testing/best_guess_diff.py @@ -0,0 +1,194 @@ +import argparse +import pathlib + +import numpy as np +import xarray as xr +import yaml + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + "Attempt to diff two NetCDFs with similar data." + "Differences that can be reconciled are strict domain vs halo, variable name mapping." + "The program will report on assumptions taken." + ) + parser.add_argument( + "netcdf_A", + type=str, + help="path of first NetCDFs, named A in the logs.", + ) + parser.add_argument( + "netcdf_B", + type=str, + help="path of second NetCDFs, named B in the logs.", + ) + parser.add_argument( + "-nm", + "--name_mapping", + type=str, + help="[Optional] Yaml file describing the mapping of the names.", + ) + parser.add_argument( + "-ha", + "--halo", + type=int, + default=3, + help="[Optional] Halo size if any, default to 3.", + ) + return parser + + +def main( + netcdf_A: str, + netcdf_B: str, + name_mapping: str | None = None, + halo: int = 3, +) -> None: + A = xr.open_dataset(netcdf_A) + B = xr.open_dataset(netcdf_B) + name_map = {} + if name_mapping is not None: + with open(name_mapping) as f: + name_map = yaml.safe_load(f) + + dataset = {} + for name_A, data_A in A.items(): + print(f"Best guess for {name_A} from A:") + # Resolve name + resolved_name_B = None + if name_A not in B.keys(): + if name_A not in name_map.keys(): + print(" [Failed] name can't be found in B nor in name mapping") + else: + resolved_name_B = name_map[name_A] + else: + resolved_name_B = name_A + + if resolved_name_B is None: + continue + print(f" [Hyp] use {resolved_name_B} for B") + + # Resolve domain size + data_B = B[resolved_name_B] + if len(data_A.shape) >= 5: + print( + " [Hyp] A data dims are >= 5, assuming savepoints/rank are the firt two and going A[0, 0, ::]" + ) + data_A = data_A[0, 0, ::] + if len(data_B.shape) >= 5: + print( + " [Hyp] B data dims are >= 5, assuming savepoints/rank are the firt two and going B[0, 0, ::]" + ) + data_B = data_B[0, 0, ::] + + if len(data_A.shape) != len(data_B.shape): + print( + f" [Failed] A is shape {len(data_A.shape)}, B is shape {len(data_B.shape)}: can't reconcile." + ) + continue + + # - Assume we have 0 == I, 1 == J now + resolved_I = None + A_uses_halo = False + B_uses_halo = False + if data_A.shape[0] != data_B.shape[0]: + if data_A.shape[0] < data_B.shape[0]: + if data_B.shape[0] - 2 * halo != data_A.shape[0]: + print( + f" [Failed] B in dim I is too big, even with halo substracted {data_B.shape[0]} (halo: {halo})" + ) + else: + B_uses_halo = True + resolved_I = data_A.shape[0] + else: + if data_A.shape[0] - 2 * halo != data_B.shape[0]: + print( + f" [Failed] A in dim I is too big, even with halo substracted {data_A.shape[0]} (halo: {halo})" + ) + else: + A_uses_halo = True + resolved_I = data_B.shape[0] + else: + resolved_I = data_A.shape[0] + + if resolved_I is None: + continue + + print(f" [Hyp] Using {resolved_I} as I dim size") + + resolved_J = None + if data_A.shape[1] != data_B.shape[1]: + if data_A.shape[1] < data_B.shape[1]: + if data_B.shape[1] - 2 * halo != data_A.shape[1]: + print( + f" [Failed] B in dim J is too big, even with halo substracted {data_B.shape[1]} (halo: {halo})" + ) + else: + resolved_J = data_A.shape[1] + else: + if data_A.shape[1] - 2 * halo != data_B.shape[1]: + print( + f" [Failed] A in dim J is too big, even with halo substracted {data_A.shape[1]} (halo: {halo})" + ) + else: + resolved_J = data_B.shape[1] + else: + resolved_J = data_A.shape[1] + + if resolved_J is None: + continue + + print(f" [Hyp] Using {resolved_J} as J dim size") + + # - Assume 2 == K + if data_A.shape[2] != data_B.shape[2]: + print( + f" [Failed] Can't reconcile K dim: A ({data_A.shape[2]}) != B ({data_B.shape[2]})" + ) + continue + resolved_K = data_A.shape[2] + + print(f" [Hyp] Using {resolved_K} as K dim size") + + # We should now be ready to diff + if A_uses_halo: + data_A = data_A[halo:-halo, halo:-halo, ::] + if B_uses_halo: + data_B = data_B[halo:-halo, halo:-halo, ::] + + dims = [f"D{i}_{s}" for i, s in enumerate(data_A.shape)] + absolute_diff = data_A.data - data_B.data + dataset[name_A] = xr.DataArray( + absolute_diff, + dims=dims, + ) + + # ULP diffs + max_values = np.maximum( + np.absolute(data_A.data.flatten()), np.absolute(data_B.data.flatten()) + ) + ulp_diff = np.divide(np.abs(absolute_diff.flatten()), np.spacing(max_values)) + ulp_diff = np.sort(ulp_diff) + dataset[f"ulp_{name_A}"] = xr.DataArray( + ulp_diff, + dims=[f"GP_{ulp_diff.shape[0]}"], + ) + + print(" [Success]") + + xr.Dataset(dataset).to_netcdf(f"best_guest_diff_{pathlib.Path(netcdf_A).stem}.nc4") + + +def entry_point() -> None: + parser = get_parser() + args = parser.parse_args() + main( + netcdf_A=args.netcdf_A, + netcdf_B=args.netcdf_B, + name_mapping=args.name_mapping, + halo=args.halo, + ) + + +if __name__ == "__main__": + entry_point() diff --git a/pyproject.toml b/pyproject.toml index a562bfb5..8bf3ab7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ test = ["pytest", "coverage"] zarr = ["zarr<3"] [project.scripts] +best_guess_diff = "ndsl.stencils.testing.best_guess_diff:entry_point" ndsl-serialbox_to_netcdf = "ndsl.stencils.testing.serialbox_to_netcdf:entry_point" [project.urls] From 57a1ba4d19fdfa838a949b0c6afc196721a5d0dd Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 26 Nov 2025 07:57:41 -0500 Subject: [PATCH 41/43] Update `gt4py` to capture improvement to user error (#331) --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index cdf2ce72..a7429094 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit cdf2ce7230f4c55c1977f197207d99d320d6fcee +Subproject commit a7429094d7dd9418a2e3c1e57b2e9c783d79250d From a7898db80439bfdaaa139c6353b0686db7db9978 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 26 Nov 2025 11:29:35 -0500 Subject: [PATCH 42/43] [Rework/Experimental] Refine Transient v2: `Ranges` for all! (#328) * Rework the `RefineTransient` to use `Range` - simpler, cleaner and more robust. Also props us for a better refine * Remove unused code --------- Co-authored-by: Tobias Wicky-Pfund --- .../stree/optimizations/refine_transients.py | 240 +++++++++--------- 1 file changed, 114 insertions(+), 126 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 342d676d..e24788bc 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,7 +1,6 @@ from __future__ import annotations import warnings -from types import TracebackType import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree @@ -25,125 +24,99 @@ def _change_index_of_tuple( return tuple(new_list) -def _reduce_cartesian_axes_size_to_1( - transient_map_access: set[stree.nodes.MapEntry], +def _reduce_cartesian_axis_size_to_1( + axis: AxisIterator, + transient_map_reads: dace.subsets.Range | None, + transient_map_writes: dace.subsets.Range | None, transient_data: dace.data.Data, ijk_order: tuple[int, int, int], ) -> bool: - """Reduce dimension size of transient to 1 if their are accessed only - in a single Map for the cartesian dimensions""" - refined = False - for axis in AxisIterator: - access_in_map_count = 0 - for map_entry in transient_map_access: - if axis.as_str() in map_entry.params[0]: - access_in_map_count += 1 - - if access_in_map_count != 1: - continue - - # This transient is used in exactly one single-Axis map - # therefore this dimension can be removed. BUT we are not truly - # removing it, we are reducing it to 1 to not have to deal - # with different slicing. - transient_data.shape = _change_index_of_tuple( - transient_data.shape, - axis.as_cartesian_index(), - value=1, - ) - - # Assume 3D cartesian! - if len(transient_data.shape) < 3: - warnings.warn( - f"Potential non-3D array: {transient_data}, skipping.", - UserWarning, - stacklevel=2, - ) - return refined - if len(transient_data.shape) == 3: - layout = [*ijk_order] - else: - data_dim_count = len(transient_data.shape) - 3 - layout = [dim + data_dim_count for dim in ijk_order] + [ - i - 1 for i in range(data_dim_count, 0, -1) - ] + """Reduce dimension size of transient to 1 if all access (reads and writes) + are atomic""" - transient_data.set_strides_from_layout(*layout) - transient_data.lifetime = dace.dtypes.AllocationLifetime.State - refined = True - - return refined + # Dev Note: Better dataflow analysis would look at exactly + # what's goin on here! + # Assume 3D cartesian! + if len(transient_data.shape) < 3: + warnings.warn( + f"Potential non-3D array: {transient_data}, skipping.", + UserWarning, + stacklevel=2, + ) + return False + + read_write_range: dace.subsets.Range = dace.subsets.union( + transient_map_reads, transient_map_writes + ) + + if read_write_range is None: + return False + + if read_write_range.size()[axis.as_cartesian_index()] != 1: + return False + + # This transient read and write access is done on exactly one element + # therefore this dimension can be removed. BUT we are not truly + # removing it, we are reducing it to 1 to not have to deal + # with different slicing. + transient_data.shape = _change_index_of_tuple( + transient_data.shape, + axis.as_cartesian_index(), + value=1, + ) + + if len(transient_data.shape) == 3: + layout = [*ijk_order] + else: + data_dim_count = len(transient_data.shape) - 3 + layout = [dim + data_dim_count for dim in ijk_order] + [ + i - 1 for i in range(data_dim_count, 0, -1) + ] -class _CartesianMapNesting: - def __init__( - self, - cartesian_current_map_nesting: list[stree.nodes.MapEntry | None], - node: stree.MapScope, - ) -> None: - self._cartesian_current_map_nesting = cartesian_current_map_nesting - self._node = node - - def __enter__(self) -> None: - if AxisIterator._I.value[0] in self._node.node.params[0]: - self._cartesian_current_map_nesting[0] = self._node.node - elif AxisIterator._J.value[0] in self._node.node.params[0]: - self._cartesian_current_map_nesting[1] = self._node.node - elif AxisIterator._K.value[0] in self._node.node.params[0]: - self._cartesian_current_map_nesting[2] = self._node.node - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - if AxisIterator._I.value[0] in self._node.node.params[0]: - self._cartesian_current_map_nesting[0] = None - elif AxisIterator._J.value[0] in self._node.node.params[0]: - self._cartesian_current_map_nesting[1] = None - elif AxisIterator._K.value[0] in self._node.node.params[0]: - self._cartesian_current_map_nesting[2] = None + transient_data.set_strides_from_layout(*layout) + transient_data.lifetime = dace.dtypes.AllocationLifetime.State + return True -class CollectTransientAccessInCartesianMaps(stree.ScheduleNodeVisitor): - """Collect all access of transient arrays per Maps.""" +class CollectTransientRangeAccess(stree.ScheduleNodeVisitor): + """Unionize all transient arrays access into a single Range.""" def __init__(self) -> None: - self.transient_map_access: dict[str, set[stree.nodes.MapEntry]] = {} - self._cartesian_current_map_nesting: list[stree.nodes.MapEntry | None] = [ - None, - None, - None, - ] + # Map access is a `list` instead of a `set` because we want to double count + # access that are in/out as two access on the axis. + self.transients_range_writes: dict[str, dace.subsets.Range | None] = {} + self.transients_range_reads: dict[str, dace.subsets.Range | None] = {} def __str__(self) -> str: return "CartesianCollectMaps" - def visit_MapScope(self, node: stree.MapScope) -> None: - if len(node.node.params) > 1: - ndsl_log.debug( - "Can't apply CartesianRefineTransients, require unidimensional Maps" - ) - return - - with _CartesianMapNesting(self._cartesian_current_map_nesting, node): - for child in node.children: - self.visit(child) - - def visit_TaskletNode(self, node: stree.TaskletNode) -> None: - for memlet in [*node.input_memlets(), *node.output_memlets()]: + def _record_access( + self, + memlets: stree.MemletSet, + recording_set: dict[str, dace.subsets.Range | None], + ) -> None: + for memlet in memlets: data = self.containers[memlet.data] if data.transient and isinstance(data, dace.data.Array): - for map_entry in self._cartesian_current_map_nesting: - if map_entry is not None: - self.transient_map_access[memlet.data].add(map_entry) + if not isinstance(memlet.subset, dace.subsets.Range): + raise NotImplementedError( + "Memlet refining only works with Range subsets" + ) + recording_set[memlet.data] = dace.subsets.union( + recording_set[memlet.data], memlet.subset + ) + + def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + self._record_access(node.input_memlets(), self.transients_range_writes) + self._record_access(node.output_memlets(), self.transients_range_reads) def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: self.containers = node.containers for name, data in self.containers.items(): if data.transient and isinstance(data, dace.data.Array): - self.transient_map_access[name] = set() + self.transients_range_writes[name] = None + self.transients_range_reads[name] = None for child in node.children: self.visit(child) @@ -152,28 +125,28 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): """Rebuild memlets from containers to ensure they are scope to the right size.""" + def __init__(self, refined_arrays: set[str]) -> None: + self._refined_arrays = refined_arrays + def __str__(self) -> str: return "RefineTransientAxis" def visit_TaskletNode(self, node: stree.TaskletNode) -> None: for memlet in [*node.output_memlets(), *node.input_memlets()]: + if memlet.data not in self._refined_arrays: + continue array = self.containers[memlet.data] if array.transient: - replace_cartesian_access = {} - if len(array.shape) >= 1 and array.shape[0] == 1: - replace_cartesian_access[AxisIterator._I.as_str()] = 0 - if len(array.shape) >= 2 and array.shape[1] == 1: - replace_cartesian_access[AxisIterator._J.as_str()] = 0 - if len(array.shape) >= 3 and array.shape[2] == 1: - # Workaround because the iterator can be `__k_0` instead of `__k` - axis = None - for axis_symbol in memlet.free_symbols: - if axis_symbol.startswith(AxisIterator._K.as_str()): - axis = axis_symbol - break - if axis: - replace_cartesian_access[axis] = 0 - memlet.replace(replace_cartesian_access) + if not isinstance(memlet.subset, dace.subsets.Range): + raise NotImplementedError( + "Memlet refining only works with Range subsets" + ) + + # Reduce "refined" dimension to a single element, effectively + # eliminating it. + for index, _ in enumerate(memlet.subset.ranges): + if array.shape[index] == 1: + memlet.subset.ranges[index] = (0, 0, 1) def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: self.containers = node.containers @@ -188,14 +161,15 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): It can do: - Looking at usage of a transient in a cartesian axis (e.g. loop over a - cartesian axis) it will reduce that axis to 1 if it exists in _only one_. + cartesian axis) it will reduce that axis to 1 if all access are atomic + (exactly _one_ element of the array is ever worked on) It should but cannot do/will bug if: - - Dataflow analysis on the axis to prevent reducing an axis to one where - the transient is used with offset, leading to faulty numerics - - Using the dataflow above, we can reduce the dimensions to the correct lowest + - If the transient is _written_ before being _read_ this won't catch it (not its job), but we could + - With better dataflow analysis, we can reduce the dimensions to the correct lowest size needed on the axis (e.g. transient[K] and transient[K+1], requires a 2-element - buffer) + buffer), instead of the defensive _no refine_ strategy used now. We have _most_ of the + info in the `Range` - Current action when detecting a valid candidate is to reduce the size of the dimension to 1, rather than removing it. This will effectively, if generic compilers do their job, reduce the cache access significantly. This also has been implemented to _not_ deal with offset/slicing @@ -210,6 +184,12 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): - Test for J refine but not in I or K - Test with dataflow: if/else, while, etc. - Test with ForScope (FORWARD/BACKWARD) instead of Map + + Coding traps: + - We reduce the "refined" dimensions to 1 which is functionally eliminating it. This is solid. In the + case of the one we can't eliminate we don't do anything. We could find the "smallest buffer size" needed + and reduce the local dimension to it. BUT if we do this, we have to take into account the offset into + memory (e.g. halo) for the `RebuildMemletsFromContainers`! """ def __init__(self, backend: str) -> None: @@ -230,11 +210,13 @@ def __init__(self, backend: str) -> None: f"backend {backend}" ) + self.refined_array: set[str] = set() + def __str__(self) -> str: return "CartesianRefineTransients" def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: - collect_map = CollectTransientAccessInCartesianMaps() + collect_map = CollectTransientRangeAccess() collect_map.visit(node) # Remove Axis @@ -242,13 +224,19 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: for name, data in node.containers.items(): if not (data.transient and isinstance(data, dace.data.Array)): continue - refined = _reduce_cartesian_axes_size_to_1( - collect_map.transient_map_access[name], - data, - self.ijk_order, - ) + refined = False + for axis in AxisIterator: + refined |= _reduce_cartesian_axis_size_to_1( + axis, + collect_map.transients_range_reads[name], + collect_map.transients_range_writes[name], + data, + self.ijk_order, + ) + refined_transient += 1 if refined else 0 + self.refined_array.add(name) - RebuildMemletsFromContainers().visit(node) + RebuildMemletsFromContainers(self.refined_array).visit(node) ndsl_log.debug(f"🚀 {refined_transient} Transient refined") From abcc3cdd70da765d9fe8b2f4b183dd38bb043d75 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 26 Nov 2025 13:27:19 -0500 Subject: [PATCH 43/43] [Fix] [Translate] Update API for parallel test when using `MultiModalMetric` (#332) * Remove old options for `MultiModalFloatMetric` * Defensive programming: bail out if we can't measure the ref vs input diff --- ndsl/stencils/testing/test_translate.py | 2 -- ndsl/testing/comparison.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index d8c615e2..8ba789cb 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -430,8 +430,6 @@ def test_parallel_savepoint( absolute_eps_override=case.testobj.mmr_absolute_eps, relative_fraction_override=case.testobj.mmr_relative_fraction, ulp_override=case.testobj.mmr_ulp, - ignore_near_zero_errors=ignore_near_zero, - near_zero=case.testobj.near_zero, sort_report=case.sort_report, ) else: diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 3891060e..c1a5b1ef 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -257,7 +257,9 @@ def __init__( self.check = np.all(self.success) self.sort_report = sort_report - if input_values is not None: + # We might have sliced outputs in the translate test. Rather than funnel the slicing + # all the way down, we bail out if we can measure input vs reference + if input_values is not None and input_values.shape == reference_values.shape: self.number_changing_values = ( (input_values != reference_values).flatten().shape[0] )