Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Refactor bias recalculation loops
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed May 19, 2021
1 parent e6f24d9 commit a301cd5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
29 changes: 15 additions & 14 deletions scripts/bert/bertpass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,22 +496,22 @@ MXReturnValue MHAInterleave(mxnet::ext::Graph *g,
}
#endif
// concatenate bias terms
int counter = 0;
for (int e = 0; e < num_heads * 3; e += 3) {
for (int h = e * head_dimension; h < e * head_dimension + head_dimension; h++) {
qkv_bias_data[h] = query_bias_data[counter++];
int query_cnt = 0;
int key_cnt = 0;
int value_cnt = 0;
const int qkv_heads = num_heads * 3;
for (int e = 0; e < qkv_heads; e += 3) {
const int query_head_offset = e * head_dimension;
for (int h = query_head_offset; h < query_head_offset + head_dimension; h++) {
qkv_bias_data[h] = query_bias_data[query_cnt++];
}
}
counter = 0;
for (int e = 1; e < num_heads * 3; e += 3) {
for (int h = e * head_dimension; h < e * head_dimension + head_dimension; h++) {
qkv_bias_data[h] = key_bias_data[counter++];
const int key_head_offset = (e + 1) * head_dimension;
for (int h = key_head_offset; h < key_head_offset + head_dimension; h++) {
qkv_bias_data[h] = key_bias_data[key_cnt++];
}
}
counter = 0;
for (int e = 2; e < num_heads * 3; e += 3) {
for (int h = e * head_dimension; h < e * head_dimension + head_dimension; h++) {
qkv_bias_data[h] = value_bias_data[counter++];
const int value_head_offset = (e + 2) * head_dimension;
for (int h = value_head_offset; h < value_head_offset + head_dimension; h++) {
qkv_bias_data[h] = value_bias_data[value_cnt++];
}
}
// set connection with new input
Expand Down Expand Up @@ -554,6 +554,7 @@ bool CheckIfSoftmaxLengthPattern(Node *softmax_node) {
} else {
return false;
}

auto bcast_axis_node = reshape_node->inputs[0].node;
if (bcast_axis_node->op == "broadcast_axis") {
std::string axis = bcast_axis_node->attrs["axis"];
Expand Down
10 changes: 5 additions & 5 deletions scripts/bert/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,16 @@ Question Answering
| SQuAD 1.1 | bert_12_768_12 | 81.18 | 80.32 | 88.58 | 88.10 |`command <https://github.com/dmlc/web-data/blob/master/gluonnlp/logs/bert/calibration_squad1.1_base_mx1.6.0b20200125.sh>`__ |
+-----------+-------------------+---------+---------+---------+---------+----------------------------------------------------------------------------------------------------------------------------+

For all model settings above, we use a subset of evaluation dataset for calibration.
For all model settings above, subset of evaluation dataset for calibration was used.

We recommend to use optimization graph passes, which boost performance of inference even more. To deploy calibrated model optimized with graph passes use
--custom_pass_lib [graph_pass_library_path] and --custom_passes [graph_passes_name] arguments. E.g.:
Using optimization graph passes is recommended to boost performance of inference even more. To deploy calibrated model optimized with graph passes following arguments can be used
--custom_pass_lib [graph_pass_library_path] and --custom_passes [graph_passes_name]. E.g.:

.. code-block:: console
$ python3 finetune_squad.py --only_calibration --model_parameters ./output_dir/net.params --custom_pass_lib bertpass_lib.so --custom_passes MaskSoftmax MHAInterleave
Use setup.py script to build graph pass library.
Graph pass library can be built with setup.py script.

Pre-training from Scratch
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -357,7 +357,7 @@ Once the model is exported, you can import the model by setting --only_infer, an
The batch size can be specified via --test_batch_size option, and accuracy can be checked setting --check_accuracy.

When using GPU and data type FP16 (--dtype float16), we recommend to use MXNET_FC_TRUE_FP16=1 for boosting performance.
Moreover, you can use a custom graph pass for BERT, via --custom_pass_lib [custom_pass_library] and --custom_passes [space_seperated_names_of_passes_to_apply], to improve the performance on GPU. To generate the pass you can run setup.py within the BERT scripts directory. These GPU optimizations require MXNet version 1.7 or higher.
Moreover, custom graph pass for BERT can be used via --custom_pass_lib [custom_pass_library] and --custom_passes [space_seperated_names_of_passes_to_apply], to improve the performance on GPU. To generate the pass library setup.py script can be run within the BERT scripts directory. These GPU optimizations require MXNet version 1.7 or higher.


BERT for Sentence or Tokens Embedding
Expand Down

0 comments on commit a301cd5

Please sign in to comment.