Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
more on finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 17, 2015
1 parent 84350f3 commit fbb1418
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
14 changes: 14 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@

from .context import Context, current_context, cpu, gpu
from .base import MXNetError
from . import base
from . import ndarray
from . import symbol
from . import kvstore as kv
from . import io
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import random
import atexit

__version__ = "0.1.0"

def finalize():
"""Stop all the components in mxnet.
There is no need to call this function.
This function will be automatically called at module exit.
"""
# pylint: disable=protected-access
base.check_call(base._LIB.MXFinalize())
kv._cleanup()

atexit.register(finalize)
8 changes: 0 additions & 8 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import ctypes
import platform
import numpy as np
import atexit

__all__ = ['MXNetError']
#----------------------------
Expand Down Expand Up @@ -180,10 +179,3 @@ def ctypes2numpy_shared(cptr, shape):
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)


def stop_all():
"""Stop All the components in mxnet."""
check_call(_LIB.MXFinalize())


atexit.register(stop_all)
9 changes: 3 additions & 6 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from .ndarray import NDArray
from .base import _LIB
from .base import check_call, c_array, NDArrayHandle
import atexit

__all__ = ['start', 'init', 'push', 'pull', 'stop', 'set_updater']
__all__ = ['start', 'init', 'push', 'pull', 'set_updater']

def _ctype_key_value(keys, vals):
"""
Expand Down Expand Up @@ -213,11 +212,9 @@ def set_updater(updater):
_updater_func = _updater_proto(_updater_wrapper(updater))
check_call(_LIB.MXKVStoreSetUpdater(_updater_func))

def stop():
""" Stop the kvstore """
def _cleanup():
""" cleanup callbacks """
# need to clear _updater_func before _LIB
global _updater_func
_updater_func = None


atexit.register(stop)
21 changes: 21 additions & 0 deletions src/global.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*!
* Copyright (c) 2015 by Contributors
* \file global.cc
* \brief Implementation of project global related functions.
*/
#include <mxnet/base.h>
#include <mxnet/engine.h>
#include <mxnet/storage.h>
#include <mxnet/resource.h>
#include <mxnet/kvstore.h>

namespace mxnet {
// finalize the mxnet modules
void Finalize() {
ResourceManager::Get()->Finalize();
KVStore::Get()->Finalize();
Engine::Get()->WaitForAll();
Engine::Get()->Finalize();
Storage::Get()->Finalize();
}
} // namespace mxnet
8 changes: 0 additions & 8 deletions tests/python/unittest/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ def init_kv():
# list
mx.kv.init(keys, [mx.nd.zeros(shape)] * len(keys))

def stop_kv():
"""stop kv """
mx.kv.stop()

def check_diff_to_scalar(A, x):
""" assert A == x"""
assert(np.sum(np.abs((A - x).asnumpy())) == 0)
Expand All @@ -30,7 +26,6 @@ def test_single_kv_pair():
mx.kv.pull(3, out = val)
check_diff_to_scalar(val, 1)

stop_kv()

def test_list_kv_pair():
"""list key-value pair push & pull"""
Expand All @@ -43,7 +38,6 @@ def test_list_kv_pair():
for v in val:
check_diff_to_scalar(v, 4)

stop_kv()

def test_aggregator():
"""aggregate value on muliple devices"""
Expand Down Expand Up @@ -72,7 +66,6 @@ def test_aggregator():
for v in vv:
check_diff_to_scalar(v, num_devs * 2.0)

stop_kv()

def updater(key, recv, local):
"""use updater: +="""
Expand Down Expand Up @@ -110,7 +103,6 @@ def test_updater(dev = 'cpu'):
for v in vv:
check_diff_to_scalar(v, num_devs * num_push)

stop_kv()

if __name__ == '__main__':
test_single_kv_pair()
Expand Down

0 comments on commit fbb1418

Please sign in to comment.