Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add AMP Conversion support for BucketingModule #15528

Merged
merged 6 commits into from
Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/amp/amp_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ To do inference with mixed precision for a trained model in FP32, you can use th
Below, we demonstrate for a gluon model and a symbolic model:
- Conversion from FP32 model to mixed precision model.
- Run inference on the mixed precision model.
- For AMP conversion of bucketing module please refer to [example/rnn/bucketing/README.md](https://github.com/apache/incubator-mxnet/blob/master/example/rnn/bucketing/README.md).

```python
with mx.Context(mx.gpu(0)):
Expand Down Expand Up @@ -336,7 +337,6 @@ with mx.Context(mx.gpu(0)):
mod.save_checkpoint("amp_tutorial_model", 0, remove_amp_cast=False)
```


## Current limitations of AMP

- AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set
Expand Down
6 changes: 6 additions & 0 deletions example/rnn/bucketing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ You can check this improved [Gluon implementation](http://gluon-nlp.mxnet.io/mod

$ python3 [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0,1,2,3

- To run the mixed precision inference for the trained model, you should use the `--dtype`.

This uses AMP conversion API for bucketing module to convert to a mixed precision module.

$ python [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0 --model-prefix saved_rnn_model --load-epoch 12 --test --dtype float16


### Performance Note:

Expand Down
19 changes: 16 additions & 3 deletions example/rnn/bucketing/cudnn_rnn_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import mxnet as mx
import argparse
from mxnet.contrib.amp import amp

parser = argparse.ArgumentParser(description="Train RNN on Sherlock Holmes",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -67,6 +68,10 @@
help='dropout probability (1.0 - keep probability)')
parser.add_argument('--rnntype', type=str, default='lstm',
help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported')
parser.add_argument('--dtype', type=str, default='float32',
help='if float16 is provided AMP convert model'
'is used to convert model to mixed precision model'
'before running inference')
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]
Expand Down Expand Up @@ -234,12 +239,20 @@ def sym_gen(seq_len):
context = contexts)
model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

# note here we load using SequentialRNNCell instead of FusedRNNCell.
_, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
model.set_params(arg_params, aux_params)

model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))
if args.dtype == "float32":
model.set_params(arg_params, aux_params)
model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))
else:
assert args.dtype == "float16", "Only float32 and float16 are supported currently"
model = amp.convert_bucketing_module(model, target_dtype="float16")
model.bind(data_val.provide_data, data_val.provide_label,
for_training=False)
model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

if __name__ == '__main__':
import logging
Expand Down
64 changes: 64 additions & 0 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ... import symbol
from ...context import gpu
from ...symbol import Symbol
from ...module import BucketingModule
from ...symbol import contrib as symbol_contrib
from ... import ndarray
from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
Expand Down Expand Up @@ -672,6 +673,69 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
ret.collect_params().load_dict(arg_dict, ctx=ctx)
return ret

def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None,
excluded_sym_names=None, cast_optional_params=False):
"""Given a bucketing module cast the symbols associated with the BucketingModule
and params if cast_optional_params is set.
bucketing_mod : BucketingModule instance
target_dtype : str
Currently only supports float16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_dtype_ops : list of strs
Override the list of operator names casted to target_dtype.
If None, uses the framework's default list to be casted to target dtype.
fp32_ops : list of strs
Override the lists of operator names casted to FP32.
If None, uses the framework's default list to be casted to FP32.
widest_dtype_ops : list of strs
A list of op names provided by user which should run in widest precision among its inputs.
If None, uses the framework's default list of widest_precision_ops.
conditional_fp32_ops : list of (string, string, list of string)
Override the list of operators to be casted to FP32.
The format of the list is
(name of the function, name of the parameter,
list of values of the parameter that make the operator to be casted to
fp32)
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being executed in lower precision.
cast_optional_params : bool, default False
Whether to cast the arg_params and aux_params that don't require to be in FP16
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
assert isinstance(bucketing_mod, BucketingModule), "module should be instance of bucketing module"
assert len(bucketing_mod._buckets) > 0, "Bucketing Module should not be empty"

sym_dict = {}
assert bucketing_mod.params_initialized, \
"bucketing_mod params should be initialized for mixed precision conversion"
arg_params, aux_params = bucketing_mod._curr_module._arg_params, bucketing_mod._curr_module._aux_params
for key, val in bucketing_mod._buckets.items():
sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol,
arg_params,
aux_params,
target_dtype=target_dtype,
target_dtype_ops=target_dtype_ops,
fp32_ops=fp32_ops,
conditional_fp32_ops=conditional_fp32_ops,
excluded_sym_names=excluded_sym_names,
cast_optional_params=cast_optional_params)
result_mod = BucketingModule.load_dict(sym_dict,
sym_gen=bucketing_mod._sym_gen,
arg_params=result_arg_params,
aux_params=result_aux_params,
default_bucket_key=bucketing_mod._default_bucket_key,
logger=bucketing_mod.logger,
context=bucketing_mod._context,
work_load_list=bucketing_mod._work_load_list,
fixed_param_names=bucketing_mod._fixed_param_names,
state_names=bucketing_mod._state_names,
group2ctxs=bucketing_mod._group2ctxs,
compression_params=bucketing_mod._compression_params)
return result_mod

def list_fp16_ops():
"""Get the default list of FP16 ops for AMP
"""
Expand Down
33 changes: 17 additions & 16 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,22 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params, remove_amp_ca
logging.info('Saved checkpoint to \"%s\"', param_name)


def load_params(prefix, epoch):
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
"""Load params from a file
"""
save_dict = nd.load("%s-%04d.params" % (prefix, epoch))
arg_params = {}
aux_params = {}
if not save_dict:
logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch))
for k, v in save_dict.items():
tp, name = k.split(":", 1)
if tp == "arg":
arg_params[name] = v
if tp == "aux":
aux_params[name] = v
return (arg_params, aux_params)

def load_checkpoint(prefix, epoch):
"""Load model checkpoint from file.

Expand All @@ -448,22 +464,7 @@ def load_checkpoint(prefix, epoch):
- Parameters will be loaded from ``prefix-epoch.params``.
"""
symbol = sym.load('%s-symbol.json' % prefix)
save_dict = nd.load('%s-%04d.params' % (prefix, epoch))
arg_params = {}
aux_params = {}
#load any params in the dict, skip if params are empty
if not save_dict:
logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch))
else:
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
elif tp == 'aux':
aux_params[name] = v
else:
logging.warning("Params file '%s' contains unknown param '%s'",
'%s-%04d.params' % (prefix, epoch), k)
arg_params, aux_params = load_params(prefix, epoch)
return (symbol, arg_params, aux_params)

from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position
Expand Down
Loading