diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index f1facfe4..acfee07f 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -53,11 +53,14 @@ def wrapper(*args, **kwargs) -> Any: def _mask_to_dimensions( mask: Tuple[bool, ...], shape: Sequence[int] ) -> List[Union[str, int]]: - assert len(mask) == 3 + assert len(mask) >= 3 dimensions: List[Union[str, int]] = [] for i, axis in enumerate(("I", "J", "K")): if mask[i]: dimensions.append(axis) + if len(mask) > 3: + for i in range(3, len(mask)): + dimensions.append(str(shape[i])) offset = int(sum(mask)) dimensions.extend(shape[offset:]) return dimensions @@ -154,6 +157,8 @@ def make_storage_data( data = _make_storage_data_2d( data, shape, start, dummy, axis, read_only, backend=backend ) + elif n_dims >= 4: + data = _make_storage_data_Nd(data, shape, start, backend=backend) else: data = _make_storage_data_3d(data, shape, start, backend=backend) @@ -257,6 +262,21 @@ def _make_storage_data_3d( return buffer +def _make_storage_data_Nd( + data: Field, + shape: Tuple[int, ...], + start: Tuple[int, ...] = None, + *, + backend: str, +) -> Field: + if start is None: + start = tuple([0] * data.ndim) + buffer = zeros(shape, backend=backend) + idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))]) + buffer[idx] = asarray(data, type(buffer)) + return buffer + + def make_storage_from_shape( shape: Tuple[int, ...], origin: Tuple[int, ...] = origin, @@ -310,6 +330,7 @@ def make_storage_dict( axis: int = 2, *, backend: str, + dtype: DTypes = Float, ) -> Dict[str, "Field"]: assert names is not None, "for 4d variable storages, specify a list of names" if shape is None: @@ -324,6 +345,7 @@ def make_storage_dict( dummy=dummy, axis=axis, backend=backend, + dtype=dtype, ) return data_dict diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 3d9d20f9..2584deb1 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -5,7 +5,7 @@ import ndsl.dsl.gt4py_utils as utils from ndsl.dsl.stencil import StencilFactory -from ndsl.dsl.typing import Field, Float # noqa: F401 +from ndsl.dsl.typing import Field, Float, Int # noqa: F401 from ndsl.quantity import Quantity from ndsl.stencils.testing.grid import Grid # type: ignore @@ -113,6 +113,12 @@ def make_storage_data( elif not full_shape and len(array.shape) < 3 and axis == len(array.shape) - 1: use_shape[1] = 1 start = (int(istart), int(jstart), int(kstart)) + if "float" in str(array.dtype): + dtype = Float + elif "int" in str(array.dtype): + dtype = Int + else: + dtype = array.dtype if names_4d: return utils.make_storage_dict( array, @@ -123,8 +129,12 @@ def make_storage_data( axis=axis, names=names_4d, backend=self.stencil_factory.backend, + dtype=dtype, ) else: + if len(array.shape) == 4: + start = (int(istart), int(jstart), int(kstart), 0) # type: ignore + use_shape.append(array.shape[-1]) return utils.make_storage_data( array, tuple(use_shape), @@ -134,6 +144,7 @@ def make_storage_data( axis=axis, read_only=read_only, backend=self.stencil_factory.backend, + dtype=dtype, ) def storage_vars(self): @@ -159,7 +170,7 @@ def collect_start_indices(self, datashape, varinfo): kstart = self.get_index_from_info(varinfo, "kstart", 0) return istart, jstart, kstart - def make_storage_data_input_vars(self, inputs, storage_vars=None): + def make_storage_data_input_vars(self, inputs, storage_vars=None, dict_4d=True): inputs_in = {**inputs} inputs_out = {} if storage_vars is None: @@ -185,7 +196,7 @@ def make_storage_data_input_vars(self, inputs, storage_vars=None): ) names_4d = None - if len(inputs_in[serialname].shape) == 4: + if (len(inputs_in[serialname].shape) == 4) and dict_4d: names_4d = info.get("names_4d", utils.tracer_variables) dummy_axes = info.get("dummy_axes", None)