From 6241cbc04a23624e3295a5829454bc5e640dbd3c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 3 Mar 2026 15:19:35 +0100 Subject: [PATCH] refactor: use xumpy for allocation in gt4py_utils --- ndsl/dsl/gt4py_utils.py | 54 +++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index e4dcb532..8f944d0f 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable, Sequence from functools import wraps from typing import Any @@ -6,9 +7,10 @@ import numpy.typing as npt from gt4py import storage as gt_storage +from ndsl import xumpy from ndsl.config.backend import Backend from ndsl.constants import N_HALO_DEFAULT -from ndsl.dsl.typing import DTypes, Float +from ndsl.dsl.typing import Float from ndsl.logging import ndsl_log from ndsl.optional_imports import cupy as cp @@ -49,19 +51,16 @@ def wrapper(*args, **kwargs) -> Any: return inner -def _mask_to_dimensions( - mask: tuple[bool, ...], shape: Sequence[int] -) -> list[str | int]: +def _mask_to_dimensions(mask: tuple[bool, ...], shape: Sequence[int]) -> list[str]: assert len(mask) >= 3 - dimensions: list[str | int] = [] + dimensions: list[str] = [] for i, axis in enumerate(("I", "J", "K")): if mask[i]: dimensions.append(axis) if len(mask) > 3: for i in range(3, len(mask)): - dimensions.append(str(shape[i])) - offset = int(sum(mask)) - dimensions.extend(shape[offset:]) + if mask[i]: + dimensions.append(str(shape[i])) return dimensions @@ -86,7 +85,7 @@ def make_storage_data( origin: tuple[int, ...] = origin, *, backend: Backend, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, mask: tuple[bool, ...] | None = None, start: tuple[int, ...] = (0, 0, 0), dummy: tuple[int, ...] | None = None, @@ -205,12 +204,12 @@ def _make_storage_data_1d( axis: int = 2, read_only: bool = True, *, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, backend: Backend, ) -> npt.NDArray: # axis refers to a repeated axis, dummy refers to a singleton axis axis = min(axis, len(shape) - 1) - buffer = zeros(shape[axis], dtype=dtype, backend=backend) + buffer = xumpy.zeros(shape[axis], backend, dtype) if dummy: axis = list(set((0, 1, 2)).difference(dummy))[0] @@ -242,7 +241,7 @@ def _make_storage_data_2d( axis: int = 2, read_only: bool = True, *, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, backend: Backend, ) -> npt.NDArray: # axis refers to which axis should be repeated (when making a full 3d data), @@ -256,7 +255,7 @@ def _make_storage_data_2d( start1, start2 = start[0:2] size1, size2 = data.shape - buffer = zeros(shape2d, dtype=dtype, backend=backend) + buffer = xumpy.zeros(shape2d, backend, dtype) buffer[start1 : start1 + size1, start2 : start2 + size2] = asarray( data, type(buffer) ) @@ -276,12 +275,12 @@ def _make_storage_data_3d( shape: tuple[int, ...], start: tuple[int, ...] = (0, 0, 0), *, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, backend: Backend, ) -> npt.NDArray: istart, jstart, kstart = start isize, jsize, ksize = data.shape - buffer = zeros(shape, dtype=dtype, backend=backend) + buffer = xumpy.zeros(shape, backend, dtype) buffer[ istart : istart + isize, jstart : jstart + jsize, @@ -295,12 +294,12 @@ def _make_storage_data_Nd( shape: tuple[int, ...], start: tuple[int, ...] | None = None, *, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, backend: Backend, ) -> npt.NDArray: if start is None: start = tuple([0] * data.ndim) - buffer = zeros(shape, dtype=dtype, backend=backend) + buffer = xumpy.zeros(shape, backend, dtype) idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))]) buffer[idx] = asarray(data, type(buffer)) return buffer @@ -311,7 +310,7 @@ def make_storage_from_shape( origin: tuple[int, ...] = origin, *, backend: Backend, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, mask: tuple[bool, ...] | None = None, ) -> npt.NDArray: """Create a new gt4py storage of a given shape filled with zeros. @@ -333,12 +332,16 @@ def make_storage_from_shape( ) 3) q_out = utils.make_storage_from_shape(q_in.shape, origin,) """ - if not mask: + if mask is None: n_dims = len(shape) if n_dims == 1: mask = (False, False, True) # Assume 1D is a k-field + elif n_dims == 2: + mask = (True, True, False) # Assume 2D is an ij-field + elif n_dims < 3: + raise NotImplementedError(f"Unexpected number of dimensions {n_dims}.") else: - mask = (n_dims * (True,)) + ((3 - n_dims) * (False,)) + mask = n_dims * (True,) storage = gt_storage.zeros( shape, dtype, @@ -359,7 +362,7 @@ def make_storage_dict( axis: int = 2, *, backend: Backend, - dtype: DTypes = Float, + dtype: npt.DTypeLike = Float, ) -> dict[str, npt.NDArray]: assert names is not None, "for 4d variable storages, specify a list of names" if shape is None: @@ -447,9 +450,12 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None): def zeros(shape, dtype=Float, *, backend: Backend): - storage_type = cp.ndarray if backend.is_gpu_backend() else np.ndarray - xp = cp if cp and storage_type is cp.ndarray else np - return xp.zeros(shape, dtype=dtype) + warnings.warn( + "gt4py_utils.zeros() is deprecated. Use `zeros()` from `ndsl.xumpy` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + return xumpy.zeros(shape, backend, dtype) def sum(array, axis=None, dtype=Float, out=None, keepdims=False):