Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None): #
dims=quantity.dims,
units=quantity.units,
allow_mismatch_float_precision=True,
backend=quantity.backend,
)
if recv_state is not None and name in recv_state:
tile_quantity = self.gather(
Expand Down
2 changes: 2 additions & 0 deletions ndsl/monitor/netcdf_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, initial: Quantity, time_chunk_size: int):
self._dims = initial.dims
self._units = initial.units
self._i_time = 1
self._backend = initial.backend

def append(self, quantity: Quantity) -> None:
# Allow mismatch precision here since this is I/O
Expand All @@ -37,6 +38,7 @@ def data(self) -> Quantity:
dims=("time",) + tuple(self._dims),
units=self._units,
allow_mismatch_float_precision=True,
backend=self._backend,
)


Expand Down
9 changes: 8 additions & 1 deletion ndsl/quantity/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(
dims: Sequence[str],
units: str,
*,
backend: str | None = None,
origin: Sequence[int] | None = None,
extent: Sequence[int] | None = None,
gt4py_backend: str | None = None,
allow_mismatch_float_precision: bool = False,
backend: str | None = None,
):
if gt4py_backend is not None:
warnings.warn(
Expand All @@ -38,6 +38,13 @@ def __init__(
if backend is None:
backend = gt4py_backend

if backend is None:
warnings.warn(
"`backend` will be a required argument starting with the next version of NDSL.",
DeprecationWarning,
stacklevel=2,
)

super().__init__(
data,
dims,
Expand Down
35 changes: 28 additions & 7 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ def __init__(
dims: Sequence[str],
units: str,
*,
backend: str | None = None,
origin: Sequence[int] | None = None,
extent: Sequence[int] | None = None,
gt4py_backend: str | None = None,
allow_mismatch_float_precision: bool = False,
number_of_halo_points: int = 0,
backend: str | None = None,
):
"""Initialize a Quantity.

Args:
data: ndarray-like object containing the underlying data
dims: dimension names for each axis
units: units of the quantity
backend: GT4Py backend name. We ensure that the data is allocated in a
performance optimal way for that backend and copy if necessary.
origin: first point in data within the
computational domain. Defaults to None.
extent: number of points along each axis
Expand All @@ -54,8 +56,6 @@ def __init__(
allow_mismatch_float_precision: allow for precision that is
not the simulation-wide default configuration. Defaults to False.
number_of_halo_points: Number of halo points used. Defaults to 0.
backend: GT4Py backend name. If given, we check that the data is
allocated in a performance optimal way for that backend.

Raises:
ValueError: Data-type mismatch between configuration and input-data
Expand All @@ -70,6 +70,13 @@ def __init__(
if backend is None:
backend = gt4py_backend

if backend is None:
warnings.warn(
"`backend` will be a required argument starting with the next version of NDSL.",
DeprecationWarning,
stacklevel=2,
)

if (
not allow_mismatch_float_precision
and is_float(data.dtype)
Expand Down Expand Up @@ -179,8 +186,9 @@ def from_data_array(
allow_mismatch_float_precision: allow for precision that is
not the simulation-wide default configuration. Defaults to False.
number_of_halo_points: Number of halo points used. Defaults to 0.
backend: GT4Py backend name. If given, we check that the data is
allocated in a performance optimal way for that backend.
backend: GT4Py backend name. If given, we allocate data in a performance
optimal way for this backend. Overrides any potentially saved `backend`
in `data.attrs["backend"]`.
"""
if "units" not in data_array.attrs:
raise ValueError("need units attribute to create Quantity from DataArray")
Expand All @@ -201,7 +209,7 @@ def from_data_array(
origin=origin,
extent=extent,
number_of_halo_points=number_of_halo_points,
backend=backend,
backend=_resolve_backend(data_array, backend),
)

def to_netcdf(
Expand Down Expand Up @@ -283,7 +291,7 @@ def backend(self) -> str | None:

@property
def attrs(self) -> dict:
return dict(**self._attrs, units=self._metadata.units)
return dict(**self._attrs, units=self.units, backend=self.backend)

@property
def dims(self) -> tuple[str, ...]:
Expand Down Expand Up @@ -495,3 +503,16 @@ def _ensure_int_tuple(arg: Sequence, arg_name: str) -> tuple:
f"unexpected type {type(item)}"
)
return tuple(return_list)


def _resolve_backend(data: xr.DataArray, backend: str | None) -> str:
if backend is not None:
# Forced backend name takes precedence
return backend

# If backend name was serialized with data, take this one
if "backend" in data.attrs:
return data.attrs["backend"]

# else, fall back to assume python-based layout.
return "debug"
12 changes: 9 additions & 3 deletions tests/dsl/test_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def test_domain_size_comparison(
domain: tuple[int],
call_count: int,
):
quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent)
quantity = Quantity(
np.zeros(extent), dimensions, "n/a", extent=extent, backend="debug"
)
stencil = FrozenStencil(
copy_stencil,
origin=(0, 0, 0),
Expand Down Expand Up @@ -147,7 +149,9 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None:

def test_stencil_2D_temporaries() -> None:
domain = (2, 2, 5)
quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain)
quantity = Quantity(
np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug"
)
stencil = FrozenStencil(
two_dim_temporaries_stencil,
origin=(0, 0, 0),
Expand All @@ -164,7 +168,9 @@ def test_stencil_2D_temporaries() -> None:
)
def test_validation_call_count(iterations: tuple[int]):
domain = (2, 2, 5)
quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain)
quantity = Quantity(
np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, backend="debug"
)
stencil_config = StencilConfig(
compilation_config=CompilationConfig(backend="numpy", rebuild=True)
)
Expand Down
3 changes: 3 additions & 0 deletions tests/quantity/test_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_boundary_data_1_by_1_array_1_halo():
units="m",
origin=(1, 1),
extent=(1, 1),
backend="debug",
)
for side in (
WEST,
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy):
units="m",
origin=(1, 1, 1),
extent=(1, 1, 1),
backend="debug",
)
for side in (
WEST,
Expand Down Expand Up @@ -109,6 +111,7 @@ def test_boundary_data_2_by_2_array_2_halo():
units="m",
origin=(2, 2),
extent=(2, 2),
backend="debug",
)
for side in (
WEST,
Expand Down
3 changes: 3 additions & 0 deletions tests/quantity/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_deepcopy_copy_is_editable_by_view():
extent=(nx, ny, nz),
dims=["x", "y", "z"],
units="",
backend="debug",
)
quantity_copy = copy.deepcopy(quantity)
# assertion below is only valid if we're overwriting the entire data through view
Expand All @@ -31,6 +32,7 @@ def test_deepcopy_copy_is_editable_by_data():
extent=(nx, ny, nz),
dims=["x", "y", "z"],
units="",
backend="debug",
)
quantity_copy = copy.deepcopy(quantity)
quantity_copy.data[:] = 1.0
Expand All @@ -46,6 +48,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data():
extent=(nx, ny, nz),
dims=["x", "y", "z"],
units="",
backend="debug",
)
quantity_copy = copy.deepcopy(quantity)
quantity_copy.data[:] = 1.0
Expand Down
17 changes: 15 additions & 2 deletions tests/quantity/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ndsl import Local


def test_local_descriptor_is_transient() -> None:
def test_dace_data_descriptor_is_transient() -> None:
nx = 5
shape = (nx,)
local = Local(
Expand All @@ -19,7 +19,7 @@ def test_local_descriptor_is_transient() -> None:
assert array.transient


def test_local_gt4py_backend_is_deprecated() -> None:
def test_gt4py_backend_is_deprecated() -> None:
nx = 5
shape = (nx,)
backend = "debug"
Expand All @@ -39,3 +39,16 @@ def test_local_gt4py_backend_is_deprecated() -> None:
# make sure we are backwards compatible (for now)
with pytest.deprecated_call(match="gt4py_backend is deprecated"):
assert local.gt4py_backend == backend


def test_backend_will_be_required() -> None:
nx = 5
shape = (nx,)
with pytest.deprecated_call(match="`backend` will be a required argument"):
local = Local(
data=np.empty(shape),
origin=(0,),
extent=(nx,),
dims=("dim_X",),
units="n/a",
)
77 changes: 59 additions & 18 deletions tests/quantity/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype):

@pytest.fixture
def quantity(data, origin, extent, dims, units):
return Quantity(data, origin=origin, extent=extent, dims=dims, units=units)
return Quantity(
data, origin=origin, extent=extent, dims=dims, units=units, backend="debug"
)


def test_smaller_data_raises(data, origin, extent, dims, units):
Expand All @@ -72,25 +74,55 @@ def test_smaller_data_raises(data, origin, extent, dims, units):
except IndexError:
pass
else:
with pytest.raises(ValueError):
with pytest.raises(
ValueError, match="received .* dimension names for .* dimensions: .*"
):
Quantity(
small_data, origin=origin, extent=extent, dims=dims, units=units
small_data,
origin=origin,
extent=extent,
dims=dims,
units=units,
backend="debug",
)


def test_smaller_dims_raises(data, origin, extent, dims, units):
with pytest.raises(ValueError):
Quantity(data, origin=origin, extent=extent, dims=dims[:-1], units=units)
with pytest.raises(
ValueError, match="received .* dimension names for .* dimensions: .*"
):
Quantity(
data,
origin=origin,
extent=extent,
dims=dims[:-1],
units=units,
backend="debug",
)


def test_smaller_origin_raises(data, origin, extent, dims, units):
with pytest.raises(ValueError):
Quantity(data, origin=origin[:-1], extent=extent, dims=dims, units=units)
with pytest.raises(ValueError, match="received .* origins for .* dimensions: .*"):
Quantity(
data,
origin=origin[:-1],
extent=extent,
dims=dims,
units=units,
backend="debug",
)


def test_smaller_extent_raises(data, origin, extent, dims, units):
with pytest.raises(ValueError):
Quantity(data, origin=origin, extent=extent[:-1], dims=dims, units=units)
with pytest.raises(ValueError, match="received .* extents for .* dimensions: .*"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Quantity(
data,
origin=origin,
extent=extent[:-1],
dims=dims,
units=units,
backend="debug",
)


def test_data_change_affects_quantity(data, quantity, numpy):
Expand Down Expand Up @@ -229,27 +261,23 @@ def test_shift_slice(slice_in, shift, extent, slice_out):
@pytest.mark.parametrize(
"quantity",
[
Quantity(np.array(5), dims=[], units="", backend="debug"),
Quantity(
np.array(5),
dims=[],
units="",
),
Quantity(
np.array([1, 2, 3]),
dims=["dimension"],
units="degK",
np.array([1, 2, 3]), dims=["dimension"], units="degK", backend="debug"
),
Quantity(
np.random.randn(3, 2, 4),
dims=["dim1", "dim_2", "dimension_3"],
units="m",
backend="debug",
),
Quantity(
np.random.randn(8, 6, 6),
dims=["dim1", "dim_2", "dimension_3"],
units="km",
origin=(2, 2, 2),
extent=(4, 2, 2),
backend="debug",
),
],
)
Expand All @@ -265,7 +293,7 @@ def test_to_data_array(quantity):


def test_data_setter():
quantity = Quantity(np.ones((5,)), dims=["dim1"], units="")
quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", backend="debug")

# After allocation - field and data are the same (origin is 0)
assert quantity.data.shape == quantity.field.shape
Expand Down Expand Up @@ -356,3 +384,16 @@ def test_assign_basic_data_is_deprecated() -> None:
# make sure we can still use it (for now)
for i in range(5):
assert quantity.data[i] == i


def test_constructor_backend_will_be_required() -> None:
nx = 5
shape = (nx,)
with pytest.deprecated_call(match="`backend` will be a required argument"):
local = Quantity(
data=np.empty(shape),
origin=(0,),
extent=(nx,),
dims=("dim_X",),
units="n/a",
)
Loading