Skip to content

Commit

Permalink
TST: improve parametrization style
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Feb 1, 2025
1 parent aad3961 commit 422a683
Showing 1 changed file with 54 additions and 39 deletions.
93 changes: 54 additions & 39 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
dtypes_all = dtypes_boolean + dtypes_integral + dtypes_real + dtypes_complex


def get_arrays(n_arrays, *, dtype, xp, ndim=(1, 4), seed=None):
def get_arrays(n_arrays, *, dtype, xp, shape=None, ndim=(1, 4), seed=None):
xpm = marray._get_namespace(xp)

entropy = np.random.SeedSequence(seed).entropy
rng = np.random.default_rng(entropy)

ndim = rng.integers(*ndim) if isinstance(ndim, tuple) else ndim
shape = rng.integers(1, 20, size=ndim)
shape = rng.integers(1, 20, size=ndim) if shape is None else np.asarray(shape)

datas = []
masks = []
Expand Down Expand Up @@ -106,18 +106,28 @@ def get_rtol(dtype, xp):
return 0


arithmetic_unary = [lambda x: +x, lambda x: -x, abs]
arithmetic_methods_unary = [lambda x: x.__abs__(), lambda x: x.__neg__(),
lambda x: x.__pos__()]
arithmetic_binary = [lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y,
lambda x, y: x / y, lambda x, y: x // y, lambda x, y: x % y,
lambda x, y: x ** y]
arithmetic_methods_binary = [lambda x, y: x.__add__(y), lambda x, y: x.__floordiv__(y),
lambda x, y: x.__mod__(y), lambda x, y: x.__mul__(y),
lambda x, y: x.__pow__(y), lambda x, y: x.__sub__(y),
lambda x, y: x.__truediv__(y)]
array_binary = [lambda x, y: x @ y, operator.matmul, operator.__matmul__]
array_methods_binary = [lambda x, y: x.__matmul__(y)]
arithmetic_unary = {"+x": lambda x: +x, "-x": lambda x: -x, "abs": abs}
arithmetic_methods_unary = {"x.__abs__": lambda x: x.__abs__(),
"x.__neg__": lambda x: x.__neg__(),
"x.__pos__": lambda x: x.__pos__()}
arithmetic_binary = {"x + y": lambda x, y: x + y,
"x - y": lambda x, y: x - y,
"x * y": lambda x, y: x * y,
"x / y": lambda x, y: x / y,
"x // y": lambda x, y: x // y,
"x % y": lambda x, y: x % y,
"x ** y": lambda x, y: x ** y}
arithmetic_methods_binary = {"x.__add__(y)": lambda x, y: x.__add__(y),
"x.__floordiv__(y)": lambda x, y: x.__floordiv__(y),
"x.__mod__(y)": lambda x, y: x.__mod__(y),
"x.__mul__(y)": lambda x, y: x.__mul__(y),
"x.__pow__(y)": lambda x, y: x.__pow__(y),
"x.__sub__(y)": lambda x, y: x.__sub__(y),
"x.__truediv__(y)": lambda x, y: x.__truediv__(y)}
array_binary = {"x @ y": lambda x, y: x @ y,
"operator.matmul": operator.matmul,
"opterator.__matmul__": operator.__matmul__}
array_methods_binary = {"x.__matmul__(y)": lambda x, y: x.__matmul__(y)}
bitwise_unary = {'bitwise_invert': lambda x: ~x}
bitwise_methods_unary = {'bitwise_invert': lambda x: x.__invert__()}
bitwise_binary = {'bitwise_and': lambda x, y: x & y,
Expand All @@ -131,19 +141,25 @@ def get_rtol(dtype, xp):
'bitwise_right_shift': lambda x, y: x.__rshift__(y),
'bitwise_xor': lambda x, y: x.__xor__(y)}


scalar_conversions = {bool: True, int: 10, float: 1.5, complex: 1.5 + 2.5j}

# tested in test_dlpack
# __dlpack__, __dlpack_device__, to_device
# tested in test_indexing
# __getitem__, __index__, __setitem__,

comparison_binary = [lambda x, y: x < y, lambda x, y: x <= y, lambda x, y: x > y,
lambda x, y: x >= y, lambda x, y: x == y , lambda x, y: x != y]
comparison_methods_binary = [lambda x, y: x.__eq__(y), lambda x, y: x.__ge__(y),
lambda x, y: x.__gt__(y), lambda x, y: x.__le__(y),
lambda x, y: x.__lt__(y), lambda x, y: x.__ne__(y)]
comparison_binary = {"x < y": lambda x, y: x < y,
"x <= y": lambda x, y: x <= y,
"x > y": lambda x, y: x > y,
"x >= y": lambda x, y: x >= y,
"x == y": lambda x, y: x == y ,
"x != y": lambda x, y: x != y}
comparison_methods_binary = {"x.__eq__(y)": lambda x, y: x.__eq__(y),
"x.__ge__(y)": lambda x, y: x.__ge__(y),
"x.__gt__(y)": lambda x, y: x.__gt__(y),
"x.__le__(y)": lambda x, y: x.__le__(y),
"x.__lt__(y)": lambda x, y: x.__lt__(y),
"x.__ne__(y)": lambda x, y: x.__ne__(y)}

def iadd(x, y): x += y
def isub(x, y): x -= y
Expand Down Expand Up @@ -185,10 +201,11 @@ def irshift(x, y): x >>= y
utility_array = ['all', 'any']


@pytest.mark.parametrize("f", arithmetic_unary[:1] + arithmetic_methods_unary)
@pytest.mark.parametrize("f_name, f",
(arithmetic_unary | arithmetic_methods_unary).items())
@pytest.mark.parametrize('dtype', dtypes_numeric)
@pytest.mark.parametrize('xp', xps)
def test_arithmetic_unary(f, dtype, xp, seed=None):
def test_arithmetic_unary(f_name, f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
res = f(marrays[0])
ref = f(masked_arrays[0])
Expand All @@ -205,11 +222,12 @@ def test_arithmetic_unary(f, dtype, xp, seed=None):
]


@pytest.mark.parametrize("f", arithmetic_binary + arithmetic_methods_binary)
@pytest.mark.parametrize("f_name, f",
(arithmetic_binary | arithmetic_methods_binary).items())
@pytest.mark.parametrize('dtype', dtypes_numeric)
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=arithetic_binary_exceptions)
def test_arithmetic_binary(f, dtype, xp, seed=None):
def test_arithmetic_binary(f_name, f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)
res = f(marrays[0], marrays[1])
ref_data = f(masked_arrays[0].data, masked_arrays[1].data)
Expand All @@ -218,11 +236,11 @@ def test_arithmetic_binary(f, dtype, xp, seed=None):
assert_equal(res, ref, seed=seed, xp=xp)


@pytest.mark.parametrize("f", array_binary + array_methods_binary)
@pytest.mark.parametrize("f_name, f", (array_binary | array_methods_binary).items())
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=["Only numeric dtypes are allowed in matmul"])
def test_array_binary(f, dtype, xp, seed=None):
def test_array_binary(f_name, f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(1, ndim=(2, 4), dtype=dtype, xp=xp, seed=seed)
res = f(marrays[0], marrays[0].mT)
x = masked_arrays[0].data
Expand All @@ -234,12 +252,10 @@ def test_array_binary(f, dtype, xp, seed=None):
assert_allclose(res, ref, seed=seed, xp=xp, rtol=get_rtol(dtype, xp))


@pytest.mark.parametrize("f_name_fun", itertools.chain(bitwise_unary.items(),
bitwise_methods_unary.items()))
@pytest.mark.parametrize("f_name, f", (bitwise_unary | bitwise_methods_unary).items())
@pytest.mark.parametrize("dtype", dtypes_integral + dtypes_boolean)
@pytest.mark.parametrize('xp', xps)
def test_bitwise_unary(f_name_fun, dtype, xp, seed=None):
f_name, f = f_name_fun
def test_bitwise_unary(f_name, f, dtype, xp, seed=None):
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)

Expand All @@ -252,14 +268,12 @@ def test_bitwise_unary(f_name_fun, dtype, xp, seed=None):
assert_equal(res, ref, xp=xp, seed=seed)


@pytest.mark.parametrize("f_name_fun", itertools.chain(bitwise_binary.items(),
bitwise_methods_binary.items()))
@pytest.mark.parametrize("f_name, f", (bitwise_binary | bitwise_methods_binary).items())
@pytest.mark.parametrize("dtype", dtypes_integral + dtypes_boolean)
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=["is only defined for x2 >= 0",
"Only integer dtypes are allowed in "])
def test_bitwise_binary(f_name_fun, dtype, xp, seed=None):
f_name, f = f_name_fun
def test_bitwise_binary(f_name, f, dtype, xp, seed=None):
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)

Expand Down Expand Up @@ -366,11 +380,12 @@ def test_dlpack(dtype, xp, seed=None):
marrays[0].to_device(mxp.__array_namespace_info__().default_device())


@pytest.mark.parametrize("f", comparison_binary + comparison_methods_binary)
@pytest.mark.parametrize("f_name, f",
(comparison_binary | comparison_methods_binary).items())
@pytest.mark.parametrize("dtype", dtypes_all)
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=["Only real numeric dtypes are allowed in"])
def test_comparison_binary(f, dtype, xp, seed=None):
def test_comparison_binary(f_name, f, dtype, xp, seed=None):
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)
res = f(marrays[0], marrays[1])
ref = f(masked_arrays[0], masked_arrays[1])
Expand Down Expand Up @@ -427,7 +442,7 @@ def test_inplace_array_binary(f, dtype, xp, seed=None):
assert_allclose(a, ref, xp=xp, seed=seed)


@pytest.mark.parametrize("f", arithmetic_binary)
@pytest.mark.parametrize("f_name, f", arithmetic_binary.items())
@pytest.mark.parametrize("dtype", dtypes_all)
@pytest.mark.parametrize('xp', xps)
@pytest.mark.parametrize('type_', ["array", "scalar"])
Expand All @@ -437,7 +452,7 @@ def test_inplace_array_binary(f, dtype, xp, seed=None):
"Only floating-point dtypes are allowed",
"Integers to negative integer powers are not allowed",
"numpy boolean subtract, the `-` operator, is not supported"])
def test_rarithmetic_binary(f, dtype, xp, type_, seed=None):
def test_rarithmetic_binary(f_name, f, dtype, xp, type_, seed=None):
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)
if type_ == "array":
arg1a = marrays[0].data
Expand Down Expand Up @@ -1050,7 +1065,7 @@ def test_signature_docs():

def test_gh33():
# See https://github.com/mdhaber/marray/issues/33
test_array_binary(array_binary[0], dtype='float32', xp=np, seed=566)
test_array_binary(*list(array_binary.items())[0], dtype='float32', xp=np, seed=566)


def test_test():
Expand Down

0 comments on commit 422a683

Please sign in to comment.