diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 6014cc0cba4f..57983dcc49af 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -353,8 +353,12 @@ static inline void InvalidateOutputs(const std::vector &arrs, static inline std::vector CreateDefaultInputs(const std::vector &arrs) { std::vector buffer(arrs.size()); - for (size_t i = 0; i < arrs.size(); ++i) - buffer[i] = arrs[i].Reorder2Default(); + for (size_t i = 0; i < arrs.size(); ++i) { + if (arrs[i].IsMKLDNNData()) { + buffer[i] = arrs[i].Reorder2Default(); + } + buffer[i] = arrs[i]; + } return buffer; }