Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions ndsl/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from typing import Any

import numpy as np
import numpy.typing as npt
from gt4py import storage as gt_storage
from gt4py.cartesian import backend as gt_backend

from ndsl.constants import N_HALO_DEFAULT
from ndsl.dsl.typing import DTypes, Field, Float
from ndsl.dsl.typing import DTypes, Float
from ndsl.logging import ndsl_log
from ndsl.optional_imports import cupy as cp

Expand Down Expand Up @@ -80,7 +81,7 @@ def _translate_origin(origin: Sequence[int], mask: tuple[bool, ...]) -> Sequence


def make_storage_data(
data: Field,
data: npt.NDArray,
Comment thread
romanc marked this conversation as resolved.
shape: tuple[int, ...] | None = None,
origin: tuple[int, ...] = origin,
*,
Expand All @@ -92,7 +93,7 @@ def make_storage_data(
axis: int = 2,
max_dim: int = 3,
read_only: bool = True,
) -> Field:
) -> npt.NDArray:
"""Create a new gt4py storage from the given data.

Args:
Expand Down Expand Up @@ -197,7 +198,7 @@ def make_storage_data(


def _make_storage_data_1d(
data: Field,
data: npt.NDArray,
shape: tuple[int, ...],
start: tuple[int, ...] = (0, 0, 0),
dummy: tuple[int, ...] | None = None,
Expand All @@ -206,7 +207,7 @@ def _make_storage_data_1d(
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
) -> npt.NDArray:
# axis refers to a repeated axis, dummy refers to a singleton axis
axis = min(axis, len(shape) - 1)
buffer = zeros(shape[axis], dtype=dtype, backend=backend)
Expand Down Expand Up @@ -234,7 +235,7 @@ def _make_storage_data_1d(


def _make_storage_data_2d(
data: Field,
data: npt.NDArray,
shape: tuple[int, ...],
start: tuple[int, ...] = (0, 0, 0),
dummy: tuple[int, ...] | None = None,
Expand All @@ -243,7 +244,7 @@ def _make_storage_data_2d(
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
) -> npt.NDArray:
# axis refers to which axis should be repeated (when making a full 3d data),
# dummy refers to a singleton axis
do_reshape = dummy or axis != 2
Expand Down Expand Up @@ -271,13 +272,13 @@ def _make_storage_data_2d(


def _make_storage_data_3d(
data: Field,
data: npt.NDArray,
shape: tuple[int, ...],
start: tuple[int, ...] = (0, 0, 0),
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
) -> npt.NDArray:
istart, jstart, kstart = start
isize, jsize, ksize = data.shape
buffer = zeros(shape, dtype=dtype, backend=backend)
Expand All @@ -290,13 +291,13 @@ def _make_storage_data_3d(


def _make_storage_data_Nd(
data: Field,
data: npt.NDArray,
shape: tuple[int, ...],
start: tuple[int, ...] | None = None,
*,
dtype: DTypes = Float,
backend: str,
) -> Field:
) -> npt.NDArray:
if start is None:
start = tuple([0] * data.ndim)
buffer = zeros(shape, dtype=dtype, backend=backend)
Expand All @@ -312,7 +313,7 @@ def make_storage_from_shape(
backend: str,
dtype: DTypes = Float,
mask: tuple[bool, ...] | None = None,
) -> Field:
) -> npt.NDArray:
"""Create a new gt4py storage of a given shape filled with zeros.

Args:
Expand Down Expand Up @@ -349,7 +350,7 @@ def make_storage_from_shape(


def make_storage_dict(
data: Field,
data: npt.NDArray,
shape: tuple[int, ...] | None = None,
origin: tuple[int, ...] = origin,
start: tuple[int, ...] = (0, 0, 0),
Expand All @@ -359,11 +360,11 @@ def make_storage_dict(
*,
backend: str,
dtype: DTypes = Float,
) -> dict[str, "Field"]:
) -> dict[str, npt.NDArray]:
assert names is not None, "for 4d variable storages, specify a list of names"
if shape is None:
shape = data.shape
data_dict: dict[str, Field] = dict()
data_dict: dict[str, npt.NDArray] = dict()
for i in range(data.shape[3]):
data_dict[names[i]] = make_storage_data(
squeeze(data[:, :, :, i]),
Expand Down
26 changes: 20 additions & 6 deletions ndsl/stencils/testing/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any

import numpy as np
import numpy.typing as npt

import ndsl.dsl.gt4py_utils as utils
from ndsl.dsl.stencil import StencilFactory
from ndsl.dsl.typing import Field, Float, Int # noqa: F401
from ndsl.optional_imports import cupy as cp
from ndsl.quantity import Quantity
from ndsl.stencils.testing.grid import Grid # type: ignore
Expand Down Expand Up @@ -34,10 +34,10 @@ def as_numpy(
def _convert(value: Quantity | np.ndarray) -> np.ndarray:
if isinstance(value, Quantity):
return value.data
elif cp is not None and isinstance(value, cp.ndarray):
return cp.asnumpy(value)
elif isinstance(value, np.ndarray):
return value
elif cp is not None and isinstance(value, cp.ndarray):
return cp.asnumpy(value)
else:
raise TypeError(f"Unrecognized value type: {type(value)}")

Expand Down Expand Up @@ -135,7 +135,7 @@ def make_storage_data(
names_4d: list[str] | None = None,
read_only: bool = False,
full_shape: bool = False,
) -> Field:
) -> dict[str, npt.NDArray] | npt.NDArray:
"""Copy input data into a gt4py.storage with given shape.

`array` is copied. Takes care of the device upload if necessary.
Expand Down Expand Up @@ -201,7 +201,10 @@ def collect_start_indices(self, datashape, varinfo):
return istart, jstart, kstart

def make_storage_data_input_vars(
self, inputs, storage_vars=None, dict_4d=True
self,
inputs,
storage_vars=None,
dict_4d=True,
) -> None:
"""From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionary to update inputs to
their configured shape.
Expand Down Expand Up @@ -292,7 +295,18 @@ def slice_output(self, inputs, out_data=None) -> dict[str, Any]:
)
out[serialname] = var4d
else:
slice_tuple = self.grid.slice_dict(ds, len(data_result.shape))
# Get slice for data dimensions (after original 3D)
if len(data_result.shape) > 3:
data_dims_slice = tuple(
[slice(0, ddim_end) for ddim_end in data_result.shape[3:]]
)
else:
data_dims_slice = ()
# Slice combine the expected cartesian and data_dims
cartesian_slice = self.grid.slice_dict(
ds, min(len(data_result.shape), 3)
)
slice_tuple = cartesian_slice + data_dims_slice
out[serialname] = np.squeeze(data_result[slice_tuple])
if "kaxis" in info:
out[serialname] = np.moveaxis(out[serialname], 2, info["kaxis"])
Expand Down