diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index de523567f3eb..e7f0676bd6e9 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable= arguments-differ, too-many-lines +# pylint: disable= arguments-differ, too-many-lines, reimported """Base container class for all neural network models.""" __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] @@ -25,6 +25,7 @@ import warnings import re from collections import OrderedDict, defaultdict +import numpy as np from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, np_symbol @@ -1353,6 +1354,28 @@ def _clear_cached_op(self): def cast(self, dtype): self._clear_cached_op() super(SymbolBlock, self).cast(dtype) + if np.dtype(dtype).name == 'float16': + # correct BatchNorm types back to float32 due to its special requirement + out = self._cached_graph[1] + params_list = out.get_internals().list_inputs() + for node in params_list: + if node.endswith('running_var'): + prefix = node[:-11] + sibs = [prefix + t for t in ('running_mean', 'gamma', 'beta')] + is_bn = all(p in params_list for p in sibs) + if is_bn: + self.params.get(node).cast('float32') + for sib in sibs: + self.params.get(sib).cast('float32') + if node.endswith('moving_var'): + # another convention used + prefix = node[:-10] + sibs = [prefix + t for t in ('moving_mean', 'gamma', 'beta')] + is_bn = all(p in params_list for p in sibs) + if is_bn: + self.params.get(node).cast('float32') + for sib in sibs: + self.params.get(sib).cast('float32') def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index b938b5783415..6aa766351b62 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -594,6 +594,24 @@ def hybrid_forward(self, F, a, b): assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()), mx.nd.ones((10,), ctx=mx.cpu()))) +@with_seed() +def test_symbol_block_symbolic_bn_fp16_cast(): + with mx.gpu(0): + net = mx.gluon.nn.HybridSequential() + sym = mx.sym.var('data') + conv = mx.sym.Convolution(sym, kernel=(3, 3), num_filter=16) + bn = mx.sym.BatchNorm(conv, name='bn_test') + internals = bn.get_internals() + net.add(mx.gluon.nn.SymbolBlock([internals['bn_test_output']], [mx.sym.var('data')])) + net.add(mx.gluon.nn.Conv2D(10, kernel_size=1)) + net.initialize() + x = mx.nd.zeros((1, 3, 32, 32), dtype='float32') + y = net(x) + assert np.dtype(y.dtype).name == 'float32' + net.cast('float16') + x = x.astype('float16') + y1 = net(x) + assert np.dtype(y1.dtype).name == 'float16' if __name__ == '__main__': import nose