From d87559278991ffbe9bd8b0113fb330d98618b776 Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Fri, 24 Aug 2018 11:24:00 -0700 Subject: [PATCH] create fallback arrays in place --- src/executor/attach_op_execs_pass.cc | 4 ++-- src/operator/nn/mkldnn/mkldnn_base-inl.h | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 1beeab9f0070..c619961a2f24 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -159,7 +159,7 @@ class StatefulComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(out_array, req); - in_array_fallback = CreateDefaultInputs(in_array); + CreateDefaultInputs(in_array, in_array_fallback); fcompute_(state_, op_ctx, in_array_fallback, req, out_array); return; #endif @@ -232,7 +232,7 @@ class FComputeExExecutor : public OpExecutor { // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); if (!is_mkldnn.get(attrs_.op, false)) { - in_array_fallback = CreateDefaultInputs(in_array); + CreateDefaultInputs(in_array, in_array_fallback); fcompute_(attrs_, op_ctx, in_array_fallback, req, out_array); return; } diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 7a8e09b6b1c8..64a67237f6f9 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -357,15 +357,14 @@ static inline void InvalidateOutputs(const std::vector &arrs, } // TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added -static inline std::vector CreateDefaultInputs(const std::vector &arrs) { - std::vector buffer(arrs.size()); +static inline void CreateDefaultInputs(const std::vector &arrs, + const std::vector &out_arrs) { for (size_t i = 0; i < arrs.size(); ++i) { if (arrs[i].IsMKLDNNData()) - buffer[i] = arrs[i].Reorder2Default(); + out_arrs[i] = arrs[i].Reorder2Default(); else - buffer[i] = arrs[i]; + out_arrs[i] = arrs[i]; } - return buffer; } const mkldnn::memory *GetWeights(const NDArray &arr,