diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index a4c99c8a..e466cd9e 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -210,6 +210,7 @@ def _allocate( extent=extent, gt4py_backend=self._backend(), allow_mismatch_float_precision=allow_mismatch_float_precision, + number_of_halo_points=self.sizer.n_halo, ) def get_quantity_halo_spec( diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index 8383ffe9..7e7b4f16 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -19,6 +19,8 @@ class QuantityMetadata: "the start of the computational domain" extent: tuple[int, ...] "the shape of the computational domain" + n_halo: int + "Number of halo-points used in the horizontal" dims: tuple[str, ...] "names of each dimension" units: str diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 6eba3572..bffaf884 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -36,22 +36,29 @@ def __init__( extent: Sequence[int] | None = None, gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, + number_of_halo_points: int = 0, ): - """ - Initialize a Quantity. + """Initialize a Quantity. Args: - data: ndarray-like object containing the underlying data - dims: dimension names for each axis - units: units of the quantity - origin: first point in data within the computational domain - extent: number of points along each axis within the computational domain - gt4py_backend: backend to use for gt4py storages, if not given this will - be derived from a Storage if given as the data argument, otherwise the - storage attribute is disabled and will raise an exception. Will raise - a TypeError if this is given with a gt4py storage type as data - """ + data (_type_): ndarray-like object containing the underlying data + dims (Sequence[str]): dimension names for each axis + units (str): units of the quantity + origin (Sequence[int] | None, optional): first point in data within the + computational domain. Defaults to None. + extent (Sequence[int] | None, optional): number of points along each axis + within the computational domain. Defaults to None. + gt4py_backend (str | None, optional): backend to use for gt4py storages, + if not given this will be derived from a Storage + if given as the data argument. Defaults to None. + allow_mismatch_float_precision (bool, optional): allow for precision that is + not the simulation-wide default configuration. Defaults to False. + number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + Raises: + ValueError: Data-type mismatch between configuration and input-data + TypeError: Typing of the data that does not fit + """ if ( not allow_mismatch_float_precision and is_float(data.dtype) @@ -108,8 +115,6 @@ def __init__( ) ) else: - if data is None: - raise TypeError("requires 'data' to be passed") # We have no info about the gt4py_backend, so just assign it. self._data = data @@ -117,6 +122,7 @@ def __init__( self._metadata = QuantityMetadata( origin=_ensure_int_tuple(origin, "origin"), extent=_ensure_int_tuple(extent, "extent"), + n_halo=number_of_halo_points, dims=tuple(dims), units=units, data_type=type(self._data), @@ -135,6 +141,7 @@ def from_data_array( origin: Sequence[int] | None = None, extent: Sequence[int] | None = None, gt4py_backend: str | None = None, + number_of_halo_points: int = 0, ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -155,6 +162,7 @@ def from_data_array( data_array.attrs["units"], origin=origin, extent=extent, + number_of_halo_points=number_of_halo_points, gt4py_backend=gt4py_backend, ) @@ -172,6 +180,13 @@ def to_netcdf( ) def halo_spec(self, n_halo: int) -> QuantityHaloSpec: + # This is a preliminary check to see if this is ever triggered. + # If not, we can remove it down the line and change the call signature. + if n_halo != self._metadata.n_halo: + warnings.warn( + "Found inconsistency with number of halo points in Quantity:" + + f"{n_halo} vs {self._metadata.n_halo}" + ) return QuantityHaloSpec( n_halo, self.data.strides,