diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 66332ec8..a8b3efbc 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -13,26 +13,6 @@ from ndsl.quantity import Quantity, QuantityHaloSpec -class _Allocator: - def __init__(self, backend: str) -> None: - """ - Initialize an object that provides gt4py storage objects for zeros(), ones(), and empty(). - - Args: - backend: GT4Py backend name used for performance-optimized allocation. - """ - self.backend = backend - - def empty(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.empty(*args, backend=self.backend, **kwargs) - - def ones(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.ones(*args, backend=self.backend, **kwargs) - - def zeros(self, *args: Any, **kwargs: Any) -> np.ndarray: - return gt_storage.zeros(*args, backend=self.backend, **kwargs) - - class QuantityFactory: def __init__(self, sizer: GridSizer, *, backend: str) -> None: """ @@ -45,8 +25,6 @@ def __init__(self, sizer: GridSizer, *, backend: str) -> None: self.sizer = sizer self.backend = backend - self._allocator = _Allocator(self.backend) - def update_data_dimensions( self, data_dimension_descriptions: dict[str, int], @@ -111,7 +89,7 @@ def empty( Equivalent to `numpy.empty`""" return self._allocate( - self._allocator.empty, dims, units, dtype, allow_mismatch_float_precision + gt_storage.empty, dims, units, dtype, allow_mismatch_float_precision ) def zeros( @@ -126,7 +104,7 @@ def zeros( Equivalent to `numpy.zeros`""" return self._allocate( - self._allocator.zeros, dims, units, dtype, allow_mismatch_float_precision + gt_storage.zeros, dims, units, dtype, allow_mismatch_float_precision ) def ones( @@ -141,7 +119,7 @@ def ones( Equivalent to `numpy.ones`""" return self._allocate( - self._allocator.ones, dims, units, dtype, allow_mismatch_float_precision + gt_storage.ones, dims, units, dtype, allow_mismatch_float_precision ) def full( @@ -157,7 +135,7 @@ def full( Equivalent to `numpy.full`""" quantity = self._allocate( - self._allocator.empty, + gt_storage.empty, dims, units, dtype, @@ -235,12 +213,15 @@ def _allocate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) ] - try: - data = allocator( - shape, dtype=dtype, aligned_index=origin, dimensions=dimensions - ) - except TypeError: - data = allocator(shape, dtype=dtype) + + data = allocator( + shape, + dtype=dtype, + aligned_index=origin, + dimensions=dimensions, + backend=self.backend, + ) + return Quantity( data, dims=dims,