diff --git a/requirements.txt b/requirements.txt index 5186c748c2..859975d19c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ numpy fasteners +kenjutsu>=0.4.2 diff --git a/requirements_dev.txt b/requirements_dev.txt index e1b4d1fea0..ebb85799d2 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,6 +6,7 @@ coverage==4.3.4 Cython==0.25.2 fasteners==0.14.1 flake8==3.3.0 +kenjutsu==0.4.2 mccabe==0.6.1 monotonic==1.2 nose==1.3.7 diff --git a/requirements_rtfd.txt b/requirements_rtfd.txt index 86091ed956..4c739fc596 100644 --- a/requirements_rtfd.txt +++ b/requirements_rtfd.txt @@ -5,3 +5,4 @@ numpydoc mock numpy cython +kenjutsu>=0.4.2 diff --git a/setup.py b/setup.py index ac134523f1..52b9c6802c 100644 --- a/setup.py +++ b/setup.py @@ -161,6 +161,7 @@ def run_setup(with_extensions): install_requires=[ 'numpy>=1.7', 'fasteners', + 'kenjutsu>=0.4.2', ], ext_modules=ext_modules, cmdclass=cmdclass, diff --git a/zarr/core.py b/zarr/core.py index b1f7303620..58fda61935 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -1,11 +1,24 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division +import collections import operator import itertools +try: + irange = xrange +except NameError: + irange = range + +try: + from itertools import map as imap +except ImportError: + imap = map + import numpy as np +from kenjutsu.format import split_indices +from kenjutsu.measure import len_slices from zarr.util import is_total_slice, normalize_array_selection, \ get_chunk_range, human_readable_size, normalize_resize_args, \ @@ -449,12 +462,27 @@ def __getitem__(self, item): selection = normalize_array_selection(item, self._shape) # determine output array shape - out_shape = tuple(s.stop - s.start for s in selection - if isinstance(s, slice)) + out_shape = len_slices(selection) # setup output array out = np.empty(out_shape, dtype=self._dtype, order=self._order) + # Find where sequences of indices are. + seqs_locs = imap(lambda v: isinstance(v, collections.Sequence), selection) + seqs_locs = itertools.compress(irange(len(selection)), seqs_locs) + seqs_locs = list(seqs_locs) + + # Retrieve each index individually and return the result. + if seqs_locs: + assert len(seqs_locs) == 1 + seq_loc = seqs_locs[0] + out_swap = out.swapaxes(0, seq_loc) + for i, each_selection in enumerate(split_indices(selection)): + each_out = out_swap[i][None].swapaxes(0, seq_loc) + each_out[...] = self.__getitem__(each_selection) + + return out + # determine indices of chunks overlapping the selection chunk_range = get_chunk_range(selection, self._chunks) @@ -571,10 +599,8 @@ def __setitem__(self, item, value): selection = normalize_array_selection(item, self._shape) # check value shape - expected_shape = tuple( - s.stop - s.start for s in selection - if isinstance(s, slice) - ) + expected_shape = len_slices(selection) + if np.isscalar(value): pass elif expected_shape != value.shape: @@ -582,6 +608,26 @@ def __setitem__(self, item, value): % (str(expected_shape), str(value.shape))) + # Find where sequences of indices are. + seqs_locs = imap(lambda v: isinstance(v, collections.Sequence), selection) + seqs_locs = itertools.compress(irange(len(selection)), seqs_locs) + seqs_locs = list(seqs_locs) + + # Set each index individually and return the result. + if seqs_locs: + assert len(seqs_locs) == 1 + seq_loc = seqs_locs[0] + if not np.isscalar(value): + value = value.swapaxes(0, seq_loc) + for i, each_selection in enumerate(split_indices(selection)): + each_value = value + if not np.isscalar(value): + each_value = value[i][None] + each_value = each_value.swapaxes(0, seq_loc) + self.__setitem__(each_selection, each_value) + + return + # determine indices of chunks overlapping the selection chunk_range = get_chunk_range(selection, self._chunks) diff --git a/zarr/tests/test_core.py b/zarr/tests/test_core.py index 670a295f2e..5dfd15574b 100644 --- a/zarr/tests/test_core.py +++ b/zarr/tests/test_core.py @@ -132,6 +132,11 @@ def test_array_1d(self): # single item eq(a[0], z[0]) eq(a[-1], z[-1]) + # index selection + assert_array_equal(a[(0,), ...], z[(0,), ...]) + assert_array_equal(a[..., (0,)], z[..., (0,)]) + assert_array_equal(a[(1,0,2), ...], z[(1,0,2), ...]) + assert_array_equal(a[..., (1,0,2)], z[..., (1,0,2)]) # check partial assignment b = np.arange(1e5, 2e5) @@ -215,11 +220,24 @@ def test_array_2d(self): assert_array_equal(a[:110, :3], z[:110, :3]) assert_array_equal(a[190:310, 3:7], z[190:310, 3:7]) assert_array_equal(a[-110:, -3:], z[-110:, -3:]) + assert_array_equal(a[0, ...], z[0, ...]) + assert_array_equal(a[..., 0], z[..., 0]) + assert_array_equal(a[10:20, ...], z[10:20, ...]) + assert_array_equal(a[..., 3:7], z[..., 3:7]) # single item assert_array_equal(a[0], z[0]) assert_array_equal(a[-1], z[-1]) eq(a[0, 0], z[0, 0]) eq(a[-1, -1], z[-1, -1]) + # index selection + assert_array_equal(a[(0,), ...], z[(0,), ...]) + assert_array_equal(a[..., (0,)], z[..., (0,)]) + assert_array_equal(a[(1,0,2), ...], z[(1,0,2), ...]) + assert_array_equal(a[..., (1,0,2)], z[..., (1,0,2)]) + + # illegal index selection + with self.assertRaises(ValueError) as e: + z[(0,), (1,0,2)] # check partial assignment b = np.arange(10000, 20000).reshape((1000, 10)) @@ -269,6 +287,21 @@ def test_array_2d_partial(self): eq(-1, z[0, 0]) eq(-1, z[2, 2]) eq(-1, z[-1, -1]) + # check multiple indices assignment + d = np.arange(z.shape[0]) + d = np.concatenate(2*[d[..., None]], axis=-1) + z[:, (0,)] = -1 + assert_array_equal(-1, z[:, (0,)]) + z[:, (2,1)] = -1 + assert_array_equal(-1, z[:, (2,1)]) + z[:, (0,)] = d[:, (0,)] + assert_array_equal(d[:, (0,)], z[:, (0,)]) + z[:, (2,1)] = d + assert_array_equal(d, z[:, (2,1)]) + + # check illegal index assignment + with self.assertRaises(ValueError) as e: + z[(0,), (1,0,2)] = -1 def test_array_order(self): diff --git a/zarr/tests/test_util.py b/zarr/tests/test_util.py index fe4d7aaf05..cfe86d228b 100644 --- a/zarr/tests/test_util.py +++ b/zarr/tests/test_util.py @@ -51,7 +51,9 @@ def test_is_total_slice(): assert_true(is_total_slice(slice(None), (100,))) assert_true(is_total_slice(slice(0, 100), (100,))) assert_false(is_total_slice(slice(0, 50), (100,))) - assert_false(is_total_slice(slice(0, 100, 2), (100,))) + assert_false(is_total_slice(([1, 0, 2],), (100,))) + assert_true(is_total_slice(([1, 0, 2],), (3,))) + assert_true(is_total_slice(([1, 0, 2, 0],), (3,))) # 2D assert_true(is_total_slice(Ellipsis, (100, 100))) @@ -61,7 +63,10 @@ def test_is_total_slice(): assert_false(is_total_slice((slice(0, 100), slice(0, 50)), (100, 100))) assert_false(is_total_slice((slice(0, 50), slice(0, 100)), (100, 100))) assert_false(is_total_slice((slice(0, 50), slice(0, 50)), (100, 100))) - assert_false(is_total_slice((slice(0, 100, 2), slice(0, 100)), (100, 100))) + assert_false(is_total_slice(([1, 0, 2],), (100, 100))) + assert_false(is_total_slice((slice(None), [1, 0, 2]), (100, 100))) + assert_true(is_total_slice((slice(None), [1, 0, 2]), (100, 3))) + assert_true(is_total_slice((slice(None), [1, 0, 2, 0]), (100, 3))) with assert_raises(TypeError): is_total_slice('foo', (100,)) @@ -80,12 +85,12 @@ def test_normalize_axis_selection(): normalize_axis_selection(-1000, 100) # slice - eq(slice(0, 100), normalize_axis_selection(slice(None), 100)) - eq(slice(0, 100), normalize_axis_selection(slice(None, 100), 100)) - eq(slice(0, 100), normalize_axis_selection(slice(0, None), 100)) - eq(slice(0, 100), normalize_axis_selection(slice(0, 1000), 100)) - eq(slice(99, 100), normalize_axis_selection(slice(-1, None), 100)) - eq(slice(98, 99), normalize_axis_selection(slice(-2, -1), 100)) + eq(slice(0, 100, 1), normalize_axis_selection(slice(None), 100)) + eq(slice(0, 100, 1), normalize_axis_selection(slice(None, 100), 100)) + eq(slice(0, 100, 1), normalize_axis_selection(slice(0, None), 100)) + eq(slice(0, 100, 1), normalize_axis_selection(slice(0, 1000), 100)) + eq(slice(99, 100, 1), normalize_axis_selection(slice(-1, None), 100)) + eq(slice(98, 99, 1), normalize_axis_selection(slice(-2, -1), 100)) with assert_raises(IndexError): normalize_axis_selection(slice(100, None), 100) with assert_raises(IndexError): @@ -108,29 +113,32 @@ def test_normalize_array_selection(): eq((0,), normalize_array_selection(0, (100,))) # 1D, slice - eq((slice(0, 100),), normalize_array_selection(Ellipsis, (100,))) - eq((slice(0, 100),), normalize_array_selection(slice(None), (100,))) - eq((slice(0, 100),), normalize_array_selection(slice(None, 100), (100,))) - eq((slice(0, 100),), normalize_array_selection(slice(0, None), (100,))) + eq((slice(0, 100, 1),), normalize_array_selection(Ellipsis, (100,))) + eq((slice(0, 100, 1),), normalize_array_selection(slice(None), (100,))) + eq( + (slice(0, 100, 1),), + normalize_array_selection(slice(None, 100), (100,)) + ) + eq((slice(0, 100, 1),), normalize_array_selection(slice(0, None), (100,))) # 2D, single item eq((0, 0), normalize_array_selection((0, 0), (100, 100))) eq((99, 1), normalize_array_selection((-1, 1), (100, 100))) # 2D, single col/row - eq((0, slice(0, 100)), normalize_array_selection((0, slice(None)), - (100, 100))) - eq((0, slice(0, 100)), normalize_array_selection((0,), - (100, 100))) - eq((slice(0, 100), 0), normalize_array_selection((slice(None), 0), - (100, 100))) + eq((0, slice(0, 100, 1)), normalize_array_selection((0, slice(None)), + (100, 100))) + eq((0, slice(0, 100, 1)), normalize_array_selection((0,), + (100, 100))) + eq((slice(0, 100, 1), 0), normalize_array_selection((slice(None), 0), + (100, 100))) # 2D slice - eq((slice(0, 100), slice(0, 100)), + eq((slice(0, 100, 1), slice(0, 100, 1)), normalize_array_selection(Ellipsis, (100, 100))) - eq((slice(0, 100), slice(0, 100)), + eq((slice(0, 100, 1), slice(0, 100, 1)), normalize_array_selection(slice(None), (100, 100))) - eq((slice(0, 100), slice(0, 100)), + eq((slice(0, 100, 1), slice(0, 100, 1)), normalize_array_selection((slice(None), slice(None)), (100, 100))) with assert_raises(TypeError): diff --git a/zarr/util.py b/zarr/util.py index 97da2f8e51..c0329c8595 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -1,10 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division +import collections +import numbers import operator import numpy as np +from kenjutsu.format import reformat_slices +from kenjutsu.measure import len_slices from zarr.compat import integer_types, PY2, reduce @@ -111,58 +115,33 @@ def is_total_slice(item, shape): given `shape`. Used to optimize __setitem__ operations on the Chunk class.""" - # N.B., assume shape is normalized - - if item == Ellipsis: - return True - if item == slice(None): - return True - if isinstance(item, slice): - item = item, - if isinstance(item, tuple): - return all( - (isinstance(s, slice) and - ((s == slice(None)) or - ((s.stop - s.start == l) and (s.step in [1, None])))) - for s, l in zip(item, shape) - ) - else: - raise TypeError('expected slice or tuple of slices, found %r' % item) + rf_item = normalize_array_selection(item, shape) + + # Remove any duplicates from sequences. + rf_item = list(rf_item) + for i in range(len(rf_item)): + if isinstance(rf_item[i], collections.Sequence): + rf_item[i] = list(set(rf_item[i])) + rf_item = tuple(rf_item) + + return len_slices(rf_item) == shape def normalize_axis_selection(item, l): """Convenience function to normalize a selection within a single axis of size `l`.""" - if isinstance(item, int): - if item < 0: - # handle wraparound - item = l + item - if item > (l - 1) or item < 0: - raise IndexError('index out of bounds: %s' % item) - return item - - elif isinstance(item, slice): - if item.step is not None and item.step != 1: - raise NotImplementedError('slice with step not supported') - start = 0 if item.start is None else item.start - stop = l if item.stop is None else item.stop - if start < 0: - start = l + start - if stop < 0: - stop = l + stop - if start < 0 or stop < 0: - raise IndexError('index out of bounds: %s, %s' % (start, stop)) - if start >= l: - raise IndexError('index out of bounds: %s, %s' % (start, stop)) - if stop > l: - stop = l - if stop < start: - raise IndexError('index out of bounds: %s, %s' % (start, stop)) - return slice(start, stop) + rf_item = reformat_slices((item,), (l,))[0] - else: - raise TypeError('expected integer or slice, found: %r' % item) + if isinstance(rf_item, slice) and rf_item.step != 1: + raise NotImplementedError("slice with step not supported") + + if np.prod(len_slices((rf_item,))) == 0: + raise IndexError( + "index out of bounds: %s, %s" % (item.start, item.stop) + ) + + return rf_item # noinspection PyTypeChecker @@ -170,29 +149,14 @@ def normalize_array_selection(item, shape): """Convenience function to normalize a selection within an array with the given `shape`.""" - # normalize item - if isinstance(item, integer_types): - item = (int(item),) - elif isinstance(item, slice): - item = (item,) - elif item == Ellipsis: - item = (slice(None),) - - # handle tuple of indices/slices - if isinstance(item, tuple): + rf_item = reformat_slices(item, shape) - # determine start and stop indices for all axes - selection = tuple(normalize_axis_selection(i, l) - for i, l in zip(item, shape)) + # Only needed for constraint checks. + rf_item = tuple( + normalize_axis_selection(i, l) for i, l in zip(rf_item, shape) + ) - # fill out selection if not completely specified - if len(selection) < len(shape): - selection += tuple(slice(0, l) for l in shape[len(selection):]) - - return selection - - else: - raise TypeError('expected indices or slice, found: %r' % item) + return rf_item def get_chunk_range(selection, chunks):