diff --git a/marray/__init__.py b/marray/__init__.py index 0cb3078..4e3a906 100644 --- a/marray/__init__.py +++ b/marray/__init__.py @@ -4,6 +4,7 @@ __version__ = "0.0.4" +import types, sys import dataclasses def masked_array(xp): @@ -195,10 +196,8 @@ def info(x): else: return xp.finfo(x.dtype) - class module: - pass - - mod = module() + mod = types.ModuleType('mxp') + sys.modules['mxp'] = mod mod.MaskedArray = MaskedArray diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index 78d314e..d633517 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -684,6 +684,13 @@ def test_sorting(f_name, descending, stable, dtype, xp=strict, seed=None): ref = np.ma.MaskedArray(ref_data, mask=ref_mask) assert_equal(res, ref, seed) + +def test_import(xp=np): + mxp = marray.masked_array(xp) + from mxp import asarray + asarray(10, mask=True) + + def test_test(): seed = 149020664425889521094089537542803361848 # test_statistical_array('argmin', True, seed=seed)