Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Style and test fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
zoeygxy committed Aug 27, 2019
1 parent ddffee0 commit 2052fb7
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 177 deletions.
4 changes: 0 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ def _get_np_basic_indexing(self, key):
for ax in new_axes: # pylint: disable=invalid-name
final_shape.insert(ax, 1)

if final_shape == []:
# Override for single element indexing
return sliced.item()
if sliced.size == 0:
return sliced.reshape(tuple(final_shape))
else:
Expand Down Expand Up @@ -222,7 +219,6 @@ def __getitem__(self, key):
if ndim == 0:
if key != ():
raise IndexError('scalar tensor can only accept `()` as index')
return self.item()
# Handle simple cases for higher speed
if isinstance(key, tuple) and len(key) == 0:
return self
Expand Down
16 changes: 8 additions & 8 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ inline bool GetIndexRange(const mxnet::TShape& dshape,
common::StaticArray<index_t, ndim>* end,
common::StaticArray<index_t, ndim>* step) {
// Function returns false if output is zero-sized, true otherwise.
bool size_non_zero = true;
bool zero_size_shape = false;
CHECK_NE(dshape.ndim(), 0U);
CHECK_LE(param_begin.ndim(), dshape.ndim())
<< "Slicing axis exceeds data dimensions";
Expand Down Expand Up @@ -726,7 +726,7 @@ inline bool GetIndexRange(const mxnet::TShape& dshape,
(*step)[i] = s;
// checking begin==end
if (b == e) {
size_non_zero = false;
zero_size_shape = true;
}
}

Expand All @@ -736,7 +736,7 @@ inline bool GetIndexRange(const mxnet::TShape& dshape,
(*step)[i] = 1;
}

return size_non_zero;
return zero_size_shape;
}

inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape,
Expand Down Expand Up @@ -981,7 +981,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& dshape = (*in_attrs)[0];
if (dshape.ndim() == 0U) return false;
if (!mxnet::ndim_is_known(dshape)) return false;
mxnet::TShape vshape = dshape; // vshape is the value shape on the right hand side
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
MXNET_NDIM_SWITCH(dshape.ndim(), ndim, {
Expand Down Expand Up @@ -1024,9 +1024,9 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs,
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin, end, step;
bool non_zero_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step,
bool zero_size_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step,
&begin, &end, &step);
if (!non_zero_shape) {
if (zero_size_shape) {
return; // slice_assign of zero-sized subspace needs no operation.
}
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
Expand Down Expand Up @@ -1129,9 +1129,9 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs,
const SliceAssignScalarParam& param = nnvm::get<SliceAssignScalarParam>(attrs.parsed);
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin, end, step;
bool non_zero_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step,
bool zero_size_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step,
&begin, &end, &step);
if (!non_zero_shape) {
if (zero_size_shape) {
return; // slice_assign of zero-sized subspaced needs no operation.
}
for (index_t i = 0; i < param.begin.ndim(); ++i) {
Expand Down
Loading

0 comments on commit 2052fb7

Please sign in to comment.