From 67445a47d0aedf766fa83fc2a90f987f57bc6140 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Sat, 13 Jul 2019 02:51:12 +0000 Subject: [PATCH 1/5] Add AMP Conversion support for bucketing module --- python/mxnet/contrib/amp/amp.py | 59 ++++++++ python/mxnet/module/bucketing_module.py | 178 +++++++++++++++++++++++- python/mxnet/module/module.py | 2 +- tests/python/gpu/test_contrib_amp.py | 35 +++++ tests/python/train/test_bucketing.py | 10 +- 5 files changed, 274 insertions(+), 10 deletions(-) diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py index ef2f7209d946..e7007d6001e8 100755 --- a/python/mxnet/contrib/amp/amp.py +++ b/python/mxnet/contrib/amp/amp.py @@ -32,6 +32,7 @@ from ... import symbol from ...context import gpu from ...symbol import Symbol +from ...module import BucketingModule from ...symbol import contrib as symbol_contrib from ... import ndarray from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP @@ -672,6 +673,64 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None, ret.collect_params().load_dict(arg_dict, ctx=ctx) return ret +def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype_ops=None, + fp32_ops=None, conditional_fp32_ops=None, + excluded_sym_names=None, cast_optional_params=False): + """Given a bucketing module cast the symbols associated with the BucketingModule + and params if cast_optional_params is set. + bucketing_mod : BucketingModule instance + target_dtype : str + Currently only supports float16. The target dtype indicates to add cast layers + when possible so that lower precision computation can be leveraged. + target_dtype_ops : list of strs + Override the list of operator names casted to target_dtype. + If None, uses the framework's default list to be casted to target dtype. + fp32_ops : list of strs + Override the lists of operator names casted to FP32. + If None, uses the framework's default list to be casted to FP32. + widest_dtype_ops : list of strs + A list of op names provided by user which should run in widest precision among its inputs. + If None, uses the framework's default list of widest_precision_ops. + conditional_fp32_ops : list of (string, string, list of string) + Override the list of operators to be casted to FP32. + The format of the list is + (name of the function, name of the parameter, + list of values of the parameter that make the operator to be casted to + fp32) + excluded_sym_names : list of strs + A list of strings that represent the names of symbols that users want to exclude + from being executed in lower precision. + cast_optional_params : bool, default False + Whether to cast the arg_params and aux_params that don't require to be in FP16 + because of a cast layer following it, but will reduce the computation and memory + overhead of the model if casted. + """ + assert isinstance(bucketing_mod, BucketingModule), "module should be instance of bucketing module" + assert len(bucketing_mod._buckets) > 0, "Bucketing Module should not be empty" + + sym_dict = {} + assert bucketing_mod.params_initialized, \ + "bucketing_mod params should be initialized for mixed precision conversion" + arg_params, aux_params = bucketing_mod._curr_module._arg_params, bucketing_mod._curr_module._aux_params + for key, val in bucketing_mod._buckets.items(): + sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol, + arg_params, + aux_params, + cast_optional_params=True) + result_mod = BucketingModule.load_dict(sym_dict, + sym_gen=bucketing_mod._sym_gen, + arg_params=result_arg_params, + aux_params=result_aux_params, + default_bucket_key=bucketing_mod._default_bucket_key, + logger=bucketing_mod.logger, + context=bucketing_mod._context, + work_load_list=bucketing_mod._work_load_list, + fixed_param_names=bucketing_mod._fixed_param_names, + state_names=bucketing_mod._state_names, + group2ctxs=bucketing_mod._group2ctxs, + compression_params=bucketing_mod._compression_params) + return result_mod + def list_fp16_ops(): """Get the default list of FP16 ops for AMP """ diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 66c666659d0b..17ea994a544a 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -28,6 +28,8 @@ from .. import context as ctx from ..initializer import Uniform +from .. import ndarray as nd +from .. import symbol as sym from .base_module import BaseModule, _check_input_names from .module import Module @@ -170,7 +172,7 @@ def get_params(self): `(arg_params, aux_params)` A pair of dictionaries each mapping parameter names to NDArray values. """ - assert self.binded and self.params_initialized + assert self.params_initialized self._curr_module._params_dirty = self._params_dirty params = self._curr_module.get_params() self._params_dirty = False @@ -335,12 +337,16 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, self._grad_req = grad_req symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key) - module = Module(symbol, data_names, label_names, logger=self.logger, - context=self._context, work_load_list=self._work_load_list, - fixed_param_names=self._fixed_param_names, - state_names=self._state_names, - group2ctxs=self._group2ctxs, - compression_params=self._compression_params) + module = None + if not self._default_bucket_key in self._buckets: + module = Module(symbol, data_names, label_names, logger=self.logger, + context=self._context, work_load_list=self._work_load_list, + fixed_param_names=self._fixed_param_names, + state_names=self._state_names, + group2ctxs=self._group2ctxs, + compression_params=self._compression_params) + else: + module = self._buckets[self._default_bucket_key] module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None, grad_req=self._grad_req) self._curr_module = module @@ -380,6 +386,13 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): if self._monitor is not None: module.install_monitor(self._monitor) self._buckets[bucket_key] = module + else: + module = self._buckets[bucket_key] + if not module.binded: + module.bind(data_shapes, label_shapes, self._curr_module.for_training, + self._curr_module.inputs_need_grad, + force_rebind=False, shared_module=self._buckets[self._default_bucket_key], + grad_req=self._grad_req) self._curr_module = self._buckets[bucket_key] self._curr_bucket_key = bucket_key @@ -544,3 +557,154 @@ def install_monitor(self, mon): self._monitor = mon for mod in self._buckets.values(): mod.install_monitor(mon) + + def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): + """Saves current progress to checkpoint for all buckets in BucketingModule + Use `mx.callback.module_checkpoint` as `epoch_end_callback` to save during training. + + Parameters + ---------- + prefix : str + The file prefix to checkpoint to. + epoch : int + The current epoch number. + save_optimizer_states : bool + Whether to save optimizer states to continue training. + """ + + assert len(self._buckets) > 0, "Empty BucketingModule cannot be saved" + param_name = "%s-%04d.params" % (prefix, epoch) + self.save_params(param_name) + for buckey_key, module in self._buckets.items(): + symbol, data_names, label_names = self._sym_gen(bucket_key) + symbol.save("%s-%s-symbol.json" % (prefix, epoch)) + if save_optimizer_states: + state_name = "%s-%04d.states" % (prefix, epoch) + module.save_optimizer_states(state_name) + nd.save("%s.buckets" % (prefix), nd.array(self._buckets.keys(), dtype=np.int32)) + + @staticmethod + def load(prefix, epoch, load_optimizer_states=False, sym_gen=None, default_bucket_key=None, **kwargs): + """Creates a model from previously saved checkpoint. + + Parameters + ---------- + prefix : str + path prefix of saved model files. You should have + "prefix-symbol.json", "prefix-xxxx.params", and + optionally "prefix-xxxx.states", where xxxx is the + epoch number. + epoch : int + epoch to load. + load_optimizer_states : bool + whether to load optimizer states. Checkpoint needs + to have been made with save_optimizer_states=True. + sym_gen : function + A function when called with a bucket key, returns a triple + ``(symbol, data_names, label_names)``. + provide sym_gen which was used when saving bucketing module. + logger : Logger + Default is `logging`. + context : Context or list of Context + Default is ``cpu()``. + work_load_list : list of number + Default ``None``, indicating uniform workload. + fixed_param_names: list of str + Default ``None``, indicating no network parameters are fixed. + state_names : list of str + States are similar to data and label, but not provided by data iterator. + Instead they are initialized to 0 and can be set by set_states() + group2ctxs : dict of str to context or list of context, + or list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. + """ + assert sym_gen is not None, \ + "sym_gen is required for loading BucketingModule" + assert default_bucket_key is not None, \ + "default_bucket_key is required for loading BucketingModule" + buckets = nd.load("%s.buckets" % prefix) + buckets = list(buckets[0].asnumpy().astype('int32')) + bucketing_mod = BucketingModule(sym_gen, default_bucket_key, **kwargs) + for bucket_key in buckets: + _, data_names, label_names = sym_gen(bucket_key) + symbol = sym.load("%s-%s-symbol.json" % (prefix, bucket_key)) + bucketing_mod._buckets[bucket_key] = Module(symbol, data_names, label_names, **kwargs) + if bucket_key == default_bucket_key: + bucketing_mod._curr_module = bucketing_mod._buckets[bucket_key] + arg_params, aux_params = load_params(prefix, epoch) + bucketing_mod._curr_module._arg_params = arg_params + bucketing_mod._curr_module._aux_params = aux_params + bucketing_mod._curr_module.params_initialized = True + bucketing_mod.params_initialized = True + if load_optimizer_states: + bucketing_mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch) + return bucketing_mod + + @staticmethod + def load_dict(sym_dict=None, sym_gen=None, default_bucket_key=None, arg_params=None, + aux_params=None, **kwargs): + """Creates a model from a dict mapping bucket_key to symbols and shared arg_params + and aux_params. + + Parameters + ---------- + sym_dict : dict mapping bucket_key to symbol + Dict mapping bucket key to symbol + sym_gen : function + A function when called with a bucket key, returns a triple + ``(symbol, data_names, label_names)``. + provide sym_gen which was used when saving bucketing module. + default_bucket_key : str (or any python object) + The key for the default bucket. + arg_params : dict + Required for loading the BucketingModule. + Dict of name to parameter ndarrays. + aux_params : dict + Required for loading the BucketingModule. + Dict of name to auxiliary state ndarrays. + logger : Logger + Default is `logging`. + context : Context or list of Context + Default is ``cpu()``. + work_load_list : list of number + Default ``None``, indicating uniform workload. + fixed_param_names: list of str + Default ``None``, indicating no network parameters are fixed. + state_names : list of str + States are similar to data and label, but not provided by data iterator. + Instead they are initialized to 0 and can be set by set_states() + group2ctxs : dict of str to context or list of context, + or list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. + """ + + assert sym_dict is not None, \ + "sym_dict needs to be provided for BucketingModule.load_dict" + assert arg_params is not None, \ + "arg_params need to be provided for BucketingModule.load_dict" + assert aux_params is not None, \ + "aux_params need to be provided for BucketingModule.load_dict" + assert default_bucket_key is not None, \ + "default_bucket_key needs to be provided for BucketingModule.load_dict" + + bucketing_mod = BucketingModule(sym_gen, default_bucket_key, **kwargs) + for bucket_key, sym in sym_dict.items(): + _, data_names, label_names = sym_gen(default_bucket_key) + bucketing_mod._buckets[bucket_key] = Module(sym, data_names, label_names, **kwargs) + if bucket_key == default_bucket_key: + bucketing_mod._curr_module = bucketing_mod._buckets[bucket_key] + bucketing_mod._curr_module._arg_params = arg_params + bucketing_mod._curr_module._aux_params = aux_params + bucketing_mod._curr_module.params_initialized = True + bucketing_mod.params_initialized = True + return bucketing_mod diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index c1867282e215..3ba141e94f62 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -250,7 +250,7 @@ def get_params(self): `(arg_params, aux_params)` A pair of dictionaries each mapping parameter names to NDArray values. """ - assert self.binded and self.params_initialized + assert self.params_initialized if self._params_dirty: self._sync_params_from_devices() diff --git a/tests/python/gpu/test_contrib_amp.py b/tests/python/gpu/test_contrib_amp.py index 7927cc99160b..29ca32c9f964 100644 --- a/tests/python/gpu/test_contrib_amp.py +++ b/tests/python/gpu/test_contrib_amp.py @@ -19,6 +19,7 @@ import sys import mxnet as mx import numpy as np +from random import randint import warnings import collections import ctypes @@ -31,6 +32,8 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import with_seed, teardown +sys.path.insert(0, os.path.join(curr_path, '../train')) +from test_bucketing import train_model def test_amp_coverage(): conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS] @@ -300,10 +303,42 @@ def check_amp_convert_hybrid_block(): params = converted_model.collect_params() assert params["stage2_unit1_conv2_weight"].dtype == np.float16 + + def check_amp_convert_bucketing_module(): + model = train_model(context=mx.current_context()) + result_model = amp.convert_bucketing_module(model) + val_sent = [] + batch_size = 128 + invalid_label = -1 + num_sentence = 1000 + buckets = [5, 10, 20, 30, 40] + len_vocab = 50 + + for _ in range(num_sentence): + len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty + val_sentence = [] + for _ in range(len_sentence): + val_sentence.append(randint(1, len_vocab)) + val_sent.append(val_sentence) + + data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False) + result_model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(batch_size, 1)) + + # AMP conversion with cast_optional_params set to true + result_model = amp.convert_bucketing_module(model, cast_optional_params=True) + result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False) + result_model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(batch_size, 1)) + + with mx.Context(mx.gpu(0)): check_amp_convert_symbol() check_amp_convert_model() check_amp_convert_hybrid_block() + check_amp_convert_bucketing_module() @with_seed() diff --git a/tests/python/train/test_bucketing.py b/tests/python/train/test_bucketing.py index 882c4a4a513d..225440eac8b5 100644 --- a/tests/python/train/test_bucketing.py +++ b/tests/python/train/test_bucketing.py @@ -20,9 +20,10 @@ import mxnet as mx import random from random import randint +from mxnet.contrib.amp import amp -def test_bucket_module(): +def train_model(context=mx.cpu()): import logging head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) @@ -80,7 +81,7 @@ def sym_gen(seq_len): return loss, ('data',), ('softmax_label',) - contexts = mx.cpu(0) + contexts = context model = mx.mod.BucketingModule( sym_gen=sym_gen, @@ -101,9 +102,14 @@ def sym_gen(seq_len): num_epoch=num_epochs, batch_end_callback=mx.callback.Speedometer(batch_size, 50)) logging.info('Finished fit...') + return model + + +def test_bucket_module(): # This test forecasts random sequence of words to check bucketing. # We cannot guarantee the accuracy of such an impossible task, and comments out the following line. # assert model.score(data_val, mx.metric.MSE())[0][1] < 350, "High mean square error." + model = train_model() if __name__ == "__main__": From d54ee9315aea510acbc528b28c6294e9b004330f Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 15 Jul 2019 19:09:29 +0000 Subject: [PATCH 2/5] Add bucketing module changes --- python/mxnet/contrib/amp/amp.py | 7 ++++++- python/mxnet/module/bucketing_module.py | 12 +++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py index e7007d6001e8..746a9a7f6d68 100755 --- a/python/mxnet/contrib/amp/amp.py +++ b/python/mxnet/contrib/amp/amp.py @@ -716,7 +716,12 @@ def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol, arg_params, aux_params, - cast_optional_params=True) + target_dtype=target_dtype, + target_dtype_ops=target_dtype_ops, + fp32_ops=fp32_ops, + conditional_fp32_ops=conditional_fp32_ops, + excluded_sym_names=excluded_sym_names, + cast_optional_params=cast_optional_params) result_mod = BucketingModule.load_dict(sym_dict, sym_gen=bucketing_mod._sym_gen, arg_params=result_arg_params, diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 17ea994a544a..bef361aad996 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -24,6 +24,7 @@ import logging import warnings +import numpy as np from .. import context as ctx @@ -33,6 +34,7 @@ from .base_module import BaseModule, _check_input_names from .module import Module +from ..model import load_params from ..name import NameManager class BucketingModule(BaseModule): @@ -575,12 +577,12 @@ def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): assert len(self._buckets) > 0, "Empty BucketingModule cannot be saved" param_name = "%s-%04d.params" % (prefix, epoch) self.save_params(param_name) - for buckey_key, module in self._buckets.items(): - symbol, data_names, label_names = self._sym_gen(bucket_key) + for bucket_key in self._buckets: + symbol, _, _ = self._sym_gen(bucket_key) symbol.save("%s-%s-symbol.json" % (prefix, epoch)) if save_optimizer_states: state_name = "%s-%04d.states" % (prefix, epoch) - module.save_optimizer_states(state_name) + self._curr_module.save_optimizer_states(state_name) nd.save("%s.buckets" % (prefix), nd.array(self._buckets.keys(), dtype=np.int32)) @staticmethod @@ -698,9 +700,9 @@ def load_dict(sym_dict=None, sym_gen=None, default_bucket_key=None, arg_params=N "default_bucket_key needs to be provided for BucketingModule.load_dict" bucketing_mod = BucketingModule(sym_gen, default_bucket_key, **kwargs) - for bucket_key, sym in sym_dict.items(): + for bucket_key, loaded_sym in sym_dict.items(): _, data_names, label_names = sym_gen(default_bucket_key) - bucketing_mod._buckets[bucket_key] = Module(sym, data_names, label_names, **kwargs) + bucketing_mod._buckets[bucket_key] = Module(loaded_sym, data_names, label_names, **kwargs) if bucket_key == default_bucket_key: bucketing_mod._curr_module = bucketing_mod._buckets[bucket_key] bucketing_mod._curr_module._arg_params = arg_params From da3a9f96b33c7cdb20cfe86a17447deed0a05ce4 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 15 Jul 2019 20:44:53 +0000 Subject: [PATCH 3/5] Add load_aprams --- python/mxnet/model.py | 24 +++++---- python/mxnet/module/bucketing_module.py | 16 ++---- tests/python/train/test_bucketing.py | 40 ++++++++------ tests/python/unittest/test_module.py | 71 +++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 39 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index aee4a8ce2b45..94487bcf5fee 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -423,6 +423,20 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params, remove_amp_ca logging.info('Saved checkpoint to \"%s\"', param_name) +def load_params(prefix, epoch): + """Load params from a file + """ + save_dict = nd.load("%s-%04d.params" % (prefix, epoch)) + arg_params = {} + aux_params = {} + for k, v in save_dict.items(): + tp, name = k.split(":", 1) + if tp == "arg": + arg_params[name] = v + if tp == "aux": + aux_params[name] = v + return (arg_params, aux_params) + def load_checkpoint(prefix, epoch): """Load model checkpoint from file. @@ -448,15 +462,7 @@ def load_checkpoint(prefix, epoch): - Parameters will be loaded from ``prefix-epoch.params``. """ symbol = sym.load('%s-symbol.json' % prefix) - save_dict = nd.load('%s-%04d.params' % (prefix, epoch)) - arg_params = {} - aux_params = {} - for k, v in save_dict.items(): - tp, name = k.split(':', 1) - if tp == 'arg': - arg_params[name] = v - if tp == 'aux': - aux_params[name] = v + arg_params, aux_params = load_params(prefix, epoch) return (symbol, arg_params, aux_params) from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index bef361aad996..a23f470a30d5 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -560,7 +560,7 @@ def install_monitor(self, mon): for mod in self._buckets.values(): mod.install_monitor(mon) - def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): + def save_checkpoint(self, prefix, epoch, remove_amp_cast=False): """Saves current progress to checkpoint for all buckets in BucketingModule Use `mx.callback.module_checkpoint` as `epoch_end_callback` to save during training. @@ -570,8 +570,6 @@ def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): The file prefix to checkpoint to. epoch : int The current epoch number. - save_optimizer_states : bool - Whether to save optimizer states to continue training. """ assert len(self._buckets) > 0, "Empty BucketingModule cannot be saved" @@ -579,11 +577,8 @@ def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): self.save_params(param_name) for bucket_key in self._buckets: symbol, _, _ = self._sym_gen(bucket_key) - symbol.save("%s-%s-symbol.json" % (prefix, epoch)) - if save_optimizer_states: - state_name = "%s-%04d.states" % (prefix, epoch) - self._curr_module.save_optimizer_states(state_name) - nd.save("%s.buckets" % (prefix), nd.array(self._buckets.keys(), dtype=np.int32)) + symbol.save("%s-%s-symbol.json" % (prefix, bucket_key), remove_amp_cast=remove_amp_cast) + nd.save("%s.buckets" % (prefix), nd.array(list(self._buckets.keys()), dtype=np.int32)) @staticmethod def load(prefix, epoch, load_optimizer_states=False, sym_gen=None, default_bucket_key=None, **kwargs): @@ -598,9 +593,6 @@ def load(prefix, epoch, load_optimizer_states=False, sym_gen=None, default_bucke epoch number. epoch : int epoch to load. - load_optimizer_states : bool - whether to load optimizer states. Checkpoint needs - to have been made with save_optimizer_states=True. sym_gen : function A function when called with a bucket key, returns a triple ``(symbol, data_names, label_names)``. @@ -643,8 +635,6 @@ def load(prefix, epoch, load_optimizer_states=False, sym_gen=None, default_bucke bucketing_mod._curr_module._aux_params = aux_params bucketing_mod._curr_module.params_initialized = True bucketing_mod.params_initialized = True - if load_optimizer_states: - bucketing_mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch) return bucketing_mod @staticmethod diff --git a/tests/python/train/test_bucketing.py b/tests/python/train/test_bucketing.py index 225440eac8b5..a233e46e0992 100644 --- a/tests/python/train/test_bucketing.py +++ b/tests/python/train/test_bucketing.py @@ -23,6 +23,28 @@ from mxnet.contrib.amp import amp +def prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence): + train_sent = [] + val_sent = [] + + for _ in range(num_sentence): + len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty + train_sentence = [] + val_sentence = [] + for _ in range(len_sentence): + train_sentence.append(randint(1, len_vocab)) + val_sentence.append(randint(1, len_vocab)) + train_sent.append(train_sentence) + val_sent.append(val_sentence) + + data_train = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + + return (data_train, data_val) + + def train_model(context=mx.cpu()): import logging head = '%(asctime)-15s %(message)s' @@ -42,23 +64,7 @@ def train_model(context=mx.cpu()): invalid_label = -1 num_sentence = 1000 - train_sent = [] - val_sent = [] - - for _ in range(num_sentence): - len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty - train_sentence = [] - val_sentence = [] - for _ in range(len_sentence): - train_sentence.append(randint(1, len_vocab)) - val_sentence.append(randint(1, len_vocab)) - train_sent.append(train_sentence) - val_sent.append(val_sentence) - - data_train = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, - invalid_label=invalid_label) - data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, - invalid_label=invalid_label) + data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence) stack = mx.rnn.SequentialRNNCell() for i in range(num_layers): diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index c82afdfe033a..b82933126d67 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import os import mxnet as mx import mxnet.ndarray as nd from mxnet.test_utils import * @@ -23,6 +24,9 @@ from mxnet.module.executor_group import DataParallelExecutorGroup from common import setup_module, with_seed, assertRaises, teardown from collections import namedtuple +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, "../train")) +from test_bucketing import train_model, prepare_bucketing_data @with_seed() @@ -216,6 +220,73 @@ def dict_equ(a, b): os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) +@with_seed() +def test_bucketing_save_load(): + previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") + os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') + def dict_equ(a, b): + assert set(a) == set(b) + for k in a: + assert (a[k].asnumpy() == b[k].asnumpy()).all() + + + len_vocab = 50 + num_embed = 25 + num_epochs = 5 + batch_size = 128 + num_layers = 2 + num_hidden = 25 + buckets = [5, 10, 20, 30, 40] + invalid_label = -1 + num_sentence=1000 + + stack = mx.rnn.SequentialRNNCell() + for i in range(num_layers): + stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i)) + + def sym_gen(seq_len): + data = mx.sym.Variable('data') + label = mx.sym.Variable('softmax_label') + embed = mx.sym.Embedding(data=data, input_dim=len_vocab, + output_dim=num_embed, name='embed') + stack.reset() + outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) + + pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) + pred = mx.sym.FullyConnected(data=pred, num_hidden=len_vocab, name='pred') + + label = mx.sym.Reshape(label, shape=(-1,)) + loss = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') + + return loss, ('data',), ('softmax_label',) + + model = train_model(context=mx.current_context()) + model.save_checkpoint("test", 0) + data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence) + mod2 = mx.mod.BucketingModule.load('test', 0, sym_gen=sym_gen, + default_bucket_key=data_train.default_bucket_key) + + mod2.bind(data_shapes=data_train.provide_data, + label_shapes=data_train.provide_label) + + for bucket_key in model._buckets.keys(): + dict_equ(model._buckets[model._default_bucket_key].get_params()[0], + mod2._buckets[mod2._default_bucket_key].get_params()[0]) + mod2.fit( + train_data=data_train, + eval_data=data_val, + eval_metric=mx.metric.Perplexity(invalid_label), # Use Perplexity for multiclass classification. + kvstore='device', + optimizer='sgd', + optimizer_params={'learning_rate': 0.01, + 'momentum': 0, + 'wd': 0.00001}, + initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), + num_epoch=num_epochs, + batch_end_callback=mx.callback.Speedometer(batch_size, 50)) + os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) + + @with_seed() def test_module_reshape(): data = mx.sym.Variable('data') From 919879d207884f6c04a7fba5a101aece811f5399 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Sat, 20 Jul 2019 01:05:10 +0000 Subject: [PATCH 4/5] Add bucketing module conversion --- docs/tutorials/amp/amp_tutorial.md | 2 +- example/rnn/bucketing/README.md | 6 ++++++ example/rnn/bucketing/cudnn_rnn_bucketing.py | 19 ++++++++++++++++--- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/tutorials/amp/amp_tutorial.md b/docs/tutorials/amp/amp_tutorial.md index 9da0505e9ff6..2b747c6c82f6 100644 --- a/docs/tutorials/amp/amp_tutorial.md +++ b/docs/tutorials/amp/amp_tutorial.md @@ -258,6 +258,7 @@ To do inference with mixed precision for a trained model in FP32, you can use th Below, we demonstrate for a gluon model and a symbolic model: - Conversion from FP32 model to mixed precision model. - Run inference on the mixed precision model. +- For AMP conversion of bucketing module please refer to [example/rnn/bucketing/README.md](https://github.com/apache/incubator-mxnet/blob/master/example/rnn/bucketing/README.md). ```python with mx.Context(mx.gpu(0)): @@ -336,7 +337,6 @@ with mx.Context(mx.gpu(0)): mod.save_checkpoint("amp_tutorial_model", 0, remove_amp_cast=False) ``` - ## Current limitations of AMP - AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set diff --git a/example/rnn/bucketing/README.md b/example/rnn/bucketing/README.md index 707370af5a96..d44b23e69b23 100644 --- a/example/rnn/bucketing/README.md +++ b/example/rnn/bucketing/README.md @@ -55,6 +55,12 @@ You can check this improved [Gluon implementation](http://gluon-nlp.mxnet.io/mod $ python3 [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0,1,2,3 +- To run the mixed precision inference for the trained model, you should use the `--dtype`. + + This uses AMP conversion API for bucketing module to convert to a mixed precision module. + + $ python [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0 --model-prefix saved_rnn_model --load-epoch 12 --test --dtype float16 + ### Performance Note: diff --git a/example/rnn/bucketing/cudnn_rnn_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py index 66d5a55c02cb..38275ae3dfb8 100644 --- a/example/rnn/bucketing/cudnn_rnn_bucketing.py +++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py @@ -18,6 +18,7 @@ import numpy as np import mxnet as mx import argparse +from mxnet.contrib.amp import amp parser = argparse.ArgumentParser(description="Train RNN on Sherlock Holmes", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -67,6 +68,10 @@ help='dropout probability (1.0 - keep probability)') parser.add_argument('--rnntype', type=str, default='lstm', help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported') +parser.add_argument('--dtype', type=str, default='float32', + help='if float16 is provided AMP convert model' + 'is used to convert model to mixed precision model' + 'before running inference') #buckets = [32] buckets = [10, 20, 30, 40, 50, 60] @@ -234,12 +239,20 @@ def sym_gen(seq_len): context = contexts) model.bind(data_val.provide_data, data_val.provide_label, for_training=False) - # note here we load using SequentialRNNCell instead of FusedRNNCell. _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch) model.set_params(arg_params, aux_params) - model.score(data_val, mx.metric.Perplexity(invalid_label), - batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) + if args.dtype == "float32": + model.set_params(arg_params, aux_params) + model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) + else: + assert args.dtype == "float16", "Only float32 and float16 are supported currently" + model = amp.convert_bucketing_module(model, target_dtype="float16") + model.bind(data_val.provide_data, data_val.provide_label, + for_training=False) + model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) if __name__ == '__main__': import logging From 306e2dff0475f79be815e2a25778beaf268848dc Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 22 Jul 2019 06:10:38 +0000 Subject: [PATCH 5/5] Fix unused arg --- python/mxnet/module/bucketing_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index a23f470a30d5..dcf2ad7b8e1e 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -581,7 +581,7 @@ def save_checkpoint(self, prefix, epoch, remove_amp_cast=False): nd.save("%s.buckets" % (prefix), nd.array(list(self._buckets.keys()), dtype=np.int32)) @staticmethod - def load(prefix, epoch, load_optimizer_states=False, sym_gen=None, default_bucket_key=None, **kwargs): + def load(prefix, epoch, sym_gen=None, default_bucket_key=None, **kwargs): """Creates a model from previously saved checkpoint. Parameters