Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix a bug to pass the test in test_contrib_rnn #14520

Merged
merged 3 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ inline void ConvertToNumpyShape(mxnet::TShape* shape) {
*shape = mxnet::TShape(); // unknown shape ndim = -1
} else {
for (int j = 0; j < shape->ndim(); ++j) {
CHECK_GE((*shape)[j], 0) << "Legacy shape cannot have dim size < 0";
if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown
(*shape)[j] = -1; // unknown dim size = -1
}
Expand Down
6 changes: 4 additions & 2 deletions src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
for (size_t i = 0; i < nds.size(); ++i) {
const NDArray& nd = nds[i];
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
const TBlob& nd_data = nd.data();

if (i == 0) {
Expand All @@ -234,6 +232,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
case kCSRStorage: {
const TBlob& nd_indices = nd.aux_data(csr::kIdx);
const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr);
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type
if (nd.storage_initialized()) {
Expand All @@ -248,6 +248,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
}
case kRowSparseStorage: {
const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx);
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
if (nd.storage_initialized()) {
const nnvm::dim_t nz_rows = nd_indices.Size();
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p
mkldnn_memory_format_t format);

inline bool same_shape(const mxnet::TShape &shape, const mkldnn_dims_t dims, int ndims) {
if (shape.ndim() != (size_t)ndims)
if (shape.ndim() != ndims)
return false;
for (int i = 0; i < ndims; i++)
if (shape[i] != dims[i])
Expand Down