First of all you should clone repo:
git clone https://github.com/elephantmipt/compressors.git && cd compressors
To train teacher you can run
python examples/transformers/train_teacher
Or you can run sweep with Weights and Biases
wandb sweep examples/transformers/configs/teacher.yml
I trained two models with this script on AG News dataset:
BERT Medium | BERT Small |
---|---|
0.9454 | 0.9429 |
After teacher training you can use it to train student network.
python examples/transformers/train_student.py --teacher-path bert_teacher/checkpoints/best.pth
Or you can use Weights and Biases for parameter tuning
wandb sweep examples/transformers/configs/student.yml
Here are the logs for models I trained:
Teacher | Student | Accuracy | Accuracy (without kd) | Improvement |
---|---|---|---|---|
BERT Medium | BERT Small | 0.9443 | 0.9429 | +0.014 |
BERT Medium | BERT Medium | 0.9466 | 0.9454 | +0.012 |