Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #89 from tqchen/master
Browse files Browse the repository at this point in the history
Add Finalize module
  • Loading branch information
antinucleon committed Sep 17, 2015
2 parents 4f1441e + fbb1418 commit a9d5227
Show file tree
Hide file tree
Showing 22 changed files with 164 additions and 57 deletions.
6 changes: 6 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ typedef mshadow::TShape TShape;
/*! \brief storage container type */
typedef mshadow::TBlob TBlob;

/*!
* \brief Finalize and shutdown all related modules of mxnet.
* Call this function at end of program to ensure correct shutdown.
*/
void Finalize();

/*! \brief Context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
Expand Down
5 changes: 5 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ MXNET_DLL const char *MXGetLastError();
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXRandomSeed(int seed);
/*!
* \brief Finalize and shutdown all related modules of mxnet.
* Call this function at end of program to ensure correct shutdown.
*/
MXNET_DLL int MXFinalize();
//-------------------------------------
// Part 1: NDArray creation and deletion
//-------------------------------------
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ class Engine {
ret.param_ = param;
return ret;
}
// friend function
friend void ::mxnet::Finalize();
/*!
* \brief Idempotent Finalize function.
* This function will signal engine to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize() = 0;
}; // class Engine
#endif // DMLC_USE_CXX11
} // namespace mxnet
Expand Down
11 changes: 7 additions & 4 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class KVStore {
virtual void Start();

/**
* \brief Stop
* \brief Finalize the KVStore
*
* clear all key-value pairs stored, updater, and devices binded
*/
virtual void Stop() {
if (impl_) { impl_->Stop(); delete impl_; impl_ = NULL; }
virtual void Finalize() {
if (impl_) { impl_->Finalize(); delete impl_; impl_ = NULL; }
}

/**
Expand Down Expand Up @@ -178,7 +178,10 @@ class KVStore {

protected:
KVStore() : impl_(NULL) { }
virtual ~KVStore() { delete impl_; impl_ = NULL; }

virtual ~KVStore() {
delete impl_; impl_ = NULL;
}

private:
inline KVStore* get_impl() const {
Expand Down
12 changes: 11 additions & 1 deletion include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,21 @@ class ResourceManager {
*/
virtual void SeedRandom(uint32_t seed) = 0;
/*! \brief virtual destructor */
virtual ~ResourceManager() {}
virtual ~ResourceManager() DMLC_THROW_EXCEPTION {}
/*!
* \return Resource manager singleton.
*/
static ResourceManager *Get();

protected:
// friend function
friend void ::mxnet::Finalize();
/*!
* \brief Idempotent Finalize function.
* This function will signal resource manager to release all resources.
* It is safe to call this function multiple times.
*/
virtual void Finalize() = 0;
};
} // namespace mxnet
#endif // MXNET_RESOURCE_H_
10 changes: 10 additions & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ class Storage {
*/
static std::shared_ptr<Storage> _GetSharedRef();

protected:
// friend function
friend void ::mxnet::Finalize();
/*!
* \brief Idempotent Finalize function.
* This function will signal engine to release all resources.
* It is safe to call this function multiple times.
*/
void Finalize();

private:
/*!
* \brief Hidden constructors.
Expand Down
14 changes: 14 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@

from .context import Context, current_context, cpu, gpu
from .base import MXNetError
from . import base
from . import ndarray
from . import symbol
from . import kvstore as kv
from . import io
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import random
import atexit

__version__ = "0.1.0"

def finalize():
"""Stop all the components in mxnet.
There is no need to call this function.
This function will be automatically called at module exit.
"""
# pylint: disable=protected-access
base.check_call(base._LIB.MXFinalize())
kv._cleanup()

atexit.register(finalize)
1 change: 0 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,3 @@ def ctypes2numpy_shared(cptr, shape):
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)


9 changes: 3 additions & 6 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from .ndarray import NDArray
from .base import _LIB
from .base import check_call, c_array, NDArrayHandle
import atexit

__all__ = ['start', 'init', 'push', 'pull', 'stop', 'set_updater']
__all__ = ['start', 'init', 'push', 'pull', 'set_updater']

def _ctype_key_value(keys, vals):
"""
Expand Down Expand Up @@ -213,11 +212,9 @@ def set_updater(updater):
_updater_func = _updater_proto(_updater_wrapper(updater))
check_call(_LIB.MXKVStoreSetUpdater(_updater_func))

def stop():
""" Stop the kvstore """
check_call(_LIB.MXKVStoreStop())
def _cleanup():
""" cleanup callbacks """
# need to clear _updater_func before _LIB
global _updater_func
_updater_func = None

atexit.register(stop)
12 changes: 6 additions & 6 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ int MXRandomSeed(int seed) {
API_END();
}

int MXFinalize() {
API_BEGIN();
mxnet::Finalize();
API_END();
}

int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray();
Expand Down Expand Up @@ -891,12 +897,6 @@ int MXKVStoreStart() {
API_END();
}

int MXKVStoreStop() {
API_BEGIN();
KVStore::Get()->Stop();
API_END();
}

int MXKVStoreSetUpdater(MXKVStoreUpdater updater) {
API_BEGIN();
auto updt = [updater](int key, const NDArray& recv, NDArray* local) {
Expand Down
6 changes: 6 additions & 0 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@ class NaiveEngine final : public Engine {
}
// virtual destructor
virtual ~NaiveEngine() {
Finalize();
}

void Finalize() override {
#if MXNET_USE_CUDA
for (size_t i = 0; i < streams_.size(); ++i) {
if (streams_[i] != nullptr) {
mshadow::DeleteStream(streams_[i]);
streams_[i] = nullptr;
}
}
#endif
}

// new variables
VarHandle NewVariable() override {
return nullptr;
Expand Down
9 changes: 6 additions & 3 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ template <std::size_t kNumGpus, std::size_t kStreams>
class StreamManager {
public:
StreamManager();
~StreamManager();
~StreamManager() {
Finalize();
}
RunContext GetRunContext(Context const& ctx);
RunContext GetIORunContext(Context const& ctx);

void Finalize();
private:
std::mutex m_;
#if MXNET_USE_CUDA
Expand Down Expand Up @@ -111,13 +113,14 @@ StreamManager<kNumGpus, kStreams>::StreamManager() {
}

template <std::size_t kNumGpus, std::size_t kStreams>
StreamManager<kNumGpus, kStreams>::~StreamManager() {
void StreamManager<kNumGpus, kStreams>::Finalize() {
#if MXNET_USE_CUDA
for (std::size_t i = 0; i < kNumGpus; ++i) {
if (gpu_cnt_.at(i) != -1) {
for (auto&& j : gpu_streams_.at(i)) {
mshadow::DeleteStream<gpu>(j);
}
gpu_cnt_.at(i) = -1;
}
}
#endif // MXNET_USE_CUDA
Expand Down
6 changes: 6 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() { return pending_.load() == 0; });
}

void ThreadedEngine::Finalize() {
// unlock all threads
pending_.store(0);
finished_cv_.notify_all();
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
Expand Down
1 change: 1 addition & 0 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class ThreadedEngine : public Engine {
threaded_opr->fn(run_ctx, callback);
OprBlock::Delete(opr_block);
}
void Finalize() override;

private:
/*!
Expand Down
20 changes: 13 additions & 7 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,17 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 2);
gpu_worker_nthreads_ = dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 2);
gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 1);

// create CPU task
auto *cpu_queue = &(cpu_worker_.task_queue);
cpu_worker_.pool.reset(new ThreadPool(
cpu_worker_.reset(new ThreadWorkerBlock());
auto *cpu_queue = &(cpu_worker_->task_queue);
cpu_worker_->pool.reset(new ThreadPool(
cpu_worker_nthreads_, [this, cpu_queue] {
this->CPUWorker(cpu_queue);
}));
// GPU tasks will be created lazily
}
~ThreadedEnginePerDevice() noexcept(false) {
// wait until all the tasks are completed.
this->WaitForAll();
Finalize();
}

protected:
Expand All @@ -56,7 +55,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
this->ExecuteOprBlock(run_ctx, opr_block);
} else {
if (ctx.dev_mask == cpu::kDevMask) {
cpu_worker_.task_queue.Push(opr_block);
cpu_worker_->task_queue.Push(opr_block);
} else {
CHECK_EQ(ctx.dev_mask, gpu::kDevMask);
ThreadWorkerBlock* block = this->GetGPUWorkerBlock(
Expand All @@ -65,6 +64,13 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
}
}
}
// finalize the internal resources
void Finalize() override {
gpu_normal_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_worker_.reset(nullptr);
ThreadedEngine::Finalize();
}

private:
// working unit for each of the task.
Expand All @@ -85,7 +91,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
/*! \brief number of concurrent thread each gpu copy worker uses */
int gpu_copy_nthreads_;
// cpu worker
ThreadWorkerBlock cpu_worker_;
std::unique_ptr<ThreadWorkerBlock> cpu_worker_;
// workers doing normal works on GPU
common::LazyAllocArray<ThreadWorkerBlock> gpu_normal_workers_;
// workers doing copy works from/to GPU
Expand Down
17 changes: 12 additions & 5 deletions src/engine/threaded_engine_pooled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ class ThreadedEnginePooled : public ThreadedEngine {
io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {}

~ThreadedEnginePooled() noexcept(false) {
// wait until all the tasks are completed.
// TODO(hotpxl) think if this is the correct thing to do
this->WaitForAll();
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
Finalize();
}

protected:
Expand All @@ -42,6 +38,17 @@ class ThreadedEnginePooled : public ThreadedEngine {
DoPushToQueue(opr_block);
}
}
// finalize the internal resources
void Finalize() override {
// wait until all the tasks are completed.
// TODO(hotpxl) think if this is the correct thing to do
this->WaitForAll();
streams_.Finalize();
task_queue_.SignalForKill();
io_task_queue_.SignalForKill();
ThreadedEngine::Finalize();
}


private:
/*! \brief Concurrency for thread pool */
Expand Down
21 changes: 21 additions & 0 deletions src/global.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*!
* Copyright (c) 2015 by Contributors
* \file global.cc
* \brief Implementation of project global related functions.
*/
#include <mxnet/base.h>
#include <mxnet/engine.h>
#include <mxnet/storage.h>
#include <mxnet/resource.h>
#include <mxnet/kvstore.h>

namespace mxnet {
// finalize the mxnet modules
void Finalize() {
ResourceManager::Get()->Finalize();
KVStore::Get()->Finalize();
Engine::Get()->WaitForAll();
Engine::Get()->Finalize();
Storage::Get()->Finalize();
}
} // namespace mxnet
2 changes: 1 addition & 1 deletion src/kvstore/kvstore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace mxnet {

void KVStore::Start() {
if (impl_ != NULL) Stop();
if (impl_ != NULL) Finalize();
char* num_worker = getenv("DMLC_NUM_WORKER");
if (num_worker == NULL || atoi(num_worker) == 1) {
impl_ = new KVStoreLocal();
Expand Down
3 changes: 1 addition & 2 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class KVStoreLocal : public KVStore {
virtual ~KVStoreLocal() { Clear(); }

virtual void Start() { }

virtual void Stop() { Clear(); }
virtual void Finalize() { Clear(); }

virtual void set_updater(const Updater& updater) {
updater_ = updater;
Expand Down
Loading

0 comments on commit a9d5227

Please sign in to comment.