diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index d00cb479b92e..08ef7e02bf9f 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -850,15 +850,20 @@ class NDArray { mxnet::ShapeVector aux_shapes; /*! \brief Reference to the storage to ensure proper destruct order */ std::shared_ptr storage_ref_; + /*! \brief Reference to the engine to ensure we cleanup without calling a destructed engine */ + std::weak_ptr engine_ref_; - /*! \brief default cosntructor */ + + /*! \brief default constructor */ Chunk() : static_data(true), delay_alloc(false), - storage_ref_(Storage::_GetSharedRef()) {} + storage_ref_(Storage::_GetSharedRef()), + engine_ref_(Engine::_GetSharedRef()) {} /*! \brief construct a new chunk */ Chunk(mxnet::TShape shape, Context ctx_, bool delay_alloc_, int dtype) : static_data(false), delay_alloc(true), ctx(ctx_), - storage_ref_(Storage::_GetSharedRef()) { + storage_ref_(Storage::_GetSharedRef()), + engine_ref_(Engine::_GetSharedRef()) { auto size = shape.Size(); storage_shape = shape; var = Engine::Get()->NewVariable(); @@ -869,7 +874,8 @@ class NDArray { Chunk(const TBlob &data, int dev_id) : static_data(true), delay_alloc(false), - storage_ref_(Storage::_GetSharedRef()) { + storage_ref_(Storage::_GetSharedRef()), + engine_ref_(Engine::_GetSharedRef()) { CHECK(storage_type == kDefaultStorage); var = Engine::Get()->NewVariable(); if (data.dev_mask() == cpu::kDevMask) { @@ -887,7 +893,8 @@ class NDArray { Chunk(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype) : static_data(false), delay_alloc(false), - storage_ref_(Storage::_GetSharedRef()) { + storage_ref_(Storage::_GetSharedRef()), + engine_ref_(Engine::_GetSharedRef()) { var = Engine::Get()->NewVariable(); ctx = Context::CPUShared(0); shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype); @@ -903,7 +910,8 @@ class NDArray { const mxnet::ShapeVector &aux_shapes_) : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_), aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_), - aux_shapes(aux_shapes_), storage_ref_(Storage::_GetSharedRef()) { + aux_shapes(aux_shapes_), storage_ref_(Storage::_GetSharedRef()), + engine_ref_(Engine::_GetSharedRef()) { shandle.ctx = ctx; var = Engine::Get()->NewVariable(); // aux_handles always reflect the correct number of aux data @@ -921,7 +929,7 @@ class NDArray { Chunk(const NDArrayStorageType storage_type_, const TBlob &data, const std::vector &aux_data, int dev_id) : static_data(true), delay_alloc(false), storage_type(storage_type_), - storage_ref_(Storage::_GetSharedRef()) { + storage_ref_(Storage::_GetSharedRef()), engine_ref_(Engine::_GetSharedRef()) { using namespace mshadow; CHECK_NE(storage_type, kDefaultStorage); // init var diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 377bef072b03..0428434b790e 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -113,20 +113,22 @@ NDArray::Chunk::~Chunk() { // We want to delete mkldnn memory after deleting the variable. mem.mem = this->mkl_mem_; #endif - Engine::Get()->DeleteVariable([mem, skip_free](RunContext s) { - if (skip_free == false) { + if (auto engine = engine_ref_.lock()) { + engine->DeleteVariable([mem, skip_free](RunContext s) { + if (skip_free == false) { #if MXNET_USE_MKLDNN == 1 - if (mem.mem) { - CHECK_LE(mem.mem->GetSize(), mem.h.size); - CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr); - } + if (mem.mem) { + CHECK_LE(mem.mem->GetSize(), mem.h.size); + CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr); + } #endif - Storage::Get()->Free(mem.h); - for (const auto& aux : mem.aux_h) { - Storage::Get()->Free(aux); + Storage::Get()->Free(mem.h); + for (const auto &aux : mem.aux_h) { + Storage::Get()->Free(aux); + } } - } - }, shandle.ctx, var); + }, shandle.ctx, var); + } } void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) {