Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix lint, move finalize function into destructor
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 19, 2015
1 parent 1a9d167 commit 84df51a
Show file tree
Hide file tree
Showing 12 changed files with 27 additions and 86 deletions.
4 changes: 2 additions & 2 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def RandomInit(narray):
block = list(zip(grad_narrays, arg_narrays, momentum_narrays))

np.random.seed(0)
# set random weight
x = mx.random.uniform(-1, 1, (10,),mx.gpu())


for name, narray in inputs.items():
if "weight" in name:
narray[:] = np.random.uniform(-0.1, 0.1, narray.shape)
Expand Down
7 changes: 0 additions & 7 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,6 @@ class Engine {
ret.param_ = param;
return ret;
}

/*!
* \brief Idempotent Finalize function.
* This function will signal engine to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize() = 0;
}; // class Engine
#endif // DMLC_USE_CXX11
} // namespace mxnet
Expand Down
8 changes: 0 additions & 8 deletions include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,6 @@ class ResourceManager {
* \return Resource manager singleton.
*/
static ResourceManager *Get();

protected:
/*!
* \brief Idempotent Finalize function.
* This function will signal resource manager to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize() = 0;
};
} // namespace mxnet
#endif // MXNET_RESOURCE_H_
8 changes: 0 additions & 8 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,6 @@ class Storage {
*/
static std::shared_ptr<Storage> _GetSharedRef();

protected:
/*!
* \brief Idempotent Finalize function.
* This function will signal engine to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize();

private:
/*!
* \brief Hidden constructors.
Expand Down
6 changes: 1 addition & 5 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ class NaiveEngine final : public Engine {
}
// virtual destructor
virtual ~NaiveEngine() {
Finalize();
}

void Finalize() override {
#if MXNET_USE_CUDA
for (size_t i = 0; i < streams_.size(); ++i) {
if (streams_[i] != nullptr) {
if (streams_[i] != nullptr) {
// Catch exception for CUDA driver shutdown
try {
mshadow::DeleteStream(streams_[i]);
Expand Down
3 changes: 2 additions & 1 deletion src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mxnet/base.h>
#include <cstddef>
#include <array>
#include <string>
#include <mutex>
#include "../common/cuda_utils.h"

Expand Down Expand Up @@ -117,7 +118,7 @@ void StreamManager<kNumGpus, kStreams>::Finalize() {
#if MXNET_USE_CUDA
for (std::size_t i = 0; i < kNumGpus; ++i) {
if (gpu_cnt_.at(i) != -1) {
for (auto&& j : gpu_streams_.at(i)) {
for (auto&& j : gpu_streams_.at(i)) {
// Catch exception for CUDA driver shutdown
try {
mshadow::DeleteStream<gpu>(j);
Expand Down
6 changes: 0 additions & 6 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,6 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() { return pending_.load() == 0; });
}

void ThreadedEngine::Finalize() {
// unlock all threads
pending_.store(0);
finished_cv_.notify_all();
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
Expand Down
1 change: 0 additions & 1 deletion src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ class ThreadedEngine : public Engine {
threaded_opr->fn(run_ctx, callback);
OprBlock::Delete(opr_block);
}
void Finalize() override;

private:
/*!
Expand Down
11 changes: 3 additions & 8 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
// GPU tasks will be created lazily
}
~ThreadedEnginePerDevice() noexcept(false) {
Finalize();
gpu_normal_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_worker_.reset(nullptr);
}

protected:
Expand All @@ -63,13 +65,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
}
}
}
// finalize the internal resources
void Finalize() override {
gpu_normal_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_worker_.reset(nullptr);
ThreadedEngine::Finalize();
}

private:
// working unit for each of the task.
Expand Down
18 changes: 6 additions & 12 deletions src/engine/threaded_engine_pooled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ class ThreadedEnginePooled : public ThreadedEngine {
io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {}

~ThreadedEnginePooled() noexcept(false) {
Finalize();
// wait until all the tasks are completed.
// TODO(hotpxl) think if this is the correct thing to do
this->WaitForAll();
streams_.Finalize();
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
}

protected:
Expand All @@ -39,17 +44,6 @@ class ThreadedEnginePooled : public ThreadedEngine {
DoPushToQueue(opr_block);
}
}
// finalize the internal resources
void Finalize() override {
// wait until all the tasks are completed.
// TODO(hotpxl) think if this is the correct thing to do
this->WaitForAll();
streams_.Finalize();
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
ThreadedEngine::Finalize();
}


private:
/*! \brief Concurrency for thread pool */
Expand Down
31 changes: 13 additions & 18 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@ class ResourceManagerImpl : public ResourceManager {
Context(cpu::kDevMask, 0), cpu_temp_space_copy_));
}
~ResourceManagerImpl() {
Finalize();
// need explicit delete, before engine get killed
cpu_rand_.reset(nullptr);
cpu_space_.reset(nullptr);
#if MXNET_USE_CUDA
gpu_rand_.Clear();
gpu_space_.Clear();
#endif
if (engine_ref_ != nullptr) {
engine_ref_->WaitForAll();
// release the reference to engine.
engine_ref_ = nullptr;
}
}

// request resources
Expand Down Expand Up @@ -74,22 +85,6 @@ class ResourceManagerImpl : public ResourceManager {
#endif
}

protected:
void Finalize() override {
// need explicit delete, before engine get killed
cpu_rand_.reset(nullptr);
cpu_space_.reset(nullptr);
#if MXNET_USE_CUDA
gpu_rand_.Clear();
gpu_space_.Clear();
#endif
if (engine_ref_ != nullptr) {
engine_ref_->WaitForAll();
// release the reference to engine.
engine_ref_ = nullptr;
}
}

private:
/*! \brief Maximum number of GPUs */
static constexpr std::size_t kMaxNumGPUs = 16;
Expand All @@ -115,7 +110,7 @@ class ResourceManagerImpl : public ResourceManager {
}
~ResourceRandom() {
mshadow::Random<xpu> *r = prnd;
Engine::Get()->DeleteVariable(
Engine::Get()->DeleteVariable(
[r](RunContext rctx){ delete r; }, ctx, resource.var);
}
// set seed to a PRNG
Expand Down
10 changes: 0 additions & 10 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ struct Storage::Impl {

Storage::Handle Storage::Alloc(size_t size, Context ctx) {
// space already recycled, ignore request
if (impl_.get() == nullptr) {
LOG(FATAL) << "Alloc called after finalize";
}

Handle hd;
hd.ctx = ctx;
hd.size = size;
Expand Down Expand Up @@ -92,9 +88,6 @@ Storage::Handle Storage::Alloc(size_t size, Context ctx) {
}

void Storage::Free(Storage::Handle handle) {
// space already recycled, ignore request
if (impl_.get() == nullptr) return;

std::lock_guard<std::mutex> lock{impl_->m};
Impl::ActivateDevice(handle.ctx);
impl_->storage_managers.at(handle.ctx.dev_mask)
Expand All @@ -116,8 +109,5 @@ Storage* Storage::Get() {

Storage::Storage() : impl_{new Impl{}} {}

void Storage::Finalize() {
impl_.reset(nullptr);
}

} // namespace mxnet

0 comments on commit 84df51a

Please sign in to comment.