Skip to content

Commit

Permalink
Revert "Fix memory leaks in Gluon (apache#18328) (apache#18359)" (apa…
Browse files Browse the repository at this point in the history
…che#19181)

This reverts commit b523527.
  • Loading branch information
leezu authored Sep 19, 2020
1 parent 0fce381 commit 0496690
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 55 deletions.
21 changes: 8 additions & 13 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
import threading
import copy
import warnings
import weakref
from collections import OrderedDict, defaultdict

import re
from collections import OrderedDict, defaultdict
import numpy as np

from ..base import mx_real_t, MXNetError
Expand All @@ -48,7 +46,7 @@ class _BlockScope(object):
_current = threading.local()

def __init__(self, block):
self._block = weakref.ref(block) if block is not None else None
self._block = block
self._counter = {}
self._old_scope = None
self._name_scope = None
Expand All @@ -57,8 +55,7 @@ def __init__(self, block):
def create(prefix, params, hint):
"""Creates prefix and params for new `Block`."""
current = getattr(_BlockScope._current, "value", None)
block = current._block() if current is not None else None
if current is None or block is None:
if current is None:
if prefix is None:
if not hasattr(_name.NameManager._current, "value"):
_name.NameManager._current.value = _name.NameManager()
Expand All @@ -74,25 +71,23 @@ def create(prefix, params, hint):
prefix = '%s%d_'%(hint, count)
current._counter[hint] = count + 1
if params is None:
parent = block.params
parent = current._block.params
params = ParameterDict(parent.prefix+prefix, parent._shared)
else:
params = ParameterDict(params.prefix, params)
return block.prefix + prefix, params
return current._block.prefix+prefix, params

def __enter__(self):
block = self._block()
if block is None or block._empty_prefix:
if self._block._empty_prefix:
return self
self._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
self._name_scope = _name.Prefix(block.prefix)
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope.__enter__()
return self

def __exit__(self, ptype, value, trace):
block = self._block()
if block is None or block._empty_prefix:
if self._block._empty_prefix:
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
Expand Down
39 changes: 0 additions & 39 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import os
import tempfile
import gc

import mxnet as mx
from mxnet import gluon
Expand Down Expand Up @@ -3230,44 +3229,6 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]

if __name__ == '__main__':
import nose
nose.runmodule()
5 changes: 2 additions & 3 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ def __init__(self, prefix):
status = [False]
event = threading.Event()
def f():
net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block
with block._BlockScope(net):
x = NameManager.current.get(None, "hello")
with block._BlockScope(dummy_block("spawned_")):
x= NameManager.current.get(None, "hello")
event.wait()
if x == "spawned_hello0":
status[0] = True
Expand Down

0 comments on commit 0496690

Please sign in to comment.