diff --git a/marray/__init__.py b/marray/__init__.py index 5ad4f71..c2a8f7e 100644 --- a/marray/__init__.py +++ b/marray/__init__.py @@ -218,9 +218,9 @@ def asarray(obj, /, *, mask=None, dtype=None, device=None, copy=None): return MArray(data, mask=mask) mod.asarray = asarray - creation_functions = ['arange', 'empty', 'empty_like', 'eye', 'from_dlpack', - 'full', 'full_like', 'linspace', 'ones', 'ones_like', - 'zeros', 'zeros_like'] + creation_functions = ['arange', 'empty', 'eye', 'from_dlpack', + 'full', 'linspace', 'ones', 'zeros'] + creation_functions_like = ['empty_like', 'full_like', 'ones_like', 'zeros_like'] # handled with array manipulation functions creation_manip_functions = ['tril', 'triu', 'meshgrid'] for name in creation_functions: @@ -229,6 +229,12 @@ def fun(*args, name=name, **kwargs): return MArray(data) setattr(mod, name, fun) + for name in creation_functions_like: + def fun(x, /, *args, name=name, **kwargs): + data = getattr(xp, name)(getattr(x, 'data', x), *args, **kwargs) + return MArray(data, mask=getattr(x, 'mask', False)) + setattr(mod, name, fun) + ## Data Type Functions and Data Types ## dtype_fun_names = ['can_cast', 'finfo', 'iinfo', 'isdtype', 'result_type'] dtype_names = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index 757ce38..a57e9a8 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -592,6 +592,26 @@ def test_creation(f_name, args, kwargs, dtype, xp, seed=None): np.testing.assert_equal(res.mask, np.full(ref.shape, False), strict=True) +@pytest.mark.parametrize('f_name', + ['empty_like', 'zeros_like', 'ones_like', 'full_like']) +@pytest.mark.parametrize("dtype", dtypes_all + [None]) +@pytest.mark.parametrize('xp', xps) +def test_creation_like(f_name, dtype, xp, seed=None): + mxp = marray.get_namespace(xp) + f_mxp = getattr(mxp, f_name) + f_np = getattr(np, f_name) # np.ma doesn't have full_like + args = (2,) if f_name == "full_like" else () + marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) + res = f_mxp(marrays[0], *args, dtype=getattr(xp, str(dtype), None)) + ref = f_np(masked_arrays[0], *args, dtype=dtype) + if f_name.startswith('empty'): + assert res.data.shape == ref.shape + np.testing.assert_equal(res.mask, ref.mask) + else: + ref = np.ma.masked_array(ref, mask=masked_arrays[0].mask) + assert_equal(res, ref, xp=xp, seed=seed) + + @pytest.mark.parametrize('f_name', ['tril', 'triu']) @pytest.mark.parametrize('dtype', dtypes_all) @pytest.mark.parametrize('xp', xps)