From febfda9ecf3fe1a78717fda267cbfa11ba4692a8 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:27:08 -0400 Subject: [PATCH 1/8] Pass `dtype` down in allocator utils (gt4py_utils) --- ndsl/dsl/gt4py_utils.py | 48 +++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 31c60ca7..2b2d72a3 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import gt4py import numpy as np @@ -151,16 +151,42 @@ def make_storage_data( if n_dims == 1: data = _make_storage_data_1d( - data, shape, start, dummy, axis, read_only, backend=backend + data, + shape, + start, + dummy, + axis, + read_only, + dtype=dtype, + backend=backend, ) elif n_dims == 2: data = _make_storage_data_2d( - data, shape, start, dummy, axis, read_only, backend=backend + data, + shape, + start, + dummy, + axis, + read_only, + dtype=dtype, + backend=backend, ) elif n_dims >= 4: - data = _make_storage_data_Nd(data, shape, start, backend=backend) + data = _make_storage_data_Nd( + data, + shape, + start, + dtype=dtype, + backend=backend, + ) else: - data = _make_storage_data_3d(data, shape, start, backend=backend) + data = _make_storage_data_3d( + data, + shape, + start, + dtype=dtype, + backend=backend, + ) storage = gt4py.storage.from_array( data, @@ -180,11 +206,12 @@ def _make_storage_data_1d( axis: int = 2, read_only: bool = True, *, + dtype: DTypes = Float, backend: str, ) -> Field: # axis refers to a repeated axis, dummy refers to a singleton axis axis = min(axis, len(shape) - 1) - buffer = zeros(shape[axis], backend=backend) + buffer = zeros(shape[axis], dtype=dtype, backend=backend) if dummy: axis = list(set((0, 1, 2)).difference(dummy))[0] @@ -216,6 +243,7 @@ def _make_storage_data_2d( axis: int = 2, read_only: bool = True, *, + dtype: DTypes = Float, backend: str, ) -> Field: # axis refers to which axis should be repeated (when making a full 3d data), @@ -229,7 +257,7 @@ def _make_storage_data_2d( start1, start2 = start[0:2] size1, size2 = data.shape - buffer = zeros(shape2d, backend=backend) + buffer = zeros(shape2d, dtype=dtype, backend=backend) buffer[start1 : start1 + size1, start2 : start2 + size2] = asarray( data, type(buffer) ) @@ -249,11 +277,12 @@ def _make_storage_data_3d( shape: Tuple[int, ...], start: Tuple[int, ...] = (0, 0, 0), *, + dtype: DTypes = Float, backend: str, ) -> Field: istart, jstart, kstart = start isize, jsize, ksize = data.shape - buffer = zeros(shape, backend=backend) + buffer = zeros(shape, dtype=dtype, backend=backend) buffer[ istart : istart + isize, jstart : jstart + jsize, @@ -267,11 +296,12 @@ def _make_storage_data_Nd( shape: Tuple[int, ...], start: Tuple[int, ...] = None, *, + dtype: DTypes = Float, backend: str, ) -> Field: if start is None: start = tuple([0] * data.ndim) - buffer = zeros(shape, backend=backend) + buffer = zeros(shape, dtype=dtype, backend=backend) idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))]) buffer[idx] = asarray(data, type(buffer)) return buffer From 4112225436002fbec18e816d5565bb4a6b8acc82 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:28:07 -0400 Subject: [PATCH 2/8] Allow coriolis forces to be read in --- ndsl/grid/helper.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 6cb6a374..81092f85 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -332,13 +332,21 @@ def __init__( vertical_data: VerticalGridData, contravariant_data: ContravariantGridData, angle_data: AngleGridData, + fc=None, + fc_agrid=None, ): self._horizontal_data = horizontal_data self._vertical_data = vertical_data self._contravariant_data = contravariant_data self._angle_data = angle_data - self._fC = None - self._fC_agrid = None + if fc is not None: + self._fC = GridData._fC_from_data(fc, horizontal_data.lat) + else: + self._fc = None + if fc_agrid is not None: + self._fC_agrid = GridData._fC_from_data(fc_agrid, horizontal_data.lat) + else: + self._fC_agrid = None @classmethod def new_from_metric_terms(cls, metric_terms: MetricTerms): @@ -369,9 +377,7 @@ def lat_agrid(self) -> Quantity: return self._horizontal_data.lat_agrid @staticmethod - def _fC_from_lat(lat: Quantity) -> Quantity: - np = lat.np - data = 2.0 * constants.OMEGA * np.sin(lat.data) + def _fC_from_data(data, lat: Quantity) -> Quantity: return Quantity( data, units="1/s", @@ -381,6 +387,12 @@ def _fC_from_lat(lat: Quantity) -> Quantity: gt4py_backend=lat.gt4py_backend, ) + @staticmethod + def _fC_from_lat(lat: Quantity) -> Quantity: + np = lat.np + data = Float(2.0) * constants.OMEGA * np.sin(lat.data, dtype=Float) + return GridData._fC_from_data(data, lat) + @property def fC(self): """Coriolis parameter at cell corners""" From c710d225a9739986f3582fe646e9697156c7f0dc Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:28:22 -0400 Subject: [PATCH 3/8] Edge factors are always 64-bit --- ndsl/grid/generation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 172f4d53..c32ceb3f 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -3331,12 +3331,12 @@ def _calculate_edge_factors(self): self._np, ) - edge_w = quantity_cast_to_model_float(self.quantity_factory, edge_w_64) - edge_e = quantity_cast_to_model_float(self.quantity_factory, edge_e_64) - edge_s = quantity_cast_to_model_float(self.quantity_factory, edge_s_64) - edge_n = quantity_cast_to_model_float(self.quantity_factory, edge_n_64) - - return edge_w, edge_e, edge_s, edge_n + return ( + edge_w_64, + edge_e_64, + edge_s_64, + edge_n_64, + ) def _calculate_edge_a2c_vect_factors(self): edge_vect_s_64 = self.quantity_factory.zeros( From 322f25e6fa9c1dd48ec5fbc4fb1921921f2b77e9 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:31:54 -0400 Subject: [PATCH 4/8] Quantity QOL --- ndsl/quantity/quantity.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 33d72a44..4f80fff1 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -153,9 +153,9 @@ def from_data_array( gt4py_backend=gt4py_backend, ) - def to_netcdf(self, name: str, rank: int = -1) -> None: + def to_netcdf(self, path: str, name="var", rank: int = -1) -> None: if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: - self.data_array.to_netcdf(f"{name}__r{rank}.nc4") + self.data_array.to_dataset(name=name).to_netcdf(f"{path}__r{rank}.nc4") def halo_spec(self, n_halo: int) -> QuantityHaloSpec: return QuantityHaloSpec( @@ -260,8 +260,16 @@ def extent(self) -> Tuple[int, ...]: return self.metadata.extent @property - def data_array(self) -> xr.DataArray: - return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) + def data_array(self, full_data=False) -> xr.DataArray: + """Returns an Xarray.DataArray of the view (domain) + + Args: + full_data: Return the entire data (halo included) instead of the view + """ + if full_data: + return xr.DataArray(self.data[:], dims=self.dims, attrs=self.attrs) + else: + return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) @property def np(self) -> NumpyModule: From 7861f94d278b80ca2184b6e129f8fbc13c42cd5f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:32:01 -0400 Subject: [PATCH 5/8] Make sure to pass `dtype` to load the grid cleanly --- ndsl/stencils/testing/translate.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index f2d71e51..8132d290 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -384,6 +384,7 @@ def _edge_vector_storage(self, varname, axis, max_shape): shape=buffer.shape, backend=self.backend, mask=mask, + dtype=self.data[varname].dtype, ) def _make_composite_vvar_storage(self, varname, data3d, shape): @@ -397,6 +398,7 @@ def _make_composite_vvar_storage(self, varname, data3d, shape): shape=buffer.shape, origin=(1, 1, 0), backend=self.backend, + dtype=self.data[varname].dtype, ) def make_grid_storage(self, pygrid): @@ -418,6 +420,7 @@ def make_grid_storage(self, pygrid): (shape[0], shape[1], 3), origin=(0, 0, 0), backend=self.backend, + dtype=self.data[key].dtype, ) for key, axis in TranslateGrid.edge_var_axis.items(): if key in self.data: @@ -428,6 +431,7 @@ def make_grid_storage(self, pygrid): axis=axis, read_only=True, backend=self.backend, + dtype=self.data[key].dtype, ) for key, axis in TranslateGrid.edge_vect_axis.items(): if key in self.data: @@ -451,6 +455,7 @@ def make_grid_storage(self, pygrid): start=origin, read_only=True, backend=self.backend, + dtype=value.dtype, ) def python_grid(self): From ca7bbe26098c76a0d84e7c1579d2defcc73a4910 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:34:00 -0400 Subject: [PATCH 6/8] Translate grid: load coriolis forces, area 64 is 64-bit --- ndsl/stencils/testing/grid.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index fbee10a5..6d4b4b38 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -7,7 +7,6 @@ from ndsl.constants import N_HALO_DEFAULT, X_DIM, Y_DIM, Z_DIM from ndsl.dsl import gt4py_utils as utils from ndsl.dsl.stencil import GridIndexing -from ndsl.dsl.typing import Float from ndsl.grid.generation import GridDefinitions from ndsl.grid.helper import ( AngleGridData, @@ -505,7 +504,12 @@ def grid_data(self) -> "GridData": data = getattr(self, name) assert data is not None - quantity = self.quantity_factory.zeros(dims=dims, units=units, dtype=Float) + quantity = self.quantity_factory.zeros( + dims=dims, + units=units, + dtype=data.dtype, + allow_mismatch_float_precision=True, + ) if len(quantity.shape) == 3: quantity.data[:] = data[:, :, : quantity.shape[2]] elif len(quantity.shape) == 2: @@ -549,6 +553,7 @@ def grid_data(self) -> "GridData": data=self.area_64, dims=GridDefinitions.area.dims, units=GridDefinitions.area.units, + allow_mismatch_float_precision=True, ), rarea=self.quantity_factory.from_array( data=self.rarea, @@ -810,6 +815,8 @@ def grid_data(self) -> "GridData": vertical_data=vertical, contravariant_data=contravariant, angle_data=angle, + fc=self.fC, + fc_agrid=self.f0, ) return self._grid_data From eee4afc86ad4fa0f70d3801166e9e205f62a6b1f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:38:39 -0400 Subject: [PATCH 7/8] Bad merge --- ndsl/dsl/gt4py_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index 2b2d72a3..3e45837b 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import gt4py import numpy as np @@ -140,9 +140,7 @@ def make_storage_data( default_mask = (True, True, False) shape = (1, shape[axis]) else: - default_mask = tuple( - [i == axis for i in range(max_dim)] - ) # type: ignore + default_mask = tuple([i == axis for i in range(max_dim)]) # type: ignore elif dummy or axis != 2: default_mask = (True, True, True) else: From e1053fe0b8ab437df4bd7153bf1dacf87c631056 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 25 Mar 2025 13:46:25 -0400 Subject: [PATCH 8/8] Typo --- ndsl/grid/helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 81092f85..1a82d053 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -342,7 +342,7 @@ def __init__( if fc is not None: self._fC = GridData._fC_from_data(fc, horizontal_data.lat) else: - self._fc = None + self._fC = None if fc_agrid is not None: self._fC_agrid = GridData._fC_from_data(fc_agrid, horizontal_data.lat) else: