From 4f2e148f047443afd73b27361ec13f730d393842 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 15 Oct 2025 13:55:16 +0200 Subject: [PATCH 1/2] store nhalo in the quantity --- ndsl/initialization/allocator.py | 1 + ndsl/quantity/metadata.py | 2 ++ ndsl/quantity/quantity.py | 46 +++++++++++++++++++------------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index a4b55b7f..b16fb1e8 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -209,6 +209,7 @@ def _allocate( extent=extent, gt4py_backend=self._backend(), allow_mismatch_float_precision=allow_mismatch_float_precision, + number_of_halo_points=self.sizer.n_halo, ) def get_quantity_halo_spec( diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index d7ddba0f..f5ec484f 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -17,6 +17,8 @@ class QuantityMetadata: "the start of the computational domain" extent: Tuple[int, ...] "the shape of the computational domain" + n_halo: int + "Number of halo-points used in the horizontal" dims: Tuple[str, ...] "names of each dimension" units: str diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 29869829..9fc1a6a4 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast +from typing import Any, Iterable, Sequence, Tuple, Union, cast import dace import matplotlib.pyplot as plt @@ -31,26 +31,31 @@ def __init__( data, dims: Sequence[str], units: str, - origin: Optional[Sequence[int]] = None, - extent: Optional[Sequence[int]] = None, - gt4py_backend: Union[str, None] = None, + origin: Sequence[int] | None = None, + extent: Sequence[int] | None = None, + gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, + number_of_halo_points: int = 0, ): - """ - Initialize a Quantity. + """Initialize a Quantity. Args: - data: ndarray-like object containing the underlying data - dims: dimension names for each axis - units: units of the quantity - origin: first point in data within the computational domain - extent: number of points along each axis within the computational domain - gt4py_backend: backend to use for gt4py storages, if not given this will - be derived from a Storage if given as the data argument, otherwise the - storage attribute is disabled and will raise an exception. Will raise - a TypeError if this is given with a gt4py storage type as data - """ + data (_type_): ndarray-like object containing the underlying data + dims (Sequence[str]): dimension names for each axis + units (str): units of the quantity + origin (Sequence[int] | None, optional): first point in data within the + computational domain. Defaults to None. + extent (Sequence[int] | None, optional): number of points along each axis + within the computational domain. Defaults to None. + gt4py_backend (str | None, optional): _description_. Defaults to None. + allow_mismatch_float_precision (bool, optional): allow for precision that is + not the simulation-wide default configuration. Defaults to False. + number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + Raises: + ValueError: Data-type mismatch between configuration and input-data + TypeError: Typing of the data that does not fit + """ if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -107,8 +112,6 @@ def __init__( ) ) else: - if data is None: - raise TypeError("requires 'data' to be passed") # We have no info about the gt4py_backend, so just assign it. self._data = data @@ -116,6 +119,7 @@ def __init__( self._metadata = QuantityMetadata( origin=_ensure_int_tuple(origin, "origin"), extent=_ensure_int_tuple(extent, "extent"), + n_halo=number_of_halo_points, dims=tuple(dims), units=units, data_type=type(self._data), @@ -134,6 +138,7 @@ def from_data_array( origin: Sequence[int] = None, extent: Sequence[int] = None, gt4py_backend: Union[str, None] = None, + number_of_halo_points: int = 0, ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -154,6 +159,7 @@ def from_data_array( data_array.attrs["units"], origin=origin, extent=extent, + number_of_halo_points=number_of_halo_points, gt4py_backend=gt4py_backend, ) @@ -169,6 +175,10 @@ def to_netcdf(self, path: str, name="var", rank: int = -1, all_data=False) -> No ) def halo_spec(self, n_halo: int) -> QuantityHaloSpec: + # This is a preliminary check to see if this is ever triggered. + # If not, we can remove it down the line and change the call signature. + if n_halo != self._metadata.n_halo: + warnings.warn("Found inconsistency with number of halo points in Quantity") return QuantityHaloSpec( n_halo, self.data.strides, From f96091da0a9eaf4566283c5c43b70e096aa8bc12 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 15 Oct 2025 16:26:38 +0200 Subject: [PATCH 2/2] reviewer's comments --- ndsl/quantity/quantity.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 9fc1a6a4..464f2534 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -47,7 +47,9 @@ def __init__( computational domain. Defaults to None. extent (Sequence[int] | None, optional): number of points along each axis within the computational domain. Defaults to None. - gt4py_backend (str | None, optional): _description_. Defaults to None. + gt4py_backend (str | None, optional): backend to use for gt4py storages, + if not given this will be derived from a Storage + if given as the data argument. Defaults to None. allow_mismatch_float_precision (bool, optional): allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. @@ -178,7 +180,10 @@ def halo_spec(self, n_halo: int) -> QuantityHaloSpec: # This is a preliminary check to see if this is ever triggered. # If not, we can remove it down the line and change the call signature. if n_halo != self._metadata.n_halo: - warnings.warn("Found inconsistency with number of halo points in Quantity") + warnings.warn( + "Found inconsistency with number of halo points in Quantity:" + + f"{n_halo} vs {self._metadata.n_halo}" + ) return QuantityHaloSpec( n_halo, self.data.strides,