diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index fe1c5a5d..3b9c5fd9 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -3,11 +3,12 @@ from typing import Any import numpy as np +import numpy.typing as npt from gt4py import storage as gt_storage from gt4py.cartesian import backend as gt_backend from ndsl.constants import N_HALO_DEFAULT -from ndsl.dsl.typing import DTypes, Field, Float +from ndsl.dsl.typing import DTypes, Float from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp @@ -80,7 +81,7 @@ def _translate_origin(origin: Sequence[int], mask: tuple[bool, ...]) -> Sequence def make_storage_data( - data: Field, + data: npt.NDArray, shape: tuple[int, ...] | None = None, origin: tuple[int, ...] = origin, *, @@ -92,7 +93,7 @@ def make_storage_data( axis: int = 2, max_dim: int = 3, read_only: bool = True, -) -> Field: +) -> npt.NDArray: """Create a new gt4py storage from the given data. Args: @@ -197,7 +198,7 @@ def make_storage_data( def _make_storage_data_1d( - data: Field, + data: npt.NDArray, shape: tuple[int, ...], start: tuple[int, ...] = (0, 0, 0), dummy: tuple[int, ...] | None = None, @@ -206,7 +207,7 @@ def _make_storage_data_1d( *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: # axis refers to a repeated axis, dummy refers to a singleton axis axis = min(axis, len(shape) - 1) buffer = zeros(shape[axis], dtype=dtype, backend=backend) @@ -234,7 +235,7 @@ def _make_storage_data_1d( def _make_storage_data_2d( - data: Field, + data: npt.NDArray, shape: tuple[int, ...], start: tuple[int, ...] = (0, 0, 0), dummy: tuple[int, ...] | None = None, @@ -243,7 +244,7 @@ def _make_storage_data_2d( *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: # axis refers to which axis should be repeated (when making a full 3d data), # dummy refers to a singleton axis do_reshape = dummy or axis != 2 @@ -271,13 +272,13 @@ def _make_storage_data_2d( def _make_storage_data_3d( - data: Field, + data: npt.NDArray, shape: tuple[int, ...], start: tuple[int, ...] = (0, 0, 0), *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: istart, jstart, kstart = start isize, jsize, ksize = data.shape buffer = zeros(shape, dtype=dtype, backend=backend) @@ -290,13 +291,13 @@ def _make_storage_data_3d( def _make_storage_data_Nd( - data: Field, + data: npt.NDArray, shape: tuple[int, ...], start: tuple[int, ...] | None = None, *, dtype: DTypes = Float, backend: str, -) -> Field: +) -> npt.NDArray: if start is None: start = tuple([0] * data.ndim) buffer = zeros(shape, dtype=dtype, backend=backend) @@ -312,7 +313,7 @@ def make_storage_from_shape( backend: str, dtype: DTypes = Float, mask: tuple[bool, ...] | None = None, -) -> Field: +) -> npt.NDArray: """Create a new gt4py storage of a given shape filled with zeros. Args: @@ -349,7 +350,7 @@ def make_storage_from_shape( def make_storage_dict( - data: Field, + data: npt.NDArray, shape: tuple[int, ...] | None = None, origin: tuple[int, ...] = origin, start: tuple[int, ...] = (0, 0, 0), @@ -359,11 +360,11 @@ def make_storage_dict( *, backend: str, dtype: DTypes = Float, -) -> dict[str, "Field"]: +) -> dict[str, npt.NDArray]: assert names is not None, "for 4d variable storages, specify a list of names" if shape is None: shape = data.shape - data_dict: dict[str, Field] = dict() + data_dict: dict[str, npt.NDArray] = dict() for i in range(data.shape[3]): data_dict[names[i]] = make_storage_data( squeeze(data[:, :, :, i]), diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 8358c1f9..8cf78282 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -2,10 +2,10 @@ from typing import Any import numpy as np +import numpy.typing as npt import ndsl.dsl.gt4py_utils as utils from ndsl.dsl.stencil import StencilFactory -from ndsl.dsl.typing import Field, Float, Int # noqa: F401 from ndsl.optional_imports import cupy as cp from ndsl.quantity import Quantity from ndsl.stencils.testing.grid import Grid # type: ignore @@ -34,10 +34,10 @@ def as_numpy( def _convert(value: Quantity | np.ndarray) -> np.ndarray: if isinstance(value, Quantity): return value.data - elif cp is not None and isinstance(value, cp.ndarray): - return cp.asnumpy(value) elif isinstance(value, np.ndarray): return value + elif cp is not None and isinstance(value, cp.ndarray): + return cp.asnumpy(value) else: raise TypeError(f"Unrecognized value type: {type(value)}") @@ -135,7 +135,7 @@ def make_storage_data( names_4d: list[str] | None = None, read_only: bool = False, full_shape: bool = False, - ) -> Field: + ) -> dict[str, npt.NDArray] | npt.NDArray: """Copy input data into a gt4py.storage with given shape. `array` is copied. Takes care of the device upload if necessary. @@ -201,7 +201,10 @@ def collect_start_indices(self, datashape, varinfo): return istart, jstart, kstart def make_storage_data_input_vars( - self, inputs, storage_vars=None, dict_4d=True + self, + inputs, + storage_vars=None, + dict_4d=True, ) -> None: """From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionary to update inputs to their configured shape. @@ -292,7 +295,18 @@ def slice_output(self, inputs, out_data=None) -> dict[str, Any]: ) out[serialname] = var4d else: - slice_tuple = self.grid.slice_dict(ds, len(data_result.shape)) + # Get slice for data dimensions (after original 3D) + if len(data_result.shape) > 3: + data_dims_slice = tuple( + [slice(0, ddim_end) for ddim_end in data_result.shape[3:]] + ) + else: + data_dims_slice = () + # Slice combine the expected cartesian and data_dims + cartesian_slice = self.grid.slice_dict( + ds, min(len(data_result.shape), 3) + ) + slice_tuple = cartesian_slice + data_dims_slice out[serialname] = np.squeeze(data_result[slice_tuple]) if "kaxis" in info: out[serialname] = np.moveaxis(out[serialname], 2, info["kaxis"])