Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[v0.10.x] Softmax optimization & bertpass refactor (#1565)
Browse files Browse the repository at this point in the history
* fix export

* Separate graph passes

* Add softmax mask pass optimization

* Add graph passes to finetune_squad

* Fix bias

* Update finetune_squad

* Change filename and update documentation

* Add support for graph passes in finetune_classifier

* Fix for mxnet 1.7

* Fix review

* Add tests for finetune_* scripts

* Fix review

* sanity

* Refactor bias recalculation loops

* Remove nightly version of mxnet from warning
  • Loading branch information
bgawrych authored May 27, 2021
1 parent 02b5b72 commit b4d7c0f
Show file tree
Hide file tree
Showing 8 changed files with 574 additions and 62 deletions.
317 changes: 292 additions & 25 deletions scripts/bert/bertpass_gpu.cc → scripts/bert/bertpass.cc

Large diffs are not rendered by default.

69 changes: 41 additions & 28 deletions scripts/bert/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,19 @@
default='float32',
help='Data type used for training. Either float32 or float16')

parser.add_argument('--custom_pass',

parser.add_argument('--custom_pass_lib',
type=str,
default=None,
help='Specify a custom graph pass for the network (library),'
help='Specify a custom graph pass library for the network,'
'allowing to customize the graph')

parser.add_argument('--custom_passes',
type=str,
nargs='+',
default=None,
help='Specify a list of custom graph pass for the network to apply')

parser.add_argument('--max_iters',
type=int,
default=None,
Expand Down Expand Up @@ -312,34 +319,40 @@ def export(prefix):
assert os.path.isfile(prefix + '-symbol.json')
assert os.path.isfile(prefix + '-0000.params')

if args.custom_pass is not None:
if args.custom_pass_lib is not None \
and args.custom_passes is not None:
# load library
libpath = os.path.abspath(args.custom_pass)
libpath = os.path.abspath(args.custom_pass_lib)
mx.library.load(libpath)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, 0)

arg_array = arg_params
arg_array['data0'] = mx.nd.ones((test_batch_size, seq_length), dtype='float32')
arg_array['data1'] = mx.nd.ones((test_batch_size, seq_length), dtype='float32')
arg_array['data2'] = mx.nd.ones((test_batch_size, ), dtype='float32')
custom_sym = sym.optimize_for('custom_pass', arg_array, aux_params)
if (mx.__version__ <= '1.7.0'):
nheads = 12
if args.bert_model == 'bert_24_1024_16':
nheads = 24
for i in range(nheads):
basename = 'bertencoder0_transformer' + str(i) + '_dotproductselfattentioncell0'
arg_array.pop(basename + '_query_weight')
arg_array.pop(basename + '_key_weight')
arg_array.pop(basename + '_value_weight')
arg_array.pop(basename + '_query_bias')
arg_array.pop(basename + '_key_bias')
arg_array.pop(basename + '_value_bias')
arg_array.pop('data0')
arg_array.pop('data1')
arg_array.pop('data2')

mx.model.save_checkpoint(prefix, 0, custom_sym, arg_params, aux_params)

for graph_pass in args.custom_passes:
log.info('Applying custom graph pass %s', graph_pass)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, 0)

arg_array = arg_params
arg_array['data0'] = mx.nd.ones((test_batch_size, seq_length), dtype='float32')
arg_array['data1'] = mx.nd.ones((test_batch_size, seq_length), dtype='float32')
arg_array['data2'] = mx.nd.ones((test_batch_size, ), dtype='float32')
custom_sym = sym.optimize_for(graph_pass, arg_array, aux_params)
if (graph_pass == 'MHAInterleave' and mx.__version__ <= '1.7.0'):
nheads = 12
if args.bert_model == 'bert_24_1024_16':
nheads = 24
for i in range(nheads):
basename = 'bertencoder0_transformer' + str(i) + '_dotproductselfattentioncell0'
arg_array.pop(basename + '_query_weight')
arg_array.pop(basename + '_key_weight')
arg_array.pop(basename + '_value_weight')
arg_array.pop(basename + '_query_bias')
arg_array.pop(basename + '_key_bias')
arg_array.pop(basename + '_value_bias')
arg_array.pop('data0')
arg_array.pop('data1')
arg_array.pop('data2')

mx.model.save_checkpoint(prefix, 0, custom_sym, arg_params, aux_params)
elif not (args.custom_pass_lib is None and args.custom_passes is None):
warnings.warn('Graph passes skipped! To apply custom pass provide both library path and pass names')

# Function to preprocess dataset to test, which depends on the task
def preprocess_data(tokenizer, task):
Expand Down
50 changes: 48 additions & 2 deletions scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from gluonnlp.data.bert.glue import truncate_seqs_equal, concat_sequences
from gluonnlp.model import BERTClassifier, RoBERTaClassifier
from gluonnlp.calibration import BertLayerCollector
from utils import QuantizableNet, QuantizableRobertaNet, run_graphpass, RobertaCalibIter

nlp.utils.check_version('0.9', warning_only=True)

Expand Down Expand Up @@ -185,6 +186,12 @@
choices=['none', 'naive', 'entropy', 'customize'],
help='calibration mode used for generating calibration table '
'for the quantized symbol.')
parser.add_argument('--custom_pass_lib', type=str, default=None,
help='Specify a custom graph pass library for the network,'
'allowing to customize the graph')

parser.add_argument('--custom_passes', nargs='+', type=str, default=None,
help='Specify names of custom graph passes for the network to apply from `custom_pass_lib`')

args = parser.parse_args()

Expand Down Expand Up @@ -426,7 +433,31 @@ def calibration(net, dev_data_list, num_calib_batches, quantized_dtype, calib_mo
assert ctx == mx.cpu(), \
'Currently only supports CPU with MKL-DNN backend.'
logging.info('Now we are doing calibration on dev with %s.', ctx)

model_name = args.bert_model
run_softmax_pass = False
if args.custom_passes is None:
args.custom_passes = []

if 'MaskSoftmax' in args.custom_passes:
run_softmax_pass = True
args.custom_passes.remove('MaskSoftmax')

if args.custom_pass_lib is not None:
# load library
libpath = os.path.abspath(args.custom_pass_lib)
mx.library.load(libpath)
for pass_name in args.custom_passes:
net = run_graphpass(net, model_name, dev_batch_size, args.max_len,
pass_name, use_roberta=use_roberta)

for _, dev_data in dev_data_list:
if use_roberta:
dev_data = RobertaCalibIter(dev_data)
net = QuantizableRobertaNet(net)
else:
net = QuantizableNet(net)

collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=logging)
num_calib_examples = dev_batch_size * num_calib_batches
net = mx.contrib.quantization.quantize_net_v2(net, quantized_dtype=quantized_dtype,
Expand All @@ -439,7 +470,21 @@ def calibration(net, dev_data_list, num_calib_batches, quantized_dtype, calib_mo
ctx=ctx,
LayerOutputCollector=collector,
logger=logging)
if run_softmax_pass:
net = run_graphpass(net, model_name, dev_batch_size, args.max_len,
'MaskSoftmax', use_roberta=use_roberta)
# save params
net.hybridize()
input_ids = mx.nd.zeros((dev_batch_size, args.max_len))
segment_ids = mx.nd.zeros((dev_batch_size, args.max_len))
valid_length = mx.nd.zeros((dev_batch_size ,))
if use_roberta:
out = net(input_ids, valid_length)
else:
out = net(input_ids, segment_ids, valid_length)

out.wait_to_read()

ckpt_name = 'model_bert_{0}_quantized_{1}'.format(task_name, calib_mode)
params_saved = os.path.join(output_dir, ckpt_name)
net.export(params_saved, epoch=0)
Expand Down Expand Up @@ -698,8 +743,9 @@ def evaluate(loader_dev, metric, segment):
num_calib_batches,
quantized_dtype,
calib_mode)
except AttributeError:
except AttributeError as e:
warnings.warn(e)
nlp.utils.version.check_version('1.7.0', warning_only=True, library=mx)
warnings.warn('INT8 Quantization for BERT need mxnet-mkl >= 1.6.0b20200115')
warnings.warn('INT8 Quantization for BERT need mxnet >= 1.7')
else:
train(task.metrics)
48 changes: 45 additions & 3 deletions scripts/bert/finetune_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from gluonnlp.calibration import BertLayerCollector
from model.qa import BertForQALoss, BertForQA
from bert_qa_evaluate import get_F1_EM, predict, PredResult
from utils import QuantizableNet, run_graphpass

np.random.seed(6)
random.seed(6)
Expand Down Expand Up @@ -234,7 +235,7 @@
parser.add_argument('--only_calibration', action='store_true',
help='quantize model')

parser.add_argument('--num_calib_batches', type=int, default=10,
parser.add_argument('--num_calib_batches', type=int, default=1,
help='number of batches for calibration')

parser.add_argument('--quantized_dtype', type=str, default='auto',
Expand All @@ -246,6 +247,18 @@
help='calibration mode used for generating calibration table '
'for the quantized symbol.')

parser.add_argument('--custom_pass_lib',
type=str,
default=None,
help='Specify a custom graph pass library for the network,'
'allowing to customize the graph')

parser.add_argument('--custom_passes',
nargs='+',
type=str,
default=None,
help='Specify names of custom graph passes for the network to apply from `custom_pass_lib`')

args = parser.parse_args()

output_dir = args.output_dir
Expand Down Expand Up @@ -581,11 +594,30 @@ def calibration(net, num_calib_batches, quantized_dtype, calib_mode):
num_workers=4, batch_size=test_batch_size,
shuffle=False, last_batch='keep')

model_name = args.bert_model
run_softmax_pass = False
if args.custom_passes is None:
args.custom_passes = []

if 'MaskSoftmax' in args.custom_passes:
run_softmax_pass = True
args.custom_passes.remove('MaskSoftmax')

if args.custom_pass_lib is not None:
# load library
libpath = os.path.abspath(args.custom_pass_lib)
mx.library.load(libpath)
for pass_name in args.custom_passes:
net = run_graphpass(net, model_name,
test_batch_size, max_seq_length,
pass_name)

assert ctx == mx.cpu(), \
'Currently only supports CPU with MKL-DNN backend.'
log.info('Now we are doing calibration on dev with %s.', ctx)
collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=log)
num_calib_examples = test_batch_size * num_calib_batches
net = QuantizableNet(net)
net = mx.contrib.quantization.quantize_net_v2(net, quantized_dtype=quantized_dtype,
exclude_layers=[],
quantize_mode='smart',
Expand All @@ -596,7 +628,16 @@ def calibration(net, num_calib_batches, quantized_dtype, calib_mode):
ctx=ctx,
LayerOutputCollector=collector,
logger=log)

if run_softmax_pass:
net = run_graphpass(net, model_name, test_batch_size, max_seq_length, 'MaskSoftmax')

# save params
net.hybridize()
out = net(mx.nd.ones((test_batch_size, max_seq_length)),
mx.nd.zeros((test_batch_size, max_seq_length)),
mx.nd.zeros((test_batch_size ,)))
out.wait_to_read()
ckpt_name = 'model_bert_squad_quantized_{0}'.format(calib_mode)
params_saved = os.path.join(output_dir, ckpt_name)
net.export(params_saved, epoch=0)
Expand Down Expand Up @@ -854,9 +895,10 @@ def preprocess_dataset(tokenizer,
num_calib_batches,
quantized_dtype,
calib_mode)
except AttributeError:
except AttributeError as e:
nlp.utils.version.check_version('1.7.0', warning_only=True, library=mx)
warnings.warn('INT8 Quantization for BERT need mxnet-mkl >= 1.6.0b20200115')
warnings.warn('INT8 Quantization for BERT needs mxnet >= 1.7')
warnings.warn(e)
elif not only_predict:
train()
evaluate()
Expand Down
13 changes: 11 additions & 2 deletions scripts/bert/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,16 @@ Question Answering
| SQuAD 1.1 | bert_12_768_12 | 81.18 | 80.32 | 88.58 | 88.10 |`command <https://github.com/dmlc/web-data/blob/master/gluonnlp/logs/bert/calibration_squad1.1_base_mx1.6.0b20200125.sh>`__ |
+-----------+-------------------+---------+---------+---------+---------+----------------------------------------------------------------------------------------------------------------------------+

For all model settings above, we use a subset of evaluation dataset for calibration.
For all model settings above, subset of evaluation dataset for calibration was used.

Using optimization graph passes is recommended to boost performance of inference even more. To deploy calibrated model optimized with graph passes following arguments can be used
--custom_pass_lib [graph_pass_library_path] and --custom_passes [graph_passes_name]. E.g.:

.. code-block:: console
$ python3 finetune_squad.py --only_calibration --model_parameters ./output_dir/net.params --custom_pass_lib bertpass_lib.so --custom_passes MaskSoftmax MHAInterleave
Graph pass library can be built with setup.py script.

Pre-training from Scratch
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -348,7 +357,7 @@ Once the model is exported, you can import the model by setting --only_infer, an
The batch size can be specified via --test_batch_size option, and accuracy can be checked setting --check_accuracy.

When using GPU and data type FP16 (--dtype float16), we recommend to use MXNET_FC_TRUE_FP16=1 for boosting performance.
Moreover, you can use a custom graph pass for BERT, via --custom_pass [custom_pass_file], to improve the performance on GPU. To generate the pass you can run setup.py within the BERT scripts directory. These GPU optimizations require MXNet version 1.7 or higher.
Moreover, custom graph pass for BERT can be used via --custom_pass_lib [custom_pass_library] and --custom_passes [space_seperated_names_of_passes_to_apply], to improve the performance on GPU. To generate the pass library setup.py script can be run within the BERT scripts directory. These GPU optimizations require MXNet version 1.7 or higher.


BERT for Sentence or Tokens Embedding
Expand Down
2 changes: 1 addition & 1 deletion scripts/bert/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def CompileBERTCustomPass():
"""Compiles custom graph pass for BERT into a library. It offers performance improvements"""
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
log = logging.getLogger()
input_pass_file = 'bertpass_gpu.cc'
input_pass_file = 'bertpass.cc'
out_lib_file = 'bertpass_lib.so'
log.info(' ... compiling BERT custom graph pass into %s', out_lib_file)
mxnet_path = pathlib.Path(mxnet.__file__).parent.absolute()
Expand Down
Loading

0 comments on commit b4d7c0f

Please sign in to comment.