diff --git a/.gitignore b/.gitignore index 3de4f1e8..82e989c9 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,15 @@ dmypy.json # Pyre type checker .pyre/ + +# jetbrains ide stuff +*.iml +.idea/ + +# vscode ide stuff +*.code-workspace +.history/ +.vscode/ + +# project scratch dir +/scratch/ diff --git a/ndindex/__init__.py b/ndindex/__init__.py index 37667c1b..d17a9f85 100644 --- a/ndindex/__init__.py +++ b/ndindex/__init__.py @@ -1,8 +1,8 @@ __all__ = [] -from .ndindex import ndindex +from .ndindex import parse_index, ndindex -__all__ += ['ndindex'] +__all__ += ['parse_index', 'ndindex'] from .slice import Slice diff --git a/ndindex/ndindex.py b/ndindex/ndindex.py index cf7688fb..feebfeec 100644 --- a/ndindex/ndindex.py +++ b/ndindex/ndindex.py @@ -1,3 +1,4 @@ +import ast import inspect import numbers @@ -65,6 +66,91 @@ def __init__(self, f): def __get__(self, obj, owner): return self.f(owner) +class _Guard: + def __init__(self): + self.val = False + + def __call__(self): + if self.val: + return True + else: + self.val = True + return False + +def parse_index(node_or_string): + """ + "Safely" (needs validation) evaluate an expression node or a string containing + a (limited subset) of valid numpy index or slice expressions. + """ + if isinstance(node_or_string, str): + node_or_string = ast.parse('dummy[{}]'.format(node_or_string.lstrip(" \t")) , mode='eval') + if isinstance(node_or_string, ast.Expression): + node_or_string = node_or_string.body + if isinstance(node_or_string, ast.Subscript): + node_or_string = node_or_string.slice + + def _raise_malformed_node(node): + raise ValueError(f'malformed node or string: {node!r}, {ast.dump(node)!r}') + def _raise_nested_tuple_node(node): + raise ValueError(f'tuples inside of tuple indices are not supported: {node!r}, {ast.dump(node)!r}') + + # from cpy37, should work until they remove ast.Num (not until cpy310) + def _convert_num(node): + if isinstance(node, ast.Constant): + if isinstance(node.value, (int, float, complex)): + return node.value + elif isinstance(node, ast.Num): + # ast.Num was removed from ast grammar in cpy38 + return node.n # pragma: no cover + _raise_malformed_node(node) + def _convert_signed_num(node): + if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)): + operand = _convert_num(node.operand) + if isinstance(node.op, ast.UAdd): + return + operand + else: + return - operand + return _convert_num(node) + + _nested_tuple_guard = _Guard() + def _convert(node): + if isinstance(node, ast.Tuple): + if _nested_tuple_guard(): + _raise_nested_tuple_node(node) + + return tuple(map(_convert, node.elts)) + elif isinstance(node, ast.List): + return list(map(_convert, node.elts)) + elif isinstance(node, ast.Slice): + return slice( + _convert(node.lower) if node.lower is not None else None, + _convert(node.upper) if node.upper is not None else None, + _convert(node.step) if node.step is not None else None, + ) + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'slice' and node.keywords == []: + # support for parsing slices written out as 'slice(...)' objects + return slice(*map(_convert, node.args)) + elif isinstance(node, ast.NameConstant) and node.value is None: + # support for literal None in slices, eg 'slice(None, ...)' + return None + elif isinstance(node, ast.Ellipsis): + # support for three dot '...' ellipsis syntax + return ... + elif isinstance(node, ast.Name) and node.id == 'Ellipsis': + # support for 'Ellipsis' ellipsis syntax + return ... + elif isinstance(node, ast.Index): + # ast.Index was removed from ast grammar in cpy39 + return _convert(node.value) # pragma: no cover + elif isinstance(node, ast.ExtSlice): + # ast.ExtSlice was removed from ast grammar in cpy39 + _nested_tuple_guard() # pragma: no cover + return tuple(map(_convert, node.dims)) # pragma: no cover + + return _convert_signed_num(node) + return ndindex(_convert(node_or_string)) + + class NDIndex: """ Represents an index into an nd-array (i.e., a numpy array). diff --git a/ndindex/tests/helpers.py b/ndindex/tests/helpers.py index 93d336b3..aeda6631 100644 --- a/ndindex/tests/helpers.py +++ b/ndindex/tests/helpers.py @@ -24,12 +24,14 @@ def prod(seq): return reduce(mul, seq, 1) +positive_ints = integers(1, 10) nonnegative_ints = integers(0, 10) negative_ints = integers(-10, -1) ints = lambda: one_of(negative_ints, nonnegative_ints) +ints_nonzero = lambda: one_of(negative_ints, positive_ints) def slices(start=one_of(none(), ints()), stop=one_of(none(), ints()), - step=one_of(none(), ints())): + step=one_of(none(), ints_nonzero())): return builds(slice, start, stop, step) ellipses = lambda: just(...) diff --git a/ndindex/tests/test_ndindex.py b/ndindex/tests/test_ndindex.py index a8ea8ed7..4cd05e33 100644 --- a/ndindex/tests/test_ndindex.py +++ b/ndindex/tests/test_ndindex.py @@ -1,17 +1,34 @@ +import ast import inspect - import numpy as np - -from hypothesis import given, example, settings - +from hypothesis import example, given, settings +from hypothesis.strategies import one_of from pytest import raises, warns -from ..ndindex import ndindex, asshape +from ..ndindex import ndindex, parse_index, asshape from ..integer import Integer from ..ellipsis import ellipsis from ..integerarray import IntegerArray from ..tuple import Tuple -from .helpers import ndindices, check_same, assert_equal +from .helpers import ndindices, check_same, assert_equal, ellipses, ints, slices, tuples, _doesnt_raise + +Tuples = tuples(one_of( + ellipses(), + ints(), + slices(), +)).filter(_doesnt_raise) + +ndindexStrs = one_of( + ellipses(), + ints(), + slices(), + Tuples, +).map(lambda x: f'{x}') + +class _Dummy: + def __getitem__(self, x): + return x +_dummy = _Dummy() @given(ndindices) def test_eq(idx): @@ -77,6 +94,63 @@ def test_ndindex_invalid(): def test_ndindex_ellipsis(): raises(IndexError, lambda: ndindex(ellipsis)) + +@example('3') +@example('-3') +@example('...') +@example('Ellipsis') +@example('+3') +@example('3:4') +@example('3:-4') +@example('3, 5, 14, 1') +@example('3, -5, 14, -1') +@example('3:15, 5, 14:99, 1') +@example('3:15, -5, 14:-99, 1') +@example(':15, -5, 14:-99:3, 1') +@example('3:15, -5, [1,2,3], :') +@example('slice(None)') +@example('slice(None, None)') +@example('slice(None, None, None)') +@example('slice(14)') +@example('slice(12, 14)') +@example('slice(12, 72, 14)') +@example('slice(-12, -72, 14)') +@example('3:15, -5, slice(-12, -72, 14), Ellipsis') +@example('..., 3:15, -5, slice(-12, -72, 14)') +@given(ndindexStrs) +def test_parse_index_hypothesis(ixStr): + assert eval(f'_dummy[{ixStr}]') == parse_index(ixStr) + +def test_parse_index_malformed_raise(): + # we don't allow the bitwise not unary op + with raises(ValueError): + ixStr = '~3' + parse_index(ixStr) + +def test_parse_index_nested_tuple_raise(): + # we don't allow tuples within tuple indices + with raises(ValueError): + # this will parse as either ast.Index or ast.Slice (depending on cpy version) containing an ast.Tuple + ixStr = '..., -5, slice(12, -14), (1,2,3)' + parse_index(ixStr) + + with raises(ValueError): + # in cpy37, this will parse as ast.ExtSlice containing an ast.Tuple + ixStr = '3:15, -5, :, (1,2,3)' + parse_index(ixStr) + +def test_parse_index_ensure_coverage(): + # ensure full coverage, regarless of cpy version and accompanying changes to the ast grammar + for node in ( + ast.Constant(7), + ast.Num(7), + ast.Index(ast.Constant(7)), + ): + assert parse_index(node) == 7 + + assert parse_index(ast.ExtSlice((ast.Constant(7), ast.Constant(7), ast.Constant(7)))) == (7, 7, 7) + + def test_signature(): sig = inspect.signature(Integer) assert sig.parameters.keys() == {'idx'} diff --git a/setup.py b/setup.py index 1b6035b6..135cf003 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,11 @@ "sympy", ], tests_require=[ - 'pytest', - 'hypothesis', + "pytest", + "pytest-cov", + "pytest-flakes", + "pytest-tornasync", + "hypothesis", ], classifiers=[ "Programming Language :: Python :: 3",