Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions ndsl/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def make_storage_data(
default_mask = (True, True, False)
shape = (1, shape[axis])
else:
default_mask = tuple(
[i == axis for i in range(max_dim)]
) # type: ignore
default_mask = tuple([i == axis for i in range(max_dim)]) # type: ignore
elif dummy or axis != 2:
default_mask = (True, True, True)
else:
Expand All @@ -151,16 +149,42 @@ def make_storage_data(

if n_dims == 1:
data = _make_storage_data_1d(
data, shape, start, dummy, axis, read_only, backend=backend
data,
shape,
start,
dummy,
axis,
read_only,
dtype=dtype,
backend=backend,
)
elif n_dims == 2:
data = _make_storage_data_2d(
data, shape, start, dummy, axis, read_only, backend=backend
data,
shape,
start,
dummy,
axis,
read_only,
dtype=dtype,
backend=backend,
)
elif n_dims >= 4:
data = _make_storage_data_Nd(data, shape, start, backend=backend)
data = _make_storage_data_Nd(
data,
shape,
start,
dtype=dtype,
backend=backend,
)
else:
data = _make_storage_data_3d(data, shape, start, backend=backend)
data = _make_storage_data_3d(
data,
shape,
start,
dtype=dtype,
backend=backend,
)

storage = gt4py.storage.from_array(
data,
Expand All @@ -180,11 +204,12 @@ def _make_storage_data_1d(
axis: int = 2,
read_only: bool = True,
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
# axis refers to a repeated axis, dummy refers to a singleton axis
axis = min(axis, len(shape) - 1)
buffer = zeros(shape[axis], backend=backend)
buffer = zeros(shape[axis], dtype=dtype, backend=backend)
if dummy:
axis = list(set((0, 1, 2)).difference(dummy))[0]

Expand Down Expand Up @@ -216,6 +241,7 @@ def _make_storage_data_2d(
axis: int = 2,
read_only: bool = True,
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
# axis refers to which axis should be repeated (when making a full 3d data),
Expand All @@ -229,7 +255,7 @@ def _make_storage_data_2d(

start1, start2 = start[0:2]
size1, size2 = data.shape
buffer = zeros(shape2d, backend=backend)
buffer = zeros(shape2d, dtype=dtype, backend=backend)
buffer[start1 : start1 + size1, start2 : start2 + size2] = asarray(
data, type(buffer)
)
Expand All @@ -249,11 +275,12 @@ def _make_storage_data_3d(
shape: Tuple[int, ...],
start: Tuple[int, ...] = (0, 0, 0),
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
istart, jstart, kstart = start
isize, jsize, ksize = data.shape
buffer = zeros(shape, backend=backend)
buffer = zeros(shape, dtype=dtype, backend=backend)
buffer[
istart : istart + isize,
jstart : jstart + jsize,
Expand All @@ -267,11 +294,12 @@ def _make_storage_data_Nd(
shape: Tuple[int, ...],
start: Tuple[int, ...] = None,
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
if start is None:
start = tuple([0] * data.ndim)
buffer = zeros(shape, backend=backend)
buffer = zeros(shape, dtype=dtype, backend=backend)
idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))])
buffer[idx] = asarray(data, type(buffer))
return buffer
Expand Down
12 changes: 6 additions & 6 deletions ndsl/grid/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,12 +3331,12 @@ def _calculate_edge_factors(self):
self._np,
)

edge_w = quantity_cast_to_model_float(self.quantity_factory, edge_w_64)
edge_e = quantity_cast_to_model_float(self.quantity_factory, edge_e_64)
edge_s = quantity_cast_to_model_float(self.quantity_factory, edge_s_64)
edge_n = quantity_cast_to_model_float(self.quantity_factory, edge_n_64)

return edge_w, edge_e, edge_s, edge_n
return (
edge_w_64,
edge_e_64,
edge_s_64,
edge_n_64,
)

def _calculate_edge_a2c_vect_factors(self):
edge_vect_s_64 = self.quantity_factory.zeros(
Expand Down
22 changes: 17 additions & 5 deletions ndsl/grid/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,21 @@ def __init__(
vertical_data: VerticalGridData,
contravariant_data: ContravariantGridData,
angle_data: AngleGridData,
fc=None,
fc_agrid=None,
):
self._horizontal_data = horizontal_data
self._vertical_data = vertical_data
self._contravariant_data = contravariant_data
self._angle_data = angle_data
self._fC = None
self._fC_agrid = None
if fc is not None:
self._fC = GridData._fC_from_data(fc, horizontal_data.lat)
else:
self._fC = None
if fc_agrid is not None:
self._fC_agrid = GridData._fC_from_data(fc_agrid, horizontal_data.lat)
else:
self._fC_agrid = None

@classmethod
def new_from_metric_terms(cls, metric_terms: MetricTerms):
Expand Down Expand Up @@ -369,9 +377,7 @@ def lat_agrid(self) -> Quantity:
return self._horizontal_data.lat_agrid

@staticmethod
def _fC_from_lat(lat: Quantity) -> Quantity:
np = lat.np
data = 2.0 * constants.OMEGA * np.sin(lat.data)
def _fC_from_data(data, lat: Quantity) -> Quantity:
return Quantity(
data,
units="1/s",
Expand All @@ -381,6 +387,12 @@ def _fC_from_lat(lat: Quantity) -> Quantity:
gt4py_backend=lat.gt4py_backend,
)

@staticmethod
def _fC_from_lat(lat: Quantity) -> Quantity:
np = lat.np
data = Float(2.0) * constants.OMEGA * np.sin(lat.data, dtype=Float)
return GridData._fC_from_data(data, lat)

@property
def fC(self):
"""Coriolis parameter at cell corners"""
Expand Down
16 changes: 12 additions & 4 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def from_data_array(
gt4py_backend=gt4py_backend,
)

def to_netcdf(self, name: str, rank: int = -1) -> None:
def to_netcdf(self, path: str, name="var", rank: int = -1) -> None:
if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank:
self.data_array.to_netcdf(f"{name}__r{rank}.nc4")
self.data_array.to_dataset(name=name).to_netcdf(f"{path}__r{rank}.nc4")

def halo_spec(self, n_halo: int) -> QuantityHaloSpec:
return QuantityHaloSpec(
Expand Down Expand Up @@ -260,8 +260,16 @@ def extent(self) -> Tuple[int, ...]:
return self.metadata.extent

@property
def data_array(self) -> xr.DataArray:
return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs)
def data_array(self, full_data=False) -> xr.DataArray:
"""Returns an Xarray.DataArray of the view (domain)

Args:
full_data: Return the entire data (halo included) instead of the view
"""
if full_data:
return xr.DataArray(self.data[:], dims=self.dims, attrs=self.attrs)
else:
return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs)

@property
def np(self) -> NumpyModule:
Expand Down
11 changes: 9 additions & 2 deletions ndsl/stencils/testing/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ndsl.constants import N_HALO_DEFAULT, X_DIM, Y_DIM, Z_DIM
from ndsl.dsl import gt4py_utils as utils
from ndsl.dsl.stencil import GridIndexing
from ndsl.dsl.typing import Float
from ndsl.grid.generation import GridDefinitions
from ndsl.grid.helper import (
AngleGridData,
Expand Down Expand Up @@ -505,7 +504,12 @@ def grid_data(self) -> "GridData":
data = getattr(self, name)
assert data is not None

quantity = self.quantity_factory.zeros(dims=dims, units=units, dtype=Float)
quantity = self.quantity_factory.zeros(
dims=dims,
units=units,
dtype=data.dtype,
allow_mismatch_float_precision=True,
)
if len(quantity.shape) == 3:
quantity.data[:] = data[:, :, : quantity.shape[2]]
elif len(quantity.shape) == 2:
Expand Down Expand Up @@ -549,6 +553,7 @@ def grid_data(self) -> "GridData":
data=self.area_64,
dims=GridDefinitions.area.dims,
units=GridDefinitions.area.units,
allow_mismatch_float_precision=True,
),
rarea=self.quantity_factory.from_array(
data=self.rarea,
Expand Down Expand Up @@ -810,6 +815,8 @@ def grid_data(self) -> "GridData":
vertical_data=vertical,
contravariant_data=contravariant,
angle_data=angle,
fc=self.fC,
fc_agrid=self.f0,
)
return self._grid_data

Expand Down
5 changes: 5 additions & 0 deletions ndsl/stencils/testing/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def _edge_vector_storage(self, varname, axis, max_shape):
shape=buffer.shape,
backend=self.backend,
mask=mask,
dtype=self.data[varname].dtype,
)

def _make_composite_vvar_storage(self, varname, data3d, shape):
Expand All @@ -397,6 +398,7 @@ def _make_composite_vvar_storage(self, varname, data3d, shape):
shape=buffer.shape,
origin=(1, 1, 0),
backend=self.backend,
dtype=self.data[varname].dtype,
)

def make_grid_storage(self, pygrid):
Expand All @@ -418,6 +420,7 @@ def make_grid_storage(self, pygrid):
(shape[0], shape[1], 3),
origin=(0, 0, 0),
backend=self.backend,
dtype=self.data[key].dtype,
)
for key, axis in TranslateGrid.edge_var_axis.items():
if key in self.data:
Expand All @@ -428,6 +431,7 @@ def make_grid_storage(self, pygrid):
axis=axis,
read_only=True,
backend=self.backend,
dtype=self.data[key].dtype,
)
for key, axis in TranslateGrid.edge_vect_axis.items():
if key in self.data:
Expand All @@ -451,6 +455,7 @@ def make_grid_storage(self, pygrid):
start=origin,
read_only=True,
backend=self.backend,
dtype=value.dtype,
)

def python_grid(self):
Expand Down