Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array API backend #317

Merged
merged 9 commits into from
Nov 23, 2023
Merged
77 changes: 43 additions & 34 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
float64,
)
from cubed.array_api.linear_algebra_functions import matmul
from cubed.backend_array_api import namespace as nxp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply import namespace as np? This would minimize the diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would, but I wanted to signal that this may not be regular NumPy. By using nxp here, it makes it easy to search for np to find places that are still using regular NumPy.

from cubed.core.array import CoreArray
from cubed.core.ops import elemwise
from cubed.utils import memory_repr
Expand Down Expand Up @@ -118,54 +119,54 @@ def T(self):
def __neg__(self, /):
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __neg__")
return elemwise(np.negative, self, dtype=self.dtype)
return elemwise(nxp.negative, self, dtype=self.dtype)

def __pos__(self, /):
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __pos__")
return elemwise(np.positive, self, dtype=self.dtype)
return elemwise(nxp.positive, self, dtype=self.dtype)

def __add__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__add__")
if other is NotImplemented:
return other
return elemwise(np.add, self, other, dtype=result_type(self, other))
return elemwise(nxp.add, self, other, dtype=result_type(self, other))

def __sub__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__sub__")
if other is NotImplemented:
return other
return elemwise(np.subtract, self, other, dtype=result_type(self, other))
return elemwise(nxp.subtract, self, other, dtype=result_type(self, other))

def __mul__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
if other is NotImplemented:
return other
return elemwise(np.multiply, self, other, dtype=result_type(self, other))
return elemwise(nxp.multiply, self, other, dtype=result_type(self, other))

def __truediv__(self, other, /):
other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
if other is NotImplemented:
return other
return elemwise(np.divide, self, other, dtype=result_type(self, other))
return elemwise(nxp.divide, self, other, dtype=result_type(self, other))

def __floordiv__(self, other, /):
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
return elemwise(np.floor_divide, self, other, dtype=result_type(self, other))
return elemwise(nxp.floor_divide, self, other, dtype=result_type(self, other))

def __mod__(self, other, /):
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
return elemwise(np.remainder, self, other, dtype=result_type(self, other))
return elemwise(nxp.remainder, self, other, dtype=result_type(self, other))

def __pow__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
return elemwise(np.power, self, other, dtype=result_type(self, other))
return elemwise(nxp.pow, self, other, dtype=result_type(self, other))

# Array Operators

Expand All @@ -180,75 +181,79 @@ def __matmul__(self, other, /):
def __invert__(self, /):
if self.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
return elemwise(np.invert, self, dtype=self.dtype)
return elemwise(nxp.bitwise_invert, self, dtype=self.dtype)

def __and__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_and, self, other, dtype=result_type(self, other))
return elemwise(nxp.bitwise_and, self, other, dtype=result_type(self, other))

def __or__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_or, self, other, dtype=result_type(self, other))
return elemwise(nxp.bitwise_or, self, other, dtype=result_type(self, other))

def __xor__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_xor, self, other, dtype=result_type(self, other))
return elemwise(nxp.bitwise_xor, self, other, dtype=result_type(self, other))

def __lshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
if other is NotImplemented:
return other
return elemwise(np.left_shift, self, other, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_left_shift, self, other, dtype=result_type(self, other)
)

def __rshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__rshift__")
if other is NotImplemented:
return other
return elemwise(np.right_shift, self, other, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_right_shift, self, other, dtype=result_type(self, other)
)

# Comparison Operators

def __eq__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
return elemwise(np.equal, self, other, dtype=np.bool_)
return elemwise(nxp.equal, self, other, dtype=np.bool_)

def __ge__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ge__")
if other is NotImplemented:
return other
return elemwise(np.greater_equal, self, other, dtype=np.bool_)
return elemwise(nxp.greater_equal, self, other, dtype=np.bool_)

def __gt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__gt__")
if other is NotImplemented:
return other
return elemwise(np.greater, self, other, dtype=np.bool_)
return elemwise(nxp.greater, self, other, dtype=np.bool_)

def __le__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__le__")
if other is NotImplemented:
return other
return elemwise(np.less_equal, self, other, dtype=np.bool_)
return elemwise(nxp.less_equal, self, other, dtype=np.bool_)

def __lt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__lt__")
if other is NotImplemented:
return other
return elemwise(np.less, self, other, dtype=np.bool_)
return elemwise(nxp.less, self, other, dtype=np.bool_)

def __ne__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
return elemwise(np.not_equal, self, other, dtype=np.bool_)
return elemwise(nxp.not_equal, self, other, dtype=np.bool_)

# Reflected Operators

Expand All @@ -258,43 +263,43 @@ def __radd__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__radd__")
if other is NotImplemented:
return other
return elemwise(np.add, other, self, dtype=result_type(self, other))
return elemwise(nxp.add, other, self, dtype=result_type(self, other))

def __rsub__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
if other is NotImplemented:
return other
return elemwise(np.subtract, other, self, dtype=result_type(self, other))
return elemwise(nxp.subtract, other, self, dtype=result_type(self, other))

def __rmul__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
if other is NotImplemented:
return other
return elemwise(np.multiply, other, self, dtype=result_type(self, other))
return elemwise(nxp.multiply, other, self, dtype=result_type(self, other))

def __rtruediv__(self, other, /):
other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
if other is NotImplemented:
return other
return elemwise(np.divide, other, self, dtype=result_type(self, other))
return elemwise(nxp.divide, other, self, dtype=result_type(self, other))

def __rfloordiv__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
if other is NotImplemented:
return other
return elemwise(np.floor_divide, other, self, dtype=result_type(self, other))
return elemwise(nxp.floor_divide, other, self, dtype=result_type(self, other))

def __rmod__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
if other is NotImplemented:
return other
return elemwise(np.remainder, other, self, dtype=result_type(self, other))
return elemwise(nxp.remainder, other, self, dtype=result_type(self, other))

def __rpow__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
return elemwise(np.power, other, self, dtype=result_type(self, other))
return elemwise(nxp.pow, other, self, dtype=result_type(self, other))

# (Reflected) Array Operators

Expand All @@ -310,31 +315,35 @@ def __rand__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_and, other, self, dtype=result_type(self, other))
return elemwise(nxp.bitwise_and, other, self, dtype=result_type(self, other))

def __ror__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_or, other, self, dtype=result_type(self, other))
return elemwise(nxp.bitwise_or, other, self, dtype=result_type(self, other))

def __rxor__(self, other, /):
other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
if other is NotImplemented:
return other
return elemwise(np.bitwise_xor, other, self, dtype=result_type(self, other))
return elemwise(nxp.bitwise_xor, other, self, dtype=result_type(self, other))

def __rlshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
if other is NotImplemented:
return other
return elemwise(np.left_shift, other, self, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_left_shift, other, self, dtype=result_type(self, other)
)

def __rrshift__(self, other, /):
other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
if other is NotImplemented:
return other
return elemwise(np.right_shift, other, self, dtype=result_type(self, other))
return elemwise(
nxp.bitwise_right_shift, other, self, dtype=result_type(self, other)
)

# Methods

Expand All @@ -347,7 +356,7 @@ def __abs__(self, /):
dtype = float64
else:
dtype = self.dtype
return elemwise(np.abs, self, dtype=dtype)
return elemwise(nxp.abs, self, dtype=dtype)

def __array_namespace__(self, /, *, api_version=None):
if api_version is not None and not api_version.startswith("2021."):
Expand Down
18 changes: 10 additions & 8 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import TYPE_CHECKING, Iterable, List

import numpy as np
from zarr.util import normalize_shape

from cubed.backend_array_api import namespace as nxp
from cubed.core import Plan, gensym, map_blocks
from cubed.core.ops import map_direct
from cubed.core.plan import new_temp_path
Expand All @@ -20,9 +22,9 @@ def arange(
) -> "Array":
if stop is None:
start, stop = 0, start
num = int(max(np.ceil((stop - start) / step), 0))
num = int(max(math.ceil((stop - start) / step), 0))
if dtype is None:
dtype = np.arange(start, stop, step * num if num else step).dtype
dtype = nxp.arange(start, stop, step * num if num else step).dtype
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]
numblocks = len(chunks[0])
Expand All @@ -43,10 +45,10 @@ def arange(


def _arange(a, size, start, stop, step):
i = a[0]
i = int(a[0])
blockstart = start + (i * size * step)
blockstop = start + ((i + 1) * size * step)
return np.arange(blockstart, min(blockstop, stop), step)
return nxp.arange(blockstart, min(blockstop, stop), step)


def asarray(
Expand All @@ -64,7 +66,7 @@ def asarray(
return asarray(a.data)
elif not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
a = np.asarray(a, dtype=dtype)
a = nxp.asarray(a, dtype=dtype)
if dtype is None:
dtype = a.dtype

Expand Down Expand Up @@ -133,9 +135,9 @@ def _eye(x, *arrays, k=None, chunksize=None, block_id=None):
i, j = block_id
bk = (j - i) * chunksize
if bk - chunksize <= k <= bk + chunksize:
return np.eye(x.shape[0], x.shape[1], k=k - bk, dtype=x.dtype)
return nxp.eye(x.shape[0], x.shape[1], k=k - bk, dtype=x.dtype)
else:
return np.zeros_like(x)
return nxp.zeros_like(x)


def full(
Expand Down Expand Up @@ -225,7 +227,7 @@ def _linspace(x, *arrays, size, start, step, endpoint, linspace_dtype, block_id=
adjusted_bs = bs - 1 if endpoint else bs
blockstart = start + (i * size * step)
blockstop = blockstart + (adjusted_bs * step)
return np.linspace(
return nxp.linspace(
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
)

Expand Down
3 changes: 2 additions & 1 deletion cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from numpy.array_api._typing import Dtype

from cubed.backend_array_api import namespace as nxp
from cubed.core import CoreArray, map_blocks

from .dtypes import (
Expand All @@ -25,7 +26,7 @@ def astype(x, dtype, /, *, copy=True):


def _astype(a, astype_dtype):
return a.astype(astype_dtype)
return nxp.astype(a, astype_dtype)


def can_cast(from_, to, /):
Expand Down
Loading