Text-Derived Knowledge Helps Vision: A Simple Cross-modal Distillation for Video-based Action Anticipation
-
Create the environment using the requirement.txt file in the
teacher
folder. -
Download the content of the data folder from the google drive link
data └───egtea_action_seq/ └───epic55_action_seq/ └───recipe1M_layers/ └───processed_data_dict.pt └───vocab.txt
-
Download the content of out folder from the google drive link. This also includes the different LMs pretrined (MLM) on 1M Recipe dataset.
out └───albert_pretrained/checkpoint-200000/ └───bert_pretrained/checkpoint-200000/ └───distilbert_pretrained/checkpoint-200000/ └───electra_pretrained/checkpoint-200000/ └───roberta_pretrained/checkpoint-200000/ └─... .
-
LM checkpoints pre-trained on action sequences derived from 1M recipe are in their respective folders
-
Code for pretraining LM on 1M-Recipe is at
code/{model_name}_pretraining.py
-
To finetune model on EGTEA-GAZE+ dataset
python ./code/egtea_finetuning.py \ -model_type bert \ # bert/roberta/distillbert/alberta/deberta/electra -batch_size 16 \ # batch-size -num_epochs 5 \ # no. of epochs -max_len 512 \ # Max # of tokens in the input, tokens beyond this number will be truncated -checkpoint_path ./out/bert_pretrained/checkpoint-200000 \ # path to the model checkpoint (for initialization) that was trained on 1M-Recipe through MLM -weigh_classes True/ #for imbalanced data, if True, then the loss will be weighted Cross-Entropy; will not work with EPIC-55 as few clases have 0 data instances -hist_len 15 \ # context length, i.e. how many actions in the past conditioned on which you want to predict the action after the anticipation time -gappy_hist True \ # EGTEA data (action sequence) have gaps as action segments in a video are partiioned into train/test set -multi_task True \ # Instead of just predicting the action, also predict the verb and noun -sort_seg True \ # Action segnment in the training batch should be sorted by their temporal order
-
To finetune model on EPIC-55 dataset
python ./code/epic55_finetuning.py \ -model_type distillbert \ -batch_size 16 \ -num_epochs 8 \ -max_len 512 \ -hist_len 5 \ -checkpoint_path ./out/distilbert_pretrained/checkpoint-200000 \ -weigh_classes False \ -multi_task True
-
For the Egtea and EPIC55 dataset, the arguments in the above snippet are the model hyperparameter used to perform the teacher training and reporting the performance.
-
Sample slurm script can be found in
code/slurm scripts
-
Model (teacher) predictions for the test data can be found at the google drive link.
-
The predictions are saved as list of dictionary, where each element of the list has the following keys
<UID (unique segment ID), action_logit, LM_feature>
along with other segment (UID) associated such as actionID, action history, etc. -
These teacher predictions are then used to train student
Anticipative Video Transformer
, through knowledge distillation. -
Reproducing teacher metrics reported in the paper: Download the model prediction folder
teacher/teacher_student_Predictions/
from link. Calculate the teacher performance metric reported in the paper by running the notebookcode/logit_analysis.ipynb
-
Student training repo is in
student
. Our student training code is adopted from theAVT
code base (link). As such theDATA
and otherAVT
model checkpoints should be first downloaded as explained in their documentation. -
Setup the
avt.yml
python environment. -
Since most of the codes are same as
AVT
so their documentation can be referred to for reference purposes, and we explain and describe where in the codebase these different additions we made. -
For all the experiments, we use exactly the same hyperparameters that the
AVT
authors used to train their model. -
Helper codes
-
Generating LM (teacher) pred logits/features https://github.com/sayontang/Action_Anticipation/blob/main/student/helpers/generate_preds/traindata.py
-
Processing result files
-
Get logits from the student https://github.com/sayontang/Action_Anticipation/blob/main/student/helpers/metrics_multiple.py
-
Compute verb/noun accuracy from action accuracy https://github.com/sayontang/Action_Anticipation/tree/main/student/helpers/egtea_verb_noun https://github.com/sayontang/Action_Anticipation/tree/main/student/helpers/ek55_verb_noun
-
-
-
Config description
-
Set weight for different losses (teacher + student during student training) https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/conf/config.yaml#L61
-
Set weight for different losses (teacher + student during student training) https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/conf/config.yaml#L61
-
-
Running Experiments - EK55 - https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/avt_ek55_ensemble_test.job
- EGTEA - [https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/avt_base_feat_egtea_test.job](https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/avt_base_feat_egtea_test.job)
-
Reading teacher prediction and setting the path in student distillation training
-
File where the path needs to be set:
AVT-main\base_video_dataset.py
-
Read the predictions (feature and logits) from teacher Language Model https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/datasets/base_video_dataset.py#L271
-
Set path of LM (teacher prediction) for student training https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/datasets/base_video_dataset.py#L454
-
-
Dataloader - updates explained
- File where the path needs to be set:
AVT-main\/datasets/base_video_dataset.py
- Helper function - _get_past_segment_pred_by_uid https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/datasets/base_video_dataset.py#L787 Usage: fetch predictions for given UID, for feature and logit LM predictions
- Call in dataloader - __getitem__ https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/datasets/base_video_dataset.py#L859
- File where the path needs to be set:
-
Knowledge distillation
-
KL divergence (distillation loss)
-
Config weight key:
distill
-
Loss function defined here https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/loss_fn/kld_lm.py
-
Path:
AVT-main/loss_fn/kld_lm.py
-
Usage: This function takes in the student’s logits and teacher (LM pred logits) and computes KL divergence loss between them. Softmax temperature specified here.
-
Variation 1: For EK55 since the number of classes was very large, we compute distillation loss between top K classes. https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/loss_fn/kld_lm.py#L32
-
Variation 2 (not used): **We compute distillation loss between top K classes and also bottom K and middle k, e.g. [||||||||||||||||||||||||||||||] https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/loss_fn/kld_lm.py#L40
-
Add the loss in training module
- Path:
AVT-main/func/train_eval_ops.py
https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/func/train_eval_ops.py#L109
- Path:
-
Call and add the distillation loss to other (original) losses
-
-
Feature (alignment) distillation(not used)-
Config weight key: distill_feat/distill_feat_mse
-
Loss function defined here - https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/loss_fn/sim.py
-
Path:
AVT-main/loss_fn/sim.py
-
Usage: This function takes in the student’s features and teacher (LM pred features) and computes cosine similarity loss between them.
-
Variation 1: MSE error instead of cosine sim. (predefined function being used) https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/loss_fn/mse.py
-
Adding the loss in training module
-
Path: AVT-main/func/train_eval_ops.py https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/func/train_eval_ops.py#L110
-
Call and add the distillation loss to other (original) losses
- Path: AVT-main/func/train_eval_ops.py https://github.com/sayontang/Action_Anticipation/blob/main/student/AVT-main/func/train_eval_ops.py#L193
-
- Reproducing student metrics reported in the paper: Download the model prediction folder
teacher/teacher_student_Predictions/
from link. Calculate the teacher performance metric reported in the paper by running the notebookcode/logit_analysis.ipynb