This repository is an adaptation of the bert repository.
The purpose of this repository is to visualize BERT's self-attention weights after it has been fine-tuned on the IMDb dateset. However, it can be extended to any text classification dataset by creating an appropriate DataProcessor class. See run_classifier.py for details.
-
Create a tsv file for each of the IMDB training and test set.
Refer to the imdb_data repo for instructions.
-
Fine-tune BERT on the IMBDb training set.
Refer to the official BERT repo for fine-tuning instructions.
Alternatively, you can skip this step by downloading the fine-tuned model from here.
The pre-trained model (BERT base uncased) used to perform the fine-tuning can also be downloaded from here.
-
Visualize BERT's weights.
Refer to the BERT_viz_attention_imdb notebook for more details.
The forward pass has been modified to return a list of {layer_i: layer_i_attention_weights}
dictionaries. The shape of
layer_i_attention_weights
is (batch_size, num_multihead_attn, max_seq_length, max_seq_length)
.
You can specify a function to process the above list by passing it as a parameter into the load_bert_model
function
in the explain.model module. The function's output is avaialble as part of the result of the Estimator's predict call
under the key named 'attention'.
Currently, only two attention processor functions have been defined, namely average_last_layer_by_head
and average_first_layer_by_head
.
See explain.attention for implementation details.
The fine-tuned model achieved an accuracy of 0.9407 on the test set.
The fine-tuning process was done with the following hyperparameters:
- maximum sequence length: 512
- training batch size: 8
- learning rate: 3e-5
- number of epochs: 3