Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
create new instance var to store copy
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Aug 7, 2018
1 parent 3d1af14 commit c0e639a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs);

namespace exec {

class MKLDNNOpExecutor : public OpExecutor {
protected:
std::vector<NDArray> in_array_fallback;
};

// abstract OpExecutor which provides storage fallback procedure on
// non-default inputs and outputs
// FComputeExecutor and FStatefulComputeExecutor inherit from this class
Expand Down Expand Up @@ -153,14 +158,14 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {


// stateful compute_ex executor
class StatefulComputeExExecutor : public OpExecutor {
class StatefulComputeExExecutor : public MKLDNNOpExecutor {
public:
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
in_array = CreateDefaultInputs(in_array);
fcompute_(state_, op_ctx, in_array, req, out_array);
in_array_fallback = CreateDefaultInputs(in_array);
fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
return;
#endif
fcompute_(state_, op_ctx, in_array, req, out_array);
Expand Down Expand Up @@ -223,16 +228,16 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
};

// fcompute_ex executor
class FComputeExExecutor : public OpExecutor {
class FComputeExExecutor : public MKLDNNOpExecutor {
public:
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs_.op, false)) {
in_array = CreateDefaultInputs(in_array);
fcompute_(attrs_, op_ctx, in_array, req, out_array);
in_array_fallback = CreateDefaultInputs(in_array);
fcompute_(attrs_, op_ctx, in_array_fallback, req, out_array);
return;
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_copy)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<bool>("TIsMKLDNN", true).set_attr<bool>("TIsMKLDNN", true)
.set_attr<bool>("TIsMKLDNN", true)
#endif
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
Expand Down

0 comments on commit c0e639a

Please sign in to comment.