Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.7.x] Fix memory leaks in Gluon #18358

Merged
merged 1 commit into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import threading
import copy
import warnings
import re
import weakref
from collections import OrderedDict, defaultdict

import re
import numpy as np

from ..base import mx_real_t, MXNetError
Expand All @@ -46,7 +48,7 @@ class _BlockScope(object):
_current = threading.local()

def __init__(self, block):
self._block = block
self._block = weakref.ref(block) if block is not None else None
self._counter = {}
self._old_scope = None
self._name_scope = None
Expand All @@ -55,7 +57,8 @@ def __init__(self, block):
def create(prefix, params, hint):
"""Creates prefix and params for new `Block`."""
current = getattr(_BlockScope._current, "value", None)
if current is None:
block = current._block() if current is not None else None
if current is None or block is None:
if prefix is None:
if not hasattr(_name.NameManager._current, "value"):
_name.NameManager._current.value = _name.NameManager()
Expand All @@ -71,23 +74,25 @@ def create(prefix, params, hint):
prefix = '%s%d_'%(hint, count)
current._counter[hint] = count + 1
if params is None:
parent = current._block.params
parent = block.params
params = ParameterDict(parent.prefix+prefix, parent._shared)
else:
params = ParameterDict(params.prefix, params)
return current._block.prefix+prefix, params
return block.prefix + prefix, params

def __enter__(self):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return self
self._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope = _name.Prefix(block.prefix)
self._name_scope.__enter__()
return self

def __exit__(self, ptype, value, trace):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import tempfile
import gc

import mxnet as mx
from mxnet import gluon
Expand Down Expand Up @@ -3212,6 +3213,44 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]

if __name__ == '__main__':
import nose
nose.runmodule()
5 changes: 3 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def __init__(self, prefix):
status = [False]
event = threading.Event()
def f():
with block._BlockScope(dummy_block("spawned_")):
x= NameManager.current.get(None, "hello")
net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block
with block._BlockScope(net):
x = NameManager.current.get(None, "hello")
event.wait()
if x == "spawned_hello0":
status[0] = True
Expand Down