Skip to content

Commit

Permalink
Abstractions for read-only arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 28, 2024
1 parent ee25aae commit 3394cb6
Show file tree
Hide file tree
Showing 5 changed files with 447 additions and 7 deletions.
6 changes: 3 additions & 3 deletions array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.
This is a small wrapper around NumPy, CuPy, JAX and others that is compatible
with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
Expand Down
257 changes: 254 additions & 3 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
"""
from __future__ import annotations

import operator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union, Any
from typing import Callable, Literal, Optional, Union, Any
from ._typing import Array, Device

import sys
Expand Down Expand Up @@ -91,7 +92,7 @@ def is_cupy_array(x):
import cupy as cp

# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))
return isinstance(x, cp.ndarray)

def is_torch_array(x):
"""
Expand Down Expand Up @@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x
return x.to_device(device, stream=stream)


def size(x):
"""
Return the total number of elements of x.
Expand All @@ -801,6 +803,253 @@ def size(x):
return None
return math.prod(x.shape)


def is_writeable_array(x) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True


def _is_fancy_index(idx) -> bool:
if not isinstance(idx, tuple):
idx = (idx,)
return any(
isinstance(i, (list, tuple)) or is_array_api_obj(i)
for i in idx
)


_undef = object()


class at:
"""
Update operations for read-only arrays.
This implements ``jax.numpy.ndarray.at`` for all backends.
Keyword arguments are passed verbatim to backends that support the `ndarray.at`
method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
ignored for backends that don't support them.
Additionally, this introduces support for the `copy` keyword for all backends:
None
The array parameter *may* be modified in place if it is possible and beneficial
for performance. You should not reuse it after calling this function.
True
Ensure that the inputs are not modified. This is the default.
False
Raise ValueError if a copy cannot be avoided.
Examples
--------
Given either of these equivalent expressions::
x = at(x)[1].add(2, copy=None)
x = at(x, 1).add(2, copy=None)
If x is a JAX array, they are the same as::
x = x.at[1].add(2)
If x is a read-only numpy array, they are the same as::
x = x.copy()
x[1] += 2
Otherwise, they are the same as::
x[1] += 2
Warning
-------
When you use copy=None, you should always immediately overwrite
the parameter array::
x = at(x, 0).set(2, copy=None)
The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays::
x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2, copy=None)
z = at(x, 1).set(3, copy=None)
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
Warning
-------
The behaviour of update methods when the index is an array of integers which
contains multiple occurrences of the same index is undefined;
e.g. ``at(x, [0, 0]).set(2)``
Note
----
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
See Also
--------
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
"""

__slots__ = ("x", "idx")

def __init__(self, x, idx=_undef, /):
self.x = x
self.idx = idx

def __getitem__(self, idx):
"""
Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
raise ValueError("Index has already been set")
self.idx = idx
return self

def _common(
self,
at_op: str,
y=_undef,
copy: bool | None | Literal["_force_false"] = True,
**kwargs,
):
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
raise TypeError(
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)

x = self.x

if copy is False:
if not is_writeable_array(x) or is_dask_array(x):
raise ValueError("Cannot modify parameter in place")
elif copy is None:
copy = not is_writeable_array(x)
elif copy == "_force_false":
copy = False
elif copy is not True:
raise ValueError(f"Invalid value for copy: {copy!r}")

if is_jax_array(x):
# Use JAX's at[]
at_ = x.at[self.idx]
args = (y,) if y is not _undef else ()
return getattr(at_, at_op)(*args, **kwargs), None

# Emulate at[] behaviour for non-JAX arrays
if copy:
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
xp = array_namespace(x)
x = xp.asarray(x, copy=True)

return None, x

def get(self, **kwargs):
"""
Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
# __getitem__ with a fancy index always returns a copy.
# Avoid an unnecessary double copy.
# If copy is forced to False, raise.
if _is_fancy_index(self.idx):
if kwargs.get("copy", True) is False:
raise TypeError(
"Indexing a numpy array with a fancy index always "
"results in a copy"
)
# Skip copy inside _common, even if array is not writeable
kwargs["copy"] = "_force_false" # type: ignore

res, x = self._common("get", **kwargs)
if res is not None:
return res
return x[self.idx]

def set(self, y, /, **kwargs):
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
if res is not None:
return res
x[self.idx] = y
return x

def _iop(
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
):
"""x[idx] += y or equivalent in-place operation on a subset of x
which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
if res is not None:
return res
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y, /, **kwargs):
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y, /, **kwargs):
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y, /, **kwargs):
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y, /, **kwargs):
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y, /, **kwargs):
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y, /, **kwargs):
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)

def max(self, y, /, **kwargs):
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, **kwargs)


__all__ = [
"array_namespace",
"device",
Expand All @@ -821,8 +1070,10 @@ def size(x):
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"size",
"to_device",
"at",
]

_all_ignore = ['sys', 'math', 'inspect', 'warnings']
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']
3 changes: 3 additions & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ instead, which would be wrapped.
.. autofunction:: device
.. autofunction:: to_device
.. autofunction:: size
.. autoclass:: at(array[, index])
:members:

Inspection Helpers
------------------
Expand All @@ -51,6 +53,7 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down
Loading

0 comments on commit 3394cb6

Please sign in to comment.