Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions ndsl/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Callable, Sequence
from functools import wraps
from typing import Any
Expand All @@ -6,9 +7,10 @@
import numpy.typing as npt
from gt4py import storage as gt_storage

from ndsl import xumpy
from ndsl.config.backend import Backend
from ndsl.constants import N_HALO_DEFAULT
from ndsl.dsl.typing import DTypes, Float
from ndsl.dsl.typing import Float
from ndsl.logging import ndsl_log
from ndsl.optional_imports import cupy as cp

Expand Down Expand Up @@ -49,19 +51,16 @@ def wrapper(*args, **kwargs) -> Any:
return inner


def _mask_to_dimensions(
mask: tuple[bool, ...], shape: Sequence[int]
) -> list[str | int]:
def _mask_to_dimensions(mask: tuple[bool, ...], shape: Sequence[int]) -> list[str]:
assert len(mask) >= 3
dimensions: list[str | int] = []
dimensions: list[str] = []
for i, axis in enumerate(("I", "J", "K")):
if mask[i]:
dimensions.append(axis)
if len(mask) > 3:
for i in range(3, len(mask)):
dimensions.append(str(shape[i]))
offset = int(sum(mask))
dimensions.extend(shape[offset:])
if mask[i]:
dimensions.append(str(shape[i]))
return dimensions


Expand All @@ -86,7 +85,7 @@ def make_storage_data(
origin: tuple[int, ...] = origin,
*,
backend: Backend,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
mask: tuple[bool, ...] | None = None,
start: tuple[int, ...] = (0, 0, 0),
dummy: tuple[int, ...] | None = None,
Expand Down Expand Up @@ -205,12 +204,12 @@ def _make_storage_data_1d(
axis: int = 2,
read_only: bool = True,
*,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
backend: Backend,
) -> npt.NDArray:
# axis refers to a repeated axis, dummy refers to a singleton axis
axis = min(axis, len(shape) - 1)
buffer = zeros(shape[axis], dtype=dtype, backend=backend)
buffer = xumpy.zeros(shape[axis], backend, dtype)
if dummy:
axis = list(set((0, 1, 2)).difference(dummy))[0]

Expand Down Expand Up @@ -242,7 +241,7 @@ def _make_storage_data_2d(
axis: int = 2,
read_only: bool = True,
*,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
backend: Backend,
) -> npt.NDArray:
# axis refers to which axis should be repeated (when making a full 3d data),
Expand All @@ -256,7 +255,7 @@ def _make_storage_data_2d(

start1, start2 = start[0:2]
size1, size2 = data.shape
buffer = zeros(shape2d, dtype=dtype, backend=backend)
buffer = xumpy.zeros(shape2d, backend, dtype)
buffer[start1 : start1 + size1, start2 : start2 + size2] = asarray(
data, type(buffer)
)
Expand All @@ -276,12 +275,12 @@ def _make_storage_data_3d(
shape: tuple[int, ...],
start: tuple[int, ...] = (0, 0, 0),
*,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
backend: Backend,
) -> npt.NDArray:
istart, jstart, kstart = start
isize, jsize, ksize = data.shape
buffer = zeros(shape, dtype=dtype, backend=backend)
buffer = xumpy.zeros(shape, backend, dtype)
buffer[
istart : istart + isize,
jstart : jstart + jsize,
Expand All @@ -295,12 +294,12 @@ def _make_storage_data_Nd(
shape: tuple[int, ...],
start: tuple[int, ...] | None = None,
*,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
backend: Backend,
) -> npt.NDArray:
if start is None:
start = tuple([0] * data.ndim)
buffer = zeros(shape, dtype=dtype, backend=backend)
buffer = xumpy.zeros(shape, backend, dtype)
idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))])
buffer[idx] = asarray(data, type(buffer))
return buffer
Expand All @@ -311,7 +310,7 @@ def make_storage_from_shape(
origin: tuple[int, ...] = origin,
*,
backend: Backend,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
mask: tuple[bool, ...] | None = None,
) -> npt.NDArray:
"""Create a new gt4py storage of a given shape filled with zeros.
Expand All @@ -333,12 +332,16 @@ def make_storage_from_shape(
)
3) q_out = utils.make_storage_from_shape(q_in.shape, origin,)
"""
if not mask:
if mask is None:
n_dims = len(shape)
if n_dims == 1:
mask = (False, False, True) # Assume 1D is a k-field
elif n_dims == 2:
mask = (True, True, False) # Assume 2D is an ij-field
elif n_dims < 3:
raise NotImplementedError(f"Unexpected number of dimensions {n_dims}.")
else:
mask = (n_dims * (True,)) + ((3 - n_dims) * (False,))
mask = n_dims * (True,)
storage = gt_storage.zeros(
shape,
dtype,
Expand All @@ -359,7 +362,7 @@ def make_storage_dict(
axis: int = 2,
*,
backend: Backend,
dtype: DTypes = Float,
dtype: npt.DTypeLike = Float,
) -> dict[str, npt.NDArray]:
assert names is not None, "for 4d variable storages, specify a list of names"
if shape is None:
Expand Down Expand Up @@ -447,9 +450,12 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None):


def zeros(shape, dtype=Float, *, backend: Backend):
storage_type = cp.ndarray if backend.is_gpu_backend() else np.ndarray
xp = cp if cp and storage_type is cp.ndarray else np
return xp.zeros(shape, dtype=dtype)
warnings.warn(
"gt4py_utils.zeros() is deprecated. Use `zeros()` from `ndsl.xumpy` instead.",
category=DeprecationWarning,
stacklevel=2,
)
return xumpy.zeros(shape, backend, dtype)


def sum(array, axis=None, dtype=Float, out=None, keepdims=False):
Expand Down