Authors: Ruiqi Zhong, Dhruba Ghosh, Dan Klein, Jacob Steinhardt
This GitHub Repo contains the pretrained models, model predictions (data) and code for the experiments. If you have any questions, need additional information about the models/data, or want to request a new feature, feel free to send an email to [email protected] .
The paper is accepted to ACL 2021 Findings. Paper link: https://arxiv.org/abs/2105.06020
You can download all of our pre-trained BERT models from here. We experimented with 5 different sizes (mini, small, medium, base, large) as described in the official BERT github Repo and stored them in the same format. We used the exact same pre-training code as the oringal BERT paper and similar training settings, except that we used the training corpus from this paper, reduced the context size from 512 to 128, and increase the number of training steps from 1M to 2M. For each size, we pre-trained 10 times with the identical procedure but different random seeds.
The directory for the medium-sized BERT model with pre-training seed 3 after pretraining for 2000000 steps is pretrained/medium/pretrain_seed3step2000000
.
We release the model predictions on 3 different tasks (sst-2, MNLI, and QQP) for 5 different model sizes, each with 10 pretraining seeds and 5 finetuning seeds, as described in the paper. Additionally, we provide the model predictions for several out-of-domain datasets (e.g. SNLI/HANS for MNLI, TwitterPPDB for QQP) after training for 3.0, 3.33, 3.67 epochs (in the paper we always finetune for 4 epochs). You can download them from here.
This folder contains the raw model predictions and the corresponding datapoints. For each task of MNLI, QQP and SST-2, there is a folder with the following files/folders (we performed 5 fold cross-validation on SST-2):
data.json
contains the data used for fine-tuning (training) and prediction, where each datapoint is represented as a dictionary.
For example,
import json
data = json.load(open('qqp/data.json'))
print('Number of datapoints for model prediction', len(data['predict'])) # prints the number of datapoints that are used for evaluation
print(data['predict'][0]) # prints the first datapoint for evaluation
, and we have
Number of datapoints for model prediction, 79497
{'guid': 'dev-0', 'label': '1',
'text_b': 'Where are the best places to eat in New York City?',
'text_a': 'Where are the best places to eat in New York City that have a great vibe?'}
Notice that string before "-" for "guid" denotes the original data split, not our train/test split. For example, "train" might occur in the evaluation set, since we used part of the original training split for testing. In our paper, we used the "dev_matched" split for MNLI, "dev" for QQP, and "train" for SST-2. We downloaded our QQP data from here, and the NLI challenge set from here.
[predict/train].tf_record
contains the tokenized data in the tensorflow format that are used for fine-tuning.
size2hyperparam.json
contains the hyper-parameter used for different model sizes.
results/
is the folder that contains the models' prediction, each in a .tsv format, representing a matrix of dimension (number of datapoints, number of classes).
Each row is the models' predicted probability for each class, and correspond to one datapoint in data['predict']
.
results/slargep8f5epoch9over3.tsv
means the predictions of the large size model with pretraining seed 8 finetuning seed 5 evaluated at (9/3)=3 epoch.
We extract 3 types of "correctness tensors" from the .tsv and data.json files, with the command
python3 dump_correctness.py
The results are dumped into correctness/
, correctness_p/
, ensemble_c/
correctness/
: stores binary tensors. 1 if the model prediction is correct, 0 otherwise. For example
import pickle as pkl
qqp_size2correctness, qqp_data = pkl.load(open('correctness/qqp.pkl', 'rb')) # qqp_data is exactly data['predict'] as mentioned above
print(qqp_size2correctness.keys()) # output: dict_keys(['mini', 'small', 'medium', 'base', 'large']). qqp_size2correctness is a mapping from model size to the correctness tensor
print(qqp_size2correctness['large'].shape) # output: (79497, 10, 5, 4). 79497 is the number of datapoints, 10/5 is the number of pretraining/finetuning seeds, 4 represents different checkpoints at [3, 3.33, 3.67, 4] epochs.
correctness_p/
: Same as "correctness", the probability assigned to the correct class.
ensemble_c/
: the correctness tensor using the last checkpoint only, after marginalizing over all fine-tuning seeds.
We use almost the exact same code as the original BERT paper to fine-tune a pre-trained model. To convert the data into the tensorflow format, run
python3 process_input.py
To finetune the pretrained models, run, for example,
python3 run_classifier.py --data_dir qqp/ --pretrain_seed 10 --model_size medium --dataorder_seed 10 --initialization_seed 10 --tpu_name lm-4
Run pip3 install -r requirements.txt
to install the required dependencies, and download the files to the corresponding folders.
The results.ipynb
jupyter notebook computes the following:
- within/across pretraining seed difference
- 0/1 loss bias-finetuningvariance-pretrainingvariance-decomposition
- variance conditioned on bias
- squared loss decomposition using probability assigned to the correct label
- Our approach, with CDF visualization
- Comparison with the conventional Benjamini-Hochberg Procedure
- Example decaying instances
The decaying instances are stored in the decaying/
folder. Each .pkl file is a map from the data split name to a 3-tuple, which are
- a list of datapoints
- whether the datapoint belongs to the control group (random non-decaying instances), or the treatment group (decaying instances)
- Ruiqi Zhong's annotation of whether he thinks the label is correct, wrong, reasonable or unsure. (None means not annotated)
- instance difference as measured by pearson-r correlation (numbers represent models sizes, smaller numbers correspond to smaller models)
- measuring with spearman rank
Some results might be different from that in the paper due to random seeds; however they should be close and lead to the exact qualitative conclusions.