From 8e35f4832a5340a4b1fb9c385fb2f81d90029986 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Thu, 13 Jan 2022 20:07:48 +0100 Subject: [PATCH 01/12] Add inc to quantization documentation --- .../backend/mkldnn/mkldnn_quantization.md | 140 +++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 8c15af267cd4..44ac07ef84f8 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -15,7 +15,7 @@ -# Quantize with MKL-DNN backend +# Native model quantization with MKL-DNN backend This document is to introduce how to quantize the customer models from FP32 to INT8 with Apache/MXNet toolkit and APIs under Intel CPU. @@ -255,4 +255,142 @@ BTW, You can also modify the `min_calib_range` and `max_calib_range` in the JSON MXNet also supports deploy quantized models with C++. Refer [MXNet C++ Package](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/README.md) for more details. +# Model quantization using Intel® Neural Compressor + +The accuracy of the model can decrease drastically as a result of quantization. In such cases we could try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.) but for bigger models, this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using a set of several tuning heuristics, finding the best quantization configuration that satisfies the specified accuracy requirement. + +**NOTE:** + +Most tuning strategies will try different configurations on an evaluation dataset in order to find out how each layer affects the accuracy of the model. This means that for larger models, it may take a long time to find a solution (as the tuning space is usually larger then and the evaluation itself takes longer). + +## Installation and Prerequisites + +- Install MXNet with MKLDNN enabled as described in the [previous section](#Installation-and-Prerequisites). + +- Install Intel® Neural Compressor: + + Supported python versions are: 3.6, 3.7, 3.8, 3.9. + ```bash + # install stable version from pip + pip install neural-compressor + + # install nightly version from pip + pip install -i https://test.pypi.org/simple/ neural-compressor + + # install stable version from from conda + conda install neural-compressor -c conda-forge -c intel + ``` + +## Configuration file + +You can customize the quantization tuning process in the yaml configuration file. Here is a simple example: + +file: cnn.yaml +```yaml +version: 1.0 + +model: + name: cnn + framework: mxnet + +quantization: + calibration: + sampling_size: 160 # number of samples for calibration + +tuning: + strategy: + name: basic + accuracy_criterion: + relative: 0.01 + exit_policy: + timeout: 0 # end on the first configuration that meet the accuracy criterion + random_seed: 9527 +``` + +We are using the `basic` strategy, but you could also try out different ones. [Here](https://github.com/intel/neural-compressor/blob/master/docs/tuning_strategies.md) you can find a list of strategies available in INC and details of how they work. You can also add your own strategy if the existing ones do not suit your needs. + +For more information about the configuration file, see the [template](https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml) from the original INC repo. Keep in mind that only post training quantization is currently supported for MXNet. + +## Model quantization and tuning + +In general, Intel® Neural Compressor requires 4 elements in order to run: +1. Config file - like the example above +2. Model to be quantized +3. Calibration dataloader +4. Evaluation function - a function that takes a model as an argument and returns the accuracy that it achieves on a certain evaluation dataset. + +Here is how to achieve the quantization using INC: + +1. Get the model + +```python +import logging +import mxnet as mx +from mxnet.gluon.model_zoo import vision + +logging.basicConfig() +logger = logging.getLogger('logger') +logger.setLevel(logging.INFO) + +batch_shape = (1, 3, 224, 224) +resnet18 = vision.resnet18_v1(pretrained=True) +``` + +2. Prepare the dataset: + +```python +mx.test_utils.download('https://data.mxnet.io/data/val_256_q90.rec', 'data/val_256_q90.rec') + +batch_size = 16 +mean_std = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939, + 'std_r': 58.393, 'std_g': 57.12, 'std_b': 57.375} + +data = mx.io.ImageRecordIter(path_imgrec='data/val_256_q90.rec', + batch_size=batch_size, + data_shape=batch_shape[1:], + rand_crop=False, + rand_mirror=False, + shuffle=False, + **mean_std) +data.batch_size = batch_size +``` + +3. Prepare the evaluation function: + +```python +eval_samples = batch_size*10 + +def eval_func(model): + data.reset() + metric = mx.metric.Accuracy() + for i, batch in enumerate(data): + if i * batch_size >= eval_samples: + break + x = batch.data[0].as_in_context(mx.cpu()) + label = batch.label[0].as_in_context(mx.cpu()) + outputs = model.forward(x) + metric.update(label, outputs) + return metric.get()[1] +``` + +4. Run Intel® Neural Compressor: + +```python +from neural_compressor.experimental import Quantization +quantizer = Quantization("./cnn.yaml") +quantizer.model = resnet18 +quantizer.calib_dataloader = data +quantizer.eval_func = eval_func +qnet = quantizer().model +``` + +## Tips +- In order to get a solution that generalizes well, evaluate the model (in eval_func) on a representative dataset. +- Using `history.snapshot` file (generated by INC) you can recover any model that was generated during the tuning process: + ```python + from neural_compressor.utils.utility import recover + + qmodel = recover(f32_model, 'nc_workspace//history.snapshot', 1) + ``` + From 3a148e217ea1953c56b819a84394f9db2c661368 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Thu, 13 Jan 2022 21:27:20 +0100 Subject: [PATCH 02/12] Minor fixes --- .../tutorials/performance/backend/mkldnn/mkldnn_quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 44ac07ef84f8..997492369c6c 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -309,7 +309,7 @@ tuning: We are using the `basic` strategy, but you could also try out different ones. [Here](https://github.com/intel/neural-compressor/blob/master/docs/tuning_strategies.md) you can find a list of strategies available in INC and details of how they work. You can also add your own strategy if the existing ones do not suit your needs. -For more information about the configuration file, see the [template](https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml) from the original INC repo. Keep in mind that only post training quantization is currently supported for MXNet. +For more information about the configuration file, see the [template](https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml) from the official INC repo. Keep in mind that only the `post training quantization` is currently supported for MXNet. ## Model quantization and tuning From a7625ab983e6083641ecd98e7bd0a0800f397436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Fri, 14 Jan 2022 14:55:48 +0100 Subject: [PATCH 03/12] review fixes --- .../performance/backend/mkldnn/mkldnn_quantization.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 997492369c6c..22d5475d13c9 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -257,11 +257,11 @@ MXNet also supports deploy quantized models with C++. Refer [MXNet C++ Package]( # Model quantization using Intel® Neural Compressor -The accuracy of the model can decrease drastically as a result of quantization. In such cases we could try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.) but for bigger models, this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using a set of several tuning heuristics, finding the best quantization configuration that satisfies the specified accuracy requirement. +The accuracy of a model can decrease drastically as a result of quantization. In such cases we could try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.) but for bigger models, this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using a set of several tuning heuristics, finding the best quantization configuration that satisfies the specified accuracy requirement. **NOTE:** -Most tuning strategies will try different configurations on an evaluation dataset in order to find out how each layer affects the accuracy of the model. This means that for larger models, it may take a long time to find a solution (as the tuning space is usually larger then and the evaluation itself takes longer). +Most tuning strategies will try different configurations on an evaluation dataset in order to find out how each layer affects the accuracy of the model. This means that for larger models, it may take a long time to find a solution (as the tuning space is usually larger and the evaluation itself takes longer). ## Installation and Prerequisites @@ -277,13 +277,13 @@ Most tuning strategies will try different configurations on an evaluation datase # install nightly version from pip pip install -i https://test.pypi.org/simple/ neural-compressor - # install stable version from from conda + # install stable version from conda conda install neural-compressor -c conda-forge -c intel ``` ## Configuration file -You can customize the quantization tuning process in the yaml configuration file. Here is a simple example: +Quantization tuning process can be customized in the yaml configuation file. Here is a simple example: file: cnn.yaml ```yaml @@ -317,7 +317,7 @@ In general, Intel® Neural Compressor requires 4 elements in order to run: 1. Config file - like the example above 2. Model to be quantized 3. Calibration dataloader -4. Evaluation function - a function that takes a model as an argument and returns the accuracy that it achieves on a certain evaluation dataset. +4. Evaluation function - a function that takes a model as an argument and returns the accuracy it achieves on a certain evaluation data set. Here is how to achieve the quantization using INC: From 3cf36c6d9e2ef1ff7f08d40a86131c764c9548c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Fri, 14 Jan 2022 14:58:28 +0100 Subject: [PATCH 04/12] fix --- .../tutorials/performance/backend/mkldnn/mkldnn_quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 22d5475d13c9..4330a36319cf 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -283,7 +283,7 @@ Most tuning strategies will try different configurations on an evaluation datase ## Configuration file -Quantization tuning process can be customized in the yaml configuation file. Here is a simple example: +Quantization tuning process can be customized in the yaml configuration file. Here is a simple example: file: cnn.yaml ```yaml From 80d4f3df39ea8b30c9596126b13191b96416da9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Fri, 14 Jan 2022 14:59:48 +0100 Subject: [PATCH 05/12] fix2 --- .../tutorials/performance/backend/mkldnn/mkldnn_quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 4330a36319cf..3f6a6424f4ce 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -317,7 +317,7 @@ In general, Intel® Neural Compressor requires 4 elements in order to run: 1. Config file - like the example above 2. Model to be quantized 3. Calibration dataloader -4. Evaluation function - a function that takes a model as an argument and returns the accuracy it achieves on a certain evaluation data set. +4. Evaluation function - a function that takes a model as an argument and returns the accuracy it achieves on a certain evaluation dataset. Here is how to achieve the quantization using INC: From 65849a0374c0a7ad29fa16c7afac445b292311cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Mon, 24 Jan 2022 13:32:29 +0100 Subject: [PATCH 06/12] Add BERT example with results, review fixes --- .../backend/mkldnn/mkldnn_quantization.md | 172 ++++++++- example/quantization_inc/BERT_MRPC/bert.yaml | 36 ++ example/quantization_inc/BERT_MRPC/details.py | 328 ++++++++++++++++++ example/quantization_inc/BERT_MRPC/main.py | 56 +++ 4 files changed, 581 insertions(+), 11 deletions(-) create mode 100644 example/quantization_inc/BERT_MRPC/bert.yaml create mode 100644 example/quantization_inc/BERT_MRPC/details.py create mode 100644 example/quantization_inc/BERT_MRPC/main.py diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 3f6a6424f4ce..009f5c78e209 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -15,7 +15,7 @@ -# Native model quantization with MKL-DNN backend +# Quantize with MKL-DNN backend This document is to introduce how to quantize the customer models from FP32 to INT8 with Apache/MXNet toolkit and APIs under Intel CPU. @@ -155,7 +155,7 @@ cqsym, cqarg_params, aux_params, collector = quantize_graph(sym=sym, arg_params= quantized_dtype=quantized_dtype, logger=logger) # download imagenet validation dataset -mx.test_utils.download('https://data.mxnet.io/data/val_256_q90.rec', 'dataset.rec') +mx.test_utils.download('http://data.mxnet.io/data/val_256_q90.rec', 'dataset.rec') # set rgb info for data mean_std = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939, 'std_r': 58.393, 'std_g': 57.12, 'std_b': 57.375} # set batch size @@ -243,6 +243,8 @@ BTW, You can also modify the `min_calib_range` and `max_calib_range` in the JSON - Change calibration dataset by setting different `num_calib_batches` or shuffle your validation dataset; +- Use Intel® Neural Compressor ([see below](#Improving-accuracy-with-Intel-Neural-Compressor)) + #### Performance Tuning - Keep sure to perform graph fusion before quantization; @@ -255,9 +257,9 @@ BTW, You can also modify the `min_calib_range` and `max_calib_range` in the JSON MXNet also supports deploy quantized models with C++. Refer [MXNet C++ Package](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/README.md) for more details. -# Model quantization using Intel® Neural Compressor +# Improving accuracy with Intel® Neural Compressor -The accuracy of a model can decrease drastically as a result of quantization. In such cases we could try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.) but for bigger models, this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using a set of several tuning heuristics, finding the best quantization configuration that satisfies the specified accuracy requirement. +The accuracy of a model can decrease as a result of quantization. When the accuracy drop is significant, we can try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.) but for bigger models, this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using several tuning heuristics, finding the quantization configuration that satisfies the specified accuracy requirement. **NOTE:** @@ -269,7 +271,8 @@ Most tuning strategies will try different configurations on an evaluation datase - Install Intel® Neural Compressor: - Supported python versions are: 3.6, 3.7, 3.8, 3.9. + Use one of the commands below to install INC (supported python versions are: 3.6, 3.7, 3.8, 3.9): + ```bash # install stable version from pip pip install neural-compressor @@ -285,8 +288,9 @@ Most tuning strategies will try different configurations on an evaluation datase Quantization tuning process can be customized in the yaml configuration file. Here is a simple example: -file: cnn.yaml ```yaml +# cnn.yaml + version: 1.0 model: @@ -303,12 +307,14 @@ tuning: accuracy_criterion: relative: 0.01 exit_policy: - timeout: 0 # end on the first configuration that meet the accuracy criterion + timeout: 0 random_seed: 9527 ``` We are using the `basic` strategy, but you could also try out different ones. [Here](https://github.com/intel/neural-compressor/blob/master/docs/tuning_strategies.md) you can find a list of strategies available in INC and details of how they work. You can also add your own strategy if the existing ones do not suit your needs. +Since the value of `timeout` is 0, INC will run until it finds a configuration that satisfy the accuracy criterion and then exit. Depending on the strategy this may not be ideal, as sometimes it would be better to further explore the tuning space to find a superior configuration both in terms of accuracy and speed. To achive this, we can set a specific timeout - how long (in seconds) do we want INC to run. + For more information about the configuration file, see the [template](https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml) from the official INC repo. Keep in mind that only the `post training quantization` is currently supported for MXNet. ## Model quantization and tuning @@ -319,7 +325,9 @@ In general, Intel® Neural Compressor requires 4 elements in order to run: 3. Calibration dataloader 4. Evaluation function - a function that takes a model as an argument and returns the accuracy it achieves on a certain evaluation dataset. -Here is how to achieve the quantization using INC: +### Quantizing ResNet + +The previous sections described how to quantize ResNet using the native MXNet quantization. This example shows how we can achieve the same (plus the auto-tuning) using INC. 1. Get the model @@ -339,7 +347,7 @@ resnet18 = vision.resnet18_v1(pretrained=True) 2. Prepare the dataset: ```python -mx.test_utils.download('https://data.mxnet.io/data/val_256_q90.rec', 'data/val_256_q90.rec') +mx.test_utils.download('http://data.mxnet.io/data/val_256_q90.rec', 'data/val_256_q90.rec') batch_size = 16 mean_std = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939, @@ -384,13 +392,155 @@ quantizer.eval_func = eval_func qnet = quantizer().model ``` +Since this model already achieves good accuracy using native quantization (less than 1% accuracy drop), for the given configuration file, INC will end on the first configuration, quantizing all layers using `naive` calibration mode for each. To see the true potential of INC, we need a model which suffers from a larger accuracy drop after quantization. + +### Quantizing BERT + +This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the code, model and task specific boilerplate has been moved to the `details.py` file. + +Here is the configuration file: +```yaml +version: 1.0 + +model: + name: bert + framework: mxnet + +quantization: + calibration: + sampling_size: 320 # number of samples for calibration + +tuning: + strategy: + name: basic + accuracy_criterion: + relative: 0.01 + exit_policy: + timeout: 0 + max_trials: 9999 # default is 100 + random_seed: 9527 +``` + +And here is the script: + +```python +from pathlib import Path +from functools import partial + +import details +from neural_compressor.experimental import Quantization, common + +# constants +INC_CONFIG_PATH = Path('./bert.yaml').resolve() +PARAMS_PATH = Path('./bert_mrpc.params').resolve() +OUTPUT_DIR_PATH = Path('./output/').resolve() +OUTPUT_MODEL_PATH = OUTPUT_DIR_PATH/'quantized_model' +OUTPUT_DIR_PATH.mkdir(parents=True, exist_ok=True) + +# Prepare the dataloaders (calib_dataloader is same as train_dataloader but without shuffling) +train_dataloader, dev_dataloader, calib_dataloader = details.preprocess_data() + +# Get the model +model = details.BERTModel(details.BACKBONE, dropout=0.1, num_classes=details.NUM_CLASSES) +model.hybridize(static_alloc=True) + +# finetune or load the parameters of already finetuned model +if not PARAMS_PATH.exists(): + model = details.finetune(model, train_dataloader, dev_dataloader, OUTPUT_DIR_PATH) + model.save_parameters(str(PARAMS_PATH)) +else: + model.load_parameters(str(PARAMS_PATH), ctx=details.CTX, cast_dtype=True) + +# run INC +calib_dataloader.batch_size = details.BATCH_SIZE +eval_func = partial(details.evaluate, dataloader=dev_dataloader) + +quantizer = Quantization(str(INC_CONFIG_PATH)) # 1. Config file +quantizer.model = common.Model(model) # 2. Model to be quantized +quantizer.calib_dataloader = calib_dataloader # 3. Calibration dataloader +quantizer.eval_func = eval_func # 4. Evaluation function +quantized_model = quantizer.fit().model + +# save the quantized model +quantized_model.export(str(OUTPUT_MODEL_PATH)) +``` + +With the evaluation function hidden in the `details.py` file: + +```python +def evaluate(model, dataloader): + metric = METRIC() + for batch in dataloader: + input_ids, segment_ids, valid_length, label = batch + input_ids = input_ids.as_in_context(CTX) + segment_ids = segment_ids.as_in_context(CTX) + valid_length = valid_length.as_in_context(CTX) + label = label.as_in_context(CTX).reshape((-1)) + + out = model(input_ids, segment_ids, valid_length) + metric.update([label], [out]) + + metric_name, metric_val = metric.get() + return metric_val +``` + +For comparision, this is how one could quantize this model using MXNet native quantization (this function is also located in the `details.py` file): + +```python +def native_quantization(model, calib_dataloader, dev_dataloader): + quantized_model = quantize_net_v2(model, + quantize_mode='smart', + calib_data=calib_dataloader, + calib_mode='naive', + num_calib_examples=BATCH_SIZE*10) + print('Native quantization results: {}'.format(evaluate(quantized_model, dev_dataloader))) + return quantized_model +``` + +For complete code, see this example on the [official GitHub repository](TODO link). + +#### Results: + +Results of the f32 model on the dev split: Accuracy = 0.8529, F1 = 0.8940 + +Results of quantized models on the dev split: + +| Quantization method | Accuracy | F1 | Relative accuracy loss [%] | Calibration/tuning time [s] | +|:----------------------------:|:--------:|:------:|:--------------------------:|:----------------------------:| +| Native 'naive', 10 batches | 0.8309 | 0.8799 | 2.5794 | 35 | +| Native 'naive', 20 batches | 0.8284 | 0.8783 | 2.8725 | 65 | +| Native 'entropy', 10 batches | 0.7524 | 0.7856 | 11.7833 | 60 | +| Native 'entropy', 20 batches | 0.7107 | 0.7388 | 16.6725 | 122 | +| INC, 'basic' | 0.8480 | 0.8918 | 0.5745 | 352 | +| INC, 'bayesian' | 0.8529 | 0.8935 | 0 | 242 | +| INC, 'mse' | 0.8456 | 0.8957 | 0.8559 | 703 | + +We can see that all INC strategies found configurations meeting the 1% relative accuracy loss criterion. + +Configuration generated by INC with `basic` strategy: + +- Layers quantized using min-max (`naive`) calibration algorithm: + ``` + {'bertclassifier0_dropout0_fwd', 'bertencoder0_layernorm0_layernorm0', 'bertencoder0_transformer0_dotproductselfattentioncell0_dropout0_fwd', ..., 'sg_mkldnn_fully_connected_0', 'sg_mkldnn_fully_connected_1', ..., 'sg_mkldnn_selfatt_valatt_7', 'sg_mkldnn_selfatt_valatt_9'} + ``` + +- Layers quantized using KL (`entropy`) calibration algorithm: + ``` + {'sg_mkldnn_selfatt_qk_2', 'sg_mkldnn_selfatt_qk_4', ..., 'sg_mkldnn_selfatt_qk_20', 'sg_mkldnn_selfatt_qk_22'} + ``` + +- Layers excluded from quantization: + ``` + {'sg_mkldnn_fully_connected_39'} + ``` + ## Tips - In order to get a solution that generalizes well, evaluate the model (in eval_func) on a representative dataset. -- Using `history.snapshot` file (generated by INC) you can recover any model that was generated during the tuning process: +- With `history.snapshot` file (generated by INC) you can recover any model that was generated during the tuning process: ```python from neural_compressor.utils.utility import recover - qmodel = recover(f32_model, 'nc_workspace//history.snapshot', 1) + quantized_model = recover(f32_model, 'nc_workspace//history.snapshot', configuration_idx) ``` diff --git a/example/quantization_inc/BERT_MRPC/bert.yaml b/example/quantization_inc/BERT_MRPC/bert.yaml new file mode 100644 index 000000000000..17a27e3448d9 --- /dev/null +++ b/example/quantization_inc/BERT_MRPC/bert.yaml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +version: 1.0 + +model: + name: bert + framework: mxnet + +quantization: + calibration: + sampling_size: 320 # number of samples for calibration + +tuning: + strategy: + name: basic + accuracy_criterion: + relative: 0.01 + exit_policy: + timeout: 0 + max_trials: 9999 # default is 100 + random_seed: 9527 diff --git a/example/quantization_inc/BERT_MRPC/details.py b/example/quantization_inc/BERT_MRPC/details.py new file mode 100644 index 000000000000..94971f0b4a01 --- /dev/null +++ b/example/quantization_inc/BERT_MRPC/details.py @@ -0,0 +1,328 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +import logging +import random +import itertools +import collections +import numpy as np +import numpy.ma as ma +import gluonnlp as nlp +import mxnet as mx + +from mxnet.contrib.quantization import quantize_net_v2 +from gluonnlp.model import BERTClassifier as BERTModel +from gluonnlp.data import BERTTokenizer +from gluonnlp.data import GlueMRPC +from functools import partial + +nlp.utils.check_version('0.9', warning_only=True) +logging.basicConfig() +logging = logging.getLogger() + +CTX = mx.cpu() + +TASK_NAME = 'MRPC' +MODEL_NAME = 'bert_12_768_12' +DATASET_NAME = 'book_corpus_wiki_en_uncased' +BACKBONE, VOCAB = nlp.model.get_model(name=MODEL_NAME, + dataset_name=DATASET_NAME, + pretrained=True, + ctx=CTX, + use_decoder=False, + use_classifier=False) +TOKENIZER = BERTTokenizer(VOCAB, lower=('uncased' in DATASET_NAME)) +MAX_LEN = int(512) + +LABEL_DTYPE = 'int32' +CLASS_LABELS = ['0', '1'] +NUM_CLASSES = len(CLASS_LABELS) +LABEL_MAP = {l: i for (i, l) in enumerate(CLASS_LABELS)} + +BATCH_SIZE = int(32) +LR = 3e-5 +EPSILON = 1e-6 +LOSS_FUNCTION = mx.gluon.loss.SoftmaxCELoss() +EPOCH_NUMBER = int(4) +TRAINING_STEPS = None # if specified, epochs will be ignored +ACCUMULATE = int(1) # >= 1 +WARMUP_RATIO = 0.1 +EARLY_STOP = None +TRAINING_LOG_INTERVAL = 10*ACCUMULATE + +METRIC = mx.metric.Accuracy + + +class FixedDataset: + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, idx): + input_ids, segment_ids, valid_length, label = self.dataset[idx] + return input_ids, segment_ids, np.float32(valid_length), label + + def __len__(self): + return len(self.dataset) + + +def truncate_seqs_equal(seqs, max_len): + assert isinstance(seqs, list) + lens = list(map(len, seqs)) + if sum(lens) <= max_len: + return seqs + + lens = ma.masked_array(lens, mask=[0] * len(lens)) + while True: + argmin = lens.argmin() + minval = lens[argmin] + quotient, remainder = divmod(max_len, len(lens) - sum(lens.mask)) + if minval <= quotient: # Ignore values that don't need truncation + lens.mask[argmin] = 1 + max_len -= minval + else: # Truncate all + lens.data[~lens.mask] = [ + quotient + 1 if i < remainder else quotient for i in range(lens.count()) + ] + break + seqs = [seq[:length] for (seq, length) in zip(seqs, lens.data.tolist())] + return seqs + + +def concat_sequences(seqs, separators, seq_mask=0, separator_mask=1): + assert isinstance(seqs, collections.abc.Iterable) and len(seqs) > 0 + assert isinstance(seq_mask, (list, int)) + assert isinstance(separator_mask, (list, int)) + concat = sum((seq + sep for sep, seq in itertools.zip_longest(separators, seqs, fillvalue=[])), + []) + segment_ids = sum( + ([i] * (len(seq) + len(sep)) + for i, (sep, seq) in enumerate(itertools.zip_longest(separators, seqs, fillvalue=[]))), + []) + if isinstance(seq_mask, int): + seq_mask = [[seq_mask] * len(seq) for seq in seqs] + if isinstance(separator_mask, int): + separator_mask = [[separator_mask] * len(sep) for sep in separators] + + p_mask = sum((s_mask + mask for sep, seq, s_mask, mask in itertools.zip_longest( + separators, seqs, seq_mask, separator_mask, fillvalue=[])), []) + return concat, segment_ids, p_mask + + +def convert_examples_to_features(example, is_test): + truncate_length = MAX_LEN if is_test else MAX_LEN - 3 + if not is_test: + example, label = example[:-1], example[-1] + label = np.array([LABEL_MAP[label]], dtype=LABEL_DTYPE) + + tokens_raw = [TOKENIZER(l) for l in example] + tokens_trun = truncate_seqs_equal(tokens_raw, truncate_length) + tokens_trun[0] = [VOCAB.cls_token] + tokens_trun[0] + tokens, segment_ids, _ = concat_sequences(tokens_trun, [[VOCAB.sep_token]] * len(tokens_trun)) + input_ids = VOCAB[tokens] + valid_length = len(input_ids) + if not is_test: + return input_ids, segment_ids, valid_length, label + else: + return input_ids, segment_ids, valid_length + + +def preprocess_data(): + def preprocess_dataset(segment): + is_calib = segment == 'calib' + is_test = segment == 'test' + segment = 'train' if is_calib else segment + trans = partial(convert_examples_to_features, is_test=is_test) + batchify = [nlp.data.batchify.Pad(axis=0, pad_val=VOCAB[VOCAB.padding_token]), # 0. input + nlp.data.batchify.Pad(axis=0, pad_val=0), # 1. segment + nlp.data.batchify.Stack()] # 2. length + batchify += [] if is_test else [nlp.data.batchify.Stack(LABEL_DTYPE)] # 3. label + batchify_fn = nlp.data.batchify.Tuple(*batchify) + + dataset = list(map(trans, GlueMRPC(segment))) + random.shuffle(dataset) + dataset = mx.gluon.data.SimpleDataset(dataset) + + batch_arg = {} + if segment == 'train' and not is_calib: + seq_len = dataset.transform(lambda *args: args[2], lazy=False) + sampler = nlp.data.sampler.FixedBucketSampler(seq_len, BATCH_SIZE, num_buckets=10, + ratio=0, shuffle=True) + batch_arg['batch_sampler'] = sampler + else: + batch_arg['batch_size'] = BATCH_SIZE + + dataset = FixedDataset(dataset) + return mx.gluon.data.DataLoader(dataset, num_workers=0, shuffle=False, + batchify_fn=batchify_fn, **batch_arg) + + return (preprocess_dataset(seg) for seg in ['train', 'dev', 'calib']) + + +def log_train(batch_id, batch_num, metric, step_loss, epoch_id, learning_rate): + """Generate and print out the log message for training. """ + metric_nm, metric_val = metric.get() + if not isinstance(metric_nm, list): + metric_nm, metric_val = [metric_nm], [metric_val] + + train_str = '[Epoch %d Batch %d/%d] loss=%.4f, lr=%.7f, metrics:' + \ + ','.join([i + ':%.4f' for i in metric_nm]) + logging.info(train_str, epoch_id, batch_id, batch_num, step_loss / TRAINING_LOG_INTERVAL, + learning_rate, *metric_val) + + +def finetune(model, train_dataloader, dev_dataloader, output_dir_path): + model.classifier.initialize(init=mx.init.Normal(0.02), ctx=CTX) + + all_model_params = model.collect_params() + optimizer_params = {'learning_rate': LR, 'epsilon': EPSILON, 'wd': 0.01} + trainer = mx.gluon.Trainer(all_model_params, 'bertadam', optimizer_params, + update_on_kvstore=False) + epochs = 9999 if TRAINING_STEPS else EPOCH_NUMBER + batches_in_epoch = TRAINING_STEPS if TRAINING_STEPS else int(len(train_dataloader) / ACCUMULATE) + num_train_steps = batches_in_epoch * epochs + + logging.info('training steps=%d', num_train_steps) + num_warmup_steps = int(num_train_steps * WARMUP_RATIO) + + # Do not apply weight decay on LayerNorm and bias terms + for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): + v.wd_mult = 0.0 + + # Collect differentiable parameters + params = [p for p in all_model_params.values() if p.grad_req != 'null'] + + # Set grad_req if gradient accumulation is required + if ACCUMULATE > 1: + for p in params: + p.grad_req = 'add' + + # track best eval score + metric = METRIC() + metric_history = [] + best_metric = None + patience = EARLY_STOP + + step_num = 0 + epoch_id = 0 + finish_flag = False + while epoch_id < epochs and not finish_flag and (not EARLY_STOP or patience > 0): + epoch_id += 1 + metric.reset() + step_loss = 0 + tic = time.time() + all_model_params.zero_grad() + + for batch_id, batch in enumerate(train_dataloader): + batch_id += 1 + # learning rate schedule + if step_num < num_warmup_steps: + new_lr = LR * step_num / num_warmup_steps + else: + non_warmup_steps = step_num - num_warmup_steps + offset = non_warmup_steps / (num_train_steps - num_warmup_steps) + new_lr = LR - offset * LR + trainer.set_learning_rate(new_lr) + + # forward and backward + with mx.autograd.record(): + input_ids, segment_ids, valid_length, label = batch + input_ids = input_ids.as_in_context(CTX) + valid_length = valid_length.as_in_context(CTX).astype('float32') + label = label.as_in_context(CTX) + out = model(input_ids, segment_ids.as_in_context(CTX), valid_length) + ls = LOSS_FUNCTION(out, label).mean() + ls.backward() + + # update + if ACCUMULATE <= 1 or batch_id % ACCUMULATE == 0: + trainer.allreduce_grads() + nlp.utils.clip_grad_global_norm(params, 1) + trainer.update(ACCUMULATE) + step_num += 1 + if ACCUMULATE > 1: + # set grad to zero for gradient accumulation + all_model_params.zero_grad() + + step_loss += ls.asscalar() + label = label.reshape((-1)) + metric.update([label], [out]) + if batch_id % TRAINING_LOG_INTERVAL == 0: + log_train(batch_id, batches_in_epoch, metric, step_loss, epoch_id, + trainer.learning_rate) + step_loss = 0 + if step_num >= num_train_steps: + logging.info('Finish training step: %d', step_num) + finish_flag = True + break + mx.nd.waitall() + + # inference on dev data + metric_val = evaluate(model, dev_dataloader) + if best_metric is None or metric_val >= best_metric: + best_metric = metric_val + patience = EARLY_STOP + else: + if EARLY_STOP is not None: + patience -= 1 + metric_history.append((epoch_id, METRIC().name, metric_val)) + print('Results of evaluation on dev dataset: {}:{}'.format(METRIC().name, metric_val)) + + # save params + ckpt_name = 'model_bert_{}_{}.params'.format(TASK_NAME, epoch_id) + params_path = (output_dir_path / ckpt_name) + + model.save_parameters(str(params_path)) + logging.info('params saved in: %s', str(params_path)) + toc = time.time() + logging.info('Time cost=%.2fs', toc - tic) + + # we choose the best model assuming higher score stands for better model quality + metric_history.sort(key=lambda x: x[2], reverse=True) + best_epoch = metric_history[0] + ckpt_name = 'model_bert_{}_{}.params'.format(TASK_NAME, best_epoch[0]) + metric_str = 'Best model at epoch {}. Validation metrics: {}:{}'.format(*best_epoch) + logging.info(metric_str) + + model.load_parameters(str(output_dir_path / ckpt_name), ctx=CTX, cast_dtype=True) + return model + + +def evaluate(model, dataloader): + metric = METRIC() + for batch in dataloader: + input_ids, segment_ids, valid_length, label = batch + input_ids = input_ids.as_in_context(CTX) + segment_ids = segment_ids.as_in_context(CTX) + valid_length = valid_length.as_in_context(CTX) + label = label.as_in_context(CTX).reshape((-1)) + + out = model(input_ids, segment_ids, valid_length) + metric.update([label], [out]) + + metric_name, metric_val = metric.get() + return metric_val + + +def native_quantization(model, calib_dataloader, dev_dataloader): + quantized_model = quantize_net_v2(model, + quantize_mode='smart', + calib_data=calib_dataloader, + calib_mode='naive', + num_calib_examples=BATCH_SIZE*10) + print('Native quantization results: {}'.format(evaluate(quantized_model, dev_dataloader))) + return quantized_model diff --git a/example/quantization_inc/BERT_MRPC/main.py b/example/quantization_inc/BERT_MRPC/main.py new file mode 100644 index 000000000000..67d7b282237a --- /dev/null +++ b/example/quantization_inc/BERT_MRPC/main.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pathlib import Path +from functools import partial + +import details +from neural_compressor.experimental import Quantization, common + +# constants +INC_CONFIG_PATH = Path('./bert.yaml').resolve() +PARAMS_PATH = Path('./bert_mrpc.params').resolve() +OUTPUT_DIR_PATH = Path('./output/').resolve() +OUTPUT_MODEL_PATH = OUTPUT_DIR_PATH/'quantized_model' +OUTPUT_DIR_PATH.mkdir(parents=True, exist_ok=True) + +# Prepare the dataloaders (calib_dataloader is same as train_dataloader but without shuffling) +train_dataloader, dev_dataloader, calib_dataloader = details.preprocess_data() + +# Get the model +model = details.BERTModel(details.BACKBONE, dropout=0.1, num_classes=details.NUM_CLASSES) +model.hybridize(static_alloc=True) + +# finetune or load the parameters of already finetuned model +if not PARAMS_PATH.exists(): + model = details.finetune(model, train_dataloader, dev_dataloader, OUTPUT_DIR_PATH) + model.save_parameters(str(PARAMS_PATH)) +else: + model.load_parameters(str(PARAMS_PATH), ctx=details.CTX, cast_dtype=True) + +# run INC +calib_dataloader.batch_size = details.BATCH_SIZE +eval_func = partial(details.evaluate, dataloader=dev_dataloader) + +quantizer = Quantization(str(INC_CONFIG_PATH)) # 1. Config file +quantizer.model = common.Model(model) # 2. Model to be quantized +quantizer.calib_dataloader = calib_dataloader # 3. Calibration dataloader +quantizer.eval_func = eval_func # 4. Evaluation function +quantized_model = quantizer.fit().model + +# save the quantized model +quantized_model.export(str(OUTPUT_MODEL_PATH)) From 1c0eba125f0b70371498e1b8bcc88ec5c73fa12f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Tue, 8 Feb 2022 14:14:44 +0100 Subject: [PATCH 07/12] Add results from aws machine (with VNNI instructions) --- .../backend/mkldnn/mkldnn_quantization.md | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 009f5c78e209..0c97293ffb91 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -497,27 +497,34 @@ def native_quantization(model, calib_dataloader, dev_dataloader): return quantized_model ``` -For complete code, see this example on the [official GitHub repository](TODO link). +For complete code, see this example on the [official GitHub repository](https://github.com/apache/incubator-mxnet/tree/v1.x/example/quantization_inc/BERT_MRPC). #### Results: -Results of the f32 model on the dev split: Accuracy = 0.8529, F1 = 0.8940 +Environment: +- c6i.16xlarge Amazon EC2 instance +- Ubuntu 20.04.3 LTS +- MXNet 1.9 +- INC 1.9.1 -Results of quantized models on the dev split: +Results on the dev split: -| Quantization method | Accuracy | F1 | Relative accuracy loss [%] | Calibration/tuning time [s] | -|:----------------------------:|:--------:|:------:|:--------------------------:|:----------------------------:| -| Native 'naive', 10 batches | 0.8309 | 0.8799 | 2.5794 | 35 | -| Native 'naive', 20 batches | 0.8284 | 0.8783 | 2.8725 | 65 | -| Native 'entropy', 10 batches | 0.7524 | 0.7856 | 11.7833 | 60 | -| Native 'entropy', 20 batches | 0.7107 | 0.7388 | 16.6725 | 122 | -| INC, 'basic' | 0.8480 | 0.8918 | 0.5745 | 352 | -| INC, 'bayesian' | 0.8529 | 0.8935 | 0 | 242 | -| INC, 'mse' | 0.8456 | 0.8957 | 0.8559 | 703 | +| Quantization method | Accuracy | F1 | Relative accuracy loss [%] | Calibration/tuning time [s] | Speedup | +|:----------------------------:|:--------:|:------:|:--------------------------:|:----------------------------:|:-------:| +| **No quantization (f32)** | **0.8529** | **0.8956** | **0** | **0** | **1.0** | +| Native 'naive', 10 batches | 0.8259 | 0.8775 | 3.1657 | 31 | 1.3811 | +| Native 'naive', 20 batches | 0.8210 | 0.8731 | 3.7402 | 58 | 1.3866 | +| Native 'entropy', 10 batches | 0.8064 | 0.8557 | 5.4520 | 37 | 1.3789 | +| Native 'entropy', 20 batches | 0.8137 | 0.8624 | 4.5961 | 67 | 1.3460 | +| INC, 'basic' | 0.8456 | 0.8889 | 0.8559 | 197 | 1.4418 | +| INC, 'bayesian' | 0.8529 | 0.8888 | 0 | 129 | 1.4275 | +| INC, 'mse' | 0.8480 | 0.8954 | 0.5745 | 974 | 0.9642 | -We can see that all INC strategies found configurations meeting the 1% relative accuracy loss criterion. +All INC strategies found configurations meeting the 1% relative accuracy loss criterion. Only the `mse` strategy struggled, taking the longest time while also being slower than the f32 model. -Configuration generated by INC with `basic` strategy: +Although these results may suggest that the `mse` strategy is the worst and the `bayesian` strategy is the best, different strategies may give better results for specific models and tasks. Usually the `basic` strategy is the most stable one. + +Here is an example of a configuration generated by INC with the `basic` strategy: - Layers quantized using min-max (`naive`) calibration algorithm: ``` From 8d7f19c655da6f69ed28c7ccc3c7bc2b7a358616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Tue, 8 Feb 2022 14:23:15 +0100 Subject: [PATCH 08/12] Small fix --- .../performance/backend/mkldnn/mkldnn_quantization.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 0c97293ffb91..cc01143642b2 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -503,7 +503,7 @@ For complete code, see this example on the [official GitHub repository](https:// Environment: - c6i.16xlarge Amazon EC2 instance -- Ubuntu 20.04.3 LTS +- Ubuntu 20.04 LTS - MXNet 1.9 - INC 1.9.1 @@ -520,9 +520,7 @@ Results on the dev split: | INC, 'bayesian' | 0.8529 | 0.8888 | 0 | 129 | 1.4275 | | INC, 'mse' | 0.8480 | 0.8954 | 0.5745 | 974 | 0.9642 | -All INC strategies found configurations meeting the 1% relative accuracy loss criterion. Only the `mse` strategy struggled, taking the longest time while also being slower than the f32 model. - -Although these results may suggest that the `mse` strategy is the worst and the `bayesian` strategy is the best, different strategies may give better results for specific models and tasks. Usually the `basic` strategy is the most stable one. +All INC strategies found configurations meeting the 1% relative accuracy loss criterion. Only the `mse` strategy struggled, taking the longest time and generating configuration that is slower than the f32 model. Although these results may suggest that the `mse` strategy is the worst and the `bayesian` strategy is the best, different strategies may give better results for specific models and tasks. Usually the `basic` strategy is the most stable one. Here is an example of a configuration generated by INC with the `basic` strategy: From 11ea83d3e493a4477083133a95c67902a4b92d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Tue, 8 Feb 2022 16:26:22 +0100 Subject: [PATCH 09/12] Fix mxnet installation instruction --- .../performance/backend/mkldnn/mkldnn_quantization.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index cc01143642b2..641d6ce6b345 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -27,9 +27,10 @@ Installing MXNet with MKLDNN backend is an easy and essential process. You can f ``` # release version -pip install mxnet-mkl +pip install mxnet + # nightly version -pip install mxnet-mkl --pre +pip install --pre "mxnet<2" -f https://dist.mxnet.io/python ``` ## Image Classification Demo From a876903c730da29e5e43d03d90647c26fc40a4cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Wed, 9 Feb 2022 11:48:51 +0100 Subject: [PATCH 10/12] Review fixes --- .../backend/mkldnn/mkldnn_quantization.md | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 641d6ce6b345..7a8ce00ca873 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -29,7 +29,7 @@ Installing MXNet with MKLDNN backend is an easy and essential process. You can f # release version pip install mxnet -# nightly version +# latest nightly development version pip install --pre "mxnet<2" -f https://dist.mxnet.io/python ``` @@ -287,7 +287,7 @@ Most tuning strategies will try different configurations on an evaluation datase ## Configuration file -Quantization tuning process can be customized in the yaml configuration file. Here is a simple example: +Quantization tuning process can be customized in the yaml configuration file. Below is a simple example: ```yaml # cnn.yaml @@ -390,16 +390,16 @@ quantizer = Quantization("./cnn.yaml") quantizer.model = resnet18 quantizer.calib_dataloader = data quantizer.eval_func = eval_func -qnet = quantizer().model +qnet = quantizer.fit().model ``` Since this model already achieves good accuracy using native quantization (less than 1% accuracy drop), for the given configuration file, INC will end on the first configuration, quantizing all layers using `naive` calibration mode for each. To see the true potential of INC, we need a model which suffers from a larger accuracy drop after quantization. ### Quantizing BERT -This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the code, model and task specific boilerplate has been moved to the `details.py` file. +This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the script, model and task specific boilerplate code has been moved to the `details.py` file. -Here is the configuration file: +This is the configuration file for this example: ```yaml version: 1.0 @@ -503,12 +503,12 @@ For complete code, see this example on the [official GitHub repository](https:// #### Results: Environment: -- c6i.16xlarge Amazon EC2 instance +- c6i.16xlarge Amazon EC2 instance (Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz) - Ubuntu 20.04 LTS - MXNet 1.9 - INC 1.9.1 -Results on the dev split: +Results on the validation dataset: | Quantization method | Accuracy | F1 | Relative accuracy loss [%] | Calibration/tuning time [s] | Speedup | |:----------------------------:|:--------:|:------:|:--------------------------:|:----------------------------:|:-------:| @@ -527,17 +527,17 @@ Here is an example of a configuration generated by INC with the `basic` strategy - Layers quantized using min-max (`naive`) calibration algorithm: ``` - {'bertclassifier0_dropout0_fwd', 'bertencoder0_layernorm0_layernorm0', 'bertencoder0_transformer0_dotproductselfattentioncell0_dropout0_fwd', ..., 'sg_mkldnn_fully_connected_0', 'sg_mkldnn_fully_connected_1', ..., 'sg_mkldnn_selfatt_valatt_7', 'sg_mkldnn_selfatt_valatt_9'} + {'bertclassifier0_dropout0_fwd', 'bertencoder0_layernorm0_layernorm0', 'bertencoder0_transformer0_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer0_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer0_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer0_layernorm0_layernorm0', 'bertencoder0_transformer0_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer10_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer10_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer10_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer10_layernorm0_layernorm0', 'bertencoder0_transformer10_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer11_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer11_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer11_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer11_layernorm0_layernorm0', 'bertencoder0_transformer1_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer1_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer1_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer1_layernorm0_layernorm0', 'bertencoder0_transformer1_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer2_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer2_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer2_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer2_layernorm0_layernorm0', 'bertencoder0_transformer2_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer3_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer3_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer3_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer3_layernorm0_layernorm0', 'bertencoder0_transformer3_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer4_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer4_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer4_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer4_layernorm0_layernorm0', 'bertencoder0_transformer4_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer5_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer5_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer5_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer5_layernorm0_layernorm0', 'bertencoder0_transformer5_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer6_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer6_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer6_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer6_layernorm0_layernorm0', 'bertencoder0_transformer6_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer7_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer7_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer7_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer7_layernorm0_layernorm0', 'bertencoder0_transformer7_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer8_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer8_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer8_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer8_layernorm0_layernorm0', 'bertencoder0_transformer8_positionwiseffn0_layernorm0_layernorm0', 'bertencoder0_transformer9_dotproductselfattentioncell0_dropout0_fwd', 'bertencoder0_transformer9_dotproductselfattentioncell0_reshape3', 'bertencoder0_transformer9_dotproductselfattentioncell0_reshape7', 'bertencoder0_transformer9_layernorm0_layernorm0', 'bertencoder0_transformer9_positionwiseffn0_layernorm0_layernorm0', 'bertmodel0_reshape0', 'sg_mkldnn_fully_connected_0', 'sg_mkldnn_fully_connected_1', 'sg_mkldnn_fully_connected_11', 'sg_mkldnn_fully_connected_12', 'sg_mkldnn_fully_connected_13', 'sg_mkldnn_fully_connected_15', 'sg_mkldnn_fully_connected_16', 'sg_mkldnn_fully_connected_17', 'sg_mkldnn_fully_connected_19', 'sg_mkldnn_fully_connected_20', 'sg_mkldnn_fully_connected_21', 'sg_mkldnn_fully_connected_23', 'sg_mkldnn_fully_connected_24', 'sg_mkldnn_fully_connected_25', 'sg_mkldnn_fully_connected_27', 'sg_mkldnn_fully_connected_28', 'sg_mkldnn_fully_connected_29', 'sg_mkldnn_fully_connected_3', 'sg_mkldnn_fully_connected_31', 'sg_mkldnn_fully_connected_32', 'sg_mkldnn_fully_connected_33', 'sg_mkldnn_fully_connected_35', 'sg_mkldnn_fully_connected_36', 'sg_mkldnn_fully_connected_37', 'sg_mkldnn_fully_connected_39', 'sg_mkldnn_fully_connected_4', 'sg_mkldnn_fully_connected_40', 'sg_mkldnn_fully_connected_41', 'sg_mkldnn_fully_connected_43', 'sg_mkldnn_fully_connected_44', 'sg_mkldnn_fully_connected_45', 'sg_mkldnn_fully_connected_47', 'sg_mkldnn_fully_connected_48', 'sg_mkldnn_fully_connected_49', 'sg_mkldnn_fully_connected_5', 'sg_mkldnn_fully_connected_7', 'sg_mkldnn_fully_connected_8', 'sg_mkldnn_fully_connected_9', 'sg_mkldnn_fully_connected_eltwise_10', 'sg_mkldnn_fully_connected_eltwise_14', 'sg_mkldnn_fully_connected_eltwise_18', 'sg_mkldnn_fully_connected_eltwise_2', 'sg_mkldnn_fully_connected_eltwise_22', 'sg_mkldnn_fully_connected_eltwise_26', 'sg_mkldnn_fully_connected_eltwise_30', 'sg_mkldnn_fully_connected_eltwise_34', 'sg_mkldnn_fully_connected_eltwise_38', 'sg_mkldnn_fully_connected_eltwise_42', 'sg_mkldnn_fully_connected_eltwise_46', 'sg_mkldnn_fully_connected_eltwise_6'} ``` - Layers quantized using KL (`entropy`) calibration algorithm: ``` - {'sg_mkldnn_selfatt_qk_2', 'sg_mkldnn_selfatt_qk_4', ..., 'sg_mkldnn_selfatt_qk_20', 'sg_mkldnn_selfatt_qk_22'} + {'sg_mkldnn_selfatt_qk_0', 'sg_mkldnn_selfatt_qk_10', 'sg_mkldnn_selfatt_qk_12', 'sg_mkldnn_selfatt_qk_14', 'sg_mkldnn_selfatt_qk_16', 'sg_mkldnn_selfatt_qk_18', 'sg_mkldnn_selfatt_qk_2', 'sg_mkldnn_selfatt_qk_20', 'sg_mkldnn_selfatt_qk_22', 'sg_mkldnn_selfatt_qk_4', 'sg_mkldnn_selfatt_qk_6', 'sg_mkldnn_selfatt_qk_8', 'sg_mkldnn_selfatt_valatt_1', 'sg_mkldnn_selfatt_valatt_11', 'sg_mkldnn_selfatt_valatt_13', 'sg_mkldnn_selfatt_valatt_15', 'sg_mkldnn_selfatt_valatt_17', 'sg_mkldnn_selfatt_valatt_19', 'sg_mkldnn_selfatt_valatt_21', 'sg_mkldnn_selfatt_valatt_23', 'sg_mkldnn_selfatt_valatt_3', 'sg_mkldnn_selfatt_valatt_5', 'sg_mkldnn_selfatt_valatt_7', 'sg_mkldnn_selfatt_valatt_9'} ``` - Layers excluded from quantization: ``` - {'sg_mkldnn_fully_connected_39'} + {'sg_mkldnn_fully_connected_43'} ``` ## Tips From 2716b89465481286b33165845b577c5584303c2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Thu, 10 Feb 2022 11:00:54 +0100 Subject: [PATCH 11/12] Review fixes --- .../performance/backend/mkldnn/mkldnn_quantization.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index 7a8ce00ca873..ae86386a5ce3 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -260,7 +260,7 @@ MXNet also supports deploy quantized models with C++. Refer [MXNet C++ Package]( # Improving accuracy with Intel® Neural Compressor -The accuracy of a model can decrease as a result of quantization. When the accuracy drop is significant, we can try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.) but for bigger models, this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using several tuning heuristics, finding the quantization configuration that satisfies the specified accuracy requirement. +The accuracy of a model can decrease as a result of quantization. When the accuracy drop is significant, we can try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.), but for bigger models this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using several tuning heuristics, which aim to find the quantization configuration that satisfies the specified accuracy requirement. **NOTE:** @@ -314,7 +314,7 @@ tuning: We are using the `basic` strategy, but you could also try out different ones. [Here](https://github.com/intel/neural-compressor/blob/master/docs/tuning_strategies.md) you can find a list of strategies available in INC and details of how they work. You can also add your own strategy if the existing ones do not suit your needs. -Since the value of `timeout` is 0, INC will run until it finds a configuration that satisfy the accuracy criterion and then exit. Depending on the strategy this may not be ideal, as sometimes it would be better to further explore the tuning space to find a superior configuration both in terms of accuracy and speed. To achive this, we can set a specific timeout - how long (in seconds) do we want INC to run. +Since the value of `timeout` is 0, INC will run until it finds a configuration that satisfies the accuracy criterion and then exit. Depending on the strategy this may not be ideal, as sometimes it would be better to further explore the tuning space to find a superior configuration both in terms of accuracy and speed. To achieve this, we can set a specific `timeout` value, which will tell INC how long (in seconds) it should run. For more information about the configuration file, see the [template](https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml) from the official INC repo. Keep in mind that only the `post training quantization` is currently supported for MXNet. @@ -328,7 +328,7 @@ In general, Intel® Neural Compressor requires 4 elements in order to run: ### Quantizing ResNet -The previous sections described how to quantize ResNet using the native MXNet quantization. This example shows how we can achieve the same (plus the auto-tuning) using INC. +The previous sections described how to quantize ResNet using the native MXNet quantization. This example shows how we can achieve the same (with the auto-tuning) using INC. 1. Get the model @@ -397,7 +397,7 @@ Since this model already achieves good accuracy using native quantization (less ### Quantizing BERT -This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the script, model and task specific boilerplate code has been moved to the `details.py` file. +This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the code, model and task specific boilerplate has been moved to the `details.py` file. This is the configuration file for this example: ```yaml @@ -546,7 +546,7 @@ Here is an example of a configuration generated by INC with the `basic` strategy ```python from neural_compressor.utils.utility import recover - quantized_model = recover(f32_model, 'nc_workspace//history.snapshot', configuration_idx) + quantized_model = recover(f32_model, 'nc_workspace//history.snapshot', configuration_idx).model ``` From 91e301eb0976285b639c628b6382a03fd3d0bdec Mon Sep 17 00:00:00 2001 From: bgawrych Date: Thu, 17 Feb 2022 10:09:05 +0100 Subject: [PATCH 12/12] Update docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md Co-authored-by: bartekkuncer --- .../tutorials/performance/backend/mkldnn/mkldnn_quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md index ae86386a5ce3..b7436ffaa013 100644 --- a/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md +++ b/docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md @@ -397,7 +397,7 @@ Since this model already achieves good accuracy using native quantization (less ### Quantizing BERT -This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the code, model and task specific boilerplate has been moved to the `details.py` file. +This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the code, model and task specific boilerplate has been moved to the `details.py` file. This is the configuration file for this example: ```yaml