Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Backport #16902 to 1.6 #16919

Merged
merged 1 commit into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 52 additions & 29 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,26 +847,32 @@ def _basic_indexing_slice_is_contiguous(slc_key, shape):
"""Whether indexing with the given key results in a contiguous array.

The rule is: From right to left, if in an axis, a slice produces a
proper subset, no later axis can produce a proper subset or use
a step different from 1.
proper subset, the later slice must have <=1 elements.

The ``slc_key`` sequence must have the same length as ``shape`` and
only contain `slice` objects.
"""
assert len(slc_key) == len(shape)
subset = False
is_subset = False
total_sliced_elements = np.prod([_get_slice_len(slc, n)
for slc, n in zip(slc_key, shape)])
if total_sliced_elements in (0, 1):
return True
for idx, n in zip(reversed(slc_key), reversed(shape)):
start, stop, step = idx.indices(n)
if step > 0:
num = int(np.ceil(max(stop - start, 0) / step))
else:
num = int(np.ceil(min(stop - start, 0) / step))

if num != 1 and (subset or step != 1):
_, _, step = idx.indices(n)
num_elements = _get_slice_len(idx, n)
if num_elements == 0:
return True
elif num_elements > 1 and (step > 1 or step < 0):
# We do not support the case of reverse slicing of multiple elements and
# forward slicing of #elements > 1 and step > 1
return False
if num != n:
subset = True

elif is_subset:
if num_elements > 1:
return False
else:
if num_elements < n:
is_subset = True
return True
# pylint: enable=invalid-name

Expand All @@ -875,30 +881,27 @@ def _basic_indexing_sliced_shape(slc_key, shape):
"""Return the shape after slicing with the given key."""
assert len(slc_key) == len(shape)
sliced_shape = []
for idx, n in zip(slc_key, shape):
start, stop, step = idx.indices(n)
if step > 0:
num = int(np.ceil(max(stop - start, 0) / step))
else:
num = int(np.ceil(min(stop - start, 0) / step))
sliced_shape.append(num)

for slc, n in zip(slc_key, shape):
num_elements = _get_slice_len(slc, n)
sliced_shape.append(num_elements)
return tuple(sliced_shape)

# pylint: disable=invalid-name
@staticmethod
def _basic_indexing_contiguous_flat_begin_end(slc_key, shape):
"""Return the flat indices of begin and end for contiguous slicing."""
assert len(slc_key) == len(shape)
begin, end, _ = slc_key[0].indices(shape[0])
flat_begin, flat_end = begin, end - 1
for idx, n in zip(slc_key[1:], shape[1:]):
flat_begin, flat_end = 0, 0
for slc, n in zip(slc_key, shape):
flat_begin *= n
flat_end *= n
begin, end, _ = idx.indices(n)
flat_begin += begin
flat_end += end - 1

begin, _, _ = slc.indices(n)
num_elements = _get_slice_len(slc, n)
if num_elements == 0:
return 0, 0
else:
flat_begin += begin
flat_end += begin + num_elements - 1
return flat_begin, flat_end + 1
# pylint: enable=invalid-name

Expand Down Expand Up @@ -1062,7 +1065,7 @@ def _get_nd_basic_indexing(self, key):
for ax in new_axes: # pylint: disable=invalid-name
final_shape.insert(ax, 1)

if final_shape == []:
if len(final_shape) == 0:
# Override for single element indexing
final_shape = [1]
return sliced.reshape(final_shape)
Expand Down Expand Up @@ -3108,6 +3111,26 @@ def _get_dim_size(start, stop, step):
return dim_size


def _get_slice_len(slc, seq_length):
"""Given a python slice object and the length of the sequence, calculate the number of elements
in the slice.

Parameters
----------
slc : py_slice
The slice object
seq_length : int
The length of the object you are going to apply the slice on

Returns
-------
ret : int
Total number of elements in the slice
"""
start, stop, step = slc.indices(seq_length)
return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)


def _get_broadcast_shape(shape1, shape2):
"""Given two shapes that are not identical, find the shape
that both input shapes can broadcast to."""
Expand Down
5 changes: 2 additions & 3 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
NDArray ret = this->Slice(begin, end);
if (!Imperative::Get()->is_recording()) return ret;
// fake a slice_axis op
// fake a slice op
nnvm::NodeAttrs attrs;
attrs.op = nnvm::Op::Get("slice_axis");
attrs.dict.insert({"axis", "0"});
attrs.op = nnvm::Op::Get("slice");
attrs.dict.insert({"begin", std::to_string(begin)});
attrs.dict.insert({"end", std::to_string(end)});
attrs.op->attr_parser(&attrs);
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ static inline bool SupportStorageMKLDNN(int stype) {

static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) {
// MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
return false;
}
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}

Expand Down
20 changes: 13 additions & 7 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,18 @@ def test_getitem(np_array, index):
)
np_indexed_array = np_array[np_index]
mx_np_array = np.array(np_array, dtype=np_array.dtype)
try:
mx_indexed_array = mx_np_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
for autograd in [True, False]:
try:
if autograd:
with mx.autograd.record():
mx_indexed_array = mx_np_array[index]
else:
mx_indexed_array = mx_np_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)

def test_setitem(np_array, index):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
Expand Down Expand Up @@ -768,6 +773,7 @@ def test_setitem_autograd(np_array, index):
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
slice(1, 2, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),
Expand Down