Skip to content

Commit

Permalink
ENH: __array_namespace__: add method
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Dec 2, 2024
1 parent ec6527d commit 782d6f6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
18 changes: 13 additions & 5 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ def size(self):
def mask(self):
return self._mask

def call_super_method(self, method_name, *args, **kwargs):
def __array_namespace__(self, api_version=None):
if api_version is None or api_version == '2023.12':
return mod
else:
message = (f"MArray interface for Array API version '{api_version}' "
"is not implemented.")
raise NotImplementedError(message)

def _call_super_method(self, method_name, *args, **kwargs):
method = getattr(self.data, method_name)
args = [getattr(arg, 'data', arg) for arg in args]
return method(*args, **kwargs)
Expand Down Expand Up @@ -143,15 +151,15 @@ def to_device(self, device, /, *, stream=None):
+ ['__ceil__'])
for name in unary_names:
def fun(self, name=name):
data = self.call_super_method(name)
data = self._call_super_method(name)
return MArray(data, self.mask)
setattr(MArray, name, fun)

# Methods that return the result of a unary operation as a Python scalar
unary_names_py = ['__bool__', '__complex__', '__float__', '__index__', '__int__']
for name in unary_names_py:
def fun(self, name=name):
return self.call_super_method(name)
return self._call_super_method(name)
setattr(MArray, name, fun)

# Methods that return the result of an elementwise binary operation
Expand All @@ -166,7 +174,7 @@ def fun(self, name=name):
for name in binary_names + rbinary_names:
def fun(self, other, name=name):
mask = (self.mask | other.mask) if hasattr(other, 'mask') else self.mask
data = self.call_super_method(name, other)
data = self._call_super_method(name, other)
return MArray(data, mask)
setattr(MArray, name, fun)

Expand All @@ -179,7 +187,7 @@ def fun(self, other, name=name, **kwargs):
if hasattr(other, 'mask'):
# self.mask |= other.mask doesn't work because mask has no setter
self.mask.__ior__(other.mask)
self.call_super_method(name, other)
self._call_super_method(name, other)
return self
setattr(MArray, name, fun)

Expand Down
21 changes: 18 additions & 3 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,17 @@ def test_sorting(f_name, descending, stable, dtype, xp, seed=None):
assert_equal(res, ref, xp=xp, seed=seed)


@pytest.mark.parametrize('xp', xps)
def test_array_namespace(xp):
mxp = marray.get_namespace(xp)
x = mxp.asarray([1, 2, 3])
assert x.__array_namespace__() is mxp
assert x.__array_namespace__("2023.12") is mxp
message = "MArray interface for Array API version 'shrubbery'..."
with pytest.raises(NotImplementedError, match=message):
x.__array_namespace__("shrubbery")


@pytest.mark.parametrize('xp', xps)
def test_import(xp):
mxp = marray.get_namespace(xp)
Expand All @@ -825,13 +836,17 @@ def test_import(xp):
# To do:
# - Indexing (same behavior as indexing data and mask separately)
# - Set functions (see https://github.com/mdhaber/marray/issues/28)
# - __array_namespace__
# - improve test_rarray_binary
# - improve test_statistical_array
# - improve test_meshgrid - need more inputs
# - investigate asarray - is copy respected?
# - investigate test_sorting - what about uint dtypes?
# - investigate failing_test

# def failing_test():
# seed = 87597311899020256922680472523907945305
# test_array_binary(array_binary[0], dtype=np.float32, xp=np, seed=seed)

def test_test():
seed = 313948498256289901944532431191982951352
test_attributes('uint8', strict, seed=seed)
seed = 8759731189902025692268047252390794530
test_array_binary(array_binary[0], dtype=np.float32, xp=np, seed=seed)

0 comments on commit 782d6f6

Please sign in to comment.