diff --git a/zarr/core.py b/zarr/core.py index eb054e4e6b..61ea8752d4 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -268,7 +268,7 @@ def ndim(self): @property def _size(self): - return reduce(operator.mul, self._shape) + return reduce(operator.mul, self._shape, 1) @property def size(self): @@ -313,8 +313,11 @@ def nbytes_stored(self): @property def _cdata_shape(self): - return tuple(int(np.ceil(s / c)) - for s, c in zip(self._shape, self._chunks)) + if self._shape == (): + return (1,) + else: + return tuple(int(np.ceil(s / c)) + for s, c in zip(self._shape, self._chunks)) @property def cdata_shape(self): @@ -325,7 +328,7 @@ def cdata_shape(self): @property def _nchunks(self): - return reduce(operator.mul, self._cdata_shape) + return reduce(operator.mul, self._cdata_shape, 1) @property def nchunks(self): @@ -360,13 +363,16 @@ def __eq__(self, other): ) def __array__(self, *args): - a = self[:] + a = self[...] if args: a = a.astype(args[0]) return a def __len__(self): - return self.shape[0] + if self.shape: + return self.shape[0] + else: + raise TypeError('len() of unsized object') def __getitem__(self, item): """Retrieve data for some portion of the array. Most NumPy-style @@ -443,6 +449,41 @@ def __getitem__(self, item): if not self._cache_metadata: self._load_metadata() + # handle zero-dimensional arrays + if self._shape == (): + return self._getitem_zd(item) + else: + return self._getitem_nd(item) + + def _getitem_zd(self, item): + # special case __getitem__ for zero-dimensional array + + # check item is valid + if item not in ((), Ellipsis): + raise IndexError('too many indices for array') + + try: + # obtain encoded data for chunk + ckey = self._chunk_key((0,)) + cdata = self.chunk_store[ckey] + + except KeyError: + # chunk not initialized + out = np.empty((), dtype=self._dtype) + if self._fill_value is not None: + out.fill(self._fill_value) + + else: + out = self._decode_chunk(cdata) + + # handle selection of the scalar value via empty tuple + out = out[item] + + return out + + def _getitem_nd(self, item): + # implementation of __getitem__ for array with at least one dimension + # normalize selection selection = normalize_array_selection(item, self._shape) @@ -559,6 +600,36 @@ def __setitem__(self, item, value): if not self._cache_metadata: self._load_metadata_nosync() + # handle zero-dimensional arrays + if self._shape == (): + return self._setitem_zd(item, value) + else: + return self._setitem_nd(item, value) + + def _setitem_zd(self, item, value): + # special case __setitem__ for zero-dimensional array + + # check item is valid + if item not in ((), Ellipsis): + raise IndexError('too many indices for array') + + # setup data to store + arr = np.asarray(value, dtype=self._dtype) + + # check value + if arr.shape != (): + raise ValueError('bad value; expected scalar, found %r' % value) + + # obtain key for chunk storage + ckey = self._chunk_key((0,)) + + # encode and store + cdata = self._encode_chunk(arr) + self.chunk_store[ckey] = cdata + + def _setitem_nd(self, item, value): + # implementation of __setitem__ for array with at least one dimension + # normalize selection selection = normalize_array_selection(item, self._shape) @@ -570,7 +641,7 @@ def __setitem__(self, item, value): if np.isscalar(value): pass elif expected_shape != value.shape: - raise ValueError('value has wrong shape, expecting %s, found %s' + raise ValueError('value has wrong shape; expected %s, found %s' % (str(expected_shape), str(value.shape))) @@ -646,6 +717,7 @@ def _chunk_getitem(self, cidx, item, dest): # optimization: we want the whole chunk, and the destination is # contiguous, so we can decompress directly from the chunk # into the destination array + if self._compressor: self._compressor.decode(cdata, dest) else: diff --git a/zarr/creation.py b/zarr/creation.py index c40977eb43..ec32de64bc 100644 --- a/zarr/creation.py +++ b/zarr/creation.py @@ -320,7 +320,7 @@ def array(data, **kwargs): z = create(**kwargs) # fill with data - z[:] = data + z[...] = data return z diff --git a/zarr/storage.py b/zarr/storage.py index 782361bae1..aa6e949139 100644 --- a/zarr/storage.py +++ b/zarr/storage.py @@ -279,19 +279,23 @@ def _init_array_metadata(store, shape, chunks=None, dtype=None, compressor='defa chunks = normalize_chunks(chunks, shape, dtype.itemsize) order = normalize_order(order) - # obtain compressor config - if compressor == 'none': + # compressor prep + if shape == (): + # no point in compressing a 0-dimensional array, only a single value + compressor = None + elif compressor == 'none': # compatibility compressor = None elif compressor == 'default': compressor = default_compressor + + # obtain compressor config + compressor_config = None if compressor: try: compressor_config = compressor.get_config() except AttributeError: err_bad_compressor(compressor) - else: - compressor_config = None # obtain filters config if filters: diff --git a/zarr/tests/test_core.py b/zarr/tests/test_core.py index 90885cea91..1ef8bd9f33 100644 --- a/zarr/tests/test_core.py +++ b/zarr/tests/test_core.py @@ -158,6 +158,8 @@ def test_array_1d_fill_value(self): assert_array_equal(f[310:], z[310:]) def test_array_1d_set_scalar(self): + # test setting the contents of an array with a scalar value + # setup a = np.zeros(100) z = self.create_array(shape=a.shape, chunks=10, dtype=a.dtype) @@ -587,6 +589,52 @@ def test_0len_dim_2d(self): with assert_raises(IndexError): z[:, 0] = 42 + def test_array_0d(self): + # test behaviour for array with 0 dimensions + + # setup + a = np.zeros(()) + z = self.create_array(shape=(), dtype=a.dtype, fill_value=0) + + # check properties + eq(a.ndim, z.ndim) + eq(a.shape, z.shape) + eq(a.size, z.size) + eq(a.dtype, z.dtype) + eq(a.nbytes, z.nbytes) + with assert_raises(TypeError): + len(z) + eq((), z.chunks) + eq(1, z.nchunks) + eq((1,), z.cdata_shape) + # compressor always None - no point in compressing a single value + assert_is_none(z.compressor) + + # check __getitem__ + b = z[...] + assert_is_instance(b, np.ndarray) + eq(a.shape, b.shape) + eq(a.dtype, b.dtype) + assert_array_equal(a, np.array(z)) + assert_array_equal(a, z[...]) + eq(a[()], z[()]) + with assert_raises(IndexError): + z[0] + with assert_raises(IndexError): + z[:] + + # check __setitem__ + z[...] = 42 + eq(42, z[()]) + z[()] = 43 + eq(43, z[()]) + with assert_raises(IndexError): + z[0] = 42 + with assert_raises(IndexError): + z[:] = 42 + with assert_raises(ValueError): + z[...] = np.array([1, 2, 3]) + class TestArrayWithPath(TestArray): diff --git a/zarr/tests/test_hierarchy.py b/zarr/tests/test_hierarchy.py index 2f34d0e3dc..4ab1073dae 100644 --- a/zarr/tests/test_hierarchy.py +++ b/zarr/tests/test_hierarchy.py @@ -641,6 +641,10 @@ def test_setitem(self): data = np.arange(200) g['foo'] = data assert_array_equal(data, g['foo']) + # 0d array + g['foo'] = 42 + eq((), g['foo'].shape) + eq(42, g['foo'][()]) except NotImplementedError: pass