From d69b7f6bf7b70bd804440cf03edd2178121f6ded Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 30 Aug 2018 09:41:33 -0700 Subject: [PATCH 01/13] Infer dtype in SymbolBlock import from input symbol --- python/mxnet/gluon/block.py | 81 ++++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index d0830dcc8cae..43b80f39a0de 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1053,13 +1053,36 @@ def __init__(self, outputs, inputs, params=None): "SymbolBlock doesn't support Parameter '%s' because its storage " \ "type is 'row_sparse'." % j.name - for i in out.list_arguments(): - if i not in input_names: - self.params.get(i, allow_deferred_init=True) + # Infer type of parameters. Without this, every parameter will be created with + # default type i.e., fp32 + arg_params = out.list_arguments() + aux_params = out.list_auxiliary_states() + + infer_type_success, arg_types, aux_types = _infer_param_types(self, + inputs[0], + out, + arg_params, + aux_params) + + if infer_type_success: + # Use inferred types for params + for i, arg in enumerate(arg_params): + if arg not in input_names: + self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) + + for i, aux in enumerate(aux_params): + if aux not in input_names: + self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) + else: + # Use default types for params + for i, arg in enumerate(arg_params): + if arg not in input_names: + dt = inputs[0].infer_type()[0] + self.params.get(arg, allow_deferred_init=True) - for i in out.list_auxiliary_states(): - if i not in input_names: - self.params.get(i, grad_req='null', allow_deferred_init=True) + for i, aux in out.list_auxiliary_states(): + if aux not in input_names: + self.params.get(aux, grad_req='null', allow_deferred_init=True) self._cached_graph = syms, out len_prefix = len(_common_prefix(list(self._params.keys()))) @@ -1086,3 +1109,49 @@ def _clear_cached_op(self): def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError + +def _infer_param_types(self, in_params, out_params, arg_params, aux_params): + """Utility function that helps in inferring DType of args and auxs params + from given input param. + + Parameters + ---------- + in_params: Symbol + Input symbol variable. + out_params: Symbol + Output symbol variable. + arg_params: List of Str + List of names of argument parametrs. + aux_params: List of Str + List of names of auxiliary parameters. + + Returns + ------- + infer_type_success: Boolean + True if able to infer types for all given arg_params and aux_params. + False, otherwise. + arg_types: List of numpy.dtype + List of arg_params type. Order is same as arg_params. + None if unable to infer type. + aux_types: List of numpy.dtype + List of aux_params type. Order is same as aux_params. + None if unable to infer type. + """ + infer_type_success = False + arg_types = None + aux_types = None + + # Get Input symbol details. This will be used to infer types of + # other parameters. + input_sym_name = in_params.name + input_sym_arg_type = in_params.infer_type()[0] + + # Try to infer types of other parameters. + if input_sym_arg_type and len(input_sym_arg_type) > 0: + params = {input_sym_name:input_sym_arg_type[0]} + arg_types, _, aux_types = out_params.infer_type(**params) + if arg_types is not None and len(arg_types) == len(arg_params) and \ + aux_types is not None and len(aux_types) == len(aux_params): + infer_type_success = True + + return (infer_type_success, arg_types, aux_types) From f82344d915e79116cf69570309f1b2cb22adc80e Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 30 Aug 2018 11:12:47 -0700 Subject: [PATCH 02/13] Fix lint issues and make existing tests pass --- python/mxnet/gluon/block.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 43b80f39a0de..4af71a26e39f 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1058,8 +1058,7 @@ def __init__(self, outputs, inputs, params=None): arg_params = out.list_arguments() aux_params = out.list_auxiliary_states() - infer_type_success, arg_types, aux_types = _infer_param_types(self, - inputs[0], + infer_type_success, arg_types, aux_types = _infer_param_types(inputs[0], out, arg_params, aux_params) @@ -1077,10 +1076,9 @@ def __init__(self, outputs, inputs, params=None): # Use default types for params for i, arg in enumerate(arg_params): if arg not in input_names: - dt = inputs[0].infer_type()[0] self.params.get(arg, allow_deferred_init=True) - for i, aux in out.list_auxiliary_states(): + for i, aux in enumerate(aux_params): if aux not in input_names: self.params.get(aux, grad_req='null', allow_deferred_init=True) @@ -1110,7 +1108,7 @@ def _clear_cached_op(self): def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError -def _infer_param_types(self, in_params, out_params, arg_params, aux_params): +def _infer_param_types(in_params, out_params, arg_params, aux_params): """Utility function that helps in inferring DType of args and auxs params from given input param. From 980c9e38ab644e6b9d8e6b9583d9d01e01c1e1eb Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 30 Aug 2018 14:00:30 -0700 Subject: [PATCH 03/13] Add tests for importing a fp64 model into symbol block --- tests/python/unittest/test_gluon.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index bf9f5a77c844..5dcbb9a3aca2 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -336,6 +336,24 @@ def hybrid_forward(self, F, x): net.hybridize() assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray) + # Test case to verify if initializing the SymbolBlock from a model with params + # other than fp32 param dtype. + # Load a resnet model, cast it to fp64 and export + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True) + net_fp32.cast('float64') + net_fp32.hybridize() + data = mx.nd.zeros((1,3,224,224), dtype='float64') + net_fp32.forward(data) + net_fp32.export('resnet34_fp64', 0) + + # Load the saved model and verify if all the params are loaded correctly. + # and choose one of the param to verify the type. + sm = mx.sym.load('resnet34_fp64-symbol.json') + inputs = mx.sym.var('data', dtype='float64') + net_fp64 = mx.gluon.SymbolBlock(sm, inputs) + net_fp64.collect_params().load('resnet34_fp64-0000.params') + assert (net_fp64.params['resnetv20_stage1_conv2_weight'].dtype is np.float64) + @with_seed() @raises(AssertionError) def test_sparse_symbol_block(): From 62acade64d8ed33162f7bf927981c9a8a9717e35 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 30 Aug 2018 17:44:49 -0700 Subject: [PATCH 04/13] Fixing failing test for test symbol block --- tests/python/unittest/test_gluon.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 5dcbb9a3aca2..331a9900d011 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import os +import tempfile + import mxnet as mx from mxnet import gluon from mxnet.gluon import nn @@ -339,19 +342,22 @@ def hybrid_forward(self, F, x): # Test case to verify if initializing the SymbolBlock from a model with params # other than fp32 param dtype. # Load a resnet model, cast it to fp64 and export + tmp = tempfile.mkdtemp() + tmpfile = os.path.join(tmp, 'resnet34_fp64') + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True) net_fp32.cast('float64') net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float64') net_fp32.forward(data) - net_fp32.export('resnet34_fp64', 0) + net_fp32.export(tmpfile, 0) # Load the saved model and verify if all the params are loaded correctly. # and choose one of the param to verify the type. - sm = mx.sym.load('resnet34_fp64-symbol.json') + sm = mx.sym.load(tmpfile + '-symbol.json') inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) - net_fp64.collect_params().load('resnet34_fp64-0000.params') + net_fp64.collect_params().load(tmpfile + '-0000.params') assert (net_fp64.params['resnetv20_stage1_conv2_weight'].dtype is np.float64) @with_seed() From 0c4e45bee6b26b78219e34358d154c72e782a3a4 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 30 Aug 2018 22:19:41 -0700 Subject: [PATCH 05/13] Set context in unit tests --- tests/python/unittest/test_gluon.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 331a9900d011..f2d552da9bba 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -344,11 +344,12 @@ def hybrid_forward(self, F, x): # Load a resnet model, cast it to fp64 and export tmp = tempfile.mkdtemp() tmpfile = os.path.join(tmp, 'resnet34_fp64') + ctx = mx.cpu(0) - net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True) + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx) net_fp32.cast('float64') net_fp32.hybridize() - data = mx.nd.zeros((1,3,224,224), dtype='float64') + data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx) net_fp32.forward(data) net_fp32.export(tmpfile, 0) @@ -357,7 +358,7 @@ def hybrid_forward(self, F, x): sm = mx.sym.load(tmpfile + '-symbol.json') inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) - net_fp64.collect_params().load(tmpfile + '-0000.params') + net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) assert (net_fp64.params['resnetv20_stage1_conv2_weight'].dtype is np.float64) @with_seed() From f66f7a8d69620948e4921cdacea0c62d849d6ae3 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 31 Aug 2018 12:59:37 -0700 Subject: [PATCH 06/13] Add tests for fp16, add default dtype in infer_param_types --- python/mxnet/gluon/block.py | 56 +++++++++++++---------------- python/mxnet/gluon/parameter.py | 2 ++ tests/python/gpu/test_gluon_gpu.py | 26 ++++++++++++++ tests/python/unittest/test_gluon.py | 7 ++-- 4 files changed, 56 insertions(+), 35 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 4af71a26e39f..2d2a56229398 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -26,6 +26,7 @@ import re from collections import OrderedDict +from ..base import mx_real_t from .. import symbol, ndarray, initializer from ..symbol import Symbol from ..ndarray import NDArray @@ -1058,29 +1059,15 @@ def __init__(self, outputs, inputs, params=None): arg_params = out.list_arguments() aux_params = out.list_auxiliary_states() - infer_type_success, arg_types, aux_types = _infer_param_types(inputs[0], - out, - arg_params, - aux_params) + arg_types, aux_types = _infer_param_types(inputs[0], out, arg_params, aux_params) - if infer_type_success: - # Use inferred types for params - for i, arg in enumerate(arg_params): - if arg not in input_names: - self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) + for i, arg in enumerate(arg_params): + if arg not in input_names: + self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) - for i, aux in enumerate(aux_params): - if aux not in input_names: - self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) - else: - # Use default types for params - for i, arg in enumerate(arg_params): - if arg not in input_names: - self.params.get(arg, allow_deferred_init=True) - - for i, aux in enumerate(aux_params): - if aux not in input_names: - self.params.get(aux, grad_req='null', allow_deferred_init=True) + for i, aux in enumerate(aux_params): + if aux not in input_names: + self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) self._cached_graph = syms, out len_prefix = len(_common_prefix(list(self._params.keys()))) @@ -1108,7 +1095,7 @@ def _clear_cached_op(self): def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError -def _infer_param_types(in_params, out_params, arg_params, aux_params): +def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t): """Utility function that helps in inferring DType of args and auxs params from given input param. @@ -1122,20 +1109,18 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params): List of names of argument parametrs. aux_params: List of Str List of names of auxiliary parameters. + default_dtype: numpy.dtype or str, default 'float32' + Default data type for arg_params and aux_params, if unable to infer the type. Returns ------- - infer_type_success: Boolean - True if able to infer types for all given arg_params and aux_params. - False, otherwise. arg_types: List of numpy.dtype List of arg_params type. Order is same as arg_params. - None if unable to infer type. + Defaults to 'float32', if unable to infer type. aux_types: List of numpy.dtype List of aux_params type. Order is same as aux_params. - None if unable to infer type. + Defaults to 'float32', if unable to infer type. """ - infer_type_success = False arg_types = None aux_types = None @@ -1148,8 +1133,15 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params): if input_sym_arg_type and len(input_sym_arg_type) > 0: params = {input_sym_name:input_sym_arg_type[0]} arg_types, _, aux_types = out_params.infer_type(**params) - if arg_types is not None and len(arg_types) == len(arg_params) and \ - aux_types is not None and len(aux_types) == len(aux_params): - infer_type_success = True - return (infer_type_success, arg_types, aux_types) + if arg_types is None or len(arg_types) != len(arg_params): + arg_types = [] + for _ in arg_params: + arg_types.append(default_dtype) + + if aux_types is None or len(aux_types) != len(aux_params): + aux_types = [] + for _ in aux_params: + aux_types.append(default_dtype) + + return (arg_types, aux_types) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 24c86f4e0fa7..f53eeb00694a 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -727,6 +727,8 @@ def get(self, name, **kwargs): if matched: param._shape = tuple(inferred_shape) continue + elif k == 'dtype' and np.dtype(v) == np.dtype(existing): + continue assert v is None or v == existing, \ "Cannot retrieve Parameter '%s' because desired attribute " \ diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 69375afdfe0a..bb5ef444b5e0 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -18,6 +18,7 @@ from __future__ import print_function import sys import os +import tempfile import time import multiprocessing as mp import unittest @@ -202,6 +203,31 @@ def get_num_devices(): _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), num_devices=ndev, cuda=True) +@with_seed() +def test_symbol_block_fp16(): + # Test case to verify if initializing the SymbolBlock from a model with params + # other than fp32 param dtype. + + # 1. Load a resnet model, cast it to fp16 and export + tmp = tempfile.mkdtemp() + tmpfile = os.path.join(tmp, 'resnet34_fp16') + ctx = mx.gpu(0) + + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx) + net_fp32.cast('float16') + net_fp32.hybridize() + data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx) + net_fp32.forward(data) + net_fp32.export(tmpfile, 0) + + # 2. Load the saved model and verify if all the params are loaded correctly. + # and choose one of the param to verify the type if fp16. + sm = mx.sym.load(tmpfile + '-symbol.json') + inputs = mx.sym.var('data', dtype='float16') + net_fp64 = mx.gluon.SymbolBlock(sm, inputs) + net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) + assert (net_fp64.params['resnetv20_stage1_conv2_weight'].dtype is np.float16) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index f2d552da9bba..25bdce06a2d2 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -341,7 +341,8 @@ def hybrid_forward(self, F, x): # Test case to verify if initializing the SymbolBlock from a model with params # other than fp32 param dtype. - # Load a resnet model, cast it to fp64 and export + + # 1. Load a resnet model, cast it to fp64 and export tmp = tempfile.mkdtemp() tmpfile = os.path.join(tmp, 'resnet34_fp64') ctx = mx.cpu(0) @@ -353,8 +354,8 @@ def hybrid_forward(self, F, x): net_fp32.forward(data) net_fp32.export(tmpfile, 0) - # Load the saved model and verify if all the params are loaded correctly. - # and choose one of the param to verify the type. + # 2. Load the saved model and verify if all the params are loaded correctly. + # and choose one of the param to verify the type if fp64. sm = mx.sym.load(tmpfile + '-symbol.json') inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) From 20333475bcc63676e99eab8fe25e2a3cc7eca2f1 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 31 Aug 2018 14:21:31 -0700 Subject: [PATCH 07/13] Use tmp directory as root for loading from model zoo to avoid race condition --- tests/python/gpu/test_gluon_gpu.py | 2 +- tests/python/unittest/test_gluon.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index bb5ef444b5e0..956a742b19b0 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -213,7 +213,7 @@ def test_symbol_block_fp16(): tmpfile = os.path.join(tmp, 'resnet34_fp16') ctx = mx.gpu(0) - net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx) + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) net_fp32.cast('float16') net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 25bdce06a2d2..dc9487ebb366 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -347,7 +347,7 @@ def hybrid_forward(self, F, x): tmpfile = os.path.join(tmp, 'resnet34_fp64') ctx = mx.cpu(0) - net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx) + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) net_fp32.cast('float64') net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx) From 346632d3ecda8dcf89c49a5a21b42e7ccfb7f837 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 31 Aug 2018 14:37:22 -0700 Subject: [PATCH 08/13] Fixing naming and parameter selection in test case --- tests/python/gpu/test_gluon_gpu.py | 6 +++--- tests/python/unittest/test_gluon.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 956a742b19b0..0f090792899b 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -224,9 +224,9 @@ def test_symbol_block_fp16(): # and choose one of the param to verify the type if fp16. sm = mx.sym.load(tmpfile + '-symbol.json') inputs = mx.sym.var('data', dtype='float16') - net_fp64 = mx.gluon.SymbolBlock(sm, inputs) - net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) - assert (net_fp64.params['resnetv20_stage1_conv2_weight'].dtype is np.float16) + net_fp16 = mx.gluon.SymbolBlock(sm, inputs) + net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx) + assert (net_fp16.params[list(net_fp16.params.keys())[0]].dtype is np.float16) if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index dc9487ebb366..ac1ba1a827f1 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -360,7 +360,7 @@ def hybrid_forward(self, F, x): inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) - assert (net_fp64.params['resnetv20_stage1_conv2_weight'].dtype is np.float64) + assert (net_fp64.params[list(net_fp64.params.keys())[0]].dtype is np.float64) @with_seed() @raises(AssertionError) From db83669c0735f147f8c52720e3d7a262edfabe16 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 31 Aug 2018 18:20:22 -0700 Subject: [PATCH 09/13] Fixing failing GPU tests --- tests/python/gpu/test_gluon_gpu.py | 2 +- tests/python/unittest/test_gluon.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 0f090792899b..9f576614c3ec 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -226,7 +226,7 @@ def test_symbol_block_fp16(): inputs = mx.sym.var('data', dtype='float16') net_fp16 = mx.gluon.SymbolBlock(sm, inputs) net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx) - assert (net_fp16.params[list(net_fp16.params.keys())[0]].dtype is np.float16) + assert np.dtype(net_fp16.params['resnetv20_conv0_weight'].dtype) == np.dtype(np.float16) if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index ac1ba1a827f1..fe344230e238 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -360,7 +360,7 @@ def hybrid_forward(self, F, x): inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) - assert (net_fp64.params[list(net_fp64.params.keys())[0]].dtype is np.float64) + assert np.dtype(net_fp64.params['resnetv20_stage1_conv2_weight'].dtype) == np.dtype(np.float64) @with_seed() @raises(AssertionError) From 05287e5a25626896595f8d50f5d43bc673877281 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Tue, 4 Sep 2018 09:16:50 -0700 Subject: [PATCH 10/13] Make unit test more deterministic to get param name --- tests/python/gpu/test_gluon_gpu.py | 7 ++++++- tests/python/unittest/test_gluon.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 9f576614c3ec..8394276c8ef0 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -226,7 +226,12 @@ def test_symbol_block_fp16(): inputs = mx.sym.var('data', dtype='float16') net_fp16 = mx.gluon.SymbolBlock(sm, inputs) net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx) - assert np.dtype(net_fp16.params['resnetv20_conv0_weight'].dtype) == np.dtype(np.float16) + # 3. Get a conv layer's weight parameter name. Conv layer's weight param is + # expected to be of dtype casted, fp16. + for param_name in net_fp16.params.keys(): + if 'conv' in param_name and 'weight' in param_name: + break + assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16) if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index fe344230e238..08028e887925 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -360,7 +360,12 @@ def hybrid_forward(self, F, x): inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) - assert np.dtype(net_fp64.params['resnetv20_stage1_conv2_weight'].dtype) == np.dtype(np.float64) + # 3. Get a conv layer's weight parameter name. Conv layer's weight param is + # expected to be of dtype casted, fp64. + for param_name in net_fp64.params.keys(): + if 'conv' in param_name and 'weight' in param_name: + break + assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64) @with_seed() @raises(AssertionError) From fb6158a5dcfdda0257372e3523cefbf4ff3599f8 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Tue, 11 Sep 2018 13:33:39 -0700 Subject: [PATCH 11/13] Override cast in symbol block, handle grouped symbol --- python/mxnet/gluon/block.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 2d2a56229398..4478247e7244 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1092,6 +1092,10 @@ def _clear_cached_op(self): super(SymbolBlock, self)._clear_cached_op() self._cached_graph = tmp + def cast(self, dtype): + self._clear_cached_op() + super(SymbolBlock, self).cast(dtype) + def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError @@ -1126,12 +1130,12 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dt # Get Input symbol details. This will be used to infer types of # other parameters. - input_sym_name = in_params.name + input_sym_names = in_params.list_inputs() input_sym_arg_type = in_params.infer_type()[0] # Try to infer types of other parameters. if input_sym_arg_type and len(input_sym_arg_type) > 0: - params = {input_sym_name:input_sym_arg_type[0]} + params = {k:v for k, v in zip(input_sym_names, input_sym_arg_type)} arg_types, _, aux_types = out_params.infer_type(**params) if arg_types is None or len(arg_types) != len(arg_params): From 0ff23f8a0033517bcd64bf20b941b7c75f247f1b Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Tue, 11 Sep 2018 16:17:21 -0700 Subject: [PATCH 12/13] Handle multiple symbolic input usecase --- python/mxnet/gluon/block.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 4478247e7244..6cb9fc690b5a 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1059,7 +1059,7 @@ def __init__(self, outputs, inputs, params=None): arg_params = out.list_arguments() aux_params = out.list_auxiliary_states() - arg_types, aux_types = _infer_param_types(inputs[0], out, arg_params, aux_params) + arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) for i, arg in enumerate(arg_params): if arg not in input_names: @@ -1105,8 +1105,8 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dt Parameters ---------- - in_params: Symbol - Input symbol variable. + in_params: List of Symbol + List of input symbol variables. out_params: Symbol Output symbol variable. arg_params: List of Str @@ -1130,12 +1130,23 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dt # Get Input symbol details. This will be used to infer types of # other parameters. - input_sym_names = in_params.list_inputs() - input_sym_arg_type = in_params.infer_type()[0] + input_sym_names = [in_param.name for in_param in in_params] + + # Try to infer input types. If not successful, we will set default dtype. + # If successful, we will try to infer other params in the graph. + input_sym_arg_types = [] + can_infer_input_type = True + for in_param in in_params: + input_sym_arg_type = in_param.infer_type()[0] + if not input_sym_arg_type or len(input_sym_arg_type) < 1: + can_infer_input_type = False + break + else: + input_sym_arg_types.append(in_param.infer_type()[0][0]) # Try to infer types of other parameters. - if input_sym_arg_type and len(input_sym_arg_type) > 0: - params = {k:v for k, v in zip(input_sym_names, input_sym_arg_type)} + if can_infer_input_type: + params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)} arg_types, _, aux_types = out_params.infer_type(**params) if arg_types is None or len(arg_types) != len(arg_params): From b58ea417e1123cc86ce55306fb336a75d07c8108 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Tue, 11 Sep 2018 20:02:45 -0700 Subject: [PATCH 13/13] Add tests to verify behavior of SymbolBlock.cast --- tests/python/unittest/test_gluon.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 08028e887925..796182e2b735 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -367,6 +367,13 @@ def hybrid_forward(self, F, x): break assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64) + # Cast the symbol block to FP32 and try to forward a FP32 data. + # This will verify SymbolBlock.cast() functionality. + net_fp64.cast('float32') + fp32_data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx) + prediction = net_fp64.forward(fp32_data) + assert np.dtype(prediction.dtype) == np.dtype(np.float32) + @with_seed() @raises(AssertionError) def test_sparse_symbol_block():