Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ensure proper destruct order of NDArray
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed Feb 26, 2019
1 parent 43428c1 commit 9be8574
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,14 +848,17 @@ class NDArray {
// The shape of aux data. The default value for the shape depends on the type of storage.
// If aux_shapes[i].Size() is zero, aux data i is empty.
std::vector<TShape> aux_shapes;
/*! \brief Reference to the storage to ensure proper destruct order */
std::shared_ptr<Storage> storage_ref_;

/*! \brief default cosntructor */
Chunk() : static_data(true), delay_alloc(false) {}
Chunk() : static_data(true), delay_alloc(false),
storage_ref_(Storage::_GetSharedRef()) {}

/*! \brief construct a new chunk */
Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype)
: static_data(false), delay_alloc(true), ctx(ctx_) {
auto size = shape.Size();
: static_data(false), delay_alloc(true), ctx(ctx_),
storage_ref_(Storage::_GetSharedRef()) { auto size = shape.Size();
storage_shape = shape;
var = Engine::Get()->NewVariable();
shandle.size = size * mshadow::mshadow_sizeof(dtype);
Expand All @@ -864,7 +867,8 @@ class NDArray {
}

Chunk(const TBlob &data, int dev_id)
: static_data(true), delay_alloc(false) {
: static_data(true), delay_alloc(false),
storage_ref_(Storage::_GetSharedRef()) {
CHECK(storage_type == kDefaultStorage);
var = Engine::Get()->NewVariable();
if (data.dev_mask() == cpu::kDevMask) {
Expand All @@ -881,7 +885,8 @@ class NDArray {
}

Chunk(int shared_pid, int shared_id, const TShape& shape, int dtype)
: static_data(false), delay_alloc(false) {
: static_data(false), delay_alloc(false),
storage_ref_(Storage::_GetSharedRef()) {
var = Engine::Get()->NewVariable();
ctx = Context::CPUShared(0);
shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
Expand All @@ -897,7 +902,7 @@ class NDArray {
const std::vector<TShape> &aux_shapes_)
: static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
aux_shapes(aux_shapes_) {
aux_shapes(aux_shapes_), storage_ref_(Storage::_GetSharedRef()) {
shandle.ctx = ctx;
var = Engine::Get()->NewVariable();
// aux_handles always reflect the correct number of aux data
Expand All @@ -914,7 +919,8 @@ class NDArray {

Chunk(const NDArrayStorageType storage_type_, const TBlob &data,
const std::vector<TBlob> &aux_data, int dev_id)
: static_data(true), delay_alloc(false), storage_type(storage_type_) {
: static_data(true), delay_alloc(false), storage_type(storage_type_),
storage_ref_(Storage::_GetSharedRef()) {
using namespace mshadow;
CHECK_NE(storage_type, kDefaultStorage);
// init var
Expand Down

0 comments on commit 9be8574

Please sign in to comment.