Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix engine crash in shutdown phase (#14382)
Browse files Browse the repository at this point in the history
* fix engine crash in shutdown phase

* fix lint

* Revert "Bypass ThreadedEngine in test_operator_gpu.py:test_convolution_multiple_streams. (#14338)"

This reverts commit d6eafca.
  • Loading branch information
arcadiaphy authored and nswamy committed Apr 5, 2019
1 parent 43b03ab commit d5bf85b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
9 changes: 9 additions & 0 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/omp.h>
#include <mxnet/storage.h>
#include <vector>
#include <functional>
#include <condition_variable>
Expand Down Expand Up @@ -306,6 +307,8 @@ class ThreadedEngine : public Engine {
objpool_varblk_ref_ = common::ObjectPool<VersionedVarBlock>::_GetSharedRef();
objpool_var_ref_ = common::ObjectPool<ThreadedVar>::_GetSharedRef();

storage_ref_ = Storage::_GetSharedRef();

// Get a ref to the profiler so that it doesn't get killed before us
profiler::Profiler::Get(&profiler_);
}
Expand Down Expand Up @@ -549,6 +552,12 @@ class ThreadedEngine : public Engine {
std::shared_ptr<common::ObjectPool<VersionedVarBlock> > objpool_varblk_ref_;
std::shared_ptr<common::ObjectPool<ThreadedVar> > objpool_var_ref_;

/*!
* \brief Async destruction of some objects is relied on storage,
* prevent it from being destructed too early
*/
std::shared_ptr<Storage> storage_ref_;

#if MXNET_USE_CUDA
/*! \brief Number of GPU devices available */
std::atomic<int> device_count_{-1};
Expand Down
12 changes: 1 addition & 11 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,18 +547,8 @@ def _conv_with_num_streams(seed):

@with_seed()
def test_convolution_multiple_streams():
engines = ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']

if os.getenv('MXNET_ENGINE_TYPE') is not None:
engines = [os.getenv('MXNET_ENGINE_TYPE'),]
print("Only running against '%s'" % engines[0], file=sys.stderr, end='')
# Remove this else clause when the ThreadedEngine can handle this test
else:
engines.remove('ThreadedEngine')
print("SKIP: 'ThreadedEngine', only running against %s" % engines, file=sys.stderr, end='')

for num_streams in [1, 2]:
for engine in engines:
for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']:
print("Starting engine %s with %d streams." % (engine, num_streams), file=sys.stderr)
run_in_spawned_process(_conv_with_num_streams,
{'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine})
Expand Down

0 comments on commit d5bf85b

Please sign in to comment.