diff --git a/.github/workflows/fv3_translate_tests.yaml b/.github/workflows/fv3_translate_tests.yaml index f28e5ad1..7d863e59 100644 --- a/.github/workflows/fv3_translate_tests.yaml +++ b/.github/workflows/fv3_translate_tests.yaml @@ -10,7 +10,10 @@ on: jobs: fv3_translate_tests: - uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop + # TODO + # restore once NDSL 2026.02.00 is released and pyFV3 is updated. + # uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop + uses: romanc/pyFV3/.github/workflows/translate.yaml@noop with: component_trigger: true component_name: NDSL diff --git a/.github/workflows/pace_tests.yaml b/.github/workflows/pace_tests.yaml index ea3d40b3..2e9ae1f3 100644 --- a/.github/workflows/pace_tests.yaml +++ b/.github/workflows/pace_tests.yaml @@ -10,7 +10,10 @@ on: jobs: pace_main_tests: - uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop + # TODO + # restore once NDSL 2026.02.00 is released and pace is updated. + # uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop + uses: romanc/pace/.github/workflows/main_unit_tests.yaml@noop with: component_trigger: true component_name: NDSL diff --git a/.github/workflows/shield_tests.yaml b/.github/workflows/shield_tests.yaml index 53ba510b..c5fa7c13 100644 --- a/.github/workflows/shield_tests.yaml +++ b/.github/workflows/shield_tests.yaml @@ -10,7 +10,10 @@ on: jobs: shield_translate_tests: - uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop + # TODO + # restore once NDSL 2026.02.00 is released and pySHiELD is updated. + # uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop + uses: romanc/pySHiELD/.github/workflows/translate.yaml@noop with: component_trigger: true component_name: NDSL diff --git a/examples/mpi/zarr_monitor.py b/examples/mpi/zarr_monitor.py index 7a630438..15b8196c 100644 --- a/examples/mpi/zarr_monitor.py +++ b/examples/mpi/zarr_monitor.py @@ -12,6 +12,7 @@ TilePartitioner, ZarrMonitor, ) +from ndsl.config import backend_python from ndsl.constants import I_DIM, J_DIM, K_DIM @@ -20,9 +21,9 @@ def get_example_state(time): sizer = SubtileGridSizer( - nx=48, ny=48, nz=70, n_halo=3, data_dimensions={}, backend="debug" + nx=48, ny=48, nz=70, n_halo=3, data_dimensions={}, backend=backend_python ) - allocator = QuantityFactory(sizer, np) + allocator = QuantityFactory(sizer, backend=backend_python) air_temperature = allocator.zeros([I_DIM, J_DIM, K_DIM], units="degK") air_temperature.view[:] = np.random.randn(*air_temperature.extent) return {"time": time, "air_temperature": air_temperature} diff --git a/ndsl/__init__.py b/ndsl/__init__.py index dd5f21e3..97839190 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,20 +1,14 @@ +# isort:skip_file from . import dsl # isort:skip from .logging import ndsl_log # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm from .comm.partitioner import CubedSpherePartitioner, TilePartitioner +from .config.backend import Backend from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath -from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration -from .dsl.dace.orchestration import orchestrate, orchestrate_function -from .dsl.dace.utils import ( - ArrayReport, - DaCeProgress, - MaxBandwidthBenchmarkProgram, - StorageReport, -) -from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .quantity import Quantity from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig @@ -25,14 +19,25 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Local, LocalState, Quantity, State +from .quantity import Local, LocalState, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .types import Allocator from .utils import MetaEnumStr +from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .dsl.dace.utils import ( + ArrayReport, + DaCeProgress, + MaxBandwidthBenchmarkProgram, + StorageReport, +) +from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration +from .dsl.dace.orchestration import orchestrate, orchestrate_function + __all__ = [ "dsl", + "Backend", "CubedSphereCommunicator", "TileCommunicator", "LocalComm", diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index 1e94e3c7..5d504251 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -1,6 +1,7 @@ import warnings from ndsl import ( + Backend, CompilationConfig, DaceConfig, DaCeOrchestration, @@ -14,6 +15,7 @@ TileCommunicator, TilePartitioner, ) +from ndsl.config.backend import backend_cpu, backend_python def _get_factories( @@ -21,7 +23,7 @@ def _get_factories( ny: int, nz: int, nhalo: int, - backend: str, + backend: Backend, orchestration: DaCeOrchestration, topology: str, ) -> tuple[StencilFactory, QuantityFactory]: @@ -91,14 +93,14 @@ def get_factories_single_tile_orchestrated( ny: int, nz: int, nhalo: int, - backend: str = "dace:cpu", + backend: Backend = backend_cpu, *, orchestration_mode: DaCeOrchestration | None = None, ) -> tuple[StencilFactory, QuantityFactory]: """Build the pair of (StencilFactory, QuantityFactory) for orchestrated code on a single tile topology.""" - if backend is not None and not backend.startswith("dace"): - raise ValueError("Only `dace:*` backends can be orchestrated.") + if backend is not None and not backend.is_orchestrated(): + raise ValueError(f"Only `orch:*` backends can be orchestrated, got {backend}.") return _get_factories( nx=nx, @@ -112,15 +114,17 @@ def get_factories_single_tile_orchestrated( def get_factories_single_tile( - nx: int, ny: int, nz: int, nhalo: int, backend: str = "numpy" + nx: int, ny: int, nz: int, nhalo: int, backend: Backend = backend_python ) -> tuple[StencilFactory, QuantityFactory]: """Build the pair of (StencilFactory, QuantityFactory) for stencils on a single tile topology.""" + if not isinstance(backend, Backend): + raise RuntimeError(f"Backend {backend} is not of class Backend") return _get_factories( nx=nx, ny=ny, nz=nz, nhalo=nhalo, backend=backend, - orchestration=DaCeOrchestration.Python, + orchestration=DaCeOrchestration.BuildAndRun, topology="tile", ) diff --git a/ndsl/config/__init__.py b/ndsl/config/__init__.py new file mode 100644 index 00000000..7bb6b243 --- /dev/null +++ b/ndsl/config/__init__.py @@ -0,0 +1,22 @@ +from .backend import ( + Backend, + BackendFramework, + BackendLoopOrder, + BackendStrategy, + BackendTargetDevice, + backend_cpu, + backend_gpu, + backend_python, +) + + +__all__ = [ + "Backend", + "BackendFramework", + "BackendStrategy", + "BackendTargetDevice", + "BackendLoopOrder", + "backend_python", + "backend_cpu", + "backend_gpu", +] diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py new file mode 100644 index 00000000..709c31d5 --- /dev/null +++ b/ndsl/config/backend.py @@ -0,0 +1,209 @@ +from enum import Enum +from typing import Final + +import gt4py.cartesian.backend as gt_backend + + +class BackendStrategy(Enum): + """Strategy for the code execution""" + + STENCIL = "st" + ORCHESTRATION = "orch" + + +class BackendTargetDevice(Enum): + """Target device""" + + CPU = "cpu" + GPU = "gpu" + + +class BackendFramework(Enum): + """Main lower-level framework (or language) backend relies on""" + + GRIDTOOLS = "gt" + DACE = "dace" + PYTHON = "python" + NUMPY = "numpy" + + +class BackendLoopOrder(Enum): + """Cartesian loop order generated by the Backend""" + + IJK = "IJK" + IKJ = "IKJ" + JKI = "JKI" + JIK = "JIK" + KIJ = "KIJ" + KJI = "KJI" + + +_NDSL_TO_GT4PY_BACKEND_NAMING = { + "st:python:cpu:IJK": "debug", + "st:numpy:cpu:IJK": "numpy", + "st:gt:cpu:IJK": "gt:cpu_kfirst", + "st:gt:cpu:KJI": "gt:cpu_ifirst", + "st:gt:gpu:KJI": "gt:gpu", + "st:dace:cpu:IJK": "dace:cpu_kfirst", + "orch:dace:cpu:IJK": "dace:cpu_kfirst", + "st:dace:cpu:KIJ": "dace:cpu", + "orch:dace:cpu:KIJ": "dace:cpu", + "st:dace:cpu:KJI": "dace:cpu_KJI", + "orch:dace:cpu:KJI": "dace:cpu_KJI", + "st:dace:gpu:KJI": "dace:gpu", + "orch:dace:gpu:KJI": "dace:gpu", +} +"""Internal: match the NDSL backend names with the GT4Py names""" + +_FORTRAN_LOOP_LAYOUT = (2, 1, 0) +"""Fortran is a column-first (or stride-first) memory system, +which in the internal gt4py loop layout means I (or axis[0]) has +the higher value, e.g. "higher importance to run first": + +for k # Layout=0 + for j # Layout=1 + for i # Layout=2 +""" + + +class Backend: + """Backend for NDSL. + + The backend is a string concatenating information on the intent of the user + for a given execution separated by a ':'. + + It describes to NDSL the strategy, device and framework to be used + on the frontend code. Additionally, it gives a hint toward the macro-strategy + for loop ordering (IJK, KJI, etc.) or a more broad intent (debug, numpy). + + For convenience, shorcuts are given to the most common needs ( + `backend_python`, `backend_cpu`, `backend_gpu`). + """ + + def __init__(self, ndsl_backend: str) -> None: + # Checks for existence and form + if ndsl_backend not in _NDSL_TO_GT4PY_BACKEND_NAMING: + raise ValueError( + f"Unknown {ndsl_backend}, options are {list(_NDSL_TO_GT4PY_BACKEND_NAMING.keys())}" + ) + parts = ndsl_backend.split(":") + if len(parts) != 4: + raise ValueError(f"Backend {ndsl_backend} is ill-formed.") + + # Breakdown and save into internal parameters + self._humanly_readable = ndsl_backend + self._strategy = BackendStrategy(parts[0].lower()) + self._framework = BackendFramework(parts[1].lower()) + self._device = BackendTargetDevice(parts[2].lower()) + self._loop_order = BackendLoopOrder(parts[3].upper()) + + # Check it exists in GT4Py + try: + gt4py_backend = gt_backend.from_name(self.as_gt4py()) + except ValueError: + raise ValueError( + f"NDSL backend {ndsl_backend} does not have a working GT4Py version. " + f"GT4Py backend {self.as_gt4py()} is not registered. " + f"Contact the team." + ) + + # Check GPU capacity + if ( + self._device == BackendTargetDevice.GPU + and gt4py_backend.storage_info["device"] != "gpu" + ): + raise ValueError( + f"NDSL backend requested ({self._humanly_readable}) tagets GPU," + f"but requests a non-GPU backend from GT4Py ({self.as_gt4py()})." + ) + + def __str__(self) -> str: + return self.as_humanly_readable() + + def __repr__(self) -> str: + return self.as_humanly_readable() + + def __eq__(self, other: object) -> bool: + if isinstance(other, str): + other = Backend(other) + if not isinstance(other, Backend): + raise NotImplementedError( + f"Backend equality operator for {type(other)} is not implemented." + ) + return self._humanly_readable == other._humanly_readable + + def __hash__(self) -> int: + return hash(self._humanly_readable) + + @staticmethod + def python() -> "Backend": + """Default backend for quick iterative work.""" + return backend_python + + @staticmethod + def cpu() -> "Backend": + """Default performance backend targeting CPU devices.""" + return backend_cpu + + @staticmethod + def gpu() -> "Backend": + """Default performance backend targeting GPU devices.""" + return backend_gpu + + @property + def device(self) -> BackendTargetDevice: + return self._device + + @property + def framework(self) -> BackendFramework: + return self._framework + + @property + def loop_order(self) -> BackendLoopOrder: + return self._loop_order + + def as_gt4py(self) -> str: + """Given an NDSL backend, give back a GT4Py equivalent""" + return _NDSL_TO_GT4PY_BACKEND_NAMING[self._humanly_readable] + + def as_humanly_readable(self) -> str: + return self._humanly_readable + + def as_safe_for_path(self) -> str: + return self._humanly_readable.replace(":", "_") + + def as_layout_map(self) -> tuple[int, ...]: + loop_order_as_string = self._loop_order.value + return tuple( + len(loop_order_as_string) - 1 - loop_order_as_string.index(axis) + for axis in "IJK" + ) + + def is_orchestrated(self) -> bool: + return self._strategy == BackendStrategy.ORCHESTRATION + + def is_stencil(self) -> bool: + return self._strategy == BackendStrategy.STENCIL + + def is_gpu_backend(self) -> bool: + return self._device == BackendTargetDevice.GPU + + def is_fortran_aligned(self) -> bool: + """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran + striding.""" + + # Dev NOTE: this probably should live as an accessor directly on the + # storage_info or layout_info of GT4Py, rather than stacked up on NDSL + return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name( + self.as_gt4py() + ).storage_info["layout_map"](("I", "J", "K")) + + +backend_python: Final[Backend] = Backend("st:python:cpu:IJK") +"""Default backend for quick iterative work.""" + +backend_cpu: Final[Backend] = Backend("orch:dace:cpu:IJK") +"""Default performance backend targeting CPU device.""" + +backend_gpu: Final[Backend] = Backend("orch:dace:gpu:KJI") +"""Default performance backend targeting GPU device.""" diff --git a/ndsl/dsl/dace/build.py b/ndsl/dsl/dace/build.py index 155583a0..5434abdf 100644 --- a/ndsl/dsl/dace/build.py +++ b/ndsl/dsl/dace/build.py @@ -2,6 +2,7 @@ from dace.sdfg import SDFG from gt4py.cartesian import config as gt_config +from ndsl.config import Backend from ndsl.dsl.caches.cache_location import get_cache_directory, get_cache_fullpath from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration from ndsl.logging import ndsl_log @@ -23,7 +24,10 @@ def build_info_filepath() -> str: def write_build_info( - sdfg: SDFG, layout: tuple[int, int], resolution_per_tile: list[int], backend: str + sdfg: SDFG, + layout: tuple[int, int], + resolution_per_tile: list[int], + backend: Backend, ) -> None: """Write down all relevant information on the build to identify it at load time.""" @@ -86,7 +90,7 @@ def get_sdfg_path( build_info_file.readline() # Read in build_backend = build_info_file.readline().rstrip() - if config.get_backend() != build_backend: + if config.get_backend() != Backend(build_backend): raise RuntimeError( f"SDFG build for {build_backend}, {config._backend} has been asked" ) @@ -106,17 +110,20 @@ def get_sdfg_path( return sdfg_dir_path -def set_distributed_caches(config: DaceConfig) -> None: - """In Run mode, check required file then point current rank cache to source cache""" +def set_distributed_caches(config: DaceConfig, force_build: bool = False) -> None: + """In Run mode, check required file then point current rank cache to source cache. + + Optional: force build irregardless of backend or orchestration mode. + """ # Execute specific initialization per orchestration state - orchestration_mode = config.get_orchestrate() - if orchestration_mode == DaCeOrchestration.Python: + if not config.get_backend().is_orchestrated() and not force_build: return # Check that we have all the file we need to early out in case # of issues. - if orchestration_mode == DaCeOrchestration.Run: + orchestration_mode = config.get_orchestrate() + if orchestration_mode == DaCeOrchestration.Run and not force_build: import os cache_directory = get_cache_fullpath(config.code_path) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 3e85050b..32dfdb9f 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -11,9 +11,9 @@ from ndsl import LocalComm from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner +from ndsl.config import Backend from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath -from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import get_precision from ndsl.optional_imports import cupy as cp from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector @@ -132,23 +132,21 @@ class DaCeOrchestration(enum.Enum): """ Orchestration mode for DaCe - Python: python orchestration - Build: compile & save SDFG only BuildAndRun: compile & save SDFG, then run + Build: compile & save SDFG only Run: load from .so and run, will fail if .so is not available """ - Python = 0 + BuildAndRun = 0 Build = 1 - BuildAndRun = 2 - Run = 3 + Run = 2 class DaceConfig: def __init__( self, communicator: Communicator | None, - backend: str, + backend: Backend, tile_nx: int = 0, tile_nz: int = 0, orchestration: DaCeOrchestration | None = None, @@ -174,6 +172,7 @@ def __init__( of column-physics) and therefore can be compiled once """ + self._backend = backend self._single_code_path = single_code_path # Recording SDFG loaded for fast re-access # ToDo: DaceConfig becomes a bit more than a read-only config @@ -194,11 +193,11 @@ def __init__( # We should refactor the architecture to allow for a `gtc:orchestrated:dace:X` # backend that would signify both the `CPU|GPU` split and the orchestration mode if orchestration is None: - fv3_dacemode_env_var = os.getenv("FV3_DACEMODE", "Python") + fv3_dacemode_env_var = os.getenv("FV3_DACEMODE", "BuildAndRun") # The below condition guards against defining empty FV3_DACEMODE and # awkward behavior of os.getenv returning "" even when not defined if fv3_dacemode_env_var is None or fv3_dacemode_env_var == "": - fv3_dacemode_env_var = "Python" + fv3_dacemode_env_var = "BuildAndRun" self._orchestrate = DaCeOrchestration[fv3_dacemode_env_var] else: self._orchestrate = orchestration @@ -331,7 +330,6 @@ def __init__( if dace_conf_to_kill is not None: Path(dace_conf_to_kill).unlink(missing_ok=True) - self._backend = backend self.tile_resolution = [tile_nx, tile_nx, tile_nz] from ndsl.dsl.dace.build import set_distributed_caches @@ -358,19 +356,13 @@ def __init__( set_distributed_caches(self) - if self.is_dace_orchestrated() and "dace" not in self._backend: - raise RuntimeError( - "DaceConfig: orchestration can only be leveraged " - f"with the `dace:*` backends, not with {self._backend}." - ) - def is_dace_orchestrated(self) -> bool: - return self._orchestrate != DaCeOrchestration.Python + return self._backend.is_orchestrated() def is_gpu_backend(self) -> bool: - return is_gpu_backend(self._backend) + return self._backend.is_gpu_backend() - def get_backend(self) -> str: + def get_backend(self) -> Backend: return self._backend def get_orchestrate(self) -> DaCeOrchestration: @@ -382,7 +374,7 @@ def get_sync_debug(self) -> bool: def as_dict(self) -> dict[str, Any]: return { "_orchestrate": str(self._orchestrate.name), - "_backend": self._backend, + "_backend": self._backend.as_humanly_readable(), "my_rank": self.my_rank, "rank_size": self.rank_size, "layout": self.layout, @@ -393,7 +385,7 @@ def as_dict(self) -> dict[str, Any]: def from_dict(cls, data: dict) -> Self: config = cls( None, - backend=data["_backend"], + backend=Backend(data["_backend"]), orchestration=DaCeOrchestration[data["_orchestrate"]], ) config.my_rank = data["my_rank"] diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index e2340a7a..4037ca7f 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -18,10 +18,11 @@ from dace.transformation.auto.auto_optimize import make_transients_persistent from dace.transformation.dataflow import MapExpansion from dace.transformation.helpers import get_parent_map -from gt4py import storage +from gt4py import storage as gt_storage import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI +from ndsl.config import BackendLoopOrder from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE, @@ -82,7 +83,10 @@ def _download_results_from_dace( return None backend = config.get_backend() - return [storage.from_array(result, backend=backend) for result in dace_result] + return [ + gt_storage.from_array(result, backend=backend.as_gt4py()) + for result in dace_result + ] def _to_gpu(sdfg: SDFG) -> None: @@ -169,7 +173,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): passes = [] - if backend_name == "dace:cpu_kfirst": + if backend_name.loop_order == BackendLoopOrder.IJK: passes.extend( [ CleanUpScheduleTree(), @@ -179,7 +183,7 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - elif backend_name in ["dace:cpu_KJI", "dace:gpu"]: + elif backend_name.loop_order == BackendLoopOrder.KJI: passes.extend( [ CleanUpScheduleTree(), @@ -189,7 +193,7 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) - else: + elif backend_name.loop_order == BackendLoopOrder.KIJ: passes.extend( [ CleanUpScheduleTree(), @@ -199,6 +203,10 @@ def _build_sdfg( CartesianRefineTransients(backend_name), ] ) + else: + raise NotImplementedError( + f"Loop order {backend_name.loop_order} has no schedule tree pipeline" + ) CPUPipeline(passes=passes).run(stree, verbose=True) with DaCeProgress(config, "Schedule Tree: go back to SDFG"): diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 64650762..7b788e32 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,10 +1,9 @@ -from __future__ import annotations - import warnings import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree +from ndsl.config import Backend, BackendFramework from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator from ndsl.logging import ndsl_log @@ -29,7 +28,7 @@ def _reduce_cartesian_axis_size_to_1( transient_map_reads: dace.subsets.Range | None, transient_map_writes: dace.subsets.Range | None, transient_data: dace.data.Data, - ijk_order: tuple[int, int, int], + layout_map: tuple[int, ...], ) -> bool: """Reduce dimension size of transient to 1 if all access (reads and writes) are atomic""" @@ -67,10 +66,10 @@ def _reduce_cartesian_axis_size_to_1( ) if len(transient_data.shape) == 3: - layout = [*ijk_order] + layout = [*layout_map] else: data_dim_count = len(transient_data.shape) - 3 - layout = [dim + data_dim_count for dim in ijk_order] + [ + layout = [dim + data_dim_count for dim in layout_map] + [ i - 1 for i in range(data_dim_count, 0, -1) ] @@ -241,7 +240,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): memory (e.g. halo) for the `RebuildMemletsFromContainers`! """ - def __init__(self, backend: str) -> None: + def __init__(self, backend: Backend) -> None: warnings.warn( "CartesianRefineTransients is a WIP. It's usage is *severely* limited " "and will most likely lead to bad numerics. Check the docs, check utest.", @@ -249,17 +248,11 @@ def __init__(self, backend: str) -> None: stacklevel=2, ) - if backend in ["dace:cpu_kfirst"]: - self.ijk_order = (2, 1, 0) - elif backend in ["dace:gpu", "dace:cpu_KJI"]: - self.ijk_order = (0, 1, 2) - elif backend in ["dace:cpu"]: - self.ijk_order = (1, 2, 0) - else: + if not backend.is_orchestrated() or backend.framework != BackendFramework.DACE: raise NotImplementedError( f"[Schedule Tree Opt] CartesianRefineTransient not implemented for backend {backend}" ) - + self.layout_map = backend.as_layout_map() self.refined_array: set[str] = set() def __str__(self) -> str: @@ -292,7 +285,7 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: collect_map.transients_range_reads[name], collect_map.transients_range_writes[name], data, - self.ijk_order, + self.layout_map, ) refined_transient += 1 if refined else 0 diff --git a/ndsl/dsl/dace/utils.py b/ndsl/dsl/dace/utils.py index ae6b1d80..b57213aa 100644 --- a/ndsl/dsl/dace/utils.py +++ b/ndsl/dsl/dace/utils.py @@ -8,11 +8,12 @@ from dace.transformation.helpers import get_parent_map from gt4py.cartesian.gtscript import PARALLEL, computation, interval +import ndsl.xumpy as xp +from ndsl.config import Backend from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, FrozenStencil, StencilConfig from ndsl.dsl.typing import Float, FloatField from ndsl.logging import ndsl_log -from ndsl.optional_imports import cupy as cp class DaCeProgress: @@ -187,7 +188,7 @@ def copy_kernel(q_in: FloatField, q_out: FloatField) -> None: class MaxBandwidthBenchmarkProgram: - def __init__(self, size: Any, backend: str) -> None: + def __init__(self, size: Any, backend: Backend) -> None: from ndsl.dsl.dace.orchestration import DaCeOrchestration, orchestrate dace_config = DaceConfig( @@ -211,7 +212,7 @@ def __call__(self, A: Any, B: Any, n: int) -> None: def kernel_theoretical_timing( sdfg: dace.sdfg.SDFG, *, - backend: str, + backend: Backend, hardware_bw_in_GB_s: float | None = None, ) -> dict[str, float]: """Compute a lower timing bound for kernels with the following hypothesis: @@ -229,12 +230,8 @@ def kernel_theoretical_timing( f" arrays at {Float} precision..." ) bench = MaxBandwidthBenchmarkProgram(size, backend) - if backend == "dace:gpu": - A = cp.ones(size, dtype=Float) - B = cp.ones(size, dtype=Float) - else: - A = np.ones(size, dtype=Float) - B = np.ones(size, dtype=Float) + A = xp.ones(size, backend, dtype=Float) + B = xp.ones(size, backend, dtype=Float) n = 1000 m = 4 dt = [] @@ -259,7 +256,7 @@ def kernel_theoretical_timing( if hardware_bw_in_GB_s else "Given hardware bandwidth" ) - print(f"{label}: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s") + print(f"{label}: {bandwidth_in_bytes_s / (1024 * 1024 * 1024)} GB/s") allmaps = [ (me, state) @@ -342,7 +339,7 @@ def report_kernel_theoretical_timing( def kernel_theoretical_timing_from_path( sdfg_path: str, - backend: str, + backend: Backend, hardware_bw_in_GB_s: float | None = None, output_format: str | None = None, ) -> str: diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 009ed324..39afa30d 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable, Sequence from functools import wraps from typing import Any @@ -5,8 +6,8 @@ import numpy as np import numpy.typing as npt from gt4py import storage as gt_storage -from gt4py.cartesian import backend as gt_backend +from ndsl.config.backend import Backend from ndsl.constants import N_HALO_DEFAULT from ndsl.dsl.typing import DTypes, Float from ndsl.logging import ndsl_log @@ -85,7 +86,7 @@ def make_storage_data( shape: tuple[int, ...] | None = None, origin: tuple[int, ...] = origin, *, - backend: str, + backend: Backend, dtype: DTypes = Float, mask: tuple[bool, ...] | None = None, start: tuple[int, ...] = (0, 0, 0), @@ -107,7 +108,7 @@ def make_storage_data( start: Starting points for slices in data copies dummy: Dummy axes axis: Axis for 2D to 3D arrays - backend: gt4py backend to use + backend: current backend in use Returns: Field[..., dtype]: New storage @@ -190,7 +191,7 @@ def make_storage_data( storage = gt_storage.from_array( data, dtype, - backend=backend, + backend=backend.as_gt4py(), aligned_index=_translate_origin(origin, mask), dimensions=_mask_to_dimensions(mask, data.shape), ) @@ -206,7 +207,7 @@ def _make_storage_data_1d( read_only: bool = True, *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: # axis refers to a repeated axis, dummy refers to a singleton axis axis = min(axis, len(shape) - 1) @@ -243,7 +244,7 @@ def _make_storage_data_2d( read_only: bool = True, *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: # axis refers to which axis should be repeated (when making a full 3d data), # dummy refers to a singleton axis @@ -277,7 +278,7 @@ def _make_storage_data_3d( start: tuple[int, ...] = (0, 0, 0), *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: istart, jstart, kstart = start isize, jsize, ksize = data.shape @@ -296,7 +297,7 @@ def _make_storage_data_Nd( start: tuple[int, ...] | None = None, *, dtype: DTypes = Float, - backend: str, + backend: Backend, ) -> npt.NDArray: if start is None: start = tuple([0] * data.ndim) @@ -310,7 +311,7 @@ def make_storage_from_shape( shape: tuple[int, ...], origin: tuple[int, ...] = origin, *, - backend: str, + backend: Backend, dtype: DTypes = Float, mask: tuple[bool, ...] | None = None, ) -> npt.NDArray: @@ -342,7 +343,7 @@ def make_storage_from_shape( storage = gt_storage.zeros( shape, dtype, - backend=backend, + backend=backend.as_gt4py(), aligned_index=_translate_origin(origin, mask), dimensions=_mask_to_dimensions(mask, shape), ) @@ -358,7 +359,7 @@ def make_storage_dict( names: list[str] | None = None, axis: int = 2, *, - backend: str, + backend: Backend, dtype: DTypes = Float, ) -> dict[str, npt.NDArray]: assert names is not None, "for 4d variable storages, specify a list of names" @@ -379,7 +380,7 @@ def make_storage_dict( return data_dict -def storage_dict(st_dict, names, shape, origin, *, backend: str): +def storage_dict(st_dict, names, shape, origin, *, backend: Backend): for name in names: st_dict[name] = make_storage_from_shape(shape, origin, backend=backend) @@ -446,35 +447,27 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None): return cp.asarray(array, dtype, order) -def is_gpu_backend(backend: str) -> bool: - return gt_backend.from_name(backend).storage_info["device"] == "gpu" - - -_FORTRAN_LOOP_LAYOUT = (2, 1, 0) -"""Fortran is a column-first (or stride-first) memory system, -which in the internal gt4py loop layout means I (or axis[0]) has -the higher value, e.g. "higher importance to run first": - -for k # Layout=0 - for j # Layout=1 - for i # Layout=2 - -""" - - -def backend_is_fortran_aligned(backend: str) -> bool: - """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran - striding.""" +def is_gpu_backend(backend: Backend) -> bool: + warnings.warn( + "Function `gt4py_utils.is_gpu_backend` is deprecated, please use `Backend.is_gpu_backend()`", + category=DeprecationWarning, + stacklevel=2, + ) + return backend.is_gpu_backend() - # Dev NOTE: this is used in interfacing with NDSL (e.g. GEOS.) - return _FORTRAN_LOOP_LAYOUT == gt_backend.from_name(backend).storage_info[ - "layout_map" - ](("I", "J", "K")) +def backend_is_fortran_aligned(backend: Backend) -> bool: + warnings.warn( + "Function `gt4py_utils.backend_is_fortran_aligned` is deprecated " + "please use `Backend.backend_is_fortran_aligned()`", + category=DeprecationWarning, + stacklevel=2, + ) + return backend.is_fortran_aligned() -def zeros(shape, dtype=Float, *, backend: str): - storage_type = cp.ndarray if is_gpu_backend(backend) else np.ndarray +def zeros(shape, dtype=Float, *, backend: Backend): + storage_type = cp.ndarray if backend.is_gpu_backend() else np.ndarray xp = cp if cp and storage_type is cp.ndarray else np return xp.zeros(shape, dtype=dtype) @@ -543,8 +536,8 @@ def stack(tup, axis: int = 0, out=None): return xp.stack(array_tup, axis, out) -def device_sync(backend: str) -> None: - if cp and is_gpu_backend(backend): +def device_sync(backend: Backend) -> None: + if cp and backend.is_gpu_backend(): cp.cuda.Device(0).synchronize() diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 7cf27a64..df40e59c 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -20,6 +20,7 @@ from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI +from ndsl.config.backend import Backend, BackendFramework from ndsl.constants import ( I_DIM, I_DIMS, @@ -182,7 +183,7 @@ def __init__( comm=comm, ) compilation_config = CompilationConfig( - backend="numpy", + backend=Backend.python(), rebuild=stencil_config.compilation_config.rebuild, validate_args=stencil_config.compilation_config.validate_args, format_source=True, @@ -316,7 +317,10 @@ def __init__( # NOTE: this is also down in `dace/build.py` for orchestration # This is still needed for non-orchestrated used of DaCe. # A better build system would take care of BOTH of those at the same time - if "dace" in self.stencil_config.compilation_config.backend: + if ( + BackendFramework.DACE + == self.stencil_config.compilation_config.backend.framework + ): dace.Config.set( "default_build_folder", value="{gt_root}/{gt_cache}/dacecache".format( @@ -986,7 +990,7 @@ def __init__( self.comm = comm @property - def backend(self) -> str: + def backend(self) -> Backend: return self.config.compilation_config.backend def from_origin_domain( diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index f90db27a..43eb08e0 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -6,14 +6,13 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from typing import Any, Self -from gt4py.cartesian.backend import from_name as check_backend_existence from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches from ndsl.comm.partitioner import Partitioner +from ndsl.config import Backend, BackendTargetDevice, backend_python from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration -from ndsl.dsl.gt4py_utils import is_gpu_backend class RunMode(enum.Enum): @@ -32,7 +31,7 @@ class RunMode(enum.Enum): class CompilationConfig: def __init__( self, - backend: str = "numpy", + backend: Backend = backend_python, rebuild: bool = True, validate_args: bool = True, format_source: bool = False, @@ -41,10 +40,8 @@ def __init__( use_minimal_caching: bool = False, communicator: Communicator | None = None, ) -> None: - if "gpu" not in backend and device_sync is True: + if backend.device is BackendTargetDevice.CPU and device_sync is True: raise RuntimeError("Device sync is true on a CPU based backend") - # GT4Py backend args - check_backend_existence(backend) self.backend = backend self.rebuild = rebuild self.validate_args = validate_args @@ -145,7 +142,7 @@ def get_decomposition_info_from_comm( def as_dict(self) -> dict[str, Any]: return { - "backend": self.backend, + "backend": self.backend.as_humanly_readable(), "rebuild": self.rebuild, "validate_args": self.validate_args, "format_source": self.format_source, @@ -157,7 +154,9 @@ def as_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict) -> Self: instance = cls( - backend=data.get("backend", "numpy"), + backend=Backend( + data.get("backend", Backend.python().as_humanly_readable()) + ), rebuild=data.get("rebuild", False), validate_args=data.get("validate_args", True), format_source=data.get("format_source", False), @@ -196,7 +195,7 @@ def __init__( else DaceConfig( communicator=None, backend=self.compilation_config.backend, - orchestration=DaCeOrchestration.Python, + orchestration=DaCeOrchestration.Run, ) ) self.backend_opts = { @@ -206,12 +205,12 @@ def __init__( self._hash = self._compute_hash() @property - def backend(self) -> str: + def backend(self) -> Backend: return self.compilation_config.backend def _compute_hash(self) -> int: md5 = hashlib.md5() - md5.update(self.compilation_config.backend.encode()) + md5.update(self.compilation_config.backend.as_gt4py().encode()) for attr in ( self.compilation_config.rebuild, self.compilation_config.validate_args, @@ -239,7 +238,7 @@ def stencil_kwargs( self, *, func: Callable[..., None], skip_passes: Iterable[str] = () ) -> dict: kwargs = { - "backend": self.compilation_config.backend, + "backend": self.compilation_config.backend.as_gt4py(), "rebuild": self.compilation_config.rebuild, "name": func.__module__ + "." + func.__name__, **self.backend_opts, @@ -254,7 +253,7 @@ def stencil_kwargs( @property def is_gpu_backend(self) -> bool: - return is_gpu_backend(self.compilation_config.backend) + return self.compilation_config.backend.is_gpu_backend() @classmethod def _get_oir_pipeline(cls, skip_passes: Sequence[str]) -> OirPipeline: diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 919e16c2..dfaf98de 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -9,6 +9,7 @@ from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.communicator import Communicator +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_INTERFACE_DIM, @@ -488,7 +489,7 @@ def from_tile_sizing( npy: int, npz: int, communicator: Communicator, - backend: str, + backend: Backend, grid_type: int = 0, dx_const: float = 1000.0, dy_const: float = 1000.0, diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 398978bc..b4eca405 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -11,7 +11,7 @@ # TODO: if we can remove translate tests in favor of checkpointer tests, # we can remove this "disallowed" import (ndsl.util does not depend on ndsl.dsl) -from ndsl.dsl.gt4py_utils import is_gpu_backend, split_cartesian_into_storages +from ndsl.dsl.gt4py_utils import split_cartesian_into_storages from ndsl.dsl.typing import Float from ndsl.grid.generation import MetricTerms from ndsl.initialization.allocator import QuantityFactory @@ -230,10 +230,9 @@ def ptop(self) -> Float: """Top of atmosphere pressure (Pa)""" if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - if self.ak.backend is not None and is_gpu_backend(self.ak.backend): + if self.ak.backend is not None and self.ak.backend.is_gpu_backend(): return Float(self.ak.view[0].get()) - else: - return Float(self.ak.view[0]) + return Float(self.ak.view[0]) @dataclasses.dataclass(frozen=True) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 6355e192..110a09f1 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -6,6 +6,7 @@ import numpy as np from gt4py import storage as gt_storage +from ndsl.config import Backend from ndsl.constants import SPATIAL_DIMS from ndsl.dsl.typing import Float from ndsl.initialization import GridSizer @@ -13,13 +14,13 @@ class QuantityFactory: - def __init__(self, sizer: GridSizer, *, backend: str) -> None: + def __init__(self, sizer: GridSizer, *, backend: Backend) -> None: """ - Initialize a QuantityFactory from a GridSizer and a GT4Py backend name. + Initialize a QuantityFactory from a GridSizer and a NDSL backend name. Args: sizer: GridSizer object that determines the array sizes. - backend: GT4Py backend name used for performance-optimized allocation. + backend: NDSL backend name used for performance-optimized allocation. """ self.sizer = sizer self.backend = backend @@ -199,7 +200,7 @@ def _allocate( dtype=dtype, aligned_index=origin, dimensions=dimensions, - backend=self.backend, + backend=self.backend.as_gt4py(), ) return Quantity( diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index 0c520353..4a257080 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -3,8 +3,8 @@ import ndsl.constants as constants from ndsl.comm.partitioner import TilePartitioner +from ndsl.config import Backend from ndsl.constants import N_HALO_DEFAULT -from ndsl.dsl.gt4py_utils import backend_is_fortran_aligned from ndsl.initialization.grid_sizer import GridSizer @@ -16,11 +16,11 @@ def __init__( nz: int, n_halo: int, data_dimensions: dict[str, int], - backend: str, + backend: Backend, ) -> None: super().__init__(nx, ny, nz, n_halo, data_dimensions) - fortran_style_memory = backend_is_fortran_aligned(backend) + fortran_style_memory = backend.is_fortran_aligned() self._pad_non_interface_dimensions = not fortran_style_memory @classmethod @@ -32,7 +32,7 @@ def from_tile_params( n_halo: int, layout: tuple[int, int], *, - backend: str, + backend: Backend, data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, @@ -45,12 +45,12 @@ def from_tile_params( nz: number of vertical levels n_halo: number of halo points layout: (y, x) number of ranks along tile edges - backend: backend name + backend: current backend in use data_dimensions: lengths of any non-x/y/z dimensions, such as land or radiation dimensions tile_partitioner (optional): partitioner object for the tile. By default, a TilePartitioner is created with the given layout - tile_rank (optional): rank of this subtile. + tile_rank (optional): rank of this subtile """ if data_dimensions is None: data_dimensions = {} @@ -85,18 +85,18 @@ def from_namelist( tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, *, - backend: str, + backend: Backend, ) -> Self: """Create a SubtileGridSizer from a Fortran namelist. Args: namelist: A namelist for the fv3gfs fortran model tile_partitioner (optional): a partitioner to use for segmenting the tile. - By default, a TilePartitioner is used. + By default, a TilePartitioner is used tile_rank (optional): current rank on tile. Default is 0. Only matters if different ranks have different domain shapes. If tile_partitioner - is a TilePartitioner, this argument does not matter. - backend: backend name + is a TilePartitioner, this argument does not matter + backend: current backend in use """ if "fv_core_nml" in namelist.keys(): layout = namelist["fv_core_nml"]["layout"] diff --git a/ndsl/performance/collector.py b/ndsl/performance/collector.py index c6ec0d97..cd091f3f 100644 --- a/ndsl/performance/collector.py +++ b/ndsl/performance/collector.py @@ -7,6 +7,7 @@ import numpy as np from ndsl.comm.comm_abc import Comm +from ndsl.config import Backend from ndsl.optional_imports import cupy as cp from ndsl.performance.report import ( Report, @@ -29,13 +30,13 @@ def collect_performance(self) -> None: ... def write_out_performance( self, - backend: str, + backend: Backend, is_orchestrated: bool, dt_atmos: float, ) -> None: ... def write_out_rank_0( - self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str + self, backend: Backend, is_orchestrated: bool, dt_atmos: float, sim_status: str ) -> None: ... @classmethod @@ -77,7 +78,7 @@ def collect_performance(self) -> None: self.timestep_timer.reset() def write_out_rank_0( - self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str + self, backend: Backend, is_orchestrated: bool, dt_atmos: float, sim_status: str ) -> None: if self.comm.Get_rank() == 0: git_hash = "None" @@ -114,7 +115,7 @@ def write_out_rank_0( def write_out_performance( self, - backend: str, + backend: Backend, is_orchestrated: bool, dt_atmos: float, ) -> None: @@ -162,13 +163,13 @@ def collect_performance(self) -> None: def write_out_performance( self, - backend: str, + backend: Backend, is_orchestrated: bool, dt_atmos: float, ) -> None: pass def write_out_rank_0( - self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str + self, backend: Backend, is_orchestrated: bool, dt_atmos: float, sim_status: str ) -> None: pass diff --git a/ndsl/performance/report.py b/ndsl/performance/report.py index a1ccd1da..06b48a22 100644 --- a/ndsl/performance/report.py +++ b/ndsl/performance/report.py @@ -8,6 +8,7 @@ import numpy as np from ndsl.comm.comm_abc import Comm +from ndsl.config import Backend @dataclasses.dataclass @@ -41,7 +42,7 @@ def __post_init__(self) -> None: def get_experiment_info( experiment_name: str, time_step: int, - backend: str, + backend: Backend, git_hash: str, is_orchestrated: bool, ) -> Experiment: @@ -52,7 +53,7 @@ def get_experiment_info( git_hash=git_hash, timestamp=datetime.now().strftime("%d/%m/%Y %H:%M:%S"), timesteps=time_step, - backend=f"{orchestration}/{backend}", + backend=f"{orchestration}/{backend.as_safe_for_path()}", ) return experiment @@ -92,7 +93,8 @@ def gather_timing_data( comm.Gather(sendbuf, recvbuf, root=0) if is_root: timing_info[timer_name] = TimeReport( - hits=0, times=copy.deepcopy(recvbuf.tolist()) # type: ignore[union-attr] # (recvbuf is defined on root rank) + hits=0, + times=copy.deepcopy(recvbuf.tolist()), # type: ignore[union-attr] # (recvbuf is defined on root rank) ) return timing_info @@ -132,7 +134,7 @@ def get_sypd(timing_info: dict[str, TimeReport], dt_atmos: float) -> float: def collect_data_and_write_to_file( time_step: int, - backend: str, + backend: Backend, is_orchestrated: bool, git_hash: str, comm: Comm, diff --git a/ndsl/performance/tools.py b/ndsl/performance/tools.py index 80ed377f..463f5ecf 100644 --- a/ndsl/performance/tools.py +++ b/ndsl/performance/tools.py @@ -1,5 +1,6 @@ import click +from ndsl.config import Backend from ndsl.dsl.dace.utils import ( kernel_theoretical_timing_from_path, memory_static_analysis_from_path, @@ -41,7 +42,7 @@ "--backend", required=False, type=click.STRING, - default="dace:gpu", + default="st:dace:gpu:KJI", ) def command_line( action: str, @@ -60,7 +61,7 @@ def command_line( print( kernel_theoretical_timing_from_path( sdfg_path, - backend=backend, + backend=Backend(backend), hardware_bw_in_GB_s=( None if hardware_bw_in_gb_s == 0 else hardware_bw_in_gb_s ), diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 79f360c6..8b6b17ee 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -4,6 +4,7 @@ import dace import numpy as np +from ndsl.config import Backend from ndsl.optional_imports import cupy from ndsl.quantity import Quantity @@ -22,7 +23,7 @@ def __init__( dims: Sequence[str], units: str, *, - backend: str, + backend: Backend, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, allow_mismatch_float_precision: bool = False, diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index a8ce919c..409c0ca0 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -5,6 +5,7 @@ import numpy as np +from ndsl.config.backend import Backend from ndsl.optional_imports import cupy from ndsl.types import NumpyModule @@ -29,8 +30,8 @@ class QuantityMetadata: "ndarray-like type used to store the data." dtype: type "dtype of the data in the ndarray-like object." - backend: str - "GT4Py backend name. Used for performance optimal data allocation." + backend: Backend + "NDSL backend. Used for performance optimal data allocation." @property def dim_lengths(self) -> dict[str, int]: diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 67d437c2..71fa89d3 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -13,6 +13,7 @@ import ndsl.constants as constants from ndsl.comm.mpi import MPI +from ndsl.config.backend import Backend from ndsl.dsl.typing import Float, is_float from ndsl.optional_imports import cupy from ndsl.quantity.bounds import BoundedArrayView @@ -33,7 +34,7 @@ def __init__( dims: Sequence[str], units: str, *, - backend: str, + backend: Backend, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, allow_mismatch_float_precision: bool = False, @@ -45,7 +46,7 @@ def __init__( data: ndarray-like object containing the underlying data dims: dimension names for each axis units: units of the quantity - backend: GT4Py backend name. We ensure that the data is allocated in a + backend: current backend in use. We ensure that the data is allocated in a performance optimal way for that backend and copy if necessary. origin: first point in data within the computational domain. Defaults to None. @@ -86,7 +87,7 @@ def __init__( _validate_quantity_property_lengths(data.shape, dims, origin, extent) - gt4py_backend_cls = gt_backend.from_name(backend) + gt4py_backend_cls = gt_backend.from_name(backend.as_gt4py()) is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] device = gt4py_backend_cls.storage_info["device"] @@ -123,7 +124,7 @@ def __init__( self._data = gt_storage.from_array( data, data.dtype, - backend=backend, + backend=backend.as_gt4py(), aligned_index=origin, dimensions=dimensions, ) @@ -151,7 +152,7 @@ def from_data_array( origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, number_of_halo_points: int = 0, - backend: str | None = None, + backend: Backend | None = None, allow_mismatch_float_precision: bool = False, ) -> Quantity: """ @@ -164,7 +165,7 @@ def from_data_array( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: GT4Py backend name. If given, we allocate data in a performance + backend: current backend in use. If given, we allocate data in a performance optimal way for this backend. Overrides any potentially saved `backend` in `data.attrs["backend"]`. """ @@ -248,12 +249,14 @@ def units(self) -> str: return self.metadata.units @property - def backend(self) -> str: + def backend(self) -> Backend: return self.metadata.backend @property def attrs(self) -> dict: - return dict(**self._attrs, units=self.units, backend=self.backend) + return dict( + **self._attrs, units=self.units, backend=self.backend.as_humanly_readable() + ) @property def dims(self) -> tuple[str, ...]: @@ -486,14 +489,18 @@ def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple: return tuple(return_list) -def _resolve_backend(data: xr.DataArray, backend: str | None) -> str: +def _resolve_backend(data: xr.DataArray, backend: Backend | None) -> Backend: if backend is not None: # Forced backend name takes precedence return backend # If backend name was serialized with data, take this one if "backend" in data.attrs: - return data.attrs["backend"] + if not isinstance(data.attrs["backend"], str): + raise ValueError( + f"Quantity.attrs 'backend' must be a string, found {data.attrs['backend']}" + ) + return Backend(data.attrs["backend"]) # else, fall back to assume python-based layout. - return "debug" + return Backend.python() diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 1e5f41f7..213c025e 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -17,6 +17,7 @@ ) from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.config import Backend from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.stencils.testing.grid import Grid # type: ignore from ndsl.stencils.testing.parallel_translate import ParallelTranslate @@ -33,7 +34,7 @@ def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( "--backend", action="store", - default="numpy", + default="st:python:cpu:numpy", help="Backend to execute the test with, can only be one.", ) parser.addoption( @@ -232,7 +233,7 @@ def get_savepoint_restriction(metafunc: Any) -> int | None: return int(svpt) if svpt else None -def get_config(backend: str, communicator: Communicator | None) -> StencilConfig: +def get_config(backend: Backend, communicator: Communicator | None) -> StencilConfig: stencil_config = StencilConfig( compilation_config=CompilationConfig( backend=backend, rebuild=False, validate_args=True @@ -248,10 +249,11 @@ def get_config(backend: str, communicator: Communicator | None) -> StencilConfig def sequential_savepoint_cases( metafunc: Any, data_path: Path, namelist_filename: Path, *, backend: str ) -> list[SavepointCase]: + ndsl_backend = Backend(backend) savepoint_names = get_sequential_savepoint_names(metafunc, data_path) namelist = load_f90nml(namelist_filename) grid_params = grid_params_from_f90nml(namelist) - stencil_config = get_config(backend, None) + stencil_config = get_config(ndsl_backend, None) ranks = get_ranks(metafunc, grid_params["layout"]) savepoint_to_replay = get_savepoint_restriction(metafunc) grid_mode = metafunc.config.getoption("grid") @@ -265,7 +267,7 @@ def sequential_savepoint_cases( savepoint_to_replay, stencil_config, namelist, - backend, + ndsl_backend, data_path, grid_mode, topology_mode, @@ -280,7 +282,7 @@ def _savepoint_cases( savepoint_to_replay: int | None, stencil_config: StencilConfig, namelist: Namelist, - backend: str, + backend: Backend, data_path: Path, grid_mode: str, topology_mode: str, @@ -349,7 +351,7 @@ def _savepoint_cases( def compute_grid_data( grid: Grid, grid_params: dict, - backend: str, + backend: Backend, layout: tuple[int, int], topology_mode: str, ) -> None: @@ -371,13 +373,14 @@ def parallel_savepoint_cases( backend: str, comm: Comm, ) -> list[SavepointCase]: + ndsl_backend = Backend(backend) namelist = load_f90nml(namelist_filename) grid_params = grid_params_from_f90nml(namelist) topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") communicator = get_communicator(comm, grid_params["layout"], topology_mode) - stencil_config = get_config(backend, communicator) + stencil_config = get_config(ndsl_backend, communicator) savepoint_names = get_parallel_savepoint_names(metafunc, data_path) grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = get_savepoint_restriction(metafunc) @@ -388,7 +391,7 @@ def parallel_savepoint_cases( savepoint_to_replay, stencil_config, namelist, - backend, + ndsl_backend, data_path, grid_mode, topology_mode, diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 12bbbe8a..26de91c5 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -3,6 +3,7 @@ import numpy as np from ndsl.comm.partitioner import TilePartitioner +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM, N_HALO_DEFAULT from ndsl.dsl import gt4py_utils as utils from ndsl.dsl.stencil import GridIndexing @@ -35,7 +36,7 @@ class Grid: # grid.ie == npx + halo - 2 @classmethod - def _make(cls, npx, npy, npz, layout, rank, backend): + def _make(cls, npx, npy, npz, layout, rank, backend: Backend): shape_params = { "npx": npx, "npy": npy, @@ -59,13 +60,13 @@ def _make(cls, npx, npy, npz, layout, rank, backend): return cls(indices, shape_params, rank, layout, backend, local_indices=True) @classmethod - def from_namelist(cls, namelist, rank, backend): + def from_namelist(cls, namelist, rank, backend: Backend): return cls._make( namelist.npx, namelist.npy, namelist.npz, namelist.layout, rank, backend ) @classmethod - def with_data_from_namelist(cls, namelist, communicator, backend): + def with_data_from_namelist(cls, namelist, communicator, backend: Backend): grid = cls.from_namelist(namelist, communicator.rank, backend) grid.make_grid_data( npx=namelist.npx, @@ -82,7 +83,7 @@ def __init__( shape_params, rank, layout, - backend, + backend: Backend, data_fields: dict | None = None, local_indices=False, ): @@ -838,7 +839,7 @@ def driver_grid_data(self) -> DriverGridData: def set_grid_data(self, grid_data: GridData): self._grid_data = grid_data - def make_grid_data(self, npx, npy, npz, communicator, backend): + def make_grid_data(self, npx, npy, npz, communicator, backend: Backend): metric_terms = MetricTerms.from_tile_sizing( npx=npx, npy=npy, npz=npz, communicator=communicator, backend=backend ) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index bb3bfe6c..c7f7fa22 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -9,6 +9,7 @@ from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.config import Backend from ndsl.dsl import gt4py_utils as gt_utils from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, StencilConfig @@ -31,7 +32,7 @@ def platform(): return "docker" if in_docker else "metal" -def process_override(threshold_overrides, testobj, test_name, backend): +def process_override(threshold_overrides, testobj, test_name, backend: Backend): override = threshold_overrides.get(test_name, None) if override is not None: for spec in override: @@ -145,7 +146,7 @@ def _get_thresholds(compute_function, input_data) -> None: ) def test_sequential_savepoint( case: SavepointCase, - backend, + backend: str, print_failures, failure_stride, subtests, @@ -158,11 +159,12 @@ def test_sequential_savepoint( raise ValueError( f"No translate object available for savepoint {case.savepoint_name}." ) + ndsl_backend = Backend(backend) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend=backend), + compilation_config=CompilationConfig(backend=ndsl_backend), dace_config=DaceConfig( communicator=None, - backend=backend, + backend=ndsl_backend, ), ) # Reduce error threshold for GPU @@ -171,7 +173,7 @@ def test_sequential_savepoint( case.testobj.near_zero = max(case.testobj.near_zero, GPU_NEAR_ZERO) if threshold_overrides is not None: process_override( - threshold_overrides, case.testobj, case.savepoint_name, backend + threshold_overrides, case.testobj, case.savepoint_name, ndsl_backend ) if case.testobj.skip_test: return @@ -330,7 +332,7 @@ def get_tile_communicator(comm, layout): ) def test_parallel_savepoint( case: SavepointCase, - backend, + backend: str, print_failures, failure_stride, subtests, @@ -340,6 +342,7 @@ def test_parallel_savepoint( multimodal_metric, xy_indices=True, ): + ndsl_backend = Backend(backend) mpi_comm = MPIComm() if mpi_comm.Get_size() % 6 != 0: layout = ( @@ -358,10 +361,10 @@ def test_parallel_savepoint( f"No translate object available for savepoint {case.savepoint_name}" ) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend=backend), + compilation_config=CompilationConfig(backend=ndsl_backend), dace_config=DaceConfig( communicator=communicator, - backend=backend, + backend=ndsl_backend, ), ) # Increase minimum error threshold for GPU @@ -370,7 +373,7 @@ def test_parallel_savepoint( case.testobj.near_zero = max(case.testobj.near_zero, GPU_NEAR_ZERO) if threshold_overrides is not None: process_override( - threshold_overrides, case.testobj, case.savepoint_name, backend + threshold_overrides, case.testobj, case.savepoint_name, ndsl_backend ) if case.testobj.skip_test: return diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 8cf78282..45d4a8c2 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -5,6 +5,7 @@ import numpy.typing as npt import ndsl.dsl.gt4py_utils as utils +from ndsl.config import Backend from ndsl.dsl.stencil import StencilFactory from ndsl.optional_imports import cupy as cp from ndsl.quantity import Quantity @@ -22,7 +23,7 @@ def read_serialized_data(serializer, savepoint, variable): return data -def pad_field_in_j(field, nj: int, backend: str): +def pad_field_in_j(field, nj: int, backend: Backend): utils.device_sync(backend) outfield = utils.tile(field[:, 0, :], (nj, 1, 1)).transpose(1, 0, 2) return outfield @@ -358,7 +359,7 @@ class TranslateGrid: # 6---2---7 @classmethod - def new_from_serialized_data(cls, serializer, rank, layout, backend): + def new_from_serialized_data(cls, serializer, rank, layout, backend: Backend): grid_savepoint = serializer.get_savepoint("Grid-Info")[0] grid_data = {} grid_fields = serializer.fields_at_savepoint(grid_savepoint) @@ -366,7 +367,7 @@ def new_from_serialized_data(cls, serializer, rank, layout, backend): grid_data[field] = read_serialized_data(serializer, grid_savepoint, field) return cls(grid_data, rank, layout, backend=backend) - def __init__(self, inputs, rank, layout, *, backend: str): + def __init__(self, inputs, rank, layout, *, backend: Backend): self.backend = backend self.indices = {} self.shape_params = {} diff --git a/ndsl/xumpy/alloc.py b/ndsl/xumpy/alloc.py index a6a98804..dc3fd0b2 100644 --- a/ndsl/xumpy/alloc.py +++ b/ndsl/xumpy/alloc.py @@ -1,7 +1,9 @@ +from typing import Sequence, SupportsIndex + import numpy as np import numpy.typing as npt -from ndsl.dsl.gt4py_utils import is_gpu_backend +from ndsl.config import Backend from ndsl.dsl.typing import Float from ndsl.optional_imports import cupy as cp @@ -9,53 +11,56 @@ if cp is None: cp = np +# Taking a page from cupy's playbook to have tuple & ndarray +_ShapeLike = SupportsIndex | Sequence[SupportsIndex] + def zeros( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.zeros(shape, dtype=dtype) return np.zeros(shape, dtype=dtype) def ones( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.ones(shape, dtype=dtype) return np.ones(shape, dtype=dtype) def empty( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.empty(shape, dtype=dtype) return np.empty(shape, dtype=dtype) def full( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, value: np.ScalarType, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): return cp.full(shape, value, dtype=dtype) return np.full(shape, value, dtype=dtype) def random( - shape: tuple[int, ...], - backend: str, + shape: _ShapeLike, + backend: Backend, dtype: npt.DTypeLike = Float, ) -> np.ndarray | cp.ndarray: - if is_gpu_backend(backend): + if backend.is_gpu_backend(): cp.random.rand(*shape) return np.random.rand(*shape) diff --git a/tests/config/test_backend.py b/tests/config/test_backend.py new file mode 100644 index 00000000..a632b2c9 --- /dev/null +++ b/tests/config/test_backend.py @@ -0,0 +1,31 @@ +import pytest + +from ndsl import Backend + + +def test_backend_building(): + Backend("st:python:cpu:IJK") + Backend("st:numpy:cpu:IJK") + Backend("st:gt:cpu:IJK") + Backend("st:gt:cpu:KJI") + Backend("st:gt:gpu:KJI") + Backend("st:dace:cpu:IJK") + Backend("orch:dace:cpu:IJK") + Backend("st:dace:cpu:KIJ") + Backend("orch:dace:cpu:KIJ") + Backend("st:dace:cpu:KJI") + Backend("orch:dace:cpu:KJI") + Backend("st:dace:gpu:KJI") + Backend("orch:dace:gpu:KJI") + + unknown_backend = "bad:name:good:number" + with pytest.raises(ValueError, match=f"Unknown {unknown_backend}, options are .*"): + Backend(unknown_backend) + + +def test_backend_operators(): + backend_A = Backend("st:numpy:cpu:IJK") + backend_B = Backend("st:numpy:cpu:IJK") + + assert backend_A == backend_B + assert not (backend_A != backend_B) diff --git a/tests/conftest.py b/tests/conftest.py index 5cb1c5ff..76a4d89a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from ndsl.config import Backend from ndsl.optional_imports import cupy @@ -13,18 +14,18 @@ def backend(request): @pytest.fixture -def gt4py_backend(backend): +def ndsl_backend(backend: str): if backend == "numpy": - return "numpy" + return Backend("st:numpy:cpu:IJK") if backend == "cupy": - return "gt:gpu" + return Backend("st:dace:gpu:KJI") - return None + raise ValueError(f"Test backend {backend} cannot be translated into Backend") @pytest.fixture -def numpy(backend): +def numpy(backend: str): if backend == "numpy": return np diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index ca715c3d..b2558860 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -5,6 +5,7 @@ from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField @@ -134,7 +135,7 @@ class TestStreeMergeMapsIJK: def factories(self) -> Factories: domain = (3, 3, 4) return get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:IJK") ) @pytest.fixture @@ -257,7 +258,7 @@ class TestStreeMergeMapsKJI: def factories(self) -> Factories: domain = (3, 3, 4) return get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_KJI" + domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:KJI") ) @pytest.fixture diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py index 17480dda..677790bc 100644 --- a/tests/dsl/dace/stree/optimizations/test_pipeline.py +++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py @@ -1,5 +1,6 @@ from ndsl import StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -29,7 +30,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField): def test_stree_roundtrip_no_opt(): domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) code = TriviallyMergeableCode(stencil_factory) diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py index ffaee9db..9795957a 100644 --- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -1,5 +1,6 @@ from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval from ndsl.dsl.typing import Float, FloatField @@ -92,7 +93,7 @@ def do_not_refine_datadims(self, in_field: Quantity, out_field: Quantity) -> Non def test_stree_roundtrip_transient_is_refined() -> None: domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") diff --git a/tests/dsl/orchestration/test_call.py b/tests/dsl/orchestration/test_call.py index 2c902138..11e4c757 100644 --- a/tests/dsl/orchestration/test_call.py +++ b/tests/dsl/orchestration/test_call.py @@ -2,6 +2,7 @@ from ndsl import NDSLRuntime, StencilFactory from ndsl.boilerplate import get_factories_single_tile_orchestrated +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM, Float from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval @@ -76,7 +77,7 @@ def test_default_types_are_compiletime(): def test_dace_call_argument_caching(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - 5, 5, 2, 0, backend="dace:cpu_kfirst" + 5, 5, 2, 0, backend=Backend.cpu() ) dconfig = stencil_factory.config.dace_config diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index 9c697143..5eaba654 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -16,6 +16,7 @@ StencilFactory, ) from ndsl.comm.mpi import MPI +from ndsl.config import Backend from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval from ndsl.dsl.stencil import CompareToNumpyStencil, FrozenStencil @@ -28,7 +29,7 @@ def _stencil(inp: Field[float], out: Field[float]): def _build_stencil( - backend: str, orchestrated: DaCeOrchestration + backend: Backend, orchestrated: DaCeOrchestration | None ) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: # Make stencil and verify it ran grid_indexing = GridIndexing( @@ -55,7 +56,7 @@ def _build_stencil( class OrchestratedProgram: - def __init__(self, backend, orchestration: DaCeOrchestration): + def __init__(self, backend: Backend, orchestration: DaCeOrchestration | None): self.stencil, grid_indexing, stencil_config = _build_stencil( backend, orchestration ) @@ -69,7 +70,7 @@ def __call__(self): def test_relocatability_orchestration() -> None: # Compile on default - p0 = OrchestratedProgram("dace:cpu", DaCeOrchestration.BuildAndRun) + p0 = OrchestratedProgram(Backend.cpu(), DaCeOrchestration.BuildAndRun) p0() expected_cache_dir = ( @@ -85,8 +86,7 @@ def test_relocatability_orchestration_tmpdir(tmpdir) -> None: gt_config.cache_settings["root_path"] = tmpdir # Compile in temporary directory that is only available in this test session. - backend = "dace:cpu" - p1 = OrchestratedProgram(backend, DaCeOrchestration.BuildAndRun) + p1 = OrchestratedProgram(Backend.cpu(), DaCeOrchestration.BuildAndRun) p1() expected_cache_dir = ( @@ -102,14 +102,14 @@ def test_relocatability_orchestration_tmpdir(tmpdir) -> None: relocated_path = tmpdir / ".my_relocated_cache_path" shutil.copytree(tmpdir, relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path - p2 = OrchestratedProgram(backend, DaCeOrchestration.Run) + p2 = OrchestratedProgram(Backend.cpu(), DaCeOrchestration.Run) p2() # Generate a file exists error to check for bad path bogus_path = "./nope/not_at_all/not_happening" gt_config.cache_settings["root_path"] = bogus_path with pytest.raises(RuntimeError): - OrchestratedProgram(backend, DaCeOrchestration.Run) + OrchestratedProgram(Backend.cpu(), DaCeOrchestration.Run) def test_relocatability() -> None: @@ -119,11 +119,11 @@ def test_relocatability() -> None: gt_config.cache_settings["root_path"] = Path.cwd() # Compile on default - backend = "dace:cpu" - p0 = OrchestratedProgram(backend, DaCeOrchestration.Python) + backend = Backend("st:dace:cpu:KIJ") + p0 = OrchestratedProgram(backend, None) p0() - backend_sanitized = backend.replace(":", "") + backend_sanitized = backend.as_gt4py().replace(":", "") python_version = f"py{sys.version_info[0]}{sys.version_info[1]}" expected_cache_path = ( Path.cwd() @@ -146,7 +146,7 @@ def test_relocatability_tmpdir(tmpdir) -> None: # Compile in another directory backend = "dace:cpu" - p1 = OrchestratedProgram(backend, DaCeOrchestration.Python) + p1 = OrchestratedProgram(Backend("st:dace:cpu:KIJ"), None) p1() backend_sanitized = backend.replace(":", "") @@ -169,7 +169,7 @@ def test_relocatability_tmpdir(tmpdir) -> None: shutil.copytree(tmpdir / ".gt_cache_000000", relocated_path, dirs_exist_ok=False) gt_config.cache_settings["root_path"] = relocated_path - p2 = OrchestratedProgram(backend, DaCeOrchestration.Python) + p2 = OrchestratedProgram(Backend("st:dace:cpu:KIJ"), None) p2() relocated_cache_path = ( diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 9f9c165c..3ce99399 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -11,13 +11,15 @@ RunMode, TilePartitioner, ) +from ndsl.config import Backend -def test_safety_checks(): - with pytest.raises(RuntimeError): - CompilationConfig(backend="numpy", device_sync=True) - with pytest.raises(RuntimeError): - CompilationConfig(backend="gt:cpu_ifirst", device_sync=True) +@pytest.mark.parametrize("backend", (Backend.python(), Backend("st:gt:cpu:KJI"))) +def test_safety_checks(backend: Backend) -> None: + with pytest.raises( + RuntimeError, match="Device sync is true on a CPU based backend" + ): + CompilationConfig(backend=backend, device_sync=True) @pytest.mark.parametrize( @@ -141,7 +143,7 @@ def test_determine_compiling_equivalent( def test_as_dict() -> None: config = CompilationConfig() asdict = config.as_dict() - assert asdict["backend"] == "numpy" + assert asdict["backend"] == Backend.python() assert asdict["rebuild"] is True assert asdict["validate_args"] is True assert asdict["format_source"] is False @@ -154,7 +156,7 @@ def test_as_dict() -> None: def test_from_dict() -> None: specification_dict = {} config = CompilationConfig.from_dict(specification_dict) - assert config.backend == "numpy" + assert config.backend == Backend.python() assert config.rebuild is False assert config.validate_args is True assert config.format_source is False @@ -162,9 +164,9 @@ def test_from_dict() -> None: assert config.run_mode == RunMode.BuildAndRun assert config.use_minimal_caching is False - specification_dict["backend"] = "gt:gpu" + specification_dict["backend"] = "st:gt:gpu:KJI" config = CompilationConfig.from_dict(specification_dict) - assert config.backend == "gt:gpu" + assert config.backend == Backend("st:gt:gpu:KJI") specification_dict["rebuild"] = True config = CompilationConfig.from_dict(specification_dict) diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index e89b69c2..8003937f 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -2,6 +2,7 @@ from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner from ndsl.comm.partitioner import Partitioner +from ndsl.config import Backend from ndsl.dsl.dace.dace_config import _determine_compiling_ranks from ndsl.dsl.dace.orchestration import orchestrate, orchestrate_function @@ -18,7 +19,7 @@ def foo() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", + backend=Backend("orch:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun, ) wrapped = orchestrate_function(config=dace_config)(foo) @@ -36,8 +37,8 @@ def foo() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", - orchestration=DaCeOrchestration.Python, + backend=Backend("st:dace:cpu:KIJ"), + orchestration=None, ) wrapped = orchestrate_function(config=dace_config)(foo) with unittest.mock.patch( @@ -50,7 +51,7 @@ def foo() -> None: def test_orchestrate_calls_dace() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", + backend=Backend("orch:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun, ) @@ -72,8 +73,8 @@ def foo(self) -> None: def test_orchestrate_does_not_call_dace() -> None: dace_config = DaceConfig( communicator=None, - backend="gtc:dace", - orchestration=DaCeOrchestration.Python, + backend=Backend("st:dace:cpu:KIJ"), + orchestration=None, ) class A: @@ -94,7 +95,7 @@ def foo(self) -> None: def test_orchestrate_distributed_build() -> None: dummy_dace_config = DaceConfig( communicator=None, - backend="gtc:dace", + backend=Backend("orch:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun, ) diff --git a/tests/dsl/test_skip_passes.py b/tests/dsl/test_skip_passes.py index 18aac43b..a8fbdc78 100644 --- a/tests/dsl/test_skip_passes.py +++ b/tests/dsl/test_skip_passes.py @@ -13,6 +13,7 @@ StencilConfig, StencilFactory, ) +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -24,7 +25,7 @@ def stencil_definition(a: FloatField): def test_skip_passes_becomes_oir_pipeline() -> None: - backend = "numpy" + backend = Backend.python() dace_config = DaceConfig(None, backend) config = StencilConfig( compilation_config=CompilationConfig(backend=backend), dace_config=dace_config diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index edfab60a..83866df7 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -11,6 +11,7 @@ StencilConfig, StencilFactory, ) +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_INTERFACE_DIM, @@ -44,7 +45,7 @@ def test_timing_collector() -> None: east_edge=True, ) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend="numpy", rebuild=True) + compilation_config=CompilationConfig(Backend.python(), rebuild=True) ) stencil_factory = StencilFactory(stencil_config, grid_indexing) @@ -117,7 +118,7 @@ def test_domain_size_comparison( call_count: int, ): quantity = Quantity( - np.zeros(extent), dimensions, "n/a", extent=extent, backend="debug" + np.zeros(extent), dimensions, "n/a", extent=extent, backend=Backend.python() ) stencil = FrozenStencil( copy_stencil, @@ -158,7 +159,11 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None: def test_stencil_2D_temporaries() -> None: domain = (2, 2, 5) quantity = Quantity( - np.zeros(domain), [I_DIM, J_DIM, K_DIM], "n/a", extent=domain, backend="debug" + np.zeros(domain), + [I_DIM, J_DIM, K_DIM], + "n/a", + extent=domain, + backend=Backend.python(), ) stencil = FrozenStencil( two_dim_temporaries_stencil, @@ -177,10 +182,14 @@ def test_stencil_2D_temporaries() -> None: def test_validation_call_count(iterations: tuple[int]): domain = (2, 2, 5) quantity = Quantity( - np.zeros(domain), [I_DIM, J_DIM, K_DIM], "n/a", extent=domain, backend="debug" + np.zeros(domain), + [I_DIM, J_DIM, K_DIM], + "n/a", + extent=domain, + backend=Backend.python(), ) stencil_config = StencilConfig( - compilation_config=CompilationConfig(backend="numpy", rebuild=True) + compilation_config=CompilationConfig(Backend.python(), rebuild=True) ) stencil = FrozenStencil( copy_stencil, diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 0124dd0f..9351724b 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -1,15 +1,16 @@ import pytest from ndsl import CompilationConfig, DaceConfig, StencilConfig +from ndsl.config.backend import Backend, backend_python @pytest.mark.parametrize("validate_args", [True, False]) @pytest.mark.parametrize("rebuild", [True, False]) @pytest.mark.parametrize("format_source", [True, False]) @pytest.mark.parametrize("compare_to_numpy", [True, False]) -@pytest.mark.parametrize("backend", ["numpy", "gt:gpu"]) +@pytest.mark.parametrize("backend", [Backend.python(), Backend("st:gt:gpu:KJI")]) def test_same_config_equal( - backend: str, + backend: Backend, rebuild: bool, validate_args: bool, format_source: bool, @@ -45,14 +46,14 @@ def test_same_config_equal( assert config == same_config -def test_different_backend_not_equal() -> None: - backend: str = "numpy" - rebuild: bool = True - validate_args: bool = True - format_source: bool = True - device_sync: bool = False - compare_to_numpy: bool = True - +def test_different_backend_not_equal( + backend: Backend = backend_python, + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( communicator=None, backend=backend, @@ -71,7 +72,7 @@ def test_different_backend_not_equal() -> None: different_config = StencilConfig( compilation_config=CompilationConfig( - backend="debug", + Backend("st:numpy:cpu:IJK"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -83,14 +84,14 @@ def test_different_backend_not_equal() -> None: assert config != different_config -def test_different_rebuild_not_equal() -> None: - backend: str = "numpy" - rebuild: bool = True - validate_args: bool = True - format_source: bool = True - device_sync: bool = False - compare_to_numpy: bool = True - +def test_different_rebuild_not_equal( + backend: Backend = backend_python, + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( communicator=None, backend=backend, @@ -130,11 +131,11 @@ def test_different_device_sync_not_equal() -> None: dace_config = DaceConfig( communicator=None, - backend="gt:gpu", + backend=Backend("st:gt:gpu:KJI"), ) config = StencilConfig( compilation_config=CompilationConfig( - backend="gt:gpu", + Backend("st:gt:gpu:KJI"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -146,7 +147,7 @@ def test_different_device_sync_not_equal() -> None: different_config = StencilConfig( compilation_config=CompilationConfig( - backend="gt:gpu", + Backend("st:gt:gpu:KJI"), rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -158,14 +159,14 @@ def test_different_device_sync_not_equal() -> None: assert config != different_config -def test_different_validate_args_not_equal() -> None: - backend: str = "numpy" - rebuild: bool = True - validate_args: bool = True - format_source: bool = True - device_sync: bool = False - compare_to_numpy: bool = True - +def test_different_validate_args_not_equal( + backend: Backend = backend_python, + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig( None, backend, @@ -196,14 +197,14 @@ def test_different_validate_args_not_equal() -> None: assert config != different_config -def test_different_format_source_not_equal() -> None: - backend: str = "numpy" - rebuild: bool = True - validate_args: bool = True - format_source: bool = True - device_sync: bool = False - compare_to_numpy: bool = True - +def test_different_format_source_not_equal( + backend: Backend = backend_python, + rebuild: bool = True, + validate_args: bool = True, + format_source: bool = True, + device_sync: bool = False, + compare_to_numpy: bool = True, +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( @@ -232,13 +233,14 @@ def test_different_format_source_not_equal() -> None: @pytest.mark.parametrize("compare_to_numpy", [True, False]) -def test_different_compare_to_numpy_not_equal(compare_to_numpy: bool) -> None: - backend: str = "numpy" - device_sync: bool = False - format_source: bool = True - rebuild: bool = True - validate_args: bool = False - +def test_different_compare_to_numpy_not_equal( + compare_to_numpy: bool, + backend: Backend = backend_python, + device_sync: bool = False, + format_source: bool = True, + rebuild: bool = True, + validate_args: bool = False, +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 394caccc..379fcb99 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -9,6 +9,7 @@ StencilConfig, StencilFactory, ) +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, horizontal, interval, region from ndsl.dsl.gt4py_utils import make_storage_from_shape @@ -16,7 +17,7 @@ from ndsl.dsl.typing import Field, FloatField -BACKENDS = ["numpy", "dace:cpu"] +BACKENDS = [Backend.python(), Backend("st:dace:cpu:KIJ")] def copy_stencil(q_in: FloatField, q_out: FloatField): @@ -39,7 +40,7 @@ def add_1_in_region_stencil(q_in: FloatField, q_out: FloatField): q_out = q_in + 1.0 -def setup_data_vars(backend: str) -> tuple[Field, Field]: +def setup_data_vars(backend: Backend) -> tuple[Field, Field]: shape = (7, 7, 3) q = make_storage_from_shape(shape, backend=backend) q[:] = 1.0 @@ -48,7 +49,7 @@ def setup_data_vars(backend: str) -> tuple[Field, Field]: return q, q_ref -def get_stencil_factory(backend: str) -> StencilFactory: +def get_stencil_factory(backend: Backend) -> StencilFactory: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( @@ -72,7 +73,7 @@ def get_stencil_factory(backend: str) -> StencilFactory: @pytest.mark.parametrize("backend", BACKENDS) -def test_get_stencils_with_varied_bounds(backend: str) -> None: +def test_get_stencils_with_varied_bounds(backend: Backend) -> None: origins = [(2, 2, 0), (1, 1, 0)] domains = [(1, 1, 3), (2, 2, 3)] factory = get_stencil_factory(backend) @@ -92,7 +93,7 @@ def test_get_stencils_with_varied_bounds(backend: str) -> None: @pytest.mark.parametrize("backend", BACKENDS) -def test_get_stencils_with_varied_bounds_and_regions(backend: str) -> None: +def test_get_stencils_with_varied_bounds_and_regions(backend: Backend) -> None: factory = get_stencil_factory(backend) origins = [(3, 3, 0), (2, 2, 0)] domains = [(1, 1, 3), (2, 2, 3)] @@ -113,7 +114,7 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str) -> None: @pytest.mark.parametrize("backend", BACKENDS) -def test_stencil_vertical_bounds(backend: str) -> None: +def test_stencil_vertical_bounds(backend: Backend) -> None: factory = get_stencil_factory(backend) origins = [(3, 3, 0), (2, 2, 1)] domains = [(1, 1, 3), (2, 2, 4)] @@ -133,7 +134,7 @@ def test_stencil_vertical_bounds(backend: str) -> None: @pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_dims_halo( - backend: str, enabled: bool + backend: Backend, enabled: bool ) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( @@ -170,7 +171,7 @@ def test_stencil_factory_numpy_comparison_from_dims_halo( @pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_origin_domain( - backend: str, enabled: bool + backend: Backend, enabled: bool ) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( @@ -203,7 +204,9 @@ def test_stencil_factory_numpy_comparison_from_origin_domain( @pytest.mark.parametrize("backend", BACKENDS) -def test_stencil_factory_numpy_comparison_runs_without_exceptions(backend: str) -> None: +def test_stencil_factory_numpy_comparison_runs_without_exceptions( + backend: Backend, +) -> None: dace_config = DaceConfig(communicator=None, backend=backend) config = StencilConfig( compilation_config=CompilationConfig( diff --git a/tests/dsl/test_stencil_tables.py b/tests/dsl/test_stencil_tables.py index 0b31826a..e0cafba8 100644 --- a/tests/dsl/test_stencil_tables.py +++ b/tests/dsl/test_stencil_tables.py @@ -1,5 +1,5 @@ +import gt4py.storage as gt_storage import numpy as np -from gt4py.storage import ones, zeros from ndsl import ( CompilationConfig, @@ -11,6 +11,7 @@ StencilFactory, orchestrate, ) +from ndsl.config import Backend from ndsl.dsl.gt4py import FORWARD, PARALLEL, Field, GlobalTable, computation, interval from ndsl.dsl.stencil import CompareToNumpyStencil from tests.dsl import utils @@ -24,7 +25,7 @@ def _stencil(inp: GlobalTable[np.int32, (5,)], out: Field[np.float64]) -> None: def _build_stencil( - backend: str, orchestrated: DaCeOrchestration + backend: Backend, orchestrated: DaCeOrchestration ) -> tuple[FrozenStencil | CompareToNumpyStencil, GridIndexing, StencilConfig]: # Make stencil and verify it ran grid_indexing = GridIndexing( @@ -51,15 +52,19 @@ def _build_stencil( class OrchestratedProgram: - def __init__(self, backend, orchestration: DaCeOrchestration): + def __init__(self, backend: Backend, orchestration: DaCeOrchestration): self.stencil, grid_indexing, stencil_config = _build_stencil( backend, orchestration ) orchestrate(obj=self, config=stencil_config.dace_config) - self.inp = ones(shape=(5,), dtype=np.int32, backend=backend) + self.inp = gt_storage.ones( + shape=(5,), dtype=np.int32, backend=backend.as_gt4py() + ) self.inp[1] = 42 - self.out = utils.make_storage(zeros, grid_indexing, stencil_config, dtype=float) + self.out = utils.make_storage( + gt_storage.zeros, grid_indexing, stencil_config, dtype=float + ) def __call__(self): self.stencil(self.inp, self.out) @@ -67,7 +72,7 @@ def __call__(self): def test_stecil_with_table_orchestrated() -> None: program = OrchestratedProgram( - backend="dace:cpu", orchestration=DaCeOrchestration.BuildAndRun + Backend("st:dace:cpu:KIJ"), orchestration=DaCeOrchestration.BuildAndRun ) # run the orchestrated stencil diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index bad1a1c7..80070e7e 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -12,6 +12,7 @@ Quantity, StencilConfig, ) +from ndsl.config.backend import Backend from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape from ndsl.dsl.stencil import _convert_quantities_to_storage @@ -36,8 +37,8 @@ def get_stencil_config( *, - backend: str, - orchestration: DaCeOrchestration = DaCeOrchestration.Python, + backend: Backend, + orchestration: DaCeOrchestration = DaCeOrchestration.BuildAndRun, **kwargs, ) -> StencilConfig: dace_config = DaceConfig(None, backend=backend, orchestration=orchestration) @@ -81,19 +82,19 @@ def __init__(self, *, axes: tuple[str, ...] = (), data_dims: tuple[int, ...] = ( "field_info, origin, field_origins", [ pytest.param( - {"a": MockFieldInfo(axes=("I"))}, + {"a": MockFieldInfo(axes=("I",))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (1,)}, id="single_field_I", ), pytest.param( - {"a": MockFieldInfo(axes=("J"))}, + {"a": MockFieldInfo(axes=("J",))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (2,)}, id="single_field_J", ), pytest.param( - {"a": MockFieldInfo(axes=("K"))}, + {"a": MockFieldInfo(axes=("K",))}, (1, 2, 3), {"_all_": (1, 2, 3), "a": (3,)}, id="single_field_K", @@ -111,7 +112,7 @@ def __init__(self, *, axes: tuple[str, ...] = (), data_dims: tuple[int, ...] = ( id="single_field_origin_mapping", ), pytest.param( - {"a": MockFieldInfo(axes=("I", "J", "K")), "b": MockFieldInfo(axes=("I"))}, + {"a": MockFieldInfo(axes=("I", "J", "K")), "b": MockFieldInfo(axes=("I",))}, {"_all_": (1, 2, 3), "a": (1, 2, 3)}, {"_all_": (1, 2, 3), "a": (1, 2, 3), "b": (1,)}, id="two_fields_update_origin_mapping", @@ -154,10 +155,10 @@ def copy_stencil(q_in: FloatField, q_out: FloatField): @pytest.mark.parametrize("validate_args", [True, False]) def test_copy_frozen_stencil(validate_args: bool) -> None: - backend: str = "numpy" - rebuild: bool = False - format_source: bool = False - device_sync: bool = False + backend = Backend.python() + rebuild = False + format_source = False + device_sync = False config = get_stencil_config( backend=backend, @@ -182,10 +183,10 @@ def test_copy_frozen_stencil(validate_args: bool) -> None: def test_frozen_stencil_raises_if_given_origin() -> None: - backend: str = "numpy" - rebuild: bool = False - format_source: bool = False - device_sync: bool = False + backend = Backend.python() + rebuild = False + format_source = False + device_sync = False # only guaranteed when validating args config = get_stencil_config( @@ -209,10 +210,10 @@ def test_frozen_stencil_raises_if_given_origin() -> None: def test_frozen_stencil_raises_if_given_domain() -> None: - backend: str = "numpy" - rebuild: bool = False - format_source: bool = False - device_sync: bool = False + backend = Backend.python() + rebuild = False + format_source = False + device_sync = False # only guaranteed when validating args config = get_stencil_config( @@ -245,8 +246,10 @@ def test_frozen_stencil_kwargs_passed_to_init( format_source: bool, device_sync: bool, ) -> None: + backend = Backend.python() + config = get_stencil_config( - backend="numpy", + backend=backend, rebuild=rebuild, validate_args=validate_args, format_source=format_source, @@ -296,7 +299,7 @@ def field_after_parameter_stencil(q_in: FloatField, param: float, q_out: FloatFi def test_frozen_field_after_parameter() -> None: config = get_stencil_config( - backend="numpy", + backend=Backend.python(), rebuild=False, validate_args=False, format_source=False, @@ -311,20 +314,19 @@ def test_frozen_field_after_parameter() -> None: ) -@pytest.mark.parametrize("backend", ("numpy", "gt:gpu")) -def test_backend_options( - backend: str, - rebuild: bool = True, - validate_args: bool = True, -) -> None: +@pytest.mark.parametrize("backend", (Backend.python(), Backend("st:gt:gpu:KJI"))) +def test_backend_options(backend: Backend) -> None: + rebuild = True + validate_args = True + expected_options = { - "numpy": { - "backend": "numpy", + Backend.python(): { + "backend": "debug", "rebuild": True, "format_source": False, "name": "tests.dsl.test_stencil_wrapper.copy_stencil", }, - "gt:gpu": { + "st:gt:gpu:KJI": { "backend": "gt:gpu", "rebuild": True, "device_sync": False, @@ -340,9 +342,10 @@ def test_backend_options( assert actual == expected -def test_illegal_backend_options(): - with pytest.raises(ValueError): - get_stencil_config(backend="illegal") +def test_illegal_backend_options() -> None: + unknown_backend = "bad:back:end:now" + with pytest.raises(ValueError, match=f"Unknown {unknown_backend}, options are:*"): + get_stencil_config(backend=Backend(unknown_backend)) def get_mock_quantity(): diff --git a/tests/dsl/utils.py b/tests/dsl/utils.py index 5e8535e5..32c880cb 100644 --- a/tests/dsl/utils.py +++ b/tests/dsl/utils.py @@ -10,7 +10,7 @@ def make_storage( aligned_index=(0, 0, 0), ): return func( - backend=stencil_config.compilation_config.backend, + backend=stencil_config.compilation_config.backend.as_gt4py(), shape=grid_indexing.domain, dtype=dtype, aligned_index=aligned_index, diff --git a/tests/grid/test_eta.py b/tests/grid/test_eta.py index 38e6c75b..764edc4f 100755 --- a/tests/grid/test_eta.py +++ b/tests/grid/test_eta.py @@ -10,6 +10,7 @@ SubtileGridSizer, TilePartitioner, ) +from ndsl.config import Backend from ndsl.grid import MetricTerms @@ -29,7 +30,7 @@ def test_set_hybrid_pressure_coefficients_nofile(): eta_file = Path("NULL") - backend = "numpy" + backend = Backend.python() layout = (1, 1) @@ -75,7 +76,7 @@ def test_set_hybrid_pressure_coefficients_not_mono(): eta_file = str(Path.cwd()) + "/tests/data/eta/non_mono_eta79.nc" - backend = "numpy" + backend = Backend.python() layout = (1, 1) diff --git a/tests/mpi/test_eta.py b/tests/mpi/test_eta.py index 9c2332a4..428d7fce 100644 --- a/tests/mpi/test_eta.py +++ b/tests/mpi/test_eta.py @@ -12,6 +12,7 @@ SubtileGridSizer, TilePartitioner, ) +from ndsl.config import Backend from ndsl.grid import MetricTerms @@ -28,7 +29,7 @@ def test_set_hybrid_pressure_coefficients_correct(levels): eta_file = Path.cwd() / "tests" / "data" / "eta" / f"eta{levels}.nc" eta_data = xr.open_dataset(eta_file) - backend = "numpy" + backend = Backend.python() layout = (1, 1) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 395dfbdf..a3f7f455 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -9,6 +9,7 @@ ) from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.mpi import MPI, MPIComm +from ndsl.config import Backend from ndsl.dsl.typing import Float @@ -46,8 +47,10 @@ def communicator(cube_partitioner: CubedSpherePartitioner) -> CubedSphereCommuni @pytest.mark.parallel -@pytest.mark.parametrize("backend", ["dace:cpu", "gt:cpu_kfirst", "numpy"]) -def test_all_reduce(backend: str, communicator: CubedSphereCommunicator) -> None: +@pytest.mark.parametrize( + "backend", [Backend("st:dace:cpu:KIJ"), Backend("st:gt:cpu:IJK"), Backend.python()] +) +def test_all_reduce(backend: Backend, communicator: CubedSphereCommunicator) -> None: base_array = np.array([i for i in range(5)], dtype=Float) testQuantity_1D = Quantity( diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 80805f22..11b91a75 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -10,6 +10,7 @@ ) from ndsl.comm._boundary_utils import get_boundary_slice from ndsl.comm.mpi import MPI, MPIComm +from ndsl.config import Backend from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -268,7 +269,12 @@ def depth_quantity( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan return Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.python(), ) @@ -311,7 +317,12 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype): outside of it.""" data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.python(), ) quantity.view[:] = 0.0 return quantity diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index 88439345..02da5384 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -3,6 +3,7 @@ from ndsl import Quantity from ndsl.comm._boundary_utils import _shift_boundary_slice, get_boundary_slice +from ndsl.config import Backend from ndsl.constants import ( EAST, I_DIM, @@ -35,7 +36,7 @@ def test_boundary_data_1_by_1_array_1_halo(): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ) for side in ( WEST, @@ -71,7 +72,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy): units="m", origin=(1, 1, 1), extent=(1, 1, 1), - backend="debug", + backend=Backend.python(), ) for side in ( WEST, @@ -109,7 +110,7 @@ def test_boundary_data_2_by_2_array_2_halo(): units="m", origin=(2, 2), extent=(2, 2), - backend="debug", + backend=Backend.python(), ) for side in ( WEST, diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index f0c7bb13..6baa5016 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -4,6 +4,7 @@ import numpy as np from ndsl import Quantity +from ndsl.config import Backend def test_deepcopy_copy_is_editable_by_view(): @@ -14,7 +15,7 @@ def test_deepcopy_copy_is_editable_by_view(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend="debug", + backend=Backend.python(), ) quantity_copy = copy.deepcopy(quantity) # assertion below is only valid if we're overwriting the entire data through view @@ -32,7 +33,7 @@ def test_deepcopy_copy_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend="debug", + backend=Backend.python(), ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 @@ -48,7 +49,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", - backend="debug", + backend=Backend.python(), ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py index 46414c38..4be0e9ee 100644 --- a/tests/quantity/test_local.py +++ b/tests/quantity/test_local.py @@ -12,6 +12,7 @@ StencilFactory, ) from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.typing import Float @@ -25,7 +26,7 @@ def test_dace_data_descriptor_is_transient() -> None: extent=(nx,), dims=("dim_X",), units="n/a", - backend="debug", + backend=Backend.python(), ) array = local.__descriptor__() assert array.transient @@ -115,7 +116,7 @@ def __call__(self) -> None: def test_proper_initialization() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - 3, 3, 5, 0, backend="debug" + 3, 3, 5, 0, Backend.python() ) the_code = TheCode(stencil_factory, quantity_factory) assert the_code.check_local_right_after_init() @@ -123,7 +124,7 @@ def test_proper_initialization() -> None: def test_forbidden_access_to_locals() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - 3, 3, 5, 0, backend="debug" + 3, 3, 5, 0, Backend.python() ) the_code = TheCode(stencil_factory, quantity_factory) @@ -155,7 +156,7 @@ def test_forbidden_access_to_locals() -> None: def test_local_state_as_regular_state() -> None: - _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, backend="debug") + _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, Backend.python()) with pytest.raises( RuntimeError, match="LocalState allocated outside of NDSLRuntime: forbidden", @@ -177,7 +178,9 @@ def test_local_state_as_regular_state() -> None: def test_nested_local_state_as_regular_state() -> None: - _, quantity_factory = get_factories_single_tile(3, 3, 5, 0, backend="debug") + _, quantity_factory = get_factories_single_tile( + 3, 3, 5, 0, backend=Backend.python() + ) with pytest.warns(UserWarning, match="LocalState is allocated as a regular State"): nested = NestedLocals.make_as_state(quantity_factory) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 9e5df3ae..fc04ddfb 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -2,6 +2,7 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.quantity.bounds import _shift_slice @@ -59,7 +60,12 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): return Quantity( - data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend=Backend.python(), ) @@ -79,7 +85,7 @@ def test_smaller_data_raises(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend="debug", + backend=Backend.python(), ) @@ -93,7 +99,7 @@ def test_smaller_dims_raises(data, origin, extent, dims, units): extent=extent, dims=dims[:-1], units=units, - backend="debug", + backend=Backend.python(), ) @@ -105,7 +111,7 @@ def test_smaller_origin_raises(data, origin, extent, dims, units): extent=extent, dims=dims, units=units, - backend="debug", + backend=Backend.python(), ) @@ -117,7 +123,7 @@ def test_smaller_extent_raises(data, origin, extent, dims, units): extent=extent[:-1], dims=dims, units=units, - backend="debug", + backend=Backend.python(), ) @@ -264,15 +270,18 @@ def test_shift_slice( @pytest.mark.parametrize( "quantity", [ - Quantity(np.array(5), dims=[], units="", backend="debug"), + Quantity(np.array(5), dims=[], units="", backend=Backend.python()), Quantity( - np.array([1, 2, 3]), dims=["dimension"], units="degK", backend="debug" + np.array([1, 2, 3]), + dims=["dimension"], + units="degK", + backend=Backend.python(), ), Quantity( np.random.randn(3, 2, 4), dims=["dim1", "dim_2", "dimension_3"], units="m", - backend="debug", + backend=Backend.python(), ), Quantity( np.random.randn(8, 6, 6), @@ -280,7 +289,7 @@ def test_shift_slice( units="km", origin=(2, 2, 2), extent=(4, 2, 2), - backend="debug", + backend=Backend.python(), ), ], ) @@ -296,7 +305,9 @@ def test_to_data_array(quantity): def test_data_setter(): - quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend="debug") + quantity = Quantity( + np.ones((5,)), dims=["dim1"], units="", backend=Backend.python() + ) # After allocation - field and data are the same (origin is 0) assert quantity.data.shape == quantity.field.shape diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py index 3c123446..c40cd9f8 100644 --- a/tests/quantity/test_state.py +++ b/tests/quantity/test_state.py @@ -5,6 +5,7 @@ from ndsl import Quantity, State from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM, K_INTERFACE_DIM, Float @@ -59,7 +60,7 @@ class InnerB: def test_basic_state(tmpdir): K_size = 3 _, quantity_factory = get_factories_single_tile( - 5, 5, K_size, 0, backend="dace:cpu_KJI" + 5, 5, K_size, 0, Backend("st:dace:cpu:KJI") ) # Test allocator @@ -140,7 +141,7 @@ class InnerB: def test_state_ddim(): _, quantity_factory = get_factories_single_tile( - 5, 5, 3, 0, backend="dace:cpu_kfirst" + 5, 5, 3, 0, backend=Backend("st:dace:cpu:IJK") ) # Test allocator diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 39ea2126..13a23852 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -2,6 +2,7 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.optional_imports import cupy as cp @@ -54,7 +55,12 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): return Quantity( - data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend=Backend.python(), ) @@ -74,7 +80,7 @@ def test_modifying_numpy_data_modifies_view_and_field(): extent=shape, dims=["dim1", "dim2"], units="units", - backend="numpy", + backend=Backend.python(), ) assert np.all(quantity.data == 0) quantity.data[0, 0] = 1 @@ -101,7 +107,7 @@ def test_data_and_field_access_right_full_array_and_compute_domain(): extent=(5, 5), dims=["dim1", "dim2"], units="units", - backend="numpy", + backend=Backend.python(), ) assert np.all(quantity.data == 0) # Write compute domain - test data is written with the offset @@ -130,7 +136,7 @@ def test_field_exists(quantity, backend): def test_accessing_data_does_not_break_view( - data, origin, extent, dims, units, gt4py_backend + data, origin, extent, dims, units, ndsl_backend ): quantity = Quantity( data, @@ -138,7 +144,7 @@ def test_accessing_data_does_not_break_view( extent=extent, dims=dims, units=units, - backend=gt4py_backend, + backend=ndsl_backend, ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] @@ -154,6 +160,6 @@ def test_numpy_data_becomes_cupy_with_gpu_backend(data, origin, extent, dims, un extent=extent, dims=dims, units=units, - backend="gt:gpu", + backend=Backend("st:dace:gpu:KIJ"), ) assert isinstance(quantity.data, cp.ndarray) diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 1e89f269..d24305fe 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -1,6 +1,7 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_DIMS, @@ -86,7 +87,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): units="unit_string", origin=initial_origin, extent=initial_extent, - backend="debug", + backend=Backend.python(), ) @@ -219,7 +220,10 @@ def test_transpose_invalid_cases( def test_transpose_retains_attrs(numpy): quantity = Quantity( - numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string", backend="debug" + numpy.random.randn(3, 4), + dims=["x", "y"], + units="unit_string", + backend=Backend.python(), ) quantity._attrs = {"long_name": "500 mb height"} transposed = quantity.transpose(["y", "x"]) diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index b7a6875b..49602fd3 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -2,15 +2,14 @@ import pytest from ndsl import Quantity +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM @pytest.fixture def quantity(request): return Quantity( - request.param[0], - dims=request.param[1], - units="units", + request.param[0], dims=request.param[1], units="units", backend=Backend.python() ) @@ -183,7 +182,7 @@ def quantity(request): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ) ], ) @@ -217,7 +216,7 @@ def test_many_indices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ) ], ) @@ -744,7 +743,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (0, 0), 4, @@ -757,7 +756,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, -1), 0, @@ -770,7 +769,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-1, 0), slice(-1, 0)), np.array([[0]]), @@ -783,7 +782,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, 0), 1, @@ -804,7 +803,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[1, 2, 3], [6, 7, 8]]), @@ -842,7 +841,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, 0), 4, @@ -855,7 +854,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (0, -1), 6, @@ -868,7 +867,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -881,7 +880,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, -1), 3, @@ -902,7 +901,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -940,7 +939,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, 0), 4, @@ -953,7 +952,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (0, -1), 6, @@ -966,7 +965,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -979,7 +978,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, 0), 4, @@ -992,7 +991,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, -1), 3, @@ -1013,7 +1012,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -1051,7 +1050,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, -1), 4, @@ -1064,7 +1063,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (0, 0), 8, @@ -1077,7 +1076,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(0, 1), slice(0, 1)), np.array([[8]]), @@ -1090,7 +1089,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, -1), 4, @@ -1103,7 +1102,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (-1, 0), 5, @@ -1124,7 +1123,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-2, 0), slice(-1, 2)), np.array([[7, 8, 9], [12, 13, 14]]), @@ -1162,7 +1161,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (0, 0), 4, @@ -1175,7 +1174,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(0, 0), slice(0, 0)), 4, @@ -1188,7 +1187,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-1, 1), slice(-1, 1)), np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), @@ -1209,7 +1208,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), - backend="debug", + backend=Backend.python(), ), (slice(-2, 0), slice(0, 1)), np.array([[2, 3], [7, 8], [12, 13]]), @@ -1230,7 +1229,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(3, 3), - backend="debug", + backend=Backend.python(), ), (0,), np.array([6, 7, 8]), diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index 17d521e7..7a76cac5 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -3,6 +3,7 @@ from ndsl import QuantityFactory, StencilFactory from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, set_4d_field_size @@ -19,7 +20,13 @@ @pytest.fixture def boilerplate() -> tuple[StencilFactory, QuantityFactory]: - return get_factories_single_tile(nx=1, ny=1, nz=10, nhalo=0, backend="dace:cpu") + return get_factories_single_tile( + nx=1, + ny=1, + nz=10, + nhalo=0, + backend=Backend("st:dace:cpu:KIJ"), + ) class ColumnOperations: diff --git a/tests/test_boilerplate.py b/tests/test_boilerplate.py index 417387ae..83ab77e4 100644 --- a/tests/test_boilerplate.py +++ b/tests/test_boilerplate.py @@ -2,6 +2,7 @@ import pytest from ndsl import QuantityFactory, StencilFactory +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -44,8 +45,8 @@ def test_boilerplate_import_numpy(): ) # Ensure backend is propagated to StencilFactory and QuantityFactory - assert stencil_factory.backend == "numpy" - assert quantity_factory.backend == "numpy" + assert stencil_factory.backend == Backend.python() + assert quantity_factory.backend == Backend.python() _copy_ops(stencil_factory, quantity_factory) @@ -63,9 +64,8 @@ def test_boilerplate_import_orchestrated_cpu(): ) # Ensure backend is propagated to StencilFactory and QuantityFactory - assert stencil_factory.backend == "dace:cpu" - assert quantity_factory.backend == "dace:cpu" - + assert stencil_factory.backend == Backend("orch:dace:cpu:IJK") + assert quantity_factory.backend == Backend("orch:dace:cpu:IJK") _copy_ops(stencil_factory, quantity_factory) @@ -74,5 +74,5 @@ def test_boilerplate_non_dace_based_orchestration_raises(): with pytest.raises(ValueError, match="Only .* backends can be orchestrated."): get_factories_single_tile_orchestrated( - nx=5, ny=5, nz=2, nhalo=1, backend="numpy" + nx=5, ny=5, nz=2, nhalo=1, backend=Backend.python() ) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index dc63635d..25863d1a 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -5,7 +5,7 @@ @pytest.fixture -def contiguous_array(numpy, backend): +def contiguous_array(numpy, backend: str): if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") array = numpy.empty([3, 4, 5]) @@ -65,7 +65,7 @@ def test_recvbuf_no_buffer(allocator, contiguous_array): assert recvbuf is contiguous_array -def test_buffer_cache_appends(allocator, backend): +def test_buffer_cache_appends(allocator, backend: str): """ Test buffer with the same key are appended while not in use for potential reuse """ @@ -91,7 +91,7 @@ def test_buffer_cache_appends(allocator, backend): assert len(BUFFER_CACHE[first_buffer._key]) == 2 -def test_buffer_reuse(allocator, backend): +def test_buffer_reuse(allocator, backend: str): """Test we reuse the buffer when available instead of reallocating one""" if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") @@ -119,7 +119,7 @@ def test_buffer_reuse(allocator, backend): Buffer.push_to_cache(repop_buffer) -def test_cacheline_differentiation(allocator, backend): +def test_cacheline_differentiation(allocator, backend: str): """Test allocation with different keys creates different cache lines""" if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") @@ -169,7 +169,9 @@ def test_cacheline_differentiation(allocator, backend): pytest.param(((10, 10, 10), float), ((10, 10, 5), float), id="different_shape"), ], ) -def test_new_args_gives_different_buffer(allocator, backend, first_args, second_args): +def test_new_args_gives_different_buffer( + allocator, backend: str, first_args, second_args +): if backend == "gt4py_cupy": pytest.skip("gt4py gpu backend cannot produce contiguous arrays") BUFFER_CACHE.clear() diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index e8b415c2..93cd9fd6 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -12,6 +12,7 @@ TilePartitioner, ) from ndsl.comm import CachingCommReader, CachingCommWriter +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM @@ -29,7 +30,7 @@ def test_halo_update_integration(): units="", origin=origin, extent=extent, - backend="debug", + backend=Backend.python(), ) for _ in range(n_ranks) ] diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index a7eada45..e8af4ee2 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -10,6 +10,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import ( HORIZONTAL_DIMS, I_DIM, @@ -163,7 +164,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), - backend="numpy", + backend=Backend.python(), ) diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index 6a04aeb3..e188c02f 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -3,6 +3,7 @@ import pytest from ndsl import GridSizer, QuantityFactory, SubtileGridSizer +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_INTERFACE_DIM, @@ -65,7 +66,7 @@ def namelist(nx_tile, ny_tile, nz, layout): def sizer( request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths ) -> GridSizer: - backend = "numpy" # original utest case + backend = Backend.python() # original utest case if request.param == "from_tile_params": return SubtileGridSizer.from_tile_params( nx_tile=nx_tile, @@ -196,7 +197,7 @@ def test_subtile_dimension_sizer_shape(sizer, dim_case): def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, backend="numpy") + allocator = QuantityFactory(sizer, backend=Backend.python()) quantity = allocator.zeros(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -207,7 +208,7 @@ def test_allocator_zeros(numpy, sizer, dim_case, units, dtype): def test_allocator_ones(numpy, sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, backend="numpy") + allocator = QuantityFactory(sizer, backend=Backend.python()) quantity = allocator.ones(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -218,7 +219,7 @@ def test_allocator_ones(numpy, sizer, dim_case, units, dtype): def test_allocator_empty(sizer, dim_case, units, dtype): - allocator = QuantityFactory(sizer, backend="numpy") + allocator = QuantityFactory(sizer, backend=Backend.python()) quantity = allocator.empty(dim_case.dims, units, dtype=dtype) assert quantity.units == units assert quantity.dims == dim_case.dims @@ -228,7 +229,7 @@ def test_allocator_empty(sizer, dim_case, units, dtype): def test_allocator_data_dimensions_operations(sizer): - quantity_factory = QuantityFactory(sizer, backend="numpy") + quantity_factory = QuantityFactory(sizer, backend=Backend.python()) quantity_factory.add_data_dimensions({"D0": 11}) assert "D0" in quantity_factory.sizer.data_dimensions.keys() assert quantity_factory.sizer.data_dimensions["D0"] == 11 @@ -254,7 +255,7 @@ def test_pad_non_interface_dimensions(): n_halo=0, layout=(layout_xy, layout_xy), data_dimensions={"some_dim": dd}, - backend="numpy", # original utest case + backend=Backend.python(), # original utest case ) padded_shape = padded_grid_sizer.get_shape([I_DIM, J_DIM, K_DIM, "some_dim"]) assert padded_shape[0] == nx // layout_xy + 1 @@ -269,7 +270,7 @@ def test_pad_non_interface_dimensions(): n_halo=0, layout=(layout_xy, layout_xy), data_dimensions={"some_dim": dd}, - backend="dace:cpu_KJI", # Fortran-friendly backend + backend=Backend("st:dace:cpu:KJI"), # Fortran-friendly backend ) non_padded_shape = non_padded_grid_sizer.get_shape( [I_DIM, J_DIM, K_DIM, "some_dim"] diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index e72bf438..ba899f31 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -16,6 +16,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.optional_imports import cupy as cp from ndsl.performance import Timer @@ -130,7 +131,7 @@ def test_halo_update_only_communicate_on_gpu( units="m", origin=(3, 3, 1), extent=(3, 3, 1), - backend="gt:gpu", + backend=Backend("st:gt:gpu:KJI"), ) halo_updater_list = [] for communicator in gpu_communicators: @@ -162,7 +163,7 @@ def test_halo_update_communicate_though_cpu( units="m", origin=(3, 3, 0), extent=(3, 3, 0), - backend="numpy", + backend=Backend("st:numpy:cpu:IJK"), ) halo_updater_list = [] for communicator in cpu_communicators: diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index d1ea9c09..d9460659 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -7,6 +7,7 @@ from ndsl import HaloExchangeSpec, Quantity from ndsl.buffer import Buffer from ndsl.comm import _boundary_utils +from ndsl.config import Backend from ndsl.constants import ( EAST, I_DIM, @@ -152,7 +153,7 @@ def _shape_length(shape: Tuple[int]) -> int: @pytest.fixture -def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): +def quantity(dims, units, origin, extent, shape, dtype, ndsl_backend: Backend): """A list of quantities whose values are 42.42 in the computational domain and 1 outside of it.""" sz = _shape_length(shape) @@ -164,7 +165,7 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): units=units, origin=origin, extent=extent, - backend=gt4py_backend, + backend=ndsl_backend, ) diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index 4f820f7e..97a8ed5c 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -14,6 +14,7 @@ ) from ndsl.buffer import BUFFER_CACHE from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.config import Backend from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -303,7 +304,12 @@ def depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.python(), ) return_list.append(quantity) return return_list @@ -336,7 +342,12 @@ def tile_depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.python(), ) return_list.append(quantity) return return_list @@ -476,7 +487,12 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, for _rank in range(total_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.python(), ) quantity.view[:] = 0.0 return_list.append(quantity) @@ -493,7 +509,12 @@ def zeros_quantity_tile_list( for _rank in range(single_tile_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" + data, + dims=dims, + units=units, + origin=origin, + extent=extent, + backend=Backend.python(), ) quantity.view[:] = 0.0 return_list.append(quantity) diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index ad086000..372cc67d 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -7,6 +7,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_INTERFACE_DIM, @@ -126,7 +127,7 @@ def rank_quantity_list(total_ranks, numpy, dtype): units="m", origin=(1, 1), extent=(1, 1), - backend="debug", + backend=Backend.python(), ) quantity_list.append(quantity) return quantity_list diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py index 0331bd05..67e4f226 100644 --- a/tests/test_ndsl_runtime.py +++ b/tests/test_ndsl_runtime.py @@ -7,6 +7,7 @@ get_factories_single_tile, get_factories_single_tile_orchestrated, ) +from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField @@ -52,7 +53,7 @@ def run(self, A: Any, B: Any) -> None: def test_runtime_make_local() -> None: stencil_factory, quantity_factory = get_factories_single_tile( - nx=5, ny=5, nz=3, nhalo=0, backend="numpy" + nx=5, ny=5, nz=3, nhalo=0, backend=Backend.python() ) A_ = quantity_factory.ones(dims=[I_DIM, J_DIM, K_DIM], units="n/a") B_ = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="n/a") @@ -73,7 +74,7 @@ def test_runtime_make_local() -> None: def test_runtime_has_orchestrated_call() -> None: stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - nx=5, ny=5, nz=3, nhalo=0, backend="dace:cpu_kfirst" + nx=5, ny=5, nz=3, nhalo=0, backend=Backend.cpu() ) A_ = quantity_factory.ones(dims=[I_DIM, J_DIM, K_DIM], units="n/a") B_ = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="n/a") @@ -89,7 +90,7 @@ def test_runtime_has_orchestrated_call() -> None: def test_runtime_does_not_orchestrate_when_call_is_not_present() -> None: stencil_factory, _ = get_factories_single_tile_orchestrated( - nx=5, ny=5, nz=3, nhalo=0, backend="dace:cpu_kfirst" + nx=5, ny=5, nz=3, nhalo=0, backend=Backend.cpu() ) code = Code_NoCall(stencil_factory) diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index bda6b55c..41e0760e 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -15,6 +15,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import I_DIM, I_INTERFACE_DIM, J_DIM, J_INTERFACE_DIM, K_DIM @@ -40,7 +41,7 @@ def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): units = "m" - backend = "debug" + backend = Backend.python() nz, ny, nx = shape ny_rank = int(ny / layout[0] + ny_rank_add) nx_rank = int(nx / layout[1] + nx_rank_add) diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index e082f6f0..35b12d36 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -7,6 +7,7 @@ get_tile_index, tile_extent_from_rank_metadata, ) +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_INTERFACE_DIM, @@ -976,7 +977,11 @@ def test_subtile_extent_with_tile_dimensions( ): data_array = np.zeros((tile_extent)) quantity = Quantity( - data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0], backend="debug" + data_array, + array_dims, + "dimensionless", + origin=[0, 0, 0, 0], + backend=Backend.python(), ) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index de4429a6..c3f0a884 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -7,6 +7,7 @@ Quantity, TilePartitioner, ) +from ndsl.config import Backend from ndsl.constants import I_DIM, I_INTERFACE_DIM, J_DIM, J_INTERFACE_DIM from ndsl.performance import Timer @@ -81,7 +82,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), - backend="debug", + backend=Backend.python(), ) y_data = numpy.empty((2, 3), dtype=dtype) y_data[:] = rank @@ -91,7 +92,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), - backend="debug", + backend=Backend.python(), ) quantity_list.append((x_quantity, y_quantity)) return quantity_list @@ -149,7 +150,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), - backend="debug", + backend=Backend.python(), ) y_data = 6 * total_ranks + numpy.array([[0, 1, 2], [3, 4, 5]]) + 6 * rank y_quantity = Quantity( @@ -158,7 +159,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), - backend="debug", + backend=Backend.python(), ) quantity_list.append((x_quantity, y_quantity)) return quantity_list diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index 6abdb9dd..48153ade 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -1,6 +1,7 @@ import pytest from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner +from ndsl.config import Backend from ndsl.constants import I_DIM, I_INTERFACE_DIM, J_DIM, J_INTERFACE_DIM @@ -36,13 +37,13 @@ def test_interface_state_two_by_two_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0] + 1, layout[1] + 1]), dims=[J_INTERFACE_DIM, I_INTERFACE_DIM], units="dimensionless", - backend="debug", + backend=Backend.python(), ), "pos_i": Quantity( numpy.empty([layout[0] + 1, layout[1] + 1], dtype=numpy.int32), dims=[J_INTERFACE_DIM, I_INTERFACE_DIM], units="dimensionless", - backend="debug", + backend=Backend.python(), ), } @@ -82,19 +83,19 @@ def test_centered_state_one_item_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0], layout[1]]), dims=[J_DIM, I_DIM], units="dimensionless", - backend="debug", + backend=Backend.python(), ), "rank_pos_j": Quantity( numpy.empty([layout[0], layout[1]]), dims=[J_DIM, I_DIM], units="dimensionless", - backend="debug", + backend=Backend.python(), ), "rank_pos_i": Quantity( numpy.empty([layout[0], layout[1]]), dims=[J_DIM, I_DIM], units="dimensionless", - backend="debug", + backend=Backend.python(), ), } @@ -142,7 +143,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend="debug", + backend=Backend.python(), ), "rank_pos_j": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -150,7 +151,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend="debug", + backend=Backend.python(), ), "rank_pos_i": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -158,7 +159,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, - backend="debug", + backend=Backend.python(), ), } diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 6d7a19e5..4338d1b4 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -4,6 +4,7 @@ import pytest from ndsl import LocalComm, Quantity, TileCommunicator, TilePartitioner +from ndsl.config import Backend from ndsl.constants import ( HORIZONTAL_DIMS, I_DIM, @@ -144,7 +145,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), - backend="debug", + backend=Backend.python(), ) diff --git a/tests/test_xumpy.py b/tests/test_xumpy.py index eb61cdde..cc198bb7 100644 --- a/tests/test_xumpy.py +++ b/tests/test_xumpy.py @@ -1,35 +1,38 @@ import numpy as np import ndsl.xumpy as xp +from ndsl.config import Backend shape = (2, 2, 5) def test_xumpy_alloc(): - rand_array = xp.random(shape, backend="debug") + rand_array = xp.random(shape, Backend.python()) assert rand_array.shape == shape - (rand_array != xp.random(shape, backend="debug")).all() + assert (rand_array != xp.random(shape, Backend.python())).all() - assert (np.ones(shape) == xp.ones(shape, backend="debug")).all() - assert (np.zeros(shape) == xp.zeros(shape, backend="debug")).all() - assert (np.full(shape, 42.42) == xp.full(shape, value=42.42, backend="debug")).all() + assert (np.ones(shape) == xp.ones(shape, Backend.python())).all() + assert (np.zeros(shape) == xp.zeros(shape, Backend.python())).all() + assert ( + np.full(shape, 42.42) == xp.full(shape, value=42.42, backend=Backend.python()) + ).all() def test_xumpy_minmax(): - rand_array = xp.random(shape, backend="debug") + rand_array = xp.random(shape, Backend.python()) assert (np.max(rand_array, axis=1) == xp.max(rand_array, axis=1)).all() assert (np.min(rand_array, axis=1) == xp.min(rand_array, axis=1)).all() - out_buffer = xp.empty(shape, backend="debug") + out_buffer = xp.empty(shape, Backend.python()) xp.max_on_horizontal_plane(rand_array, out_buffer) assert (np.max(rand_array, axis=(0, 1)) == out_buffer).all() def test_xumpy_counts(): - rand_array = xp.random(shape, backend="debug") + rand_array = xp.random(shape, Backend.python()) rand_array[1, 1, :] = 0 assert np.count_nonzero(rand_array) == xp.count_nonzero(rand_array) diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index 6a44934a..68b5f07e 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -8,6 +8,7 @@ import xarray as xr from ndsl import CubedSpherePartitioner, LocalComm, MPIComm, Quantity, TilePartitioner +from ndsl.config import Backend from ndsl.constants import ( I_DIM, I_DIMS, @@ -93,7 +94,10 @@ def base_state(request, nz, ny, nx, numpy) -> dict: if request.param == "one_var_2d": return { "var1": Quantity( - numpy.ones([ny, nx]), dims=(J_DIM, X_DIM), units="m", backend="debug" + numpy.ones([ny, nx]), + dims=(J_DIM, X_DIM), + units="m", + backend=Backend.python(), ) } @@ -103,20 +107,23 @@ def base_state(request, nz, ny, nx, numpy) -> dict: numpy.ones([nz, ny, nx]), dims=(K_DIM, J_DIM, I_DIM), units="m", - backend="debug", + backend=Backend.python(), ) } if request.param == "two_vars": return { "var1": Quantity( - numpy.ones([ny, nx]), dims=(J_DIM, I_DIM), units="m", backend="debug" + numpy.ones([ny, nx]), + dims=(J_DIM, I_DIM), + units="m", + backend=Backend.python(), ), "var2": Quantity( numpy.ones([nz, ny, nx]), dims=(K_DIM, J_DIM, I_DIM), units="degK", - backend="debug", + backend=Backend.python(), ), } @@ -249,7 +256,7 @@ def test_monitor_file_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, - backend="debug", + backend=Backend.python(), ), } monitor_list[rank].store(state) @@ -339,7 +346,10 @@ def test_open_zarr_without_nans(cube_partitioner, numpy, mask_and_scale): # initialize store monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) zero_quantity = Quantity( - numpy.zeros([10, 10]), dims=(J_DIM, I_DIM), units="m", backend="debug" + numpy.zeros([10, 10]), + dims=(J_DIM, I_DIM), + units="m", + backend=Backend.python(), ) monitor.store({"var": zero_quantity}) @@ -361,7 +371,10 @@ def test_values_preserved(cube_partitioner, numpy): # initialize store monitor = ZarrMonitor(store, cube_partitioner, mpi_comm=LocalComm(0, 1, buffer)) quantity = Quantity( - numpy.random.uniform(size=(10, 10)), dims=dims, units=units, backend="debug" + numpy.random.uniform(size=(10, 10)), + dims=dims, + units=units, + backend=Backend.python(), ) monitor.store({"var": quantity}) @@ -415,7 +428,7 @@ def diag(request, numpy): numpy.ones([size + 2 for size in range(len(dims))]), dims=dims, units="m", - backend="debug", + backend=Backend.python(), ) @@ -489,7 +502,7 @@ def test_diags_fail_different_dim_set(diag, numpy, zarr_monitor_single_rank): numpy.ones([size + 2 for size in range(len(diag.dims))]), dims=new_dims, units="m", - backend="debug", + backend=Backend.python(), ) with pytest.raises(ValueError) as excinfo: zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) @@ -505,6 +518,8 @@ def test_diags_only_consistent_units_attrs_required(diag, zarr_monitor_single_ra diag_2 = copy.deepcopy(diag) diag_2._attrs.update({"some_non_units_attrs": 9.0}) zarr_monitor_single_rank.store({"time": time_2, "a": diag_2}) - diag_3 = Quantity(data=diag.view[:], dims=diag.dims, units="not_m", backend="debug") + diag_3 = Quantity( + data=diag.view[:], dims=diag.dims, units="not_m", backend=Backend.python() + ) with pytest.raises(ValueError): zarr_monitor_single_rank.store({"time": time_3, "a": diag_3})