From 8587043b1ba08598208c813cb6d9bd84c2827448 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 7 Nov 2025 11:29:10 +0100 Subject: [PATCH 1/2] refactor: force kwargs in ctor of Quantity/Local Force keyword arguments for optional arguments to those constructors. This will facilitate the `gt4py_backen` -> `backend` transition. --- ndsl/quantity/local.py | 10 +++++----- ndsl/quantity/quantity.py | 1 + tests/test_partitioner.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index 910999ab..a9e34772 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -20,6 +20,7 @@ def __init__( data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, @@ -29,12 +30,11 @@ def __init__( data, dims, units, - origin, - extent, - gt4py_backend, - allow_mismatch_float_precision, + origin=origin, + extent=extent, + gt4py_backend=gt4py_backend, + allow_mismatch_float_precision=allow_mismatch_float_precision, ) - self._transient = True def __descriptor__(self) -> Any: """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 289fa644..3263971a 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -32,6 +32,7 @@ def __init__( data: np.ndarray | cupy.ndarray, dims: Sequence[str], units: str, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index cc0a105e..4d20e709 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -993,7 +993,7 @@ def test_subtile_extent_with_tile_dimensions( cubedsphere_expected, ): data_array = np.zeros((tile_extent)) - quantity = Quantity(data_array, array_dims, "dimensionless", [0, 0, 0, 0]) + quantity = Quantity(data_array, array_dims, "dimensionless", origin=[0, 0, 0, 0]) tile_partitioner = TilePartitioner(layout, edge_interior_ratio) cubedsphere_partitioner = CubedSpherePartitioner(tile_partitioner) From 19d9509b34367e30cd563fa64492c2d31743f03c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:20:37 +0100 Subject: [PATCH 2/2] refactor: prefer `backend` over `gt4py_backend` in Quantity --- ndsl/comm/communicator.py | 10 +-- ndsl/dsl/ndsl_runtime.py | 2 +- ndsl/grid/generation.py | 24 +++--- ndsl/grid/helper.py | 12 +-- ndsl/initialization/allocator.py | 2 +- ndsl/quantity/local.py | 16 +++- ndsl/quantity/metadata.py | 5 +- ndsl/quantity/quantity.py | 113 ++++++++++++++++++--------- tests/mpi/test_mpi_all_reduce_sum.py | 12 +-- tests/quantity/test_local.py | 41 ++++++++++ tests/quantity/test_quantity.py | 67 ++++++++++++++++ tests/quantity/test_storage.py | 8 +- tests/quantity/test_transpose.py | 10 ++- tests/test_halo_data_transformer.py | 2 +- 14 files changed, 244 insertions(+), 80 deletions(-) create mode 100644 tests/quantity/test_local.py diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 6eee4514..d1c1205f 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -108,7 +108,7 @@ def _create_all_reduce_quantity( units=input_metadata.units, origin=input_metadata.origin, extent=input_metadata.extent, - gt4py_backend=input_metadata.gt4py_backend, + backend=input_metadata.backend, allow_mismatch_float_precision=False, ) return all_reduce_quantity @@ -228,7 +228,7 @@ def _get_gather_recv_quantity( units=send_metadata.units, origin=tuple([0 for dim in send_metadata.dims]), extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, + backend=send_metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -241,7 +241,7 @@ def _get_scatter_recv_quantity( send_metadata.np.zeros(shape, dtype=send_metadata.dtype), # type: ignore dims=send_metadata.dims, units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, + backend=send_metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -841,7 +841,7 @@ def _get_gather_recv_quantity( units=metadata.units, origin=(0,) + tuple([0 for dim in metadata.dims]), extent=global_extent, - gt4py_backend=metadata.gt4py_backend, + backend=metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity @@ -861,7 +861,7 @@ def _get_scatter_recv_quantity( metadata.np.zeros(shape, dtype=metadata.dtype), # type: ignore dims=metadata.dims[1:], units=metadata.units, - gt4py_backend=metadata.gt4py_backend, + backend=metadata.backend, allow_mismatch_float_precision=True, ) return recv_quantity diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 065cc298..c798ded4 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -123,6 +123,6 @@ def make_local( units=quantity.units, origin=quantity.origin, extent=quantity.extent, - gt4py_backend=quantity.gt4py_backend, + backend=quantity.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, ) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 7fd6eb7d..7296fa76 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -551,7 +551,7 @@ def lon(self): origin=self.grid.origin[:2], extent=self.grid.extent[:2], units=self.grid.units, - gt4py_backend=self.grid.gt4py_backend, + backend=self.grid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -563,7 +563,7 @@ def lat(self) -> Quantity: origin=self.grid.origin[:2], extent=self.grid.extent[:2], units=self.grid.units, - gt4py_backend=self.grid.gt4py_backend, + backend=self.grid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -575,7 +575,7 @@ def lon_agrid(self) -> Quantity: origin=self.agrid.origin[:2], extent=self.agrid.extent[:2], units=self.agrid.units, - gt4py_backend=self.agrid.gt4py_backend, + backend=self.agrid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -587,7 +587,7 @@ def lat_agrid(self) -> Quantity: origin=self.agrid.origin[:2], extent=self.agrid.extent[:2], units=self.agrid.units, - gt4py_backend=self.agrid.gt4py_backend, + backend=self.agrid.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1551,7 +1551,7 @@ def rarea(self) -> Quantity: origin=self.area.origin, extent=self.area.extent, units="m^-2", - gt4py_backend=self.area.gt4py_backend, + backend=self.area.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1566,7 +1566,7 @@ def rarea_c(self) -> Quantity: origin=self.area_c.origin, extent=self.area_c.extent, units="m^-2", - gt4py_backend=self.area_c.gt4py_backend, + backend=self.area_c.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1582,7 +1582,7 @@ def rdx(self) -> Quantity: origin=self.dx.origin, extent=self.dx.extent, units="m^-1", - gt4py_backend=self.dx.gt4py_backend, + backend=self.dx.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1598,7 +1598,7 @@ def rdy(self) -> Quantity: origin=self.dy.origin, extent=self.dy.extent, units="m^-1", - gt4py_backend=self.dy.gt4py_backend, + backend=self.dy.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1614,7 +1614,7 @@ def rdxa(self) -> Quantity: origin=self.dxa.origin, extent=self.dxa.extent, units="m^-1", - gt4py_backend=self.dxa.gt4py_backend, + backend=self.dxa.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1630,7 +1630,7 @@ def rdya(self) -> Quantity: origin=self.dya.origin, extent=self.dya.extent, units="m^-1", - gt4py_backend=self.dya.gt4py_backend, + backend=self.dya.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1646,7 +1646,7 @@ def rdxc(self) -> Quantity: origin=self.dxc.origin, extent=self.dxc.extent, units="m^-1", - gt4py_backend=self.dxc.gt4py_backend, + backend=self.dxc.backend, number_of_halo_points=N_HALO_DEFAULT, ) @@ -1662,7 +1662,7 @@ def rdyc(self) -> Quantity: origin=self.dyc.origin, extent=self.dyc.extent, units="m^-1", - gt4py_backend=self.dyc.gt4py_backend, + backend=self.dyc.backend, number_of_halo_points=N_HALO_DEFAULT, ) diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index dd612b22..d907e49a 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -186,7 +186,7 @@ def p_interface(self) -> Quantity: p_interface_data, dims=[Z_INTERFACE_DIM], units="Pa", - gt4py_backend=self.ak.gt4py_backend, + backend=self.ak.backend, number_of_halo_points=self.ak.metadata.n_halo, ) return self._p_interface @@ -203,7 +203,7 @@ def p(self) -> Quantity: p_data, dims=[Z_DIM], units="Pa", - gt4py_backend=self.p_interface.gt4py_backend, + backend=self.p_interface.backend, number_of_halo_points=self.p_interface.metadata.n_halo, ) return self._p @@ -220,7 +220,7 @@ def dp(self) -> Quantity: dp_ref_data, dims=[Z_DIM], units="Pa", - gt4py_backend=self.ak.gt4py_backend, + backend=self.ak.backend, number_of_halo_points=self.ak.metadata.n_halo, ) return self._dp_ref @@ -230,7 +230,7 @@ def ptop(self) -> Float: """Top of atmosphere pressure (Pa)""" if self.bk.view[0] != 0: raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0") - if self.ak.gt4py_backend is not None and is_gpu_backend(self.ak.gt4py_backend): + if self.ak.backend is not None and is_gpu_backend(self.ak.backend): return Float(self.ak.view[0].get()) else: return Float(self.ak.view[0]) @@ -382,7 +382,7 @@ def _fC_from_data(data, lat: Quantity) -> Quantity: dims=lat.dims, origin=lat.origin, extent=lat.extent, - gt4py_backend=lat.gt4py_backend, + backend=lat.backend, number_of_halo_points=lat.metadata.n_halo, ) @@ -824,7 +824,7 @@ def split_quantity_along_last_dim(quantity: Quantity) -> list[Quantity]: units=quantity.units, origin=quantity.origin[:-1], extent=quantity.extent[:-1], - gt4py_backend=quantity.gt4py_backend, + backend=quantity.backend, number_of_halo_points=quantity.metadata.n_halo, ) ) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index a8b3efbc..f36cd02e 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -228,7 +228,7 @@ def _allocate( units=units, origin=origin, extent=extent, - gt4py_backend=self.backend, + backend=self.backend, allow_mismatch_float_precision=allow_mismatch_float_precision, number_of_halo_points=self.sizer.n_halo, ) diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index a9e34772..af75242d 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +import warnings +from collections.abc import Sequence +from typing import Any import dace import numpy as np @@ -25,15 +27,25 @@ def __init__( extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, + backend: str | None = None, ): + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + super().__init__( data, dims, units, origin=origin, extent=extent, - gt4py_backend=gt4py_backend, allow_mismatch_float_precision=allow_mismatch_float_precision, + backend=backend, ) def __descriptor__(self) -> Any: diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index 7e7b4f16..45a14445 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -30,7 +30,9 @@ class QuantityMetadata: dtype: type "dtype of the data in the ndarray-like object" gt4py_backend: str | None = None - "backend to use for gt4py storages" + "Deprecated. Use backend instead." + backend: str | None = None + "GT4Py backend name. Used for performance optimal data allocation." @property def dim_lengths(self) -> dict[str, int]: @@ -57,6 +59,7 @@ def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None: metadata_copy.data_type = self.data_type metadata_copy.dtype = self.dtype metadata_copy.gt4py_backend = self.gt4py_backend + metadata_copy.backend = self.backend @dataclasses.dataclass diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 3263971a..4b902a72 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -38,28 +38,38 @@ def __init__( gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, number_of_halo_points: int = 0, + backend: str | None = None, ): """Initialize a Quantity. Args: - data (_type_): ndarray-like object containing the underlying data - dims (Sequence[str]): dimension names for each axis - units (str): units of the quantity - origin (Sequence[int] | None, optional): first point in data within the - computational domain. Defaults to None. - extent (Sequence[int] | None, optional): number of points along each axis + data: ndarray-like object containing the underlying data + dims: dimension names for each axis + units: units of the quantity + origin: first point in data within the + computational domain. Defaults to None. + extent: number of points along each axis within the computational domain. Defaults to None. - gt4py_backend (str | None, optional): backend to use for gt4py storages, - if not given this will be derived from a Storage - if given as the data argument. Defaults to None. - allow_mismatch_float_precision (bool, optional): allow for precision that is + gt4py_backend: deprecated, use `backend` instead. + allow_mismatch_float_precision: allow for precision that is not the simulation-wide default configuration. Defaults to False. - number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + number_of_halo_points: Number of halo points used. Defaults to 0. + backend: GT4Py backend name. If given, we check that the data is + allocated in a performance optimal way for that backend. Raises: ValueError: Data-type mismatch between configuration and input-data TypeError: Typing of the data that does not fit """ + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -81,6 +91,11 @@ def __init__( if isinstance(data, (int, float, list)): # If converting basic data, use a numpy ndarray. + warnings.warn( + "Usage of basic data in Quantities is deprecated. Please use it with a numpy or cuppy ndarray instead.", + DeprecationWarning, + stacklevel=2, + ) data = np.asarray(data) if not isinstance(data, (np.ndarray, cupy.ndarray)): @@ -88,8 +103,10 @@ def __init__( f"Only supports numpy.ndarray and cupy.ndarray, got {type(data)}" ) - if gt4py_backend is not None: - gt4py_backend_cls = gt_backend.from_name(gt4py_backend) + _validate_quantity_property_lengths(data.shape, dims, origin, extent) + + if backend is not None: + gt4py_backend_cls = gt_backend.from_name(backend) is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"] dimensions: tuple[str | int, ...] = tuple( @@ -105,21 +122,25 @@ def __init__( ] ) - self._data = ( - data - if is_optimal_layout(data, dimensions) - else self._initialize_data( + if is_optimal_layout(data, dimensions): + self._data = data + else: + warnings.warn( + f"Suboptimal data layout found. Copying data to optimally align for backend '{backend}'.", + UserWarning, + stacklevel=2, + ) + self._data = gt_storage.from_array( data, - origin=origin, - gt4py_backend=gt4py_backend, + data.dtype, + backend=backend, + aligned_index=origin, dimensions=dimensions, ) - ) else: - # We have no info about the gt4py_backend, so just assign it. + # We have no info about the gt4py backend, so just assign it. self._data = data - _validate_quantity_property_lengths(data.shape, dims, origin, extent) self._metadata = QuantityMetadata( origin=_ensure_int_tuple(origin, "origin"), extent=_ensure_int_tuple(extent, "extent"), @@ -128,7 +149,8 @@ def __init__( units=units, data_type=type(self._data), dtype=data.dtype, - gt4py_backend=gt4py_backend, + backend=backend, + gt4py_backend=backend, ) self._attrs = {} # type: ignore[var-annotated] self._compute_domain_view = BoundedArrayView( @@ -139,10 +161,12 @@ def __init__( def from_data_array( cls, data_array: xr.DataArray, + *, origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, number_of_halo_points: int = 0, + backend: str | None = None, ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -151,12 +175,25 @@ def from_data_array( data_array origin: first point in data within the computational domain extent: number of points along each axis within the computational domain - gt4py_backend: backend to use for gt4py storages, if not given this will - be derived from a Storage if given as the data argument, otherwise the - storage attribute is disabled and will raise an exception + gt4py_backend: deprecated, use `backend` instead. + allow_mismatch_float_precision: allow for precision that is + not the simulation-wide default configuration. Defaults to False. + number_of_halo_points: Number of halo points used. Defaults to 0. + backend: GT4Py backend name. If given, we check that the data is + allocated in a performance optimal way for that backend. """ if "units" not in data_array.attrs: raise ValueError("need units attribute to create Quantity from DataArray") + + if gt4py_backend is not None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) + if backend is None: + backend = gt4py_backend + return cls( data_array.values, cast(tuple[str], data_array.dims), @@ -164,7 +201,7 @@ def from_data_array( origin=origin, extent=extent, number_of_halo_points=number_of_halo_points, - gt4py_backend=gt4py_backend, + backend=backend, ) def to_netcdf( @@ -222,17 +259,6 @@ def sel(self, **kwargs: slice | int) -> np.ndarray: """ return self.view[tuple(kwargs.get(dim, slice(None, None)) for dim in self.dims)] - def _initialize_data(self, data, origin, gt4py_backend: str, dimensions: tuple): # type: ignore - """Allocates an ndarray with optimal memory layout, and copies the data over.""" - storage = gt_storage.from_array( - data, - data.dtype, - backend=gt4py_backend, - aligned_index=origin, - dimensions=dimensions, - ) - return storage - @property def metadata(self) -> QuantityMetadata: return self._metadata @@ -244,8 +270,17 @@ def units(self) -> str: @property def gt4py_backend(self) -> str | None: + warnings.warn( + "gt4py_backend is deprecated. Use `backend` instead.", + DeprecationWarning, + stacklevel=2, + ) return self.metadata.gt4py_backend + @property + def backend(self) -> str | None: + return self.metadata.backend + @property def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @@ -386,8 +421,8 @@ def transpose( units=self.units, origin=_transpose_sequence(self.origin, transpose_order), extent=_transpose_sequence(self.extent, transpose_order), - gt4py_backend=self.gt4py_backend, allow_mismatch_float_precision=allow_mismatch_float_precision, + backend=self.backend, ) transposed._attrs = self._attrs return transposed diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 6cab1023..52b02dad 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -58,7 +58,7 @@ def test_all_reduce(communicator): data=base_array, dims=["K"], units="Some 1D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5)], dtype=Float) @@ -68,7 +68,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J"], units="Some 2D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) @@ -78,7 +78,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J", "K"], units="Some 3D unit", - gt4py_backend=backend, + backend=backend, ) global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM) @@ -98,7 +98,7 @@ def test_all_reduce(communicator): data=base_array, dims=["K"], units="New 1D unit", - gt4py_backend=backend, + backend=backend, origin=(8,), extent=(7,), ) @@ -110,7 +110,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J"], units="Some 2D unit", - gt4py_backend=backend, + backend=backend, ) base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) @@ -120,7 +120,7 @@ def test_all_reduce(communicator): data=base_array, dims=["I", "J", "K"], units="Some 3D unit", - gt4py_backend=backend, + backend=backend, ) communicator.all_reduce( testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out diff --git a/tests/quantity/test_local.py b/tests/quantity/test_local.py new file mode 100644 index 00000000..859bb009 --- /dev/null +++ b/tests/quantity/test_local.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +from ndsl import Local + + +def test_local_descriptor_is_transient() -> None: + nx = 5 + shape = (nx,) + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + backend="debug", + ) + array = local.__descriptor__() + assert array.transient + + +def test_local_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + local = Local( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + gt4py_backend=backend, + ) + + # make sure we assign backend + assert local.backend == backend + + # make sure we are backwards compatible (for now) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert local.gt4py_backend == backend diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index dccfa94f..3286ce0f 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import xarray as xr from ndsl import Quantity from ndsl.quantity.bounds import _shift_slice @@ -289,3 +290,69 @@ def test_data_setter(): # Expected fail: new array is not even an array with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): quantity.data = "meh" + + +def test_constructor_with_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + quantity = Quantity( + data=np.empty(shape), + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + gt4py_backend=backend, + ) + + # make sure we assign backend + assert quantity.backend == backend + + # make sure we are backwards compatible (on the QuantityMetadata) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert quantity.gt4py_backend == backend + + +def test_from_data_array_with_gt4py_backend_is_deprecated() -> None: + nx = 5 + shape = (nx,) + backend = "debug" + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + np_data = np.empty(shape) + data_array = xr.DataArray(data=np_data, attrs={"units": "n/a"}) + quantity = Quantity.from_data_array( + data_array, + origin=(0,), + extent=(nx,), + number_of_halo_points=0, + gt4py_backend=backend, + ) + + # make sure we assign backend + assert quantity.backend == backend + + # make sure we don't assign gt4py_backend anymore (on the QuantityMetadata) + with pytest.deprecated_call(match="gt4py_backend is deprecated"): + assert quantity.gt4py_backend == backend + + +def test_assign_basic_data_is_deprecated() -> None: + nx = 5 + backend = "debug" + with pytest.deprecated_call( + match="Usage of basic data in Quantities is deprecated" + ): + quantity = Quantity( + data=[0, 1, 2, 3, 4], + origin=(0,), + extent=(nx,), + dims=("dim_X",), + units="n/a", + backend=backend, + allow_mismatch_float_precision=True, + ) + + # make sure we can still use it (for now) + for i in range(5): + assert quantity.data[i] == i diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index bc39f61f..7fbd1a04 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -72,7 +72,7 @@ def test_modifying_numpy_data_modifies_view_and_field(): extent=shape, dims=["dim1", "dim2"], units="units", - gt4py_backend="numpy", + backend="numpy", ) assert np.all(quantity.data == 0) quantity.data[0, 0] = 1 @@ -99,7 +99,7 @@ def test_data_and_field_access_right_full_array_and_compute_domain(): extent=(5, 5), dims=["dim1", "dim2"], units="units", - gt4py_backend="numpy", + backend="numpy", ) assert np.all(quantity.data == 0) # Write compute domain - test data is written with the offset @@ -139,7 +139,7 @@ def test_accessing_data_does_not_break_view( extent=extent, dims=dims, units=units, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] @@ -158,6 +158,6 @@ def test_numpy_data_becomes_cupy_with_gpu_backend( extent=extent, dims=dims, units=units, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, ) assert isinstance(quantity.data, cp.ndarray) diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index 745a8cd9..88653676 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -165,7 +165,13 @@ def param_product(*param_lists): ) @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_transpose( - quantity, target_dims, final_data, final_dims, final_origin, final_extent, numpy + quantity: Quantity, + target_dims, + final_data, + final_dims, + final_origin, + final_extent, + numpy, ): result = quantity.transpose(target_dims) numpy.testing.assert_array_equal(result.data, final_data) @@ -173,7 +179,7 @@ def test_transpose( assert result.origin == final_origin assert result.extent == final_extent assert result.units == quantity.units - assert result.gt4py_backend == quantity.gt4py_backend + assert result.backend == quantity.backend @pytest.mark.parametrize( diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index ed647e2c..2d34567a 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -175,7 +175,7 @@ def quantity(dims, units, origin, extent, shape, dtype, gt4py_backend): units=units, origin=origin, extent=extent, - gt4py_backend=gt4py_backend, + backend=gt4py_backend, )