From e8f0f7f94d70efeca2b4f14a87a72f2560fe599d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 26 Jan 2026 13:02:50 -0500 Subject: [PATCH 01/47] Introducing Backend and updating downstream NDSL code --- ndsl/__init__.py | 2 + ndsl/boilerplate.py | 12 +- ndsl/config/__init__.py | 7 + ndsl/config/backend.py | 126 ++++++++++++++++++ ndsl/dsl/dace/build.py | 12 +- ndsl/dsl/dace/dace_config.py | 26 ++-- ndsl/dsl/dace/orchestration.py | 14 +- .../stree/optimizations/refine_transients.py | 13 +- ndsl/dsl/dace/utils.py | 19 ++- ndsl/dsl/gt4py_utils.py | 44 +++--- ndsl/dsl/stencil.py | 10 +- ndsl/dsl/stencil_config.py | 21 +-- ndsl/grid/generation.py | 3 +- ndsl/grid/helper.py | 4 +- ndsl/initialization/allocator.py | 8 +- ndsl/initialization/subtile_grid_sizer.py | 7 +- ndsl/performance/collector.py | 13 +- ndsl/performance/report.py | 10 +- ndsl/performance/tools.py | 5 +- ndsl/quantity/local.py | 3 +- ndsl/quantity/metadata.py | 5 +- ndsl/quantity/quantity.py | 15 ++- ndsl/stencils/testing/conftest.py | 19 +-- ndsl/stencils/testing/grid.py | 11 +- ndsl/stencils/testing/test_translate.py | 23 ++-- ndsl/stencils/testing/translate.py | 7 +- ndsl/xumpy/alloc.py | 37 ++--- tests/test_ndsl_runtime.py | 9 +- 28 files changed, 325 insertions(+), 160 deletions(-) create mode 100644 ndsl/config/__init__.py create mode 100644 ndsl/config/backend.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index dd5f21e3..41bd7bc6 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -4,6 +4,7 @@ from .comm.local_comm import LocalComm from .comm.mpi import MPIComm from .comm.partitioner import CubedSpherePartitioner, TilePartitioner +from .config.backend import Backend from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration @@ -33,6 +34,7 @@ __all__ = [ "dsl", + "Backend", "CubedSphereCommunicator", "TileCommunicator", "LocalComm", diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index 1e94e3c7..c04b0dfd 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -1,6 +1,7 @@ import warnings from ndsl import ( + Backend, CompilationConfig, DaceConfig, DaCeOrchestration, @@ -14,6 +15,7 @@ TileCommunicator, TilePartitioner, ) +from ndsl.config.backend import _BACKEND_PERFORMANCE_CPU, _BACKEND_PYTHON def _get_factories( @@ -21,7 +23,7 @@ def _get_factories( ny: int, nz: int, nhalo: int, - backend: str, + backend: Backend, orchestration: DaCeOrchestration, topology: str, ) -> tuple[StencilFactory, QuantityFactory]: @@ -91,13 +93,13 @@ def get_factories_single_tile_orchestrated( ny: int, nz: int, nhalo: int, - backend: str = "dace:cpu", + backend: Backend = _BACKEND_PERFORMANCE_CPU, *, orchestration_mode: DaCeOrchestration | None = None, ) -> tuple[StencilFactory, QuantityFactory]: """Build the pair of (StencilFactory, QuantityFactory) for orchestrated code on a single tile topology.""" - if backend is not None and not backend.startswith("dace"): + if backend is not None and not backend.is_orchestrated(): raise ValueError("Only `dace:*` backends can be orchestrated.") return _get_factories( @@ -112,7 +114,7 @@ def get_factories_single_tile_orchestrated( def get_factories_single_tile( - nx: int, ny: int, nz: int, nhalo: int, backend: str = "numpy" + nx: int, ny: int, nz: int, nhalo: int, backend: Backend = _BACKEND_PYTHON ) -> tuple[StencilFactory, QuantityFactory]: """Build the pair of (StencilFactory, QuantityFactory) for stencils on a single tile topology.""" return _get_factories( @@ -121,6 +123,6 @@ def get_factories_single_tile( nz=nz, nhalo=nhalo, backend=backend, - orchestration=DaCeOrchestration.Python, + orchestration=DaCeOrchestration.BuildAndRun, topology="tile", ) diff --git a/ndsl/config/__init__.py b/ndsl/config/__init__.py new file mode 100644 index 00000000..51e071c1 --- /dev/null +++ b/ndsl/config/__init__.py @@ -0,0 +1,7 @@ +from .backend import Backend, BackendFramework, BackendTargetDevice + +__all__ = [ + "Backend", + "BackendFramework", + "BackendTargetDevice", +] diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py new file mode 100644 index 00000000..07fdbd39 --- /dev/null +++ b/ndsl/config/backend.py @@ -0,0 +1,126 @@ +from __future__ import annotations +from enum import Enum +import gt4py.cartesian.backend as gt_backend + + +class BackendStrategy(Enum): + """Strategy for the code execution""" + + STENCIL = "st" + ORCHESTRATION = "orch" + + +class BackendTargetDevice(Enum): + """Targeted device""" + + CPU = "cpu" + GPU = "gpu" + + +class BackendFramework(Enum): + """Main framework (or language) backend relies on""" + + GRIDTOOLS = "gt" + DACE = "dace" + PYTHON = "python" + + +_NDSL_TO_GT4PY_BACKEND_NAMING = { + "st:python:cpu:debug": "debug", + "st:python:cpu:numpy": "numpy", + "st:gt:cpu:IJK": "gt:cpu_kfirst", + "st:gt:cpu:KJI": "gt:cpu_ifirst", + "st:gt:gpu:KJI": "gt:gpu", + "orch:dace:cpu:IJK": "dace:cpu_kfirst", + "orch:dace:cpu:KIJ": "dace:cpu", + "orch:dace:cpu:KJI": "dace:cpu_KJI", + "orch:dace:gpu:KJI": "dace:gpu", +} +"""Internal: match the NDSL backend names with the GT4Py names""" + + +class Backend: + """Backend for NDSL""" + + def __init__(self, ndsl_backend: str) -> None: + parts = ndsl_backend.split(":") + if len(parts) != 4: + raise ValueError(f"Backend {ndsl_backend} is ill-formed.") + self._humanly_readable = ndsl_backend + # Split into internal parameters + self._strategy = BackendStrategy(parts[0].lower()) + self._framework = BackendFramework(parts[1].lower()) + self._device = BackendTargetDevice(parts[2].lower()) + self._loop_order = parts[3] + + # Check GPU capacity + if ( + self._device == BackendTargetDevice.GPU + and gt_backend.from_name(self.as_gt4py()).storage_info["device"] != "gpu" + ): + raise ValueError( + f"Coding error: NDSL backend requested {self._humanly_readable} " + f"translate to non-GPU {self.as_gt4py()} GT4Py backend" + ) + + def __str__(self) -> str: + return self.as_humanly_readable() + + def __repr__(self) -> str: + return self.as_humanly_readable() + + @staticmethod + def debug() -> Backend: + return Backend("st:python:cpu:debug") + + @staticmethod + def python() -> Backend: + return Backend("st:python:cpu:numpy") + + @staticmethod + def performance_cpu() -> Backend: + return Backend("orch:dace:cpu:IJK") + + @staticmethod + def hybrid_fortran_cpu() -> Backend: + return Backend("orch:dace:cpu:KJI") + + @staticmethod + def performance_gpu() -> Backend: + return Backend("orch:dace:gpu:KJI") + + @property + def device(self) -> BackendTargetDevice: + return self._device + + @property + def framework(self) -> BackendFramework: + return self._framework + + def as_gt4py(self) -> str: + if self._humanly_readable in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): + return _NDSL_TO_GT4PY_BACKEND_NAMING[self._humanly_readable] + raise ValueError( + f"Backend {self._humanly_readable} cannot be translate to GT4Py" + ) + + def as_humanly_readable(self) -> str: + return self._humanly_readable + + def as_safe_for_path(self) -> str: + return self._humanly_readable.replace(":", "_") + + def is_orchestrated(self) -> bool: + return self._strategy == BackendStrategy.ORCHESTRATION + + def is_stencil(self) -> bool: + return self._strategy == BackendStrategy.STENCIL + + def is_gpu_backend(self) -> bool: + return self._device == BackendTargetDevice.GPU + + +# Those two internal values are used for default parameters values +# as it is bad practice to call a function in default argument value +_BACKEND_PERFORMANCE_CPU = Backend.performance_cpu() +_BACKEND_PYTHON = Backend.python() diff --git a/ndsl/dsl/dace/build.py b/ndsl/dsl/dace/build.py index 155583a0..ef1bcc90 100644 --- a/ndsl/dsl/dace/build.py +++ b/ndsl/dsl/dace/build.py @@ -2,6 +2,7 @@ from dace.sdfg import SDFG from gt4py.cartesian import config as gt_config +from ndsl.config import Backend from ndsl.dsl.caches.cache_location import get_cache_directory, get_cache_fullpath from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration from ndsl.logging import ndsl_log @@ -23,7 +24,10 @@ def build_info_filepath() -> str: def write_build_info( - sdfg: SDFG, layout: tuple[int, int], resolution_per_tile: list[int], backend: str + sdfg: SDFG, + layout: tuple[int, int], + resolution_per_tile: list[int], + backend: Backend, ) -> None: """Write down all relevant information on the build to identify it at load time.""" @@ -86,7 +90,7 @@ def get_sdfg_path( build_info_file.readline() # Read in build_backend = build_info_file.readline().rstrip() - if config.get_backend() != build_backend: + if config.get_backend() != Backend(build_backend): raise RuntimeError( f"SDFG build for {build_backend}, {config._backend} has been asked" ) @@ -110,12 +114,12 @@ def set_distributed_caches(config: DaceConfig) -> None: """In Run mode, check required file then point current rank cache to source cache""" # Execute specific initialization per orchestration state - orchestration_mode = config.get_orchestrate() - if orchestration_mode == DaCeOrchestration.Python: + if not config.get_backend().is_orchestrated(): return # Check that we have all the file we need to early out in case # of issues. + orchestration_mode = config.get_orchestrate() if orchestration_mode == DaCeOrchestration.Run: import os diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 2be2b940..deaf2bda 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -7,12 +7,11 @@ import dace.config from gt4py.cartesian.config import GT4PY_COMPILE_OPT_LEVEL -from ndsl import LocalComm +from ndsl import Backend, LocalComm from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath -from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import get_precision from ndsl.optional_imports import cupy as cp from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector @@ -137,17 +136,16 @@ class DaCeOrchestration(enum.Enum): Run: load from .so and run, will fail if .so is not available """ - Python = 0 + BuildAndRun = 0 Build = 1 - BuildAndRun = 2 - Run = 3 + Run = 2 class DaceConfig: def __init__( self, communicator: Communicator | None, - backend: str, + backend: Backend, tile_nx: int = 0, tile_nz: int = 0, orchestration: DaCeOrchestration | None = None, @@ -193,11 +191,11 @@ def __init__( # We should refactor the architecture to allow for a `gtc:orchestrated:dace:X` # backend that would signify both the `CPU|GPU` split and the orchestration mode if orchestration is None: - fv3_dacemode_env_var = os.getenv("FV3_DACEMODE", "Python") + fv3_dacemode_env_var = os.getenv("FV3_DACEMODE", "BuildAndRun") # The below condition guards against defining empty FV3_DACEMODE and # awkward behavior of os.getenv returning "" even when not defined if fv3_dacemode_env_var is None or fv3_dacemode_env_var == "": - fv3_dacemode_env_var = "Python" + fv3_dacemode_env_var = "BuildAndRun" self._orchestrate = DaCeOrchestration[fv3_dacemode_env_var] else: self._orchestrate = orchestration @@ -359,19 +357,13 @@ def __init__( set_distributed_caches(self) - if self.is_dace_orchestrated() and "dace" not in self._backend: - raise RuntimeError( - "DaceConfig: orchestration can only be leveraged " - f"with the `dace:*` backends, not with {self._backend}." - ) - def is_dace_orchestrated(self) -> bool: - return self._orchestrate != DaCeOrchestration.Python + return self._backend.is_orchestrated() def is_gpu_backend(self) -> bool: - return is_gpu_backend(self._backend) + return self._backend.is_gpu_backend() - def get_backend(self) -> str: + def get_backend(self) -> Backend: return self._backend def get_orchestrate(self) -> DaCeOrchestration: diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 98e44542..b59ac04f 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -19,7 +19,7 @@ from dace.transformation.dataflow import MapExpansion from dace.transformation.helpers import get_parent_map from dace.transformation.passes.simplify import SimplifyPass -from gt4py import storage +from gt4py import storage as gt_storage import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI @@ -83,7 +83,10 @@ def _download_results_from_dace( return None backend = config.get_backend() - return [storage.from_array(result, backend=backend) for result in dace_result] + return [ + gt_storage.from_array(result, backend=backend.as_gt4py()) + for result in dace_result + ] def _to_gpu(sdfg: SDFG) -> None: @@ -173,7 +176,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): passes = [] - if backend_name == "dace:cpu_kfirst": + if backend_name.as_humanly_readable() == "orch:dace:cpu:IJK": passes.extend( [ CleanUpScheduleTree(), @@ -183,7 +186,10 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - elif backend_name in ["dace:cpu_KJI", "dace:gpu"]: + elif backend_name.as_humanly_readable() in [ + "orch:dace:cpu:KJI", + "orch:dace:gpu:KJI", + ]: passes.extend( [ CleanUpScheduleTree(), diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 05224897..998a2a48 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -5,7 +5,7 @@ import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl import ndsl_log +from ndsl import Backend, ndsl_log from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator @@ -241,7 +241,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): memory (e.g. halo) for the `RebuildMemletsFromContainers`! """ - def __init__(self, backend: str) -> None: + def __init__(self, backend: Backend) -> None: warnings.warn( "CartesianRefineTransients is a WIP. It's usage is *severely* limited " "and will most likely lead to bad numerics. Check the docs, check utest.", @@ -249,11 +249,14 @@ def __init__(self, backend: str) -> None: stacklevel=2, ) - if backend in ["dace:cpu_kfirst"]: + if backend.as_humanly_readable() in ["orch:dace:cpu:IJK"]: self.ijk_order = (2, 1, 0) - elif backend in ["dace:gpu", "dace:cpu_KJI"]: + elif backend.as_humanly_readable() in [ + "orch:dace:gpu:KJI", + "orch:dace:cpu:KJI", + ]: self.ijk_order = (0, 1, 2) - elif backend in ["dace:cpu"]: + elif backend.as_humanly_readable() in ["orch:dace:cpu:KIJ"]: self.ijk_order = (1, 2, 0) else: raise NotImplementedError( diff --git a/ndsl/dsl/dace/utils.py b/ndsl/dsl/dace/utils.py index ae6b1d80..b57213aa 100644 --- a/ndsl/dsl/dace/utils.py +++ b/ndsl/dsl/dace/utils.py @@ -8,11 +8,12 @@ from dace.transformation.helpers import get_parent_map from gt4py.cartesian.gtscript import PARALLEL, computation, interval +import ndsl.xumpy as xp +from ndsl.config import Backend from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, FrozenStencil, StencilConfig from ndsl.dsl.typing import Float, FloatField from ndsl.logging import ndsl_log -from ndsl.optional_imports import cupy as cp class DaCeProgress: @@ -187,7 +188,7 @@ def copy_kernel(q_in: FloatField, q_out: FloatField) -> None: class MaxBandwidthBenchmarkProgram: - def __init__(self, size: Any, backend: str) -> None: + def __init__(self, size: Any, backend: Backend) -> None: from ndsl.dsl.dace.orchestration import DaCeOrchestration, orchestrate dace_config = DaceConfig( @@ -211,7 +212,7 @@ def __call__(self, A: Any, B: Any, n: int) -> None: def kernel_theoretical_timing( sdfg: dace.sdfg.SDFG, *, - backend: str, + backend: Backend, hardware_bw_in_GB_s: float | None = None, ) -> dict[str, float]: """Compute a lower timing bound for kernels with the following hypothesis: @@ -229,12 +230,8 @@ def kernel_theoretical_timing( f" arrays at {Float} precision..." ) bench = MaxBandwidthBenchmarkProgram(size, backend) - if backend == "dace:gpu": - A = cp.ones(size, dtype=Float) - B = cp.ones(size, dtype=Float) - else: - A = np.ones(size, dtype=Float) - B = np.ones(size, dtype=Float) + A = xp.ones(size, backend, dtype=Float) + B = xp.ones(size, backend, dtype=Float) n = 1000 m = 4 dt = [] @@ -259,7 +256,7 @@ def kernel_theoretical_timing( if hardware_bw_in_GB_s else "Given hardware bandwidth" ) - print(f"{label}: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s") + print(f"{label}: {bandwidth_in_bytes_s / (1024 * 1024 * 1024)} GB/s") allmaps = [ (me, state) @@ -342,7 +339,7 @@ def report_kernel_theoretical_timing( def kernel_theoretical_timing_from_path( sdfg_path: str, - backend: str, + backend: Backend, hardware_bw_in_GB_s: float | None = None, output_format: str | None = None, ) -> str: diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 86eb0bbc..6dc6a4a6 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -7,6 +7,7 @@ from gt4py import storage as gt_storage from gt4py.cartesian import backend as gt_backend +from ndsl.config.backend import Backend from ndsl.constants import N_HALO_DEFAULT from ndsl.dsl.typing import DTypes, Float from ndsl.logging import ndsl_log @@ -85,7 +86,7 @@ def make_storage_data( shape: tuple[int, ...] | None = None, origin: tuple[int, ...] = origin, *, - backend: str, + backend: Backend, dtype: DTypes = Float, mask: tuple[bool, ...] | None = None, start: tuple[int, ...] = (0, 0, 0), @@ -190,7 +191,7 @@ def make_storage_data( storage = gt_storage.from_array( data, dtype, - backend=backend, + backend=backend.as_gt4py(), aligned_index=_translate_origin(origin, mask), dimensions=_mask_to_dimensions(mask, data.shape), ) @@ -206,7 +207,7 @@ def _make_storage_data_1d( read_only: bool = True, *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: # axis refers to a repeated axis, dummy refers to a singleton axis axis = min(axis, len(shape) - 1) @@ -243,7 +244,7 @@ def _make_storage_data_2d( read_only: bool = True, *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: # axis refers to which axis should be repeated (when making a full 3d data), # dummy refers to a singleton axis @@ -277,7 +278,7 @@ def _make_storage_data_3d( start: tuple[int, ...] = (0, 0, 0), *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: istart, jstart, kstart = start isize, jsize, ksize = data.shape @@ -296,7 +297,7 @@ def _make_storage_data_Nd( start: tuple[int, ...] | None = None, *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: if start is None: start = tuple([0] * data.ndim) @@ -310,7 +311,7 @@ def make_storage_from_shape( shape: tuple[int, ...], origin: tuple[int, ...] = origin, *, - backend: str, + backend: Backend, dtype: DTypes = Float, mask: tuple[bool, ...] | None = None, ) -> npt.NDArray: @@ -342,7 +343,7 @@ def make_storage_from_shape( storage = gt_storage.zeros( shape, dtype, - backend=backend, + backend=backend.as_gt4py(), aligned_index=_translate_origin(origin, mask), dimensions=_mask_to_dimensions(mask, shape), ) @@ -358,7 +359,7 @@ def make_storage_dict( names: list[str] | None = None, axis: int = 2, *, - backend: str, + backend: Backend, dtype: DTypes = Float, ) -> dict[str, npt.NDArray]: assert names is not None, "for 4d variable storages, specify a list of names" @@ -379,7 +380,7 @@ def make_storage_dict( return data_dict -def storage_dict(st_dict, names, shape, origin, *, backend: str): +def storage_dict(st_dict, names, shape, origin, *, backend: Backend): for name in names: st_dict[name] = make_storage_from_shape(shape, origin, backend=backend) @@ -446,10 +447,6 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None): return cp.asarray(array, dtype, order) -def is_gpu_backend(backend: str) -> bool: - return gt_backend.from_name(backend).storage_info["device"] == "gpu" - - _FORTRAN_LOOP_LAYOUT = (2, 1, 0) """Fortran is a column-first (or stride-first) memory system, which in the internal gt4py loop layout means I (or axis[0]) has @@ -462,19 +459,22 @@ def is_gpu_backend(backend: str) -> bool: """ -def backend_is_fortran_aligned(backend: str) -> bool: +def backend_is_fortran_aligned(backend: Backend | None) -> bool: """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran striding.""" # Dev NOTE: this is used in interfacing with NDSL (e.g. GEOS.) - return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name(backend).storage_info[ - "layout_map" - ](("I", "J", "K")) + if not backend: + return False + + return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name( + backend.as_gt4py() + ).storage_info["layout_map"](("I", "J", "K")) -def zeros(shape, dtype=Float, *, backend: str): - storage_type = cp.ndarray if is_gpu_backend(backend) else np.ndarray +def zeros(shape, dtype=Float, *, backend: Backend): + storage_type = cp.ndarray if backend.is_gpu_backend() else np.ndarray xp = cp if cp and storage_type is cp.ndarray else np return xp.zeros(shape, dtype=dtype) @@ -543,8 +543,8 @@ def stack(tup, axis: int = 0, out=None): return xp.stack(array_tup, axis, out) -def device_sync(backend: str) -> None: - if cp and is_gpu_backend(backend): +def device_sync(backend: Backend) -> None: + if cp and backend.is_gpu_backend(): cp.cuda.Device(0).synchronize() diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index e6682ec3..2ba025e1 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -20,6 +20,7 @@ from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI +from ndsl.config.backend import Backend, BackendFramework from ndsl.constants import ( X_DIM, X_DIMS, @@ -182,7 +183,7 @@ def __init__( comm=comm, ) compilation_config = CompilationConfig( - backend="numpy", + backend=Backend.python(), rebuild=stencil_config.compilation_config.rebuild, validate_args=stencil_config.compilation_config.validate_args, format_source=True, @@ -316,7 +317,10 @@ def __init__( # NOTE: this is also down in `dace/build.py` for orchestration # This is still needed for non-orchestrated used of DaCe. # A better build system would take care of BOTH of those at the same time - if "dace" in self.stencil_config.compilation_config.backend: + if ( + BackendFramework.DACE + == self.stencil_config.compilation_config.backend.framework + ): dace.Config.set( "default_build_folder", value="{gt_root}/{gt_cache}/dacecache".format( @@ -986,7 +990,7 @@ def __init__( self.comm = comm @property - def backend(self) -> str: + def backend(self) -> Backend: return self.config.compilation_config.backend def from_origin_domain( diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index f90db27a..0ed74e92 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -6,14 +6,15 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from typing import Any, Self -from gt4py.cartesian.backend import from_name as check_backend_existence +import gt4py.cartesian.backend as gt_backend from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches from ndsl.comm.partitioner import Partitioner +from ndsl.config import Backend, BackendTargetDevice +from ndsl.config.backend import _BACKEND_PYTHON from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration -from ndsl.dsl.gt4py_utils import is_gpu_backend class RunMode(enum.Enum): @@ -32,7 +33,7 @@ class RunMode(enum.Enum): class CompilationConfig: def __init__( self, - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = True, validate_args: bool = True, format_source: bool = False, @@ -41,10 +42,10 @@ def __init__( use_minimal_caching: bool = False, communicator: Communicator | None = None, ) -> None: - if "gpu" not in backend and device_sync is True: + if backend.device is BackendTargetDevice.CPU and device_sync is True: raise RuntimeError("Device sync is true on a CPU based backend") - # GT4Py backend args - check_backend_existence(backend) + # GT4Py backend check - expect GT4Py to raise if the backend doesn't exist + gt_backend.from_name(backend.as_gt4py()) self.backend = backend self.rebuild = rebuild self.validate_args = validate_args @@ -196,7 +197,7 @@ def __init__( else DaceConfig( communicator=None, backend=self.compilation_config.backend, - orchestration=DaCeOrchestration.Python, + orchestration=DaCeOrchestration.Run, ) ) self.backend_opts = { @@ -206,12 +207,12 @@ def __init__( self._hash = self._compute_hash() @property - def backend(self) -> str: + def backend(self) -> Backend: return self.compilation_config.backend def _compute_hash(self) -> int: md5 = hashlib.md5() - md5.update(self.compilation_config.backend.encode()) + md5.update(self.compilation_config.backend.as_gt4py().encode()) for attr in ( self.compilation_config.rebuild, self.compilation_config.validate_args, @@ -254,7 +255,7 @@ def stencil_kwargs( @property def is_gpu_backend(self) -> bool: - return is_gpu_backend(self.compilation_config.backend) + return self.compilation_config.backend.is_gpu_backend() @classmethod def _get_oir_pipeline(cls, skip_passes: Sequence[str]) -> OirPipeline: diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index a2581451..012e260a 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -9,6 +9,7 @@ from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.communicator import Communicator +from ndsl.config import Backend from ndsl.constants import ( N_HALO_DEFAULT, PI, @@ -488,7 +489,7 @@ def from_tile_sizing( npy: int, npz: int, communicator: Communicator, - backend: str, + backend: Backend, grid_type: int = 0, dx_const: float = 1000.0, dy_const: float = 1000.0, diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index d907e49a..903d71c8 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -11,7 +11,7 @@ # TODO: if we can remove translate tests in favor of checkpointer tests, # we can remove this "disallowed" import (ndsl.util does not depend on ndsl.dsl) -from ndsl.dsl.gt4py_utils import is_gpu_backend, split_cartesian_into_storages +from ndsl.dsl.gt4py_utils import split_cartesian_into_storages from ndsl.dsl.typing import Float from ndsl.grid.generation import MetricTerms from ndsl.initialization.allocator import QuantityFactory @@ -230,7 +230,7 @@ def ptop(self) -> Float: """Top of atmosphere pressure (Pa)""" if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - if self.ak.backend is not None and is_gpu_backend(self.ak.backend): + if self.ak.backend is not None and self.ak.backend.is_gpu_backend(): return Float(self.ak.view[0].get()) else: return Float(self.ak.view[0]) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 6b01d449..813c1551 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -6,6 +6,7 @@ import numpy as np from gt4py import storage as gt_storage +from ndsl.config import Backend from ndsl.constants import SPATIAL_DIMS from ndsl.dsl.typing import Float from ndsl.initialization import GridSizer @@ -13,13 +14,13 @@ class QuantityFactory: - def __init__(self, sizer: GridSizer, *, backend: str) -> None: + def __init__(self, sizer: GridSizer, *, backend: Backend) -> None: """ - Initialize a QuantityFactory from a GridSizer and a GT4Py backend name. + Initialize a QuantityFactory from a GridSizer and a NDSL backend name. Args: sizer: GridSizer object that determines the array sizes. - backend: GT4Py backend name used for performance-optimized allocation. + backend: NDSL backend name used for performance-optimized allocation. """ self.sizer = sizer self.backend = backend @@ -225,7 +226,6 @@ def get_quantity_halo_spec( dims: dimensionality of the data n_halo: number of halo points to update, defaults to self.n_halo dtype: data type of the data - backend: gt4py backend to use """ # TEMPORARY: we do a nasty temporary allocation here to read in the hardware diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index e6781e3c..5674e46a 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -3,6 +3,7 @@ import ndsl.constants as constants from ndsl.comm.partitioner import TilePartitioner +from ndsl.config import Backend from ndsl.constants import N_HALO_DEFAULT from ndsl.dsl.gt4py_utils import backend_is_fortran_aligned from ndsl.initialization.grid_sizer import GridSizer @@ -16,7 +17,7 @@ def __init__( nz: int, n_halo: int, data_dimensions: dict[str, int], - backend: str, + backend: Backend | None = None, ) -> None: super().__init__(nx, ny, nz, n_halo, data_dimensions) @@ -32,7 +33,7 @@ def from_tile_params( n_halo: int, layout: tuple[int, int], *, - backend: str, + backend: Backend | None, data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, @@ -85,7 +86,7 @@ def from_namelist( tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, *, - backend: str, + backend: Backend | None = None, ) -> Self: """Create a SubtileGridSizer from a Fortran namelist. diff --git a/ndsl/performance/collector.py b/ndsl/performance/collector.py index c6ec0d97..cd091f3f 100644 --- a/ndsl/performance/collector.py +++ b/ndsl/performance/collector.py @@ -7,6 +7,7 @@ import numpy as np from ndsl.comm.comm_abc import Comm +from ndsl.config import Backend from ndsl.optional_imports import cupy as cp from ndsl.performance.report import ( Report, @@ -29,13 +30,13 @@ def collect_performance(self) -> None: ... def write_out_performance( self, - backend: str, + backend: Backend, is_orchestrated: bool, dt_atmos: float, ) -> None: ... def write_out_rank_0( - self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str + self, backend: Backend, is_orchestrated: bool, dt_atmos: float, sim_status: str ) -> None: ... @classmethod @@ -77,7 +78,7 @@ def collect_performance(self) -> None: self.timestep_timer.reset() def write_out_rank_0( - self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str + self, backend: Backend, is_orchestrated: bool, dt_atmos: float, sim_status: str ) -> None: if self.comm.Get_rank() == 0: git_hash = "None" @@ -114,7 +115,7 @@ def write_out_rank_0( def write_out_performance( self, - backend: str, + backend: Backend, is_orchestrated: bool, dt_atmos: float, ) -> None: @@ -162,13 +163,13 @@ def collect_performance(self) -> None: def write_out_performance( self, - backend: str, + backend: Backend, is_orchestrated: bool, dt_atmos: float, ) -> None: pass def write_out_rank_0( - self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str + self, backend: Backend, is_orchestrated: bool, dt_atmos: float, sim_status: str ) -> None: pass diff --git a/ndsl/performance/report.py b/ndsl/performance/report.py index a1ccd1da..06b48a22 100644 --- a/ndsl/performance/report.py +++ b/ndsl/performance/report.py @@ -8,6 +8,7 @@ import numpy as np from ndsl.comm.comm_abc import Comm +from ndsl.config import Backend @dataclasses.dataclass @@ -41,7 +42,7 @@ def __post_init__(self) -> None: def get_experiment_info( experiment_name: str, time_step: int, - backend: str, + backend: Backend, git_hash: str, is_orchestrated: bool, ) -> Experiment: @@ -52,7 +53,7 @@ def get_experiment_info( git_hash=git_hash, timestamp=datetime.now().strftime("%d/%m/%Y %H:%M:%S"), timesteps=time_step, - backend=f"{orchestration}/{backend}", + backend=f"{orchestration}/{backend.as_safe_for_path()}", ) return experiment @@ -92,7 +93,8 @@ def gather_timing_data( comm.Gather(sendbuf, recvbuf, root=0) if is_root: timing_info[timer_name] = TimeReport( - hits=0, times=copy.deepcopy(recvbuf.tolist()) # type: ignore[union-attr] # (recvbuf is defined on root rank) + hits=0, + times=copy.deepcopy(recvbuf.tolist()), # type: ignore[union-attr] # (recvbuf is defined on root rank) ) return timing_info @@ -132,7 +134,7 @@ def get_sypd(timing_info: dict[str, TimeReport], dt_atmos: float) -> float: def collect_data_and_write_to_file( time_step: int, - backend: str, + backend: Backend, is_orchestrated: bool, git_hash: str, comm: Comm, diff --git a/ndsl/performance/tools.py b/ndsl/performance/tools.py index 80ed377f..463f5ecf 100644 --- a/ndsl/performance/tools.py +++ b/ndsl/performance/tools.py @@ -1,5 +1,6 @@ import click +from ndsl.config import Backend from ndsl.dsl.dace.utils import ( kernel_theoretical_timing_from_path, memory_static_analysis_from_path, @@ -41,7 +42,7 @@ "--backend", required=False, type=click.STRING, - default="dace:gpu", + default="st:dace:gpu:KJI", ) def command_line( action: str, @@ -60,7 +61,7 @@ def command_line( print( kernel_theoretical_timing_from_path( sdfg_path, - backend=backend, + backend=Backend(backend), hardware_bw_in_GB_s=( None if hardware_bw_in_gb_s == 0 else hardware_bw_in_gb_s ), diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 79f360c6..8b6b17ee 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -4,6 +4,7 @@ import dace import numpy as np +from ndsl.config import Backend from ndsl.optional_imports import cupy from ndsl.quantity import Quantity @@ -22,7 +23,7 @@ def __init__( dims: Sequence[str], units: str, *, - backend: str, + backend: Backend, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, allow_mismatch_float_precision: bool = False, diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index d53a08e9..7cf16a12 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -5,6 +5,7 @@ import numpy as np +from ndsl.config.backend import Backend from ndsl.optional_imports import cupy from ndsl.types import NumpyModule @@ -29,8 +30,8 @@ class QuantityMetadata: "ndarray-like type used to store the data." dtype: type "dtype of the data in the ndarray-like object." - backend: str - "GT4Py backend name. Used for performance optimal data allocation." + backend: Backend + "NDSL backend name. Used for performance optimal data allocation." @property def dim_lengths(self) -> dict[str, int]: diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index e31ac123..545b4a34 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -13,6 +13,7 @@ import ndsl.constants as constants from ndsl.comm.mpi import MPI +from ndsl.config.backend import Backend from ndsl.dsl.typing import Float, is_float from ndsl.optional_imports import cupy from ndsl.quantity.bounds import BoundedArrayView @@ -33,7 +34,7 @@ def __init__( dims: Sequence[str], units: str, *, - backend: str, + backend: Backend, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, allow_mismatch_float_precision: bool = False, @@ -86,7 +87,7 @@ def __init__( _validate_quantity_property_lengths(data.shape, dims, origin, extent) - gt4py_backend_cls = gt_backend.from_name(backend) + gt4py_backend_cls = gt_backend.from_name(backend.as_gt4py()) is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] device = gt4py_backend_cls.storage_info["device"] @@ -123,7 +124,7 @@ def __init__( self._data = gt_storage.from_array( data, data.dtype, - backend=backend, + backend=backend.as_gt4py(), aligned_index=origin, dimensions=dimensions, ) @@ -151,7 +152,7 @@ def from_data_array( origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, number_of_halo_points: int = 0, - backend: str | None = None, + backend: Backend | None = None, allow_mismatch_float_precision: bool = False, ) -> Quantity: """ @@ -249,7 +250,7 @@ def units(self) -> str: return self.metadata.units @property - def backend(self) -> str: + def backend(self) -> Backend: return self.metadata.backend @property @@ -487,7 +488,7 @@ def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple: return tuple(return_list) -def _resolve_backend(data: xr.DataArray, backend: str | None) -> str: +def _resolve_backend(data: xr.DataArray, backend: Backend | None) -> Backend: if backend is not None: # Forced backend name takes precedence return backend @@ -497,4 +498,4 @@ def _resolve_backend(data: xr.DataArray, backend: str | None) -> str: return data.attrs["backend"] # else, fall back to assume python-based layout. - return "debug" + return Backend.debug() diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 1e5f41f7..213c025e 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -17,6 +17,7 @@ ) from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.config import Backend from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.stencils.testing.grid import Grid # type: ignore from ndsl.stencils.testing.parallel_translate import ParallelTranslate @@ -33,7 +34,7 @@ def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( "--backend", action="store", - default="numpy", + default="st:python:cpu:numpy", help="Backend to execute the test with, can only be one.", ) parser.addoption( @@ -232,7 +233,7 @@ def get_savepoint_restriction(metafunc: Any) -> int | None: return int(svpt) if svpt else None -def get_config(backend: str, communicator: Communicator | None) -> StencilConfig: +def get_config(backend: Backend, communicator: Communicator | None) -> StencilConfig: stencil_config = StencilConfig( compilation_config=CompilationConfig( backend=backend, rebuild=False, validate_args=True @@ -248,10 +249,11 @@ def get_config(backend: str, communicator: Communicator | None) -> StencilConfig def sequential_savepoint_cases( metafunc: Any, data_path: Path, namelist_filename: Path, *, backend: str ) -> list[SavepointCase]: + ndsl_backend = Backend(backend) savepoint_names = get_sequential_savepoint_names(metafunc, data_path) namelist = load_f90nml(namelist_filename) grid_params = grid_params_from_f90nml(namelist) - stencil_config = get_config(backend, None) + stencil_config = get_config(ndsl_backend, None) ranks = get_ranks(metafunc, grid_params["layout"]) savepoint_to_replay = get_savepoint_restriction(metafunc) grid_mode = metafunc.config.getoption("grid") @@ -265,7 +267,7 @@ def sequential_savepoint_cases( savepoint_to_replay, stencil_config, namelist, - backend, + ndsl_backend, data_path, grid_mode, topology_mode, @@ -280,7 +282,7 @@ def _savepoint_cases( savepoint_to_replay: int | None, stencil_config: StencilConfig, namelist: Namelist, - backend: str, + backend: Backend, data_path: Path, grid_mode: str, topology_mode: str, @@ -349,7 +351,7 @@ def _savepoint_cases( def compute_grid_data( grid: Grid, grid_params: dict, - backend: str, + backend: Backend, layout: tuple[int, int], topology_mode: str, ) -> None: @@ -371,13 +373,14 @@ def parallel_savepoint_cases( backend: str, comm: Comm, ) -> list[SavepointCase]: + ndsl_backend = Backend(backend) namelist = load_f90nml(namelist_filename) grid_params = grid_params_from_f90nml(namelist) topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") communicator = get_communicator(comm, grid_params["layout"], topology_mode) - stencil_config = get_config(backend, communicator) + stencil_config = get_config(ndsl_backend, communicator) savepoint_names = get_parallel_savepoint_names(metafunc, data_path) grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = get_savepoint_restriction(metafunc) @@ -388,7 +391,7 @@ def parallel_savepoint_cases( savepoint_to_replay, stencil_config, namelist, - backend, + ndsl_backend, data_path, grid_mode, topology_mode, diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 2ae34a85..13d5db2d 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -3,6 +3,7 @@ import numpy as np from ndsl.comm.partitioner import TilePartitioner +from ndsl.config import Backend from ndsl.constants import N_HALO_DEFAULT, X_DIM, Y_DIM, Z_DIM from ndsl.dsl import gt4py_utils as utils from ndsl.dsl.stencil import GridIndexing @@ -35,7 +36,7 @@ class Grid: # grid.ie == npx + halo - 2 @classmethod - def _make(cls, npx, npy, npz, layout, rank, backend): + def _make(cls, npx, npy, npz, layout, rank, backend: Backend): shape_params = { "npx": npx, "npy": npy, @@ -59,13 +60,13 @@ def _make(cls, npx, npy, npz, layout, rank, backend): return cls(indices, shape_params, rank, layout, backend, local_indices=True) @classmethod - def from_namelist(cls, namelist, rank, backend): + def from_namelist(cls, namelist, rank, backend: Backend): return cls._make( namelist.npx, namelist.npy, namelist.npz, namelist.layout, rank, backend ) @classmethod - def with_data_from_namelist(cls, namelist, communicator, backend): + def with_data_from_namelist(cls, namelist, communicator, backend: Backend): grid = cls.from_namelist(namelist, communicator.rank, backend) grid.make_grid_data( npx=namelist.npx, @@ -82,7 +83,7 @@ def __init__( shape_params, rank, layout, - backend, + backend: Backend, data_fields: dict | None = None, local_indices=False, ): @@ -838,7 +839,7 @@ def driver_grid_data(self) -> DriverGridData: def set_grid_data(self, grid_data: GridData): self._grid_data = grid_data - def make_grid_data(self, npx, npy, npz, communicator, backend): + def make_grid_data(self, npx, npy, npz, communicator, backend: Backend): metric_terms = MetricTerms.from_tile_sizing( npx=npx, npy=npy, npz=npz, communicator=communicator, backend=backend ) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index b0b36931..240c49b5 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -9,6 +9,7 @@ from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.config import Backend from ndsl.dsl import gt4py_utils as gt_utils from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, StencilConfig @@ -31,7 +32,7 @@ def platform(): return "docker" if in_docker else "metal" -def process_override(threshold_overrides, testobj, test_name, backend): +def process_override(threshold_overrides, testobj, test_name, backend: Backend): override = threshold_overrides.get(test_name, None) if override is not None: for spec in override: @@ -42,7 +43,7 @@ def process_override(threshold_overrides, testobj, test_name, backend): matches = [ spec for spec in override - if spec["backend"] == backend and spec["platform"] == platform() + if Backend(spec["backend"]) == backend and spec["platform"] == platform() ] if len(matches) == 1: match = matches[0] @@ -145,7 +146,7 @@ def _get_thresholds(compute_function, input_data) -> None: ) def test_sequential_savepoint( case: SavepointCase, - backend, + backend: str, print_failures, failure_stride, subtests, @@ -158,11 +159,12 @@ def test_sequential_savepoint( raise ValueError( f"No translate object available for savepoint {case.savepoint_name}." ) + ndsl_backend = Backend(backend) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend=backend), + compilation_config=CompilationConfig(backend=ndsl_backend), dace_config=DaceConfig( communicator=None, - backend=backend, + backend=ndsl_backend, ), ) # Reduce error threshold for GPU @@ -171,7 +173,7 @@ def test_sequential_savepoint( case.testobj.near_zero = max(case.testobj.near_zero, GPU_NEAR_ZERO) if threshold_overrides is not None: process_override( - threshold_overrides, case.testobj, case.savepoint_name, backend + threshold_overrides, case.testobj, case.savepoint_name, ndsl_backend ) if case.testobj.skip_test: return @@ -328,7 +330,7 @@ def get_tile_communicator(comm, layout): ) def test_parallel_savepoint( case: SavepointCase, - backend, + backend: str, print_failures, failure_stride, subtests, @@ -338,6 +340,7 @@ def test_parallel_savepoint( multimodal_metric, xy_indices=True, ): + ndsl_backend = Backend(backend) mpi_comm = MPIComm() if mpi_comm.Get_size() % 6 != 0: layout = ( @@ -356,10 +359,10 @@ def test_parallel_savepoint( f"No translate object available for savepoint {case.savepoint_name}" ) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend=backend), + compilation_config=CompilationConfig(backend=ndsl_backend), dace_config=DaceConfig( communicator=communicator, - backend=backend, + backend=ndsl_backend, ), ) # Increase minimum error threshold for GPU @@ -368,7 +371,7 @@ def test_parallel_savepoint( case.testobj.near_zero = max(case.testobj.near_zero, GPU_NEAR_ZERO) if threshold_overrides is not None: process_override( - threshold_overrides, case.testobj, case.savepoint_name, backend + threshold_overrides, case.testobj, case.savepoint_name, ndsl_backend ) if case.testobj.skip_test: return diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 8cf78282..45d4a8c2 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -5,6 +5,7 @@ import numpy.typing as npt import ndsl.dsl.gt4py_utils as utils +from ndsl.config import Backend from ndsl.dsl.stencil import StencilFactory from ndsl.optional_imports import cupy as cp from ndsl.quantity import Quantity @@ -22,7 +23,7 @@ def read_serialized_data(serializer, savepoint, variable): return data -def pad_field_in_j(field, nj: int, backend: str): +def pad_field_in_j(field, nj: int, backend: Backend): utils.device_sync(backend) outfield = utils.tile(field[:, 0, :], (nj, 1, 1)).transpose(1, 0, 2) return outfield @@ -358,7 +359,7 @@ class TranslateGrid: # 6---2---7 @classmethod - def new_from_serialized_data(cls, serializer, rank, layout, backend): + def new_from_serialized_data(cls, serializer, rank, layout, backend: Backend): grid_savepoint = serializer.get_savepoint("Grid-Info")[0] grid_data = {} grid_fields = serializer.fields_at_savepoint(grid_savepoint) @@ -366,7 +367,7 @@ def new_from_serialized_data(cls, serializer, rank, layout, backend): grid_data[field] = read_serialized_data(serializer, grid_savepoint, field) return cls(grid_data, rank, layout, backend=backend) - def __init__(self, inputs, rank, layout, *, backend: str): + def __init__(self, inputs, rank, layout, *, backend: Backend): self.backend = backend self.indices = {} self.shape_params = {} diff --git a/ndsl/xumpy/alloc.py b/ndsl/xumpy/alloc.py index a6a98804..dc3fd0b2 100644 --- a/ndsl/xumpy/alloc.py +++ b/ndsl/xumpy/alloc.py @@ -1,7 +1,9 @@ +from typing import Sequence, SupportsIndex + import numpy as np import numpy.typing as npt -from ndsl.dsl.gt4py_utils import is_gpu_backend +from ndsl.config import Backend from ndsl.dsl.typing import Float from ndsl.optional_imports import cupy as cp @@ -9,53 +11,56 @@ if cp is None: cp = np +# Taking a page from cupy's playbook to have tuple & ndarray +_ShapeLike = SupportsIndex | Sequence[SupportsIndex] + def zeros( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.zeros(shape, dtype=dtype) return np.zeros(shape, dtype=dtype) def ones( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.ones(shape, dtype=dtype) return np.ones(shape, dtype=dtype) def empty( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.empty(shape, dtype=dtype) return np.empty(shape, dtype=dtype) def full( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, value: np.ScalarType, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.full(shape, value, dtype=dtype) return np.full(shape, value, dtype=dtype) def random( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): cp.random.rand(*shape) return np.random.rand(*shape) diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py index beb16517..a256ead4 100644 --- a/tests/test_ndsl_runtime.py +++ b/tests/test_ndsl_runtime.py @@ -7,6 +7,7 @@ get_factories_single_tile, get_factories_single_tile_orchestrated, ) +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -51,9 +52,7 @@ def run(self, A: Any, B: Any) -> None: def test_runtime_make_local() -> None: - stencil_factory, quantity_factory = get_factories_single_tile( - 5, 5, 3, 0, backend="numpy" - ) + stencil_factory, quantity_factory = get_factories_single_tile(5, 5, 3, 0) A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") @@ -73,7 +72,7 @@ def test_runtime_make_local() -> None: def test_runtime_has_orchestracted_call() -> None: stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - 5, 5, 3, 0, backend="dace:cpu_kfirst" + 5, 5, 3, 0, backend=Backend.performance_cpu() ) A_ = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") B_ = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") @@ -89,7 +88,7 @@ def test_runtime_has_orchestracted_call() -> None: def test_runtime_does_not_orchestrate_when_call_is_not_present() -> None: stencil_factory, _ = get_factories_single_tile_orchestrated( - 5, 5, 3, 0, backend="dace:cpu_kfirst" + 5, 5, 3, 0, backend=Backend.performance_cpu() ) code = Code_NoCall(stencil_factory) From 706a120ffe3552864f43e5fa691818265fd6f07d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 26 Jan 2026 13:03:08 -0500 Subject: [PATCH 02/47] Lint new files --- ndsl/config/__init__.py | 1 + ndsl/config/backend.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/ndsl/config/__init__.py b/ndsl/config/__init__.py index 51e071c1..9236dc7d 100644 --- a/ndsl/config/__init__.py +++ b/ndsl/config/__init__.py @@ -1,5 +1,6 @@ from .backend import Backend, BackendFramework, BackendTargetDevice + __all__ = [ "Backend", "BackendFramework", diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 07fdbd39..810f6ff0 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -1,5 +1,7 @@ from __future__ import annotations + from enum import Enum + import gt4py.cartesian.backend as gt_backend From 24d61aaf855b1c5bad893eb293e7de1ebd80eeb6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 26 Jan 2026 13:21:28 -0500 Subject: [PATCH 03/47] Some docs --- ndsl/config/backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 810f6ff0..eac9d13c 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -125,4 +125,6 @@ def is_gpu_backend(self) -> bool: # Those two internal values are used for default parameters values # as it is bad practice to call a function in default argument value _BACKEND_PERFORMANCE_CPU = Backend.performance_cpu() +"""Internal: cache performance CPU. Please use Backend.performance_cpu().""" _BACKEND_PYTHON = Backend.python() +"""Internal: cache performance CPU. Please use Backend.python().""" From 359882a1d1a08c3aaa0d80be18a64b55209c4fff Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 26 Jan 2026 14:24:18 -0500 Subject: [PATCH 04/47] Re-order init to go around the circluar dependency (and skip isort on everything) --- ndsl/__init__.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 41bd7bc6..97839190 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,3 +1,4 @@ +# isort:skip_file from . import dsl # isort:skip from .logging import ndsl_log # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator @@ -7,15 +8,7 @@ from .config.backend import Backend from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath -from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration -from .dsl.dace.orchestration import orchestrate, orchestrate_function -from .dsl.dace.utils import ( - ArrayReport, - DaCeProgress, - MaxBandwidthBenchmarkProgram, - StorageReport, -) -from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .quantity import Quantity from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig @@ -26,11 +19,21 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Local, LocalState, Quantity, State +from .quantity import Local, LocalState, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .types import Allocator from .utils import MetaEnumStr +from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .dsl.dace.utils import ( + ArrayReport, + DaCeProgress, + MaxBandwidthBenchmarkProgram, + StorageReport, +) +from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration +from .dsl.dace.orchestration import orchestrate, orchestrate_function + __all__ = [ "dsl", From 52d5f5919373267d5dc2ade62a2a8e3f2df47a4b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 26 Jan 2026 15:56:28 -0500 Subject: [PATCH 05/47] Add backend exists check Add stencil dace backends --- ndsl/config/backend.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index eac9d13c..5594ed35 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -33,9 +33,13 @@ class BackendFramework(Enum): "st:gt:cpu:IJK": "gt:cpu_kfirst", "st:gt:cpu:KJI": "gt:cpu_ifirst", "st:gt:gpu:KJI": "gt:gpu", + "st:dace:cpu:IJK": "dace:cpu_kfirst", "orch:dace:cpu:IJK": "dace:cpu_kfirst", + "st:dace:cpu:KIJ": "dace:cpu", "orch:dace:cpu:KIJ": "dace:cpu", + "st:dace:cpu:KJI": "dace:cpu_KJI", "orch:dace:cpu:KJI": "dace:cpu_KJI", + "st:dace:gpu:KJI": "dace:gpu", "orch:dace:gpu:KJI": "dace:gpu", } """Internal: match the NDSL backend names with the GT4Py names""" @@ -45,6 +49,10 @@ class Backend: """Backend for NDSL""" def __init__(self, ndsl_backend: str) -> None: + if ndsl_backend not in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): + raise ValueError( + f"Unknown {ndsl_backend}, options are {list(_NDSL_TO_GT4PY_BACKEND_NAMING.keys())}" + ) parts = ndsl_backend.split(":") if len(parts) != 4: raise ValueError(f"Backend {ndsl_backend} is ill-formed.") From b84994853e08e7ac93548963df14832b5045bab9 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 09:52:57 -0500 Subject: [PATCH 06/47] Add Backend equal and hash operators --- ndsl/config/backend.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 5594ed35..b2d61ff5 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -79,6 +79,12 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.as_humanly_readable() + def __eq__(self, other: object) -> bool: + return self._humanly_readable == self._humanly_readable + + def __hash__(self) -> int: + return hash(self._humanly_readable) + @staticmethod def debug() -> Backend: return Backend("st:python:cpu:debug") From ba7909b0a0b2b34bff7af1c4014bf1bd5e265517 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 09:53:23 -0500 Subject: [PATCH 07/47] Fix forwarding of NDSL backend into GT4Py --- ndsl/dsl/dace/dace_config.py | 2 +- ndsl/dsl/stencil_config.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index deaf2bda..2a793c54 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -171,6 +171,7 @@ def __init__( of column-physics) and therefore can be compiled once """ + self._backend = backend self._single_code_path = single_code_path # Recording SDFG loaded for fast re-access # ToDo: DaceConfig becomes a bit more than a read-only config @@ -330,7 +331,6 @@ def __init__( except OSError: pass - self._backend = backend self.tile_resolution = [tile_nx, tile_nx, tile_nz] from ndsl.dsl.dace.build import set_distributed_caches diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index 0ed74e92..b615625e 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -158,7 +158,7 @@ def as_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict) -> Self: instance = cls( - backend=data.get("backend", "numpy"), + backend=data.get("backend", Backend.python()), rebuild=data.get("rebuild", False), validate_args=data.get("validate_args", True), format_source=data.get("format_source", False), @@ -240,7 +240,7 @@ def stencil_kwargs( self, *, func: Callable[..., None], skip_passes: Iterable[str] = () ) -> dict: kwargs = { - "backend": self.compilation_config.backend, + "backend": self.compilation_config.backend.as_gt4py(), "rebuild": self.compilation_config.rebuild, "name": func.__module__ + "." + func.__name__, **self.backend_opts, From 3ece5dd60fe6b53306e1604fab7e710f547d4c23 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 09:53:29 -0500 Subject: [PATCH 08/47] Fix forwarding of NDSL backend into GT4Py --- ndsl/initialization/allocator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 813c1551..edbc1511 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -200,7 +200,7 @@ def _allocate( dtype=dtype, aligned_index=origin, dimensions=dimensions, - backend=self.backend, + backend=self.backend.as_gt4py(), ) return Quantity( From c61cd1e5034e05d712f10f0f7563d7163c7318eb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 09:53:44 -0500 Subject: [PATCH 09/47] Fix forwarding of NDSL backend into GT4Py --- tests/dsl/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dsl/utils.py b/tests/dsl/utils.py index 5e8535e5..32c880cb 100644 --- a/tests/dsl/utils.py +++ b/tests/dsl/utils.py @@ -10,7 +10,7 @@ def make_storage( aligned_index=(0, 0, 0), ): return func( - backend=stencil_config.compilation_config.backend, + backend=stencil_config.compilation_config.backend.as_gt4py(), shape=grid_indexing.domain, dtype=dtype, aligned_index=aligned_index, From 18cf9892a325ce42a3b216b6861d18f9c38e171a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 09:56:15 -0500 Subject: [PATCH 10/47] Update all tests --- tests/conftest.py | 11 ++-- tests/dsl/orchestration/test_call.py | 3 +- tests/dsl/test_caches.py | 20 +++--- tests/dsl/test_compilation_config.py | 9 +-- tests/dsl/test_dace_config.py | 15 ++--- tests/dsl/test_skip_passes.py | 3 +- tests/dsl/test_stencil.py | 11 ++-- tests/dsl/test_stencil_config.py | 23 +++---- tests/dsl/test_stencil_factory.py | 21 ++++--- tests/dsl/test_stencil_tables.py | 17 +++-- tests/dsl/test_stencil_wrapper.py | 33 +++++----- tests/grid/test_eta.py | 5 +- tests/mpi/test_eta.py | 3 +- tests/mpi/test_mpi_all_reduce_sum.py | 7 ++- tests/mpi/test_mpi_halo_update.py | 15 ++++- tests/quantity/test_boundary.py | 7 ++- tests/quantity/test_deepcopy.py | 7 ++- tests/quantity/test_local.py | 9 +-- tests/quantity/test_quantity.py | 29 ++++++--- tests/quantity/test_state.py | 5 +- tests/quantity/test_storage.py | 20 +++--- tests/quantity/test_transpose.py | 8 ++- tests/quantity/test_view.py | 63 +++++++++---------- tests/stencils/test_stencils.py | 10 ++- tests/stree_optimizer/test_merge.py | 5 +- tests/stree_optimizer/test_pipeline.py | 3 +- .../stree_optimizer/test_transient_refine.py | 3 +- tests/test_boilerplate.py | 7 ++- tests/test_buffer.py | 12 ++-- tests/test_caching_comm.py | 3 +- tests/test_cube_scatter_gather.py | 3 +- tests/test_dimension_sizer.py | 15 ++--- tests/test_g2g_communication.py | 9 ++- tests/test_halo_data_transformer.py | 5 +- tests/test_halo_update.py | 29 +++++++-- tests/test_halo_update_ranks.py | 3 +- tests/test_netcdf_monitor.py | 3 +- tests/test_partitioner.py | 7 ++- tests/test_sync_shared_boundary.py | 9 +-- tests/test_tile_scatter.py | 17 ++--- tests/test_tile_scatter_gather.py | 3 +- tests/test_xumpy.py | 19 +++--- tests/test_zarr_monitor.py | 32 +++++++--- 43 files changed, 331 insertions(+), 210 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 12e2b88c..df3c2d5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from ndsl.config import Backend from ndsl.optional_imports import cupy @@ -17,14 +18,14 @@ def backend(request): @pytest.fixture -def gt4py_backend(backend): +def ndsl_backend(backend: str) -> Backend: if backend in ("numpy"): - return "numpy" + return Backend.python() if backend in ("cupy"): - return "gt:gpu" + return Backend("st:dace:gpu:KJI") - return None + raise ValueError(f"Test backend {backend} cannot be translated into Backend") @pytest.fixture @@ -33,7 +34,7 @@ def fast(pytestconfig): @pytest.fixture -def numpy(backend): +def numpy(backend: str): if backend == "numpy": return np elif backend == "cupy": diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py index 4bbced15..df0703ac 100644 --- a/tests/dsl/orchestration/test_call.py +++ b/tests/dsl/orchestration/test_call.py @@ -2,6 +2,7 @@ from ndsl import NDSLRuntime, QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval @@ -84,7 +85,7 @@ def test_default_types_are_compiletime(): def test_dace_call_argument_caching(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - 5, 5, 2, 0, backend="dace:cpu_kfirst" + 5, 5, 2, 0, backend=Backend.performance_cpu() ) dconfig = stencil_factory.config.dace_config diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index 5b8ef1d7..239febce 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -16,6 +16,7 @@ StencilFactory, ) from ndsl.comm.mpi import MPI +from ndsl.config import Backend from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval from ndsl.dsl.stencil import CompareToNumpyStencil, FrozenStencil @@ -28,7 +29,7 @@ def _stencil(inp: Field[float], out: Field[float]): def _build_stencil( - backend: str, orchestrated: DaCeOrchestration + backend: Backend, orchestrated: DaCeOrchestration | None ) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: # Make stencil and verify it ran grid_indexing = GridIndexing( @@ -55,7 +56,7 @@ def _build_stencil( class OrchestratedProgram: - def __init__(self, backend, orchestration: DaCeOrchestration): + def __init__(self, backend: Backend, orchestration: DaCeOrchestration | None): self.stencil, grid_indexing, stencil_config = _build_stencil( backend, orchestration ) @@ -72,7 +73,7 @@ def __call__(self): ) def test_relocatability_orchestration() -> None: # Compile on default - p0 = OrchestratedProgram("dace:cpu", DaCeOrchestration.BuildAndRun) + p0 = OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.BuildAndRun) p0() expected_cache_dir = ( @@ -91,8 +92,7 @@ def test_relocatability_orchestration_tmpdir(tmpdir) -> None: gt_config.cache_settings["root_path"] = tmpdir # Compile in temporary directory that is only available in this test session. - backend = "dace:cpu" - p1 = OrchestratedProgram(backend, DaCeOrchestration.BuildAndRun) + p1 = OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.BuildAndRun) p1() expected_cache_dir = ( @@ -108,14 +108,14 @@ def test_relocatability_orchestration_tmpdir(tmpdir) -> None: relocated_path = tmpdir / ".my_relocated_cache_path" shutil.copytree(tmpdir, relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path - p2 = OrchestratedProgram(backend, DaCeOrchestration.Run) + p2 = OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.Run) p2() # Generate a file exists error to check for bad path bogus_path = "./nope/not_at_all/not_happening" gt_config.cache_settings["root_path"] = bogus_path with pytest.raises(RuntimeError): - OrchestratedProgram(backend, DaCeOrchestration.Run) + OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.Run) @pytest.mark.skipif( @@ -129,7 +129,7 @@ def test_relocatability() -> None: # Compile on default backend = "dace:cpu" - p0 = OrchestratedProgram(backend, DaCeOrchestration.Python) + p0 = OrchestratedProgram(Backend("st:dace:cpu:KIJ"), None) p0() backend_sanitized = backend.replace(":", "") @@ -158,7 +158,7 @@ def test_relocatability_tmpdir(tmpdir) -> None: # Compile in another directory backend = "dace:cpu" - p1 = OrchestratedProgram(backend, DaCeOrchestration.Python) + p1 = OrchestratedProgram(Backend("st:dace:cpu:KIJ"), None) p1() backend_sanitized = backend.replace(":", "") @@ -181,7 +181,7 @@ def test_relocatability_tmpdir(tmpdir) -> None: shutil.copytree(tmpdir / ".gt_cache_000000", relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path - p2 = OrchestratedProgram(backend, DaCeOrchestration.Python) + p2 = OrchestratedProgram(Backend("st:dace:cpu:KIJ"), None) p2() relocated_cache_path = ( diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 9f9c165c..5822635e 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -11,13 +11,14 @@ RunMode, TilePartitioner, ) +from ndsl.config import Backend def test_safety_checks(): with pytest.raises(RuntimeError): - CompilationConfig(backend="numpy", device_sync=True) + CompilationConfig(Backend.python(), device_sync=True) with pytest.raises(RuntimeError): - CompilationConfig(backend="gt:cpu_ifirst", device_sync=True) + CompilationConfig(Backend("st:gt:cpu:KJI"), device_sync=True) @pytest.mark.parametrize( @@ -162,9 +163,9 @@ def test_from_dict() -> None: assert config.run_mode == RunMode.BuildAndRun assert config.use_minimal_caching is False - specification_dict["backend"] = "gt:gpu" + specification_dict["backend"] = Backend("st:gt:gpu:KJI") config = CompilationConfig.from_dict(specification_dict) - assert config.backend == "gt:gpu" + assert config.backend == Backend("st:gt:gpu:KJI") specification_dict["rebuild"] = True config = CompilationConfig.from_dict(specification_dict) diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index e89b69c2..8003937f 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -2,6 +2,7 @@ from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner from ndsl.comm.partitioner import Partitioner +from ndsl.config import Backend from ndsl.dsl.dace.dace_config import _determine_compiling_ranks from ndsl.dsl.dace.orchestration import orchestrate, orchestrate_function @@ -18,7 +19,7 @@ def foo() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", + backend=Backend("orch:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun, ) wrapped = orchestrate_function(config=dace_config)(foo) @@ -36,8 +37,8 @@ def foo() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", - orchestration=DaCeOrchestration.Python, + backend=Backend("st:dace:cpu:KIJ"), + orchestration=None, ) wrapped = orchestrate_function(config=dace_config)(foo) with unittest.mock.patch( @@ -50,7 +51,7 @@ def foo() -> None: def test_orchestrate_calls_dace() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", + backend=Backend("orch:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun, ) @@ -72,8 +73,8 @@ def foo(self) -> None: def test_orchestrate_does_not_call_dace() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", - orchestration=DaCeOrchestration.Python, + backend=Backend("st:dace:cpu:KIJ"), + orchestration=None, ) class A: @@ -94,7 +95,7 @@ def foo(self) -> None: def test_orchestrate_distributed_build() -> None: dummy_dace_config = DaceConfig( communicator=None, - backend="gtc:dace", + backend=Backend("orch:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun, ) diff --git a/tests/dsl/test_skip_passes.py b/tests/dsl/test_skip_passes.py index 6ffb0c95..870abac1 100644 --- a/tests/dsl/test_skip_passes.py +++ b/tests/dsl/test_skip_passes.py @@ -13,6 +13,7 @@ StencilConfig, StencilFactory, ) +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -24,7 +25,7 @@ def stencil_definition(a: FloatField): def test_skip_passes_becomes_oir_pipeline() -> None: - backend = "numpy" + backend = Backend.python() dace_config = DaceConfig(None, backend) config = StencilConfig( compilation_config=CompilationConfig(backend=backend), dace_config=dace_config diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 1f7f3836..60383137 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -11,6 +11,7 @@ StencilConfig, StencilFactory, ) +from ndsl.config import Backend from ndsl.dsl.gt4py import FORWARD, PARALLEL, Field, computation, interval from ndsl.dsl.typing import ( BoolFieldIJ, @@ -36,7 +37,7 @@ def test_timing_collector() -> None: east_edge=True, ) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend="numpy", rebuild=True) + compilation_config=CompilationConfig(Backend.python(), rebuild=True) ) stencil_factory = StencilFactory(stencil_config, grid_indexing) @@ -109,7 +110,7 @@ def test_domain_size_comparison( call_count: int, ): quantity = Quantity( - np.zeros(extent), dimensions, "n/a", extent=extent, backend="debug" + np.zeros(extent), dimensions, "n/a", extent=extent, backend=Backend.debug() ) stencil = FrozenStencil( copy_stencil, @@ -150,7 +151,7 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None: def test_stencil_2D_temporaries() -> None: domain = (2, 2, 5) quantity = Quantity( - np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend=Backend.debug() ) stencil = FrozenStencil( two_dim_temporaries_stencil, @@ -169,10 +170,10 @@ def test_stencil_2D_temporaries() -> None: def test_validation_call_count(iterations: tuple[int]): domain = (2, 2, 5) quantity = Quantity( - np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend=Backend.debug() ) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend="numpy", rebuild=True) + compilation_config=CompilationConfig(Backend.python(), rebuild=True) ) stencil = FrozenStencil( copy_stencil, diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 5cdf79dc..97ecbfa1 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -1,15 +1,16 @@ import pytest from ndsl import CompilationConfig, DaceConfig, StencilConfig +from ndsl.config.backend import _BACKEND_PYTHON, Backend @pytest.mark.parametrize("validate_args", [True, False]) @pytest.mark.parametrize("rebuild", [True, False]) @pytest.mark.parametrize("format_source", [True, False]) @pytest.mark.parametrize("compare_to_numpy", [True, False]) -@pytest.mark.parametrize("backend", ["numpy", "gt:gpu"]) +@pytest.mark.parametrize("backend", [Backend.python(), Backend("st:gt:gpu:KJI")]) def test_same_config_equal( - backend: str, + backend: Backend, rebuild: bool, validate_args: bool, format_source: bool, @@ -46,7 +47,7 @@ def test_same_config_equal( def test_different_backend_not_equal( - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -71,7 +72,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - backend="debug", + Backend.debug(), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -84,7 +85,7 @@ def test_different_backend_not_equal( def test_different_rebuild_not_equal( - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -130,11 +131,11 @@ def test_different_device_sync_not_equal( ) -> None: dace_config = DaceConfig( communicator=None, - backend="gt:gpu", + backend=Backend("st:gt:gpu:KJI"), ) config = StencilConfig( compilation_config=CompilationConfig( - backend="gt:gpu", + Backend("st:gt:gpu:KJI"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -146,7 +147,7 @@ def test_different_device_sync_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - backend="gt:gpu", + Backend("st:gt:gpu:KJI"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -159,7 +160,7 @@ def test_different_device_sync_not_equal( def test_different_validate_args_not_equal( - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -197,7 +198,7 @@ def test_different_validate_args_not_equal( def test_different_format_source_not_equal( - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -234,7 +235,7 @@ def test_different_format_source_not_equal( @pytest.mark.parametrize("compare_to_numpy", [True, False]) def test_different_compare_to_numpy_not_equal( compare_to_numpy: bool, - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, device_sync: bool = False, format_source: bool = True, rebuild: bool = True, diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index ce9de962..c5acc9eb 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -9,6 +9,7 @@ StencilConfig, StencilFactory, ) +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, horizontal, interval, region from ndsl.dsl.gt4py_utils import make_storage_from_shape @@ -16,7 +17,7 @@ from ndsl.dsl.typing import Field, FloatField -BACKENDS = ["numpy", "dace:cpu"] +BACKENDS = [Backend.python(), Backend("st:dace:cpu:KIJ")] def copy_stencil(q_in: FloatField, q_out: FloatField): @@ -39,7 +40,7 @@ def add_1_in_region_stencil(q_in: FloatField, q_out: FloatField): q_out = q_in + 1.0 -def setup_data_vars(backend: str) -> tuple[Field, Field]: +def setup_data_vars(backend: Backend) -> tuple[Field, Field]: shape = (7, 7, 3) q = make_storage_from_shape(shape, backend=backend) q[:] = 1.0 @@ -48,7 +49,7 @@ def setup_data_vars(backend: str) -> tuple[Field, Field]: return q, q_ref -def get_stencil_factory(backend: str) -> StencilFactory: +def get_stencil_factory(backend: Backend) -> StencilFactory: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( @@ -72,7 +73,7 @@ def get_stencil_factory(backend: str) -> StencilFactory: @pytest.mark.parametrize("backend", BACKENDS) -def test_get_stencils_with_varied_bounds(backend: str) -> None: +def test_get_stencils_with_varied_bounds(backend: Backend) -> None: origins = [(2, 2, 0), (1, 1, 0)] domains = [(1, 1, 3), (2, 2, 3)] factory = get_stencil_factory(backend) @@ -92,7 +93,7 @@ def test_get_stencils_with_varied_bounds(backend: str) -> None: @pytest.mark.parametrize("backend", BACKENDS) -def test_get_stencils_with_varied_bounds_and_regions(backend: str) -> None: +def test_get_stencils_with_varied_bounds_and_regions(backend: Backend) -> None: factory = get_stencil_factory(backend) origins = [(3, 3, 0), (2, 2, 0)] domains = [(1, 1, 3), (2, 2, 3)] @@ -113,7 +114,7 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str) -> None: @pytest.mark.parametrize("backend", BACKENDS) -def test_stencil_vertical_bounds(backend: str) -> None: +def test_stencil_vertical_bounds(backend: Backend) -> None: factory = get_stencil_factory(backend) origins = [(3, 3, 0), (2, 2, 1)] domains = [(1, 1, 3), (2, 2, 4)] @@ -133,7 +134,7 @@ def test_stencil_vertical_bounds(backend: str) -> None: @pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_dims_halo( - backend: str, enabled: bool + backend: Backend, enabled: bool ) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( @@ -170,7 +171,7 @@ def test_stencil_factory_numpy_comparison_from_dims_halo( @pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_origin_domain( - backend: str, enabled: bool + backend: Backend, enabled: bool ) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( @@ -203,7 +204,9 @@ def test_stencil_factory_numpy_comparison_from_origin_domain( @pytest.mark.parametrize("backend", BACKENDS) -def test_stencil_factory_numpy_comparison_runs_without_exceptions(backend: str) -> None: +def test_stencil_factory_numpy_comparison_runs_without_exceptions( + backend: Backend, +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( diff --git a/tests/dsl/test_stencil_tables.py b/tests/dsl/test_stencil_tables.py index 0b31826a..e0cafba8 100644 --- a/tests/dsl/test_stencil_tables.py +++ b/tests/dsl/test_stencil_tables.py @@ -1,5 +1,5 @@ +import gt4py.storage as gt_storage import numpy as np -from gt4py.storage import ones, zeros from ndsl import ( CompilationConfig, @@ -11,6 +11,7 @@ StencilFactory, orchestrate, ) +from ndsl.config import Backend from ndsl.dsl.gt4py import FORWARD, PARALLEL, Field, GlobalTable, computation, interval from ndsl.dsl.stencil import CompareToNumpyStencil from tests.dsl import utils @@ -24,7 +25,7 @@ def _stencil(inp: GlobalTable[np.int32, (5,)], out: Field[np.float64]) -> None: def _build_stencil( - backend: str, orchestrated: DaCeOrchestration + backend: Backend, orchestrated: DaCeOrchestration ) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: # Make stencil and verify it ran grid_indexing = GridIndexing( @@ -51,15 +52,19 @@ def _build_stencil( class OrchestratedProgram: - def __init__(self, backend, orchestration: DaCeOrchestration): + def __init__(self, backend: Backend, orchestration: DaCeOrchestration): self.stencil, grid_indexing, stencil_config = _build_stencil( backend, orchestration ) orchestrate(obj=self, config=stencil_config.dace_config) - self.inp = ones(shape=(5,), dtype=np.int32, backend=backend) + self.inp = gt_storage.ones( + shape=(5,), dtype=np.int32, backend=backend.as_gt4py() + ) self.inp[1] = 42 - self.out = utils.make_storage(zeros, grid_indexing, stencil_config, dtype=float) + self.out = utils.make_storage( + gt_storage.zeros, grid_indexing, stencil_config, dtype=float + ) def __call__(self): self.stencil(self.inp, self.out) @@ -67,7 +72,7 @@ def __call__(self): def test_stecil_with_table_orchestrated() -> None: program = OrchestratedProgram( - backend="dace:cpu", orchestration=DaCeOrchestration.BuildAndRun + Backend("st:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun ) # run the orchestrated stencil diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 49b7b2a9..37605c82 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -12,6 +12,7 @@ Quantity, StencilConfig, ) +from ndsl.config.backend import _BACKEND_PYTHON, Backend from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage @@ -36,8 +37,8 @@ def get_stencil_config( *, - backend: str, - orchestration: DaCeOrchestration = DaCeOrchestration.Python, + backend: Backend, + orchestration: DaCeOrchestration = DaCeOrchestration.BuildAndRun, **kwargs, ) -> StencilConfig: dace_config = DaceConfig(None, backend=backend, orchestration=orchestration) @@ -81,19 +82,19 @@ def __init__(self, *, axes: tuple[str, ...] = (), data_dims: tuple[int, ...] = ( "field_info, origin, field_origins", [ pytest.param( - {"a": MockFieldInfo(axes=("I"))}, + {"a": MockFieldInfo(axes=("I",))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1,)}, id="single_field_I", ), pytest.param( - {"a": MockFieldInfo(axes=("J"))}, + {"a": MockFieldInfo(axes=("J",))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (2,)}, id="single_field_J", ), pytest.param( - {"a": MockFieldInfo(axes=("K"))}, + {"a": MockFieldInfo(axes=("K",))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (3,)}, id="single_field_K", @@ -111,7 +112,7 @@ def __init__(self, *, axes: tuple[str, ...] = (), data_dims: tuple[int, ...] = ( id="single_field_origin_mapping", ), pytest.param( - {"a": MockFieldInfo(axes=("I", "J", "K")), "b": MockFieldInfo(axes=("I"))}, + {"a": MockFieldInfo(axes=("I", "J", "K")), "b": MockFieldInfo(axes=("I",))}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, {"_all_": (1, 2, 3), "a": (1, 2, 3), "b": (1,)}, id="two_fields_update_origin_mapping", @@ -155,7 +156,7 @@ def copy_stencil(q_in: FloatField, q_out: FloatField): @pytest.mark.parametrize("validate_args", [True, False]) def test_copy_frozen_stencil( validate_args: bool, - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = False, format_source: bool = False, device_sync: bool = False, @@ -183,7 +184,7 @@ def test_copy_frozen_stencil( def test_frozen_stencil_raises_if_given_origin( - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = False, format_source: bool = False, device_sync: bool = False, @@ -210,7 +211,7 @@ def test_frozen_stencil_raises_if_given_origin( def test_frozen_stencil_raises_if_given_domain( - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, rebuild: bool = False, format_source: bool = False, device_sync: bool = False, @@ -245,7 +246,7 @@ def test_frozen_stencil_kwargs_passed_to_init( validate_args: bool, format_source: bool, device_sync: bool, - backend: str = "numpy", + backend: Backend = _BACKEND_PYTHON, ): config = get_stencil_config( backend=backend, @@ -298,7 +299,7 @@ def field_after_parameter_stencil(q_in: FloatField, param: float, q_out: FloatFi def test_frozen_field_after_parameter() -> None: config = get_stencil_config( - backend="numpy", + backend=Backend.python(), rebuild=False, validate_args=False, format_source=False, @@ -313,20 +314,20 @@ def test_frozen_field_after_parameter() -> None: ) -@pytest.mark.parametrize("backend", ("numpy", "gt:gpu")) +@pytest.mark.parametrize("backend", (Backend.python(), Backend("st:gt:gpu:KJI"))) def test_backend_options( - backend: str, + backend: Backend, rebuild: bool = True, validate_args: bool = True, ) -> None: expected_options = { - "numpy": { + Backend.python(): { "backend": "numpy", "rebuild": True, "format_source": False, "name": "tests.dsl.test_stencil_wrapper.copy_stencil", }, - "gt:gpu": { + "st:gt:gpu:KJI": { "backend": "gt:gpu", "rebuild": True, "device_sync": False, @@ -344,7 +345,7 @@ def test_backend_options( def test_illegal_backend_options(): with pytest.raises(ValueError): - get_stencil_config(backend="illegal") + get_stencil_config(backend=Backend("bad:back:end:now")) def get_mock_quantity(): diff --git a/tests/grid/test_eta.py b/tests/grid/test_eta.py index 38e6c75b..764edc4f 100755 --- a/tests/grid/test_eta.py +++ b/tests/grid/test_eta.py @@ -10,6 +10,7 @@ SubtileGridSizer, TilePartitioner, ) +from ndsl.config import Backend from ndsl.grid import MetricTerms @@ -29,7 +30,7 @@ def test_set_hybrid_pressure_coefficients_nofile(): eta_file = Path("NULL") - backend = "numpy" + backend = Backend.python() layout = (1, 1) @@ -75,7 +76,7 @@ def test_set_hybrid_pressure_coefficients_not_mono(): eta_file = str(Path.cwd()) + "/tests/data/eta/non_mono_eta79.nc" - backend = "numpy" + backend = Backend.python() layout = (1, 1) diff --git a/tests/mpi/test_eta.py b/tests/mpi/test_eta.py index b4b8a0a8..ad035419 100644 --- a/tests/mpi/test_eta.py +++ b/tests/mpi/test_eta.py @@ -12,6 +12,7 @@ SubtileGridSizer, TilePartitioner, ) +from ndsl.config import Backend from ndsl.grid import MetricTerms from tests.mpi import MPI @@ -29,7 +30,7 @@ def test_set_hybrid_pressure_coefficients_correct(levels): eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" eta_data = xr.open_dataset(eta_file) - backend = "numpy" + backend = Backend.python() layout = (1, 1) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 34a3b0dd..6627ae2e 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -9,6 +9,7 @@ ) from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.mpi import MPIComm +from ndsl.config import Backend from ndsl.dsl.typing import Float from tests.mpi import MPI @@ -49,7 +50,11 @@ def communicator(cube_partitioner): @pytest.mark.skipif(MPI is None, reason="pytest is not run in parallel") def test_all_reduce(communicator): - backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] + backends = [ + Backend("st:dace:cpu:KIJ"), + Backend("st:gt:cpu:IJK"), + Backend.python(), + ] for backend in backends: base_array = np.array([i for i in range(5)], dtype=Float) diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 45278cc5..e395e74f 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -10,6 +10,7 @@ ) from ndsl.comm._boundary_utils import get_boundary_slice from ndsl.comm.mpi import MPIComm +from ndsl.config import Backend from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -272,7 +273,12 @@ def depth_quantity( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan return Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.debug(), ) @@ -315,7 +321,12 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype): outside of it.""" data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.debug(), ) quantity.view[:] = 0.0 return quantity diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index 7ba2eddb..8ced20fa 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -3,6 +3,7 @@ from ndsl import Quantity from ndsl.comm._boundary_utils import _shift_boundary_slice, get_boundary_slice +from ndsl.config import Backend from ndsl.constants import ( EAST, NORTH, @@ -36,7 +37,7 @@ def test_boundary_data_1_by_1_array_1_halo(): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ) for side in ( WEST, @@ -72,7 +73,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy): units="m", origin=(1, 1, 1), extent=(1, 1, 1), - backend="debug", + backend=Backend.debug(), ) for side in ( WEST, @@ -111,7 +112,7 @@ def test_boundary_data_2_by_2_array_2_halo(): units="m", origin=(2, 2), extent=(2, 2), - backend="debug", + backend=Backend.debug(), ) for side in ( WEST, diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index f0c7bb13..3498f326 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -4,6 +4,7 @@ import numpy as np from ndsl import Quantity +from ndsl.config import Backend def test_deepcopy_copy_is_editable_by_view(): @@ -14,7 +15,7 @@ def test_deepcopy_copy_is_editable_by_view(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend="debug", + backend=Backend.debug(), ) quantity_copy = copy.deepcopy(quantity) # assertion below is only valid if we're overwriting the entire data through view @@ -32,7 +33,7 @@ def test_deepcopy_copy_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend="debug", + backend=Backend.debug(), ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 @@ -48,7 +49,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend="debug", + backend=Backend.debug(), ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py index e5f53be3..86fc062a 100644 --- a/tests/quantity/test_local.py +++ b/tests/quantity/test_local.py @@ -12,6 +12,7 @@ StencilFactory, ) from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float @@ -25,7 +26,7 @@ def test_dace_data_descriptor_is_transient() -> None: extent=(nx,), dims=("dim_X",), units="n/a", - backend="debug", + backend=Backend.debug(), ) array = local.__descriptor__() assert array.transient @@ -88,7 +89,7 @@ def __call__(self) -> None: def test_proper_initialization() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - 3, 3, 5, 0, backend="debug" + 3, 3, 5, 0, Backend.debug() ) the_code = TheCode(stencil_factory, quantity_factory) assert the_code.check_local_right_after_init() @@ -96,7 +97,7 @@ def test_proper_initialization() -> None: def test_forbidden_access_to_locals() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - 3, 3, 5, 0, backend="debug" + 3, 3, 5, 0, Backend.debug() ) the_code = TheCode(stencil_factory, quantity_factory) @@ -128,7 +129,7 @@ def test_forbidden_access_to_locals() -> None: def test_local_state_as_regular_state() -> None: - _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, backend="debug") + _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, Backend.debug()) with pytest.raises( RuntimeError, match="LocalState allocated outside of NDSLRuntime: forbidden", diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index e2c00ddd..b4f81333 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -2,6 +2,7 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.quantity.bounds import _shift_slice @@ -62,7 +63,12 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): return Quantity( - data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend=Backend.debug(), ) @@ -82,7 +88,7 @@ def test_smaller_data_raises(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend="debug", + backend=Backend.debug(), ) @@ -96,7 +102,7 @@ def test_smaller_dims_raises(data, origin, extent, dims, units): extent=extent, dims=dims[:-1], units=units, - backend="debug", + backend=Backend.debug(), ) @@ -108,7 +114,7 @@ def test_smaller_origin_raises(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend="debug", + backend=Backend.debug(), ) @@ -120,7 +126,7 @@ def test_smaller_extent_raises(data, origin, extent, dims, units): extent=extent[:-1], dims=dims, units=units, - backend="debug", + backend=Backend.debug(), ) @@ -260,15 +266,18 @@ def test_shift_slice(slice_in, shift, extent, slice_out): @pytest.mark.parametrize( "quantity", [ - Quantity(np.array(5), dims=[], units="", backend="debug"), + Quantity(np.array(5), dims=[], units="", backend=Backend.debug()), Quantity( - np.array([1, 2, 3]), dims=["dimension"], units="degK", backend="debug" + np.array([1, 2, 3]), + dims=["dimension"], + units="degK", + backend=Backend.debug(), ), Quantity( np.random.randn(3, 2, 4), dims=["dim1", "dim_2", "dimension_3"], units="m", - backend="debug", + backend=Backend.debug(), ), Quantity( np.random.randn(8, 6, 6), @@ -276,7 +285,7 @@ def test_shift_slice(slice_in, shift, extent, slice_out): units="km", origin=(2, 2, 2), extent=(4, 2, 2), - backend="debug", + backend=Backend.debug(), ), ], ) @@ -292,7 +301,7 @@ def test_to_data_array(quantity): def test_data_setter(): - quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend="debug") + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend=Backend.debug()) # After allocation - field and data are the same (origin is 0) assert quantity.data.shape == quantity.field.shape diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py index d001a744..3dbc0cab 100644 --- a/tests/quantity/test_state.py +++ b/tests/quantity/test_state.py @@ -5,6 +5,7 @@ from ndsl import Quantity, State from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, Float @@ -59,7 +60,7 @@ class InnerB: def test_basic_state(tmpdir): K_size = 3 _, quantity_factory = get_factories_single_tile( - 5, 5, K_size, 0, backend="dace:cpu_KJI" + 5, 5, K_size, 0, Backend("st:dace:cpu:KJI") ) # Test allocator @@ -140,7 +141,7 @@ class InnerB: def test_state_ddim(): _, quantity_factory = get_factories_single_tile( - 5, 5, 3, 0, backend="dace:cpu_kfirst" + 5, 5, 3, 0, backend=Backend("st:dace:cpu:IJK") ) # Test allocator diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 6d8dc4a4..952f4bff 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -2,6 +2,7 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.optional_imports import cupy as cp @@ -54,7 +55,12 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): return Quantity( - data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend=Backend.debug(), ) @@ -74,7 +80,7 @@ def test_modifying_numpy_data_modifies_view_and_field(): extent=shape, dims=["dim1", "dim2"], units="units", - backend="numpy", + backend=Backend.python(), ) assert np.all(quantity.data == 0) quantity.data[0, 0] = 1 @@ -101,7 +107,7 @@ def test_data_and_field_access_right_full_array_and_compute_domain(): extent=(5, 5), dims=["dim1", "dim2"], units="units", - backend="numpy", + backend=Backend.python(), ) assert np.all(quantity.data == 0) # Write compute domain - test data is written with the offset @@ -133,7 +139,7 @@ def test_field_exists(quantity, backend): @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_accessing_data_does_not_break_view( - data, origin, extent, dims, units, gt4py_backend + data, origin, extent, dims, units, ndsl_backend ): quantity = Quantity( data, @@ -141,7 +147,7 @@ def test_accessing_data_does_not_break_view( extent=extent, dims=dims, units=units, - backend=gt4py_backend, + backend=ndsl_backend, ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] @@ -151,7 +157,7 @@ def test_accessing_data_does_not_break_view( # run using cupy backend even though unused, to mark this as a "gpu" test @pytest.mark.parametrize("backend", ["cupy"], indirect=True) def test_numpy_data_becomes_cupy_with_gpu_backend( - data, origin, extent, dims, units, gt4py_backend + data, origin, extent, dims, units, ndsl_backend ): cpu_data = np.zeros(data.shape) quantity = Quantity( @@ -160,6 +166,6 @@ def test_numpy_data_becomes_cupy_with_gpu_backend( extent=extent, dims=dims, units=units, - backend=gt4py_backend, + backend=ndsl_backend, ) assert isinstance(quantity.data, cp.ndarray) diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index b7ceb0f8..278e3522 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -1,6 +1,7 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.constants import ( X_DIM, X_DIMS, @@ -86,7 +87,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): units="unit_string", origin=initial_origin, extent=initial_extent, - backend="debug", + backend=Backend.debug(), ) @@ -220,7 +221,10 @@ def test_transpose_invalid_cases( def test_transpose_retains_attrs(numpy): quantity = Quantity( - numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string", backend="debug" + numpy.random.randn(3, 4), + dims=["x", "y"], + units="unit_string", + backend=Backend.debug(), ) quantity._attrs = {"long_name": "500 mb height"} transposed = quantity.transpose(["y", "x"]) diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index 0ce44cfe..979dc733 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -2,15 +2,14 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM @pytest.fixture def quantity(request): return Quantity( - request.param[0], - dims=request.param[1], - units="units", + request.param[0], dims=request.param[1], units="units", backend=Backend.debug() ) @@ -183,7 +182,7 @@ def quantity(request): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ) ], ) @@ -217,7 +216,7 @@ def test_many_indices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ) ], ) @@ -744,7 +743,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (0, 0), 4, @@ -757,7 +756,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, -1), 0, @@ -770,7 +769,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-1, 0), slice(-1, 0)), np.array([[0]]), @@ -783,7 +782,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, 0), 1, @@ -804,7 +803,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-2, 0), slice(-1, 2)), np.array([[1, 2, 3], [6, 7, 8]]), @@ -842,7 +841,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, 0), 4, @@ -855,7 +854,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (0, -1), 6, @@ -868,7 +867,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -881,7 +880,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, -1), 3, @@ -902,7 +901,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -940,7 +939,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, 0), 4, @@ -953,7 +952,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (0, -1), 6, @@ -966,7 +965,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -979,7 +978,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, 0), 4, @@ -992,7 +991,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, -1), 3, @@ -1013,7 +1012,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -1051,7 +1050,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, -1), 4, @@ -1064,7 +1063,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (0, 0), 8, @@ -1077,7 +1076,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(0, 1), slice(0, 1)), np.array([[8]]), @@ -1090,7 +1089,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, -1), 4, @@ -1103,7 +1102,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (-1, 0), 5, @@ -1124,7 +1123,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-2, 0), slice(-1, 2)), np.array([[7, 8, 9], [12, 13, 14]]), @@ -1162,7 +1161,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (0, 0), 4, @@ -1175,7 +1174,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(0, 0), slice(0, 0)), 4, @@ -1188,7 +1187,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-1, 1), slice(-1, 1)), np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), @@ -1209,7 +1208,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ), (slice(-2, 0), slice(0, 1)), np.array([[2, 3], [7, 8], [12, 13]]), @@ -1230,7 +1229,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(3, 3), - backend="debug", + backend=Backend.debug(), ), (0,), np.array([6, 7, 8]), diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index 194a956e..1ca121b5 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -3,6 +3,7 @@ from ndsl import QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, set_4d_field_size @@ -19,12 +20,17 @@ @pytest.fixture def boilerplate() -> tuple[StencilFactory, QuantityFactory]: - return get_factories_single_tile(nx=1, ny=1, nz=10, nhalo=0, backend="dace:cpu") + return get_factories_single_tile( + nx=1, + ny=1, + nz=10, + nhalo=0, + backend=Backend("st:dace:cpu:KIJ"), + ) class ColumnOperations: def __init__(self, stencil_factory: StencilFactory): - def column_max_stencil( data: FloatField, max_value: FloatFieldIJ, max_index: FloatFieldIJ ): diff --git a/tests/stree_optimizer/test_merge.py b/tests/stree_optimizer/test_merge.py index 56016806..4c0f2b2f 100644 --- a/tests/stree_optimizer/test_merge.py +++ b/tests/stree_optimizer/test_merge.py @@ -2,6 +2,7 @@ from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField @@ -127,7 +128,7 @@ def push_non_cartesian_for( def test_stree_merge_maps_IJK() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + domain[0], domain[1], domain[2], 0, backend=Backend.performance_cpu() ) code = OrchestratedCode(stencil_factory, quantity_factory) @@ -206,7 +207,7 @@ def test_stree_merge_maps_IJK() -> None: def test_stree_merge_maps_KJI() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_KJI" + domain[0], domain[1], domain[2], 0, Backend("st:dace:cpu:KJI") ) code = OrchestratedCode(stencil_factory, quantity_factory) diff --git a/tests/stree_optimizer/test_pipeline.py b/tests/stree_optimizer/test_pipeline.py index 6d4c6f74..ce6fb23e 100644 --- a/tests/stree_optimizer/test_pipeline.py +++ b/tests/stree_optimizer/test_pipeline.py @@ -1,5 +1,6 @@ from ndsl import StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -28,7 +29,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField): def test_stree_roundtrip_no_opt(): domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + domain[0], domain[1], domain[2], 0, backend=Backend.performance_cpu() ) code = TriviallyMergeableCode(stencil_factory) diff --git a/tests/stree_optimizer/test_transient_refine.py b/tests/stree_optimizer/test_transient_refine.py index 2709aea5..d151ef91 100644 --- a/tests/stree_optimizer/test_transient_refine.py +++ b/tests/stree_optimizer/test_transient_refine.py @@ -1,5 +1,6 @@ from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval from ndsl.dsl.typing import Float, FloatField @@ -93,7 +94,7 @@ def do_not_refine_datadims(self, in_field: Quantity, out_field: Quantity) -> Non def test_stree_roundtrip_transient_is_refined() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + domain[0], domain[1], domain[2], 0, backend=Backend.performance_cpu() ) in_qty = quantity_factory.ones([X_DIM, Y_DIM, Z_DIM], "") diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index 0e00d459..d1288106 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -2,6 +2,7 @@ import pytest from ndsl import QuantityFactory, StencilFactory +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -44,8 +45,8 @@ def test_boilerplate_import_numpy(): ) # Ensure backend is propagated to StencilFactory and QuantityFactory - assert stencil_factory.backend == "numpy" - assert quantity_factory.backend == "numpy" + assert stencil_factory.backend == Backend.python() + assert quantity_factory.backend == Backend.python() _copy_ops(stencil_factory, quantity_factory) @@ -74,5 +75,5 @@ def test_boilerplate_non_dace_based_orchestration_raises(): with pytest.raises(ValueError, match="Only .* backends can be orchestrated."): get_factories_single_tile_orchestrated( - nx=5, ny=5, nz=2, nhalo=1, backend="numpy" + nx=5, ny=5, nz=2, nhalo=1, backend=Backend.python() ) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index de7723ce..a706ea2d 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -5,7 +5,7 @@ @pytest.fixture -def contiguous_array(numpy, backend): +def contiguous_array(numpy, backend: str): if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") array = numpy.empty([3, 4, 5]) @@ -65,7 +65,7 @@ def test_recvbuf_no_buffer(allocator, contiguous_array): assert recvbuf is contiguous_array -def test_buffer_cache_appends(allocator, backend): +def test_buffer_cache_appends(allocator, backend: str): """ Test buffer with the same key are appended while not in use for potential reuse """ @@ -91,7 +91,7 @@ def test_buffer_cache_appends(allocator, backend): assert len(BUFFER_CACHE[first_buffer._key]) == 2 -def test_buffer_reuse(allocator, backend): +def test_buffer_reuse(allocator, backend: str): """Test we reuse the buffer when available instead of reallocating one""" if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") @@ -119,7 +119,7 @@ def test_buffer_reuse(allocator, backend): Buffer.push_to_cache(repop_buffer) -def test_cacheline_differentiation(allocator, backend): +def test_cacheline_differentiation(allocator, backend: str): """Test allocation with different keys creates different cache lines""" if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") @@ -169,7 +169,9 @@ def test_cacheline_differentiation(allocator, backend): pytest.param(((10, 10, 10), float), ((10, 10, 5), float), id="different_shape"), ], ) -def test_new_args_gives_different_buffer(allocator, backend, first_args, second_args): +def test_new_args_gives_different_buffer( + allocator, backend: str, first_args, second_args +): if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") BUFFER_CACHE.clear() diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index bdbab4cd..d19a6724 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -12,6 +12,7 @@ TilePartitioner, ) from ndsl.comm import CachingCommReader, CachingCommWriter +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM @@ -29,7 +30,7 @@ def test_halo_update_integration(): units="", origin=origin, extent=extent, - backend="debug", + backend=Backend.debug(), ) for _ in range(n_ranks) ] diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 5fe4f5b9..e7c4ae6e 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -10,6 +10,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import ( HORIZONTAL_DIMS, TILE_DIM, @@ -169,7 +170,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), - backend="numpy", + backend=Backend.python(), ) diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index 52c4274e..cef35b24 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -3,6 +3,7 @@ import pytest from ndsl import GridSizer, QuantityFactory, SubtileGridSizer +from ndsl.config import Backend from ndsl.constants import ( N_HALO_DEFAULT, X_DIM, @@ -69,7 +70,7 @@ def namelist(nx_tile, ny_tile, nz, layout): def sizer( request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths ) -> GridSizer: - backend = "numpy" # original utest case + backend = Backend.python() # original utest case if request.param == "from_tile_params": return SubtileGridSizer.from_tile_params( nx_tile=nx_tile, @@ -203,7 +204,7 @@ def test_subtile_dimension_sizer_shape(sizer, dim_case): def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, backend="numpy") + allocator = QuantityFactory(sizer, backend=Backend.python()) quantity = allocator.zeros(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -214,7 +215,7 @@ def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): def test_allocator_ones(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, backend="numpy") + allocator = QuantityFactory(sizer, backend=Backend.python()) quantity = allocator.ones(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -225,7 +226,7 @@ def test_allocator_ones(numpy, sizer, dim_case, units, dtype): def test_allocator_empty(sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, backend="numpy") + allocator = QuantityFactory(sizer, backend=Backend.python()) quantity = allocator.empty(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -235,7 +236,7 @@ def test_allocator_empty(sizer, dim_case, units, dtype): def test_allocator_data_dimensions_operations(sizer): - quantity_factory = QuantityFactory(sizer, backend="numpy") + quantity_factory = QuantityFactory(sizer, backend=Backend.python()) quantity_factory.add_data_dimensions({"D0": 11}) assert "D0" in quantity_factory.sizer.data_dimensions.keys() assert quantity_factory.sizer.data_dimensions["D0"] == 11 @@ -261,7 +262,7 @@ def test_pad_non_interface_dimensions(): n_halo=0, layout=(layout_xy, layout_xy), data_dimensions={"some_dim": dd}, - backend="numpy", # original utest case + backend=Backend.python(), # original utest case ) padded_shape = padded_grid_sizer.get_shape([X_DIM, Y_DIM, Z_DIM, "some_dim"]) assert padded_shape[0] == nx // layout_xy + 1 @@ -276,7 +277,7 @@ def test_pad_non_interface_dimensions(): n_halo=0, layout=(layout_xy, layout_xy), data_dimensions={"some_dim": dd}, - backend="dace:cpu_KJI", # Fortran-friendly backend + backend=Backend("st:dace:cpu:KJI"), # Fortran-friendly backend ) non_padded_shape = non_padded_grid_sizer.get_shape( [X_DIM, Y_DIM, Z_DIM, "some_dim"] diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index b7af3e12..1d263c4a 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -16,6 +16,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.optional_imports import cupy as cp from ndsl.performance import Timer @@ -35,7 +36,7 @@ def ranks_per_tile(layout): @pytest.fixture -def total_ranks(ranks_per_tile): +def total_ranks(ranks_per_tile) -> int: return 6 * ranks_per_tile @@ -50,7 +51,7 @@ def cube_partitioner(tile_partitioner): @pytest.fixture -def cpu_communicators(cube_partitioner): +def cpu_communicators(cube_partitioner, total_ranks): shared_buffer = {} return_list = [] for rank in range(cube_partitioner.total_ranks): @@ -68,7 +69,7 @@ def cpu_communicators(cube_partitioner): @pytest.fixture -def gpu_communicators(cube_partitioner): +def gpu_communicators(cube_partitioner, total_ranks): shared_buffer = {} return_list = [] for rank in range(cube_partitioner.total_ranks): @@ -127,6 +128,7 @@ def test_halo_update_only_communicate_on_gpu(backend, gpu_communicators): units="m", origin=(3, 3, 1), extent=(3, 3, 1), + backend=Backend("st:gt:gpu:KJI"), ) halo_updater_list = [] for communicator in gpu_communicators: @@ -156,6 +158,7 @@ def test_halo_update_communicate_though_cpu(backend, cpu_communicators): units="m", origin=(3, 3, 0), extent=(3, 3, 0), + backend=Backend("st:gt:gpu:KJI"), ) halo_updater_list = [] for communicator in cpu_communicators: diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 91e10d78..bcf0783e 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -7,6 +7,7 @@ from ndsl import HaloExchangeSpec, Quantity from ndsl.buffer import Buffer from ndsl.comm import _boundary_utils +from ndsl.config import Backend from ndsl.constants import ( EAST, NORTH, @@ -161,7 +162,7 @@ def _shape_length(shape: Tuple[int]) -> int: @pytest.fixture -def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): +def quantity(dims, units, origin, extent, shape, dtype, ndsl_backend: Backend): """A list of quantities whose values are 42.42 in the computational domain and 1 outside of it.""" sz = _shape_length(shape) @@ -173,7 +174,7 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): units=units, origin=origin, extent=extent, - backend=gt4py_backend, + backend=ndsl_backend, ) diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index e4ca02df..8aa24d95 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -14,6 +14,7 @@ ) from ndsl.buffer import BUFFER_CACHE from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.config import Backend from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -319,7 +320,12 @@ def depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.debug(), ) return_list.append(quantity) return return_list @@ -352,7 +358,12 @@ def tile_depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.debug(), ) return_list.append(quantity) return return_list @@ -492,7 +503,12 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, for _rank in range(total_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.debug(), ) quantity.view[:] = 0.0 return_list.append(quantity) @@ -509,7 +525,12 @@ def zeros_quantity_tile_list( for _rank in range(single_tile_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.debug(), ) quantity.view[:] = 0.0 return_list.append(quantity) diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index 4a101f37..999b839b 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -7,6 +7,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -126,7 +127,7 @@ def rank_quantity_list(total_ranks, numpy, dtype): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.debug(), ) quantity_list.append(quantity) return quantity_list diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index c4cbe575..edddb492 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -15,6 +15,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend logger = logging.getLogger(__name__) @@ -39,7 +40,7 @@ def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): units = "m" - backend = "debug" + backend = Backend.debug() nz, ny, nx = shape ny_rank = int(ny / layout[0] + ny_rank_add) nx_rank = int(nx / layout[1] + nx_rank_add) diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index 5ee9aa7e..00662bfe 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -7,6 +7,7 @@ get_tile_index, tile_extent_from_rank_metadata, ) +from ndsl.config import Backend from ndsl.constants import ( TILE_DIM, X_DIM, @@ -984,7 +985,11 @@ def test_subtile_extent_with_tile_dimensions( ): data_array = np.zeros((tile_extent)) quantity = Quantity( - data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0], backend="debug" + data_array, + array_dims, + "dimensionless", + origin=[0, 0, 0, 0], + backend=Backend.debug(), ) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 27db0048..935a3781 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -7,6 +7,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM from ndsl.performance import Timer @@ -81,7 +82,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), - backend="debug", + backend=Backend.debug(), ) y_data = numpy.empty((2, 3), dtype=dtype) y_data[:] = rank @@ -91,7 +92,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), - backend="debug", + backend=Backend.debug(), ) quantity_list.append((x_quantity, y_quantity)) return quantity_list @@ -149,7 +150,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), - backend="debug", + backend=Backend.debug(), ) y_data = 6 * total_ranks + numpy.array([[0, 1, 2], [3, 4, 5]]) + 6 * rank y_quantity = Quantity( @@ -158,7 +159,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), - backend="debug", + backend=Backend.debug(), ) quantity_list.append((x_quantity, y_quantity)) return quantity_list diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index 9ab00eba..86ed4c17 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -1,6 +1,7 @@ import pytest from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner +from ndsl.config import Backend from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM @@ -36,13 +37,13 @@ def test_interface_state_two_by_two_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0] + 1, layout[1] + 1]), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", - backend="debug", + backend=Backend.debug(), ), "pos_i": Quantity( numpy.empty([layout[0] + 1, layout[1] + 1], dtype=numpy.int32), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", - backend="debug", + backend=Backend.debug(), ), } @@ -82,19 +83,19 @@ def test_centered_state_one_item_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", - backend="debug", + backend=Backend.debug(), ), "rank_pos_j": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", - backend="debug", + backend=Backend.debug(), ), "rank_pos_i": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", - backend="debug", + backend=Backend.debug(), ), } @@ -142,7 +143,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend="debug", + backend=Backend.debug(), ), "rank_pos_j": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -150,7 +151,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend="debug", + backend=Backend.debug(), ), "rank_pos_i": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -158,7 +159,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend="debug", + backend=Backend.debug(), ), } diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 4e87eb65..8ad0313c 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -4,6 +4,7 @@ import pytest from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner +from ndsl.config import Backend from ndsl.constants import ( HORIZONTAL_DIMS, X_DIM, @@ -150,7 +151,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), - backend="debug", + backend=Backend.debug(), ) diff --git a/tests/test_xumpy.py b/tests/test_xumpy.py index eb61cdde..80585894 100644 --- a/tests/test_xumpy.py +++ b/tests/test_xumpy.py @@ -1,35 +1,38 @@ import numpy as np import ndsl.xumpy as xp +from ndsl.config import Backend shape = (2, 2, 5) def test_xumpy_alloc(): - rand_array = xp.random(shape, backend="debug") + rand_array = xp.random(shape, Backend.debug()) assert rand_array.shape == shape - (rand_array != xp.random(shape, backend="debug")).all() + (rand_array != xp.random(shape, Backend.debug())).all() - assert (np.ones(shape) == xp.ones(shape, backend="debug")).all() - assert (np.zeros(shape) == xp.zeros(shape, backend="debug")).all() - assert (np.full(shape, 42.42) == xp.full(shape, value=42.42, backend="debug")).all() + assert (np.ones(shape) == xp.ones(shape, Backend.debug())).all() + assert (np.zeros(shape) == xp.zeros(shape, Backend.debug())).all() + assert ( + np.full(shape, 42.42) == xp.full(shape, value=42.42, backend=Backend.debug()) + ).all() def test_xumpy_minmax(): - rand_array = xp.random(shape, backend="debug") + rand_array = xp.random(shape, Backend.debug()) assert (np.max(rand_array, axis=1) == xp.max(rand_array, axis=1)).all() assert (np.min(rand_array, axis=1) == xp.min(rand_array, axis=1)).all() - out_buffer = xp.empty(shape, backend="debug") + out_buffer = xp.empty(shape, Backend.debug()) xp.max_on_horizontal_plane(rand_array, out_buffer) assert (np.max(rand_array, axis=(0, 1)) == out_buffer).all() def test_xumpy_counts(): - rand_array = xp.random(shape, backend="debug") + rand_array = xp.random(shape, Backend.debug()) rand_array[1, 1, :] = 0 assert np.count_nonzero(rand_array) == xp.count_nonzero(rand_array) diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index e584da22..3692bdc5 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -8,6 +8,7 @@ import xarray as xr from ndsl import CubedSpherePartitioner, LocalComm, MPIComm, Quantity, TilePartitioner +from ndsl.config import Backend from ndsl.constants import ( X_DIM, X_DIMS, @@ -98,7 +99,10 @@ def base_state(request, nz, ny, nx, numpy) -> dict: if request.param == "one_var_2d": return { "var1": Quantity( - numpy.ones([ny, nx]), dims=("y", "x"), units="m", backend="debug" + numpy.ones([ny, nx]), + dims=("y", "x"), + units="m", + backend=Backend.debug(), ) } @@ -108,20 +112,23 @@ def base_state(request, nz, ny, nx, numpy) -> dict: numpy.ones([nz, ny, nx]), dims=("z", "y", "x"), units="m", - backend="debug", + backend=Backend.debug(), ) } if request.param == "two_vars": return { "var1": Quantity( - numpy.ones([ny, nx]), dims=("y", "x"), units="m", backend="debug" + numpy.ones([ny, nx]), + dims=("y", "x"), + units="m", + backend=Backend.debug(), ), "var2": Quantity( numpy.ones([nz, ny, nx]), dims=("z", "y", "x"), units="degK", - backend="debug", + backend=Backend.debug(), ), } @@ -254,7 +261,7 @@ def test_monitor_file_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, - backend="debug", + backend=Backend.debug(), ), } monitor_list[rank].store(state) @@ -344,7 +351,7 @@ def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale # initialize store monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) zero_quantity = Quantity( - numpy.zeros([10, 10]), dims=("y", "x"), units="m", backend="debug" + numpy.zeros([10, 10]), dims=("y", "x"), units="m", backend=Backend.debug() ) monitor.store({"var": zero_quantity}) @@ -366,7 +373,10 @@ def test_values_preserved(cube_partitioner, numpy): # initialize store monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) quantity = Quantity( - numpy.random.uniform(size=(10, 10)), dims=dims, units=units, backend="debug" + numpy.random.uniform(size=(10, 10)), + dims=dims, + units=units, + backend=Backend.debug(), ) monitor.store({"var": quantity}) @@ -420,7 +430,7 @@ def diag(request, numpy): numpy.ones([size + 2 for size in range(len(dims))]), dims=dims, units="m", - backend="debug", + backend=Backend.debug(), ) @@ -494,7 +504,7 @@ def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): numpy.ones([size + 2 for size in range(len(diag.dims))]), dims=new_dims, units="m", - backend="debug", + backend=Backend.debug(), ) with pytest.raises(ValueError) as excinfo: zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) @@ -510,6 +520,8 @@ def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_ra diag_2 = copy.deepcopy(diag) diag_2._attrs.update({"some_non_units_attrs": 9.0}) zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) - diag_3 = Quantity(data=diag.view[:], dims=diag.dims, units="not_m", backend="debug") + diag_3 = Quantity( + data=diag.view[:], dims=diag.dims, units="not_m", backend=Backend.debug() + ) with pytest.raises(ValueError): zarr_monitor_single_rank.store({"time": time_3, "a": diag_3}) From d4be82b621c49522c1f66888582e9fb9daffe4ef Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 11:48:26 -0500 Subject: [PATCH 11/47] Save Backend as string in Quantity.attrs --- ndsl/quantity/quantity.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 545b4a34..4d888226 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -255,7 +255,9 @@ def backend(self) -> Backend: @property def attrs(self) -> dict: - return dict(**self._attrs, units=self.units, backend=self.backend) + return dict( + **self._attrs, units=self.units, backend=self.backend.as_humanly_readable() + ) @property def dims(self) -> tuple[str, ...]: @@ -495,7 +497,11 @@ def _resolve_backend(data: xr.DataArray, backend: Backend | None) -> Backend: # If backend name was serialized with data, take this one if "backend" in data.attrs: - return data.attrs["backend"] + if not isinstance(data.attrs["backend"], str): + raise ValueError( + f"Quantity.attrs 'backend' must be a string, found {data.attrs['backend']}" + ) + return Backend(data.attrs["backend"]) # else, fall back to assume python-based layout. return Backend.debug() From 2fb0de5b8619f63bf31efc6451d2265ddae98b82 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 12:04:10 -0500 Subject: [PATCH 12/47] Fix bad backend given to stree merge test --- ndsl/boilerplate.py | 2 +- tests/stree_optimizer/test_merge.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index c04b0dfd..6e953471 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -100,7 +100,7 @@ def get_factories_single_tile_orchestrated( """Build the pair of (StencilFactory, QuantityFactory) for orchestrated code on a single tile topology.""" if backend is not None and not backend.is_orchestrated(): - raise ValueError("Only `dace:*` backends can be orchestrated.") + raise ValueError(f"Only `orch:*` backends can be orchestrated, got {backend}.") return _get_factories( nx=nx, diff --git a/tests/stree_optimizer/test_merge.py b/tests/stree_optimizer/test_merge.py index 4c0f2b2f..7e0c0fb1 100644 --- a/tests/stree_optimizer/test_merge.py +++ b/tests/stree_optimizer/test_merge.py @@ -207,7 +207,7 @@ def test_stree_merge_maps_IJK() -> None: def test_stree_merge_maps_KJI() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, Backend("st:dace:cpu:KJI") + domain[0], domain[1], domain[2], 0, Backend("orch:dace:cpu:KJI") ) code = OrchestratedCode(stencil_factory, quantity_factory) From ee7e872dd74a673b06c5f877febef4446ae0017c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 12:33:11 -0500 Subject: [PATCH 13/47] Introduce short name shortcuts --- ndsl/config/backend.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index b2d61ff5..1c9e82d1 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -44,11 +44,22 @@ class BackendFramework(Enum): } """Internal: match the NDSL backend names with the GT4Py names""" +_NDSL_SHORT_TO_LONG_NAME = { + "numpy": "st:python:cpu:numpy", + "debug": "st:python:cpu:debug", +} +"""Internal: match short name to long name form""" + class Backend: """Backend for NDSL""" def __init__(self, ndsl_backend: str) -> None: + # Swap short form to long form + if ndsl_backend in _NDSL_SHORT_TO_LONG_NAME.keys(): + ndsl_backend = _NDSL_SHORT_TO_LONG_NAME[ndsl_backend] + + # Checks for existence and form if ndsl_backend not in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): raise ValueError( f"Unknown {ndsl_backend}, options are {list(_NDSL_TO_GT4PY_BACKEND_NAMING.keys())}" @@ -56,8 +67,9 @@ def __init__(self, ndsl_backend: str) -> None: parts = ndsl_backend.split(":") if len(parts) != 4: raise ValueError(f"Backend {ndsl_backend} is ill-formed.") + + # Breakdown and save into internal parameters self._humanly_readable = ndsl_backend - # Split into internal parameters self._strategy = BackendStrategy(parts[0].lower()) self._framework = BackendFramework(parts[1].lower()) self._device = BackendTargetDevice(parts[2].lower()) From 305e78528b8b3218be085f7eb8038d3873aa6ec9 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 14:22:38 -0500 Subject: [PATCH 14/47] Lint --- tests/test_boilerplate.py | 4 +++- tests/test_zarr_monitor.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index 50da4d25..2784fa9d 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -15,7 +15,9 @@ def _copy_ops(stencil_factory: StencilFactory, quantity_factory: QuantityFactory qty_in.view[:] = np.indices( dimensions=quantity_factory.sizer.get_extent([I_DIM, J_DIM, K_DIM]), dtype=np.float64, - ).sum(axis=0) # Value of each entry is sum of the I and J index at each point + ).sum( + axis=0 + ) # Value of each entry is sum of the I and J index at each point # Define a stencil def copy_stencil(input_field: FloatField, output_field: FloatField): diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index 2f68de48..84a19497 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -339,9 +339,9 @@ def _assert_no_nulls(dataset: xr.Dataset): number_of_null = dataset["var"].isnull().sum().item() total_size = dataset["var"].size - assert number_of_null == 0, ( - f"Number of nulls {number_of_null}. Size of data {total_size}" - ) + assert ( + number_of_null == 0 + ), f"Number of nulls {number_of_null}. Size of data {total_size}" @pytest.mark.parametrize("mask_and_scale", [True, False]) From e7e4481323b182c2bf6fc8adfb3446ffcef3e2a2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 14:55:30 -0500 Subject: [PATCH 15/47] Fix translate: Backned is already properly built --- ndsl/stencils/testing/test_translate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 240c49b5..768d1f07 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -43,7 +43,7 @@ def process_override(threshold_overrides, testobj, test_name, backend: Backend): matches = [ spec for spec in override - if Backend(spec["backend"]) == backend and spec["platform"] == platform() + if spec["backend"] == backend and spec["platform"] == platform() ] if len(matches) == 1: match = matches[0] From e1cdc6e57b9a95c2dfaa75873751ac57cfa99f36 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 15:17:30 -0500 Subject: [PATCH 16/47] Add concatenation operations --- ndsl/config/backend.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 1c9e82d1..fbb5c2a2 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -1,6 +1,7 @@ from __future__ import annotations from enum import Enum +from typing import Any import gt4py.cartesian.backend as gt_backend @@ -97,6 +98,18 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return hash(self._humanly_readable) + def __add__(self, other: str) -> Any: + """Concatenation operators""" + if isinstance(other, Backend): + raise TypeError("OperationError: Backend cannot add to another Backend") + return str(self) + other + + def __radd__(self, other: str) -> Any: + """Concatenation operators""" + if isinstance(other, Backend): + raise TypeError("OperationError: Backend cannot add to another Backend") + return other + str(self) + @staticmethod def debug() -> Backend: return Backend("st:python:cpu:debug") From e68bcfc65349c5c1571d708fd08df2eb6a0eda66 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 15:23:41 -0500 Subject: [PATCH 17/47] Deprecate inlined `is_gpu_backend` check --- ndsl/dsl/gt4py_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 951d4f4b..6d50250a 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable, Sequence from functools import wraps from typing import Any @@ -447,6 +448,15 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None): return cp.asarray(array, dtype, order) +def is_gpu_backend(backend: Backend) -> bool: + warnings.warn( + "Function `gt4py_utils.is_gpu_backend` is deprecated, please use `Backend.is_gpu_backend()`", + category=DeprecationWarning, + stacklevel=2, + ) + return backend.is_gpu_backend() + + _FORTRAN_LOOP_LAYOUT = (2, 1, 0) """Fortran is a column-first (or stride-first) memory system, which in the internal gt4py loop layout means I (or axis[0]) has From 03b4c5056dfee1d52656478cf90bbe0252399e27 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 27 Jan 2026 16:17:03 -0500 Subject: [PATCH 18/47] Documentation --- ndsl/config/backend.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index fbb5c2a2..d16f2fc5 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -14,14 +14,14 @@ class BackendStrategy(Enum): class BackendTargetDevice(Enum): - """Targeted device""" + """Target device""" CPU = "cpu" GPU = "gpu" class BackendFramework(Enum): - """Main framework (or language) backend relies on""" + """Main lower-level framework (or language) backend relies on""" GRIDTOOLS = "gt" DACE = "dace" @@ -48,12 +48,26 @@ class BackendFramework(Enum): _NDSL_SHORT_TO_LONG_NAME = { "numpy": "st:python:cpu:numpy", "debug": "st:python:cpu:debug", + "performance_cpu": "orch:dace:cpu:IJK", + "hybrid_fortran_cpu": "orch:dace:cpu:KJI", + "performance_gpu": "orch:dace:gpu:KJI", } """Internal: match short name to long name form""" class Backend: - """Backend for NDSL""" + """Backend for NDSL. + + The backend is a string concatenating information on the intent of the user + for a given execution seperated by a ':'. + + It describes to NDSL the strategy, device and framwork to be used + on the frontend code. Additionaly, it gives a hint toward the macro-strategy + for loop ordering (IJK, KJI, etc.) or a more broad intent (debug, numpy). + + For convenience, shorcuts are given to the most common needs (see internal + `_NDSL_SHORT_TO_LONG_NAME`). + """ def __init__(self, ndsl_backend: str) -> None: # Swap short form to long form @@ -112,23 +126,23 @@ def __radd__(self, other: str) -> Any: @staticmethod def debug() -> Backend: - return Backend("st:python:cpu:debug") + return Backend("debug") @staticmethod def python() -> Backend: - return Backend("st:python:cpu:numpy") + return Backend("numpy") @staticmethod def performance_cpu() -> Backend: - return Backend("orch:dace:cpu:IJK") + return Backend("performance_cpu") @staticmethod def hybrid_fortran_cpu() -> Backend: - return Backend("orch:dace:cpu:KJI") + return Backend("hybrid_fortran_cpu") @staticmethod def performance_gpu() -> Backend: - return Backend("orch:dace:gpu:KJI") + return Backend("performance_gpu") @property def device(self) -> BackendTargetDevice: @@ -139,6 +153,7 @@ def framework(self) -> BackendFramework: return self._framework def as_gt4py(self) -> str: + """Given a NDSL backend, give back a GT4Py equivalent""" if self._humanly_readable in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): return _NDSL_TO_GT4PY_BACKEND_NAMING[self._humanly_readable] raise ValueError( @@ -161,7 +176,7 @@ def is_gpu_backend(self) -> bool: return self._device == BackendTargetDevice.GPU -# Those two internal values are used for default parameters values +# Those two internal values are used for default parameters values in functions/methods # as it is bad practice to call a function in default argument value _BACKEND_PERFORMANCE_CPU = Backend.performance_cpu() """Internal: cache performance CPU. Please use Backend.performance_cpu().""" From bc5590ea8186b0475d3722b6d3f35af5d2ad67fe Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 12:51:00 -0500 Subject: [PATCH 19/47] Rework shortcuts for Final[Backend] variable --- ndsl/boilerplate.py | 6 +-- ndsl/config/__init__.py | 14 +++++- ndsl/config/backend.py | 50 +++++++------------ ndsl/dsl/stencil_config.py | 5 +- ndsl/quantity/quantity.py | 2 +- tests/dsl/orchestration/test_call.py | 2 +- tests/dsl/test_caches.py | 8 +-- tests/stree_optimizer/test_merge.py | 2 +- tests/stree_optimizer/test_pipeline.py | 2 +- .../stree_optimizer/test_transient_refine.py | 2 +- tests/test_ndsl_runtime.py | 4 +- 11 files changed, 46 insertions(+), 51 deletions(-) diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index 6e953471..51a672a8 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -15,7 +15,7 @@ TileCommunicator, TilePartitioner, ) -from ndsl.config.backend import _BACKEND_PERFORMANCE_CPU, _BACKEND_PYTHON +from ndsl.config.backend import backend_cpu, backend_python def _get_factories( @@ -93,7 +93,7 @@ def get_factories_single_tile_orchestrated( ny: int, nz: int, nhalo: int, - backend: Backend = _BACKEND_PERFORMANCE_CPU, + backend: Backend = backend_cpu, *, orchestration_mode: DaCeOrchestration | None = None, ) -> tuple[StencilFactory, QuantityFactory]: @@ -114,7 +114,7 @@ def get_factories_single_tile_orchestrated( def get_factories_single_tile( - nx: int, ny: int, nz: int, nhalo: int, backend: Backend = _BACKEND_PYTHON + nx: int, ny: int, nz: int, nhalo: int, backend: Backend = backend_python ) -> tuple[StencilFactory, QuantityFactory]: """Build the pair of (StencilFactory, QuantityFactory) for stencils on a single tile topology.""" return _get_factories( diff --git a/ndsl/config/__init__.py b/ndsl/config/__init__.py index 9236dc7d..f0328407 100644 --- a/ndsl/config/__init__.py +++ b/ndsl/config/__init__.py @@ -1,8 +1,20 @@ -from .backend import Backend, BackendFramework, BackendTargetDevice +from .backend import ( + Backend, + BackendFramework, + BackendStrategy, + BackendTargetDevice, + backend_cpu, + backend_gpu, + backend_python, +) __all__ = [ "Backend", "BackendFramework", + "BackendStrategy", "BackendTargetDevice", + "backend_python", + "backend_cpu", + "backend_gpu", ] diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index d16f2fc5..34ff0cb5 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Any +from typing import Any, Final import gt4py.cartesian.backend as gt_backend @@ -45,15 +45,6 @@ class BackendFramework(Enum): } """Internal: match the NDSL backend names with the GT4Py names""" -_NDSL_SHORT_TO_LONG_NAME = { - "numpy": "st:python:cpu:numpy", - "debug": "st:python:cpu:debug", - "performance_cpu": "orch:dace:cpu:IJK", - "hybrid_fortran_cpu": "orch:dace:cpu:KJI", - "performance_gpu": "orch:dace:gpu:KJI", -} -"""Internal: match short name to long name form""" - class Backend: """Backend for NDSL. @@ -65,15 +56,11 @@ class Backend: on the frontend code. Additionaly, it gives a hint toward the macro-strategy for loop ordering (IJK, KJI, etc.) or a more broad intent (debug, numpy). - For convenience, shorcuts are given to the most common needs (see internal - `_NDSL_SHORT_TO_LONG_NAME`). + For convenience, shorcuts are given to the most common needs ( + `backend_python`, `backend_cpu`, `backend_gpu`). """ def __init__(self, ndsl_backend: str) -> None: - # Swap short form to long form - if ndsl_backend in _NDSL_SHORT_TO_LONG_NAME.keys(): - ndsl_backend = _NDSL_SHORT_TO_LONG_NAME[ndsl_backend] - # Checks for existence and form if ndsl_backend not in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): raise ValueError( @@ -124,24 +111,19 @@ def __radd__(self, other: str) -> Any: raise TypeError("OperationError: Backend cannot add to another Backend") return other + str(self) - @staticmethod - def debug() -> Backend: - return Backend("debug") - @staticmethod def python() -> Backend: - return Backend("numpy") + """Default backend for quick iterative work.""" + return Backend("debug") @staticmethod - def performance_cpu() -> Backend: + def cpu() -> Backend: + """Default performance backend targeting CPU device""" return Backend("performance_cpu") @staticmethod - def hybrid_fortran_cpu() -> Backend: - return Backend("hybrid_fortran_cpu") - - @staticmethod - def performance_gpu() -> Backend: + def gpu() -> Backend: + """Default performance backend targeting GPU device""" return Backend("performance_gpu") @property @@ -176,9 +158,11 @@ def is_gpu_backend(self) -> bool: return self._device == BackendTargetDevice.GPU -# Those two internal values are used for default parameters values in functions/methods -# as it is bad practice to call a function in default argument value -_BACKEND_PERFORMANCE_CPU = Backend.performance_cpu() -"""Internal: cache performance CPU. Please use Backend.performance_cpu().""" -_BACKEND_PYTHON = Backend.python() -"""Internal: cache performance CPU. Please use Backend.python().""" +backend_python: Final[Backend] = Backend("st:python:cpu:debug") +"""Default backend for quick iterative work.""" + +backend_cpu: Final[Backend] = Backend("orch:dace:cpu:IJK") +"""Default performance backend targeting CPU device""" + +backend_gpu: Final[Backend] = Backend("orch:dace:gpu:KJI") +"""Default performance backend targeting GPU device""" diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index b615625e..a4cf4b6a 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -12,8 +12,7 @@ from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches from ndsl.comm.partitioner import Partitioner -from ndsl.config import Backend, BackendTargetDevice -from ndsl.config.backend import _BACKEND_PYTHON +from ndsl.config import Backend, BackendTargetDevice, backend_python from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration @@ -33,7 +32,7 @@ class RunMode(enum.Enum): class CompilationConfig: def __init__( self, - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = True, validate_args: bool = True, format_source: bool = False, diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index c6a0f482..cc331a70 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -504,4 +504,4 @@ def _resolve_backend(data: xr.DataArray, backend: Backend | None) -> Backend: return Backend(data.attrs["backend"]) # else, fall back to assume python-based layout. - return Backend.debug() + return Backend.python() diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py index 99969880..165c7342 100644 --- a/tests/dsl/orchestration/test_call.py +++ b/tests/dsl/orchestration/test_call.py @@ -85,7 +85,7 @@ def test_default_types_are_compiletime(): def test_dace_call_argument_caching(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - 5, 5, 2, 0, backend=Backend.performance_cpu() + 5, 5, 2, 0, backend=Backend.cpu() ) dconfig = stencil_factory.config.dace_config diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index 239febce..50176d77 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -73,7 +73,7 @@ def __call__(self): ) def test_relocatability_orchestration() -> None: # Compile on default - p0 = OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.BuildAndRun) + p0 = OrchestratedProgram(Backend.cpu(), DaCeOrchestration.BuildAndRun) p0() expected_cache_dir = ( @@ -92,7 +92,7 @@ def test_relocatability_orchestration_tmpdir(tmpdir) -> None: gt_config.cache_settings["root_path"] = tmpdir # Compile in temporary directory that is only available in this test session. - p1 = OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.BuildAndRun) + p1 = OrchestratedProgram(Backend.cpu(), DaCeOrchestration.BuildAndRun) p1() expected_cache_dir = ( @@ -108,14 +108,14 @@ def test_relocatability_orchestration_tmpdir(tmpdir) -> None: relocated_path = tmpdir / ".my_relocated_cache_path" shutil.copytree(tmpdir, relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path - p2 = OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.Run) + p2 = OrchestratedProgram(Backend.cpu(), DaCeOrchestration.Run) p2() # Generate a file exists error to check for bad path bogus_path = "./nope/not_at_all/not_happening" gt_config.cache_settings["root_path"] = bogus_path with pytest.raises(RuntimeError): - OrchestratedProgram(Backend.performance_cpu(), DaCeOrchestration.Run) + OrchestratedProgram(Backend.cpu(), DaCeOrchestration.Run) @pytest.mark.skipif( diff --git a/tests/stree_optimizer/test_merge.py b/tests/stree_optimizer/test_merge.py index 26df44e2..791ab7cf 100644 --- a/tests/stree_optimizer/test_merge.py +++ b/tests/stree_optimizer/test_merge.py @@ -128,7 +128,7 @@ def push_non_cartesian_for( def test_stree_merge_maps_IJK() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend=Backend.performance_cpu() + domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) code = OrchestratedCode(stencil_factory, quantity_factory) diff --git a/tests/stree_optimizer/test_pipeline.py b/tests/stree_optimizer/test_pipeline.py index 39f57769..9cab38f0 100644 --- a/tests/stree_optimizer/test_pipeline.py +++ b/tests/stree_optimizer/test_pipeline.py @@ -29,7 +29,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField): def test_stree_roundtrip_no_opt(): domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend=Backend.performance_cpu() + domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) code = TriviallyMergeableCode(stencil_factory) diff --git a/tests/stree_optimizer/test_transient_refine.py b/tests/stree_optimizer/test_transient_refine.py index d9b53c68..1e550d65 100644 --- a/tests/stree_optimizer/test_transient_refine.py +++ b/tests/stree_optimizer/test_transient_refine.py @@ -94,7 +94,7 @@ def do_not_refine_datadims(self, in_field: Quantity, out_field: Quantity) -> Non def test_stree_roundtrip_transient_is_refined() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend=Backend.performance_cpu() + domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py index 58bb86ef..643381dd 100644 --- a/tests/test_ndsl_runtime.py +++ b/tests/test_ndsl_runtime.py @@ -72,7 +72,7 @@ def test_runtime_make_local() -> None: def test_runtime_has_orchestracted_call() -> None: stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - 5, 5, 3, 0, backend=Backend.performance_cpu() + 5, 5, 3, 0, backend=Backend.cpu() ) A_ = quantity_factory.ones(dims=[I_DIM, J_DIM, K_DIM], units="n/a") B_ = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="n/a") @@ -88,7 +88,7 @@ def test_runtime_has_orchestracted_call() -> None: def test_runtime_does_not_orchestrate_when_call_is_not_present() -> None: stencil_factory, _ = get_factories_single_tile_orchestrated( - 5, 5, 3, 0, backend=Backend.performance_cpu() + 5, 5, 3, 0, backend=Backend.cpu() ) code = Code_NoCall(stencil_factory) From d0381a8d562a5bb2c10672c8b88d7e40ef34e0bb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 13:09:42 -0500 Subject: [PATCH 20/47] Move `is_fortran_asligned` to Backend + Lint --- ndsl/config/backend.py | 36 +++++++++++++++++++++++++----------- ndsl/dsl/dace/dace_config.py | 3 +-- ndsl/dsl/gt4py_utils.py | 32 ++++++++------------------------ 3 files changed, 34 insertions(+), 37 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 34ff0cb5..b8d3a0fc 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -45,6 +45,17 @@ class BackendFramework(Enum): } """Internal: match the NDSL backend names with the GT4Py names""" +_FORTRAN_LOOP_LAYOUT = (2, 1, 0) +"""Fortran is a column-first (or stride-first) memory system, +which in the internal gt4py loop layout means I (or axis[0]) has +the higher value, e.g. "higher importance to run first": + +for k # Layout=0 + for j # Layout=1 + for i # Layout=2 + +""" + class Backend: """Backend for NDSL. @@ -52,8 +63,8 @@ class Backend: The backend is a string concatenating information on the intent of the user for a given execution seperated by a ':'. - It describes to NDSL the strategy, device and framwork to be used - on the frontend code. Additionaly, it gives a hint toward the macro-strategy + It describes to NDSL the strategy, device and framework to be used + on the frontend code. Additionally, it gives a hint toward the macro-strategy for loop ordering (IJK, KJI, etc.) or a more broad intent (debug, numpy). For convenience, shorcuts are given to the most common needs ( @@ -62,7 +73,7 @@ class Backend: def __init__(self, ndsl_backend: str) -> None: # Checks for existence and form - if ndsl_backend not in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): + if ndsl_backend not in _NDSL_TO_GT4PY_BACKEND_NAMING: raise ValueError( f"Unknown {ndsl_backend}, options are {list(_NDSL_TO_GT4PY_BACKEND_NAMING.keys())}" ) @@ -83,8 +94,8 @@ def __init__(self, ndsl_backend: str) -> None: and gt_backend.from_name(self.as_gt4py()).storage_info["device"] != "gpu" ): raise ValueError( - f"Coding error: NDSL backend requested {self._humanly_readable} " - f"translate to non-GPU {self.as_gt4py()} GT4Py backend" + f"NDSL backend requested ({self._humanly_readable}) tagets GPU," + f"but requests a non-GPU backend from GT4Py ({self.as_gt4py()})." ) def __str__(self) -> str: @@ -135,12 +146,8 @@ def framework(self) -> BackendFramework: return self._framework def as_gt4py(self) -> str: - """Given a NDSL backend, give back a GT4Py equivalent""" - if self._humanly_readable in _NDSL_TO_GT4PY_BACKEND_NAMING.keys(): - return _NDSL_TO_GT4PY_BACKEND_NAMING[self._humanly_readable] - raise ValueError( - f"Backend {self._humanly_readable} cannot be translate to GT4Py" - ) + """Given an NDSL backend, give back a GT4Py equivalent""" + return _NDSL_TO_GT4PY_BACKEND_NAMING[self._humanly_readable] def as_humanly_readable(self) -> str: return self._humanly_readable @@ -157,6 +164,13 @@ def is_stencil(self) -> bool: def is_gpu_backend(self) -> bool: return self._device == BackendTargetDevice.GPU + def is_fortran_aligned(self) -> bool: + """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran + striding.""" + return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name( + self.as_gt4py() + ).storage_info["layout_map"](("I", "J", "K")) + backend_python: Final[Backend] = Backend("st:python:cpu:debug") """Default backend for quick iterative work.""" diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 2a793c54..1828f898 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -130,9 +130,8 @@ class DaCeOrchestration(enum.Enum): """ Orchestration mode for DaCe - Python: python orchestration - Build: compile & save SDFG only BuildAndRun: compile & save SDFG, then run + Build: compile & save SDFG only Run: load from .so and run, will fail if .so is not available """ diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 6d50250a..7372af0b 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -457,30 +457,14 @@ def is_gpu_backend(backend: Backend) -> bool: return backend.is_gpu_backend() -_FORTRAN_LOOP_LAYOUT = (2, 1, 0) -"""Fortran is a column-first (or stride-first) memory system, -which in the internal gt4py loop layout means I (or axis[0]) has -the higher value, e.g. "higher importance to run first": - -for k # Layout=0 - for j # Layout=1 - for i # Layout=2 - -""" - - -def backend_is_fortran_aligned(backend: Backend | None) -> bool: - """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran - striding.""" - - # Dev NOTE: this is used in interfacing with NDSL (e.g. GEOS.) - - if not backend: - return False - - return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name( - backend.as_gt4py() - ).storage_info["layout_map"](("I", "J", "K")) +def backend_is_fortran_aligned(backend: Backend) -> bool: + warnings.warn( + "Function `gt4py_utils.backend_is_fortran_aligned` is deprecated " + "please use `Backend.backend_is_fortran_aligned()`", + category=DeprecationWarning, + stacklevel=2, + ) + return backend.is_fortran_aligned() def zeros(shape, dtype=Float, *, backend: Backend): From 998b3f29d208fc80d86527d1edb2f19a4d4266e2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 13:25:51 -0500 Subject: [PATCH 21/47] Remove wrong default to `| None` --- examples/mpi/zarr_monitor.py | 5 +++-- ndsl/dsl/gt4py_utils.py | 1 - ndsl/initialization/subtile_grid_sizer.py | 6 +++--- ndsl/quantity/quantity.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/mpi/zarr_monitor.py b/examples/mpi/zarr_monitor.py index 7a630438..15b8196c 100644 --- a/examples/mpi/zarr_monitor.py +++ b/examples/mpi/zarr_monitor.py @@ -12,6 +12,7 @@ TilePartitioner, ZarrMonitor, ) +from ndsl.config import backend_python from ndsl.constants import I_DIM, J_DIM, K_DIM @@ -20,9 +21,9 @@ def get_example_state(time): sizer = SubtileGridSizer( - nx=48, ny=48, nz=70, n_halo=3, data_dimensions={}, backend="debug" + nx=48, ny=48, nz=70, n_halo=3, data_dimensions={}, backend=backend_python ) - allocator = QuantityFactory(sizer, np) + allocator = QuantityFactory(sizer, backend=backend_python) air_temperature = allocator.zeros([I_DIM, J_DIM, K_DIM], units="degK") air_temperature.view[:] = np.random.randn(*air_temperature.extent) return {"time": time, "air_temperature": air_temperature} diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 7372af0b..fcc51117 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -6,7 +6,6 @@ import numpy as np import numpy.typing as npt from gt4py import storage as gt_storage -from gt4py.cartesian import backend as gt_backend from ndsl.config.backend import Backend from ndsl.constants import N_HALO_DEFAULT diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index 1310cbf2..4effe131 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -17,7 +17,7 @@ def __init__( nz: int, n_halo: int, data_dimensions: dict[str, int], - backend: Backend | None = None, + backend: Backend, ) -> None: super().__init__(nx, ny, nz, n_halo, data_dimensions) @@ -33,7 +33,7 @@ def from_tile_params( n_halo: int, layout: tuple[int, int], *, - backend: Backend | None, + backend: Backend, data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, @@ -86,7 +86,7 @@ def from_namelist( tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, *, - backend: Backend | None = None, + backend: Backend, ) -> Self: """Create a SubtileGridSizer from a Fortran namelist. diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index cc331a70..9d44326e 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -165,7 +165,7 @@ def from_data_array( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: GT4Py backend name. If given, we allocate data in a performance + backend: NDSL backend name. If given, we allocate data in a performance optimal way for this backend. Overrides any potentially saved `backend` in `data.attrs["backend"]`. """ From 5d83835d3db7676fef09bf7a402e51d3888f1088 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 13:30:12 -0500 Subject: [PATCH 22/47] Simplify the G2G comms test --- tests/test_g2g_communication.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index c352ab68..6736708a 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -35,11 +35,6 @@ def ranks_per_tile(layout): return layout[0] * layout[1] -@pytest.fixture -def total_ranks(ranks_per_tile) -> int: - return 6 * ranks_per_tile - - @pytest.fixture def tile_partitioner(layout): return TilePartitioner(layout) @@ -51,14 +46,16 @@ def cube_partitioner(tile_partitioner): @pytest.fixture -def cpu_communicators(cube_partitioner, total_ranks): +def cpu_communicators(cube_partitioner: CubedSpherePartitioner): shared_buffer = {} return_list = [] for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( comm=LocalComm( - rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer + rank=rank, + total_ranks=cube_partitioner.total_ranks, + buffer_dict=shared_buffer, ), force_cpu=True, partitioner=cube_partitioner, @@ -69,14 +66,16 @@ def cpu_communicators(cube_partitioner, total_ranks): @pytest.fixture -def gpu_communicators(cube_partitioner, total_ranks): +def gpu_communicators(cube_partitioner: CubedSpherePartitioner): shared_buffer = {} return_list = [] for rank in range(cube_partitioner.total_ranks): return_list.append( CubedSphereCommunicator( comm=LocalComm( - rank=rank, total_ranks=total_ranks, buffer_dict=shared_buffer + rank=rank, + total_ranks=cube_partitioner.total_ranks, + buffer_dict=shared_buffer, ), partitioner=cube_partitioner, force_cpu=False, From 31a8586808152831f894c2cb83a9db95cba3c323 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 13:43:39 -0500 Subject: [PATCH 23/47] Better checking of expected Error --- tests/dsl/test_compilation_config.py | 8 ++++++-- tests/dsl/test_stencil_wrapper.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 5822635e..b83e611e 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -15,9 +15,13 @@ def test_safety_checks(): - with pytest.raises(RuntimeError): + with pytest.raises( + RuntimeError, match="Device sync is true on a CPU based backend" + ): CompilationConfig(Backend.python(), device_sync=True) - with pytest.raises(RuntimeError): + with pytest.raises( + RuntimeError, match="Device sync is true on a CPU based backend" + ): CompilationConfig(Backend("st:gt:cpu:KJI"), device_sync=True) diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 37605c82..ca35d08f 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -345,7 +345,10 @@ def test_backend_options( def test_illegal_backend_options(): with pytest.raises(ValueError): - get_stencil_config(backend=Backend("bad:back:end:now")) + get_stencil_config( + backend=Backend("bad:back:end:now"), + match="Backend bad:back:end:now is not registered. Valid options are:*", + ) def get_mock_quantity(): From 8e5dcfebf4ed305d1f3f49a8fecf00cd4bf0f562 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 14:03:31 -0500 Subject: [PATCH 24/47] Lint --- ndsl/config/backend.py | 2 +- ndsl/dsl/dace/dace_config.py | 3 ++- ndsl/dsl/dace/stree/optimizations/refine_transients.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index b8d3a0fc..d053491b 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -61,7 +61,7 @@ class Backend: """Backend for NDSL. The backend is a string concatenating information on the intent of the user - for a given execution seperated by a ':'. + for a given execution separated by a ':'. It describes to NDSL the strategy, device and framework to be used on the frontend code. Additionally, it gives a hint toward the macro-strategy diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 1828f898..bb700eeb 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -7,9 +7,10 @@ import dace.config from gt4py.cartesian.config import GT4PY_COMPILE_OPT_LEVEL -from ndsl import Backend, LocalComm +from ndsl import LocalComm from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner +from ndsl.config import Backend from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.typing import get_precision diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 998a2a48..0bcd9451 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -5,7 +5,8 @@ import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl import Backend, ndsl_log +from ndsl import ndsl_log +from ndsl.config import Backend from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator From 09ae4fe49e0916d6bf2329c11f4eb836b860c77c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 14:07:10 -0500 Subject: [PATCH 25/47] Properly use global backend shortcuts within API shortcuts Update all wrong usage of `Backend.debug()` --- ndsl/config/backend.py | 6 +-- tests/dsl/test_stencil.py | 6 +-- tests/dsl/test_stencil_config.py | 2 +- tests/mpi/test_mpi_halo_update.py | 4 +- tests/quantity/test_boundary.py | 6 +-- tests/quantity/test_deepcopy.py | 6 +-- tests/quantity/test_local.py | 8 ++-- tests/quantity/test_quantity.py | 28 +++++++------- tests/quantity/test_storage.py | 2 +- tests/quantity/test_transpose.py | 4 +- tests/quantity/test_view.py | 60 +++++++++++++++--------------- tests/test_caching_comm.py | 2 +- tests/test_halo_update.py | 8 ++-- tests/test_halo_update_ranks.py | 2 +- tests/test_netcdf_monitor.py | 2 +- tests/test_partitioner.py | 2 +- tests/test_sync_shared_boundary.py | 8 ++-- tests/test_tile_scatter.py | 16 ++++---- tests/test_tile_scatter_gather.py | 2 +- tests/test_xumpy.py | 16 ++++---- tests/test_zarr_monitor.py | 26 ++++++------- 21 files changed, 109 insertions(+), 107 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index d053491b..12d35246 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -125,17 +125,17 @@ def __radd__(self, other: str) -> Any: @staticmethod def python() -> Backend: """Default backend for quick iterative work.""" - return Backend("debug") + return backend_python @staticmethod def cpu() -> Backend: """Default performance backend targeting CPU device""" - return Backend("performance_cpu") + return backend_cpu @staticmethod def gpu() -> Backend: """Default performance backend targeting GPU device""" - return Backend("performance_gpu") + return backend_gpu @property def device(self) -> BackendTargetDevice: diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index bbe34fe5..83866df7 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -118,7 +118,7 @@ def test_domain_size_comparison( call_count: int, ): quantity = Quantity( - np.zeros(extent), dimensions, "n/a", extent=extent, backend=Backend.debug() + np.zeros(extent), dimensions, "n/a", extent=extent, backend=Backend.python() ) stencil = FrozenStencil( copy_stencil, @@ -163,7 +163,7 @@ def test_stencil_2D_temporaries() -> None: [I_DIM, J_DIM, K_DIM], "n/a", extent=domain, - backend=Backend.debug(), + backend=Backend.python(), ) stencil = FrozenStencil( two_dim_temporaries_stencil, @@ -186,7 +186,7 @@ def test_validation_call_count(iterations: tuple[int]): [I_DIM, J_DIM, K_DIM], "n/a", extent=domain, - backend=Backend.debug(), + backend=Backend.python(), ) stencil_config = StencilConfig( compilation_config=CompilationConfig(Backend.python(), rebuild=True) diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 97ecbfa1..47207304 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -72,7 +72,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - Backend.debug(), + Backend.python(), rebuild=rebuild, validate_args=validate_args, format_source=format_source, diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index f1e6f575..27252a72 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -278,7 +278,7 @@ def depth_quantity( units=units, origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -326,7 +326,7 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype): units=units, origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) quantity.view[:] = 0.0 return quantity diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index 61d52d5e..943dc48c 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -37,7 +37,7 @@ def test_boundary_data_1_by_1_array_1_halo(): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ) for side in ( WEST, @@ -73,7 +73,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy): units="m", origin=(1, 1, 1), extent=(1, 1, 1), - backend=Backend.debug(), + backend=Backend.python(), ) for side in ( WEST, @@ -112,7 +112,7 @@ def test_boundary_data_2_by_2_array_2_halo(): units="m", origin=(2, 2), extent=(2, 2), - backend=Backend.debug(), + backend=Backend.python(), ) for side in ( WEST, diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index 3498f326..6baa5016 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -15,7 +15,7 @@ def test_deepcopy_copy_is_editable_by_view(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend=Backend.debug(), + backend=Backend.python(), ) quantity_copy = copy.deepcopy(quantity) # assertion below is only valid if we're overwriting the entire data through view @@ -33,7 +33,7 @@ def test_deepcopy_copy_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend=Backend.debug(), + backend=Backend.python(), ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 @@ -49,7 +49,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend=Backend.debug(), + backend=Backend.python(), ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py index 5a6bd276..4ef63710 100644 --- a/tests/quantity/test_local.py +++ b/tests/quantity/test_local.py @@ -26,7 +26,7 @@ def test_dace_data_descriptor_is_transient() -> None: extent=(nx,), dims=("dim_X",), units="n/a", - backend=Backend.debug(), + backend=Backend.python(), ) array = local.__descriptor__() assert array.transient @@ -89,7 +89,7 @@ def __call__(self) -> None: def test_proper_initialization() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - 3, 3, 5, 0, Backend.debug() + 3, 3, 5, 0, Backend.python() ) the_code = TheCode(stencil_factory, quantity_factory) assert the_code.check_local_right_after_init() @@ -97,7 +97,7 @@ def test_proper_initialization() -> None: def test_forbidden_access_to_locals() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - 3, 3, 5, 0, Backend.debug() + 3, 3, 5, 0, Backend.python() ) the_code = TheCode(stencil_factory, quantity_factory) @@ -129,7 +129,7 @@ def test_forbidden_access_to_locals() -> None: def test_local_state_as_regular_state() -> None: - _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, Backend.debug()) + _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, Backend.python()) with pytest.raises( RuntimeError, match="LocalState allocated outside of NDSLRuntime: forbidden", diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index bd9ebc5a..e81c1177 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -68,7 +68,7 @@ def quantity(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -88,7 +88,7 @@ def test_smaller_data_raises(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -102,7 +102,7 @@ def test_smaller_dims_raises(data, origin, extent, dims, units): extent=extent, dims=dims[:-1], units=units, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -114,7 +114,7 @@ def test_smaller_origin_raises(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -126,7 +126,7 @@ def test_smaller_extent_raises(data, origin, extent, dims, units): extent=extent[:-1], dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -271,18 +271,18 @@ def test_shift_slice( @pytest.mark.parametrize( "quantity", [ - Quantity(np.array(5), dims=[], units="", backend=Backend.debug()), + Quantity(np.array(5), dims=[], units="", backend=Backend.python()), Quantity( np.array([1, 2, 3]), dims=["dimension"], units="degK", - backend=Backend.debug(), + backend=Backend.python(), ), Quantity( np.random.randn(3, 2, 4), dims=["dim1", "dim_2", "dimension_3"], units="m", - backend=Backend.debug(), + backend=Backend.python(), ), Quantity( np.random.randn(8, 6, 6), @@ -290,7 +290,7 @@ def test_shift_slice( units="km", origin=(2, 2, 2), extent=(4, 2, 2), - backend=Backend.debug(), + backend=Backend.python(), ), ], ) @@ -300,13 +300,15 @@ def test_to_data_array(quantity): assert quantity.field_as_xarray.shape == quantity.extent np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: - assert ( - quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data - ), "data memory address is not equal" + assert quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data, ( + "data memory address is not equal" + ) def test_data_setter(): - quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend=Backend.debug()) + quantity = Quantity( + np.ones((5,)), dims=["dim1"], units="", backend=Backend.python() + ) # After allocation - field and data are the same (origin is 0) assert quantity.data.shape == quantity.field.shape diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 952f4bff..e414598d 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -60,7 +60,7 @@ def quantity(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ) diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 629730f1..dd168753 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -87,7 +87,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): units="unit_string", origin=initial_origin, extent=initial_extent, - backend=Backend.debug(), + backend=Backend.python(), ) @@ -224,7 +224,7 @@ def test_transpose_retains_attrs(numpy): numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string", - backend=Backend.debug(), + backend=Backend.python(), ) quantity._attrs = {"long_name": "500 mb height"} transposed = quantity.transpose(["y", "x"]) diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index 2808e9c8..49602fd3 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -9,7 +9,7 @@ @pytest.fixture def quantity(request): return Quantity( - request.param[0], dims=request.param[1], units="units", backend=Backend.debug() + request.param[0], dims=request.param[1], units="units", backend=Backend.python() ) @@ -182,7 +182,7 @@ def quantity(request): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ) ], ) @@ -216,7 +216,7 @@ def test_many_indices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ) ], ) @@ -743,7 +743,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (0, 0), 4, @@ -756,7 +756,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, -1), 0, @@ -769,7 +769,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-1, 0), slice(-1, 0)), np.array([[0]]), @@ -782,7 +782,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, 0), 1, @@ -803,7 +803,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(2, 2), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[1, 2, 3], [6, 7, 8]]), @@ -841,7 +841,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, 0), 4, @@ -854,7 +854,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (0, -1), 6, @@ -867,7 +867,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -880,7 +880,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, -1), 3, @@ -901,7 +901,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -939,7 +939,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, 0), 4, @@ -952,7 +952,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (0, -1), 6, @@ -965,7 +965,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -978,7 +978,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, 0), 4, @@ -991,7 +991,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, -1), 3, @@ -1012,7 +1012,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -1050,7 +1050,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, -1), 4, @@ -1063,7 +1063,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (0, 0), 8, @@ -1076,7 +1076,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(0, 1), slice(0, 1)), np.array([[8]]), @@ -1089,7 +1089,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, -1), 4, @@ -1102,7 +1102,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (-1, 0), 5, @@ -1123,7 +1123,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[7, 8, 9], [12, 13, 14]]), @@ -1161,7 +1161,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (0, 0), 4, @@ -1174,7 +1174,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(0, 0), slice(0, 0)), 4, @@ -1187,7 +1187,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-1, 1), slice(-1, 1)), np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), @@ -1208,7 +1208,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ), (slice(-2, 0), slice(0, 1)), np.array([[2, 3], [7, 8], [12, 13]]), @@ -1229,7 +1229,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(3, 3), - backend=Backend.debug(), + backend=Backend.python(), ), (0,), np.array([6, 7, 8]), diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index 4057962c..93cd9fd6 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -30,7 +30,7 @@ def test_halo_update_integration(): units="", origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) for _ in range(n_ranks) ] diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index cf39d355..7b284095 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -325,7 +325,7 @@ def depth_quantity_list( units=units, origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) return_list.append(quantity) return return_list @@ -363,7 +363,7 @@ def tile_depth_quantity_list( units=units, origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) return_list.append(quantity) return return_list @@ -508,7 +508,7 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, units=units, origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) quantity.view[:] = 0.0 return_list.append(quantity) @@ -530,7 +530,7 @@ def zeros_quantity_tile_list( units=units, origin=origin, extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ) quantity.view[:] = 0.0 return_list.append(quantity) diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index b3a4bc7c..372cc67d 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -127,7 +127,7 @@ def rank_quantity_list(total_ranks, numpy, dtype): units="m", origin=(1, 1), extent=(1, 1), - backend=Backend.debug(), + backend=Backend.python(), ) quantity_list.append(quantity) return quantity_list diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index b8a881de..41e0760e 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -41,7 +41,7 @@ def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): units = "m" - backend = Backend.debug() + backend = Backend.python() nz, ny, nx = shape ny_rank = int(ny / layout[0] + ny_rank_add) nx_rank = int(nx / layout[1] + nx_rank_add) diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index ae2ee199..f59a3332 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -989,7 +989,7 @@ def test_subtile_extent_with_tile_dimensions( array_dims, "dimensionless", origin=[0, 0, 0, 0], - backend=Backend.debug(), + backend=Backend.python(), ) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 223ed815..c3f0a884 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -82,7 +82,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), - backend=Backend.debug(), + backend=Backend.python(), ) y_data = numpy.empty((2, 3), dtype=dtype) y_data[:] = rank @@ -92,7 +92,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), - backend=Backend.debug(), + backend=Backend.python(), ) quantity_list.append((x_quantity, y_quantity)) return quantity_list @@ -150,7 +150,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), - backend=Backend.debug(), + backend=Backend.python(), ) y_data = 6 * total_ranks + numpy.array([[0, 1, 2], [3, 4, 5]]) + 6 * rank y_quantity = Quantity( @@ -159,7 +159,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), - backend=Backend.debug(), + backend=Backend.python(), ) quantity_list.append((x_quantity, y_quantity)) return quantity_list diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index 134d079e..48153ade 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -37,13 +37,13 @@ def test_interface_state_two_by_two_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0] + 1, layout[1] + 1]), dims=[J_INTERFACE_DIM, I_INTERFACE_DIM], units="dimensionless", - backend=Backend.debug(), + backend=Backend.python(), ), "pos_i": Quantity( numpy.empty([layout[0] + 1, layout[1] + 1], dtype=numpy.int32), dims=[J_INTERFACE_DIM, I_INTERFACE_DIM], units="dimensionless", - backend=Backend.debug(), + backend=Backend.python(), ), } @@ -83,19 +83,19 @@ def test_centered_state_one_item_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0], layout[1]]), dims=[J_DIM, I_DIM], units="dimensionless", - backend=Backend.debug(), + backend=Backend.python(), ), "rank_pos_j": Quantity( numpy.empty([layout[0], layout[1]]), dims=[J_DIM, I_DIM], units="dimensionless", - backend=Backend.debug(), + backend=Backend.python(), ), "rank_pos_i": Quantity( numpy.empty([layout[0], layout[1]]), dims=[J_DIM, I_DIM], units="dimensionless", - backend=Backend.debug(), + backend=Backend.python(), ), } @@ -143,7 +143,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ), "rank_pos_j": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -151,7 +151,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ), "rank_pos_i": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -159,7 +159,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend=Backend.debug(), + backend=Backend.python(), ), } diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index e47d4180..8d5f90b7 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -151,7 +151,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), - backend=Backend.debug(), + backend=Backend.python(), ) diff --git a/tests/test_xumpy.py b/tests/test_xumpy.py index 80585894..f907dbc6 100644 --- a/tests/test_xumpy.py +++ b/tests/test_xumpy.py @@ -8,31 +8,31 @@ def test_xumpy_alloc(): - rand_array = xp.random(shape, Backend.debug()) + rand_array = xp.random(shape, Backend.python()) assert rand_array.shape == shape - (rand_array != xp.random(shape, Backend.debug())).all() + (rand_array != xp.random(shape, Backend.python())).all() - assert (np.ones(shape) == xp.ones(shape, Backend.debug())).all() - assert (np.zeros(shape) == xp.zeros(shape, Backend.debug())).all() + assert (np.ones(shape) == xp.ones(shape, Backend.python())).all() + assert (np.zeros(shape) == xp.zeros(shape, Backend.python())).all() assert ( - np.full(shape, 42.42) == xp.full(shape, value=42.42, backend=Backend.debug()) + np.full(shape, 42.42) == xp.full(shape, value=42.42, backend=Backend.python()) ).all() def test_xumpy_minmax(): - rand_array = xp.random(shape, Backend.debug()) + rand_array = xp.random(shape, Backend.python()) assert (np.max(rand_array, axis=1) == xp.max(rand_array, axis=1)).all() assert (np.min(rand_array, axis=1) == xp.min(rand_array, axis=1)).all() - out_buffer = xp.empty(shape, Backend.debug()) + out_buffer = xp.empty(shape, Backend.python()) xp.max_on_horizontal_plane(rand_array, out_buffer) assert (np.max(rand_array, axis=(0, 1)) == out_buffer).all() def test_xumpy_counts(): - rand_array = xp.random(shape, Backend.debug()) + rand_array = xp.random(shape, Backend.python()) rand_array[1, 1, :] = 0 assert np.count_nonzero(rand_array) == xp.count_nonzero(rand_array) diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index 84a19497..04af0bdb 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -104,7 +104,7 @@ def base_state(request, nz, ny, nx, numpy) -> dict: numpy.ones([ny, nx]), dims=(J_DIM, X_DIM), units="m", - backend=Backend.debug(), + backend=Backend.python(), ) } @@ -114,7 +114,7 @@ def base_state(request, nz, ny, nx, numpy) -> dict: numpy.ones([nz, ny, nx]), dims=(K_DIM, J_DIM, I_DIM), units="m", - backend=Backend.debug(), + backend=Backend.python(), ) } @@ -124,13 +124,13 @@ def base_state(request, nz, ny, nx, numpy) -> dict: numpy.ones([ny, nx]), dims=(J_DIM, I_DIM), units="m", - backend=Backend.debug(), + backend=Backend.python(), ), "var2": Quantity( numpy.ones([nz, ny, nx]), dims=(K_DIM, J_DIM, I_DIM), units="degK", - backend=Backend.debug(), + backend=Backend.python(), ), } @@ -263,7 +263,7 @@ def test_monitor_file_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ), } monitor_list[rank].store(state) @@ -339,9 +339,9 @@ def _assert_no_nulls(dataset: xr.Dataset): number_of_null = dataset["var"].isnull().sum().item() total_size = dataset["var"].size - assert ( - number_of_null == 0 - ), f"Number of nulls {number_of_null}. Size of data {total_size}" + assert number_of_null == 0, ( + f"Number of nulls {number_of_null}. Size of data {total_size}" + ) @pytest.mark.parametrize("mask_and_scale", [True, False]) @@ -356,7 +356,7 @@ def test_open_zarr_without_nans(cube_partitioner, numpy, backend, mask_and_scale numpy.zeros([10, 10]), dims=(J_DIM, I_DIM), units="m", - backend=Backend.debug(), + backend=Backend.python(), ) monitor.store({"var": zero_quantity}) @@ -381,7 +381,7 @@ def test_values_preserved(cube_partitioner, numpy): numpy.random.uniform(size=(10, 10)), dims=dims, units=units, - backend=Backend.debug(), + backend=Backend.python(), ) monitor.store({"var": quantity}) @@ -435,7 +435,7 @@ def diag(request, numpy): numpy.ones([size + 2 for size in range(len(dims))]), dims=dims, units="m", - backend=Backend.debug(), + backend=Backend.python(), ) @@ -509,7 +509,7 @@ def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): numpy.ones([size + 2 for size in range(len(diag.dims))]), dims=new_dims, units="m", - backend=Backend.debug(), + backend=Backend.python(), ) with pytest.raises(ValueError) as excinfo: zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) @@ -526,7 +526,7 @@ def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_ra diag_2._attrs.update({"some_non_units_attrs": 9.0}) zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) diag_3 = Quantity( - data=diag.view[:], dims=diag.dims, units="not_m", backend=Backend.debug() + data=diag.view[:], dims=diag.dims, units="not_m", backend=Backend.python() ) with pytest.raises(ValueError): zarr_monitor_single_rank.store({"time": time_3, "a": diag_3}) From b66c2bde19b8376d4daed52810193038f2500602 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 14:07:30 -0500 Subject: [PATCH 26/47] Lint --- tests/quantity/test_quantity.py | 6 +++--- tests/test_zarr_monitor.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index e81c1177..b0ae282d 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -300,9 +300,9 @@ def test_to_data_array(quantity): assert quantity.field_as_xarray.shape == quantity.extent np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: - assert quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data, ( - "data memory address is not equal" - ) + assert ( + quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data + ), "data memory address is not equal" def test_data_setter(): diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index 04af0bdb..4e67cf73 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -339,9 +339,9 @@ def _assert_no_nulls(dataset: xr.Dataset): number_of_null = dataset["var"].isnull().sum().item() total_size = dataset["var"].size - assert number_of_null == 0, ( - f"Number of nulls {number_of_null}. Size of data {total_size}" - ) + assert ( + number_of_null == 0 + ), f"Number of nulls {number_of_null}. Size of data {total_size}" @pytest.mark.parametrize("mask_and_scale", [True, False]) From 32b43723c18cd2273b15d0783f57f5a00270161e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 14:56:16 -0500 Subject: [PATCH 27/47] Fix unit tests --- tests/dsl/test_stencil_config.py | 14 +++++++------- tests/dsl/test_stencil_wrapper.py | 12 ++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 47207304..55a1efb8 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -1,7 +1,7 @@ import pytest from ndsl import CompilationConfig, DaceConfig, StencilConfig -from ndsl.config.backend import _BACKEND_PYTHON, Backend +from ndsl.config.backend import backend_python, Backend @pytest.mark.parametrize("validate_args", [True, False]) @@ -47,7 +47,7 @@ def test_same_config_equal( def test_different_backend_not_equal( - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -72,7 +72,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - Backend.python(), + Backend("st:python:cpu:numpy"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -85,7 +85,7 @@ def test_different_backend_not_equal( def test_different_rebuild_not_equal( - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -160,7 +160,7 @@ def test_different_device_sync_not_equal( def test_different_validate_args_not_equal( - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -198,7 +198,7 @@ def test_different_validate_args_not_equal( def test_different_format_source_not_equal( - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = True, validate_args: bool = True, format_source: bool = True, @@ -235,7 +235,7 @@ def test_different_format_source_not_equal( @pytest.mark.parametrize("compare_to_numpy", [True, False]) def test_different_compare_to_numpy_not_equal( compare_to_numpy: bool, - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, device_sync: bool = False, format_source: bool = True, rebuild: bool = True, diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index ca35d08f..62d0c1a2 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -12,7 +12,7 @@ Quantity, StencilConfig, ) -from ndsl.config.backend import _BACKEND_PYTHON, Backend +from ndsl.config.backend import backend_python, Backend from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage @@ -156,7 +156,7 @@ def copy_stencil(q_in: FloatField, q_out: FloatField): @pytest.mark.parametrize("validate_args", [True, False]) def test_copy_frozen_stencil( validate_args: bool, - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = False, format_source: bool = False, device_sync: bool = False, @@ -184,7 +184,7 @@ def test_copy_frozen_stencil( def test_frozen_stencil_raises_if_given_origin( - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = False, format_source: bool = False, device_sync: bool = False, @@ -211,7 +211,7 @@ def test_frozen_stencil_raises_if_given_origin( def test_frozen_stencil_raises_if_given_domain( - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, rebuild: bool = False, format_source: bool = False, device_sync: bool = False, @@ -246,7 +246,7 @@ def test_frozen_stencil_kwargs_passed_to_init( validate_args: bool, format_source: bool, device_sync: bool, - backend: Backend = _BACKEND_PYTHON, + backend: Backend = backend_python, ): config = get_stencil_config( backend=backend, @@ -322,7 +322,7 @@ def test_backend_options( ) -> None: expected_options = { Backend.python(): { - "backend": "numpy", + "backend": "debug", "rebuild": True, "format_source": False, "name": "tests.dsl.test_stencil_wrapper.copy_stencil", From 5e6dbdb8ffc38cb62dcb4d11eb2452bf1596dc4f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 28 Jan 2026 14:56:32 -0500 Subject: [PATCH 28/47] Lint --- tests/dsl/test_stencil_config.py | 2 +- tests/dsl/test_stencil_wrapper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 55a1efb8..a59c8008 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -1,7 +1,7 @@ import pytest from ndsl import CompilationConfig, DaceConfig, StencilConfig -from ndsl.config.backend import backend_python, Backend +from ndsl.config.backend import Backend, backend_python @pytest.mark.parametrize("validate_args", [True, False]) diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 62d0c1a2..ad88a002 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -12,7 +12,7 @@ Quantity, StencilConfig, ) -from ndsl.config.backend import backend_python, Backend +from ndsl.config.backend import Backend, backend_python from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage From b18bc5eb90b17f217f4625b6bc4f1ddeaa19e052 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 2 Feb 2026 12:07:15 -0500 Subject: [PATCH 29/47] Added guardrail for Boilerplate code --- ndsl/boilerplate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index 51a672a8..5d504251 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -117,6 +117,8 @@ def get_factories_single_tile( nx: int, ny: int, nz: int, nhalo: int, backend: Backend = backend_python ) -> tuple[StencilFactory, QuantityFactory]: """Build the pair of (StencilFactory, QuantityFactory) for stencils on a single tile topology.""" + if not isinstance(backend, Backend): + raise RuntimeError(f"Backend {backend} is not of class Backend") return _get_factories( nx=nx, ny=ny, From dd8dfc814d3c26beb051d4e1e521a143891ac8a9 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 18 Feb 2026 16:48:19 -0500 Subject: [PATCH 30/47] Expose loop_order, rework python backends Fix __eq__ operator Some unit tests --- ndsl/config/backend.py | 42 +++++++++++-------- ndsl/dsl/dace/orchestration.py | 13 +++--- .../stree/optimizations/refine_transients.py | 12 +++--- ndsl/quantity/metadata.py | 2 +- tests/dsl/test_caches.py | 6 +-- tests/test_backend.py | 31 ++++++++++++++ tests/test_xumpy.py | 2 +- 7 files changed, 72 insertions(+), 36 deletions(-) create mode 100644 tests/test_backend.py diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 12d35246..d409cde6 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Final +from typing import Final import gt4py.cartesian.backend as gt_backend @@ -29,8 +29,8 @@ class BackendFramework(Enum): _NDSL_TO_GT4PY_BACKEND_NAMING = { - "st:python:cpu:debug": "debug", - "st:python:cpu:numpy": "numpy", + "st:python:cpu:IJK": "debug", + "st:numpy:cpu:IJK": "numpy", "st:gt:cpu:IJK": "gt:cpu_kfirst", "st:gt:cpu:KJI": "gt:cpu_ifirst", "st:gt:gpu:KJI": "gt:gpu", @@ -86,7 +86,7 @@ def __init__(self, ndsl_backend: str) -> None: self._strategy = BackendStrategy(parts[0].lower()) self._framework = BackendFramework(parts[1].lower()) self._device = BackendTargetDevice(parts[2].lower()) - self._loop_order = parts[3] + self._loop_order = parts[3].upper() # Check GPU capacity if ( @@ -105,23 +105,15 @@ def __repr__(self) -> str: return self.as_humanly_readable() def __eq__(self, other: object) -> bool: - return self._humanly_readable == self._humanly_readable + if not isinstance(other, Backend): + raise NotImplementedError( + f"Backend equality operator for {type(other)} is not implemented" + ) + return self._humanly_readable == other._humanly_readable def __hash__(self) -> int: return hash(self._humanly_readable) - def __add__(self, other: str) -> Any: - """Concatenation operators""" - if isinstance(other, Backend): - raise TypeError("OperationError: Backend cannot add to another Backend") - return str(self) + other - - def __radd__(self, other: str) -> Any: - """Concatenation operators""" - if isinstance(other, Backend): - raise TypeError("OperationError: Backend cannot add to another Backend") - return other + str(self) - @staticmethod def python() -> Backend: """Default backend for quick iterative work.""" @@ -145,6 +137,10 @@ def device(self) -> BackendTargetDevice: def framework(self) -> BackendFramework: return self._framework + @property + def loop_order(self) -> str: + return self._loop_order + def as_gt4py(self) -> str: """Given an NDSL backend, give back a GT4Py equivalent""" return _NDSL_TO_GT4PY_BACKEND_NAMING[self._humanly_readable] @@ -155,6 +151,13 @@ def as_humanly_readable(self) -> str: def as_safe_for_path(self) -> str: return self._humanly_readable.replace(":", "_") + def as_layout_map(self) -> tuple[int, ...]: + if self._loop_order in ["numpy", "debug"]: + return (0, 1, 2) + return tuple( + len(self._loop_order) - 1 - self._loop_order.index(axis) for axis in "IJK" + ) + def is_orchestrated(self) -> bool: return self._strategy == BackendStrategy.ORCHESTRATION @@ -167,12 +170,15 @@ def is_gpu_backend(self) -> bool: def is_fortran_aligned(self) -> bool: """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran striding.""" + + # Dev NOTE: this probably should live as an accessor directly on the + # storage_info or layout_info of GT4Py, rather than stacked up on NDSL return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name( self.as_gt4py() ).storage_info["layout_map"](("I", "J", "K")) -backend_python: Final[Backend] = Backend("st:python:cpu:debug") +backend_python: Final[Backend] = Backend("st:python:cpu:IJK") """Default backend for quick iterative work.""" backend_cpu: Final[Backend] = Backend("orch:dace:cpu:IJK") diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 184acbc4..abb7d2f1 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -172,7 +172,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): passes = [] - if backend_name.as_humanly_readable() == "orch:dace:cpu:IJK": + if backend_name.loop_order == "IJK": passes.extend( [ CleanUpScheduleTree(), @@ -182,10 +182,7 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - elif backend_name.as_humanly_readable() in [ - "orch:dace:cpu:KJI", - "orch:dace:gpu:KJI", - ]: + elif backend_name.loop_order in "KJI": passes.extend( [ CleanUpScheduleTree(), @@ -195,7 +192,7 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - else: + elif backend_name.loop_order in "KIJ": passes.extend( [ CleanUpScheduleTree(), @@ -205,6 +202,10 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) + else: + raise NotImplementedError( + f"Loop order {backend_name.loop_order()} has no schedule tree pipeline" + ) CPUPipeline(passes=passes).run(stree, verbose=True) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index ca303b47..9a8ab6ea 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -5,9 +5,9 @@ import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree +from ndsl.config import Backend from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator from ndsl.logging import ndsl_log -from ndsl.config import Backend def _change_index_of_tuple( @@ -250,15 +250,13 @@ def __init__(self, backend: Backend) -> None: stacklevel=2, ) - if backend.as_humanly_readable() in ["orch:dace:cpu:IJK"]: - self.ijk_order = (2, 1, 0) - elif backend.as_humanly_readable() in [ + if backend.as_humanly_readable() in [ + "orch:dace:cpu:IJK", "orch:dace:gpu:KJI", "orch:dace:cpu:KJI", + "orch:dace:cpu:KIJ", ]: - self.ijk_order = (0, 1, 2) - elif backend.as_humanly_readable() in ["orch:dace:cpu:KIJ"]: - self.ijk_order = (1, 2, 0) + self.ijk_order = backend.as_layout_map() else: raise NotImplementedError( f"[Schedule Tree Opt] CartesianRefineTransient not implemented for backend {backend}" diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index c7528815..409c0ca0 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -31,7 +31,7 @@ class QuantityMetadata: dtype: type "dtype of the data in the ndarray-like object." backend: Backend - "NDSL backend name. Used for performance optimal data allocation." + "NDSL backend. Used for performance optimal data allocation." @property def dim_lengths(self) -> dict[str, int]: diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index 50176d77..d877a863 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -128,11 +128,11 @@ def test_relocatability() -> None: gt_config.cache_settings["root_path"] = Path.cwd() # Compile on default - backend = "dace:cpu" - p0 = OrchestratedProgram(Backend("st:dace:cpu:KIJ"), None) + backend = Backend("st:dace:cpu:KIJ") + p0 = OrchestratedProgram(backend, None) p0() - backend_sanitized = backend.replace(":", "") + backend_sanitized = backend.as_humanly_readable().replace(":", "") python_version = f"py{sys.version_info[0]}{sys.version_info[1]}" expected_cache_path = ( Path.cwd() diff --git a/tests/test_backend.py b/tests/test_backend.py new file mode 100644 index 00000000..061eb430 --- /dev/null +++ b/tests/test_backend.py @@ -0,0 +1,31 @@ +import pytest +from ndsl import Backend + + +def test_backend_building(): + Backend("st:python:cpu:IJK") + Backend("st:numpy:cpu:IJK") + Backend("st:gt:cpu:IJK") + Backend("st:gt:cpu:KJI") + Backend("st:gt:gpu:KJI") + Backend("st:dace:cpu:IJK") + Backend("orch:dace:cpu:IJK") + Backend("st:dace:cpu:KIJ") + Backend("orch:dace:cpu:KIJ") + Backend("st:dace:cpu:KJI") + Backend("orch:dace:cpu:KJI") + Backend("st:dace:gpu:KJI") + Backend("orch:dace:gpu:KJI") + + with pytest.raises(ValueError): + Backend("bad:name:good:number") + + +def test_backend_operators(): + backend_A = Backend("st:numpy:cpu:IJK") + backend_B = Backend("st:numpy:cpu:IJK") + + assert backend_A == backend_B + assert not (backend_A != backend_B) + assert backend_A is backend_B + assert not (backend_A is not backend_B) diff --git a/tests/test_xumpy.py b/tests/test_xumpy.py index f907dbc6..cc198bb7 100644 --- a/tests/test_xumpy.py +++ b/tests/test_xumpy.py @@ -10,7 +10,7 @@ def test_xumpy_alloc(): rand_array = xp.random(shape, Backend.python()) assert rand_array.shape == shape - (rand_array != xp.random(shape, Backend.python())).all() + assert (rand_array != xp.random(shape, Backend.python())).all() assert (np.ones(shape) == xp.ones(shape, Backend.python())).all() assert (np.zeros(shape) == xp.zeros(shape, Backend.python())).all() From d02683e336e08d71dfd96acc2d2cf480c994e5ec Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 18 Feb 2026 17:09:40 -0500 Subject: [PATCH 31/47] Fix tests --- ndsl/config/backend.py | 3 +++ tests/dsl/test_compilation_config.py | 4 ++-- tests/dsl/test_stencil_config.py | 2 +- tests/quantity/test_local.py | 4 +++- tests/test_backend.py | 2 -- tests/test_boilerplate.py | 9 +++------ 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index d409cde6..2e114e05 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -26,6 +26,7 @@ class BackendFramework(Enum): GRIDTOOLS = "gt" DACE = "dace" PYTHON = "python" + NUMPY = "numpy" _NDSL_TO_GT4PY_BACKEND_NAMING = { @@ -105,6 +106,8 @@ def __repr__(self) -> str: return self.as_humanly_readable() def __eq__(self, other: object) -> bool: + if isinstance(other, str): + other = Backend(other) if not isinstance(other, Backend): raise NotImplementedError( f"Backend equality operator for {type(other)} is not implemented" diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index b83e611e..ed73db25 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -146,7 +146,7 @@ def test_determine_compiling_equivalent( def test_as_dict() -> None: config = CompilationConfig() asdict = config.as_dict() - assert asdict["backend"] == "numpy" + assert asdict["backend"] == Backend.python() assert asdict["rebuild"] is True assert asdict["validate_args"] is True assert asdict["format_source"] is False @@ -159,7 +159,7 @@ def test_as_dict() -> None: def test_from_dict() -> None: specification_dict = {} config = CompilationConfig.from_dict(specification_dict) - assert config.backend == "numpy" + assert config.backend == Backend.python() assert config.rebuild is False assert config.validate_args is True assert config.format_source is False diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index a59c8008..03f357d2 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -72,7 +72,7 @@ def test_different_backend_not_equal( different_config = StencilConfig( compilation_config=CompilationConfig( - Backend("st:python:cpu:numpy"), + Backend("st:numpy:cpu:IJK"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py index 3202651c..4be0e9ee 100644 --- a/tests/quantity/test_local.py +++ b/tests/quantity/test_local.py @@ -178,7 +178,9 @@ def test_local_state_as_regular_state() -> None: def test_nested_local_state_as_regular_state() -> None: - _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, backend="debug") + _, quantity_factory = get_factories_single_tile( + 3, 3, 5, 0, backend=Backend.python() + ) with pytest.warns(UserWarning, match="LocalState is allocated as a regular State"): nested = NestedLocals.make_as_state(quantity_factory) diff --git a/tests/test_backend.py b/tests/test_backend.py index 061eb430..8a28286d 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -27,5 +27,3 @@ def test_backend_operators(): assert backend_A == backend_B assert not (backend_A != backend_B) - assert backend_A is backend_B - assert not (backend_A is not backend_B) diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index 2784fa9d..e2cc44a0 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -15,9 +15,7 @@ def _copy_ops(stencil_factory: StencilFactory, quantity_factory: QuantityFactory qty_in.view[:] = np.indices( dimensions=quantity_factory.sizer.get_extent([I_DIM, J_DIM, K_DIM]), dtype=np.float64, - ).sum( - axis=0 - ) # Value of each entry is sum of the I and J index at each point + ).sum(axis=0) # Value of each entry is sum of the I and J index at each point # Define a stencil def copy_stencil(input_field: FloatField, output_field: FloatField): @@ -64,9 +62,8 @@ def test_boilerplate_import_orchestrated_cpu(): ) # Ensure backend is propagated to StencilFactory and QuantityFactory - assert stencil_factory.backend == "dace:cpu" - assert quantity_factory.backend == "dace:cpu" - + assert stencil_factory.backend == Backend("orch:dace:cpu:IJK") + assert quantity_factory.backend == Backend("orch:dace:cpu:IJK") _copy_ops(stencil_factory, quantity_factory) From b0b695d09249f09b1007b61be2e0136f7284b08e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 18 Feb 2026 17:12:58 -0500 Subject: [PATCH 32/47] Fix test caches --- tests/dsl/test_caches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index d877a863..d5d142d5 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -132,7 +132,7 @@ def test_relocatability() -> None: p0 = OrchestratedProgram(backend, None) p0() - backend_sanitized = backend.as_humanly_readable().replace(":", "") + backend_sanitized = backend.as_gt4py().replace(":", "") python_version = f"py{sys.version_info[0]}{sys.version_info[1]}" expected_cache_path = ( Path.cwd() From d3c0b17e1aa0a4312488771219dd2f541f5d741d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 18 Feb 2026 17:15:05 -0500 Subject: [PATCH 33/47] More fix to tests --- ndsl/dsl/dace/orchestration.py | 2 +- ndsl/dsl/dace/stree/optimizations/refine_transients.py | 2 +- tests/test_backend.py | 1 + tests/test_boilerplate.py | 4 +++- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index abb7d2f1..0bc01ab4 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -204,7 +204,7 @@ def _build_sdfg( ) else: raise NotImplementedError( - f"Loop order {backend_name.loop_order()} has no schedule tree pipeline" + f"Loop order {backend_name.loop_order} has no schedule tree pipeline" ) CPUPipeline(passes=passes).run(stree, verbose=True) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 9a8ab6ea..4a5f9928 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -30,7 +30,7 @@ def _reduce_cartesian_axis_size_to_1( transient_map_reads: dace.subsets.Range | None, transient_map_writes: dace.subsets.Range | None, transient_data: dace.data.Data, - ijk_order: tuple[int, int, int], + ijk_order: tuple[int, ...], ) -> bool: """Reduce dimension size of transient to 1 if all access (reads and writes) are atomic""" diff --git a/tests/test_backend.py b/tests/test_backend.py index 8a28286d..678af2af 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,4 +1,5 @@ import pytest + from ndsl import Backend diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index e2cc44a0..83ab77e4 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -15,7 +15,9 @@ def _copy_ops(stencil_factory: StencilFactory, quantity_factory: QuantityFactory qty_in.view[:] = np.indices( dimensions=quantity_factory.sizer.get_extent([I_DIM, J_DIM, K_DIM]), dtype=np.float64, - ).sum(axis=0) # Value of each entry is sum of the I and J index at each point + ).sum( + axis=0 + ) # Value of each entry is sum of the I and J index at each point # Define a stencil def copy_stencil(input_field: FloatField, output_field: FloatField): From 1837d965aa21c09943f04aa84ab6b3c984ea7933 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 19 Feb 2026 09:21:17 -0500 Subject: [PATCH 34/47] PR updartes: move test in config/, made loop_order an enum --- ndsl/config/__init__.py | 2 + ndsl/config/backend.py | 47 +++++++++++++------ ndsl/dsl/dace/orchestration.py | 7 +-- .../stree/optimizations/refine_transients.py | 23 +++------ ndsl/dsl/gt4py_utils.py | 2 +- ndsl/dsl/stencil_config.py | 3 -- ndsl/grid/helper.py | 3 +- ndsl/initialization/subtile_grid_sizer.py | 13 +++-- ndsl/quantity/quantity.py | 4 +- tests/{ => config}/test_backend.py | 0 .../dace/stree/optimizations/test_merge.py | 4 +- tests/dsl/test_compilation_config.py | 9 ++-- 12 files changed, 61 insertions(+), 56 deletions(-) rename tests/{ => config}/test_backend.py (100%) diff --git a/ndsl/config/__init__.py b/ndsl/config/__init__.py index f0328407..7bb6b243 100644 --- a/ndsl/config/__init__.py +++ b/ndsl/config/__init__.py @@ -1,6 +1,7 @@ from .backend import ( Backend, BackendFramework, + BackendLoopOrder, BackendStrategy, BackendTargetDevice, backend_cpu, @@ -14,6 +15,7 @@ "BackendFramework", "BackendStrategy", "BackendTargetDevice", + "BackendLoopOrder", "backend_python", "backend_cpu", "backend_gpu", diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 2e114e05..3803dd41 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from enum import Enum from typing import Final @@ -29,6 +27,17 @@ class BackendFramework(Enum): NUMPY = "numpy" +class BackendLoopOrder(Enum): + """Cartesian loop order generated by the Backend""" + + IJK = "IJK" + IKJ = "IKJ" + JKI = "JKI" + JIK = "JIK" + KIJ = "KIJ" + KJI = "KJI" + + _NDSL_TO_GT4PY_BACKEND_NAMING = { "st:python:cpu:IJK": "debug", "st:numpy:cpu:IJK": "numpy", @@ -87,12 +96,22 @@ def __init__(self, ndsl_backend: str) -> None: self._strategy = BackendStrategy(parts[0].lower()) self._framework = BackendFramework(parts[1].lower()) self._device = BackendTargetDevice(parts[2].lower()) - self._loop_order = parts[3].upper() + self._loop_order = BackendLoopOrder(parts[3].upper()) + + # Check it exists in GT4Py + try: + gt4py_backend = gt_backend.from_name(self.as_gt4py()) + except ValueError: + raise ValueError( + f"NDSL backend {ndsl_backend} does not have a working GT4Py version. " + f"GT4Py backend {self.as_gt4py()} is not registered. " + f"Contact the team." + ) # Check GPU capacity if ( self._device == BackendTargetDevice.GPU - and gt_backend.from_name(self.as_gt4py()).storage_info["device"] != "gpu" + and gt4py_backend.storage_info["device"] != "gpu" ): raise ValueError( f"NDSL backend requested ({self._humanly_readable}) tagets GPU," @@ -110,7 +129,7 @@ def __eq__(self, other: object) -> bool: other = Backend(other) if not isinstance(other, Backend): raise NotImplementedError( - f"Backend equality operator for {type(other)} is not implemented" + f"Backend equality operator for {type(other)} is not implemented." ) return self._humanly_readable == other._humanly_readable @@ -118,18 +137,18 @@ def __hash__(self) -> int: return hash(self._humanly_readable) @staticmethod - def python() -> Backend: + def python() -> "Backend": """Default backend for quick iterative work.""" return backend_python @staticmethod - def cpu() -> Backend: - """Default performance backend targeting CPU device""" + def cpu() -> "Backend": + """Default performance backend targeting CPU devices.""" return backend_cpu @staticmethod - def gpu() -> Backend: - """Default performance backend targeting GPU device""" + def gpu() -> "Backend": + """Default performance backend targeting GPU devices.""" return backend_gpu @property @@ -141,7 +160,7 @@ def framework(self) -> BackendFramework: return self._framework @property - def loop_order(self) -> str: + def loop_order(self) -> BackendLoopOrder: return self._loop_order def as_gt4py(self) -> str: @@ -155,10 +174,10 @@ def as_safe_for_path(self) -> str: return self._humanly_readable.replace(":", "_") def as_layout_map(self) -> tuple[int, ...]: - if self._loop_order in ["numpy", "debug"]: - return (0, 1, 2) + loop_order_as_string = self._loop_order.value return tuple( - len(self._loop_order) - 1 - self._loop_order.index(axis) for axis in "IJK" + len(loop_order_as_string) - 1 - loop_order_as_string.index(axis) + for axis in "IJK" ) def is_orchestrated(self) -> bool: diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 0bc01ab4..4037ca7f 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -22,6 +22,7 @@ import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI +from ndsl.config import BackendLoopOrder from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE, @@ -172,7 +173,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): passes = [] - if backend_name.loop_order == "IJK": + if backend_name.loop_order == BackendLoopOrder.IJK: passes.extend( [ CleanUpScheduleTree(), @@ -182,7 +183,7 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - elif backend_name.loop_order in "KJI": + elif backend_name.loop_order == BackendLoopOrder.KJI: passes.extend( [ CleanUpScheduleTree(), @@ -192,7 +193,7 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - elif backend_name.loop_order in "KIJ": + elif backend_name.loop_order == BackendLoopOrder.KIJ: passes.extend( [ CleanUpScheduleTree(), diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 4a5f9928..7b788e32 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import warnings import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl.config import Backend +from ndsl.config import Backend, BackendFramework from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator from ndsl.logging import ndsl_log @@ -30,7 +28,7 @@ def _reduce_cartesian_axis_size_to_1( transient_map_reads: dace.subsets.Range | None, transient_map_writes: dace.subsets.Range | None, transient_data: dace.data.Data, - ijk_order: tuple[int, ...], + layout_map: tuple[int, ...], ) -> bool: """Reduce dimension size of transient to 1 if all access (reads and writes) are atomic""" @@ -68,10 +66,10 @@ def _reduce_cartesian_axis_size_to_1( ) if len(transient_data.shape) == 3: - layout = [*ijk_order] + layout = [*layout_map] else: data_dim_count = len(transient_data.shape) - 3 - layout = [dim + data_dim_count for dim in ijk_order] + [ + layout = [dim + data_dim_count for dim in layout_map] + [ i - 1 for i in range(data_dim_count, 0, -1) ] @@ -250,18 +248,11 @@ def __init__(self, backend: Backend) -> None: stacklevel=2, ) - if backend.as_humanly_readable() in [ - "orch:dace:cpu:IJK", - "orch:dace:gpu:KJI", - "orch:dace:cpu:KJI", - "orch:dace:cpu:KIJ", - ]: - self.ijk_order = backend.as_layout_map() - else: + if not backend.is_orchestrated() or backend.framework != BackendFramework.DACE: raise NotImplementedError( f"[Schedule Tree Opt] CartesianRefineTransient not implemented for backend {backend}" ) - + self.layout_map = backend.as_layout_map() self.refined_array: set[str] = set() def __str__(self) -> str: @@ -294,7 +285,7 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: collect_map.transients_range_reads[name], collect_map.transients_range_writes[name], data, - self.ijk_order, + self.layout_map, ) refined_transient += 1 if refined else 0 diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index fcc51117..39afa30d 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -108,7 +108,7 @@ def make_storage_data( start: Starting points for slices in data copies dummy: Dummy axes axis: Axis for 2D to 3D arrays - backend: gt4py backend to use + backend: current backend in use Returns: Field[..., dtype]: New storage diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index a4cf4b6a..d2d8180f 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -6,7 +6,6 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from typing import Any, Self -import gt4py.cartesian.backend as gt_backend from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.communicator import Communicator @@ -43,8 +42,6 @@ def __init__( ) -> None: if backend.device is BackendTargetDevice.CPU and device_sync is True: raise RuntimeError("Device sync is true on a CPU based backend") - # GT4Py backend check - expect GT4Py to raise if the backend doesn't exist - gt_backend.from_name(backend.as_gt4py()) self.backend = backend self.rebuild = rebuild self.validate_args = validate_args diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index b4048672..b4eca405 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -232,8 +232,7 @@ def ptop(self) -> Float: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") if self.ak.backend is not None and self.ak.backend.is_gpu_backend(): return Float(self.ak.view[0].get()) - else: - return Float(self.ak.view[0]) + return Float(self.ak.view[0]) @dataclasses.dataclass(frozen=True) diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index 4effe131..4a257080 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -5,7 +5,6 @@ from ndsl.comm.partitioner import TilePartitioner from ndsl.config import Backend from ndsl.constants import N_HALO_DEFAULT -from ndsl.dsl.gt4py_utils import backend_is_fortran_aligned from ndsl.initialization.grid_sizer import GridSizer @@ -21,7 +20,7 @@ def __init__( ) -> None: super().__init__(nx, ny, nz, n_halo, data_dimensions) - fortran_style_memory = backend_is_fortran_aligned(backend) + fortran_style_memory = backend.is_fortran_aligned() self._pad_non_interface_dimensions = not fortran_style_memory @classmethod @@ -46,12 +45,12 @@ def from_tile_params( nz: number of vertical levels n_halo: number of halo points layout: (y, x) number of ranks along tile edges - backend: backend name + backend: current backend in use data_dimensions: lengths of any non-x/y/z dimensions, such as land or radiation dimensions tile_partitioner (optional): partitioner object for the tile. By default, a TilePartitioner is created with the given layout - tile_rank (optional): rank of this subtile. + tile_rank (optional): rank of this subtile """ if data_dimensions is None: data_dimensions = {} @@ -93,11 +92,11 @@ def from_namelist( Args: namelist: A namelist for the fv3gfs fortran model tile_partitioner (optional): a partitioner to use for segmenting the tile. - By default, a TilePartitioner is used. + By default, a TilePartitioner is used tile_rank (optional): current rank on tile. Default is 0. Only matters if different ranks have different domain shapes. If tile_partitioner - is a TilePartitioner, this argument does not matter. - backend: backend name + is a TilePartitioner, this argument does not matter + backend: current backend in use """ if "fv_core_nml" in namelist.keys(): layout = namelist["fv_core_nml"]["layout"] diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 2ce5cdba..71fa89d3 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -46,7 +46,7 @@ def __init__( data: ndarray-like object containing the underlying data dims: dimension names for each axis units: units of the quantity - backend: GT4Py backend name. We ensure that the data is allocated in a + backend: current backend in use. We ensure that the data is allocated in a performance optimal way for that backend and copy if necessary. origin: first point in data within the computational domain. Defaults to None. @@ -165,7 +165,7 @@ def from_data_array( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: NDSL backend name. If given, we allocate data in a performance + backend: current backend in use. If given, we allocate data in a performance optimal way for this backend. Overrides any potentially saved `backend` in `data.attrs["backend"]`. """ diff --git a/tests/test_backend.py b/tests/config/test_backend.py similarity index 100% rename from tests/test_backend.py rename to tests/config/test_backend.py diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 4e914582..b2558860 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -135,7 +135,7 @@ class TestStreeMergeMapsIJK: def factories(self) -> Factories: domain = (3, 3, 4) return get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend=Backend.cpu() + domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:IJK") ) @pytest.fixture @@ -258,7 +258,7 @@ class TestStreeMergeMapsKJI: def factories(self) -> Factories: domain = (3, 3, 4) return get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend=Backend.cpu() + domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:KJI") ) @pytest.fixture diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index ed73db25..e1a9c0fe 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -14,15 +14,12 @@ from ndsl.config import Backend -def test_safety_checks(): +@pytest.mark.parametrize("backend", (Backend.python(), Backend("st:gt:cpu:KJI"))) +def test_safety_checks(backend: Backend) -> None: with pytest.raises( RuntimeError, match="Device sync is true on a CPU based backend" ): - CompilationConfig(Backend.python(), device_sync=True) - with pytest.raises( - RuntimeError, match="Device sync is true on a CPU based backend" - ): - CompilationConfig(Backend("st:gt:cpu:KJI"), device_sync=True) + CompilationConfig(backend=backend, device_sync=True) @pytest.mark.parametrize( From ffa6e39bb39e06c535739268d0107a805d4c6a6e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 19 Feb 2026 11:57:57 -0500 Subject: [PATCH 35/47] Add a `force_build` option on the `set_distributed_cache` for fortran model integration --- ndsl/config/backend.py | 4 ++-- ndsl/dsl/dace/build.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 3803dd41..447522ab 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -204,7 +204,7 @@ def is_fortran_aligned(self) -> bool: """Default backend for quick iterative work.""" backend_cpu: Final[Backend] = Backend("orch:dace:cpu:IJK") -"""Default performance backend targeting CPU device""" +"""Default performance backend targeting CPU device.""" backend_gpu: Final[Backend] = Backend("orch:dace:gpu:KJI") -"""Default performance backend targeting GPU device""" +"""Default performance backend targeting GPU device.""" diff --git a/ndsl/dsl/dace/build.py b/ndsl/dsl/dace/build.py index ef1bcc90..5434abdf 100644 --- a/ndsl/dsl/dace/build.py +++ b/ndsl/dsl/dace/build.py @@ -110,17 +110,20 @@ def get_sdfg_path( return sdfg_dir_path -def set_distributed_caches(config: DaceConfig) -> None: - """In Run mode, check required file then point current rank cache to source cache""" +def set_distributed_caches(config: DaceConfig, force_build: bool = False) -> None: + """In Run mode, check required file then point current rank cache to source cache. + + Optional: force build irregardless of backend or orchestration mode. + """ # Execute specific initialization per orchestration state - if not config.get_backend().is_orchestrated(): + if not config.get_backend().is_orchestrated() and not force_build: return # Check that we have all the file we need to early out in case # of issues. orchestration_mode = config.get_orchestrate() - if orchestration_mode == DaCeOrchestration.Run: + if orchestration_mode == DaCeOrchestration.Run and not force_build: import os cache_directory = get_cache_fullpath(config.code_path) From eec242d5ef23666264b19f0f9e154a8675acbae1 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 19 Feb 2026 16:01:50 -0500 Subject: [PATCH 36/47] Lint --- tests/test_g2g_communication.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 27155c44..ba899f31 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -175,6 +175,6 @@ def test_halo_update_communicate_though_cpu( # We expect several np calls and several cp calls global N_ZEROS_CALLS # noqa: F824 global ... is unused assert N_ZEROS_CALLS[np.zeros] > 0 - assert len(N_ZEROS_CALLS) == 1 or N_ZEROS_CALLS[cp.zeros] == 0, ( - "no calls to cupy.zeros logged" - ) + assert ( + len(N_ZEROS_CALLS) == 1 or N_ZEROS_CALLS[cp.zeros] == 0 + ), "no calls to cupy.zeros logged" From d27017768dbaa2ba62a1c4592251e6196be62927 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 19 Feb 2026 16:09:20 -0500 Subject: [PATCH 37/47] [TMP CI] Shift `pyFV3` translate to update branch --- .github/workflows/fv3_translate_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fv3_translate_tests.yaml b/.github/workflows/fv3_translate_tests.yaml index f28e5ad1..5299170b 100644 --- a/.github/workflows/fv3_translate_tests.yaml +++ b/.github/workflows/fv3_translate_tests.yaml @@ -10,7 +10,7 @@ on: jobs: fv3_translate_tests: - uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop + uses: floriandeconinck/pyFV3/.github/workflows/translate.yaml@update/2026.02.00 with: component_trigger: true component_name: NDSL From 08891fc977c169c48a1265bac2e3af636847dc08 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 20 Feb 2026 09:55:36 -0500 Subject: [PATCH 38/47] Space change to trigger CI --- ndsl/config/backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 447522ab..709c31d5 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -63,7 +63,6 @@ class BackendLoopOrder(Enum): for k # Layout=0 for j # Layout=1 for i # Layout=2 - """ From 4cc245cd5847231bc46e7436342575aeac7f6803 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 09:25:18 -0500 Subject: [PATCH 39/47] Pull on test branch for pySHiELD CI --- .github/workflows/shield_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/shield_tests.yaml b/.github/workflows/shield_tests.yaml index 53ba510b..811e778b 100644 --- a/.github/workflows/shield_tests.yaml +++ b/.github/workflows/shield_tests.yaml @@ -10,7 +10,7 @@ on: jobs: shield_translate_tests: - uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop + uses: floriandeconinck/pySHiELD/.github/workflows/translate.yaml@update/2026.02.00 with: component_trigger: true component_name: NDSL From 5cdb59963dfe318b9f6828aa36a32668789a9a62 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 11:03:08 -0500 Subject: [PATCH 40/47] Update `CompilationConfig` to turn backend as a string into `Backend` --- ndsl/dsl/stencil_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index d2d8180f..086b3874 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -142,7 +142,7 @@ def get_decomposition_info_from_comm( def as_dict(self) -> dict[str, Any]: return { - "backend": self.backend, + "backend": self.backend.as_humanly_readable, "rebuild": self.rebuild, "validate_args": self.validate_args, "format_source": self.format_source, @@ -154,7 +154,9 @@ def as_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict) -> Self: instance = cls( - backend=data.get("backend", Backend.python()), + backend=Backend( + data.get("backend", Backend.python().as_humanly_readable()) + ), rebuild=data.get("rebuild", False), validate_args=data.get("validate_args", True), format_source=data.get("format_source", False), From 2d940a7a6fd60037ac8597b949fd2a4201c68cf3 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 11:11:39 -0500 Subject: [PATCH 41/47] Typo --- ndsl/dsl/stencil_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index 086b3874..43eb08e0 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -142,7 +142,7 @@ def get_decomposition_info_from_comm( def as_dict(self) -> dict[str, Any]: return { - "backend": self.backend.as_humanly_readable, + "backend": self.backend.as_humanly_readable(), "rebuild": self.rebuild, "validate_args": self.validate_args, "format_source": self.format_source, From 8bf39c84495c9ef49c21417cec5e30e1bfc11756 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 11:24:20 -0500 Subject: [PATCH 42/47] Fix tests for CompilationConfig --- tests/dsl/test_compilation_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index e1a9c0fe..3ce99399 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -164,7 +164,7 @@ def test_from_dict() -> None: assert config.run_mode == RunMode.BuildAndRun assert config.use_minimal_caching is False - specification_dict["backend"] = Backend("st:gt:gpu:KJI") + specification_dict["backend"] = "st:gt:gpu:KJI" config = CompilationConfig.from_dict(specification_dict) assert config.backend == Backend("st:gt:gpu:KJI") From 5c13eaed2f15893522a8f0c89f92395b232003aa Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 12:17:00 -0500 Subject: [PATCH 43/47] Fix backend save in dace_config --- ndsl/dsl/dace/dace_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 29e8b361..b53d6675 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -374,7 +374,7 @@ def get_sync_debug(self) -> bool: def as_dict(self) -> dict[str, Any]: return { "_orchestrate": str(self._orchestrate.name), - "_backend": self._backend, + "_backend": self._backend.as_humanly_readable(), "my_rank": self.my_rank, "rank_size": self.rank_size, "layout": self.layout, From 40f0a2dcc3a0784930d0400561a307c55563ec6f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 12:19:08 -0500 Subject: [PATCH 44/47] Fix dace config load --- ndsl/dsl/dace/dace_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index b53d6675..32dfdb9f 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -385,7 +385,7 @@ def as_dict(self) -> dict[str, Any]: def from_dict(cls, data: dict) -> Self: config = cls( None, - backend=data["_backend"], + backend=Backend(data["_backend"]), orchestration=DaCeOrchestration[data["_orchestrate"]], ) config.my_rank = data["my_rank"] From ebcb258d0afd4622aaa69007372d64ba42614397 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 24 Feb 2026 13:10:06 -0500 Subject: [PATCH 45/47] [TMP CI] Update Pace branch --- .github/workflows/pace_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pace_tests.yaml b/.github/workflows/pace_tests.yaml index ea3d40b3..bae1c5ab 100644 --- a/.github/workflows/pace_tests.yaml +++ b/.github/workflows/pace_tests.yaml @@ -10,7 +10,7 @@ on: jobs: pace_main_tests: - uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop + uses: floriandeconinck/pace/.github/workflows/main_unit_tests.yaml@update/2026.02.00 with: component_trigger: true component_name: NDSL From c321fd6aaaac9700bd3ddb40516df7a849b471a7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 4 Mar 2026 08:45:08 +0100 Subject: [PATCH 46/47] tests: simple cleanup as code review --- tests/config/test_backend.py | 5 ++- tests/dsl/test_stencil_wrapper.py | 61 +++++++++++++++---------------- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/config/test_backend.py b/tests/config/test_backend.py index 678af2af..a632b2c9 100644 --- a/tests/config/test_backend.py +++ b/tests/config/test_backend.py @@ -18,8 +18,9 @@ def test_backend_building(): Backend("st:dace:gpu:KJI") Backend("orch:dace:gpu:KJI") - with pytest.raises(ValueError): - Backend("bad:name:good:number") + unknown_backend = "bad:name:good:number" + with pytest.raises(ValueError, match=f"Unknown {unknown_backend}, options are .*"): + Backend(unknown_backend) def test_backend_operators(): diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 4112e9dd..80070e7e 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -12,7 +12,7 @@ Quantity, StencilConfig, ) -from ndsl.config.backend import Backend, backend_python +from ndsl.config.backend import Backend from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage @@ -154,13 +154,12 @@ def copy_stencil(q_in: FloatField, q_out: FloatField): @pytest.mark.parametrize("validate_args", [True, False]) -def test_copy_frozen_stencil( - validate_args: bool, - backend: Backend = backend_python, - rebuild: bool = False, - format_source: bool = False, - device_sync: bool = False, -) -> None: +def test_copy_frozen_stencil(validate_args: bool) -> None: + backend = Backend.python() + rebuild = False + format_source = False + device_sync = False + config = get_stencil_config( backend=backend, rebuild=rebuild, @@ -183,12 +182,12 @@ def test_copy_frozen_stencil( np.testing.assert_array_equal(q_in, q_out) -def test_frozen_stencil_raises_if_given_origin( - backend: Backend = backend_python, - rebuild: bool = False, - format_source: bool = False, - device_sync: bool = False, -) -> None: +def test_frozen_stencil_raises_if_given_origin() -> None: + backend = Backend.python() + rebuild = False + format_source = False + device_sync = False + # only guaranteed when validating args config = get_stencil_config( backend=backend, @@ -210,12 +209,12 @@ def test_frozen_stencil_raises_if_given_origin( stencil(q_in, q_out, origin=(0, 0, 0)) -def test_frozen_stencil_raises_if_given_domain( - backend: Backend = backend_python, - rebuild: bool = False, - format_source: bool = False, - device_sync: bool = False, -) -> None: +def test_frozen_stencil_raises_if_given_domain() -> None: + backend = Backend.python() + rebuild = False + format_source = False + device_sync = False + # only guaranteed when validating args config = get_stencil_config( backend=backend, @@ -246,8 +245,9 @@ def test_frozen_stencil_kwargs_passed_to_init( validate_args: bool, format_source: bool, device_sync: bool, - backend: Backend = backend_python, ) -> None: + backend = Backend.python() + config = get_stencil_config( backend=backend, rebuild=rebuild, @@ -315,11 +315,10 @@ def test_frozen_field_after_parameter() -> None: @pytest.mark.parametrize("backend", (Backend.python(), Backend("st:gt:gpu:KJI"))) -def test_backend_options( - backend: Backend, - rebuild: bool = True, - validate_args: bool = True, -) -> None: +def test_backend_options(backend: Backend) -> None: + rebuild = True + validate_args = True + expected_options = { Backend.python(): { "backend": "debug", @@ -343,12 +342,10 @@ def test_backend_options( assert actual == expected -def test_illegal_backend_options(): - with pytest.raises(ValueError): - get_stencil_config( - backend=Backend("bad:back:end:now"), - match="Backend bad:back:end:now is not registered. Valid options are:*", - ) +def test_illegal_backend_options() -> None: + unknown_backend = "bad:back:end:now" + with pytest.raises(ValueError, match=f"Unknown {unknown_backend}, options are:*"): + get_stencil_config(backend=Backend(unknown_backend)) def get_mock_quantity(): From 01e8717d5fd467819f4956c47e4a275447ed4e52 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 4 Mar 2026 10:53:46 +0100 Subject: [PATCH 47/47] tmp ci: point to noop translate test in hooks Required status checks require us to run these exact hooks (admins might have powers that I don't have). However, what we can is to point to any fork and just have a noop workflow there. This way we can merge and release. Once released and updated, we can restore the default hooks. --- .github/workflows/fv3_translate_tests.yaml | 5 ++++- .github/workflows/pace_tests.yaml | 5 ++++- .github/workflows/shield_tests.yaml | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/workflows/fv3_translate_tests.yaml b/.github/workflows/fv3_translate_tests.yaml index 5299170b..7d863e59 100644 --- a/.github/workflows/fv3_translate_tests.yaml +++ b/.github/workflows/fv3_translate_tests.yaml @@ -10,7 +10,10 @@ on: jobs: fv3_translate_tests: - uses: floriandeconinck/pyFV3/.github/workflows/translate.yaml@update/2026.02.00 + # TODO + # restore once NDSL 2026.02.00 is released and pyFV3 is updated. + # uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop + uses: romanc/pyFV3/.github/workflows/translate.yaml@noop with: component_trigger: true component_name: NDSL diff --git a/.github/workflows/pace_tests.yaml b/.github/workflows/pace_tests.yaml index bae1c5ab..2e9ae1f3 100644 --- a/.github/workflows/pace_tests.yaml +++ b/.github/workflows/pace_tests.yaml @@ -10,7 +10,10 @@ on: jobs: pace_main_tests: - uses: floriandeconinck/pace/.github/workflows/main_unit_tests.yaml@update/2026.02.00 + # TODO + # restore once NDSL 2026.02.00 is released and pace is updated. + # uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop + uses: romanc/pace/.github/workflows/main_unit_tests.yaml@noop with: component_trigger: true component_name: NDSL diff --git a/.github/workflows/shield_tests.yaml b/.github/workflows/shield_tests.yaml index 811e778b..c5fa7c13 100644 --- a/.github/workflows/shield_tests.yaml +++ b/.github/workflows/shield_tests.yaml @@ -10,7 +10,10 @@ on: jobs: shield_translate_tests: - uses: floriandeconinck/pySHiELD/.github/workflows/translate.yaml@update/2026.02.00 + # TODO + # restore once NDSL 2026.02.00 is released and pySHiELD is updated. + # uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop + uses: romanc/pySHiELD/.github/workflows/translate.yaml@noop with: component_trigger: true component_name: NDSL