Skip to content
7 changes: 5 additions & 2 deletions ndsl/stencils/testing/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def pad_field_in_j(field, nj: int, backend: Backend):
def as_numpy(
value: dict[str, Any] | Quantity | np.ndarray,
) -> np.ndarray | dict[str, np.ndarray]:
def _convert(value: Quantity | np.ndarray) -> np.ndarray:
def _convert(value: Any) -> np.ndarray:
if isinstance(value, Quantity):
return value.data
elif isinstance(value, np.ndarray):
Expand Down Expand Up @@ -74,10 +74,13 @@ def __init__(
self.out_vars: dict[str, Any] = {}
self.write_vars: list = []
self.grid = grid
self.maxshape: tuple[int, ...] = grid.domain_shape_full(add=(1, 1, 1))
self.ordered_input_vars = None
self.ignore_near_zero_errors: dict[str, Any] = {}
self.skip_test = skip_test
if self.stencil_factory.backend.is_fortran_aligned():
self.maxshape = self.grid.domain_shape_full()
else:
self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1))

def extra_data_load(self, data_loader: DataLoader):
pass
Expand Down
11 changes: 2 additions & 9 deletions ndsl/xumpy/alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import numpy.typing as npt
from numpy._typing import _SupportsDType

from ndsl.config import Backend
from ndsl.dsl.typing import Float
Expand All @@ -14,12 +13,6 @@

# Taking a page from cupy's playbook to have tuple & ndarray
_ShapeLike = SupportsIndex | Sequence[SupportsIndex]
_DTypeLikeFloat32 = (
np.dtype[np.float32] | _SupportsDType[np.dtype[np.float32]] | type[np.float32]
)
_DTypeLikeFloat64 = (
np.dtype[np.float64] | _SupportsDType[np.dtype[np.float64]] | type[np.float64]
)


def zeros(
Expand Down Expand Up @@ -55,7 +48,7 @@ def empty(
def full(
shape: _ShapeLike,
backend: Backend,
value: np.ScalarType,
value: npt.DTypeLike,
dtype: npt.DTypeLike = Float,
) -> np.ndarray | cp.ndarray:
if backend.is_gpu_backend():
Expand All @@ -66,7 +59,7 @@ def full(
def random(
shape: _ShapeLike,
backend: Backend,
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64 = Float, # type: ignore [valid-type]
dtype: np.floating = Float,
) -> np.ndarray | cp.ndarray:
if backend.is_gpu_backend():
gen = cp.random.default_rng()
Expand Down