From d592f9a61d700697709ab7d127d005476aa6a9f7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:28:09 +0100 Subject: [PATCH] refactor: Deprecate optional backend argument to Quantity/Local --- ndsl/comm/communicator.py | 1 + ndsl/monitor/netcdf_monitor.py | 2 + ndsl/quantity/local.py | 9 +++- ndsl/quantity/quantity.py | 35 +++++++++++--- tests/dsl/test_stencil.py | 12 +++-- tests/quantity/test_boundary.py | 3 ++ tests/quantity/test_deepcopy.py | 3 ++ tests/quantity/test_local.py | 17 ++++++- tests/quantity/test_quantity.py | 77 +++++++++++++++++++++++------- tests/quantity/test_storage.py | 4 +- tests/quantity/test_transpose.py | 5 +- tests/quantity/test_view.py | 34 +++++++++++++ tests/test_basic_operations.py | 10 ++++ tests/test_caching_comm.py | 1 + tests/test_cube_scatter_gather.py | 1 + tests/test_halo_update.py | 24 ++-------- tests/test_halo_update_ranks.py | 1 + tests/test_netcdf_monitor.py | 4 ++ tests/test_partitioner.py | 4 +- tests/test_sync_shared_boundary.py | 4 ++ tests/test_tile_scatter.py | 8 ++++ tests/test_tile_scatter_gather.py | 1 + 22 files changed, 206 insertions(+), 54 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index d1c1205f..983a35b2 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -326,6 +326,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None): # dims=quantity.dims, units=quantity.units, allow_mismatch_float_precision=True, + backend=quantity.backend, ) if recv_state is not None and name in recv_state: tile_quantity = self.gather( diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index e2cb417f..a95c532e 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -21,6 +21,7 @@ def __init__(self, initial: Quantity, time_chunk_size: int): self._dims = initial.dims self._units = initial.units self._i_time = 1 + self._backend = initial.backend def append(self, quantity: Quantity) -> None: # Allow mismatch precision here since this is I/O @@ -37,6 +38,7 @@ def data(self) -> Quantity: dims=("time",) + tuple(self._dims), units=self._units, allow_mismatch_float_precision=True, + backend=self._backend, ) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index af75242d..438def1b 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -23,11 +23,11 @@ def __init__( dims: Sequence[str], units: str, *, + backend: str | None = None, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, - backend: str | None = None, ): if gt4py_backend is not None: warnings.warn( @@ -38,6 +38,13 @@ def __init__( if backend is None: backend = gt4py_backend + if backend is None: + warnings.warn( + "`backend` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( data, dims, diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 4b902a72..300f6376 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -33,12 +33,12 @@ def __init__( dims: Sequence[str], units: str, *, + backend: str | None = None, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, number_of_halo_points: int = 0, - backend: str | None = None, ): """Initialize a Quantity. @@ -46,6 +46,8 @@ def __init__( data: ndarray-like object containing the underlying data dims: dimension names for each axis units: units of the quantity + backend: GT4Py backend name. We ensure that the data is allocated in a + performance optimal way for that backend and copy if necessary. origin: first point in data within the computational domain. Defaults to None. extent: number of points along each axis @@ -54,8 +56,6 @@ def __init__( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: GT4Py backend name. If given, we check that the data is - allocated in a performance optimal way for that backend. Raises: ValueError: Data-type mismatch between configuration and input-data @@ -70,6 +70,13 @@ def __init__( if backend is None: backend = gt4py_backend + if backend is None: + warnings.warn( + "`backend` will be a required argument starting with the next version of NDSL.", + DeprecationWarning, + stacklevel=2, + ) + if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -179,8 +186,9 @@ def from_data_array( allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points: Number of halo points used. Defaults to 0. - backend: GT4Py backend name. If given, we check that the data is - allocated in a performance optimal way for that backend. + backend: GT4Py backend name. If given, we allocate data in a performance + optimal way for this backend. Overrides any potentially saved `backend` + in `data.attrs["backend"]`. """ if "units" not in data_array.attrs: raise ValueError("need units attribute to create Quantity from DataArray") @@ -201,7 +209,7 @@ def from_data_array( origin=origin, extent=extent, number_of_halo_points=number_of_halo_points, - backend=backend, + backend=_resolve_backend(data_array, backend), ) def to_netcdf( @@ -283,7 +291,7 @@ def backend(self) -> str | None: @property def attrs(self) -> dict: - return dict(**self._attrs, units=self._metadata.units) + return dict(**self._attrs, units=self.units, backend=self.backend) @property def dims(self) -> tuple[str, ...]: @@ -495,3 +503,16 @@ def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple: f"unexpected type {type(item)}" ) return tuple(return_list) + + +def _resolve_backend(data: xr.DataArray, backend: str | None) -> str: + if backend is not None: + # Forced backend name takes precedence + return backend + + # If backend name was serialized with data, take this one + if "backend" in data.attrs: + return data.attrs["backend"] + + # else, fall back to assume python-based layout. + return "debug" diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 4daa401c..1f7f3836 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -108,7 +108,9 @@ def test_domain_size_comparison( domain: tuple[int], call_count: int, ): - quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent) + quantity = Quantity( + np.zeros(extent), dimensions, "n/a", extent=extent, backend="debug" + ) stencil = FrozenStencil( copy_stencil, origin=(0, 0, 0), @@ -147,7 +149,9 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None: def test_stencil_2D_temporaries() -> None: domain = (2, 2, 5) - quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + quantity = Quantity( + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + ) stencil = FrozenStencil( two_dim_temporaries_stencil, origin=(0, 0, 0), @@ -164,7 +168,9 @@ def test_stencil_2D_temporaries() -> None: ) def test_validation_call_count(iterations: tuple[int]): domain = (2, 2, 5) - quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain) + quantity = Quantity( + np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug" + ) stencil_config = StencilConfig( compilation_config=CompilationConfig(backend="numpy", rebuild=True) ) diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index a4f8e812..7ba2eddb 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -36,6 +36,7 @@ def test_boundary_data_1_by_1_array_1_halo(): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) for side in ( WEST, @@ -71,6 +72,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy): units="m", origin=(1, 1, 1), extent=(1, 1, 1), + backend="debug", ) for side in ( WEST, @@ -109,6 +111,7 @@ def test_boundary_data_2_by_2_array_2_halo(): units="m", origin=(2, 2), extent=(2, 2), + backend="debug", ) for side in ( WEST, diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index d6b5c7cb..f0c7bb13 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -14,6 +14,7 @@ def test_deepcopy_copy_is_editable_by_view(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) # assertion below is only valid if we're overwriting the entire data through view @@ -31,6 +32,7 @@ def test_deepcopy_copy_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 @@ -46,6 +48,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data(): extent=(nx, ny, nz), dims=["x", "y", "z"], units="", + backend="debug", ) quantity_copy = copy.deepcopy(quantity) quantity_copy.data[:] = 1.0 diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py index 859bb009..bb6027f5 100644 --- a/tests/quantity/test_local.py +++ b/tests/quantity/test_local.py @@ -4,7 +4,7 @@ from ndsl import Local -def test_local_descriptor_is_transient() -> None: +def test_dace_data_descriptor_is_transient() -> None: nx = 5 shape = (nx,) local = Local( @@ -19,7 +19,7 @@ def test_local_descriptor_is_transient() -> None: assert array.transient -def test_local_gt4py_backend_is_deprecated() -> None: +def test_gt4py_backend_is_deprecated() -> None: nx = 5 shape = (nx,) backend = "debug" @@ -39,3 +39,16 @@ def test_local_gt4py_backend_is_deprecated() -> None: # make sure we are backwards compatible (for now) with pytest.deprecated_call(match="gt4py_backend is deprecated"): assert local.gt4py_backend == backend + + +def test_backend_will_be_required() -> None: + nx = 5 + shape = (nx,) + with pytest.deprecated_call(match="`backend` will be a required argument"): + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + ) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 3286ce0f..ef94b45e 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -62,7 +62,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): - return Quantity(data, origin=origin, extent=extent, dims=dims, units=units) + return Quantity( + data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + ) def test_smaller_data_raises(data, origin, extent, dims, units): @@ -72,25 +74,55 @@ def test_smaller_data_raises(data, origin, extent, dims, units): except IndexError: pass else: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="received .* dimension names for .* dimensions: .*" + ): Quantity( - small_data, origin=origin, extent=extent, dims=dims, units=units + small_data, + origin=origin, + extent=extent, + dims=dims, + units=units, + backend="debug", ) def test_smaller_dims_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin, extent=extent, dims=dims[:-1], units=units) + with pytest.raises( + ValueError, match="received .* dimension names for .* dimensions: .*" + ): + Quantity( + data, + origin=origin, + extent=extent, + dims=dims[:-1], + units=units, + backend="debug", + ) def test_smaller_origin_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin[:-1], extent=extent, dims=dims, units=units) + with pytest.raises(ValueError, match="received .* origins for .* dimensions: .*"): + Quantity( + data, + origin=origin[:-1], + extent=extent, + dims=dims, + units=units, + backend="debug", + ) def test_smaller_extent_raises(data, origin, extent, dims, units): - with pytest.raises(ValueError): - Quantity(data, origin=origin, extent=extent[:-1], dims=dims, units=units) + with pytest.raises(ValueError, match="received .* extents for .* dimensions: .*"): + Quantity( + data, + origin=origin, + extent=extent[:-1], + dims=dims, + units=units, + backend="debug", + ) def test_data_change_affects_quantity(data, quantity, numpy): @@ -229,20 +261,15 @@ def test_shift_slice(slice_in, shift, extent, slice_out): @pytest.mark.parametrize( "quantity", [ + Quantity(np.array(5), dims=[], units="", backend="debug"), Quantity( - np.array(5), - dims=[], - units="", - ), - Quantity( - np.array([1, 2, 3]), - dims=["dimension"], - units="degK", + np.array([1, 2, 3]), dims=["dimension"], units="degK", backend="debug" ), Quantity( np.random.randn(3, 2, 4), dims=["dim1", "dim_2", "dimension_3"], units="m", + backend="debug", ), Quantity( np.random.randn(8, 6, 6), @@ -250,6 +277,7 @@ def test_shift_slice(slice_in, shift, extent, slice_out): units="km", origin=(2, 2, 2), extent=(4, 2, 2), + backend="debug", ), ], ) @@ -265,7 +293,7 @@ def test_to_data_array(quantity): def test_data_setter(): - quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend="debug") # After allocation - field and data are the same (origin is 0) assert quantity.data.shape == quantity.field.shape @@ -356,3 +384,16 @@ def test_assign_basic_data_is_deprecated() -> None: # make sure we can still use it (for now) for i in range(5): assert quantity.data[i] == i + + +def test_constructor_backend_will_be_required() -> None: + nx = 5 + shape = (nx,) + with pytest.deprecated_call(match="`backend` will be a required argument"): + local = Quantity( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + ) diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 7fbd1a04..6d8dc4a4 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -53,7 +53,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype): @pytest.fixture def quantity(data, origin, extent, dims, units): - return Quantity(data, origin=origin, extent=extent, dims=dims, units=units) + return Quantity( + data, origin=origin, extent=extent, dims=dims, units=units, backend="debug" + ) def test_numpy(quantity, backend): diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 88653676..b7ceb0f8 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -86,6 +86,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent): units="unit_string", origin=initial_origin, extent=initial_extent, + backend="debug", ) @@ -218,7 +219,9 @@ def test_transpose_invalid_cases( def test_transpose_retains_attrs(numpy): - quantity = Quantity(numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string") + quantity = Quantity( + numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string", backend="debug" + ) quantity._attrs = {"long_name": "500 mb height"} transposed = quantity.transpose(["y", "x"]) assert transposed.attrs == quantity.attrs diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index 73245093..0ce44cfe 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -183,6 +183,7 @@ def quantity(request): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) ], ) @@ -216,6 +217,7 @@ def test_many_indices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) ], ) @@ -742,6 +744,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 4, @@ -754,6 +757,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 0, @@ -766,6 +770,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(-1, 0), slice(-1, 0)), np.array([[0]]), @@ -778,6 +783,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 1, @@ -798,6 +804,7 @@ def test_many_slices_raises(quantity, view_name): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[1, 2, 3], [6, 7, 8]]), @@ -816,6 +823,7 @@ def test_southwest(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.southwest[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -834,6 +842,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -846,6 +855,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, -1), 6, @@ -858,6 +868,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -870,6 +881,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 3, @@ -890,6 +902,7 @@ def test_southwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -908,6 +921,7 @@ def test_southeast(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.southeast[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -926,6 +940,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -938,6 +953,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, -1), 6, @@ -950,6 +966,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(-1, 0)), np.array([[6]]), @@ -962,6 +979,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 4, @@ -974,6 +992,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 3, @@ -994,6 +1013,7 @@ def test_southeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[6, 7, 8], [11, 12, 13]]), @@ -1012,6 +1032,7 @@ def test_northwest(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.northwest[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -1030,6 +1051,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 4, @@ -1042,6 +1064,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 8, @@ -1054,6 +1077,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 1), slice(0, 1)), np.array([[8]]), @@ -1066,6 +1090,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, -1), 4, @@ -1078,6 +1103,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (-1, 0), 5, @@ -1098,6 +1124,7 @@ def test_northwest(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(-1, 2)), np.array([[7, 8, 9], [12, 13, 14]]), @@ -1116,6 +1143,7 @@ def test_northeast(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) transposed_result = transposed_quantity.view.northeast[view_slice[::-1]] if isinstance(reference, quantity.np.ndarray): @@ -1134,6 +1162,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (0, 0), 4, @@ -1146,6 +1175,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(0, 0), slice(0, 0)), 4, @@ -1158,6 +1188,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ), (slice(-1, 1), slice(-1, 1)), np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), @@ -1178,6 +1209,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(2, 2), extent=(1, 1), + backend="debug", ), (slice(-2, 0), slice(0, 1)), np.array([[2, 3], [7, 8], [12, 13]]), @@ -1198,6 +1230,7 @@ def test_northeast(quantity, view_slice, reference): units="m", origin=(1, 1), extent=(3, 3), + backend="debug", ), (0,), np.array([6, 7, 8]), @@ -1216,6 +1249,7 @@ def test_interior(quantity, view_slice, reference): units=quantity.units, origin=quantity.origin[::-1], extent=quantity.extent[::-1], + backend=quantity.backend, ) if len(view_slice) == len(quantity.dims): # skip if not transposed_result = transposed_quantity.view.interior[view_slice[::-1]] diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 0d707240..74916015 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -128,12 +128,14 @@ def test_copy(): data=np.zeros([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.ones([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) copy(f_in=infield.data, f_out=outfield.data) @@ -148,18 +150,21 @@ def test_adjustmentfactor(): data=np.full(shape=[20, 20], fill_value=2.0), dims=[X_DIM, Y_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=4.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) adfact(factor=factorfield.data, f_out=outfield.data) @@ -173,12 +178,14 @@ def test_setvalue(): data=np.zeros(shape=[20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) setvalue(f_out=outfield.data, value=2.0) @@ -193,18 +200,21 @@ def test_adjustdivide(): data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) outfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) testfield = Quantity( data=np.full(shape=[20, 20, 79], fill_value=1.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", + backend=backend, ) addiv(factor=factorfield.data, f_out=outfield.data) diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index d2d7f64e..cafecaf8 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -30,6 +30,7 @@ def test_halo_update_integration(): units="", origin=origin, extent=extent, + backend="debug", ) for _ in range(n_ranks) ] diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 236e1eb9..b71f54e1 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -169,6 +169,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), + backend="numpy", ) diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index ca76ab8b..21137c17 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -319,11 +319,7 @@ def depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) return_list.append(quantity) return return_list @@ -356,11 +352,7 @@ def tile_depth_quantity_list( pos[i] = origin[i] + extent[i] + n_outside - 1 data[tuple(pos)] = numpy.nan quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) return_list.append(quantity) return return_list @@ -500,11 +492,7 @@ def zeros_quantity_list(total_ranks, dims, units, origin, extent, shape, numpy, for _rank in range(total_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return_list.append(quantity) @@ -521,11 +509,7 @@ def zeros_quantity_tile_list( for _rank in range(single_tile_ranks): data = numpy.ones(shape, dtype=dtype) quantity = Quantity( - data, - dims=dims, - units=units, - origin=origin, - extent=extent, + data, dims=dims, units=units, origin=origin, extent=extent, backend="debug" ) quantity.view[:] = 0.0 return_list.append(quantity) diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index 6ceb4886..b50f03d1 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -126,6 +126,7 @@ def rank_quantity_list(total_ranks, numpy, dtype): units="m", origin=(1, 1), extent=(1, 1), + backend="debug", ) quantity_list.append(quantity) return quantity_list diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 19614486..a19cf193 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -39,6 +39,7 @@ def test_monitor_store_multi_rank_state( layout, nt, time_chunk_size, tmpdir, shape, ny_rank_add, nx_rank_add, dims, numpy ): units = "m" + backend = "debug" nz, ny, nx = shape ny_rank = int(ny / layout[0] + ny_rank_add) nx_rank = int(nx / layout[1] + nx_rank_add) @@ -74,6 +75,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store_constant(state) @@ -87,6 +89,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store(state) @@ -100,6 +103,7 @@ def test_monitor_store_multi_rank_state( numpy.ones([nz, ny_rank, nx_rank]), dims=dims, units=units, + backend=backend, ), } monitor_list[rank].store_constant(state) diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index 4d20e709..5b8836ac 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -993,7 +993,9 @@ def test_subtile_extent_with_tile_dimensions( cubedsphere_expected, ): data_array = np.zeros((tile_extent)) - quantity = Quantity(data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0]) + quantity = Quantity( + data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0], backend="debug" + ) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) cubedsphere_partitioner = CubedSpherePartitioner(tile_partitioner) diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 7db5a621..be281194 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -81,6 +81,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), + backend="debug", ) y_data = numpy.empty((2, 3), dtype=dtype) y_data[:] = rank @@ -90,6 +91,7 @@ def rank_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), + backend="debug", ) quantity_list.append((x_quantity, y_quantity)) return quantity_list @@ -147,6 +149,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(3, 2), + backend="debug", ) y_data = 6 * total_ranks + numpy.array([[0, 1, 2], [3, 4, 5]]) + 6 * rank y_quantity = Quantity( @@ -155,6 +158,7 @@ def counting_quantity_list(total_ranks, numpy, dtype, units=units): units=units, origin=(0, 0), extent=(2, 3), + backend="debug", ) quantity_list.append((x_quantity, y_quantity)) return quantity_list diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index d768bb15..ccbe1cdb 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -36,11 +36,13 @@ def test_interface_state_two_by_two_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0] + 1, layout[1] + 1]), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", + backend="debug", ), "pos_i": Quantity( numpy.empty([layout[0] + 1, layout[1] + 1], dtype=numpy.int32), dims=[Y_INTERFACE_DIM, X_INTERFACE_DIM], units="dimensionless", + backend="debug", ), } @@ -80,16 +82,19 @@ def test_centered_state_one_item_per_rank_scatter_tile(layout, numpy): numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), "rank_pos_j": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), "rank_pos_i": Quantity( numpy.empty([layout[0], layout[1]]), dims=[Y_DIM, X_DIM], units="dimensionless", + backend="debug", ), } @@ -137,6 +142,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), "rank_pos_j": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -144,6 +150,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), "rank_pos_i": Quantity( numpy.empty([layout[0] + 2 * n_halo, layout[1] + 2 * n_halo]), @@ -151,6 +158,7 @@ def test_centered_state_one_item_per_rank_with_halo_scatter_tile(layout, n_halo, units="dimensionless", origin=(n_halo, n_halo), extent=extent, + backend="debug", ), } diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 60ef583a..e8bfa861 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -150,6 +150,7 @@ def get_quantity(dims, units, extent, n_halo, numpy): units, origin=tuple(origin), extent=tuple(extent), + backend="debug", )