Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 77 additions & 5 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def ndim(self):

@property
def _size(self):
return reduce(operator.mul, self._shape)
return reduce(operator.mul, self._shape, 1)

@property
def size(self):
Expand Down Expand Up @@ -325,7 +325,7 @@ def cdata_shape(self):

@property
def _nchunks(self):
return reduce(operator.mul, self._cdata_shape)
return reduce(operator.mul, self._cdata_shape, 1)

@property
def nchunks(self):
Expand Down Expand Up @@ -360,13 +360,16 @@ def __eq__(self, other):
)

def __array__(self, *args):
a = self[:]
a = self[...]
if args:
a = a.astype(args[0])
return a

def __len__(self):
return self.shape[0]
if self.shape:
return self.shape[0]
else:
raise TypeError('len() of unsized object')

def __getitem__(self, item):
"""Retrieve data for some portion of the array. Most NumPy-style
Expand Down Expand Up @@ -443,6 +446,44 @@ def __getitem__(self, item):
if not self._cache_metadata:
self._load_metadata()

# handle zero-dimensional arrays
if self._shape == ():
return self._getitem_zd(item)
else:
return self._getitem_nd(item)

def _getitem_zd(self, item):
# special case __getitem__ for zero-dimensional array

# check item is valid
if item not in ((), Ellipsis):
raise IndexError('too many indices for array')

try:

# obtain encoded data for chunk
ckey = self._chunk_key((0,))
cdata = self.chunk_store[ckey]

except KeyError:

# chunk not initialized
out = np.empty((), dtype=self._dtype)
if self._fill_value is not None:
out.fill(self._fill_value)

else:

out = self._decode_chunk(cdata)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point. This seems a bit spaced out. Not sure if that was intentional or not.


# handle selection of the scalar value via empty tuple
out = out[item]

return out

def _getitem_nd(self, item):
# implementation of __getitem__ for array with at least one dimension

# normalize selection
selection = normalize_array_selection(item, self._shape)

Expand Down Expand Up @@ -559,6 +600,36 @@ def __setitem__(self, item, value):
if not self._cache_metadata:
self._load_metadata_nosync()

# handle zero-dimensional arrays
if self._shape == ():
return self._setitem_zd(item, value)
else:
return self._setitem_nd(item, value)

def _setitem_zd(self, item, value):
# special case __setitem__ for zero-dimensional array

# check item is valid
if item not in ((), Ellipsis):
raise IndexError('too many indices for array')

# setup data to store
arr = np.asarray(value, dtype=self._dtype)

# check value
if arr.shape != ():
raise ValueError('bad value; expected scalar, found %r' % value)

# obtain key for chunk storage
ckey = self._chunk_key((0,))

# encode and store
cdata = self._encode_chunk(arr)
self.chunk_store[ckey] = cdata

def _setitem_nd(self, item, value):
# implementation of __setitem__ for array with at least one dimension

# normalize selection
selection = normalize_array_selection(item, self._shape)

Expand All @@ -570,7 +641,7 @@ def __setitem__(self, item, value):
if np.isscalar(value):
pass
elif expected_shape != value.shape:
raise ValueError('value has wrong shape, expecting %s, found %s'
raise ValueError('value has wrong shape; expected %s, found %s'
% (str(expected_shape),
str(value.shape)))

Expand Down Expand Up @@ -646,6 +717,7 @@ def _chunk_getitem(self, cidx, item, dest):
# optimization: we want the whole chunk, and the destination is
# contiguous, so we can decompress directly from the chunk
# into the destination array

if self._compressor:
self._compressor.decode(cdata, dest)
else:
Expand Down
12 changes: 8 additions & 4 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,23 @@ def _init_array_metadata(store, shape, chunks=None, dtype=None, compressor='defa
chunks = normalize_chunks(chunks, shape, dtype.itemsize)
order = normalize_order(order)

# obtain compressor config
if compressor == 'none':
# compressor prep
if shape == ():
# no point in compressing scalars
compressor = None
elif compressor == 'none':
# compatibility
compressor = None
elif compressor == 'default':
compressor = default_compressor

# obtain compressor config
compressor_config = None
if compressor:
try:
compressor_config = compressor.get_config()
except AttributeError:
err_bad_compressor(compressor)
else:
compressor_config = None

# obtain filters config
if filters:
Expand Down
46 changes: 46 additions & 0 deletions zarr/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def test_array_1d_fill_value(self):
assert_array_equal(f[310:], z[310:])

def test_array_1d_set_scalar(self):
# test setting the contents of an array with a scalar value

# setup
a = np.zeros(100)
z = self.create_array(shape=a.shape, chunks=10, dtype=a.dtype)
Expand Down Expand Up @@ -587,6 +589,50 @@ def test_0len_dim_2d(self):
with assert_raises(IndexError):
z[:, 0] = 42

def test_array_0d(self):
# test behaviour for scalars, i.e., array with 0 dimensions

# setup
a = np.zeros(())
z = self.create_array(shape=(), dtype=a.dtype, fill_value=0)

# check properties
eq(a.ndim, z.ndim)
eq(a.shape, z.shape)
eq(a.size, z.size)
eq(a.dtype, z.dtype)
with assert_raises(TypeError):
len(z)
eq((), z.chunks)
eq(1, z.nchunks)
# compressor always None - no point in compressing a scalar value
assert_is_none(z.compressor)
Copy link
Member

@jakirkham jakirkham Oct 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth checking nbytes as well or is that too far afield?


# check __getitem__
b = z[...]
assert_is_instance(b, np.ndarray)
eq(a.shape, b.shape)
eq(a.dtype, b.dtype)
assert_array_equal(a, np.array(z))
assert_array_equal(a, z[...])
eq(a[()], z[()])
with assert_raises(IndexError):
z[0]
with assert_raises(IndexError):
z[:]

# check __setitem__
z[...] = 42
eq(42, z[()])
z[()] = 43
eq(43, z[()])
with assert_raises(IndexError):
z[0] = 42
with assert_raises(IndexError):
z[:] = 42
with assert_raises(ValueError):
z[...] = np.array([1, 2, 3])


class TestArrayWithPath(TestArray):

Expand Down