Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: be more careful to only access data attribute of masked arrays #81

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_namespace(xp):
class MArray:

def __init__(self, data, mask=None):
data = xp.asarray(getattr(data, '_data', data))
data = xp.asarray(_get_data(data))
mask = (xp.zeros(data.shape, dtype=xp.bool) if mask is None
else xp.asarray(mask, dtype=xp.bool))
if mask.shape != data.shape: # avoid copy if possible
Expand Down Expand Up @@ -99,7 +99,7 @@ def __array_namespace__(self, api_version=None):

def _call_super_method(self, method_name, *args, **kwargs):
method = getattr(self.data, method_name)
args = [getattr(arg, 'data', arg) for arg in args]
args = [_get_data(arg) for arg in args]
return method(*args, **kwargs)

def _validate_key(self, key):
Expand All @@ -122,7 +122,7 @@ def __getitem__(self, key):
def __setitem__(self, key, other):
key = self._validate_key(key)
self.mask[key] = getattr(other, 'mask', False)
return self.data.__setitem__(key, getattr(other, 'data', other))
return self.data.__setitem__(key, _get_data(other))

def __iter__(self):
return iter(self.data)
Expand Down Expand Up @@ -245,7 +245,7 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None):
if device is not None:
raise NotImplementedError("`device` argument is not implemented")

data = getattr(obj, 'data', obj)
data = _get_data(obj)
data = xp.asarray(data, dtype=dtype, device=device, copy=copy)

mask = (getattr(obj, 'mask', xp.full(data.shape, False))
Expand All @@ -268,16 +268,15 @@ def fun(*args, name=name, **kwargs):

for name in creation_functions_like:
def fun(x, /, *args, name=name, **kwargs):
data = getattr(xp, name)(getattr(x, 'data', x), *args, **kwargs)
data = getattr(xp, name)(_get_data(x), *args, **kwargs)
return MArray(data, mask=getattr(x, 'mask', False))
setattr(mod, name, fun)

## Data Type Functions and Data Types ##
dtype_fun_names = ['can_cast', 'finfo', 'iinfo', 'result_type']
for name in dtype_fun_names:
def fun(*args, name=name, **kwargs):
args = [(getattr(arg, 'data') if hasattr(arg, 'mask') else arg)
for arg in args]
args = [_get_data(arg) for arg in args]
return getattr(xp, name)(*args, **kwargs)
setattr(mod, name, fun)

Expand Down Expand Up @@ -316,7 +315,7 @@ def astype(x, dtype, /, *, copy=True, device=None):
def fun(*args, name=name, **kwargs):
masks = [arg.mask for arg in args if hasattr(arg, 'mask')]
masks = xp.broadcast_arrays(*masks)
args = [getattr(arg, 'data', arg) for arg in args]
args = [_get_data(arg) for arg in args]
out = getattr(xp, name)(*args, **kwargs)
return MArray(out, mask=xp.any(xp.stack(masks), axis=0))
setattr(mod, name, fun)
Expand All @@ -327,14 +326,14 @@ def clip(x, /, min=None, max=None):
masks = [arg.mask for arg in args if hasattr(arg, 'mask')]
masks = xp.broadcast_arrays(*masks)
mask = xp.any(xp.stack(masks), axis=0)
datas = [getattr(arg, 'data', arg) for arg in args]
datas = [_get_data(arg) for arg in args]
data = xp.clip(datas[0], min=datas[1], max=datas[2])
return MArray(data, mask)
mod.clip = clip

## Indexing Functions
def take(x, indices, /, *, axis=None):
indices_data = getattr(indices, 'data', indices)
indices_data = _get_data(indices)
indices_mask = getattr(indices, 'mask', False)
indices_data[indices_mask] = 0 # ensure valid index
data = xp.take(x.data, indices_data, axis=axis)
Expand All @@ -343,7 +342,7 @@ def take(x, indices, /, *, axis=None):
mod.take = take

def take_along_axis(x, indices, /, *, axis=-1):
indices_data = getattr(indices, 'data', indices)
indices_data = _get_data(indices)
indices_mask = getattr(indices, 'mask', False)
indices_data[indices_mask] = 0 # ensure valid index
data = xp.take_along_axis(x.data, indices_data, axis=axis)
Expand Down Expand Up @@ -644,3 +643,9 @@ def _xinfo(x):
return binfo(min=False, max=True)
else:
return xp.finfo(x.dtype)


def _get_data(x):
# get data from an MArray or NumPy masked array *without*
# getting memoryview from NumPy array, etc.
return x.data if hasattr(x, 'mask') else x
Loading