Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix sanity problem
Browse files Browse the repository at this point in the history
  • Loading branch information
AntiZpvoh committed May 17, 2020
1 parent c004b29 commit 695d995
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 143 deletions.
1 change: 0 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def __getitem__(self, key): # pylint: disable = too-many-return-statements, inco
' are supported! Received key={}'.format(key))
if is_symbol_tuple:
return result

new_shape += (-4,)
sliced = _npi.slice(self, begin, end, step)
return _npi.reshape(sliced, new_shape)
Expand Down
85 changes: 46 additions & 39 deletions src/operator/numpy/np_indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ void AdvancedIndexingOpForward<cpu>(const nnvm::NodeAttrs& attrs,
stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
prefix_sum.data(), col_size);
});
} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 ||
} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) {
using namespace mshadow;
const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape();
Expand Down Expand Up @@ -237,15 +237,16 @@ void AdvancedIndexingOpForward<cpu>(const nnvm::NodeAttrs& attrs,
bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max);
CHECK(is_valid) << "take operator contains indices out of bound";
Kernel<AdvancedIndexingTakeCPU, cpu>::Launch(s, idxshape.Size(),
outputs[np_indexing_::kOut].data().dptr<DType>(),
inputs[np_indexing_::kArr].data().dptr<DType>(),
inputs[np_indexing_::kIdx].data().dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
outputs[np_indexing_::kOut].data().dptr<DType>(),
inputs[np_indexing_::kArr].data().dptr<DType>(),
inputs[np_indexing_::kIdx].data().dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
});
});
} else {
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
dmlc::LogMessageFatal(__FILE__, __LINE__).stream()
<< "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
}
}

Expand All @@ -261,10 +262,11 @@ void AdvancedIndexingMultipleOpForward<cpu>(const nnvm::NodeAttrs& attrs,
CHECK_EQ(outputs.size(), 1U);

if (inputs[np_indexing_::kIdx].dtype() == mshadow::kBool) {
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Multi-dimension boolean indexing is not supported.";
} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 ||
dmlc::LogMessageFatal(__FILE__, __LINE__).stream()
<< "Multi-dimension boolean indexing is not supported.";
} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 ||
inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) {
using namespace mshadow;
const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape();
Expand All @@ -274,7 +276,7 @@ void AdvancedIndexingMultipleOpForward<cpu>(const nnvm::NodeAttrs& attrs,
return;
}

CHECK_EQ(arrshape[0], idxshape[0]); // size of index must equal to size of array
CHECK_EQ(arrshape[0], idxshape[0]); // size of index must equal to size of array

mxnet::TShape oshape(arrshape.ndim() - 1, -1);
oshape[0] = arrshape[0];
Expand All @@ -297,15 +299,16 @@ void AdvancedIndexingMultipleOpForward<cpu>(const nnvm::NodeAttrs& attrs,
bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max);
CHECK(is_valid) << "take operator contains indices out of bound";
Kernel<AdvancedIndexingTakeMultiDimensionCPU, cpu>::Launch(s, idxshape.Size(),
outputs[np_indexing_::kOut].data().dptr<DType>(),
inputs[np_indexing_::kArr].data().dptr<DType>(),
inputs[np_indexing_::kIdx].data().dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[1]);
outputs[np_indexing_::kOut].data().dptr<DType>(),
inputs[np_indexing_::kArr].data().dptr<DType>(),
inputs[np_indexing_::kIdx].data().dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[1]);
});
});
} else {
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
dmlc::LogMessageFatal(__FILE__, __LINE__).stream()
<< "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
}
}

Expand Down Expand Up @@ -348,9 +351,9 @@ void AdvancedIndexingOpBackward<cpu>(const nnvm::NodeAttrs& attrs,
}
});
});
} else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 ||
} else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) {
using namespace mshadow;
using namespace mshadow::expr;
Expand Down Expand Up @@ -396,12 +399,12 @@ void AdvancedIndexingOpBackward<cpu>(const nnvm::NodeAttrs& attrs,
} else {
LOG(FATAL) << "wrong req";
}

});
});
} else {
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
dmlc::LogMessageFatal(__FILE__, __LINE__).stream()
<< "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
}
}

Expand Down Expand Up @@ -444,9 +447,9 @@ void AdvancedIndexingMultipleOpBackward<cpu>(const nnvm::NodeAttrs& attrs,
}
});
});
} else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 ||
} else if (inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt8 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt16 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt32 ||
inputs[np_indexing_::kIdx+1].dtype() == mshadow::kInt64) {
using namespace mxnet_op;
using namespace mshadow;
Expand All @@ -463,20 +466,23 @@ void AdvancedIndexingMultipleOpBackward<cpu>(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(inputs[2].dtype(), IType, { // index type
if (req[0] != kAddTo) outputs[0].data().FlatTo1D<cpu, DType>(s) = 0;
if (trailing == 1) {
Kernel<pick_grad<2, true>, cpu>::Launch(s, inputs[0].data().Size(), outputs[0].data().dptr<DType>(),
inputs[0].data().dptr<DType>(), inputs[2].data().dptr<IType>(),
M, 1, Shape2(leading, M), Shape2(leading, 1));
Kernel<pick_grad<2, true>, cpu>::Launch(s, inputs[0].data().Size(),
outputs[0].data().dptr<DType>(), inputs[0].data().dptr<DType>(),
inputs[2].data().dptr<IType>(), M,
1, Shape2(leading, M), Shape2(leading, 1));
} else {
Kernel<pick_grad<3, true>, cpu>::Launch(s, inputs[0].data().Size(), outputs[0].data().dptr<DType>(),
inputs[0].data().dptr<DType>(), inputs[2].data().dptr<IType>(),
M, trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
Kernel<pick_grad<3, true>, cpu>::Launch(s, inputs[0].data().Size(),
outputs[0].data().dptr<DType>(), inputs[0].data().dptr<DType>(),
inputs[2].data().dptr<IType>(), M,
trailing, Shape3(leading, M, trailing),
Shape3(leading, 1, trailing));
}
});
});
} else {
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "arrays used as indices must be explictly declared as integer (or boolean) type. "
<< "Use np.astype() to cast indices to integer or boolean.";
dmlc::LogMessageFatal(__FILE__, __LINE__).stream()
<< "arrays used as indices must be explictly declared as integer (or boolean) type."
<< "Use np.astype() to cast indices to integer or boolean.";
}
}

Expand Down Expand Up @@ -545,7 +551,8 @@ which stands for the rows in x where the corresonding element in index is non-ze
})
.set_attr<nnvm::FInferType>("FInferType", AdvancedIndexingOpType)
.set_attr<FComputeEx>("FComputeEx<cpu>", AdvancedIndexingMultipleOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_np_advanced_indexing_multiple"})
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_np_advanced_indexing_multiple"})
.set_attr<FInferStorageType>("FInferStorageType", AdvancedIndexingOpStorageType)
.add_argument("data", "NDArray-or-Symbol", "Data")
.add_argument("indices", "NDArray-or-Symbol", "Indices");
Expand Down
Loading

0 comments on commit 695d995

Please sign in to comment.