Skip to content

Commit

Permalink
Fix a bug to pass the test in test_contrib_rnn (apache#14520)
Browse files Browse the repository at this point in the history
* fix.

* remove type conversion.

* remove type cast.
  • Loading branch information
zheng-da authored and reminisce committed Apr 10, 2019
1 parent 59c46f5 commit 12311ab
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 0 additions & 1 deletion src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ inline void ConvertToNumpyShape(mxnet::TShape* shape) {
*shape = mxnet::TShape(); // unknown shape ndim = -1
} else {
for (int j = 0; j < shape->ndim(); ++j) {
CHECK_GE((*shape)[j], 0) << "Legacy shape cannot have dim size < 0";
if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown
(*shape)[j] = -1; // unknown dim size = -1
}
Expand Down
6 changes: 4 additions & 2 deletions src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
for (size_t i = 0; i < nds.size(); ++i) {
const NDArray& nd = nds[i];
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
const TBlob& nd_data = nd.data();

if (i == 0) {
Expand All @@ -234,6 +232,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
case kCSRStorage: {
const TBlob& nd_indices = nd.aux_data(csr::kIdx);
const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr);
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type
if (nd.storage_initialized()) {
Expand All @@ -248,6 +248,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
}
case kRowSparseStorage: {
const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx);
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
if (nd.storage_initialized()) {
const nnvm::dim_t nz_rows = nd_indices.Size();
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p
mkldnn_memory_format_t format);

inline bool same_shape(const mxnet::TShape &shape, const mkldnn_dims_t dims, int ndims) {
if (shape.ndim() != (size_t)ndims)
if (shape.ndim() != ndims)
return false;
for (int i = 0; i < ndims; i++)
if (shape[i] != dims[i])
Expand Down

0 comments on commit 12311ab

Please sign in to comment.