From ce365b333fd6cf1dcd60363eb41193375837aa2a Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 3 Jan 2020 12:14:18 -0800 Subject: [PATCH 1/5] fix symbolblock with bn+fp16 --- python/mxnet/gluon/block.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index de523567f3eb..c5c83340d3ae 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -25,6 +25,7 @@ import warnings import re from collections import OrderedDict, defaultdict +import numpy as np from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, np_symbol @@ -1298,6 +1299,7 @@ def __init__(self, outputs, inputs, params=None): assert len(i.get_internals().list_outputs()) == 1, \ "Input symbols must be variable, but %s is an output of operators"%str(i) input_names.add(i.name) + self._input_names = input_names # check if any symbol is row_sparse row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] @@ -1353,7 +1355,29 @@ def _clear_cached_op(self): def cast(self, dtype): self._clear_cached_op() super(SymbolBlock, self).cast(dtype) - + if np.dtype(dtype).name == 'float16': + # correct BatchNorm types back to float32 due to its special requirement + out = self._cached_graph[1] + params_list = out.get_internals().list_inputs() + for node in params_list: + if node.endswith('running_var'): + prefix = node[:-11] + sibs = [prefix + t for t in ('running_mean', 'gamma', 'beta')] + is_bn = all(p in params_list for p in sibs) + if is_bn: + self.params.get(node).cast('float32') + for sib in sibs: + self.params.get(sib).cast('float32') + if node.endswith('moving_var'): + # another convention used + prefix = node[:-10] + sibs = [prefix + t for t in ('moving_mean', 'gamma', 'beta')] + is_bn = all(p in params_list for p in sibs) + if is_bn: + self.params.get(node).cast('float32') + for sib in sibs: + self.params.get(sib).cast('float32') + def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError From 9d741ab70a4e3d1f521c41a14c9d07ddda88ad41 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 3 Jan 2020 13:56:00 -0800 Subject: [PATCH 2/5] add unittest --- tests/python/gpu/test_gluon_gpu.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index b938b5783415..b977224df773 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -594,6 +594,24 @@ def hybrid_forward(self, F, a, b): assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()), mx.nd.ones((10,), ctx=mx.cpu()))) +@with_seed() +def test_symbol_block_symbolic_bn_fp16_cast(): + ctx = default_context() + net = nn.HybridSequential() + sym = mx.sym.var('data') + conv = mx.sym.Convolution(sym, kernel=(3, 3), num_filter=16) + bn = mx.sym.BatchNorm(conv, name='bn_test') + internals = bn.get_internals() + net.add(nn.SymbolBlock([internals['bn_test_output']], [mx.sym.var('data')])) + net.add(nn.Conv2D(10, kernel_size=1)) + net.initialize(ctx=ctx) + x = mx.nd.zeros((1, 3, 32, 32), dtype='float32', ctx=ctx) + y = net(x) + assert np.dtype(y.dtype).name == 'float32' + net.cast('float16') + x = x.astype('float16') + y1 = net(x) + assert np.dtype(y1.dtype).name == 'float16' if __name__ == '__main__': import nose From c21ed517d6a8b263b515d1c67a69d87c37ce04d5 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 3 Jan 2020 13:59:38 -0800 Subject: [PATCH 3/5] fix --- tests/python/gpu/test_gluon_gpu.py | 32 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index b977224df773..6aa766351b62 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -596,22 +596,22 @@ def hybrid_forward(self, F, a, b): @with_seed() def test_symbol_block_symbolic_bn_fp16_cast(): - ctx = default_context() - net = nn.HybridSequential() - sym = mx.sym.var('data') - conv = mx.sym.Convolution(sym, kernel=(3, 3), num_filter=16) - bn = mx.sym.BatchNorm(conv, name='bn_test') - internals = bn.get_internals() - net.add(nn.SymbolBlock([internals['bn_test_output']], [mx.sym.var('data')])) - net.add(nn.Conv2D(10, kernel_size=1)) - net.initialize(ctx=ctx) - x = mx.nd.zeros((1, 3, 32, 32), dtype='float32', ctx=ctx) - y = net(x) - assert np.dtype(y.dtype).name == 'float32' - net.cast('float16') - x = x.astype('float16') - y1 = net(x) - assert np.dtype(y1.dtype).name == 'float16' + with mx.gpu(0): + net = mx.gluon.nn.HybridSequential() + sym = mx.sym.var('data') + conv = mx.sym.Convolution(sym, kernel=(3, 3), num_filter=16) + bn = mx.sym.BatchNorm(conv, name='bn_test') + internals = bn.get_internals() + net.add(mx.gluon.nn.SymbolBlock([internals['bn_test_output']], [mx.sym.var('data')])) + net.add(mx.gluon.nn.Conv2D(10, kernel_size=1)) + net.initialize() + x = mx.nd.zeros((1, 3, 32, 32), dtype='float32') + y = net(x) + assert np.dtype(y.dtype).name == 'float32' + net.cast('float16') + x = x.astype('float16') + y1 = net(x) + assert np.dtype(y1.dtype).name == 'float16' if __name__ == '__main__': import nose From 05f78a2bbdfdfb0e326d5d2e51b095461f6e01ec Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 3 Jan 2020 14:13:19 -0800 Subject: [PATCH 4/5] remove unused --- python/mxnet/gluon/block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index c5c83340d3ae..b793c34ab88f 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1299,7 +1299,6 @@ def __init__(self, outputs, inputs, params=None): assert len(i.get_internals().list_outputs()) == 1, \ "Input symbols must be variable, but %s is an output of operators"%str(i) input_names.add(i.name) - self._input_names = input_names # check if any symbol is row_sparse row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] From e48ba29fe2a76d98711bb515e8929b148f8be211 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 3 Jan 2020 17:02:21 -0800 Subject: [PATCH 5/5] fix lint --- python/mxnet/gluon/block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index b793c34ab88f..e7f0676bd6e9 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable= arguments-differ, too-many-lines +# pylint: disable= arguments-differ, too-many-lines, reimported """Base container class for all neural network models.""" __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] @@ -1376,7 +1376,7 @@ def cast(self, dtype): self.params.get(node).cast('float32') for sib in sibs: self.params.get(sib).cast('float32') - + def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError