Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Oct 22, 2019
1 parent b30f423 commit 380bf03
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
14 changes: 7 additions & 7 deletions src/operator/numpy/np_memory_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
* \file np_memory_op.cu
*/

#include "./np_memory_op.h"
#include "./np_memory_op.h"

namespace mxnet {
namespace op {
namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_share_memory)
.set_attr<FCompute>("FCompute<gpu>", NumpyShareMemoryCompute<gpu>);
NNVM_REGISTER_OP(_npi_share_memory)
.set_attr<FCompute>("FCompute<gpu>", NumpyShareMemoryCompute<gpu>);

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet
10 changes: 6 additions & 4 deletions src/operator/numpy/np_memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_

#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include "../operator_common.h"
Expand All @@ -43,12 +44,13 @@ void NumpyShareMemoryCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& a = inputs[0];
const TBlob& b = inputs[1];
const TBlob& outdata = outputs[0];
Tensor<xpu, 1, bool> outdata = outputs[0].FlatTo1D<xpu, bool>(s);

if (a.Size() == 0 || b.Size() == 0) {
*(outdata.dptr<bool>()) = false;
ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, false);
return;
}
MSHADOW_TYPE_SWITCH_WITH_BOOL(a.type_flag_, AType, {
Expand All @@ -58,9 +60,9 @@ void NumpyShareMemoryCompute(const nnvm::NodeAttrs& attrs,
uint64_t start2 = reinterpret_cast<uint64_t>(b.dptr_);
uint64_t end2 = start2 + b.Size() * sizeof(BType);
if (!(start1 < end2 && start2 < end1 && start1 < end1 && start2 < end2)) {
*(outdata.dptr<bool>()) = false;
ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, false);
} else {
*(outdata.dptr<bool>()) = true;
ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, true);
}
});
});
Expand Down

0 comments on commit 380bf03

Please sign in to comment.