diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 034b497de264..7c2d20789365 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -181,8 +181,8 @@ def RandomInit(narray): block = list(zip(grad_narrays, arg_narrays, momentum_narrays)) np.random.seed(0) -# set random weight -x = mx.random.uniform(-1, 1, (10,),mx.gpu()) + + for name, narray in inputs.items(): if "weight" in name: narray[:] = np.random.uniform(-0.1, 0.1, narray.shape) diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 7a28f6cf2c0d..72a4456f592a 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -199,13 +199,6 @@ class Engine { ret.param_ = param; return ret; } - - /*! - * \brief Idempotent Finalize function. - * This function will signal engine to release all resources. - * It is safe to call this function multiple times. - */ - virtual void Finalize() = 0; }; // class Engine #endif // DMLC_USE_CXX11 } // namespace mxnet diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h index 7f8b07df245d..53a6dd6698e6 100644 --- a/include/mxnet/resource.h +++ b/include/mxnet/resource.h @@ -115,14 +115,6 @@ class ResourceManager { * \return Resource manager singleton. */ static ResourceManager *Get(); - - protected: - /*! - * \brief Idempotent Finalize function. - * This function will signal resource manager to release all resources. - * It is safe to call this function multiple times. - */ - virtual void Finalize() = 0; }; } // namespace mxnet #endif // MXNET_RESOURCE_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 41e67faf14ca..71d303ff01f3 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -63,14 +63,6 @@ class Storage { */ static std::shared_ptr _GetSharedRef(); - protected: - /*! - * \brief Idempotent Finalize function. - * This function will signal engine to release all resources. - * It is safe to call this function multiple times. - */ - virtual void Finalize(); - private: /*! * \brief Hidden constructors. diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 609021f3dfea..9351e8a05163 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -15,13 +15,9 @@ class NaiveEngine final : public Engine { } // virtual destructor virtual ~NaiveEngine() { - Finalize(); - } - - void Finalize() override { #if MXNET_USE_CUDA for (size_t i = 0; i < streams_.size(); ++i) { - if (streams_[i] != nullptr) { + if (streams_[i] != nullptr) { // Catch exception for CUDA driver shutdown try { mshadow::DeleteStream(streams_[i]); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 96c180fbc4d8..c392e7d6ce3c 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "../common/cuda_utils.h" @@ -117,7 +118,7 @@ void StreamManager::Finalize() { #if MXNET_USE_CUDA for (std::size_t i = 0; i < kNumGpus; ++i) { if (gpu_cnt_.at(i) != -1) { - for (auto&& j : gpu_streams_.at(i)) { + for (auto&& j : gpu_streams_.at(i)) { // Catch exception for CUDA driver shutdown try { mshadow::DeleteStream(j); diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 74314479cbe2..1a3144e783ec 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -291,12 +291,6 @@ void ThreadedEngine::WaitForAll() { finished_cv_.wait(lock, [this]() { return pending_.load() == 0; }); } -void ThreadedEngine::Finalize() { - // unlock all threads - pending_.store(0); - finished_cv_.notify_all(); -} - inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { // Mark complete for read variables for (auto&& i : threaded_opr->const_vars) { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 33914bbeafcc..fa29939d291f 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -249,7 +249,6 @@ class ThreadedEngine : public Engine { threaded_opr->fn(run_ctx, callback); OprBlock::Delete(opr_block); } - void Finalize() override; private: /*! diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index ff163a4be7ce..35d333b3f85b 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -37,7 +37,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine { // GPU tasks will be created lazily } ~ThreadedEnginePerDevice() noexcept(false) { - Finalize(); + gpu_normal_workers_.Clear(); + gpu_copy_workers_.Clear(); + cpu_worker_.reset(nullptr); } protected: @@ -63,13 +65,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } } } - // finalize the internal resources - void Finalize() override { - gpu_normal_workers_.Clear(); - gpu_copy_workers_.Clear(); - cpu_worker_.reset(nullptr); - ThreadedEngine::Finalize(); - } private: // working unit for each of the task. diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 49b120869b73..3a32623776b5 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -28,7 +28,12 @@ class ThreadedEnginePooled : public ThreadedEngine { io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} ~ThreadedEnginePooled() noexcept(false) { - Finalize(); + // wait until all the tasks are completed. + // TODO(hotpxl) think if this is the correct thing to do + this->WaitForAll(); + streams_.Finalize(); + task_queue_.SignalForKill(); + io_task_queue_.SignalForKill(); } protected: @@ -39,17 +44,6 @@ class ThreadedEnginePooled : public ThreadedEngine { DoPushToQueue(opr_block); } } - // finalize the internal resources - void Finalize() override { - // wait until all the tasks are completed. - // TODO(hotpxl) think if this is the correct thing to do - this->WaitForAll(); - streams_.Finalize(); - task_queue_.SignalForKill(); - io_task_queue_.SignalForKill(); - ThreadedEngine::Finalize(); - } - private: /*! \brief Concurrency for thread pool */ diff --git a/src/resource.cc b/src/resource.cc index 953b5383e54a..a2d51c5516a0 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -29,7 +29,18 @@ class ResourceManagerImpl : public ResourceManager { Context(cpu::kDevMask, 0), cpu_temp_space_copy_)); } ~ResourceManagerImpl() { - Finalize(); + // need explicit delete, before engine get killed + cpu_rand_.reset(nullptr); + cpu_space_.reset(nullptr); +#if MXNET_USE_CUDA + gpu_rand_.Clear(); + gpu_space_.Clear(); +#endif + if (engine_ref_ != nullptr) { + engine_ref_->WaitForAll(); + // release the reference to engine. + engine_ref_ = nullptr; + } } // request resources @@ -74,22 +85,6 @@ class ResourceManagerImpl : public ResourceManager { #endif } - protected: - void Finalize() override { - // need explicit delete, before engine get killed - cpu_rand_.reset(nullptr); - cpu_space_.reset(nullptr); -#if MXNET_USE_CUDA - gpu_rand_.Clear(); - gpu_space_.Clear(); -#endif - if (engine_ref_ != nullptr) { - engine_ref_->WaitForAll(); - // release the reference to engine. - engine_ref_ = nullptr; - } - } - private: /*! \brief Maximum number of GPUs */ static constexpr std::size_t kMaxNumGPUs = 16; @@ -115,7 +110,7 @@ class ResourceManagerImpl : public ResourceManager { } ~ResourceRandom() { mshadow::Random *r = prnd; - Engine::Get()->DeleteVariable( + Engine::Get()->DeleteVariable( [r](RunContext rctx){ delete r; }, ctx, resource.var); } // set seed to a PRNG diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 769cb35f0866..e41f4e701f80 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -52,10 +52,6 @@ struct Storage::Impl { Storage::Handle Storage::Alloc(size_t size, Context ctx) { // space already recycled, ignore request - if (impl_.get() == nullptr) { - LOG(FATAL) << "Alloc called after finalize"; - } - Handle hd; hd.ctx = ctx; hd.size = size; @@ -92,9 +88,6 @@ Storage::Handle Storage::Alloc(size_t size, Context ctx) { } void Storage::Free(Storage::Handle handle) { - // space already recycled, ignore request - if (impl_.get() == nullptr) return; - std::lock_guard lock{impl_->m}; Impl::ActivateDevice(handle.ctx); impl_->storage_managers.at(handle.ctx.dev_mask) @@ -116,8 +109,5 @@ Storage* Storage::Get() { Storage::Storage() : impl_{new Impl{}} {} -void Storage::Finalize() { - impl_.reset(nullptr); -} } // namespace mxnet