Skip to content

Commit

Permalink
MAINT: _like creation functions: preserve mask of input
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Dec 2, 2024
1 parent 5511a21 commit ec6527d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
12 changes: 9 additions & 3 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None):
return MArray(data, mask=mask)
mod.asarray = asarray

creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack',
'full', 'full_like', 'linspace', 'ones', 'ones_like',
'zeros', 'zeros_like']
creation_functions = ['arange', 'empty', 'eye', 'from_dlpack',
'full', 'linspace', 'ones', 'zeros']
creation_functions_like = ['empty_like', 'full_like', 'ones_like', 'zeros_like']
# handled with array manipulation functions
creation_manip_functions = ['tril', 'triu', 'meshgrid']
for name in creation_functions:
Expand All @@ -229,6 +229,12 @@ def fun(*args, name=name, **kwargs):
return MArray(data)
setattr(mod, name, fun)

for name in creation_functions_like:
def fun(x, /, *args, name=name, **kwargs):
data = getattr(xp, name)(getattr(x, 'data', x), *args, **kwargs)
return MArray(data, mask=getattr(x, 'mask', False))
setattr(mod, name, fun)

## Data Type Functions and Data Types ##
dtype_fun_names = ['can_cast', 'finfo', 'iinfo', 'isdtype', 'result_type']
dtype_names = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16',
Expand Down
20 changes: 20 additions & 0 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,26 @@ def test_creation(f_name, args, kwargs, dtype, xp, seed=None):
np.testing.assert_equal(res.mask, np.full(ref.shape, False), strict=True)


@pytest.mark.parametrize('f_name',
['empty_like', 'zeros_like', 'ones_like', 'full_like'])
@pytest.mark.parametrize("dtype", dtypes_all + [None])
@pytest.mark.parametrize('xp', xps)
def test_creation_like(f_name, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
f_mxp = getattr(mxp, f_name)
f_np = getattr(np, f_name) # np.ma doesn't have full_like
args = (2,) if f_name == "full_like" else ()
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
res = f_mxp(marrays[0], *args, dtype=getattr(xp, str(dtype), None))
ref = f_np(masked_arrays[0], *args, dtype=dtype)
if f_name.startswith('empty'):
assert res.data.shape == ref.shape
np.testing.assert_equal(res.mask, ref.mask)
else:
ref = np.ma.masked_array(ref, mask=masked_arrays[0].mask)
assert_equal(res, ref, xp=xp, seed=seed)


@pytest.mark.parametrize('f_name', ['tril', 'triu'])
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
Expand Down

0 comments on commit ec6527d

Please sign in to comment.