diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index f23f1e7..dfaabc5 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -20,14 +20,14 @@ dtypes_all = dtypes_boolean + dtypes_integral + dtypes_real + dtypes_complex -def get_arrays(n_arrays, *, dtype, xp, ndim=(1, 4), seed=None): +def get_arrays(n_arrays, *, dtype, xp, shape=None, ndim=(1, 4), seed=None): xpm = marray._get_namespace(xp) entropy = np.random.SeedSequence(seed).entropy rng = np.random.default_rng(entropy) ndim = rng.integers(*ndim) if isinstance(ndim, tuple) else ndim - shape = rng.integers(1, 20, size=ndim) + shape = rng.integers(1, 20, size=ndim) if shape is None else np.asarray(shape) datas = [] masks = [] @@ -106,18 +106,28 @@ def get_rtol(dtype, xp): return 0 -arithmetic_unary = [lambda x: +x, lambda x: -x, abs] -arithmetic_methods_unary = [lambda x: x.__abs__(), lambda x: x.__neg__(), - lambda x: x.__pos__()] -arithmetic_binary = [lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y, - lambda x, y: x / y, lambda x, y: x // y, lambda x, y: x % y, - lambda x, y: x ** y] -arithmetic_methods_binary = [lambda x, y: x.__add__(y), lambda x, y: x.__floordiv__(y), - lambda x, y: x.__mod__(y), lambda x, y: x.__mul__(y), - lambda x, y: x.__pow__(y), lambda x, y: x.__sub__(y), - lambda x, y: x.__truediv__(y)] -array_binary = [lambda x, y: x @ y, operator.matmul, operator.__matmul__] -array_methods_binary = [lambda x, y: x.__matmul__(y)] +arithmetic_unary = {"+x": lambda x: +x, "-x": lambda x: -x, "abs": abs} +arithmetic_methods_unary = {"x.__abs__": lambda x: x.__abs__(), + "x.__neg__": lambda x: x.__neg__(), + "x.__pos__": lambda x: x.__pos__()} +arithmetic_binary = {"x + y": lambda x, y: x + y, + "x - y": lambda x, y: x - y, + "x * y": lambda x, y: x * y, + "x / y": lambda x, y: x / y, + "x // y": lambda x, y: x // y, + "x % y": lambda x, y: x % y, + "x ** y": lambda x, y: x ** y} +arithmetic_methods_binary = {"x.__add__(y)": lambda x, y: x.__add__(y), + "x.__floordiv__(y)": lambda x, y: x.__floordiv__(y), + "x.__mod__(y)": lambda x, y: x.__mod__(y), + "x.__mul__(y)": lambda x, y: x.__mul__(y), + "x.__pow__(y)": lambda x, y: x.__pow__(y), + "x.__sub__(y)": lambda x, y: x.__sub__(y), + "x.__truediv__(y)": lambda x, y: x.__truediv__(y)} +array_binary = {"x @ y": lambda x, y: x @ y, + "operator.matmul": operator.matmul, + "opterator.__matmul__": operator.__matmul__} +array_methods_binary = {"x.__matmul__(y)": lambda x, y: x.__matmul__(y)} bitwise_unary = {'bitwise_invert': lambda x: ~x} bitwise_methods_unary = {'bitwise_invert': lambda x: x.__invert__()} bitwise_binary = {'bitwise_and': lambda x, y: x & y, @@ -131,7 +141,6 @@ def get_rtol(dtype, xp): 'bitwise_right_shift': lambda x, y: x.__rshift__(y), 'bitwise_xor': lambda x, y: x.__xor__(y)} - scalar_conversions = {bool: True, int: 10, float: 1.5, complex: 1.5 + 2.5j} # tested in test_dlpack @@ -139,11 +148,18 @@ def get_rtol(dtype, xp): # tested in test_indexing # __getitem__, __index__, __setitem__, -comparison_binary = [lambda x, y: x < y, lambda x, y: x <= y, lambda x, y: x > y, - lambda x, y: x >= y, lambda x, y: x == y , lambda x, y: x != y] -comparison_methods_binary = [lambda x, y: x.__eq__(y), lambda x, y: x.__ge__(y), - lambda x, y: x.__gt__(y), lambda x, y: x.__le__(y), - lambda x, y: x.__lt__(y), lambda x, y: x.__ne__(y)] +comparison_binary = {"x < y": lambda x, y: x < y, + "x <= y": lambda x, y: x <= y, + "x > y": lambda x, y: x > y, + "x >= y": lambda x, y: x >= y, + "x == y": lambda x, y: x == y , + "x != y": lambda x, y: x != y} +comparison_methods_binary = {"x.__eq__(y)": lambda x, y: x.__eq__(y), + "x.__ge__(y)": lambda x, y: x.__ge__(y), + "x.__gt__(y)": lambda x, y: x.__gt__(y), + "x.__le__(y)": lambda x, y: x.__le__(y), + "x.__lt__(y)": lambda x, y: x.__lt__(y), + "x.__ne__(y)": lambda x, y: x.__ne__(y)} def iadd(x, y): x += y def isub(x, y): x -= y @@ -185,10 +201,11 @@ def irshift(x, y): x >>= y utility_array = ['all', 'any'] -@pytest.mark.parametrize("f", arithmetic_unary[:1] + arithmetic_methods_unary) +@pytest.mark.parametrize("f_name, f", + (arithmetic_unary | arithmetic_methods_unary).items()) @pytest.mark.parametrize('dtype', dtypes_numeric) @pytest.mark.parametrize('xp', xps) -def test_arithmetic_unary(f, dtype, xp, seed=None): +def test_arithmetic_unary(f_name, f, dtype, xp, seed=None): marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) res = f(marrays[0]) ref = f(masked_arrays[0]) @@ -205,11 +222,12 @@ def test_arithmetic_unary(f, dtype, xp, seed=None): ] -@pytest.mark.parametrize("f", arithmetic_binary + arithmetic_methods_binary) +@pytest.mark.parametrize("f_name, f", + (arithmetic_binary | arithmetic_methods_binary).items()) @pytest.mark.parametrize('dtype', dtypes_numeric) @pytest.mark.parametrize('xp', xps) @pass_exceptions(allowed=arithetic_binary_exceptions) -def test_arithmetic_binary(f, dtype, xp, seed=None): +def test_arithmetic_binary(f_name, f, dtype, xp, seed=None): marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed) res = f(marrays[0], marrays[1]) ref_data = f(masked_arrays[0].data, masked_arrays[1].data) @@ -218,11 +236,11 @@ def test_arithmetic_binary(f, dtype, xp, seed=None): assert_equal(res, ref, seed=seed, xp=xp) -@pytest.mark.parametrize("f", array_binary + array_methods_binary) +@pytest.mark.parametrize("f_name, f", (array_binary | array_methods_binary).items()) @pytest.mark.parametrize('dtype', dtypes_all) @pytest.mark.parametrize('xp', xps) @pass_exceptions(allowed=["Only numeric dtypes are allowed in matmul"]) -def test_array_binary(f, dtype, xp, seed=None): +def test_array_binary(f_name, f, dtype, xp, seed=None): marrays, masked_arrays, seed = get_arrays(1, ndim=(2, 4), dtype=dtype, xp=xp, seed=seed) res = f(marrays[0], marrays[0].mT) x = masked_arrays[0].data @@ -234,12 +252,10 @@ def test_array_binary(f, dtype, xp, seed=None): assert_allclose(res, ref, seed=seed, xp=xp, rtol=get_rtol(dtype, xp)) -@pytest.mark.parametrize("f_name_fun", itertools.chain(bitwise_unary.items(), - bitwise_methods_unary.items())) +@pytest.mark.parametrize("f_name, f", (bitwise_unary | bitwise_methods_unary).items()) @pytest.mark.parametrize("dtype", dtypes_integral + dtypes_boolean) @pytest.mark.parametrize('xp', xps) -def test_bitwise_unary(f_name_fun, dtype, xp, seed=None): - f_name, f = f_name_fun +def test_bitwise_unary(f_name, f, dtype, xp, seed=None): mxp = marray._get_namespace(xp) marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) @@ -252,14 +268,12 @@ def test_bitwise_unary(f_name_fun, dtype, xp, seed=None): assert_equal(res, ref, xp=xp, seed=seed) -@pytest.mark.parametrize("f_name_fun", itertools.chain(bitwise_binary.items(), - bitwise_methods_binary.items())) +@pytest.mark.parametrize("f_name, f", (bitwise_binary | bitwise_methods_binary).items()) @pytest.mark.parametrize("dtype", dtypes_integral + dtypes_boolean) @pytest.mark.parametrize('xp', xps) @pass_exceptions(allowed=["is only defined for x2 >= 0", "Only integer dtypes are allowed in "]) -def test_bitwise_binary(f_name_fun, dtype, xp, seed=None): - f_name, f = f_name_fun +def test_bitwise_binary(f_name, f, dtype, xp, seed=None): mxp = marray._get_namespace(xp) marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed) @@ -366,11 +380,12 @@ def test_dlpack(dtype, xp, seed=None): marrays[0].to_device(mxp.__array_namespace_info__().default_device()) -@pytest.mark.parametrize("f", comparison_binary + comparison_methods_binary) +@pytest.mark.parametrize("f_name, f", + (comparison_binary | comparison_methods_binary).items()) @pytest.mark.parametrize("dtype", dtypes_all) @pytest.mark.parametrize('xp', xps) @pass_exceptions(allowed=["Only real numeric dtypes are allowed in"]) -def test_comparison_binary(f, dtype, xp, seed=None): +def test_comparison_binary(f_name, f, dtype, xp, seed=None): marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed) res = f(marrays[0], marrays[1]) ref = f(masked_arrays[0], masked_arrays[1]) @@ -427,7 +442,7 @@ def test_inplace_array_binary(f, dtype, xp, seed=None): assert_allclose(a, ref, xp=xp, seed=seed) -@pytest.mark.parametrize("f", arithmetic_binary) +@pytest.mark.parametrize("f_name, f", arithmetic_binary.items()) @pytest.mark.parametrize("dtype", dtypes_all) @pytest.mark.parametrize('xp', xps) @pytest.mark.parametrize('type_', ["array", "scalar"]) @@ -437,7 +452,7 @@ def test_inplace_array_binary(f, dtype, xp, seed=None): "Only floating-point dtypes are allowed", "Integers to negative integer powers are not allowed", "numpy boolean subtract, the `-` operator, is not supported"]) -def test_rarithmetic_binary(f, dtype, xp, type_, seed=None): +def test_rarithmetic_binary(f_name, f, dtype, xp, type_, seed=None): marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed) if type_ == "array": arg1a = marrays[0].data @@ -1050,7 +1065,7 @@ def test_signature_docs(): def test_gh33(): # See https://github.com/mdhaber/marray/issues/33 - test_array_binary(array_binary[0], dtype='float32', xp=np, seed=566) + test_array_binary(*list(array_binary.items())[0], dtype='float32', xp=np, seed=566) def test_test():