diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b6e4761 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..0d31b1f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bab9c2d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing to LaViLa +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style +* 4 spaces for indentation rather than tabs +* 80 character line length +* PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) + +## License +By contributing to LaViLa, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0eb714d --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ + +MIT License + +Copyright (c) Meta Platforms, Inc. and affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..17ae05e --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +# Learning Video Representations from Large Language Models + + +[**Learning Video Representations from Large Language Models**](http://arxiv.org/abs/2212.04501) +Yue Zhao, Ishan Misra, Philipp Krähenbühl, Rohit Girdhar +[arxiv](http://arxiv.org/abs/2212.04501) | [bibtex](#citing-lavila) | [colab](#narrator-demo) + +LaViLa (**L**anguage **a**ugmented **Vi**deo **La**nguage Pretraining) is a new approach to learning video representations from Large Language Models (LLMs). We repurpose LLMs to be visually conditioned "Narrators", and use them to automatically generate video-language paired data. We use this data to then learn a video-langauge representation, outperforming prior work by large margins. + +**Sample Generations:** + +| Video | Generation 1 | Generation 2 | +| --------|-------------|--------------| +| | so now we're going to slice the bread | now i'm going to do is just slice
this up into a nice chunk and
then we're going to place it
on the plate | + +[Try out](#narrator-demo) our Narrator to generate text descriptions for your own videos! + +The resulting video-language model sets a new **state-of-the-art** on a number of popular video tasks! +image + + + + +## Introduction and installation + +LaViLa leverages Large Language Models (LLMs) as "NARRATOR"s (and "REPHRASER"s) to densely narrate long videos, and uses these narrations to train strong dual-encoder models. + + + + +See [INSTALL.md](docs/INSTALL.md) to install this code. + +## NARRATOR + +NARRATOR is a *visually conditioned* LLM that takes videos frames as input and pseudo-labels this clip with narrations. + + + + +### NARRATOR Demo + +We provide some generated samples by our NARRATOR: + +| | | | +| :----------------: | :----------------------------------------: | :-------------------------------------: | :--------------------------------------: | +| Human
narration | C separates the yarn. | C lifts container. | C opterates the camera. | +| NARRATOR generation (a) | C stetches the thread with both hands. | C wipes the countertop with a sponge. | C takes a photo shot. | +| NARRATOR generation (b) | C pulls out the yarn with her right hand. | C moves the container. | A man X looks at the camera. | + + +Run the narrator demo using Colab (no GPU needed): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gHWiEWywIotRivYQTR-8NQ6GJC7sJUe4) + +Since Colab free account offers very limited RAM, if you'd like to run the demo with a larger model, please run [./demo_narrator.py](./demo_narrator.py) locally. For more technical details, please refer to Sec 4.1 in our paper. + +```bash +# CPU mode +python demo_narrator.py [--video-path $TEST_VIDEO] + +# GPU mode +python demo_narrator.py --cuda +``` + + +Our narrator also works on third-person videos! Below are several examples generated by our NARRATOR that is pre-trained on HowTo100M Auto-Aligned ([HTM-AA](https://www.robots.ox.ac.uk/~vgg/research/tan/index.html#htm-align)) and applied to some stock footage video clips. Note that since the text corpus in HowTo100M is ASR transcription, the style of narration is slightly different from that of ground-truth captions. However the generated results are generally reasonable. + +| | | | | +| :--------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :---: | +| GT caption | Pastry chef cutting bread into
slices during the preparation
of a dessert, inside a kitchen. | Close-up shot of the hands
of an experienced baker
skillfully kneading bread dough. | Chef preparing a sauce in
a blender, adding different
ingredients while blending. | +| NARRATOR (a) | so now we're going to slice the bread | i'm gonna make a little hole
in the middle of the dough here | all right let's blend this up | +| NARRATOR (b) | now i'm going to do is just slice
this up into a nice chunk and
then we're going to place it
on the plate | you just keep kneading it | the last step to making this
is to blend the ingredients
in the food processor | + +Below is a demo for 3rd-person videos. +```bash +python demo_narrator_3rd_person.py [--video-path $TEST_VIDEO] [--cuda] +``` + +## Dual-Encoder + +The dual-encoder model contains a video encoder and a text encoder. It learns video-langauge representation from both human annotations and generated narrations using a contrastive loss like [CLIP](https://github.com/openai/CLIP). + + +* LaViLa's dual-encoder achieves excellent **zero-shot** performance on a wide range of egocentric benchmarks, outperforming previous state-of-the-art video-language pretraining methods by a large margin. + + +
+ + | | Backbone | EK-100 MIR
avg. mAP^ | EK-100 MIR
avg. nDCG^ | Charades-Ego
mAP | EGTEA
mean acc. | EgoMCQ
intra-video acc. | + | :----------: | :------: | :---------------------: | :----------------------: | :------------------: | :-----------------: | :------------------------: | + | Prev. SOTA^^ | TSF-B | 22.1/23.3 | 22.1/27.9 | 25.2 | 17.6 | 57.2 | + | LAVILA | TSF-B | 29.7/30.9 | 31.5/32.0 | 26.8 | 28.9 | 59.9 | + | LAVILA | TSF-L | 35.0/36.1 | 34.2/34.6 | 28.9 | 34.1 | 63.1 | + +
+ + ^ The two numbers are obtained by using different number of frames as input (4-frame and 16-frame). + + ^^ We use the checkpoints released by [EgoVLP](https://github.com/showlab/EgoVLP) and convert them to be compatible with this codebase. Also note that our reproduced numbers are better than the reported numbers, especially on EK-100 MIR since we evaluate on raw videos directly (for more details, check out Appendix F & Table 10 in our paper). + + For details on how to get the numbers, please refer to [MODEL_ZOO.md](./docs/MODEL_ZOO.md#zero-shot). + + +* Once **fine-tuned** on the down-stream dataset, LaViLa's dual-encoder can also achieve state-of-the-art results on it. We show some key results as follows. + +
+ + | | EK-100 MIR
avg. mAP | EK-100 MIR
avg. nDCG | EK-100 CLS
Action top-1 | Charades-Ego
mAP | EGTEA
mean acc. | + | :--------: | :--------------------: | :---------------------: | :------------------------: | :------------------: | :-----------------: | + | Prev. SOTA | 45.0 | 59.4 | 50.5 | 32.1 | 65.9 | + | LAVILA | 50.9 | 66.5 | 50.9 | 36.1 | 76.0 | + +
+ + For details on how to fine-tune the pre-trained dual-encoder on down-stream datasets, please refer to [MODEL_ZOO.md](./docs/MODEL_ZOO.md#fine-tuned). + +## License +The majority of LAVILA is licensed under a [MIT License](./LICENSE), however portions of the project are available under separate license terms: + +* https://github.com/EGO4D/episodic-memory is licensed under the MIT license. + +* The videos of [cutting a loaf](https://mixkit.co/free-stock-video/pastry-chef-cutting-a-loaf-into-slices-43015/), [kneading a dough](https://mixkit.co/free-stock-video/hands-of-a-baker-kneading-a-dough-42467/), and [preparing a sauce in a blender](https://mixkit.co/free-stock-video/chef-preparing-a-sauce-in-a-blender-43034/) are licensed under the [Mixkit Stock Video Free License](https://mixkit.co/license/#videoFree). + +## Citing LaViLa + +```bibtex +@inproceedings{zhao2022lavila, + title={Learning Video Representations from Large Language Models}, + author={Zhao, Yue and Misra, Ishan and Kr{\"a}henb{\"u}hl, Philipp and Girdhar, Rohit}, + booktitle={arXiv preprint arXiv:2212.04501}, + year={2022} +} +``` diff --git a/assets/06919917-76bc-4adc-b944-2a722f165513.gif b/assets/06919917-76bc-4adc-b944-2a722f165513.gif new file mode 100644 index 0000000..66064be Binary files /dev/null and b/assets/06919917-76bc-4adc-b944-2a722f165513.gif differ diff --git a/assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4 b/assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4 new file mode 100644 index 0000000..26ca336 Binary files /dev/null and b/assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4 differ diff --git a/assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif b/assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif new file mode 100644 index 0000000..f30c1de Binary files /dev/null and b/assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif differ diff --git a/assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif b/assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif new file mode 100644 index 0000000..2132035 Binary files /dev/null and b/assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif differ diff --git a/assets/lavila_ego4d.gif b/assets/lavila_ego4d.gif new file mode 100644 index 0000000..37100dc Binary files /dev/null and b/assets/lavila_ego4d.gif differ diff --git a/assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif b/assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif new file mode 100644 index 0000000..8eb440b Binary files /dev/null and b/assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif differ diff --git a/assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif b/assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif new file mode 100644 index 0000000..404dcb0 Binary files /dev/null and b/assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif differ diff --git a/assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif b/assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif new file mode 100644 index 0000000..8b1e9cb Binary files /dev/null and b/assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif differ diff --git a/assets/narrator.gif b/assets/narrator.gif new file mode 100644 index 0000000..1e63a84 Binary files /dev/null and b/assets/narrator.gif differ diff --git a/assets/rephraser.gif b/assets/rephraser.gif new file mode 100644 index 0000000..a39c23e Binary files /dev/null and b/assets/rephraser.gif differ diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000..966bb41 --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,153 @@ +# Preparing datasets for LAVILA + +Please download the (selected) datasets from the official websites and place or sim-link them under `$LAVILA_ROOT/datasets/`. + +```bash +$LAVILA_ROOT/datasets/ + CharadesEgo/ + EGTEA/ + EK100/ + Ego4D/ +``` + +## Ego4D +1. Download [Ego4D videos](https://ego4d-data.org/docs/start-here/#download-data) (license is required). + +2. Preprocess(TBA) + +3. Download annotations + + a. Download [egomcq.json](https://drive.google.com/file/d/1-5iRYf4BCHmj4MYQYFRMY4bhsWJUN3rW/view) to `$LAVILA_ROOT/datasets/Ego4D` (if you want to evaluate EgoMCQ). + + b. Download [metadata for train split](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.pkl) and [val split](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_val.pkl) to `$LAVILA_ROOT/datasets/Ego4D` ((if you want to train LAVILA from scratch). + +The fold should look like this: +```bash +$LAVILA_ROOT/datasets/ + Ego4D/ + ego4d_train.pkl + ego4d_val.pkl + egomcq.json + video_288px/ + 000786a7-3f9d-4fe6-bfb3-045b368f7d44.mp4/ + 0.mp4 + 300.mp4 + 000a3525-6c98-4650-aaab-be7d2c7b9402.mp4/ + 0.mp4 + ... +``` + + +## EPIC-Kitchens-100 (EK-100) + +1. Download annotations + +```bash +# Assume that you are under `datasets/EK100/` +git clone https://github.com/epic-kitchens/epic-kitchens-100-annotations +``` + +2. Download videos. + + a. For raw videos, please download them from [https://epic-kitchens.github.io/](https://epic-kitchens.github.io/). + + b. (Recommended) The raw videos are huge (~1 TB). As an alternative, please check out a [resized version](). + +3. (For EK-100 MIR) + + a. Generate the relevancy matrix of train/val splits using [the official code](https://github.com/mwray/Joint-Part-of-Speech-Embeddings). + + b. (Recommended) The generated result has some randomness. Therefore, we also provide the [replica of train split](https://dl.fbaipublicfiles.com/lavila/metadata/EK100/caption_relevancy_EPIC_100_retrieval_train.pkl) and [val split](https://dl.fbaipublicfiles.com/lavila/metadata/EK100/caption_relevancy_EPIC_100_retrieval_test.pkl). Please put them to the folder `$LAVILA_ROOT/datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/`. + + +The folder should look like this: +```bash +$LAVILA_ROOT/datasets/ + EK100/ + epic-kitchens-100-annotations/ + EPIC_100_train.csv + EPIC_100_validation.csv + ... + retrieval_annotations/relevancy/ # this appears if you do 3. + caption_relevancy_EPIC_100_retrieval_train.pkl + caption_relevancy_EPIC_100_retrieval_test.pkl + video_ht256px/ + P01/ + P01_01.MP4 + P01_02.MP4 + ... + P01_19.MP4 + P02/ + P02_01.MP4 + P02_02.MP4 + ... + P02_15.MP4 + ... +``` + +## CharadesEgo + +1. Download annotations at [https://prior.allenai.org/projects/charades-ego](https://prior.allenai.org/projects/charades-ego). +```bash +### Annotations +# Assume that you are under `datasets/CharadesEgo/` +wget https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/CharadesEgo.zip +unzip CharadesEgo.zip && rm CharadesEgo.zip +``` + +2. Download data (~11GB) at [https://prior.allenai.org/projects/charades-ego](https://prior.allenai.org/projects/charades-ego). +```bash +### Data +wget https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/CharadesEgo_v1_480.tar +tar -xvf CharadesEgo_v1_480.tar # Or specify an external path using `-C` and sim-link it to here +rm CharadesEgo_v1_480.tar +``` + +3. (For fine-tuning CharadesEgo) Download two additional metadata files: [clip-level metadata (train)](https://dl.fbaipublicfiles.com/lavila/metadata/CharadesEgo/metadata_filtered_train.pkl) and [clip-level metadata (val)](https://dl.fbaipublicfiles.com/lavila/metadata/CharadesEgo/metadata_filtered_val.pkl). Put them to the folder `$LAVILA_ROOT/datasets/CharadesEgo/CharadesEgo/`. + +The folder should look like this: +```bash +$LAVILA_ROOT/datasets/ + CharadesEgo/ + CharadesEgo/ + CharadesEgo_v1_train_only1st.csv + CharadesEgo_v1_test_only1st.csv + ... + metadata_filtered_train.pkl # this appears if you do 3. + metadata_filtered_val.pkl # this appears if you do 3. + CharadesEgo_v1_480/ + 005BU.mp4 + 005BUEGO.mp4 + ... +``` + + +## EGTEA + +1. Visit [https://cbs.ic.gatech.edu/fpv/](https://cbs.ic.gatech.edu/fpv/). + +2. Download `TRIMMED_ACTION_CLIPS` (~20GB) and `ACTION_ANNOTATIONS` and untar to the current folder `$LAVILA_ROOT/datasets/EGTEA`. + +```bash +unzip action_annotation.zip -d EGTEA/ && rm action_annotation.zip +``` + +The folder should look like this: +```bash +$LAVILA_ROOT/datasets/ + EGTEA/ + train_split1.txt + test_split1.txt + cropped_clips/ + OP01-R01-PastaSalad/ + OP01-R01-PastaSalad-1002316-1004005-F024051-F024101.mp4 + OP01-R01-PastaSalad-1004110-1021110-F024057-F024548.mp4 + OP01-R01-PastaSalad-1022590-1024050-F024539-F024581.mp4 + ... + OP01-R02-TurkeySandwich/ + OP01-R02-TurkeySandwich-102320-105110-F002449-F002529.mp4 + OP01-R02-TurkeySandwich-105440-106460-F002528-F002558.mp4 + OP01-R02-TurkeySandwich-107332-133184-F002513-F003259.mp4 + ... + ... +``` diff --git a/demo_narrator.py b/demo_narrator.py new file mode 100644 index 0000000..17deabd --- /dev/null +++ b/demo_narrator.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os +import urllib.request +from collections import OrderedDict + +import decord +import torch +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video + +from lavila.data.video_transforms import Permute +from lavila.data.datasets import get_frame_ids, video_loader_by_frames +from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL +from lavila.models.tokenizer import MyGPT2Tokenizer +from eval_narrator import decode_one + + +def main(args): + + vr = decord.VideoReader(args.video_path) + num_seg = 4 + frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False) + frames = video_loader_by_frames('./', args.video_path, frame_ids) + + ckpt_name = 'vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth' + ckpt_path = os.path.join('modelzoo/', ckpt_name) + os.makedirs('modelzoo/', exist_ok=True) + if not os.path.exists(ckpt_path): + print('downloading model to {}'.format(ckpt_path)) + urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/{}'.format(ckpt_name), ckpt_path) + ckpt = torch.load(ckpt_path, map_location='cpu') + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + # instantiate the model, and load the pre-trained weights + model = VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL( + text_use_cls_token=False, + project_embed_dim=256, + gated_xattn=True, + timesformer_gated_xattn=False, + freeze_lm_vclm=False, # we use model.eval() anyway + freeze_visual_vclm=False, # we use model.eval() anyway + num_frames=4, + drop_path_rate=0. + ) + model.load_state_dict(state_dict, strict=True) + if args.cuda: + model.cuda() + model.eval() + + # transforms on input frames + crop_size = 336 + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]) + ]) + frames = val_transform(frames) + frames = frames.unsqueeze(0) # fake a batch dimension + + tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True) + with torch.no_grad(): + if args.cuda: + frames = frames.cuda(non_blocking=True) + image_features = model.encode_image(frames) + generated_text_ids, ppls = model.generate( + image_features, + tokenizer, + target=None, # free-form generation + max_text_length=77, + top_k=None, + top_p=0.95, # nucleus sampling + num_return_sequences=10, # number of candidates: 10 + temperature=0.7, + early_stopping=True, + ) + + for i in range(10): + generated_text_str = decode_one(generated_text_ids[i], tokenizer) + print('{}: {}'.format(i, generated_text_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila narrator demo') + parser.add_argument('--cuda', action='store_true', help='use cuda') + parser.add_argument('--video-path', default='assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', type=str, help='video path') + args = parser.parse_args() + main(args) diff --git a/demo_narrator_3rd_person.py b/demo_narrator_3rd_person.py new file mode 100644 index 0000000..0d25e09 --- /dev/null +++ b/demo_narrator_3rd_person.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os +import urllib.request +from collections import OrderedDict + +import decord +import torch +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video + +from lavila.data.video_transforms import Permute +from lavila.data.datasets import get_frame_ids, video_loader_by_frames +from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL +from lavila.models.tokenizer import MyGPT2Tokenizer +from eval_narrator import decode_one + + +def main(args): + + vr = decord.VideoReader(args.video_path) + num_seg = 4 + frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False) + frames = video_loader_by_frames('./', args.video_path, frame_ids) + + ckpt_name = 'vclm_openai_timesformer_large_gpt2_xl.pt_htm.jobid_341080.ep_0001.pth' + ckpt_path = os.path.join('modelzoo/', ckpt_name) + os.makedirs('modelzoo/', exist_ok=True) + if not os.path.exists(ckpt_path): + print('downloading model to {}'.format(ckpt_path)) + urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/htm_aa/{}'.format(ckpt_name), ckpt_path) + ckpt = torch.load(ckpt_path, map_location='cpu') + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + # instantiate the model, and load the pre-trained weights + model = VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL( + text_use_cls_token=False, + project_embed_dim=256, + gated_xattn=True, + timesformer_gated_xattn=False, + freeze_lm_vclm=False, # we use model.eval() anyway + freeze_visual_vclm=False, # we use model.eval() anyway + freeze_visual_vclm_temporal=False, + num_frames=4, + drop_path_rate=0. + ) + model.load_state_dict(state_dict, strict=True) + if args.cuda: + model.cuda() + model.eval() + + # transforms on input frames + crop_size = 224 + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]) + ]) + frames = val_transform(frames) + frames = frames.unsqueeze(0) # fake a batch dimension + + tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True) + with torch.no_grad(): + if args.cuda: + frames = frames.cuda(non_blocking=True) + image_features = model.encode_image(frames) + generated_text_ids, ppls = model.generate( + image_features, + tokenizer, + target=None, # free-form generation + max_text_length=77, + top_k=None, + top_p=0.95, # nucleus sampling + num_return_sequences=10, # number of candidates: 10 + temperature=0.7, + early_stopping=True, + ) + + for i in range(10): + generated_text_str = decode_one(generated_text_ids[i], tokenizer) + print('{}: {}'.format(i, generated_text_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila narrator demo') + parser.add_argument('--cuda', action='store_true', help='use cuda') + parser.add_argument('--video-path', type=str, + default='assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.mp4') + args = parser.parse_args() + main(args) diff --git a/docs/INSTALL.md b/docs/INSTALL.md new file mode 100644 index 0000000..980079f --- /dev/null +++ b/docs/INSTALL.md @@ -0,0 +1,15 @@ +# Installation + +## Requirements + + +## Example conda environment setup + +```bash +conda create --name lavila python=3.8 -y +conda activate lavila +pip install -r requirements.txt +``` + +## datasets +If you want to train/evaluate on the datasets, please see [datasets/README.md](../datasets/README.md) to see how we prepare datasets for this project. diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md new file mode 100644 index 0000000..1db4cc6 --- /dev/null +++ b/docs/MODEL_ZOO.md @@ -0,0 +1,311 @@ +# LAVILA Model Zoo + +## Multi-node Training +We use multi-node training on a SLURM cluster with [submitit](https://github.com/facebookincubator/submitit) for producing the results and models in the paper. +Please install `submitit` in your conda environment: +```bash +pip install submitit +``` + + +## Pre-training + +Please refer to [PRETRAIN.md](./PRETRAIN.md). + + +## Narrator + +| Visual Encoder | Text Decoder | METEOR | ROUGE-L | CIDEr | Pre-trained
Vis. Encoder (md5) | checkpoint (md5) | +| :------------: | :----------: | :----: | :-----: | :---: | :-------------------------------: | :--------: | +| TSF-B | GPT-2 | 0.282 | 0.517 | 0.833 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.baseline.ep_0003.pth) (dbcc4d) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth) (68a71f) | +| TSF-L@HR | GPT-2 XL | 0.298 | 0.539 | 0.977 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large_336px_distilbert_base.baseline.ep_0003.pth) (5c69b8) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth) (443263) | + + +
Ego4D val split +

+ +```bash +torchrun --nproc_per_node=1 \ + eval_narrator.py \ + --caption-top-p 0.95 --caption-temperature 0.7 \ + --eval-freq 10000 \ + --resume $CHECKPOINT +``` + +

+ +## Zero-shot + +
+ +| | Backbone | EK-100 MIR
avg. mAP | EK-100 MIR
avg. nDCG | Charades-Ego
mAP^ | EGTEA
mean acc. | EgoMCQ
intra-video acc. | checkpoint | +| :----------: | :------: | :--------------------: | :---------------------: | :------------------: | :-----------------: | :------------------------: | :----------: | +| Prev. SOTA^^ | TSF-B | 22.1/23.3 | 22.1/27.9 | 25.2 | 17.6 | 57.2 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/egovlp_epo1_converted_f16.md5sum_7a3d3b.pth), [best epoch](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/egovlp_converted_f16.md5sum_c33363.pth) | +| LAVILA | TSF-B | 29.7/30.9 | 31.5/32.0 | 26.8 | 28.9 | 59.9 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0001.md5sum_02dbb9.pth)^, [Epoch 5](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) | +| LAVILA | TSF-L | 35.0/36.1 | 34.2/34.6 | 28.9 | 34.1 | 63.1 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0001.md5sum_9a25de.pth)^, [Epoch 3](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) | + +
+ +^ Note that the pre-trained checkpoint to evaluate CharadesEgo is different from that to evalute other datasets. +Specifically, we use the checkpoint at epoch 1 to zero-shot evaluate CharadesEgo and the checkpoint that achieves best average mAP on EK-100 MIR to evaluate other datasets, as is done in [EgoVLP](https://arxiv.org/pdf/2206.01670.pdf). +Our guess is that since CharadesEgo videos (captured by head-mounted mobile cameras) are visually different from Ego4D/EPIC-Kitchens videos (captured by professional action cameras, eg GoPro), pre-training on Ego4D videos for longer will lead to some potential domain discrepancy. + +^^ We use the checkpoints released by [EgoVLP](https://github.com/showlab/EgoVLP) and convert them to be compatible with this codebase. Also note that our reproduced numbers are better than the reported numbers, especially on EK-100 MIR since we evaluate on raw videos directly (for more details, check out Appendix F & Table 10 in our paper). + +
1. EK-100 MIR +

+ +```bash +python eval_zeroshot.py --dataset ek100_mir --root datasets/EK100/video_ht256px/ --clip-length 4 --resume $PATH +``` +By increasing the number of frames per clip, eg `--clip-length 16`, you are expected to see a better performance. + +

+ +
2. EK-100 CLS +

+ +```bash +python eval_zeroshot.py --dataset ek100_cls --metadata-val datasets/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv --resume $PATH +``` + +

+ +
3. Charades-Ego +

+ +```bash +python eval_zeroshot.py --dataset charades_ego --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv --root datasets/CharadesEgo/CharadesEgo_v1_480/ --clip-length 16 --sparse-sample --resume $PATH +``` + +

+ +
4. EGTEA +

+ +```bash +python eval_zeroshot.py --dataset egtea --metadata-val datasets/EGTEA/test_split1.txt --root datasets/EGTEA/cropped_clips/ --clip-length 16 --clip-stride 2 --num-crops 3 --num-clips 10 --resume $PATH +``` + +

+ +
5. EgoMCQ +

+ +```bash +python eval_zeroshot.py --dataset ego4d_mcq --metadata-val datasets/Ego4D/egomcq.json --root datasets/Ego4D/video_5min_chunks_288px/ --clip-length 4 --resume $PATH --use-half -j 4 +``` + +

+ +## Fine-tuned + +### EK-100 MIR + +
+ +| | Backbone | avg mAP | avg nDCG | Pretrain (md5) | Fine-tuned checkpoint | training log | +| :----: | :-------:| :-----: | :------: | :----------: | :-------------------: | :----------: | +| LAVILA | TSF-B | 50.5 | 65.0 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_base.ft_ek100_mir.ep_0085.md5sum_c67d95.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_base.ft_ek100_mir.jobid_57361.log) | +| LAVILA | TSF-L | 50.9 | 66.5 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_large.ft_ek100_mir.ep_0095.md5sum_bd508b.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_large.ft_ek100_mir.jobid_56606.log) | + +
+ + +
Training and evaluating scripts +

+ +### Multi-node training (Slurm) +```bash +# TimeSformer-Base +python run_with_submitit_finetune_retrieval.py \ + --pretrain-model $PATH \ + --use-checkpoint --nodes 4 + +# TimeSformer-Large +python run_with_submitit_finetune_retrieval.py \ + --pretrain-model $PATH \ + --batch-size 4 \ + --use-checkpoint --nodes 4 +``` + +### Single-machine training +```bash +torchrun --nproc_per_node=8 \ + main_finetune_retrieval.py \ + --output-dir $OUT_DIR \ + --pretrain-model $PATH \ + --use-checkpoint +``` + +Note that you might see a slight drop of performance when training on a single node compared to multiple nodes (everything else being the same) because of a smaller total batch size. + +### Evaluation + +Evaluation is done every `--eval-freq 5` epochs by default during fine-tuning. +If you want to evaluate any checkpoint after fine-tuning, please switch to `--evaluate` mode and specify the path to the checkpoint by `--resume $FINETUNED_CHECKPOINT`. +```bash +torchrun --nproc_per_node=1 \ + main_finetune_retrieval.py \ + --output-dir $OUT_DIR \ + --pretrain-model $PATH \ + --use-checkpoint \ + --evaluate \ + --resume $FINETUNED_CHECKPOINT +``` + + +

+ +### CharadesEgo + +
+ +| | Backbone | video mAP |Pretrain^ (md5) | Fine-tuned checkpoint | training log | +| :----: | :-------:| :------: | :-------: | :-------------------: | :----------: | +| LAVILA | TSF-B | 33.7 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0001.md5sum_02dbb9.pth) (02dbb9) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_base.ft_charades_ego.ep_0005.md5sum_39bf4b.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_base.ft_charades_ego.jobid_65760.log) | +| LAVILA | TSF-L | 36.1 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0001.md5sum_9a25de.pth) (9a25de) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_large.ft_charades_ego.ep_0003.md5sum_9448b2.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_large.ft_charades_ego.jobid_65760.log) | + +
+ +^ Note that the pre-trained checkpoint for fine-tuning CharadesEgo is different from that for fine-tuning EK-100 or EGTEA. Same reason stated above. + +
Training and evaluating scripts +

+ +### Multi-node training (Slurm) + +```bash +# TimeSformer-Base +python run_with_submitit_finetune_retrieval.py \ + --dataset charades_ego \ + --metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \ + --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \ + --root datasets/CharadesEgo/CharadesEgo_v1_480/ \ + --epochs 10 \ + --save-freq 1 --eval-freq 1 \ + --sparse-sample \ + --pretrain-model $PATH \ + --use-checkpoint --nodes 4 + +# TimeSformer-Large +python run_with_submitit_finetune_retrieval.py \ + --dataset charades_ego \ + --metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \ + --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \ + --root datasets/CharadesEgo/CharadesEgo_v1_480/ \ + --epochs 10 \ + --save-freq 1 --eval-freq 1 \ + --sparse-sample \ + --pretrain-model $PATH \ + --batch-size 4 \ + --use-checkpoint --nodes 4 +``` + +### Evaluation +```bash +torchrun --nproc_per_node=1 \ + main_finetune_retrieval.py \ + --dataset charades_ego \ + --metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \ + --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \ + --root datasets/CharadesEgo/CharadesEgo_v1_480/ \ + --output-dir $OUT_DIR \ + --sparse-sample \ + --pretrain-model $PATH \ + --evaluate \ + --resume $FINETUNED_CHECKPOINT +``` + +

+ +### EK-100 CLS + +
+ +| | Backbone | V+N+A multi-head | Verb top-1 | Noun top-1 | Action top-1 | Pretrain (md5) | Fine-tuned checkpoint | training log | +| :----: | :-------:| :--------------: | :--------: | :--------: | :---------: | :------------: | :-------------------: | :----------: | +| LAVILA | TSF-B | no | 67.7 | 56.7 | 46.2 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.single_head.ep_0100.md5sum_e8aa0c.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.single_head.jobid_73363.log) | +| LAVILA | TSF-B | yes | 69.0 | 58.4 | 46.9 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.ep_0100.md5sum_4e3575.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.jobid_73361.log) | +| LAVILA | TSF-L | yes | 72.0 | 62.9 | 51.0 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_large.ft_ek100_cls.ep_0090.md5sum_4a2509.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_large.ft_ek100_cls.jobid_74016.log) | +
+ +
Training and evaluating scripts +

+ +### Multi-node training (Slurm) + +```bash +# TimeSformer-Base +python run_with_submitit_finetune_classification.py \ + --pretrain-model $PATH \ + --use-vn-classifier --num-classes 97 300 3806 \ + --use-sgd --wd 4e-5 --lr-multiplier-on-backbone 0.1 \ + --use-checkpoint --node 1 + +# TimeSformer-Large +python run_with_submitit_finetune_classification.py \ + --pretrain-model $PATH \ + --use-vn-classifier --num-classes 97 300 3806 \ + --use-sgd --wd 4e-5 --lr-multiplier-on-backbone 0.1 \ + --use-checkpoint --node 4 +``` + +

+ +### EGTEA + +
+ +| | Backbone | mean Acc. | Pretrain (md5) | Fine-tuned checkpoint | training log | +| :----: | :-------:| :-------: | :------: | :-------------------: | :----------: | +| LAVILA | TSF-B | 70.12 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_base.ft_egtea.ep_0090.md5sum_3b1faf.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_base.ft_egtea.jobid_73358.log) | +| LAVILA | TSF-L | 76.00 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_large.ft_egtea.ep_0095.md5sum_a5ba17.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_large.ft_egtea.jobid_74026.log) | + +
+ +
Training and evaluating scripts +

+ +```bash +# TimeSformer-Base +python run_with_submitit_finetune_classification.py \ + --dataset egtea \ + --metadata-train datasets/EGTEA/train_split1.txt \ + --metadata-val datasets/EGTEA/test_split1.txt \ + --root datasets/EGTEA/cropped_clips/ \ + --pretrain-model $PATH \ + --num-classes 106 \ + --use-sgd --wd 4e-5 \ + --use-checkpoint --node 1 + +# TimeSformer-Large +python run_with_submitit_finetune_classification.py \ + --dataset egtea \ + --metadata-train datasets/EGTEA/train_split1.txt \ + --metadata-val datasets/EGTEA/test_split1.txt \ + --root datasets/EGTEA/cropped_clips/ \ + --pretrain-model $PATH \ + --num-classes 106 \ + --use-sgd --wd 4e-5 \ + --batch-size 4 \ + --use-checkpoint --node 4 +``` +### Evaluation +```bash +torchrun --nproc_per_node=1 \ + main_finetune_classification.py \ + --dataset egtea \ + --metadata-train datasets/EGTEA/train_split1.txt \ + --metadata-val datasets/EGTEA/test_split1.txt \ + --root datasets/EGTEA/cropped_clips/ \ + --output-dir $OUT_DIR \ + --pretrain-model $PATH \ + --num-classes 106 \ + --use-sgd --wd 4e-5 \ + --evaluate \ + --resume $FINETUNED_CHECKPOINT \ + --num-crops 3 --num-clips 10 \ + --use-half +``` +

diff --git a/docs/PRETRAIN.md b/docs/PRETRAIN.md new file mode 100644 index 0000000..ebd5d9f --- /dev/null +++ b/docs/PRETRAIN.md @@ -0,0 +1,125 @@ +# LAVILA Pretraining + +In this doc, we provide a step-by-step guide (with commands) to train LaViLa. +Note that we recommend running the following job with four 8x V100 (32GB) nodes (or eight nodes for the larger backbone) using [submitit](https://github.com/facebookincubator/submitit). +See how to install submitit at [here](./MODEL_ZOO.md#multi-node-training). + + +## Pre-training Dual-Encoder Baseline + +We first pre-train a dual-encoder baseline with human annotations on Ego4d clips. +The goal is (1) to establish a comparable baseline for LAVILA, and (2) provide a video encoder for narrator (see below). +We use a default batch size of 32 per gpu so that the total batch size for InfoNCE loss is `32*8*4=1024`. + +
Train a baseline dual-encoder (with TSF-B) + +```bash +python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_BASE \ + --norm-embed --freeze-temperature \ + --fix-lr --contrastive-use-vissl \ + --nodes 4 --use_volta32 +``` +
+ +To fit a High-Resolution TimeSformer-Large with a sufficient batch size, we use [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), a memory-efficient text encoder, instead of the original text encoder in the CLIP. Additionally we apply [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054). + +
Train a baseline dual-encoder (with TSF-L@HR) + +```bash +python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_LARGE_336PX_DISTILBERT_BASE \ + --batch-size 8 \ + --use-checkpoint --use-zero \ + --norm-embed --freeze-temperature \ + --fix-lr --contrastive-use-vissl \ + --nodes 8 --use_volta32 +``` +
+ +## Training and Evaluating Narrator + +The narrator is a *visually conditioned* large language model (VCLM), which comprises a pre-trained video encoder (obtained above), a text decoder (GPT-2 family), and a few gated cross-attention modules that attends visual information while captioning. Both the video encoder and the text decoder are kept frozen while the cross-attention modules are learnable. + +Note that we turn off Pytorch's automatic mixed-precision (AMP) during training the narrator. We observe training is instable if AMP is on. + +Also note that `$PATH` can be found in the `Vis. Encoder` column of [MODEL_ZOO.md#Narrator](./MODEL_ZOO.md#narrator). If you are using your own checkpoint (e.g. pre-trained in the previous step), please make sure that the following keys in the checkpoint have been dropped: `epoch`, `optimizer`, and `scaler`. + +
Train a baseline narrator (TSF-B as visual encoder and GPT-2 base as textual decoder) + +```bash +python run_with_submitit_pretrain.py \ + --model VCLM_OPENAI_TIMESFORMER_BASE_GPT2 \ + --gated-xattn --freeze-lm-vclm --freeze-visual-vclm --freeze-visual-vclm-temporal \ + --fix-lr --batch-size 8 --clip-grad-value 1.0 --eval-freq 1 --disable-amp \ + --nodes 4 --use_volta32 --resume $PATH # Eg. $PATH can be "modelzoo/clip_openai_timesformer_base.baseline.ep_0003.pth" +``` + +
+ +
Train a strong narrator (TSF-L@HR as visual encoder and GPT-2 XL as textual decoder) + +```bash +python run_with_submitit_pretrain.py \ + --model VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL \ + --gated-xattn --freeze-lm-vclm --freeze-visual-vclm --freeze-visual-vclm-temporal --use-checkpoint \ + --fix-lr --batch-size 8 --clip-grad-value 1.0 --eval-freq 1 --disable-amp \ + --nodes 4 --use_volta32 --resume $PATH # Eg. $PATH can be "modelzoo/clip_openai_timesformer_large_336px_distilbert_base.baseline.ep_0003.pth" +``` +
+ +
Evaluate the narrator on Ego4D val split + +```bash +torchrun --nproc_per_node=1 eval_narrator.py \ + --caption-top-p 0.95 --caption-temperature 0.7 \ + --eval-freq 10000 \ # evaluate on the val split of Ego4D (1/10000-subset for fast evaluation) + --resume $VCLM_CHECKPOINT +``` +This will output some common NLG metrics, such as BLEU-x, METEOR, ROUGE_L, and CIDEr (using the human narrations as ground-truth). +
+ +## Narrating video clips using LAVILA-Narrator + + +
Infer the narrator + +```bash +python run_with_submitit_infer_narrator.py \ + --metadata datasets/Ego4D/ego4d_train.pkl \ + --batch-size 64 \ + --resume $PATH --use-half \ + --nodes 4 --use_volta32 +``` +
+ +It will generate a pickle file (`$output_dir/total.pkl`) which is a list of quintuples - `(video_uid: str, start_time: float, end_time: float, narration_list: List[str], NLL_list: List[float])`. + +For narrator-generated narrations on Ego4D ground-truth clips, we also provide a [replica](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.narrator_63690737.return_10.pkl). Note that the narrator used here is our best performing one. + +## Rephrasing human narrations using LAVILA-Rephraser + +Rephraser is a standard LLM that can paraphrase narrations in existing clips. +Specifically, we use an off-the-shelf T5-based paraphraser which is publicly available at [Hugging Face's model hub](https://huggingface.co/ramsrigouthamg/t5-large-paraphraser-diverse-high-quality). +For more details, please refer to the [model card](https://huggingface.co/ramsrigouthamg/t5-large-paraphraser-diverse-high-quality). + +For rephrased human narrations on Ego4D ground-truth clips, we provide a [replica](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.rephraser.no_punkt_top3.pkl). + + +## Pre-training LAVILA Dual-Encoder +Now we are ready to pre-train our LAVILA's dual-encoder by combining human annotations (augmented by Rephraser) and the Narrator-generated narrations. + +
Training a LaViLa dual-encoder + +```bash +python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_BASE \ + --metadata datasets/Ego4D/ego4d_train.rephraser.no_punkt_top3.pkl \ + --metadata-aux datasets/Ego4D/ego4d_train.narrator_63690737.return_10.pkl \ + --norm-embed --freeze-temperature \ + --freeze-pseudo-temperature \ + --fix-lr --contrastive-use-vissl \ + --nodes 4 --use_volta32 +``` +
+ +## Down-stream Evaluation +With the pre-trained dual-encoder at hand, we now can do zero-shot or fine-tuning evalution evaluations on down-stream benchmarks. +Please refer to [MODEL_ZOO.md](./MODEL_ZOO.md#zero-shot) for more details. diff --git a/eval_narrator.py b/eval_narrator.py new file mode 100644 index 0000000..25d7d86 --- /dev/null +++ b/eval_narrator.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os.path as osp +import time +from collections import OrderedDict + +import numpy as np +# https://github.com/numpy/numpy/issues/21079 +try: + import numpy.distutils + numpy.distutils.__config__.blas_opt_info = np.distutils.__config__.blas_ilp64_opt_info +except Exception: + pass +from nlgeval import NLGEval + +import torch +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video + +from lavila.data import datasets +from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop +from lavila.models import models +from lavila.models.utils import inflate_positional_embeds +from lavila.utils import distributed as dist_utils +from lavila.utils.preprocess import generate_tokenizer + + +def decode_one(generated_ids, tokenizer): + # get the index of + if tokenizer.eos_token_id == tokenizer.bos_token_id: + if tokenizer.eos_token_id in generated_ids[1:].tolist(): + eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1 + else: + eos_id = len(generated_ids.tolist()) - 1 + elif tokenizer.eos_token_id in generated_ids.tolist(): + eos_id = generated_ids.tolist().index(tokenizer.eos_token_id) + else: + eos_id = len(generated_ids.tolist()) - 1 + generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist()) + return generated_text_str + + +def get_args_parser(): + parser = argparse.ArgumentParser(description='LAVILA 0-shot evaluations', add_help=False) + parser.add_argument('--dataset', default='ego4d', type=str, + choices=['ego4d']) + parser.add_argument('--root', + default='datasets/Ego4D/video_5min_chunks_288px/', + type=str, help='path to dataset root') + parser.add_argument('--metadata-val', + default='datasets/Ego4D/ego4d_val.pkl', + type=str, help='path to metadata file (val set)') + parser.add_argument('--output-dir', default='./', type=str, help='output dir') + parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms') + parser.add_argument('--num-clips', default=1, type=int, help='number of clips (for untrimmed videos, eg. Charades)') + parser.add_argument('--clip-length', default=4, type=int, help='clip length') + parser.add_argument('--clip-stride', default=16, type=int, help='clip stride') + parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') + parser.add_argument('--batch-size', default=16, type=int, help='batch_size') + # captioning options + parser.add_argument('--caption-sample', default='multinomial_sample', + choices=['multinomial_sample', 'beam_sample', 'group_beam_search']) + parser.add_argument('--caption-top-k', default=None, type=int, help='top-k sampling (predecessor of nucleus sampling)') + parser.add_argument('--caption-top-p', default=0.95, type=float, help='top-p sampling sampling (aka nucleus sampling)') + parser.add_argument('--caption-num-beams', default=3, type=int) + parser.add_argument('--caption-num-beam-groups', default=1, type=int) + parser.add_argument('--caption-temperature', default=0.7, type=float) + parser.add_argument('--caption-length-penalty', default=1.0, type=float) + parser.add_argument('--caption-num-return-sequences', default=1, type=int) + parser.add_argument('--caption-max-len', default=77, type=int) + parser.add_argument('--caption-disable-visual', action='store_true') + parser.add_argument('--caption-early-stop', action='store_true', help='early stopping to save computation') + parser.add_argument('--caption-output-filename', default='caption.txt', type=str) + # others + parser.add_argument('--eval-freq', default=1000, type=int, + help='percentage (1/eval_freq) of val data to evaluate (for fast prototyping)') + parser.add_argument('--print-freq', default=10, type=int) + parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', + help='number of data loading workers per process') + parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') + parser.add_argument('--use-half', action='store_true') + return parser + + +def main(args): + if args.resume: + ckpt_path = args.resume + elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')): + ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt') + else: + raise Exception('no checkpoint found') + + ckpt = torch.load(ckpt_path, map_location='cpu') + + # create model + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + old_args = ckpt['args'] + print('=> creating model: {}'.format(old_args.model)) + model = getattr(models, old_args.model)( + text_use_cls_token=old_args.use_cls_token, + project_embed_dim=old_args.project_embed_dim, + gated_xattn=False if 'gated_xattn' not in old_args else old_args.gated_xattn, + timesformer_gated_xattn=False if 'timesformer_gated_xattn' not in old_args else old_args.timesformer_gated_xattn, + timesformer_freeze_space=False if 'timesformer_freeze_space' not in old_args else old_args.timesformer_freeze_space, + freeze_lm_vclm=False if 'freeze_lm_vclm' not in old_args else old_args.freeze_lm_vclm, + freeze_visual_vclm=False if 'freeze_visual_vclm' not in old_args else old_args.freeze_visual_vclm, + num_frames=args.clip_length, + drop_path_rate=0, + ) + model.cuda() + if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model: + # inflate weight + print('=> inflating PE in models due to different frame numbers') + state_dict = inflate_positional_embeds( + model.state_dict(), state_dict, + num_frames=args.clip_length, + load_temporal_fix='bilinear', + ) + model.load_state_dict(state_dict, strict=True) + print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})".format(args.resume, ckpt['epoch'], ckpt['best_acc1'])) + + torch.backends.cudnn.benchmark = True + + tokenizer = generate_tokenizer(old_args.model) + crop_size = 224 if '336PX' not in old_args.model else 336 + if args.num_crops == 1 and args.num_clips == 1: + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + ]) + else: + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length), + SpatialCrop(crop_size=crop_size, num_crops=args.num_crops), + ]) + + val_dataset = datasets.VideoCaptionDatasetCLIP( + args.dataset, + args.root, + args.metadata_val, + transform=val_transform, + is_training=False, + tokenizer=tokenizer, + clip_length=args.clip_length, + clip_stride=args.clip_stride, + sparse_sample=False, + subsample_stride=args.eval_freq, + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, drop_last=False) + + validate_caption(val_loader, model, tokenizer, args.caption_output_filename, use_half=args.use_half) + + +def validate_caption(val_loader, model, tokenizer, output_filename='caption.txt', use_half=False): + model.eval() + if args.use_half: + model = model.half() + nlgeval = NLGEval() + f = open(output_filename, 'w') + ppls_all = [] + ppls_with_teacher_all = [] + reference = [] + hypothesis = [] + end_time = time.time() + id_offset = 0 + print('=> start forwarding') + with torch.no_grad(): + for i, inputs in enumerate(val_loader): + if i % args.print_freq == 0: + print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time)) + end_time = time.time() + images = inputs[0].cuda(non_blocking=True) + if use_half: + images = images.half() + target = inputs[1].cuda(non_blocking=True) + + # encode images + image_features = dist_utils.get_model(model).encode_image(images) + + # teacher forcing (to get standard ppl metric) + generated_text_ids_with_teacher, ppls_with_teacher = dist_utils.get_model(model).generate( + image_features, + tokenizer, + target=target, + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + teacher_forcing=True, + early_stopping=args.caption_early_stop, + ) + + if args.caption_sample == 'multinomial_sample': + assert args.caption_num_beam_groups == 1 + generated_text_ids, ppls = dist_utils.get_model(model).generate( + image_features, + tokenizer, + target=target.repeat_interleave(args.caption_num_return_sequences, dim=0), + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + num_return_sequences=args.caption_num_return_sequences, + temperature=args.caption_temperature, + early_stopping=args.caption_early_stop, + ) + elif args.caption_sample == 'beam_sample': + assert args.caption_num_beam_groups == 1 + generated_text_ids, ppls = dist_utils.get_model(model).beam_sample( + image_features, + tokenizer, + target=target, + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + temperature=args.caption_temperature, + length_penalty=args.caption_length_penalty, + num_beams=args.caption_num_beams, + num_return_sequences=args.caption_num_return_sequences, + early_stopping=args.caption_early_stop, + ) + elif args.caption_sample == 'group_beam_search': + assert args.caption_num_beam_groups > 1 and args.caption_num_beams % args.caption_num_beam_groups == 0 + generated_text_ids, ppls = dist_utils.get_model(model).group_beam_search( + image_features, + tokenizer, + target=target if not args.caption_no_gt else None, + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + temperature=args.caption_temperature, + length_penalty=args.caption_length_penalty, + num_beams=args.caption_num_beams, + num_beam_groups=args.caption_num_beam_groups, + num_return_sequences=args.caption_num_return_sequences, + early_stopping=args.caption_early_stop, + ) + else: + raise NotImplementedError + ppls_all.append(ppls.reshape(-1, args.caption_num_return_sequences).mean(1)) + ppls_with_teacher_all.append(ppls_with_teacher) + + for j in range(generated_text_ids.shape[0] // args.caption_num_return_sequences): + for k in range(args.caption_num_return_sequences): + jj = j * args.caption_num_return_sequences + k + + generated_text_str = decode_one(generated_text_ids[jj], tokenizer) + gt_text = decode_one(target[j], tokenizer) + generated_text_str_with_teacher = decode_one(generated_text_ids_with_teacher[j], tokenizer) + + from transformers import BertTokenizer + bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + gt_text = bert_tokenizer.decode(bert_tokenizer(gt_text)['input_ids'][1:-1]) + generated_text_str = bert_tokenizer.decode(bert_tokenizer(generated_text_str)['input_ids'][1:-1]) + generated_text_str_with_teacher = bert_tokenizer.decode(bert_tokenizer(generated_text_str_with_teacher)['input_ids'][1:-1]) + reference.append(gt_text) + hypothesis.append(generated_text_str) + s1 = '[{:6d}] Groundtruth | | {}'.format(id_offset + j, gt_text) + s2 = '[{:6d}] Generated | PPL : {:9.3f} | {}'.format(id_offset + j, ppls[jj], generated_text_str) + s3 = '[{:6d}] Generated (w/. teacher) | PPL : {:9.3f} | {}'.format(id_offset + j, ppls_with_teacher[j], generated_text_str_with_teacher) + for s in [s1, s2, s3]: + # if i % args.print_freq == 0: + # print(s) + f.write('{} \n'.format(s)) + id_offset += generated_text_ids.shape[0] // args.caption_num_return_sequences + + ppls_with_teacher_all = torch.cat(ppls_with_teacher_all, dim=0) + ppls_all = torch.cat(ppls_all, dim=0) + + print('PPL (w/. teacher) = {:9.3f}'.format(ppls_with_teacher_all.mean().item())) + print('PPL (w/o. teacher) = {:9.3f}'.format(ppls_all.mean().item())) + f.write('PPL (w/. teacher) = {:9.3f} \n'.format(ppls_with_teacher_all.mean().item())) + f.write('PPL (w/o. teacher) = {:9.3f} \n'.format(ppls_all.mean().item())) + + print('Avg length for reference: {:9.3f}'.format(sum(map(lambda sentence: len(sentence.split(' ')), reference)) / len(reference))) + print('Avg length for hypothesis: {:9.3f}'.format(sum(map(lambda sentence: len(sentence.split(' ')), hypothesis)) / len(hypothesis))) + f.write('Avg length for reference: {:9.3f} \n'.format(sum(map(lambda sentence: len(sentence.split(' ')), reference)) / len(reference))) + f.write('Avg length for hypothesis: {:9.3f} \n'.format(sum(map(lambda sentence: len(sentence.split(' ')), hypothesis)) / len(hypothesis))) + + print('=> Calling NLGEval') + f.write('=> Calling NLGEval\n') + metrics_dict = nlgeval.compute_metrics([reference], hypothesis) + for k in metrics_dict: + print('{:16s} = {:9.3f}'.format(k, metrics_dict[k])) + f.write('{:16s} = {:9.3f} \n'.format(k, metrics_dict[k])) + f.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila 0-shot evaluations', parents=[get_args_parser()]) + args = parser.parse_args() + main(args) diff --git a/eval_zeroshot.py b/eval_zeroshot.py new file mode 100644 index 0000000..14e6c84 --- /dev/null +++ b/eval_zeroshot.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import numpy as np +import os.path as osp +import time +from collections import OrderedDict + +import pandas as pd +import torch +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video +from sklearn.metrics import confusion_matrix + +from lavila.data import datasets +from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop +from lavila.models import models +from lavila.models.utils import inflate_positional_embeds +from lavila.utils import distributed as dist_utils +from lavila.utils.evaluation import accuracy, get_mean_accuracy +from lavila.utils.evaluation_egomcq import egomcq_accuracy_metrics +from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG) +from lavila.utils.evaluation_charades import charades_map +from lavila.utils.preprocess import generate_label_map, generate_tokenizer + + +def get_args_parser(): + parser = argparse.ArgumentParser(description='LAVILA 0-shot evaluations', add_help=False) + parser.add_argument('--dataset', default='ek100_mir', type=str, + choices=['ek100_cls', 'ek100_mir', 'charades_ego', 'egtea', 'ego4d_mcq']) + parser.add_argument('--root', + default='datasets/EK100/video_ht256px/', + type=str, help='path to dataset root') + parser.add_argument('--metadata-val', + default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv', + type=str, help='path to metadata file (val set)') + parser.add_argument('--relevancy-path', + default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl', + type=str, help='path to relevancy matrix (val set)') + parser.add_argument('--output-dir', default='./', type=str, help='output dir') + parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms') + parser.add_argument('--num-clips', default=1, type=int, help='number of clips (for untrimmed videos, eg. Charades)') + parser.add_argument('--clip-length', default=4, type=int, help='clip length') + parser.add_argument('--clip-stride', default=16, type=int, help='clip stride') + parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') + parser.add_argument('--batch-size', default=16, type=int, help='batch_size') + parser.add_argument('--cls-use-template', action='store_true', help='use prompt in 0-shot classification') + parser.add_argument('--print-freq', default=100, type=int) + parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', + help='number of data loading workers per process') + parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') + parser.add_argument('--use-half', action='store_true') + return parser + + +def main(args): + if args.resume: + ckpt_path = args.resume + elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')): + ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt') + else: + raise Exception('no checkpoint found') + + ckpt = torch.load(ckpt_path, map_location='cpu') + + # create model + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + old_args = ckpt['args'] + print('=> creating model: {}'.format(old_args.model)) + model = getattr(models, old_args.model)( + text_use_cls_token=old_args.use_cls_token, + project_embed_dim=old_args.project_embed_dim, + gated_xattn=False if 'gated_xattn' not in old_args else old_args.gated_xattn, + timesformer_gated_xattn=False if 'timesformer_gated_xattn' not in old_args else old_args.timesformer_gated_xattn, + timesformer_freeze_space=False if 'timesformer_freeze_space' not in old_args else old_args.timesformer_freeze_space, + freeze_lm_vclm=False if 'freeze_lm_vclm' not in old_args else old_args.freeze_lm_vclm, + freeze_visual_vclm=False if 'freeze_visual_vclm' not in old_args else old_args.freeze_visual_vclm, + num_frames=args.clip_length, + drop_path_rate=0, + ) + model.cuda() + if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model: + # inflate weight + print('=> inflating PE in models due to different frame numbers') + state_dict = inflate_positional_embeds( + model.state_dict(), state_dict, + num_frames=args.clip_length, + load_temporal_fix='bilinear', + ) + model.load_state_dict(state_dict, strict=True) + print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})".format(args.resume, ckpt['epoch'], ckpt['best_acc1'])) + + torch.backends.cudnn.benchmark = True + + if args.dataset in ['ek100_cls', 'charades_ego', 'egtea']: + labels, mapping_vn2act = generate_label_map(args.dataset) + else: + mapping_vn2act = None + tokenizer = generate_tokenizer(old_args.model) + crop_size = 224 if '336PX' not in old_args.model else 336 + if args.num_crops == 1 and args.num_clips == 1: + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + ]) + else: + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length), + SpatialCrop(crop_size=crop_size, num_crops=args.num_crops), + ]) + + val_dataset = datasets.get_downstream_dataset( + val_transform, tokenizer, args, subset='val', label_mapping=mapping_vn2act, + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, drop_last=False) + + if args.cls_use_template: + templates = ['#C C {}', '#C {}'] + else: + templates = ['{}'] + + if args.dataset in ['ek100_cls', 'charades_ego', 'egtea']: + preds, targets = validate_zeroshot(val_loader, templates, labels, model, tokenizer) + if args.dataset == 'ek100_cls': + if args.use_half: + preds = preds.float() + top1, top5 = accuracy(preds, targets, topk=(1, 5)) + print('top1 = {:.3f}'.format(top1.item())) + print('top5 = {:.3f}'.format(top5.item())) + elif args.dataset == 'charades_ego': + preds, targets = preds.numpy(), targets.numpy() + m_ap, _, _ = charades_map(preds, targets) + print('mAP = {:.3f}'.format(m_ap)) + elif args.dataset == 'egtea': + preds, targets = preds.numpy(), targets.numpy() + print(preds.shape, targets.shape) + cm = confusion_matrix(targets, preds.argmax(axis=1)) + mean_class_acc, acc = get_mean_accuracy(cm) + print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) + + if args.dataset == 'ek100_mir': + val_dataset = datasets.VideoCaptionDatasetCLIP( + 'ek100_mir', + args.root, + args.metadata_val, + transform=val_transform, is_training=False, + tokenizer=tokenizer, + clip_length=args.clip_length, + clip_stride=args.clip_stride, + sparse_sample=False + ) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, drop_last=False + ) + similarity_matrix = get_similarity_matrix(val_loader, model, print_freq=args.print_freq, use_half=args.use_half) + similarity_matrix = (similarity_matrix + 1) / 2 + video_id = pd.read_csv(args.metadata_val).values[:, 0] + text_id = pd.read_csv(args.metadata_val.replace("test.csv", "test_sentence.csv")).values[:, 0] + indexes = [video_id.tolist().index(elem) for elem in text_id] + similarity_matrix = similarity_matrix[:, indexes] + print(similarity_matrix.shape) + rel_matrix = pd.read_pickle(args.relevancy_path) + vis_map = calculate_mAP(similarity_matrix, rel_matrix) + txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) + print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, (vis_map + txt_map) / 2)) + vis_k_counts = calculate_k_counts(rel_matrix) + txt_k_counts = calculate_k_counts(rel_matrix.T) + vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) + txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) + vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) + txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) + print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2)) + + if args.dataset == 'ego4d_mcq': + val_dataset = datasets.VideoCaptionDatasetMCQ( + args.dataset, + args.root, + args.metadata_val, + transform=val_transform, is_training=False, + tokenizer=tokenizer, + clip_length=args.clip_length, + clip_stride=args.clip_stride, + sparse_sample=False, + ) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, drop_last=False + ) + validate_mcq(val_loader, model, use_half=args.use_half) + + +def validate_zeroshot(val_loader, templates, labels, model, tokenizer): + model.eval() + if args.use_half: + model = model.half() + all_outputs = [] + all_targets = [] + all_vis_features = [] + print('=> encoding captions') + with torch.no_grad(): + text_features = [] + for label in labels: + if isinstance(label, list): + texts = [tmpl.format(lbl) for tmpl in templates for lbl in label] + else: + texts = [tmpl.format(label) for tmpl in templates] + texts = tokenizer(texts) + if isinstance(texts, tuple): + # Bert-style tokenizer will output both ids and mask + texts, masks = texts + texts = texts.cuda(non_blocking=True) + masks = masks.cuda(non_blocking=True) + else: + texts = texts.cuda(non_blocking=True) + masks = None + texts = texts.view(-1, 77).contiguous() + masks = masks.view(-1, 77).contiguous() if masks is not None else None + if masks is not None: + class_embeddings = dist_utils.get_model(model).encode_text(texts, attention_mask=masks) + else: + class_embeddings = dist_utils.get_model(model).encode_text(texts) + class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) + + class_embeddings = class_embeddings.mean(dim=0) + class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) + + text_features.append(class_embeddings) + text_features = torch.stack(text_features, dim=0) + + print('=> start forwarding') + end_time = time.time() + for i, (images, target) in enumerate(val_loader): + if i % args.print_freq == 0: + print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time)) + end_time = time.time() + if isinstance(images, torch.Tensor): + images = images.cuda(non_blocking=True) + if args.use_half: + images = images.half() + target = target.cuda(non_blocking=True) + + # encode images + image_features = dist_utils.get_model(model).encode_image(images) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + all_vis_features.append(image_features) + # cosine similarity as logits + logits_per_image = image_features @ text_features.t() + # logits_per_image = torch.softmax(logits_per_image, dim=1) + else: + target = target.cuda(non_blocking=True) + images_list = images + logits_all_clips = [] + for images in images_list: + images = images.cuda(non_blocking=True) + if args.use_half: + images = images.half() + image_features = dist_utils.get_model(model).encode_image(images) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + logits_per_image = image_features @ text_features.t() + logits_all_clips.append(logits_per_image) + + logits_all_clips = torch.stack(logits_all_clips, dim=0) + logits_per_image = logits_all_clips.max(0).values + # logits_per_image = logits_all_clips.mean(0) + logits_per_image = torch.softmax(logits_per_image, dim=1) + + all_outputs.append(logits_per_image.cpu()) + all_targets.append(target.cpu()) + + return torch.cat(all_outputs), torch.cat(all_targets) + + +def get_similarity_matrix(val_loader, model, print_freq=100, use_half=False): + model.eval() + if use_half: + model = model.half() + all_text_embed = [] + all_video_embed = [] + with torch.no_grad(): + print('=> encoding visual and textual') + for i, inputs in enumerate(val_loader): + if i % print_freq == 0: + print('finish batch {}/{}'.format(i, len(val_loader))) + frames = inputs[0].cuda(non_blocking=True) + if use_half: + frames = frames.half() + texts = inputs[1].cuda(non_blocking=True) + if len(inputs) == 4: + masks = inputs[2].cuda(non_blocking=True) + else: + masks = None + + # encode images + image_features = dist_utils.get_model(model).encode_image(frames) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + all_video_embed.append(image_features.cpu().numpy()) + + if texts.ndim == 3: + is_multiple_narrations = True + texts = texts.view(-1, texts.shape[-1]) + else: + is_multiple_narrations = False + if masks is not None: + text_features = dist_utils.get_model(model).encode_text(texts, attention_mask=masks) + else: + text_features = dist_utils.get_model(model).encode_text(texts) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + all_text_embed.append(text_features.cpu().numpy()) + + all_text_embed = np.vstack(all_text_embed) + all_video_embed = np.vstack(all_video_embed) + similarity_matrix = np.matmul(all_video_embed, all_text_embed.T) + if is_multiple_narrations: + similarity_matrix = similarity_matrix.reshape(all_video_embed.shape[0], all_video_embed.shape[0], -1) + + return similarity_matrix + + +def validate_mcq(val_loader, model, use_half=False): + model.eval() + if use_half: + model.half() + with torch.no_grad(): + print('=> start forwarding') + all_preds = [] + all_gts = [] + all_types = [] + end_time = time.time() + for i, inputs in enumerate(val_loader): + if i % args.print_freq == 0: + print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time)) + end_time = time.time() + texts_query = inputs[0].cuda(non_blocking=True) + frames_options = inputs[1].cuda(non_blocking=True) + if use_half: + frames_options = frames_options.half() + answer = inputs[3] + q_type = inputs[4] + if len(inputs) == 7: + masks_query = inputs[5].cuda(non_blocking=True) + else: + masks_query = None + + batch_size = frames_options.shape[0] + + frames_options = frames_options.view(-1, *frames_options.shape[2:]) + image_features = dist_utils.get_model(model).encode_image(frames_options) + image_features = image_features.view(batch_size, -1, *image_features.shape[1:]) + + if masks_query is not None: + query_features = dist_utils.get_model(model).encode_text(texts_query, attention_mask=masks_query) + else: + query_features = dist_utils.get_model(model).encode_text(texts_query) + + all_gts.append(answer) + all_types.append(q_type) + for j in range(batch_size): + similarity_matrix = torch.matmul(query_features[j], image_features[j].T) + similarity_matrix = similarity_matrix.cpu().detach() + all_preds.append(similarity_matrix) + all_preds = torch.stack(all_preds) + all_gts = torch.cat(all_gts) + all_types = torch.cat(all_types) + metrics = egomcq_accuracy_metrics(all_preds, all_gts, all_types) + print(metrics) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila 0-shot evaluations', parents=[get_args_parser()]) + args = parser.parse_args() + main(args) diff --git a/lavila/data/datasets.py b/lavila/data/datasets.py new file mode 100644 index 0000000..22e2969 --- /dev/null +++ b/lavila/data/datasets.py @@ -0,0 +1,517 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import glob +import json +import numpy as np +import os.path as osp +import pickle +import random + +import decord +import pandas as pd +import torch + + +def datetime2sec(str): + hh, mm, ss = str.split(':') + return int(hh) * 3600 + int(mm) * 60 + float(ss) + + +def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False): + if chunk_len == -1: + vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid))) + second_offset = second + if end_second is not None: + end_second = min(end_second, len(vr) / vr.get_avg_fps()) + else: + end_second = len(vr) / vr.get_avg_fps() + else: + chunk_start = int(second) // chunk_len * chunk_len + second_offset = second - chunk_start + vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start))) + if fps == -1: + fps = vr.get_avg_fps() + + # calculate frame_ids + frame_offset = int(np.round(second_offset * fps)) + total_duration = max(int((end_second - second) * fps), clip_length) + if chunk_len == -1: + if end_second <= second: + raise ValueError("end_second should be greater than second") + else: + frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter) + else: + frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter) + + # load frames + if max(frame_ids) < len(vr): + try: + frames = vr.get_batch(frame_ids).asnumpy() + except decord.DECORDError as error: + print(error) + frames = vr.get_batch([0] * len(frame_ids)).asnumpy() + else: + # find the remaining frames in the next chunk + try: + frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids)) + frames_part1 = vr.get_batch(frame_ids_part1).asnumpy() + vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len))) + frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids)) + frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2] + frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy() + frames = np.concatenate([frames_part1, frames_part2], axis=0) + # the next chunk does not exist; the current chunk is the last one + except (RuntimeError, decord.DECORDError) as error: + print(error) + frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter) + frames = vr.get_batch(frame_ids).asnumpy() + + frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] + return torch.stack(frames, dim=0) + + +def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): + seg_size = float(end_frame - start_frame - 1) / num_segments + seq = [] + for i in range(num_segments): + start = int(np.round(seg_size * i) + start_frame) + end = int(np.round(seg_size * (i + 1)) + start_frame) + end = min(end, end_frame) + if jitter: + frame_id = np.random.randint(low=start, high=(end + 1)) + else: + frame_id = (start + end) // 2 + seq.append(frame_id) + return seq + + +def video_loader_by_frames(root, vid, frame_ids): + vr = decord.VideoReader(osp.join(root, vid)) + try: + frames = vr.get_batch(frame_ids).asnumpy() + frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] + except (IndexError, decord.DECORDError) as error: + print(error) + print("Erroneous video: ", vid) + frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] + return torch.stack(frames, dim=0) + + +class VideoCaptionDatasetBase(torch.utils.data.Dataset): + def __init__(self, dataset, root, metadata, is_trimmed=True): + self.dataset = dataset + self.root = root + self.is_trimmed = is_trimmed + + if self.dataset == 'ego4d': + with open(metadata, 'rb') as f: + self.samples = pickle.load(f) + elif self.dataset == 'ego4d_mcq': + with open(metadata, 'r') as f: + self.samples = json.load(f) + elif self.dataset in ['ek100_cls', 'ek100_mir']: + video_list = glob.glob(osp.join(self.root, '*/*.MP4')) + fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list} + self.samples = [] + with open(metadata) as f: + csv_reader = csv.reader(f) + _ = next(csv_reader) # skip the header + for row in csv_reader: + pid, vid = row[1:3] + # start_frame, end_frame = int(row[6]), int(row[7]) + # Deprecated: some videos might have fps mismatch issue + start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) + narration = row[8] + verb, noun = int(row[10]), int(row[12]) + vid_path = '{}/{}.MP4'.format(pid, vid) + fps = fps_dict[osp.join(self.root, vid_path)] + start_frame = int(np.round(fps * start_timestamp)) + end_frame = int(np.ceil(fps * end_timestamp)) + self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun)) + if self.dataset == 'ek100_mir': + self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv') + if 'train' in metadata: + self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb')) + elif 'test' in metadata: + self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb')) + else: + raise ValueError('{} should contain either "train" or "test"!'.format(metadata)) + self.relevancy = .1 + elif self.dataset == 'egtea': + video_list = glob.glob(osp.join(self.root, '*/*')) + len_dict = {video: len(decord.VideoReader(video)) for video in video_list} + + vn_list, labels = [], [] + for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')): + row = row.strip() + vn = int(row.split(' ')[-1]) + vn_list.append(vn) + narration = ' '.join(row.split(' ')[:-1]) + labels.append(narration.replace('_', ' ').lower()) + # labels.append(narration) + mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)} + + self.samples = [] + with open(metadata) as f: + for row in f: + clip_id, action_idx = row.strip().split(' ')[:2] + video_id = '-'.join(clip_id.split('-')[:3]) + vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id)) + vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id)) + self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)])) + elif self.dataset == 'charades_ego': + video_list = glob.glob(osp.join(self.root, '*.mp4')) + fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list} + self.samples = [] + with open(metadata) as f: + csv_reader = csv.reader(f) + _ = next(csv_reader) # skip the header + for row in csv_reader: + video_id = row[0] + if self.is_trimmed: + for action_tuple in row[9].split(';'): + if not action_tuple: + continue + action, start_timestamp, end_timestamp = action_tuple.split(' ') + start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp) + vid_path = '{}.mp4'.format(video_id) + fps = fps_dict[osp.join(self.root, vid_path)] + start_frame = int(np.round(fps * start_timestamp)) + end_frame = int(np.ceil(fps * end_timestamp)) + self.samples.append((vid_path, start_frame, end_frame, action)) + else: + if not row[9]: + action_list = [] + else: + action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')] + vid_path = '{}.mp4'.format(video_id) + fps = fps_dict[osp.join(self.root, vid_path)] + duration = fps * float(row[10]) + self.samples.append((vid_path, 0, duration, action_list)) + elif self.dataset == 'charades_ego_trimmed': + with open(metadata, 'rb') as f: + self.samples = pickle.load(f) + else: + raise NotImplementedError + + def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False, + narration_selection='random'): + if self.dataset == 'ego4d': + if len(self.samples[i]) == 4: + vid, start_second, end_second, narration = self.samples[i] + frames = video_loader(self.root, vid, start_second, + end_second=end_second, + clip_length=clip_length, + jitter=is_training) + if isinstance(narration, list): + if narration_selection == 'random': + narration = random.choice(narration) + elif narration_selection == 'concat': + narration = '. '.join(narration) + elif narration_selection == 'list': + narration = narration + else: + raise ValueError + return frames, narration + elif len(self.samples[i]) == 5: + # TODO: need better filtering strategy based on nll + vid, start_second, end_second, narration, _ = self.samples[i] + frames = video_loader(self.root, vid, start_second, + end_second=end_second, + clip_length=clip_length, + jitter=is_training) + if isinstance(narration, list): + if narration_selection == 'random': + narration = random.choice(narration) + elif narration_selection == 'concat': + narration = '. '.join(narration) + elif narration_selection == 'list': + narration = narration + else: + raise ValueError + return frames, narration + elif self.dataset == 'ego4d_mcq': + itemMCQ = self.samples[str(i)] + answerIndex = itemMCQ['answer'] + textQuery = itemMCQ['query']['clip_text'] + sampleOptions = itemMCQ['choices'] + frames_options = [] + narration_options = [] + for option_id in range(len(sampleOptions)): + option = sampleOptions[str(option_id)] + frames = video_loader(self.root, option['video_uid'], + float(option['clip_start']), end_second=float(option['clip_end']), + clip_length=clip_length, + jitter=is_training) + frames_options.append(frames) + narration_options.append(option['clip_text']) + return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types'] + elif self.dataset == 'ek100_mir': + vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] + # from third_party.EgoVLP.base.base_dataset import sample_frames_start_end + # frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None) + frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) + frames = video_loader_by_frames(self.root, vid_path, frame_ids) + if is_training: + positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist() + if positive_list != []: + pos = random.sample(positive_list, min(len(positive_list), 1))[0] + if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]: + return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos]) + else: + return frames, (narration, 1) + elif self.dataset == 'ek100_cls': + vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] + frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) + frames = video_loader_by_frames(self.root, vid_path, frame_ids) + return frames, '{}:{}'.format(verb, noun) + elif self.dataset == 'egtea': + vid_path, start_frame, end_frame, sentence = self.samples[i] + if is_training: + assert num_clips == 1 + if end_frame < clip_length * clip_stride: + frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) + zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) + frames = torch.cat((frames, zeros), dim=0) + frames = frames[::clip_stride] + else: + start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1) + frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride) + frames = video_loader_by_frames(self.root, vid_path, frame_ids) + else: + if end_frame < clip_length * clip_stride: + frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) + zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) + frames = torch.cat((frames, zeros), dim=0) + frames = frames[::clip_stride] + frames = frames.repeat(num_clips, 1, 1, 1) + else: + frame_ids = [] + for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int): + frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)) + frames = video_loader_by_frames(self.root, vid_path, frame_ids) + return frames, sentence + elif self.dataset == 'charades_ego': + vid_path, start_frame, end_frame, action_list = self.samples[i] + if sparse_sample: + frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training) + frames = video_loader_by_frames(self.root, vid_path, frame_ids) + else: + if end_frame < clip_length * clip_stride: + frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) + zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) + frames = torch.cat((frames, zeros), dim=0) + frames = frames[::clip_stride] + frames = frames.repeat(num_clips, 1, 1, 1) + else: + frame_ids = [] + for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int): + frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)) + print('frame_ids:', frame_ids) + frames = video_loader_by_frames(self.root, vid_path, frame_ids) + return frames, action_list + elif self.dataset == 'charades_ego_trimmed': + vid, start_second, end_second, narration = self.samples[i] + frames = video_loader(self.root, vid, start_second, + end_second=end_second, + chunk_len=-1, # no chunk for CharadesEgo + fps=-1, # could be variable fps + clip_length=clip_length, + jitter=is_training) + return frames, narration + else: + raise NotImplementedError + + def __getitem__(self, i): + raise NotImplementedError + + def __len__(self): + return len(self.samples) + + +class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase): + def __init__(self, dataset, root, metadata, transform=None, + is_training=True, tokenizer=None, + clip_length=32, clip_stride=2, sparse_sample=False, + narration_selection='random', + num_hard_negatives=0, + subsample_stride=None): + super().__init__(dataset, root, metadata) + + self.full_samples = self.samples.copy() + if isinstance(subsample_stride, int): + self.samples = self.samples[::subsample_stride] + self.transform = transform + self.is_training = is_training + self.tokenizer = tokenizer + self.clip_length = clip_length + self.clip_stride = clip_stride + self.sparse_sample = sparse_sample + self.narration_selection = narration_selection + self.num_hard_negatives = num_hard_negatives + if num_hard_negatives > 0: + assert self.dataset == 'htm_aa' + + def __getitem__(self, i): + frames, caption = self.get_raw_item( + i, is_training=self.is_training, + clip_length=self.clip_length, + clip_stride=self.clip_stride, + sparse_sample=self.sparse_sample, + narration_selection=self.narration_selection, + ) + + # ek100_mir will also output relevancy value + if isinstance(caption, tuple): + caption, relevancy = caption + else: + relevancy = 0. + + # apply transformation + if self.transform is not None: + frames = self.transform(frames) + + # tokenize caption + if self.tokenizer is not None: + caption = self.tokenizer(caption) + + if isinstance(caption, tuple): + caption, mask = caption + return frames, caption, mask, relevancy + else: + return frames, caption, relevancy + + +class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase): + def __init__(self, dataset, root, metadata, transform=None, + is_training=True, tokenizer=None, + clip_length=32, clip_stride=2, sparse_sample=False, + narration_selection='random'): + super().__init__(dataset, root, metadata) + + self.full_samples = self.samples.copy() + self.transform = transform + self.is_training = is_training + self.tokenizer = tokenizer + self.clip_length = clip_length + self.clip_stride = clip_stride + self.sparse_sample = sparse_sample + self.narration_selection = narration_selection + + def __getitem__(self, i): + + textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item( + i, is_training=self.is_training, + clip_length=self.clip_length, + clip_stride=self.clip_stride, + sparse_sample=self.sparse_sample, + narration_selection=self.narration_selection, + ) + + # apply transformation + if self.transform is not None: + frames_options = [self.transform(frames) for frames in frames_options] + + # tokenize caption + if self.tokenizer is not None: + textQuery = self.tokenizer(textQuery) + narration_options = self.tokenizer(narration_options) + if isinstance(textQuery, tuple): + textQuery, mask_query = textQuery + narration_options, mask_options = narration_options + return ( + textQuery, torch.stack(frames_options, dim=0), + narration_options, answerIndex, q_type, + mask_query, mask_options + ) + else: + return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type + + +class VideoClassyDataset(VideoCaptionDatasetBase): + def __init__( + self, dataset, root, metadata, transform=None, + is_training=True, label_mapping=None, + num_clips=1, + clip_length=32, clip_stride=2, + sparse_sample=False, + is_trimmed=True, + ): + super().__init__(dataset, root, metadata, is_trimmed=is_trimmed) + + self.transform = transform + self.is_training = is_training + self.label_mapping = label_mapping + self.num_clips = num_clips + self.clip_length = clip_length + self.clip_stride = clip_stride + self.sparse_sample = sparse_sample + + def __getitem__(self, i): + frames, label = self.get_raw_item( + i, is_training=self.is_training, + num_clips=self.num_clips, + clip_length=self.clip_length, + clip_stride=self.clip_stride, + sparse_sample=self.sparse_sample, + ) + + # apply transformation + if self.transform is not None: + frames = self.transform(frames) + + if self.label_mapping is not None: + if isinstance(label, list): + # multi-label case + res_array = np.zeros(len(self.label_mapping)) + for lbl in label: + res_array[self.label_mapping[lbl]] = 1. + label = res_array + else: + label = self.label_mapping[label] + + return frames, label + + +def get_dataset(train_transform, tokenizer, args, is_training=True): + if 'narration_selection' not in args: + args.narration_selection = 'random' + if args.model.startswith('CLIP') or args.model.startswith('VCLM'): + return VideoCaptionDatasetCLIP( + args.dataset, args.root, args.metadata, train_transform, + is_training=is_training, + tokenizer=tokenizer, + clip_length=args.clip_length, clip_stride=args.clip_stride, + sparse_sample=args.sparse_sample, + narration_selection=args.narration_selection, + num_hard_negatives=args.num_hard_neg if 'num_hard_neg' in args else 0, + ) + else: + raise NotImplementedError + + +def get_downstream_dataset(transform, tokenizer, args, subset='train', label_mapping=None): + if subset == 'train': + return VideoClassyDataset( + args.dataset, args.root, args.metadata_train, transform, + is_training=True, label_mapping=label_mapping, + num_clips=args.num_clips, + clip_length=args.clip_length, clip_stride=args.clip_stride, + sparse_sample=args.sparse_sample, + ) + elif subset == 'val': + return VideoClassyDataset( + args.dataset, args.root, args.metadata_val, transform, + is_training=False, label_mapping=label_mapping, + num_clips=args.num_clips, + clip_length=args.clip_length, clip_stride=args.clip_stride, + sparse_sample=args.sparse_sample, + is_trimmed=not args.dataset == 'charades_ego' + ) + else: + assert ValueError("subset should be either 'train' or 'val'") diff --git a/lavila/data/video_transforms.py b/lavila/data/video_transforms.py new file mode 100644 index 0000000..8f9e925 --- /dev/null +++ b/lavila/data/video_transforms.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Sequence +import torch +import torch.nn as nn +from torchvision import transforms + + +class Permute(nn.Module): + """ + Permutation as an op + """ + + def __init__(self, ordering): + super().__init__() + self.ordering = ordering + + def forward(self, frames): + """ + Args: + frames in some ordering, by default (C, T, H, W) + Returns: + frames in the ordering that was specified + """ + return frames.permute(self.ordering) + + +class TemporalCrop(nn.Module): + """ + Convert the video into smaller clips temporally. + """ + + def __init__( + self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1 + ): + super().__init__() + self.frames = frames_per_clip + self.stride = stride + self.frame_stride = frame_stride + + def forward(self, video): + assert video.ndim == 4, "Must be (C, T, H, W)" + res = [] + for start in range( + 0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride + ): + end = start + (self.frames) * self.frame_stride + res.append(video[:, start: end: self.frame_stride, ...]) + return res + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +class SpatialCrop(nn.Module): + """ + Convert the video into 3 smaller clips spatially. Must be used after the + temporal crops to get spatial crops, and should be used with + -2 in the spatial crop at the slowfast augmentation stage (so full + frames are passed in here). Will return a larger list with the + 3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT) + or 3x10 testing in SlowFast etc. + """ + + def __init__(self, crop_size: int = 224, num_crops: int = 3): + super().__init__() + self.crop_size = crop_size + if num_crops == 6: + self.crops_to_ext = [0, 1, 2] + # I guess Swin uses 5 crops without flipping, but that doesn't + # make sense given they first resize to 224 and take 224 crops. + # (pg 6 of https://arxiv.org/pdf/2106.13230.pdf) + # So I'm assuming we can use flipped crops and that will add sth.. + self.flipped_crops_to_ext = [0, 1, 2] + elif num_crops == 3: + self.crops_to_ext = [0, 1, 2] + self.flipped_crops_to_ext = [] + elif num_crops == 1: + self.crops_to_ext = [1] + self.flipped_crops_to_ext = [] + else: + raise NotImplementedError( + "Nothing else supported yet, " + "slowfast only takes 0, 1, 2 as arguments" + ) + + def forward(self, videos: Sequence[torch.Tensor]): + """ + Args: + videos: A list of C, T, H, W videos. + Returns: + videos: A list with 3x the number of elements. Each video converted + to C, T, H', W' by spatial cropping. + """ + assert isinstance(videos, list), "Must be a list of videos after temporal crops" + assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" + res = [] + for video in videos: + for spatial_idx in self.crops_to_ext: + res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) + if not self.flipped_crops_to_ext: + continue + flipped_video = transforms.functional.hflip(video) + for spatial_idx in self.flipped_crops_to_ext: + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) + return res diff --git a/lavila/models/bpe_simple_vocab_16e6.txt.gz b/lavila/models/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/lavila/models/bpe_simple_vocab_16e6.txt.gz differ diff --git a/lavila/models/coca.py b/lavila/models/coca.py new file mode 100644 index 0000000..9c27a9d --- /dev/null +++ b/lavila/models/coca.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py +# Modified by Yue Zhao +# The original code is under MIT License + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from einops import rearrange + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# normalization +# they use layernorm without bias, something that pytorch does not offer +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward +# https://arxiv.org/abs/2002.05202 +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + *, + context_dim=None, + dim_head=64, + heads=8, + parallel_ff=False, + ff_mult=4, + norm_context=False + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + context_dim = default(context_dim, dim) + + self.norm = LayerNorm(dim) + self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether to have parallel feedforward + + ff_inner_dim = ff_mult * dim + + self.ff = nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False) + ) if parallel_ff else None + + def forward(self, x, context): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + # pre-layernorm, for queries and context + x = self.norm(x) + context = self.context_norm(context) + + # get queries + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads) + + # scale + q = q * self.scale + + # get key / values + k, v = self.to_kv(context).chunk(2, dim=-1) + + # query / key similarity + sim = einsum('b h i d, b j d -> b h i j', q, k) + + # attention + sim = sim - sim.amax(dim=-1, keepdim=True) + attn = sim.softmax(dim=-1) + + # aggregate + out = einsum('b h i j, b j d -> b h i d', attn, v) + + # merge and combine heads + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + + # add parallel feedforward (for multimodal layers) + if exists(self.ff): + out = out + self.ff(x) + + return out diff --git a/lavila/models/distributed_utils.py b/lavila/models/distributed_utils.py new file mode 100644 index 0000000..e5c0d2f --- /dev/null +++ b/lavila/models/distributed_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from +# `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and +# `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py` +# Modified by Yue Zhao +# The original code is under MIT License + +import torch +import torch.distributed as dist +from typing import Tuple + + +def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This helper function converts to the correct + device and returns the tensor + original device. + """ + orig_device = "cpu" if not tensor.is_cuda else "gpu" + if ( + torch.distributed.is_available() + and torch.distributed.get_backend() == torch.distributed.Backend.NCCL + and not tensor.is_cuda + ): + tensor = tensor.cuda() + return (tensor, orig_device) + + +def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This converts the tensor back to original device. + """ + if tensor.is_cuda and orig_device == "cpu": + tensor = tensor.cpu() + return tensor + + +def is_distributed_training_run() -> bool: + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and (torch.distributed.get_world_size() > 1) + ) + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + dist.all_reduce(all_gradients) + return all_gradients[dist.get_rank()] + + +def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: + """ + Similar to classy_vision.generic.distributed_util.gather_from_all + except that it does not cut the gradients + """ + if tensor.ndim == 0: + # 0 dim tensors cannot be gathered. so unsqueeze + tensor = tensor.unsqueeze(0) + + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + gathered_tensors = GatherLayer.apply(tensor) + gathered_tensors = [ + convert_to_normal_tensor(_tensor, orig_device) + for _tensor in gathered_tensors + ] + else: + gathered_tensors = [tensor] + gathered_tensor = torch.cat(gathered_tensors, 0) + return gathered_tensor diff --git a/lavila/models/gpt2_gated.py b/lavila/models/gpt2_gated.py new file mode 100644 index 0000000..d9c06d1 --- /dev/null +++ b/lavila/models/gpt2_gated.py @@ -0,0 +1,1615 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +# Modified by Yue Zhao +# +# +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import copy +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + + +if version.parse(torch.__version__) >= version.parse("1.6"): + is_amp_available = True + from torch.cuda.amp import autocast +else: + is_amp_available = False + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +def augment_gpt2_config(config, cross_attn_freq=1, gated_xattn=True): + new_config = copy.deepcopy(config) + new_config.add_cross_attention = True + new_config.add_cross_attention_freq = cross_attn_freq + new_config.is_tanh_gating = gated_xattn + return new_config + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1) ** 0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + if is_amp_available: + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + else: + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class SqReLU(nn.Module): + """ + See So: Primer: Searching for Efficient Transformers for Language Modeling (So., https://arxiv.org/abs/2109.08668). + """ + + def __init__(self): + super().__init__() + self.act = self._sqrelu_python + + def _sqrelu_python(self, input: torch.Tensor) -> torch.Tensor: + return torch.pow(nn.functional.relu(input), 2) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.act(input) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config, squared_relu=False): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + if squared_relu: + self.act = SqReLU() + else: + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.add_cross_attention_freq = config.add_cross_attention_freq + if config.add_cross_attention and layer_idx % config.add_cross_attention_freq == 0: + self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp_crossattention = GPT2MLP(inner_dim, config, squared_relu=True) + self.ln_2_crossattention = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if config.is_tanh_gating: + self.alpha_cattn = nn.Parameter(torch.zeros([])) + self.alpha_dense = nn.Parameter(torch.zeros([])) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + if encoder_hidden_states is not None and self.attn.layer_idx % self.add_cross_attention_freq == 0: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + if hasattr(self, "alpha_cattn"): + attn_output = torch.tanh(self.alpha_cattn) * attn_output + # residual connection + hidden_states = residual + attn_output + + residual = hidden_states + hidden_states = self.ln_2_crossattention(hidden_states) + feed_forward_hidden_states = self.mlp_crossattention(hidden_states) + if hasattr(self, "alpha_dense"): + feed_forward_hidden_states = torch.tanh(self.alpha_dense) * feed_forward_hidden_states + # residual connection + hidden_states = residual + feed_forward_hidden_states + + # Self-Attention + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + # add cross attentions (follow the original order, not to mess things up) + if encoder_hidden_states is not None and self.attn.layer_idx % self.add_cross_attention_freq == 0: + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + # FFN + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained("gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + def freeze_lm_weights(self): + freeze_list, unfreeze_list = [], [] + for n, p in self.named_parameters(): + if 'crossattention' in n or 'cross_attn' in n or 'alpha_cattn' in n or 'alpha_dense' in n: + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in LM: {}".format(freeze_list)) + print(" Learn the rest parts in LM: {}".format(unfreeze_list)) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1[`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size - 1]` All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output="'LABEL_0'", + expected_loss=5.28, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/lavila/models/loss.py b/lavila/models/loss.py new file mode 100644 index 0000000..27f8a20 --- /dev/null +++ b/lavila/models/loss.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import torch +import torch.distributed as dist +import torch.distributed.nn +import torch.nn as nn +import torch.nn.functional as F + +from .distributed_utils import gather_from_all + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, +): + # Adapted from: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class CLIPLoss(nn.Module): + + def __init__( + self, + use_vissl=False, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + ): + super().__init__() + self.use_vissl = use_vissl + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, outputs): + image_features = outputs['image_embed'] + text_features = outputs['text_embed'] + logit_scale = outputs['logit_scale'] + device = image_features.device + if self.world_size > 1: + if self.use_vissl: + all_image_features = gather_from_all(image_features) + all_text_features = gather_from_all(text_features) + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + # compute accuracy + with torch.no_grad(): + pred = torch.argmax(logits_per_image, dim=-1) + correct = pred.eq(labels).sum() + acc = 100 * correct / logits_per_image.size(0) + + return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc} + + +class SSLCLIPLoss(nn.Module): + + def __init__( + self, + use_vissl=False, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + scale_init=0.08, + freeze_scale=False, + ): + super().__init__() + self.use_vissl = use_vissl + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.logit_scale_pseudo = nn.Parameter(torch.ones([]) * np.log(1 / scale_init)) + if freeze_scale: + self.logit_scale_pseudo.requires_grad = False + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, outputs, gt_indicators): + image_features = outputs['image_embed'] + text_features = outputs['text_embed'] + logit_scale = outputs['logit_scale'] + logit_scale_pseudo = self.logit_scale_pseudo.exp() + device = image_features.device + if self.world_size > 1: + if self.use_vissl: + all_image_features = gather_from_all(image_features) + all_text_features = gather_from_all(text_features) + all_gt_indicators = gather_from_all(gt_indicators) + num = all_gt_indicators.shape[0] + mask = all_gt_indicators.repeat(num, 1) + all_gt_indicators.repeat(num, 1).T + logit_scale_mat = torch.ones((num, num), device=device) + logit_scale_mat[mask == 0] = logit_scale_pseudo + logit_scale_mat[mask == 1] = torch.sqrt(logit_scale_pseudo * logit_scale) + logit_scale_mat[mask == 2] = logit_scale + logits_per_image = logit_scale_mat * (all_image_features @ all_text_features.T) + logits_per_text = logits_per_image.T + else: + raise NotImplementedError + else: + all_gt_indicators = gt_indicators + num = gt_indicators.shape[0] + mask = gt_indicators.repeat(num, 1) + gt_indicators.repeat(num, 1).T + logit_scale_mat = torch.ones((num, num), device=device) + logit_scale_mat[mask == 0] = logit_scale_pseudo + logit_scale_mat[mask == 1] = torch.sqrt(logit_scale_pseudo * logit_scale) + logit_scale_mat[mask == 2] = logit_scale + logits_per_image = logit_scale_mat * (image_features @ text_features.T) + logits_per_text = logit_scale_mat * (text_features @ image_features.T) + + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + # compute accuracy + with torch.no_grad(): + pred = torch.argmax(logits_per_image, dim=-1) + correct = pred.eq(labels).sum() + acc = 100 * correct / logits_per_image.size(0) + pred_gt = pred[all_gt_indicators == 1] + labels_gt = labels[all_gt_indicators == 1] + pred_pseudo = pred[all_gt_indicators == 0] + labels_pseudo = labels[all_gt_indicators == 0] + num_gt = pred_gt.shape[0] + num_pseudo = pred_pseudo.shape[0] + correct_gt = pred_gt.eq(labels_gt).sum() + correct_pseudo = pred_pseudo.eq(labels_pseudo).sum() + acc_gt = 100 * correct_gt / num_gt + acc_pseudo = 100 * correct_pseudo / num_pseudo + + return { + 'loss': loss, 'clip_loss': loss, 'num_gt': torch.tensor([num_gt]), 'num_pseudo': torch.tensor([num_pseudo]), + 'clip_acc': acc, 'clip_acc_gt': acc_gt, 'clip_acc_pseudo': acc_pseudo + } + + +class CaptionLoss(nn.Module): + def __init__(self, pad_id=0, tokenizer=None): + super().__init__() + self.pad_id = pad_id + self.tokenizer = tokenizer + self.pad_id = tokenizer.pad_token_id + + def forward(self, outputs): + logits = outputs['text_tokens_logits'] + labels = outputs['labels'] + # loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id) + loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id, reduction='none') + + # compute accuracy + with torch.no_grad(): + correct = 0. + total = 0. + ppls = [] + for i in range(logits.size(0)): + pred = torch.argmax(logits[i], dim=0) + nopad = labels[i].ne(self.pad_id) + correct += (pred.eq(labels[i]) & nopad).sum() + total += nopad.sum() + ppl = torch.exp(loss[i].sum() / nopad.sum()) + ppls.append(ppl) + # TODO: for debug only + # sep_pos = labels[i].tolist().index(self.tokenizer.tokenizer.sep_token_id) + # if self.tokenizer is not None: + # print('{} {} {}'.format( + # i, self.tokenizer.tokenizer.convert_ids_to_tokens(pred[:sep_pos]), + # self.tokenizer.tokenizer.convert_ids_to_tokens(labels[i, :sep_pos]), + # )) + acc = 100 * correct / (total + 1e-8) + return {'loss': loss.mean(), 'caption_loss': loss.mean(), 'caption_acc': acc, 'ppl': torch.tensor(ppls).mean()} + + +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + + +class MaxMarginRankingLoss(nn.Module): + + def __init__(self, margin=0.2, fix_norm=True): + super().__init__() + self.fix_norm = fix_norm + self.loss = nn.MarginRankingLoss(margin) + self.margin = margin + + def forward(self, outputs, weight=None): + image_features = outputs['image_embed'] + text_features = outputs['text_embed'] + + all_image_features = gather_from_all(image_features) + all_text_features = gather_from_all(text_features) + x = sim_matrix(all_text_features, all_image_features) + + n = x.size()[0] + + x1 = torch.diag(x) + x1 = x1.unsqueeze(1) + x1 = x1.expand(n, n) + x1 = x1.contiguous().view(-1, 1) + x1 = torch.cat((x1, x1), 0) + + x2 = x.view(-1, 1) + x3 = x.transpose(0, 1).contiguous().view(-1, 1) + + x2 = torch.cat((x2, x3), 0) + max_margin = F.relu(self.margin - (x1 - x2)) + + if self.fix_norm: + # remove the elements from the diagonal + keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128 + keep1 = keep.view(-1, 1) + keep2 = keep.transpose(0, 1).contiguous().view(-1, 1) + keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten() + if x1.is_cuda: + keep_idx = keep_idx.cuda() + x1_ = torch.index_select(x1, dim=0, index=keep_idx) + x2_ = torch.index_select(x2, dim=0, index=keep_idx) + max_margin = F.relu(self.margin - (x1_ - x2_)) + + return { + 'loss': max_margin.mean(), + 'max_margin_loss': max_margin.mean() + } + + +class AdaptiveMaxMarginRankingLoss(nn.Module): + + def __init__(self, margin=0.4, fix_norm=True): + super().__init__() + self.fix_norm = fix_norm + self.loss = nn.MarginRankingLoss(margin) + self.margin = margin + + def forward(self, outputs, weight=None): + image_features = outputs['image_embed'] + text_features = outputs['text_embed'] + + all_image_features = gather_from_all(image_features) + all_text_features = gather_from_all(text_features) + all_weights = gather_from_all(weight) + x = sim_matrix(all_text_features, all_image_features) + + n = x.size()[0] + + x1 = torch.diag(x) + x1 = x1.unsqueeze(1) + x1 = x1.expand(n, n) + x1 = x1.contiguous().view(-1, 1) + x1 = torch.cat((x1, x1), 0) + + w1 = all_weights.unsqueeze(1) + w1 = w1.expand(n, n) + w1 = w1.contiguous().view(-1, 1) + w1 = torch.cat((w1, w1), 0) + + x2 = x.view(-1, 1) + x3 = x.transpose(0, 1).contiguous().view(-1, 1) + + x2 = torch.cat((x2, x3), 0) + max_margin = F.relu(w1 * self.margin - (x1 - x2)) + + if self.fix_norm: + # remove the elements from the diagonal + keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128 + keep1 = keep.view(-1, 1) + keep2 = keep.transpose(0, 1).contiguous().view(-1, 1) + keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten() + if x1.is_cuda: + keep_idx = keep_idx.cuda() + x1_ = torch.index_select(x1, dim=0, index=keep_idx) + w1_ = torch.index_select(w1, dim=0, index=keep_idx) + x2_ = torch.index_select(x2, dim=0, index=keep_idx) + max_margin = F.relu(w1_ * self.margin - (x1_ - x2_)) + + return { + 'loss': max_margin.mean(), + 'max_margin_loss': max_margin.mean() + } diff --git a/lavila/models/models.py b/lavila/models/models.py new file mode 100644 index 0000000..a90aee9 --- /dev/null +++ b/lavila/models/models.py @@ -0,0 +1,1218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import DistilBertModel, GPT2LMHeadModel + +import lavila.models.loss as loss +from lavila.models.gpt2_gated import GPT2LMHeadModel as GatedGPT2LMHeadModel +from lavila.models.gpt2_gated import augment_gpt2_config +from lavila.models.narrator import VCLM_HF +from lavila.models.openai_clip import load as load_openai_clip +from lavila.models.openai_model import QuickGELU, Transformer +from lavila.models.timesformer import SpaceTimeTransformer +from lavila.models.utils import remap_keys, rsetattr + + +class VideoClassifier(nn.Module): + def __init__(self, + vision_model: nn.Module, + dropout: float, + num_classes: int, + **kwargs, + ): + super().__init__() + self.visual = vision_model + self.dropout = nn.Dropout(dropout) + self.fc_cls = nn.Linear(vision_model.num_features, num_classes, bias=True) + + self.fc_cls.weight.data.normal_(mean=0.0, std=0.01) + self.fc_cls.bias.data.zero_() + + def forward(self, image, use_checkpoint=False): + image_embed = self.visual(image, use_checkpoint=use_checkpoint) + if isinstance(image_embed, list): + assert len(image_embed) == 1 + image_embed = image_embed[0] + logit = self.fc_cls(self.dropout(image_embed)) + return logit + + +class VideoClassifierMultiHead(nn.Module): + def __init__(self, + vision_model: nn.Module, + dropout: float, + num_classes_list: list, + **kwargs, + ): + super().__init__() + self.visual = vision_model + self.dropout = nn.Dropout(dropout) + self.fc_cls = nn.ModuleList( + [nn.Linear(vision_model.num_features, num_classes, bias=True) for num_classes in num_classes_list] + ) + + for m in self.fc_cls: + m.weight.data.normal_(mean=0.0, std=0.01) + m.bias.data.zero_() + + def forward(self, image, use_checkpoint=False): + image_embed = self.visual(image, use_checkpoint=use_checkpoint) + if isinstance(image_embed, list): + assert len(image_embed) == 1 + image_embed = image_embed[0] + logit_list = [m(self.dropout(image_embed)) for m in self.fc_cls] + return logit_list + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + vision_width: int, + vision_model: nn.Module, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + tempearture_init=0.07, + **kwargs, + ): + super().__init__() + + self.context_length = context_length + self.vision_width = vision_width + + self.visual = vision_model + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm`` + + self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + print("=> initialize initial temperature with {}".format(tempearture_init)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, image, use_checkpoint=False, apply_project=True): + x = self.visual(image, use_checkpoint=use_checkpoint) + if isinstance(x, list): + assert len(x) == 1 + x = x[0] + if not apply_project: + return x + x = x @ self.image_projection + + return x + + def encode_text(self, text, use_checkpoint=False): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, use_checkpoint=use_checkpoint) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text, use_checkpoint=False, norm_embed=False): + image_embed = self.encode_image(image, use_checkpoint=use_checkpoint) + text_embed = self.encode_text(text, use_checkpoint=use_checkpoint) + + if norm_embed: + image_embed = F.normalize(image_embed, dim=-1) + text_embed = F.normalize(text_embed, dim=-1) + return {'image_embed': image_embed, + 'text_embed': text_embed, + 'logit_scale': self.logit_scale.exp()} + + +class CLIP_HF(nn.Module): + def __init__(self, + embed_dim: int, + # vision + vision_width: int, + vision_model: nn.Module, + # text + text_width: int, + text_model: nn.Module, + text_use_cls_token: bool, + text_is_regressive: bool, + tempearture_init=0.07, + **kwargs, + ): + super().__init__() + + self.vision_width = vision_width + self.visual = vision_model + self.text_width = text_width + self.textual = text_model + self.text_use_cls_token = text_use_cls_token + self.text_is_regressive = text_is_regressive + + if 'projection' not in kwargs: + self.projection = 'default' + else: + self.projection = kwargs['projection'] + if self.projection == 'default': + self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) + self.text_projection = nn.Parameter(torch.empty(text_width, embed_dim)) + elif self.projection == 'frozen_in_time': + self.image_projection = nn.Sequential( + nn.Linear(vision_width, embed_dim) + ) + self.text_projection = nn.Sequential( + nn.ReLU(), + nn.Linear(text_width, embed_dim) + ) + print("=> initialize initial temperature with {}".format(tempearture_init)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init)) + + self.initialize_parameters() + + def initialize_parameters(self): + if self.projection == 'default': + nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) + nn.init.normal_(self.text_projection, std=self.text_width ** -0.5) + else: + nn.init.normal_(self.image_projection[0].weight, std=self.vision_width ** -0.5) + nn.init.normal_(self.text_projection[1].weight, std=self.text_width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, image, use_checkpoint=False, apply_project=True): + x = self.visual(image, use_checkpoint=use_checkpoint) + if isinstance(x, list): + assert len(x) == 1 + x = x[0] + if not apply_project: + return x + if self.projection == 'default': + x = x @ self.image_projection + else: + x = self.image_projection(x) + + return x + + def encode_text(self, text, attention_mask=None, use_checkpoint=False): + if use_checkpoint: + if isinstance(self.textual, DistilBertModel): + pass + # print("DistilBertModel does not support gradient checkpointing. Skipping even if use_checkpoint=True") + else: + self.textual.gradient_checkpointing_enable() + else: + self.textual.gradient_checkpointing_disable() + # text, attention_mask = text.squeeze(1), attention_mask.squeeze(1) + # ^ uncomment this only when doing local debugging (distributed=False) + x = self.textual(text, attention_mask=attention_mask) + + if self.text_is_regressive: + # gpt-style + x = x.last_hidden_state + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + else: + # bert-style + if self.text_use_cls_token: + x = x.last_hidden_state + x = x[torch.arange(x.shape[0]), 0, :] + else: + x = x.pooler_output + if self.projection == 'default': + x = x @ self.text_projection + else: + x = self.text_projection(x) + + return x + + def forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False): + image_embed = self.encode_image(image, use_checkpoint=use_checkpoint) + text_embed = self.encode_text(text, attention_mask=mask, use_checkpoint=use_checkpoint) + + if norm_embed: + image_embed = F.normalize(image_embed, dim=-1) + text_embed = F.normalize(text_embed, dim=-1) + return {'image_embed': image_embed, + 'text_embed': text_embed, + 'logit_scale': self.logit_scale.exp()} + + +def get_loss(model, args, tokenizer=None): + if model.startswith('CLIP'): + return loss.CLIPLoss( + use_vissl=args.contrastive_use_vissl, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + ) + elif model.startswith('VCLM'): + return loss.CaptionLoss(tokenizer=tokenizer) + else: + raise NotImplementedError + + +def get_metric_names(model): + if model.startswith('CLIP'): + return ['loss', 'clip_loss', 'clip_acc'] + elif model.startswith('VCLM'): + return ['loss', 'caption_loss', 'caption_acc', 'ppl'] + else: + raise NotImplementedError + + +def CLIP_OPENAI_TIMESFORMER_BASE( + num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False, + temperature_init=0.07, project_embed_dim=256, **kwargs, +): + vision_model = SpaceTimeTransformer( + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + drop_path_rate=drop_path_rate, + ) + clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') + print("=> Loading CLIP (ViT-B/16) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + if timesformer_freeze_space: + print("=> Freeze the space part in TimeSformer") + freeze_list, unfreeze_list = [], [] + for n, p in vision_model.named_parameters(): + if n not in remapped_state_dict or n == 'cls_token': + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list)) + print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list)) + + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + model = CLIP( + embed_dim=project_embed_dim, + vision_width=768, + vision_model=vision_model, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + tempearture_init=temperature_init, + **kwargs + ) + model.transformer.load_state_dict(clip_model.transformer.state_dict()) + model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict()) + model.positional_embedding.data.copy_(clip_model.positional_embedding.data) + model.ln_final.load_state_dict(clip_model.ln_final.state_dict()) + if project_embed_dim == clip_model.text_projection.shape[1]: + print("=> Loading CLIP's text_projection, image_projection and logit_scale directly") + model.image_projection.data.copy_(clip_model.visual.proj.data) + model.text_projection.data.copy_(clip_model.text_projection.data) + model.logit_scale.data.copy_(clip_model.logit_scale.data) + return model + + +def CLIP_OPENAI_TIMESFORMER_LARGE( + num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False, + temperature_init=0.07, project_embed_dim=256, **kwargs, +): + vision_model = SpaceTimeTransformer( + img_size=224, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + drop_path_rate=drop_path_rate, + ) + clip_model, _ = load_openai_clip('ViT-L/14', 'cpu') + print("=> Loading CLIP (ViT-L/14) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + if timesformer_freeze_space: + print("=> Freeze the space part in TimeSformer") + freeze_list, unfreeze_list = [], [] + for n, p in vision_model.named_parameters(): + if n not in remapped_state_dict or n == 'cls_token': + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list)) + print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list)) + + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + model = CLIP( + embed_dim=project_embed_dim, + vision_width=1024, + vision_model=vision_model, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=12, + transformer_layers=12, + tempearture_init=temperature_init, + **kwargs + ) + model.transformer.load_state_dict(clip_model.transformer.state_dict()) + model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict()) + model.positional_embedding.data.copy_(clip_model.positional_embedding.data) + model.ln_final.load_state_dict(clip_model.ln_final.state_dict()) + if project_embed_dim == clip_model.text_projection.shape[1]: + print("=> Loading CLIP's text_projection, image_projection and logit_scale directly") + model.image_projection.data.copy_(clip_model.visual.proj.data) + model.text_projection.data.copy_(clip_model.text_projection.data) + model.logit_scale.data.copy_(clip_model.logit_scale.data) + return model + + +def CLIP_OPENAI_TIMESFORMER_LARGE_336PX( + num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False, + temperature_init=0.07, project_embed_dim=256, **kwargs, +): + vision_model = SpaceTimeTransformer( + img_size=336, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + drop_path_rate=drop_path_rate, + ) + clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu') + print("=> Loading CLIP (ViT-L/14@336px) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + if timesformer_freeze_space: + print("=> Freeze the space part in TimeSformer") + freeze_list, unfreeze_list = [], [] + for n, p in vision_model.named_parameters(): + if n not in remapped_state_dict or n == 'cls_token': + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list)) + print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list)) + + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + model = CLIP( + embed_dim=project_embed_dim, + vision_width=1024, + vision_model=vision_model, + context_length=77, + vocab_size=49408, + transformer_width=768, + transformer_heads=12, + transformer_layers=12, + tempearture_init=temperature_init, + **kwargs + ) + model.transformer.load_state_dict(clip_model.transformer.state_dict()) + model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict()) + model.positional_embedding.data.copy_(clip_model.positional_embedding.data) + model.ln_final.load_state_dict(clip_model.ln_final.state_dict()) + if project_embed_dim == clip_model.text_projection.shape[1]: + print("=> Loading CLIP's text_projection, image_projection and logit_scale directly") + model.image_projection.data.copy_(clip_model.visual.proj.data) + model.text_projection.data.copy_(clip_model.text_projection.data) + model.logit_scale.data.copy_(clip_model.logit_scale.data) + return model + + +def CLIP_OPENAI_TIMESFORMER_BASE_DISTILBERT_BASE( + num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False, + temperature_init=0.07, project_embed_dim=256, **kwargs, +): + vision_model = SpaceTimeTransformer( + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + drop_path_rate=drop_path_rate, + ) + clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') + print("=> Loading CLIP (ViT-B/16) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + if timesformer_freeze_space: + print("=> Freeze the space part in TimeSformer") + freeze_list, unfreeze_list = [], [] + for n, p in vision_model.named_parameters(): + if n not in remapped_state_dict or n == 'cls_token': + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list)) + print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list)) + + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + text_model = DistilBertModel.from_pretrained( + 'distilbert-base-uncased', + ) + kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top + model = CLIP_HF( + embed_dim=project_embed_dim, + vision_width=vision_model.embed_dim, + vision_model=vision_model, + text_width=768, + text_model=text_model, + text_use_cls_token=True, # DistilBert does not have pooler on top + text_is_regressive=False, + tempearture_init=temperature_init, + **kwargs, + ) + + return model + + +def CLIP_OPENAI_TIMESFORMER_LARGE_DISTILBERT_BASE( + num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False, + temperature_init=0.07, project_embed_dim=256, **kwargs, +): + vision_model = SpaceTimeTransformer( + img_size=224, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + drop_path_rate=drop_path_rate, + ) + clip_model, _ = load_openai_clip('ViT-L/14', 'cpu') + print("=> Loading CLIP (ViT-L/14) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + if timesformer_freeze_space: + print("=> Freeze the space part in TimeSformer") + freeze_list, unfreeze_list = [], [] + for n, p in vision_model.named_parameters(): + if n not in remapped_state_dict or n == 'cls_token': + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list)) + print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list)) + + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + text_model = DistilBertModel.from_pretrained( + 'distilbert-base-uncased', + ) + kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top + model = CLIP_HF( + embed_dim=project_embed_dim, + vision_width=vision_model.embed_dim, + vision_model=vision_model, + text_width=768, + text_model=text_model, + text_use_cls_token=True, # DistilBert does not have pooler on top + text_is_regressive=False, + tempearture_init=temperature_init, + **kwargs, + ) + + return model + + +def CLIP_OPENAI_TIMESFORMER_LARGE_336PX_DISTILBERT_BASE( + num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False, + temperature_init=0.07, project_embed_dim=256, **kwargs, +): + vision_model = SpaceTimeTransformer( + img_size=336, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + drop_path_rate=drop_path_rate, + ) + clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu') + print("=> Loading CLIP (ViT-L/14@336px) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + if timesformer_freeze_space: + print("=> Freeze the space part in TimeSformer") + freeze_list, unfreeze_list = [], [] + for n, p in vision_model.named_parameters(): + if n not in remapped_state_dict or n == 'cls_token': + p.requires_grad = True + unfreeze_list.append(n) + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list)) + print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list)) + + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + text_model = DistilBertModel.from_pretrained( + 'distilbert-base-uncased', + ) + kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top + model = CLIP_HF( + embed_dim=project_embed_dim, + vision_width=vision_model.embed_dim, + vision_model=vision_model, + text_width=768, + text_model=text_model, + text_use_cls_token=True, # DistilBert does not have pooler on top + text_is_regressive=False, + tempearture_init=temperature_init, + **kwargs, + ) + + return model + + +def CLIP_HF_EGOVLP_DISTILBERT_BASE(num_frames=4, project_embed_dim=256, **kwargs): + vision_model = SpaceTimeTransformer( + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ) + vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True) + vision_model.load_state_dict(vit_model.state_dict(), strict=False) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + text_model = DistilBertModel.from_pretrained( + 'distilbert-base-uncased', + ) + kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top + kwargs.update({'projection': 'frozen_in_time'}) + model = CLIP_HF( + embed_dim=project_embed_dim, + vision_width=vision_model.embed_dim, + vision_model=vision_model, + text_width=768, + text_model=text_model, + text_use_cls_token=True, # DistilBert does not have pooler on top + text_is_regressive=False, + **kwargs, + ) + + return model + + +def CLIP_HF_TIMESFORMER_DISTILBERT_BASE(num_frames=4, drop_path_rate=0, temperature_init=0.07, project_embed_dim=256, **kwargs): + vision_model = SpaceTimeTransformer( + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + drop_path_rate=drop_path_rate, + ) + vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True) + vision_model.load_state_dict(vit_model.state_dict(), strict=False) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + text_model = DistilBertModel.from_pretrained( + 'distilbert-base-uncased', + ) + kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top + model = CLIP_HF( + embed_dim=project_embed_dim, + vision_width=vision_model.embed_dim, + vision_model=vision_model, + text_width=768, + text_model=text_model, + text_use_cls_token=True, # DistilBert does not have pooler on top + text_is_regressive=False, + tempearture_init=temperature_init, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_VITB16_GPT2_LARGE(gated_xattn=False, freeze_lm_vclm=False, + freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs): + clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') + vision_model = clip_model.visual + kwargs.pop('text_use_cls_token') + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-large", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=768, + vision_model=vision_model, + text_width=1280, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=20, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_VITB16_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False, + freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs): + clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') + vision_model = clip_model.visual + kwargs.pop('text_use_cls_token') + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-xl", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=768, + vision_model=vision_model, + text_width=1600, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=25, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_VITL14_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False, + freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs): + clip_model, _ = load_openai_clip('ViT-L/14', 'cpu') + vision_model = clip_model.visual + kwargs.pop('text_use_cls_token') + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-xl", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=1024, + vision_model=vision_model, + text_width=1600, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=25, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_VITL14_336PX_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False, + freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs): + clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu') + vision_model = clip_model.visual + kwargs.pop('text_use_cls_token') + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-xl", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=1024, + vision_model=vision_model, + text_width=1600, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=25, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_TIMESFORMER_BASE_GPT2( + gated_xattn=False, + random_init_gpt2=False, + freeze_lm_vclm=False, + freeze_visual_vclm=False, + freeze_visual_vclm_temporal=False, + num_frames=4, + timesformer_gated_xattn=False, + **kwargs, +): + vision_model = SpaceTimeTransformer( + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + ) + clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') + print("=> Loading CLIP (ViT-B/16) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=1, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + if not random_init_gpt2: + print('Loading LM from pretrained weights..') + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=768, + vision_model=vision_model, + text_width=768, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=12, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_TIMESFORMER_BASE_GPT2_XL( + gated_xattn=False, + freeze_lm_vclm=False, + freeze_visual_vclm=False, + freeze_visual_vclm_temporal=False, + num_frames=4, + timesformer_gated_xattn=False, + **kwargs, +): + vision_model = SpaceTimeTransformer( + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + ) + clip_model, _ = load_openai_clip('ViT-B/16', 'cpu') + print("=> Loading CLIP (ViT-B/16) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-xl", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=768, + vision_model=vision_model, + text_width=1600, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=25, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL( + gated_xattn=False, + freeze_lm_vclm=False, + freeze_visual_vclm=False, + freeze_visual_vclm_temporal=False, + num_frames=4, + timesformer_gated_xattn=False, + **kwargs, +): + vision_model = SpaceTimeTransformer( + img_size=224, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + ) + clip_model, _ = load_openai_clip('ViT-L/14', 'cpu') + print("=> Loading CLIP (ViT-L/14x) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-xl", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=1024, + vision_model=vision_model, + text_width=1600, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=25, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_TIMESFORMER_LARGE_GPT2( + gated_xattn=False, + freeze_lm_vclm=False, + freeze_visual_vclm=False, + freeze_visual_vclm_temporal=False, + num_frames=4, + timesformer_gated_xattn=False, + **kwargs +): + vision_model = SpaceTimeTransformer( + img_size=224, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + ) + clip_model, _ = load_openai_clip('ViT-L/14', 'cpu') + print("=> Loading CLIP (ViT-L/14x) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=1, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=1024, + vision_model=vision_model, + text_width=768, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=12, + **kwargs, + ) + + return model + + +def VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL( + gated_xattn=False, + freeze_lm_vclm=False, + freeze_visual_vclm=False, + freeze_visual_vclm_temporal=False, + num_frames=4, + timesformer_gated_xattn=False, + **kwargs, +): + vision_model = SpaceTimeTransformer( + img_size=336, patch_size=14, + embed_dim=1024, depth=24, num_heads=16, + num_frames=num_frames, + time_init='zeros', + attention_style='frozen-in-time', + ln_pre=True, + act_layer=QuickGELU, + is_tanh_gating=timesformer_gated_xattn, + ) + clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu') + print("=> Loading CLIP (ViT-L/14@336px) weights") + remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24) + res = vision_model.load_state_dict(remapped_state_dict, strict=False) + print(res) + vision_model.head = nn.Identity() + vision_model.pre_logits = nn.Identity() + vision_model.fc = nn.Identity() + + gpt2 = GPT2LMHeadModel.from_pretrained( + "gpt2-xl", + use_cache=False, + ) + new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=3, gated_xattn=gated_xattn) + text_decoder = GatedGPT2LMHeadModel(new_config) + for n, p in gpt2.named_parameters(): + rsetattr(text_decoder, n + '.data', p.data) + + if freeze_lm_vclm: + print('Freeze the LM part of TextDecoder of VCLM') + text_decoder.freeze_lm_weights() + + if freeze_visual_vclm: + print('Freeze the spatial part of VideoEncoder of VCLM') + vision_model.freeze_spatial_weights() + + if freeze_visual_vclm_temporal: + print('Freeze the temporal part of VideoEncoder of VCLM') + vision_model.freeze_temporal_weights() + + model = VCLM_HF( + vision_width=1024, + vision_model=vision_model, + text_width=1600, + text_decoder=text_decoder, + num_img_queries=256, + dim_head=64, + heads=25, + **kwargs, + ) + + return model + + +def CLIP_OPENAI_VITB32(**kwargs): + model, _ = load_openai_clip('ViT-B/32', 'cpu') + return model + + +def CLIP_OPENAI_VITB16(**kwargs): + model, _ = load_openai_clip('ViT-B/16', 'cpu') + return model + + +def CLIP_OPENAI_VITL14(**kwargs): + model, _ = load_openai_clip('ViT-L/14', 'cpu') + return model + + +def CLIP_OPENAI_VITL14_336PX(**kwargs): + model, _ = load_openai_clip('ViT-L/14@336px', 'cpu') + return model diff --git a/lavila/models/narrator.py b/lavila/models/narrator.py new file mode 100644 index 0000000..61f3e77 --- /dev/null +++ b/lavila/models/narrator.py @@ -0,0 +1,385 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/huggingface/transformers/blob/main/src/transformers/generation_utils.py +# Modified by Yue Zhao +# The original code is under Apache 2.0 License + + +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BeamSearchScorer +from transformers.generation_logits_process import ( + LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper, + TemperatureLogitsWarper, TypicalLogitsWarper, LogitNormalization +) + +from lavila.models.coca import CrossAttention, LayerNorm +from lavila.models.openai_model import VisionTransformer +from lavila.models.timesformer import SpaceTimeTransformer + + +class VCLM_HF(nn.Module): + def __init__(self, + # vision + vision_width: int, + vision_model: nn.Module, + # text + text_width: int, + text_decoder: nn.Module, + num_img_queries=256, + dim_head=64, + heads=8, + **kwargs, + ): + super().__init__() + self.vision_width = vision_width + self.visual = vision_model + self.text_width = text_width + self.text_decoder = text_decoder + + self.img_queries = nn.Parameter(torch.empty(num_img_queries, text_width)) + self.img_attn_pool = CrossAttention( + dim=text_width, context_dim=vision_width, + dim_head=dim_head, heads=heads, + norm_context=True + ) + self.img_attn_pool_norm = LayerNorm(text_width) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.img_queries, std=self.text_width ** -0.5) + + def encode_image(self, image, use_checkpoint=False): + if isinstance(self.visual, VisionTransformer): + # openai_model.VisionTransformer accepts (N, C, H, W) instead of (N, C, T, H, W) + image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW + bb, tt, _, _, _ = image.shape + x = self.visual(image.reshape(-1, *image.shape[2:]), use_checkpoint=use_checkpoint, cls_at_last=False) # NLD + x = x.view(bb, tt, *x.shape[1:]) + x = x.permute(0, 3, 1, 2) + elif isinstance(self.visual, SpaceTimeTransformer): + image = image.permute(0, 2, 1, 3, 4).contiguous() # BCTHW -> BTCHW + bb, tt, _, _, _ = image.shape + x = self.visual.forward_features(image, use_checkpoint=use_checkpoint, cls_at_last=False) # NLD + x = x.permute(0, 2, 1) + else: + x = self.visual(image, use_checkpoint=use_checkpoint, mean_at_last=False) + if isinstance(x, list): + assert len(x) == 1 + x = x[0] + + x = x.flatten(start_dim=2) # BDTHW -> BD(THW) + x = x.permute(0, 2, 1) # BDN -> BND + img_queries = repeat(self.img_queries, 'n d -> b n d', b=x.shape[0]) + img_queries = self.img_attn_pool(img_queries, x) + img_queries = self.img_attn_pool_norm(img_queries) + return img_queries + + def forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False): + if use_checkpoint: + self.text_decoder.gradient_checkpointing_enable() + else: + self.text_decoder.gradient_checkpointing_disable() + + text, labels = text[:, :-1], text[:, 1:] + # mask = mask[:, :-1] + image_tokens = self.encode_image(image, use_checkpoint=use_checkpoint) + + output_decoder = self.text_decoder(text.contiguous(), encoder_hidden_states=image_tokens) + text_tokens_logits = output_decoder.logits + text_tokens_logits = rearrange(text_tokens_logits, 'b n c -> b c n') + + return {'text_tokens_logits': text_tokens_logits, + 'labels': labels} + + def generate(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None, + num_return_sequences=1, temperature=1.0, teacher_forcing=False, early_stopping=False): + image_tokens = image_tokens.repeat_interleave(num_return_sequences, dim=0) + device = image_tokens.device + generated_text_ids = torch.LongTensor([[tokenizer.bos_token_id]] * image_tokens.shape[0]).to(device) + condition_text_ids = generated_text_ids.clone() + + logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=1) + + nlls, num_tokens = torch.zeros(image_tokens.shape[0]).to(device), torch.zeros(image_tokens.shape[0]).to(device) + is_reach_eos = torch.zeros(image_tokens.shape[0]).bool().to(device) + with torch.no_grad(): + for i in range(max_text_length - 1): + output_decoder = self.text_decoder(condition_text_ids, encoder_hidden_states=image_tokens) + decoded_token_logits = output_decoder.logits + next_token_logits = decoded_token_logits[:, -1, :] + if target is not None: + nll = F.cross_entropy(next_token_logits, target[:, i+1], ignore_index=tokenizer.pad_token_id, reduction='none') + nlls += nll + num_tokens += target[:, i+1].ne(tokenizer.pad_token_id) + else: + nll = torch.special.entr(F.softmax(next_token_logits, dim=1)).sum(dim=1) + nlls += nll * (~is_reach_eos) + num_tokens += (~is_reach_eos) + # filtered_p = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p, device=device) + next_token_logits = logits_warper(generated_text_ids, next_token_logits) + filtered_p = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(filtered_p, num_samples=1) + is_reach_eos = is_reach_eos | (next_token[:, 0] == tokenizer.eos_token_id) + if early_stopping and torch.all(is_reach_eos): + break + + if teacher_forcing: + condition_text_ids = target[:, :i+2] + else: + condition_text_ids = torch.cat((generated_text_ids, next_token), dim=1) + + generated_text_ids = torch.cat((generated_text_ids, next_token), dim=1) + if target is not None: + return generated_text_ids, torch.exp(nlls / num_tokens) + else: + return generated_text_ids, torch.exp(nlls / num_tokens) + + def beam_sample(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None, + temperature=1.0, length_penalty=1., + num_beams=3, num_return_sequences=1, teacher_forcing=False, early_stopping=False): + batch_size = image_tokens.shape[0] + device = image_tokens.device + input_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long) + input_ids = input_ids * tokenizer.bos_token_id + + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams * num_return_sequences).view(-1).to(device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + batch_beam_size, cur_len = input_ids.shape + + logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams) + + beam_scorer = BeamSearchScorer( + batch_size=batch_size * num_return_sequences, num_beams=num_beams, + device=device, + length_penalty=length_penalty, + ) + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + beam_scores = torch.zeros((batch_size, num_beams)).to(device) + beam_scores = beam_scores.view((batch_size * num_beams,)) + + is_reach_eos = torch.zeros(batch_beam_size).bool().to(device) + with torch.no_grad(): + for i in range(max_text_length - 1): + output_decoder = self.text_decoder( + input_ids, + encoder_hidden_states=image_tokens.repeat_interleave(num_beams * num_return_sequences, dim=0) + ) + decoded_token_logits = output_decoder.logits + next_token_logits = decoded_token_logits[:, -1, :] + + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) + # supposed to be the line below, but ignore temporarily + # next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores_processed = next_token_scores + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + # supposed to be the line below, but do a simple top_k+top_p temporarily + next_token_scores = logits_warper(input_ids, next_token_scores) + # next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device) + + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id) + if beam_scorer.is_done or torch.all(is_reach_eos): + break + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + max_length=max_text_length, + ) + + sequences = sequence_outputs["sequences"] + sequence_scores = sequence_outputs["sequence_scores"] + return sequences, sequence_scores + + def group_beam_search(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None, + temperature=1.0, length_penalty=1., + num_beams=6, num_beam_groups=3, + num_return_sequences=1, teacher_forcing=False, early_stopping=False): + batch_size = image_tokens.shape[0] + device = image_tokens.device + input_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long) + input_ids = input_ids * tokenizer.bos_token_id + + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams).view(-1).to(device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + batch_beam_size, cur_len = input_ids.shape + + logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams) + + beam_scorer = BeamSearchScorer( + batch_size=batch_size, num_beams=num_beams, + num_beam_groups=num_beam_groups, + num_beam_hyps_to_keep=num_return_sequences, device=device, + length_penalty=length_penalty, + ) + num_sub_beams = num_beams // num_beam_groups + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + is_reach_eos = torch.zeros(batch_beam_size).bool().to(device) + with torch.no_grad(): + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + for i in range(max_text_length - 1): + output_decoder = self.text_decoder( + input_ids, + encoder_hidden_states=image_tokens.repeat_interleave(num_beams, dim=0) + ) + decoded_token_logits = output_decoder.logits + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + next_token_logits = decoded_token_logits[batch_group_indices, -1, :] + + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) + vocab_size = next_token_scores.shape[-1] + + # supposed to be the line below, but ignore temporarily + # next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores_processed = next_token_scores + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + next_token_scores = logits_warper(input_ids, next_token_scores) + # next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + beam_indices=None + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id) + if beam_scorer.is_done or torch.all(is_reach_eos): + break + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + max_length=max_text_length, + beam_indices=None, + ) + + sequences = sequence_outputs["sequences"] + sequence_scores = sequence_outputs["sequence_scores"] + return sequences, sequence_scores + + def _get_logits_warper( + self, top_k=None, top_p=None, typical_p=None, + temperature=None, num_beams=None, renormalize_logits=None, + ): + top_k = top_k if top_k is not None else 0 + top_p = top_p if top_p is not None else 1.0 + typical_p = typical_p if typical_p is not None else 1. + temperature = temperature if temperature is not None else 1. + warpers = LogitsProcessorList() + + if temperature is not None and temperature != 1.0: + warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + # `LogitNormalization` should always be the last logit processor, when present + if renormalize_logits is True: + warpers.append(LogitNormalization()) + return warpers diff --git a/lavila/models/openai_clip.py b/lavila/models/openai_clip.py new file mode 100644 index 0000000..f2d12da --- /dev/null +++ b/lavila/models/openai_clip.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/clip.py +# Modified by Yue Zhao +# The original code is under MIT License + +import hashlib +import os +import urllib +import warnings +from typing import Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .openai_model import build_model +from .tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/lavila/models/openai_model.py b/lavila/models/openai_model.py new file mode 100644 index 0000000..a6c1524 --- /dev/null +++ b/lavila/models/openai_model.py @@ -0,0 +1,485 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/model.py +# Modified by Yue Zhao +# The original code is under MIT License + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm` + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm` + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward_part1(self, x): + return self.attention(self.ln_1(x)) + + def forward_part2(self, x): + return self.mlp(self.ln_2(x)) + + def forward(self, x: torch.Tensor, use_checkpoint=False): + if use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part1, x) + else: + x = x + self.forward_part1(x) + + if use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor, use_checkpoint=False): + if use_checkpoint: + for i in range(self.layers): + x = checkpoint.checkpoint(self.resblocks[i], x) + return x + else: + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, apply_project=True, use_checkpoint=False, cls_at_last=True): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, use_checkpoint=use_checkpoint) + x = x.permute(1, 0, 2) # LND -> NLD + + if cls_at_last: + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None and apply_project: + x = x @ self.proj + + return x + else: + return x[:, 1:, :] + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, apply_project=True, use_checkpoint=False): + if image.ndim == 4: + return self.visual(image.type(self.dtype)) + else: + image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW + bb, tt, _, _, _ = image.shape + x = self.visual(image.reshape(-1, *image.shape[2:]), apply_project=apply_project, use_checkpoint=use_checkpoint) # ND + x = x.view(bb, tt, -1) + image_features = x.mean(1) + # image_features = x.max(1).values + return image_features + + def encode_text(self, text, use_checkpoint=False): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, use_checkpoint=use_checkpoint) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text, use_checkpoint=False, norm_embed=True): + image_features = self.encode_image(image, use_checkpoint=use_checkpoint) + text_features = self.encode_text(text, use_checkpoint=use_checkpoint) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # # cosine similarity as logits + # logit_scale = self.logit_scale.exp() + # logits_per_image = logit_scale * image_features @ text_features.t() + # logits_per_text = logits_per_image.t() + + # # shape = [global_batch_size, global_batch_size] + # return logits_per_image, logits_per_text + + return {'image_embed': image_features, + 'text_embed': text_features, + 'logit_scale': self.logit_scale.exp()} + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/lavila/models/timesformer.py b/lavila/models/timesformer.py new file mode 100644 index 0000000..0019262 --- /dev/null +++ b/lavila/models/timesformer.py @@ -0,0 +1,390 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/m-bain/frozen-in-time/blob/main/model/video_transformer.py +# Modified by Yue Zhao +# The original code is under MIT License + +""" +Implementations of Video Transformers in PyTorch +A PyTorch implementation of space-time transformer as described in +'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650 +A PyTorch implementation of timesformer as described in +'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095 +Acknowledgments: +- This code builds on Ross Wightman's vision_transformer code in pytorch-image-models: +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +- It is also inspired by lucidrains timesformer implementation: +https://github.com/lucidrains/TimeSformer-pytorch +Hacked together by Max Bain +""" + +from collections import OrderedDict +from functools import partial + +import torch +import torch.utils.checkpoint as checkpoint +from einops import rearrange, repeat +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from torch import einsum, nn + + +def attn(q, k, v): + sim = einsum('b i d, b j d -> b i j', q, k) + attn = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', attn, v) + return out + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class VideoPatchEmbed(nn.Module): + """ Video to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, + num_frames=8, ln_pre=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.num_frames = num_frames + self.embed_dim = embed_dim + # ln_pre is inserted to be compatible with CLIP-style model + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre) + + def forward(self, x): + B, F, C, H, W = x.shape + assert F <= self.num_frames + x = x.view(-1, C, H, W) + x = self.proj(x) + return x + + +class VarAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., + initialize='random'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + if initialize == 'zeros': + self.qkv.weight.data.fill_(0) + self.qkv.bias.data.fill_(0) + # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs + # are multiplied by 0*0, which is hard for the model to move out of. + self.proj.weight.data.fill_(1) + self.proj.bias.data.fill_(0) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, einops_from, einops_to, einops_dims): + h = self.num_heads + # project x to q, k, v vaalues + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + q *= self.scale + + # splice out CLS token at index 1 + (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) + + # let CLS token attend to key / values of all patches across time and space + cls_out = attn(cls_q, k, v) + # rearrange across time or space + q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) + + # expand cls token keys and values across time or space and concat + r = q_.shape[0] // cls_k.shape[0] + cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v)) + + k_ = torch.cat((cls_k, k_), dim=1) + v_ = torch.cat((cls_v, v_), dim=1) + + # attention + out = attn(q_, k_, v_) + + # merge back time or space + out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) + + # concat back the cls token + out = torch.cat((cls_out, out), dim=1) + + # merge back the heads + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + # to out + x = self.proj(out) + x = self.proj_drop(x) + return x + + +class SpaceTimeBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros', + attention_style='frozen-in-time', is_tanh_gating=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = VarAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.timeattn = VarAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + initialize=time_init) + + if is_tanh_gating: + self.alpha_timeattn = nn.Parameter(torch.zeros([])) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm3 = norm_layer(dim) + + self.attention_style = attention_style + + def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time, + time_n, space_f, use_checkpoint=False): + if use_checkpoint: + time_output = checkpoint.checkpoint( + self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n} + ) + else: + time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}) + if hasattr(self, "alpha_timeattn"): + time_output = torch.tanh(self.alpha_timeattn) * time_output + time_residual = x + time_output + if use_checkpoint: + space_output = checkpoint.checkpoint( + self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f} + ) + else: + space_output = self.attn(self.norm1(time_residual), einops_from_space, + einops_to_space, {"f": space_f}) + if self.attention_style == 'frozen-in-time': + space_residual = x + self.drop_path(space_output) + else: + raise NotImplementedError + + x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual))) + + return x + + +class SpaceTimeTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain. + https://arxiv.org/abs/2104.00650 + Based off: + - ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py] + lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch]. + Notable differences: + - allows for variable length input frames (<= num_frames) + - allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED] + - different attention block mechanism + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False, + act_layer=nn.GELU, is_tanh_gating=False): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module + norm_layer: (nn.Module): normalization layer + num_frames: (int) maximum number of frames expected as input + time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off + as ViT. + attention_style: (str) how to attend to space and time. + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_frames = num_frames + self.embed_dim = embed_dim + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + print("######USING ATTENTION STYLE: ", attention_style) + if hybrid_backbone is not None: + raise NotImplementedError('hybrid backbone not implemented') + else: + self.patch_embed = VideoPatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre) + num_patches = self.patch_embed.num_patches + self.patches_per_frame = num_patches // num_frames + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patches_per_frame + 1, + embed_dim)) # remember to take pos_embed[1:] for tiling over time + self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) + + if ln_pre: + self.ln_pre = nn.LayerNorm(embed_dim) + else: + self.ln_pre = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + SpaceTimeBlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init, + attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + + # if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary. + if num_frames == 1: + self.apply(self._init_weights) + + # einops transformations + self.einops_from_space = 'b (f n) d' + self.einops_to_space = '(b f) n d' + self.einops_from_time = 'b (f n) d' + self.einops_to_time = '(b n) f d' + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def freeze_spatial_weights(self): + freeze_list = [] + for n, p in self.named_parameters(): + if 'temporal_embed' in n or 'timeattn' in n or 'norm3' in n: + pass + else: + p.requires_grad = False + freeze_list.append(n) + print("Freeze the pretrained parts in vision model: {}".format(freeze_list)) + + def freeze_temporal_weights(self): + freeze_list = [] + for n, p in self.named_parameters(): + if 'temporal_embed' in n or 'timeattn' in n or 'norm3' in n: + p.requires_grad = False + freeze_list.append(n) + else: + pass + print("Freeze the pretrained parts in vision model: {}".format(freeze_list)) + + def forward_features(self, x, use_checkpoint=False, cls_at_last=True): + # print(x.shape) + b, curr_frames, channels, _, _ = x.shape + x = self.patch_embed(x) + x = x.flatten(2).transpose(2, 1) + x = x.reshape(b, -1, self.patch_embed.embed_dim) + + BF = x.shape[0] + cls_tokens = self.cls_token.expand(BF, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + # positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...) + cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) + tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1) + # temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...) + tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1) + total_pos_embed = tile_pos_embed + tile_temporal_embed + total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) + + curr_patches = x.shape[1] + x = x + total_pos_embed[:, :curr_patches] + if self.ln_pre is not None: + x = self.ln_pre(x) + x = self.pos_drop(x) + n = self.patches_per_frame + f = curr_frames + + for blk in self.blocks: + x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time, + self.einops_to_time, + time_n=n, space_f=f, use_checkpoint=use_checkpoint) + + if cls_at_last: + x = self.norm(x)[:, 0] + x = self.pre_logits(x) + + return x + else: + return self.norm(x) + + def forward(self, x, use_checkpoint=False): + # Note: B C T H W => B T C H W + # The default input order is different from the one in Frozen-in-Time + x = x.permute(0, 2, 1, 3, 4).contiguous() + x = self.forward_features(x, use_checkpoint=use_checkpoint) + x = self.head(x) + return x diff --git a/lavila/models/tokenizer.py b/lavila/models/tokenizer.py new file mode 100644 index 0000000..387516a --- /dev/null +++ b/lavila/models/tokenizer.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py +# Modified by Yue Zhao +# The original code is under MIT License + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re +import torch + +from transformers import (BertTokenizer, DistilBertTokenizer, GPT2Tokenizer) + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + def __call__(self, texts, context_length=77): + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, :len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class MyBertTokenizer(object): + def __init__(self, name=''): + print('=> Initialize MyBertTokenizer ({})'.format(name)) + self.tokenizer = BertTokenizer.from_pretrained(name) + self.bos_token_id, self.eos_token_id = self.tokenizer('').input_ids + self.pad_token_id = 0 + + def __call__(self, texts, context_length=77): + if isinstance(texts, str): + texts = [texts] + result = torch.zeros(len(texts), context_length, dtype=torch.long) + mask = torch.zeros(len(texts), context_length, dtype=torch.float32) + for i, text in enumerate(texts): + tokens = self.tokenizer(text) + input_ids = tokens.input_ids[:context_length] + attention_mask = tokens.attention_mask[:context_length] + result[i, :len(input_ids)] = torch.tensor(input_ids) + mask[i, :len(attention_mask)] = torch.tensor(attention_mask) + + if len(result) == 1: + return result[0], mask[0] + return result, mask + + +class MyDistilBertTokenizer(object): + def __init__(self, name=''): + print('=> Initialize MyDistilBertTokenizer ({})'.format(name)) + self.tokenizer = DistilBertTokenizer.from_pretrained(name) + + def __call__(self, texts, context_length=77): + if isinstance(texts, str): + texts = [texts] + result = torch.zeros(len(texts), context_length, dtype=torch.long) + mask = torch.zeros(len(texts), context_length, dtype=torch.float32) + for i, text in enumerate(texts): + tokens = self.tokenizer(text) + input_ids = tokens.input_ids[:context_length] + attention_mask = tokens.attention_mask[:context_length] + result[i, :len(input_ids)] = torch.tensor(input_ids) + mask[i, :len(attention_mask)] = torch.tensor(attention_mask) + + if len(result) == 1: + return result[0], mask[0] + return result, mask + + +class MyGPT2Tokenizer(object): + def __init__(self, name='', add_bos=False): + print('=> Initialize MyGPT2Tokenizer ({})'.format(name)) + self.tokenizer = GPT2Tokenizer.from_pretrained(name) + self.bos_token_id, self.eos_token_id = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id + self.pad_token_id = 0 + self.add_bos = add_bos + # num_added_tokens = self.tokenizer.add_special_tokens({'pad_token': "[PAD]"}) + # print('num_added_tokens={}'.format(len(num_added_tokens))) + + def __call__(self, texts, context_length=77): + if isinstance(texts, str): + texts = [texts] + result = torch.zeros(len(texts), context_length, dtype=torch.long) + for i, text in enumerate(texts): + tokens = self.tokenizer(text) + if not self.add_bos: + input_ids = tokens.input_ids[:context_length - 1] + input_ids = input_ids + [self.tokenizer.eos_token_id] # add [EOS] + else: + input_ids = tokens.input_ids[:context_length - 2] + input_ids = [self.tokenizer.bos_token_id] + input_ids + [self.tokenizer.eos_token_id] # add [EOS] + # attention_mask = tokens.attention_mask[:context_length] + # attention_mask = attention_mask + [0.] * pad_length + result[i, :len(input_ids)] = torch.tensor(input_ids) + + if len(result) == 1: + return result[0] + return result diff --git a/lavila/models/utils.py b/lavila/models/utils.py new file mode 100644 index 0000000..0657f73 --- /dev/null +++ b/lavila/models/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +import functools +import torch +import torch.nn.functional as F + + +def inflate_positional_embeds( + current_model_state_dict, new_state_dict, + num_frames=4, + load_temporal_fix='bilinear', +): + # allow loading of timesformer with fewer num_frames + curr_keys = list(current_model_state_dict.keys()) + if 'visual.temporal_embed' in new_state_dict and 'visual.temporal_embed' in curr_keys: + load_temporal_embed = new_state_dict['visual.temporal_embed'] + load_num_frames = load_temporal_embed.shape[1] + curr_num_frames = num_frames + embed_dim = load_temporal_embed.shape[2] + + if load_num_frames != curr_num_frames: + if load_num_frames > curr_num_frames: + print(f'### loaded SpaceTimeTransformer model has MORE frames than current...' + f'### loading weights, filling in the extras via {load_temporal_fix}') + new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :] + else: + print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...' + f'### loading weights, filling in the extras via {load_temporal_fix}') + if load_temporal_fix == 'zeros': + new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim]) + new_temporal_embed[:, :load_num_frames] = load_temporal_embed + elif load_temporal_fix in ['interp', 'bilinear']: + # interpolate + # unsqueeze so pytorch thinks its an image + mode = 'nearest' + if load_temporal_fix == 'bilinear': + mode = 'bilinear' + load_temporal_embed = load_temporal_embed.unsqueeze(0) + new_temporal_embed = F.interpolate(load_temporal_embed, + (curr_num_frames, embed_dim), mode=mode).squeeze(0) + else: + raise NotImplementedError + new_state_dict['visual.temporal_embed'] = new_temporal_embed + # allow loading with smaller spatial patches. assumes custom border crop, to append the + # border patches to the input sequence + if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys: + load_pos_embed = new_state_dict['visual.pos_embed'] + load_num_patches = load_pos_embed.shape[1] + curr_pos_embed = current_model_state_dict['visual.pos_embed'] + if load_num_patches != curr_pos_embed.shape[1]: + raise NotImplementedError( + 'Loading models with different spatial resolution / patch number not yet implemented, sorry.') + + return new_state_dict + + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +# util functions to convert CLIP-style model keys to TimeSformer-style +def remap_keys(clip_state_dict, transformer_layers=12): + remapped_state_dict = OrderedDict() + key_mapping = { + "class_embedding": "cls_token", + "positional_embedding": "pos_embed", + "conv1.weight": "patch_embed.proj.weight", + "ln_pre.weight": "ln_pre.weight", + "ln_pre.bias": "ln_pre.bias", + "ln_post.weight": "norm.weight", + "ln_post.bias": "norm.bias", + } + for layer in range(transformer_layers): + key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight" + key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias" + key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight" + key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias" + key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight" + key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias" + key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight" + key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias" + key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight" + key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias" + key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight" + key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias" + + for key in clip_state_dict: + if key == 'proj': + continue # due to possible dim mismatch, we load this later + if key == "class_embedding": + clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0) + if key == "positional_embedding": + clip_state_dict[key] = clip_state_dict[key].unsqueeze(0) + remapped_state_dict[key_mapping[key]] = clip_state_dict[key] + + return remapped_state_dict diff --git a/lavila/utils/distributed.py b/lavila/utils/distributed.py new file mode 100644 index 0000000..bcbf22d --- /dev/null +++ b/lavila/utils/distributed.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import torch +import torch.distributed as dist + + +def get_model(model): + if isinstance(model, torch.nn.DataParallel) \ + or isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + else: + return model + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + else: + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(state, is_best, output_dir, is_epoch=True): + if is_main_process(): + ckpt_path = f'{output_dir}/checkpoint.pt' + best_path = f'{output_dir}/checkpoint_best.pt' + if is_best: + torch.save(state, best_path) + if is_epoch: + if isinstance(state['epoch'], int): + ckpt2_path = '{}/checkpoint_{:04d}.pt'.format(output_dir, state['epoch']) + else: + ckpt2_path = '{}/checkpoint_{:.4f}.pt'.format(output_dir, state['epoch']) + torch.save(state, ckpt_path) + shutil.copy(ckpt_path, ckpt2_path) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) diff --git a/lavila/utils/evaluation.py b/lavila/utils/evaluation.py new file mode 100644 index 0000000..d598d2f --- /dev/null +++ b/lavila/utils/evaluation.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def get_mean_accuracy(cm): + list_acc = [] + for i in range(len(cm)): + acc = 0 + if cm[i, :].sum() > 0: + acc = cm[i, i] / cm[i, :].sum() + list_acc.append(acc) + + return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm) diff --git a/lavila/utils/evaluation_charades.py b/lavila/utils/evaluation_charades.py new file mode 100644 index 0000000..73614a9 --- /dev/null +++ b/lavila/utils/evaluation_charades.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + + +def compute_map(submission_array, gt_array): + """ Returns mAP, weighted mAP, and AP array """ + m_aps = [] + n_classes = submission_array.shape[1] + for oc_i in range(n_classes): + sorted_idxs = np.argsort(-submission_array[:, oc_i]) + tp = gt_array[:, oc_i][sorted_idxs] == 1 + fp = np.invert(tp) + n_pos = tp.sum() + if n_pos < 0.1: + m_aps.append(float('nan')) + continue + fp.sum() + f_pcs = np.cumsum(fp) + t_pcs = np.cumsum(tp) + prec = t_pcs / (f_pcs+t_pcs).astype(float) + avg_prec = 0 + for i in range(submission_array.shape[0]): + if tp[i]: + avg_prec += prec[i] + m_aps.append(avg_prec / n_pos.astype(float)) + m_aps = np.array(m_aps) + m_ap = np.mean(m_aps) + w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float)) + return m_ap, w_ap, m_aps + + +def charades_map(submission_array, gt_array): + """ + Approximate version of the charades evaluation function + For precise numbers, use the submission file with the official matlab script + """ + fix = submission_array.copy() + empty = np.sum(gt_array, axis=1) == 0 + fix[empty, :] = np.NINF + return compute_map(fix, gt_array) + + +def create_submission(video_list, predictions, out_file): + assert len(video_list) == predictions.shape[0] + with open(out_file, 'w') as f: + for i, video_id in enumerate(video_list): + pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist())) + f.write('{} {}\n\n'.format(video_id, pred_str)) diff --git a/lavila/utils/evaluation_egomcq.py b/lavila/utils/evaluation_egomcq.py new file mode 100644 index 0000000..98e4254 --- /dev/null +++ b/lavila/utils/evaluation_egomcq.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def egomcq_accuracy_metrics(preds, labels, types): + metrics = {} + type_list = torch.unique(types) + group_list = ["Intra-video", "Inter-video"] + for type_i, group_i in zip(type_list, group_list): + correct = 0 + total = 0 + for pred, label, type in zip(preds, labels, types): + if type == type_i: + pred_ = torch.argmax(pred) + if pred_.item() == label.item(): + correct += 1 + total += 1 + accuracy = correct/total + metrics[group_i] = accuracy * 100 + return metrics diff --git a/lavila/utils/evaluation_ek100cls.py b/lavila/utils/evaluation_ek100cls.py new file mode 100644 index 0000000..6b83d46 --- /dev/null +++ b/lavila/utils/evaluation_ek100cls.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from https://github.com/fpv-iplab/rulstm/blob/master/RULSTM/utils.py +# Modified by Yue Zhao + +import numpy as np + + +def get_marginal_indexes(actions, mode): + """For each verb/noun retrieve the list of actions containing that verb/name + Input: + mode: "verb" or "noun" + Output: + a list of numpy array of indexes. If verb/noun 3 is contained in actions 2,8,19, + then output[3] will be np.array([2,8,19]) + """ + vi = [] + for v in range(actions[mode].max()+1): + vals = actions[actions[mode] == v].index.values + if len(vals) > 0: + vi.append(vals) + else: + vi.append(np.array([0])) + return vi + + +def marginalize(probs, indexes): + mprobs = [] + for ilist in indexes: + mprobs.append(probs[:, ilist].sum(1)) + return np.array(mprobs).T diff --git a/lavila/utils/evaluation_ek100mir.py b/lavila/utils/evaluation_ek100mir.py new file mode 100644 index 0000000..c54587b --- /dev/null +++ b/lavila/utils/evaluation_ek100mir.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Part of the code is from +# `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/NDCG.py` +# and +# `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/mAP.py` +# Modified by Yue Zhao + +import numpy as np + + +def calculate_DCG(similarity_matrix, relevancy_matrix, k_counts): + """ + Calculates the Discounted Cumulative Gain (DCG) between two modalities for + the first modality. + DCG = \sum_{i=1}^k \frac{rel_i}{log_2(i + 1)} + i.e. the sum of the k relevant retrievals which is calculated as the scaled + relevancy for the ith item. The scale is designed such that early + retrievals are more important than later retrievals. + Params: + - similarity_matrix: matrix of size n1 x n2 where n1 is the number of + items in the first modality and n2 is the number of items in the + second modality. The [ith,jth] element is the predicted similarity + between the ith item from the first modality and the jth item from + the second modality. + - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix + above). The [ith, jth] element is the semantic relevancy between the + ith item from the first modality and the jth item from the second + modality. + - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which + includes information on which items to use to calculate the DCG for + (see calculate_k_counts for more info on this matrix). + Returns: + - The DCG for each item in the first modality, a n1 length vector. + """ + x_sz, y_sz = similarity_matrix.shape + ranks = np.argsort(similarity_matrix)[:, ::-1] + # Create vector of size (n,) where n is the length of the last dimension in + # similarity matrix + # This vector is of the form log(i+1) + logs = np.log2(np.arange(y_sz) + 2) + # Convert logs into the divisor for the DCG calculation, of size similarity + # matrix + divisors = np.repeat(np.expand_dims(logs, axis=0), x_sz, axis=0) + + # mask out the sorted relevancy matrix to only use the first k relevant + # retrievals for each item. + columns = np.repeat(np.expand_dims(np.arange(x_sz), axis=1), y_sz, axis=1) + numerators = relevancy_matrix[columns, ranks] * k_counts + # Calculate the final DCG score (note that this isn't expected to sum to 1) + return np.sum(numerators / divisors, axis=1) + + +def calculate_k_counts(relevancy_matrix): + """ + Works out the maximum number of allowed retrievals when working out the + Discounted Cumulative Gain. For each query the DCG only uses the first k + items retrieved which constitute the k relevant items for that query + (otherwise the nDCG scores can be deceptively high for bad rankings). + Params: + - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of + items in the first modality and n2 is the number of items in the + second modality. The [ith, jth] element is the semantic relevancy + between the ith item from the first modality and the jth item from + the second modality. + Returns: + - Matrix of size n1 x n2 (see relevancy matrix for more info). This is + created as a mask such that if the [ith, jth] element is 1 it + represents a valid item to use for the calculation of DCG for the + ith item after sorting. For example, if relevancy matrix of: + [[1, 0.5, 0], + [0, 0 , 1]] + is given, then the k_counts matrix will be: + [[1, 1, 0], + [1, 0, 0]] + i.e. the first row has 2 non-zero items, so the first two retrieved + items should be used in the calculation. In the second row there is + only 1 relevant item, therefore only the first retrieved item should + be used for the DCG calculation. + """ + return (np.sort(relevancy_matrix)[:, ::-1] > 0).astype(int) + + +def calculate_IDCG(relevancy_matrix, k_counts): + """ + Calculates the Ideal Discounted Cumulative Gain (IDCG) which is the value + of the Discounted Cumulative Gain (DCG) for a perfect retrieval, i.e. the + items in the second modality were retrieved in order of their descending + relevancy. + Params: + - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of + items in the first modality and n2 is the number of items in the + second modality. The [ith, jth] element is the semantic relevancy + between the ith item from the first modality and the jth item from + the second modality. + - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which + includes information on which items to use to calculate the DCG for + (see calculate_k_counts for more info on this matrix). + """ + return calculate_DCG(relevancy_matrix, relevancy_matrix, k_counts) + + +def calculate_nDCG(similarity_matrix, relevancy_matrix, k_counts=None, IDCG=None, reduction='mean'): + """ + Calculates the normalised Discounted Cumulative Gain (nDCG) between two + modalities for the first modality using the Discounted Cumulative Gain + (DCG) and the Ideal Discounted Cumulative Gain (IDCG). + nDCG = \frac{DCG}{IDCG} + Params: + - similarity_matrix: matrix of size n1 x n2 where n1 is the number of + items in the first modality and n2 is the number of items in the second + modality. The [ith,jth] element is the predicted similarity between + the ith item from the first modality and the jth item from the second + modality. + - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix + above). The [ith, jth] element is the semantic relevancy between the + ith item from the first modality and the jth item from the second + modality. + - k_counts: optional parameter: matrix of size n1 x n2 (see + similarity_matrix above) which includes information on which items to + use to calculate the DCG for (see calculate_k_counts for more info on + this matrix). This will be calculated using calculate_IDCG if not + present, but should be pre-processed for efficiency. + - IDCG: Optional parameter which includes the pre-processed Ideal + Discounted Cumulative Gain (IDCG). This is a vector of size n1 (see + similarity_matrix above) which contains the IDCG value for each item + from the first modality. This will be calculated using calculate_IDCG + if not present, but should be pre-processed for efficiency. + - reduction: what to use to reduce the different nDCG scores. By + default this applies np.mean across all different queries. + Returns: + - The nDCG values for the first modality. + """ + if k_counts is None: + k_counts = calculate_k_counts(relevancy_matrix) + DCG = calculate_DCG(similarity_matrix, relevancy_matrix, k_counts) + if IDCG is None: + IDCG = calculate_IDCG(relevancy_matrix, k_counts) + if reduction == 'mean': + return np.mean(DCG / IDCG) + elif reduction is None: + return DCG / IDCG + + +def calculate_mAP(sim_mat, relevancy_matrix): + """ + Computes the mean average precision according to the following formula of + average precision: + \frac{\sum_{k=1}^n p(k) x rel(k)}{num_rel_docs} + where p(k) is the precision at k, rel(k) is an indicator function + determining whether the kth returned item is relevant or not and + num_rel_docs is the number of relevant items to find within the search. + The mean average precision is the mean of the average precision for each + query item (i.e row in the matrix) + This function takes in two parameters: + - sim_mat: a NxM matrix which represents the similarity between two + modalities (with modality 1 being of size N and modality 2 of size M). + - relevancy_matrix: an NxM matrix which represents the relevancy between two + modalities of items (with modality 1 being of size N and modality 2 of + size M). + """ + # Find the order of the items in modality 2 according to modality 1 + ranked_order = (-sim_mat).argsort() + ranked_sim_mat = sim_mat[np.arange(sim_mat.shape[0])[:, None], ranked_order] + # re-order the relevancy matrix to accommodate the proposals + ranked_rel_mat = relevancy_matrix[np.arange(relevancy_matrix.shape[0])[:, None], ranked_order] + + # find the number of relevant items found at each k + cumulative_rel_mat = np.cumsum(ranked_rel_mat, axis=1) + # Mask this ensuring that it is non zero if the kth term is 1 (rel(k) above) + cumulative_rel_mat[ranked_rel_mat != 1] = 0 + # find the divisor for p(k) + divisor = np.arange(ranked_rel_mat.shape[1]) + 1 + + # find the number of relevant docs per query item + number_rel_docs = np.sum(ranked_rel_mat == 1, axis=1) + + # find the average precision per query, within np.sum finds p(k) * rel(k) + avg_precision = np.sum(cumulative_rel_mat / divisor, axis=1) / number_rel_docs + mAP = np.mean(avg_precision) + return mAP + + +def get_mAP(similarity_matrix, rel_matrix): + vis_map = calculate_mAP(similarity_matrix, rel_matrix) + txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) + return vis_map, txt_map, (vis_map + txt_map) / 2 + + +def get_nDCG(similarity_matrix, rel_matrix): + vis_k_counts = calculate_k_counts(rel_matrix) + txt_k_counts = calculate_k_counts(rel_matrix.T) + vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) + txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) + vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) + txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) + return vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2 diff --git a/lavila/utils/meter.py b/lavila/utils/meter.py new file mode 100644 index 0000000..08e7086 --- /dev/null +++ b/lavila/utils/meter.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +from lavila.utils import distributed as dist_utils + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def synchronize(self): + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.sum = int(t[0]) + self.count = t[1] + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def synchronize(self): + for meter in self.meters: + meter.synchronize() + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' diff --git a/lavila/utils/preprocess.py b/lavila/utils/preprocess.py new file mode 100644 index 0000000..67cbb6f --- /dev/null +++ b/lavila/utils/preprocess.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import csv + +from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer + + +def generate_label_map(dataset): + if dataset == 'ek100_cls': + print("Preprocess ek100 action label space") + vn_list = [] + mapping_vn2narration = {} + for f in [ + 'datasets/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv', + 'datasets/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv', + ]: + csv_reader = csv.reader(open(f)) + _ = next(csv_reader) # skip the header + for row in csv_reader: + vn = '{}:{}'.format(int(row[10]), int(row[12])) + narration = row[8] + if vn not in vn_list: + vn_list.append(vn) + if vn not in mapping_vn2narration: + mapping_vn2narration[vn] = [narration] + else: + mapping_vn2narration[vn].append(narration) + # mapping_vn2narration[vn] = [narration] + vn_list = sorted(vn_list) + print('# of action= {}'.format(len(vn_list))) + mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} + labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))] + print(labels[:5]) + elif dataset == 'charades_ego': + print("=> preprocessing charades_ego action label space") + vn_list = [] + labels = [] + with open('datasets/CharadesEgo/CharadesEgo/Charades_v1_classes.txt') as f: + csv_reader = csv.reader(f) + for row in csv_reader: + vn = row[0][:4] + vn_list.append(vn) + narration = row[0][5:] + labels.append(narration) + mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)} + print(labels[:5]) + elif dataset == 'egtea': + print("=> preprocessing egtea action label space") + labels = [] + with open('datasets/EGTEA/action_idx.txt') as f: + for row in f: + row = row.strip() + narration = ' '.join(row.split(' ')[:-1]) + labels.append(narration.replace('_', ' ').lower()) + # labels.append(narration) + mapping_vn2act = {label: i for i, label in enumerate(labels)} + print(len(labels), labels[:5]) + else: + raise NotImplementedError + return labels, mapping_vn2act + + +def generate_tokenizer(model): + if model.endswith('DISTILBERT_BASE'): + tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') + elif model.endswith('BERT_BASE'): + tokenizer = MyBertTokenizer('bert-base-uncased') + elif model.endswith('BERT_LARGE'): + tokenizer = MyBertTokenizer('bert-large-uncased') + elif model.endswith('GPT2'): + tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True) + elif model.endswith('GPT2_MEDIUM'): + tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True) + elif model.endswith('GPT2_LARGE'): + tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True) + elif model.endswith('GPT2_XL'): + tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True) + else: + print("Using SimpleTokenizer because of model '{}'. " + "Please check if this is what you want".format(model)) + tokenizer = SimpleTokenizer() + return tokenizer diff --git a/lavila/utils/random.py b/lavila/utils/random.py new file mode 100644 index 0000000..1a74644 --- /dev/null +++ b/lavila/utils/random.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +import numpy as np +import torch + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) diff --git a/lavila/utils/scheduler.py b/lavila/utils/scheduler.py new file mode 100644 index 0000000..f7435f0 --- /dev/null +++ b/lavila/utils/scheduler.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule diff --git a/main_finetune_classification.py b/main_finetune_classification.py new file mode 100644 index 0000000..264d7ad --- /dev/null +++ b/main_finetune_classification.py @@ -0,0 +1,716 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from collections import OrderedDict +import json +import math +import numpy as np +import os +import pandas as pd +import sys +import time + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.cuda.amp as amp +from torch.distributed.optim import ZeroRedundancyOptimizer +import torch.nn.parallel +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video +from sklearn.metrics import confusion_matrix +import wandb + +from lavila.data import datasets +from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop +from lavila.models import models +from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) +from lavila.models.utils import inflate_positional_embeds +from lavila.utils import distributed as dist_utils +from lavila.utils.evaluation import accuracy, get_mean_accuracy +from lavila.utils.meter import AverageMeter, ProgressMeter +from lavila.utils.preprocess import generate_label_map +from lavila.utils.random import random_seed +from lavila.utils.scheduler import cosine_scheduler +from lavila.utils.evaluation_ek100cls import get_marginal_indexes, marginalize + + +def get_args_parser(): + parser = argparse.ArgumentParser(description='lavila finetune and evaluation', add_help=False) + # Data + parser.add_argument('--dataset', default='ek100_cls', type=str, + choices=['ek100_cls', 'egtea']) + parser.add_argument('--root', + default='datasets/EK100/video_ht256px/', + type=str, help='path to dataset root') + parser.add_argument('--metadata-train', + default='datasets/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv', + type=str, help='path to metadata file (train set)') + parser.add_argument('--metadata-val', + default='datasets/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv', + type=str, help='path to metadata file (val set)') + parser.add_argument('--relevancy-path', + default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl', + type=str, help='path to relevancy matrix (val set)') + parser.add_argument('--output-dir', default='./', type=str, help='output dir') + parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms for val') + parser.add_argument('--num-clips', default=1, type=int, help='number of clips for val') + parser.add_argument('--clip-length', default=16, type=int, help='clip length') + parser.add_argument('--clip-stride', default=2, type=int, help='clip stride') + parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') + # Model + parser.add_argument('--pretrain-model', default='', type=str, help='path to pretrain model') + parser.add_argument('--resume', default='', type=str, help='path to resume from') + parser.add_argument('--find-unused-parameters', action='store_true', + help='do this during DDP (useful for models with tied weights)') + parser.add_argument('--drop-path-rate', default=0.1, type=float, help='drop path ratio') + parser.add_argument('--dropout-ratio', default=0.5, type=float, help='dropout ratio for the last linear layer') + parser.add_argument('--num-classes', default=3806, nargs='+', type=int, help='number of classes for the last linear layer') + parser.add_argument('--use-vn-classifier', action='store_true') + parser.add_argument('--use-half', action='store_true', help='use half precision at inference') + # Training + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--warmup-epochs', default=1, type=int) + parser.add_argument('--start-epoch', default=0, type=int) + parser.add_argument('--batch-size', default=16, type=int, + help='number of samples per-device/per-gpu') + parser.add_argument('--use-sgd', action='store_true') + parser.add_argument('--freeze-temperature', action='store_true', help='freeze temperature if set to True') + parser.add_argument('--lr', default=3e-3, type=float) + parser.add_argument('--fix-lr', action='store_true', help='disable cosine lr decay if set True') + parser.add_argument('--lr-start', default=1e-6, type=float, + help='initial warmup lr') + parser.add_argument('--lr-end', default=1e-5, type=float, + help='minimum final lr') + parser.add_argument('--lr-multiplier-on-backbone', default=0.1, type=float, help='lr multiplier for the backbone') + parser.add_argument('--clip-grad-type', default='norm', choices=['norm', 'value']) + parser.add_argument('--clip-grad-value', default=None, type=float, help='') + parser.add_argument('--update-freq', default=1, type=int, + help='optimizer update frequency (i.e. gradient accumulation steps)') + parser.add_argument('--wd', default=0.01, type=float) + parser.add_argument('--betas', default=(0.9, 0.999), nargs=2, type=float) + parser.add_argument('--eps', default=1e-8, type=float) + parser.add_argument('--label-smoothing', default=0.1, type=float, help='label smoothing') + parser.add_argument('--eval-freq', default=5, type=int) + parser.add_argument('--save-freq', default=5, type=int) + parser.add_argument('--disable-amp', action='store_true', + help='disable mixed-precision training (requires more memory and compute)') + parser.add_argument('--use-zero', action='store_true', + help='use ZeroRedundancyOptimizer to save memory') + parser.add_argument('--use-checkpoint', action='store_true', + help='use gradient checkpointing during training for significantly less GPU usage') + # System + parser.add_argument('--print-freq', default=100, type=int, help='print frequency') + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers per process') + parser.add_argument('--evaluate', action='store_true', help='eval only') + parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--dist-url', default='env://', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str) + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') + parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') + return parser + + +def main(args): + dist_utils.init_distributed_mode(args) + + global best_acc1 + random_seed(args.seed, dist_utils.get_rank()) + + if args.pretrain_model: + ckpt_path = args.pretrain_model + else: + raise Exception('no checkpoint found') + ckpt = torch.load(ckpt_path, map_location='cpu') + + if args.use_vn_classifier: + assert args.dataset == 'ek100_cls' and len(args.num_classes) == 3 + + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + old_args = ckpt['args'] + print("=> creating model: {}".format(old_args.model)) + model = getattr(models, old_args.model)( + pretrained=old_args.load_visual_pretrained, + pretrained2d=old_args.load_visual_pretrained is not None, + text_use_cls_token=old_args.use_cls_token, + project_embed_dim=old_args.project_embed_dim, + timesformer_gated_xattn=False, + timesformer_freeze_space=False, + num_frames=args.clip_length, + drop_path_rate=args.drop_path_rate, + ) + if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model: + # inflate weight + print('=> inflating PE in models due to different frame numbers') + state_dict = inflate_positional_embeds( + model.state_dict(), state_dict, + num_frames=args.clip_length, + load_temporal_fix='bilinear', + ) + model.load_state_dict(state_dict, strict=True) + print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) + + if args.use_vn_classifier: + model = models.VideoClassifierMultiHead( + model.visual, + dropout=args.dropout_ratio, + num_classes_list=args.num_classes + ) + else: + assert len(args.num_classes) == 1 + model = models.VideoClassifier( + model.visual, + dropout=args.dropout_ratio, + num_classes=args.num_classes[0] + ) + + model.cuda(args.gpu) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], bucket_cap_mb=200, + find_unused_parameters=args.find_unused_parameters + ) + + p_wd, p_non_wd = [], [] + p_head_wd, p_head_non_wd = [], [] + for n, p in model.named_parameters(): + if 'fc_cls' in n: + if 'bias' in n: + p_head_non_wd.append(p) + else: + p_head_wd.append(p) + elif not p.requires_grad: + continue # frozen weights + elif p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: + p_non_wd.append(p) + else: + p_wd.append(p) + + optim_params = [ + {"params": p_wd, "weight_decay": args.wd, "lr": args.lr * args.lr_multiplier_on_backbone}, + {"params": p_non_wd, "weight_decay": 0, "lr": args.lr * args.lr_multiplier_on_backbone}, + {"params": p_head_wd, "weight_decay": args.wd}, + {"params": p_head_non_wd, "weight_decay": 0} + ] + + if args.use_zero: + optimizer = ZeroRedundancyOptimizer( + optim_params, optimizer_class=torch.optim.SGD if args.use_sgd else torch.optim.AdamW, + lr=args.lr, betas=args.betas, eps=args.eps, weight_decay=args.wd + ) + else: + if args.use_sgd: + optimizer = torch.optim.SGD(optim_params, lr=args.lr, momentum=args.betas[0], weight_decay=args.wd) + else: + optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, + eps=args.eps, weight_decay=args.wd) + scaler = amp.GradScaler(enabled=not args.disable_amp) + # optionally resume from a checkpoint (takes precedence over autoresume) + latest = os.path.join(args.output_dir, 'checkpoint.pt') + if os.path.isfile(latest): + args.resume = '' + if args.resume: + if os.path.isfile(args.resume): + print("=> loading resume checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume, map_location='cpu') + epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 + args.start_epoch = epoch + if not args.distributed: + state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + result = model.load_state_dict(state_dict, strict=False) + else: + result = model.load_state_dict(checkpoint['state_dict'], strict=False) + print(result) + optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () + scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () + best_acc1 = checkpoint['best_acc1'] + print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})" + .format(args.resume, epoch, best_acc1)) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + else: + # auto-resume from latest checkpoint in output directory + latest = os.path.join(args.output_dir, 'checkpoint.pt') + if os.path.isfile(latest): + print("=> loading latest checkpoint '{}'".format(latest)) + latest_checkpoint = torch.load(latest, map_location='cpu') + args.start_epoch = latest_checkpoint['epoch'] + model.load_state_dict(latest_checkpoint['state_dict']) + optimizer.load_state_dict(latest_checkpoint['optimizer']) + scaler.load_state_dict(latest_checkpoint['scaler']) + best_acc1 = latest_checkpoint['best_acc1'] + print("=> loaded latest checkpoint '{}' (epoch {})" + .format(latest, latest_checkpoint['epoch'])) + + cudnn.benchmark = True + + # Data loading code + print("=> creating dataset") + if old_args.model.endswith('DISTILBERT_BASE'): + tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') + elif old_args.model.endswith('BERT_BASE'): + tokenizer = MyBertTokenizer('bert-base-uncased') + elif old_args.model.endswith('BERT_LARGE'): + tokenizer = MyBertTokenizer('bert-large-uncased') + elif old_args.model.endswith('GPT2'): + tokenizer = MyGPT2Tokenizer('gpt2') + elif old_args.model.endswith('GPT2_MEDIUM'): + tokenizer = MyGPT2Tokenizer('gpt2-medium') + elif old_args.model.endswith('GPT2_LARGE'): + tokenizer = MyGPT2Tokenizer('gpt2-large') + elif old_args.model.endswith('GPT2_XL'): + tokenizer = MyGPT2Tokenizer('gpt2-xl') + else: + print("Using SimpleTokenizer because of model '{}'. " + "Please check if this is what you want".format(old_args.model)) + tokenizer = SimpleTokenizer() + + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).cuda(args.gpu) + + crop_size = 224 if '336PX' not in old_args.model else 336 + transforms_list = [ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.RandomResizedCrop(crop_size, scale=(0.5, 1.0)), + transforms.RandomHorizontalFlip(p=0.5), + ] + if 'OPENAI' in old_args.model: + transforms_list.append(transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])) + else: + transforms_list.append(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) + train_transform = transforms.Compose(transforms_list) + + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in old_args.model else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length), + SpatialCrop(crop_size=crop_size, num_crops=args.num_crops), + ]) + + # build dataset + _, mapping_vn2act = generate_label_map(args.dataset) + if args.dataset == 'ek100_cls': + args.mapping_act2v = {i: int(vn.split(':')[0]) for (vn, i) in mapping_vn2act.items()} + args.mapping_act2n = {i: int(vn.split(':')[1]) for (vn, i) in mapping_vn2act.items()} + args.actions = pd.DataFrame.from_dict({'verb': args.mapping_act2v.values(), 'noun': args.mapping_act2n.values()}) + num_clips_at_val = args.num_clips + args.num_clips = 1 + train_dataset = datasets.get_downstream_dataset( + train_transform, tokenizer, args, subset='train', label_mapping=mapping_vn2act, + ) + args.num_clips = num_clips_at_val + val_dataset = datasets.get_downstream_dataset( + val_transform, tokenizer, args, subset='val', label_mapping=mapping_vn2act, + ) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + val_sampler = torch.utils.data.SequentialSampler(val_dataset) # disable distributed + else: + train_sampler = None + val_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True + ) + print('len(train_loader) = {}'.format(len(train_loader))) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False + ) + print('len(val_loader) = {}'.format(len(val_loader))) + + if args.evaluate: + if args.use_vn_classifier: + val_stats = validate_multihead(val_loader, model, args) + else: + val_stats = validate(val_loader, model, args) + return + + if args.fix_lr: + lr_schedule = None + else: + lr_schedule = cosine_scheduler( + args.lr, args.lr_end, args.epochs, len(train_loader) // args.update_freq, + warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start, + ) + + if dist_utils.is_main_process() and args.wandb: + wandb_id = os.path.split(args.output_dir)[-1] + wandb.init(project='LaViLa', id=wandb_id, config=args, resume='allow') + + print(args) + + best_metric = 0. + print("=> beginning training") + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + + train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args) + + is_epoch = ((epoch + 1) % args.save_freq) == 0 + + print('=> saving checkpoint') + dist_utils.save_on_master({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler.state_dict(), + 'best_acc1': 0, + 'args': args, + }, False, args.output_dir, is_epoch=is_epoch) + + if ((epoch + 1) % args.eval_freq) == 0: + if args.use_vn_classifier: + val_stats = validate_multihead(val_loader, model, args) + else: + val_stats = validate(val_loader, model, args) + if val_stats['acc1'] > best_metric: + is_best = True + best_metric = val_stats['acc1'] + else: + is_best = False + + print('=> saving checkpoint') + dist_utils.save_on_master({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler.state_dict(), + 'best_acc1': best_metric, + 'args': args, + }, is_best, args.output_dir, is_epoch=is_epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in val_stats.items()}, + 'epoch': epoch} + + if dist_utils.is_main_process(): + if args.wandb: + wandb.log(log_stats) + with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: + f.write(json.dumps(log_stats) + '\n') + + +def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + mem = AverageMeter('Mem (GB)', ':6.1f') + iters_per_epoch = len(train_loader) // args.update_freq + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + top1_noun = AverageMeter('Noun Acc@1', ':6.2f') + top1_verb = AverageMeter('Verb Acc@1', ':6.2f') + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, losses, top1, top5, top1_noun, top1_verb], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for data_iter, (images, target) in enumerate(train_loader): + optim_iter = data_iter // args.update_freq + + # measure data loading time + data_time.update(time.time() - end) + + # update weight decay and learning rate according to their schedule + it = iters_per_epoch * epoch + optim_iter # global training iteration + for k, param_group in enumerate(optimizer.param_groups): + if lr_schedule is not None: + param_group['lr'] = lr_schedule[it] * args.lr_multiplier_on_backbone + else: + param_group['lr'] = lr_schedule[it] + + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + with amp.autocast(enabled=not args.disable_amp): + output = model(images, use_checkpoint=args.use_checkpoint) + if isinstance(output, list): + assert len(output) == 3 + target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + loss = criterion(output[0], target_to_verb) + target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + loss += criterion(output[1], target_to_noun) + loss += criterion(output[2], target) + else: + loss = criterion(output, target) + loss /= args.update_freq + + if not math.isfinite(loss.item()): + print("Loss is {}, stopping training".format(loss.item())) + sys.exit(1) + + scaler.scale(loss).backward() + + if (data_iter + 1) % args.update_freq != 0: + continue + + if args.clip_grad_value is not None: + scaler.unscale_(optimizer) + if args.clip_grad_type == 'norm': + torch.nn.utils.clip_grad_norm_( + model.parameters(), args.clip_grad_value, norm_type=2. + ) + elif args.clip_grad_type == 'value': + torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_value) + else: + assert False, f"Unknown clip mode ({args.clip_grad_type})." + # compute gradient and do SGD step + scaler.step(optimizer) + scaler.update() + model.zero_grad(set_to_none=True) + + if isinstance(output, list): + target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + acc1_verb, _ = accuracy(output[0], target_to_verb, topk=(1, 5)) + top1_verb.update(acc1_verb.item(), images.size(0)) + target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + acc1_noun, _ = accuracy(output[1], target_to_noun, topk=(1, 5)) + top1_noun.update(acc1_noun.item(), images.size(0)) + acc1, acc5 = accuracy(output[2], target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1.item(), images.size(0)) + top5.update(acc5.item(), images.size(0)) + else: + output = torch.softmax(output, dim=1) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1.item(), images.size(0)) + top5.update(acc5.item(), images.size(0)) + if args.dataset == 'ek100_cls': + vi = get_marginal_indexes(args.actions, 'verb') + ni = get_marginal_indexes(args.actions, 'noun') + verb_scores = torch.tensor(marginalize(output.detach().cpu().numpy(), vi)).cuda(args.gpu, non_blocking=True) + noun_scores = torch.tensor(marginalize(output.detach().cpu().numpy(), ni)).cuda(args.gpu, non_blocking=True) + target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + acc1_verb, _ = accuracy(verb_scores, target_to_verb, topk=(1, 5)) + acc1_noun, _ = accuracy(noun_scores, target_to_noun, topk=(1, 5)) + top1_verb.update(acc1_verb.item(), images.size(0)) + top1_noun.update(acc1_noun.item(), images.size(0)) + else: + top1_verb.update(0., images.size(0)) + top1_noun.update(0., images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if optim_iter % args.print_freq == 0: + if dist_utils.is_main_process() and args.wandb: + wandb.log({ + 'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg, + 'acc1_verb': top1_verb.avg, 'acc1_noun': top1_noun.avg, + }) + progress.display(optim_iter) + progress.synchronize() + return { + 'acc1': top1.avg, 'acc5': top5.avg, 'loss': losses.avg, + 'acc1_verb': top1_verb.avg, 'acc1_noun': top1_noun.avg, + 'lr': optimizer.param_groups[0]['lr'], + } + + +def validate(val_loader, model, args): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, top1, top5], + prefix='Test: ' + ) + + # switch to eval mode + model.eval() + if args.use_half: + model.half() + + all_outputs = [] + all_targets = [] + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + # measure data loading time + data_time.update(time.time() - end) + if isinstance(images, list): + logit_allcrops = [] + for crop in images: + crop = crop.cuda(args.gpu, non_blocking=True) + if args.use_half: + crop = crop.half() + logit = model(crop, use_checkpoint=args.use_checkpoint) + logit_allcrops.append(logit) + logit_allcrops = torch.stack(logit_allcrops, 0) + logit = logit_allcrops.mean(0) + logit = torch.softmax(logit, dim=1) + target = target.cuda(args.gpu, non_blocking=True) + + acc1, acc5 = accuracy(logit, target, topk=(1, 5)) + top1.update(acc1.item(), target.size(0)) + top5.update(acc5.item(), target.size(0)) + else: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + if args.use_half: + images = images.half() + + logit = model(images, use_checkpoint=args.use_checkpoint) + logit = torch.softmax(logit, dim=1) + + acc1, acc5 = accuracy(logit, target, topk=(1, 5)) + top1.update(acc1.item(), images.size(0)) + top5.update(acc5.item(), images.size(0)) + + all_outputs.append(logit) + all_targets.append(target) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + progress.synchronize() + if args.dataset == 'ek100_cls': + print('EK100 * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) + else: + print('EGTEA * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) + all_outputs = torch.cat(all_outputs).cpu().numpy() + all_targets = torch.cat(all_targets).cpu().numpy() + cm = confusion_matrix(all_targets, all_outputs.argmax(axis=1)) + mean_acc, acc = get_mean_accuracy(cm) + print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_acc, acc)) + + if args.dataset == 'ek100_cls': + vi = get_marginal_indexes(args.actions, 'verb') + ni = get_marginal_indexes(args.actions, 'noun') + verb_scores = marginalize(all_outputs, vi) + noun_scores = marginalize(all_outputs, ni) + target_to_verb = np.array([args.mapping_act2v[a] for a in all_targets.tolist()]) + target_to_noun = np.array([args.mapping_act2n[a] for a in all_targets.tolist()]) + cm = confusion_matrix(target_to_verb, verb_scores.argmax(axis=1)) + _, acc = get_mean_accuracy(cm) + print('Verb Acc@1: {:.3f}'.format(acc)) + cm = confusion_matrix(target_to_noun, noun_scores.argmax(axis=1)) + _, acc = get_mean_accuracy(cm) + print('Noun Acc@1: {:.3f}'.format(acc)) + return {'acc1': top1.avg, 'acc5': top5.avg, 'mean_acc': mean_acc} + + +def validate_multihead(val_loader, model, args): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + top1_verb = AverageMeter('Verb Acc@1', ':6.2f') + top1_noun = AverageMeter('Noun Acc@1', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, top1, top5, top1_verb, top1_noun], + prefix='Test: ' + ) + + # switch to eval mode + model.eval() + if args.use_half: + model.half() + + all_verb_outputs = [] + all_noun_outputs = [] + all_action_outputs = [] + all_verb_targets = [] + all_noun_targets = [] + all_action_targets = [] + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + # measure data loading time + data_time.update(time.time() - end) + if isinstance(images, torch.Tensor): + images = [images, ] + logit_verb_allcrops = [] + logit_noun_allcrops = [] + logit_action_allcrops = [] + for crop in images: + crop = crop.cuda(args.gpu, non_blocking=True) + if args.use_half: + crop = crop.half() + logit = model(crop, use_checkpoint=args.use_checkpoint) + logit_verb_allcrops.append(logit[0]) + logit_noun_allcrops.append(logit[1]) + logit_action_allcrops.append(logit[2]) + logit_verb_allcrops = torch.stack(logit_verb_allcrops, 0) + logit_noun_allcrops = torch.stack(logit_noun_allcrops, 0) + logit_action_allcrops = torch.stack(logit_action_allcrops, 0) + logit_verb = logit_verb_allcrops.mean(0) + logit_noun = logit_noun_allcrops.mean(0) + logit_action = logit_action_allcrops.mean(0) + logit_noun = torch.softmax(logit_noun, dim=1) + logit_verb = torch.softmax(logit_verb, dim=1) + logit_action = torch.softmax(logit_action, dim=1) + target = target.cuda(args.gpu, non_blocking=True) + target_to_verb = torch.tensor([args.mapping_act2v[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + target_to_noun = torch.tensor([args.mapping_act2n[a] for a in target.tolist()]).cuda(args.gpu, non_blocking=True) + + acc1, acc5 = accuracy(logit_action, target, topk=(1, 5)) + acc1_verb, _ = accuracy(logit_verb, target_to_verb, topk=(1, 5)) + acc1_noun, _ = accuracy(logit_noun, target_to_noun, topk=(1, 5)) + top1.update(acc1.item(), target.size(0)) + top5.update(acc5.item(), target.size(0)) + top1_verb.update(acc1_verb.item(), target_to_verb.size(0)) + top1_noun.update(acc1_noun.item(), target_to_noun.size(0)) + + all_verb_outputs.append(logit_verb) + all_noun_outputs.append(logit_noun) + all_action_outputs.append(logit_action) + all_verb_targets.append(target_to_verb) + all_noun_targets.append(target_to_noun) + all_action_targets.append(target) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + progress.synchronize() + print('EK100 * Verb Acc@1 {top1.avg:.3f}'.format(top1=top1_verb)) + print('EK100 * Noun Acc@1 {top1.avg:.3f}'.format(top1=top1_noun)) + print('EK100 * Action Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) + return {'acc1': top1.avg, 'acc5': top5.avg, 'acc1_verb': top1_verb.avg, 'acc1_noun': top1_noun.avg} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila finetune and evaluation', parents=[get_args_parser()]) + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + main(args) diff --git a/main_finetune_retrieval.py b/main_finetune_retrieval.py new file mode 100644 index 0000000..e2b332a --- /dev/null +++ b/main_finetune_retrieval.py @@ -0,0 +1,651 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from collections import OrderedDict +import json +import math +import numpy as np +import os +import pandas as pd +import sys +import time + +import torch +import torch.backends.cudnn as cudnn +import torch.cuda.amp as amp +from torch.distributed.optim import ZeroRedundancyOptimizer +import torch.nn.parallel +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video +import wandb + +from lavila.data import datasets +from lavila.data.video_transforms import Permute +from lavila.models import models, loss +from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) +from lavila.models.utils import inflate_positional_embeds +from lavila.utils import distributed as dist_utils +from lavila.utils.evaluation_charades import charades_map +from lavila.utils.meter import AverageMeter, ProgressMeter +from lavila.utils.preprocess import generate_label_map +from lavila.utils.random import random_seed +from lavila.utils.scheduler import cosine_scheduler +from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG) + + +def get_args_parser(): + parser = argparse.ArgumentParser(description='lavila finetune and evaluation', add_help=False) + # Data + parser.add_argument('--dataset', default='ek100_mir', type=str, + choices=['ek100_mir', 'charades_ego']) + parser.add_argument('--root', + default='datasets/EK100/video_ht256px/', + type=str, help='path to dataset root') + parser.add_argument('--metadata', + default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv', + type=str, help='path to metadata file (train set)') + parser.add_argument('--metadata-val', + default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv', + type=str, help='path to metadata file (val set)') + parser.add_argument('--relevancy-path', + default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl', + type=str, help='path to relevancy matrix (val set)') + parser.add_argument('--output-dir', default='./', type=str, help='output dir') + parser.add_argument('--clip-length', default=16, type=int, help='clip length') + parser.add_argument('--clip-stride', default=4, type=int, help='clip stride') + parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') + # Model + parser.add_argument('--pretrain-model', default='', type=str, help='path to pretrain model') + parser.add_argument('--resume', default='', type=str, help='path to resume from') + parser.add_argument('--find-unused-parameters', action='store_true', + help='do this during DDP (useful for models with tied weights)') + parser.add_argument('--drop-path-rate', default=0.1, type=float, help='drop path ratio') + # Training + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--warmup-epochs', default=1, type=int) + parser.add_argument('--start-epoch', default=0, type=int) + parser.add_argument('--batch-size', default=16, type=int, + help='number of samples per-device/per-gpu') + parser.add_argument('--freeze-temperature', action='store_true', help='freeze temperature if set to True') + parser.add_argument('--lr', default=3e-5, type=float) + parser.add_argument('--fix-lr', action='store_true', help='disable cosine lr decay if set True') + parser.add_argument('--lr-start', default=1e-6, type=float, + help='initial warmup lr') + parser.add_argument('--lr-end', default=1e-5, type=float, + help='minimum final lr') + parser.add_argument('--clip-grad-type', default='norm', choices=['norm', 'value']) + parser.add_argument('--clip-grad-value', default=None, type=float, help='') + parser.add_argument('--update-freq', default=1, type=int, + help='optimizer update frequency (i.e. gradient accumulation steps)') + parser.add_argument('--wd', default=0.01, type=float) + parser.add_argument('--betas', default=(0.9, 0.999), nargs=2, type=float) + parser.add_argument('--eps', default=1e-8, type=float) + parser.add_argument('--eval-freq', default=5, type=int) + parser.add_argument('--save-freq', default=5, type=int) + parser.add_argument('--disable-amp', action='store_true', + help='disable mixed-precision training (requires more memory and compute)') + parser.add_argument('--use-zero', action='store_true', + help='use ZeroRedundancyOptimizer to save memory') + parser.add_argument('--use-checkpoint', action='store_true', + help='use gradient checkpointing during training for significantly less GPU usage') + # System + parser.add_argument('--print-freq', default=100, type=int, help='print frequency') + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers per process') + parser.add_argument('--evaluate', action='store_true', help='eval only') + parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--dist-url', default='env://', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str) + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') + parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') + return parser + + +def main(args): + dist_utils.init_distributed_mode(args) + + global best_acc1 + random_seed(args.seed, dist_utils.get_rank()) + + if args.pretrain_model: + ckpt_path = args.pretrain_model + else: + raise Exception('no checkpoint found') + ckpt = torch.load(ckpt_path, map_location='cpu') + + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + old_args = ckpt['args'] + print("=> creating model: {}".format(old_args.model)) + model = getattr(models, old_args.model)( + pretrained=old_args.load_visual_pretrained, + pretrained2d=old_args.load_visual_pretrained is not None, + text_use_cls_token=old_args.use_cls_token, + project_embed_dim=old_args.project_embed_dim, + timesformer_gated_xattn=False, + timesformer_freeze_space=False, + num_frames=args.clip_length, + drop_path_rate=args.drop_path_rate, + ) + model.logit_scale.requires_grad = False + model.cuda(args.gpu) + if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model: + # inflate weight + print('=> inflating PE in models due to different frame numbers') + state_dict = inflate_positional_embeds( + model.state_dict(), state_dict, + num_frames=args.clip_length, + load_temporal_fix='bilinear', + ) + model.load_state_dict(state_dict, strict=True) + print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], bucket_cap_mb=200, + find_unused_parameters=args.find_unused_parameters + ) + + p_wd, p_non_wd = [], [] + for n, p in model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: + p_non_wd.append(p) + else: + p_wd.append(p) + + optim_params = [{"params": p_wd, "weight_decay": args.wd}, + {"params": p_non_wd, "weight_decay": 0}] + + if args.use_zero: + optimizer = ZeroRedundancyOptimizer( + optim_params, optimizer_class=torch.optim.AdamW, + lr=args.lr, betas=args.betas, eps=args.eps, weight_decay=args.wd + ) + else: + optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, + eps=args.eps, weight_decay=args.wd) + scaler = amp.GradScaler(enabled=not args.disable_amp) + # optionally resume from a checkpoint (takes precedence over autoresume) + latest = os.path.join(args.output_dir, 'checkpoint.pt') + if os.path.isfile(latest): + args.resume = '' + if args.resume: + if os.path.isfile(args.resume): + print("=> loading resume checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume, map_location='cpu') + epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 + args.start_epoch = epoch + if not args.distributed: + state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + result = model.load_state_dict(state_dict, strict=False) + else: + result = model.load_state_dict(checkpoint['state_dict'], strict=False) + print(result) + optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () + scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () + best_acc1 = checkpoint['best_acc1'] + print("=> loaded resume checkpoint '{}' (epoch {})" + .format(args.resume, epoch)) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + else: + # auto-resume from latest checkpoint in output directory + latest = os.path.join(args.output_dir, 'checkpoint.pt') + if os.path.isfile(latest): + print("=> loading latest checkpoint '{}'".format(latest)) + latest_checkpoint = torch.load(latest, map_location='cpu') + args.start_epoch = latest_checkpoint['epoch'] + model.load_state_dict(latest_checkpoint['state_dict']) + optimizer.load_state_dict(latest_checkpoint['optimizer']) + scaler.load_state_dict(latest_checkpoint['scaler']) + best_acc1 = latest_checkpoint['best_acc1'] + print("=> loaded latest checkpoint '{}' (epoch {})" + .format(latest, latest_checkpoint['epoch'])) + + cudnn.benchmark = True + + # Data loading code + print("=> creating dataset") + if old_args.model.endswith('DISTILBERT_BASE'): + tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') + elif old_args.model.endswith('BERT_BASE'): + tokenizer = MyBertTokenizer('bert-base-uncased') + elif old_args.model.endswith('BERT_LARGE'): + tokenizer = MyBertTokenizer('bert-large-uncased') + elif old_args.model.endswith('GPT2'): + tokenizer = MyGPT2Tokenizer('gpt2') + elif old_args.model.endswith('GPT2_MEDIUM'): + tokenizer = MyGPT2Tokenizer('gpt2-medium') + elif old_args.model.endswith('GPT2_LARGE'): + tokenizer = MyGPT2Tokenizer('gpt2-large') + elif old_args.model.endswith('GPT2_XL'): + tokenizer = MyGPT2Tokenizer('gpt2-xl') + else: + print("Using SimpleTokenizer because of model '{}'. " + "Please check if this is what you want".format(old_args.model)) + tokenizer = SimpleTokenizer() + + if args.dataset == 'ek100_mir': + criterion = loss.MaxMarginRankingLoss(margin=0.2, fix_norm=True).cuda(args.gpu) + elif args.dataset == 'charades_ego': + criterion = loss.CLIPLoss( + use_vissl=True, + cache_labels=True, + rank=args.rank, + world_size=args.world_size + ) + + crop_size = 224 if '336PX' not in old_args.model else 336 + transforms_list = [ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.RandomResizedCrop(crop_size, scale=(0.5, 1.0)), + ] + if 'OPENAI' in old_args.model: + transforms_list.append(transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])) + else: + transforms_list.append(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) + train_transform = transforms.Compose(transforms_list) + + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in old_args.model else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + ]) + + # build dataset + args.model = old_args.model + args.norm_embed = old_args.norm_embed + if args.dataset == 'ek100_mir': + train_dataset = datasets.get_dataset(train_transform, tokenizer, args, is_training=True) + args.metadata = args.metadata.replace('train', 'test') + val_dataset = datasets.get_dataset(val_transform, tokenizer, args, is_training=False) + args.metadata = args.metadata.replace('test', 'train') + elif args.dataset == 'charades_ego': + train_dataset = datasets.VideoCaptionDatasetCLIP( + 'charades_ego_trimmed', args.root, args.metadata, + transform=train_transform, is_training=True, tokenizer=tokenizer, + clip_length=args.clip_length, clip_stride=args.clip_stride + ) + labels, mapping_vn2act = generate_label_map(args.dataset) + val_dataset = datasets.VideoClassyDataset( + args.dataset, args.root, args.metadata_val, + transform=val_transform, is_training=False, + label_mapping=mapping_vn2act, is_trimmed=False, + num_clips=1, clip_length=args.clip_length, clip_stride=args.clip_stride, + sparse_sample=args.sparse_sample, + ) + else: + raise NotImplementedError + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + val_sampler = torch.utils.data.SequentialSampler(val_dataset) # disable distributed + else: + train_sampler = None + val_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True + ) + print('len(train_loader) = {}'.format(len(train_loader))) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False + ) + print('len(val_loader) = {}'.format(len(val_loader))) + + if args.evaluate: + if args.dataset == 'ek100_mir': + _ = validate_mir(val_loader, model, criterion, args) + elif args.dataset == 'charades_ego': + _ = validate_cls(val_loader, ['{}'], labels, model, tokenizer, args) + return + + if args.fix_lr: + lr_schedule = None + else: + lr_schedule = cosine_scheduler( + args.lr, args.lr_end, args.epochs, len(train_loader) // args.update_freq, + warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start, + ) + + if dist_utils.is_main_process() and args.wandb: + wandb_id = os.path.split(args.output_dir)[-1] + wandb.init(project='LaViLa', id=wandb_id, config=args, resume='allow') + + print(args) + + print("=> zero-shot testing") + if args.dataset == 'ek100_mir': + _ = validate_mir(val_loader, model, criterion, args) + elif args.dataset == 'charades_ego': + _ = validate_cls(val_loader, ['{}'], labels, model, tokenizer, args) + + print("=> beginning training") + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + + train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args) + + is_epoch = ((epoch + 1) % args.save_freq) == 0 + + print('=> saving checkpoint') + dist_utils.save_on_master({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler.state_dict(), + 'best_acc1': 0, + 'args': args, + }, False, args.output_dir, is_epoch=is_epoch) + + if (epoch + 1) % args.eval_freq != 0: + continue + + # TODO: add evaluation + if args.dataset == 'ek100_mir': + val_stats = validate_mir(val_loader, model, criterion, args) + elif args.dataset == 'charades_ego': + val_stats = validate_cls(val_loader, ['{}'], labels, model, tokenizer, args) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in val_stats.items()}, + 'epoch': epoch} + + if dist_utils.is_main_process(): + if args.wandb: + wandb.log(log_stats) + with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: + f.write(json.dumps(log_stats) + '\n') + + +def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + mem = AverageMeter('Mem (GB)', ':6.1f') + if args.dataset == 'ek100_mir': + metric_names = ['loss', 'max_margin_loss'] + elif args.dataset == 'charades_ego': + metric_names = models.get_metric_names(args.model) + iters_per_epoch = len(train_loader) // args.update_freq + metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, *metrics.values()], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for data_iter, inputs in enumerate(train_loader): + optim_iter = data_iter // args.update_freq + + # measure data loading time + data_time.update(time.time() - end) + + # update weight decay and learning rate according to their schedule + it = iters_per_epoch * epoch + optim_iter # global training iteration + for k, param_group in enumerate(optimizer.param_groups): + if lr_schedule is not None: + param_group['lr'] = lr_schedule[it] + + inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] + relevancies = inputs.pop() + + # compute output + with amp.autocast(enabled=not args.disable_amp): + outputs = model( + *inputs, + use_checkpoint=args.use_checkpoint, + norm_embed=args.norm_embed + ) + if args.dataset == 'ek100_mir': + loss_dict = criterion(outputs, weight=relevancies) + elif args.dataset == 'charades_ego': + loss_dict = criterion(outputs) + loss = loss_dict['loss'] + loss /= args.update_freq + + if not math.isfinite(loss.item()): + print("Loss is {}, stopping training".format(loss.item())) + sys.exit(1) + + scaler.scale(loss).backward() + # TODO: for debug only + # for n, p in model.named_parameters(): + # if p.grad is not None: + # print('{}: {} | {}'.format(n, torch.mean(torch.abs(p.data)), torch.mean(torch.abs(p.grad))), flush=True) + # else: + # print('{}: {} | {}'.format(n, torch.mean(torch.abs(p.data)), 'None'), flush=True) + # if torch.isnan(loss): + # for n, p in model.named_parameters(): + # print(f'{n}:', p.grad, flush=True) + + if (data_iter + 1) % args.update_freq != 0: + continue + + if args.clip_grad_value is not None: + scaler.unscale_(optimizer) + if args.clip_grad_type == 'norm': + torch.nn.utils.clip_grad_norm_( + model.parameters(), args.clip_grad_value, norm_type=2. + ) + elif args.clip_grad_type == 'value': + torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_value) + else: + assert False, f"Unknown clip mode ({args.clip_grad_type})." + # compute gradient and do SGD step + scaler.step(optimizer) + scaler.update() + model.zero_grad(set_to_none=True) + + if hasattr(dist_utils.get_model(model), 'logit_scale'): + # clamp logit scale to [0, 100] + dist_utils.get_model(model).logit_scale.data.clamp_(0, 4.6052) + logit_scale = dist_utils.get_model(model).logit_scale.exp().item() + else: + logit_scale = torch.nan + + for k in loss_dict: + metrics[k].update(loss_dict[k].item(), args.batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if optim_iter % args.print_freq == 0: + if dist_utils.is_main_process() and args.wandb: + wandb.log({**{k: v.item() for k, v in loss_dict.items()}, + 'scaler': scaler.get_scale(), 'logit': logit_scale}) + progress.display(optim_iter) + progress.synchronize() + return {**{k: v.avg for k, v in metrics.items()}, + 'lr': optimizer.param_groups[0]['lr'], + 'logit_scale': logit_scale} + + +def validate_mir(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + mem = AverageMeter('Mem (GB)', ':6.1f') + metric_names = ['loss', 'max_margin_loss'] + iters_per_epoch = len(val_loader) // args.update_freq + metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, *metrics.values()], + prefix="Test: " + ) + + # switch to eval mode + model.eval() + + all_video_embed = [] + all_text_embed = [] + with torch.no_grad(): + end = time.time() + for i, inputs in enumerate(val_loader): + # measure data loading time + data_time.update(time.time() - end) + + inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] + relevancies = inputs.pop() + + # compute output + outputs = model( + *inputs, + use_checkpoint=args.use_checkpoint, + norm_embed=args.norm_embed + ) + loss_dict = criterion(outputs, weight=relevancies) + + for k in loss_dict: + metrics[k].update(loss_dict[k].item(), args.batch_size) + + image_features = outputs['image_embed'] + text_features = outputs['text_embed'] + all_video_embed.append(image_features.cpu().numpy()) + all_text_embed.append(text_features.cpu().numpy()) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if i % args.print_freq == 0: + if dist_utils.is_main_process() and args.wandb: + wandb.log({**{k: v.item() for k, v in loss_dict.items()}}) + progress.display(i) + progress.synchronize() + all_text_embed = np.vstack(all_text_embed) + all_video_embed = np.vstack(all_video_embed) + similarity_matrix = np.matmul(all_video_embed, all_text_embed.T) + similarity_matrix = (similarity_matrix + 1) / 2 + video_id = pd.read_csv(args.metadata.replace('train', 'test')).values[:, 0] + text_id = pd.read_csv(args.metadata.replace('train', 'test_sentence')).values[:, 0] + indexes = [video_id.tolist().index(elem) for elem in text_id] + similarity_matrix = similarity_matrix[:, indexes] + print(similarity_matrix.shape) + rel_matrix = pd.read_pickle( + args.relevancy_path + ) + vis_map = calculate_mAP(similarity_matrix, rel_matrix) + txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) + print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, (vis_map + txt_map) / 2)) + vis_k_counts = calculate_k_counts(rel_matrix) + txt_k_counts = calculate_k_counts(rel_matrix.T) + vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) + txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) + vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) + txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) + print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2)) + return {**{k: v.avg for k, v in metrics.items()}} + + +def validate_cls(val_loader, templates, labels, model, tokenizer, args): + # switch to eval mode + model.eval() + + all_outputs = [] + all_targets = [] + with torch.no_grad(): + text_features = [] + for label in labels: + if isinstance(label, list): + texts = [tmpl.format(lbl) for tmpl in templates for lbl in label] + else: + texts = [tmpl.format(label) for tmpl in templates] + texts = tokenizer(texts) + if isinstance(texts, tuple): + # Bert-style tokenizer will output both ids and mask + texts, masks = texts + texts = texts.cuda(non_blocking=True) + masks = masks.cuda(non_blocking=True) + else: + texts = texts.cuda(non_blocking=True) + masks = None + texts = texts.view(-1, 77).contiguous() + masks = masks.view(-1, 77).contiguous() if masks is not None else None + if masks is not None: + class_embeddings = dist_utils.get_model(model).encode_text(texts, attention_mask=masks) + else: + class_embeddings = dist_utils.get_model(model).encode_text(texts) + class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) + class_embeddings = class_embeddings.mean(dim=0) + class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) + + text_features.append(class_embeddings) + text_features = torch.stack(text_features, dim=0) + + print('=> start forwarding') + end_time = time.time() + for i, (images, target) in enumerate(val_loader): + if i % args.print_freq == 0: + print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time)) + end_time = time.time() + if isinstance(images, torch.Tensor): + images = images.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # encode images + image_features = dist_utils.get_model(model).encode_image(images) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ text_features.t() + logits_per_image = torch.softmax(logits_per_image, dim=1) + else: + target = target.cuda(non_blocking=True) + images_list = images + logits_all_clips = [] + for images in images_list: + images = images.cuda(non_blocking=True) + image_features = dist_utils.get_model(model).encode_image(images) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + logits_per_image = image_features @ text_features.t() + logits_all_clips.append(logits_per_image) + + logits_all_clips = torch.stack(logits_all_clips, dim=0) + # logits_per_image = logits_all_clips.max(0).values + logits_per_image = logits_all_clips.mean(0) + logits_per_image = torch.softmax(logits_per_image, dim=1) + + all_outputs.append(logits_per_image.cpu()) + all_targets.append(target.cpu()) + all_outputs = torch.cat(all_outputs) + all_targets = torch.cat(all_targets) + preds, targets = all_outputs.numpy(), all_targets.numpy() + m_ap, _, _ = charades_map(preds, targets) + print('mAP = {:.3f}'.format(m_ap)) + return {'mAP': m_ap} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila finetune and evaluation', parents=[get_args_parser()]) + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + main(args) diff --git a/main_infer_narrator.py b/main_infer_narrator.py new file mode 100644 index 0000000..5a5dcb6 --- /dev/null +++ b/main_infer_narrator.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +from collections import OrderedDict +import os +import os.path as osp +import pickle +import time + +import torch +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video + +from lavila.data import datasets +from lavila.data.video_transforms import Permute +from lavila.models import models +from lavila.utils.preprocess import generate_tokenizer +from lavila.utils import distributed as dist_utils +from eval_narrator import decode_one + + +class IndexedDataset(torch.utils.data.Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, index): + return index, self.dataset[index] + + def __len__(self): + return len(self.dataset) + + +def get_args_parser(): + parser = argparse.ArgumentParser(description='lavila infer narrator', add_help=False) + parser.add_argument('--dataset', default='ego4d', type=str, choices=['ego4d']) + parser.add_argument('--root', + default='datasets/Ego4D/video_5min_chunks_288px/', + type=str, help='path to dataset root') + parser.add_argument('--metadata', + default='datasets/Ego4D/ego4d_train.pkl', + type=str, help='path to metadata file') + parser.add_argument('--output-dir', default='./', type=str, help='output dir') + parser.add_argument('--batch-size', default=64, type=int) + parser.add_argument('--use-half', action='store_true') + parser.add_argument('--clip-length', default=4, type=int, help='clip length') + parser.add_argument('--clip-stride', default=16, type=int, help='clip stride') + parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') + parser.add_argument('--caption-sample', default='multinomial_sample', + choices=['multinomial_sample', 'beam_sample', 'group_beam_search']) + parser.add_argument('--caption-top-k', default=None, type=int) + parser.add_argument('--caption-top-p', default=0.95, type=float) + parser.add_argument('--caption-num-beams', default=1, type=int) + parser.add_argument('--caption-num-beam-groups', default=1, type=int) + parser.add_argument('--caption-temperature', default=0.7, type=float) + parser.add_argument('--caption-length-penalty', default=1.0, type=float) + parser.add_argument('--caption-num-return-sequences', default=10, type=int) + parser.add_argument('--caption-max-len', default=77, type=int) + parser.add_argument('--caption-early-stop', action='store_true', help='early stopping to save computation') + # System + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', + help='number of data loading workers per process') + parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--dist-url', default='env://', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str) + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') + return parser + + +def main(args): + dist_utils.init_distributed_mode(args) + print(args) + + if args.resume: + ckpt_path = args.resume + elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')): + ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt') + else: + raise Exception('no checkpoint found') + + ckpt = torch.load(ckpt_path, map_location='cpu') + state_dict = OrderedDict() + for k, v in ckpt['state_dict'].items(): + state_dict[k.replace('module.', '')] = v + + # create model + old_args = ckpt['args'] + print('=> creating model: {}'.format(old_args.model)) + model = getattr(models, old_args.model)( + text_use_cls_token=old_args.use_cls_token, + gated_xattn=old_args.gated_xattn, + timesformer_gated_xattn=old_args.timesformer_gated_xattn, + num_frames=old_args.clip_length, + drop_path_rate=0, + ) + model.cuda() + model.load_state_dict(state_dict, strict=True) + print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt['epoch'])) + + torch.backends.cudnn.benchmark = True + + # Data loading + print("=> creating dataset") + tokenizer = generate_tokenizer(old_args.model) + + crop_size = 224 if '336PX' not in old_args.model else 336 + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in old_args.model else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), + ]) + + val_dataset = datasets.VideoCaptionDatasetCLIP( + args.dataset, + args.root, + args.metadata, + transform=val_transform, + is_training=False, + tokenizer=tokenizer, + clip_length=args.clip_length, + clip_stride=args.clip_stride, + sparse_sample=False, + subsample_stride=1, + ) + val_dataset = IndexedDataset(val_dataset) + + print(len(val_dataset)) + + if args.distributed: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + else: + val_sampler = None + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False + ) + print('len(val_loader) = {}'.format(len(val_loader))) + + model.eval() + if args.use_half: + model.half() + + id_offset = 0 + all_captions_cache = [] + end = time.time() + with torch.no_grad(): + for data_iter, (indices, inputs) in enumerate(val_loader): + indices = indices.tolist() + if data_iter % args.print_freq == 0: + print("finished {}/{} in {}".format(data_iter, len(val_loader), time.time() - end)) + end = time.time() + if len(inputs) == 2 or len(inputs) == 3: + images = inputs[0].cuda(non_blocking=True) + if args.use_half: + images = images.half() + + image_features = dist_utils.get_model(model).encode_image(images) + if not isinstance(image_features, (list, tuple)): + image_tokens = image_features + else: + image_tokens = image_features[1] + if args.caption_sample == 'multinomial_sample': + generated_text_ids, ppls = dist_utils.get_model(model).generate( + image_tokens, + tokenizer, + target=None, + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + num_return_sequences=args.caption_num_return_sequences, + temperature=args.caption_temperature, + early_stopping=args.caption_early_stop, + ) + elif args.caption_sample == 'beam_sample': + generated_text_ids, ppls = dist_utils.get_model(model).beam_sample( + image_tokens, + tokenizer, + target=None, + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + temperature=args.caption_temperature, + length_penalty=args.caption_length_penalty, + num_beams=args.caption_num_beams, + num_return_sequences=args.caption_num_return_sequences, + ) + elif args.caption_sample == 'group_beam_search': + assert args.caption_num_beam_groups > 1 and args.caption_num_beams % args.caption_num_beam_groups == 0 + generated_text_ids, ppls = dist_utils.get_model(model).group_beam_search( + image_tokens, + tokenizer, + target=None, + max_text_length=args.caption_max_len, + top_k=args.caption_top_k, + top_p=args.caption_top_p, + temperature=args.caption_temperature, + length_penalty=args.caption_length_penalty, + num_beams=args.caption_num_beams, + num_beam_groups=args.caption_num_beam_groups, + num_return_sequences=args.caption_num_return_sequences, + ) + for j in range(generated_text_ids.shape[0] // args.caption_num_return_sequences): + generated_text_str_list = [] + ppls_list = [] + for k in range(args.caption_num_return_sequences): + jj = j * args.caption_num_return_sequences + k + generated_text_str = decode_one(generated_text_ids[jj], tokenizer) + generated_text_str_list.append(generated_text_str) + ppls_list.append(ppls[jj].item()) + video_uid, t_start, t_end, _ = val_loader.dataset.dataset.samples[indices[j]] + if args.caption_num_return_sequences == 1: + all_captions_cache.append((video_uid, t_start, t_end, generated_text_str, ppls[jj].item())) + else: + all_captions_cache.append((video_uid, t_start, t_end, generated_text_str_list, ppls_list)) + id_offset += generated_text_ids.shape[0] + + pickle.dump(all_captions_cache, open(osp.join(args.output_dir, 'cache.{}.pkl'.format(args.rank)), 'wb')) + + torch.distributed.barrier() + disorded_list = [] + total_num = 0 + if args.rank == 0: + for i in range(args.world_size): + print('=> reading {}'.format(osp.join(args.output_dir, f'cache.{i}.pkl'))) + sublist = pickle.load(open(osp.join(args.output_dir, f'cache.{i}.pkl'), 'rb')) + disorded_list.append(sublist) + total_num += len(sublist) + ordered_list = [] + for i in range(total_num): + ordered_list.append(disorded_list[i % args.world_size][i // args.world_size]) + print(f"{len(val_dataset)}/{len(ordered_list)}") + ordered_list = ordered_list[:len(val_dataset)] + pickle.dump(ordered_list, open(osp.join(args.output_dir, 'total.pkl'), 'wb')) + for i in range(args.world_size): + print('=> deleting {}'.format(osp.join(args.output_dir, f'cache.{i}.pkl'))) + os.remove(osp.join(args.output_dir, f'cache.{i}.pkl')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('lavila infer narrator', parents=[get_args_parser()]) + args = parser.parse_args() + main(args) diff --git a/main_pretrain.py b/main_pretrain.py new file mode 100644 index 0000000..edb9a39 --- /dev/null +++ b/main_pretrain.py @@ -0,0 +1,614 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from collections import OrderedDict +import json +import math +import os +import pandas as pd +import sys +import time + +import torch +import torch.backends.cudnn as cudnn +import torch.cuda.amp as amp +from torch.distributed.optim import ZeroRedundancyOptimizer +import torch.nn.parallel +import torchvision.transforms as transforms +import torchvision.transforms._transforms_video as transforms_video +import wandb + +from eval_zeroshot import get_similarity_matrix +from lavila.data import datasets +from lavila.data.video_transforms import Permute +from lavila.models import models +from lavila.utils.meter import AverageMeter, ProgressMeter +from lavila.utils import distributed as dist_utils +from lavila.utils.evaluation_ek100mir import get_mAP, get_nDCG +from lavila.utils.preprocess import generate_tokenizer +from lavila.utils.random import random_seed +from lavila.utils.scheduler import cosine_scheduler + + +class GroundTruthDataset(torch.utils.data.Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, index): + return 1, self.dataset[index] + + def __len__(self): + return len(self.dataset) + + +class PseudoLabelDataset(torch.utils.data.Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, index): + return 0, self.dataset[index] + + def __len__(self): + return len(self.dataset) + + +def get_args_parser(): + parser = argparse.ArgumentParser(description='LaVid training and evaluation', add_help=False) + # Data + parser.add_argument('--dataset', default='ego4d', type=str, choices=['ego4d']) + parser.add_argument('--root', default='datasets/Ego4D/video_5min_chunks_288px/', + type=str, help='path to dataset root') + parser.add_argument('--metadata', default='datasets/Ego4D/ego4d_train.pkl', + type=str, help='path to metadata file') + parser.add_argument('--metadata-aux', default=None, nargs='+', + type=str, help='path to metadata file (auxiliary data with pseudo narrations)') + parser.add_argument('--output-dir', default='./', type=str, help='output dir') + parser.add_argument('--clip-length', default=4, type=int, help='clip length') + parser.add_argument('--clip-stride', default=16, type=int, help='clip stride') + parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling') + parser.add_argument('--narration-selection', default='random', + choices=['random', 'concat'], + type=str, help='selection strategy if multiple narrations per clip') + parser.add_argument('--num-hard-neg', default=0, type=int, help='number of hard negatives per video') + # Model + parser.add_argument('--model', default='CLIP_OPENAI_TIMESFORMER_BASE', type=str) + parser.add_argument('--norm-embed', action='store_true', help='norm text and visual embed if set True') + parser.add_argument('--resume', default='', type=str, help='path to resume from') + parser.add_argument('--load-visual-pretrained', default=None, type=str, + help='path to pretrained model (in1k/in21k/...)') + parser.add_argument('--project-embed-dim', default=256, type=int, help='embed dim after projection') + parser.add_argument('--use-cls-token', action='store_true', help='use feature at [CLS] if set True') + parser.add_argument('--contrastive-use-vissl', action='store_true', help='use contrastive implementation in vissl') + parser.add_argument('--gated-xattn', action='store_true', help='use gated x-attn in VCLM_GPT2') + parser.add_argument('--random-init-gpt2', action='store_true', help='random initialize params of text decoder in VCLM_GPT2') + parser.add_argument('--timesformer-gated-xattn', action='store_true', help='use gated x-attn in TimeSformer') + parser.add_argument('--timesformer-freeze-space', action='store_true', help='freeze space part in TimeSformer') + parser.add_argument('--drop-path-rate', default=0., type=float, help='DropPath rate') + parser.add_argument('--freeze-visual-vclm', action='store_true', help='freeze the visual model in VCLM_GPT2') + parser.add_argument('--freeze-visual-vclm-temporal', action='store_true', help='freeze the temporal part of visual model in VCLM_GPT2') + parser.add_argument('--freeze-lm-vclm', action='store_true', help='freeze the lm in VCLM_GPT2') + parser.add_argument('--find-unused-parameters', action='store_true', + help='do this during DDP (useful for models with tied weights)') + # Training + parser.add_argument('--epochs', default=5, type=int) + parser.add_argument('--warmup-epochs', default=1, type=int) + parser.add_argument('--start-epoch', default=0, type=int) + parser.add_argument('--batch-size', default=32, type=int, + help='number of samples per-device/per-gpu') + parser.add_argument('--temperature-init', default=0.07, type=float, + help='init. logit temperature for samples') + parser.add_argument('--freeze-temperature', action='store_true', + help='freeze logit temperature') + parser.add_argument('--pseudo-temperature-init', default=0.07, type=float, + help='init. logit temperature for pseudo-narrated samples') + parser.add_argument('--freeze-pseudo-temperature', action='store_true', + help='freeze logit temperature (for pseudo-narrated samples)') + parser.add_argument('--lr', default=3e-5, type=float) + parser.add_argument('--fix-lr', action='store_true', help='disable cosine lr decay if set True') + parser.add_argument('--lr-start', default=1e-6, type=float, + help='initial warmup lr') + parser.add_argument('--lr-end', default=1e-5, type=float, + help='minimum final lr') + parser.add_argument('--clip-grad-type', default='norm', choices=['norm', 'value']) + parser.add_argument('--clip-grad-value', default=None, type=float, help='') + parser.add_argument('--update-freq', default=1, type=int, + help='optimizer update frequency (i.e. gradient accumulation steps)') + parser.add_argument('--wd', default=0.01, type=float) + parser.add_argument('--betas', default=(0.9, 0.999), nargs=2, type=float) + parser.add_argument('--eps', default=1e-8, type=float) + parser.add_argument('--eval-freq', default=99, type=int) + parser.add_argument('--eval-in-middle-freq', default=-1, type=int) + parser.add_argument('--save-freq', default=1, type=int) + parser.add_argument('--disable-amp', action='store_true', + help='disable mixed-precision training (requires more memory and compute)') + parser.add_argument('--use-zero', action='store_true', + help='use ZeroRedundancyOptimizer to save memory') + parser.add_argument('--use-checkpoint', action='store_true', + help='use gradient checkpointing during training for significantly less GPU usage') + parser.add_argument('--use-half', action='store_true', help='evaluate using half-precision') + # System + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', + help='number of data loading workers per process') + parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--dist-url', default='env://', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str) + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') + parser.add_argument('--wandb', action='store_true', help='Enable WandB logging') + return parser + + +def main(args): + dist_utils.init_distributed_mode(args) + + global best_acc1 + random_seed(args.seed, dist_utils.get_rank()) + + print("=> creating model: {}".format(args.model)) + model = getattr(models, args.model)( + pretrained=args.load_visual_pretrained, + pretrained2d=args.load_visual_pretrained is not None, + text_use_cls_token=args.use_cls_token, + project_embed_dim=args.project_embed_dim, + gated_xattn=args.gated_xattn, + random_init_gpt2=args.random_init_gpt2, + timesformer_gated_xattn=args.timesformer_gated_xattn, + timesformer_freeze_space=args.timesformer_freeze_space, + freeze_lm_vclm=args.freeze_lm_vclm, + freeze_visual_vclm=args.freeze_visual_vclm, + freeze_visual_vclm_temporal=args.freeze_visual_vclm_temporal, + num_frames=args.clip_length, + drop_path_rate=args.drop_path_rate, + temperature_init=args.temperature_init, + ) + if args.freeze_temperature: + print('Freeze logit temperature') + model.logit_scale.requires_grad = False + model.cuda(args.gpu) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], bucket_cap_mb=200, + find_unused_parameters=args.find_unused_parameters + ) + + tokenizer = generate_tokenizer(args.model) + + if args.metadata_aux is None: + criterion = models.get_loss(args.model, args, tokenizer=tokenizer).cuda(args.gpu) + else: + criterion = models.loss.SSLCLIPLoss( + use_vissl=args.contrastive_use_vissl, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + scale_init=args.pseudo_temperature_init, + freeze_scale=args.freeze_pseudo_temperature, + ).cuda(args.gpu) + + p_wd, p_non_wd = [], [] + for n, p in model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n: + p_non_wd.append(p) + else: + p_wd.append(p) + for n, p in criterion.named_parameters(): + if not p.requires_grad: + continue + p_non_wd.append(p) + + optim_params = [{"params": p_wd, "weight_decay": args.wd}, + {"params": p_non_wd, "weight_decay": 0}] + + if args.use_zero: + optimizer = ZeroRedundancyOptimizer( + optim_params, optimizer_class=torch.optim.AdamW, + lr=args.lr, betas=args.betas, eps=args.eps, weight_decay=args.wd + ) + else: + optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas, + eps=args.eps, weight_decay=args.wd) + scaler = amp.GradScaler(enabled=not args.disable_amp) + # optionally resume from a checkpoint (takes precedence over autoresume) + latest = os.path.join(args.output_dir, 'checkpoint.pt') + if os.path.isfile(latest): + args.resume = '' + if args.resume: + if os.path.isfile(args.resume): + print("=> loading resume checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume, map_location='cpu') + epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0 + args.start_epoch = epoch + result = model.load_state_dict(checkpoint['state_dict'], strict=False) + print(result) + optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else () + scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else () + criterion.load_state_dict(checkpoint['criterion']) if 'criterion' in checkpoint else () + best_acc1 = checkpoint['best_acc1'] + print("=> loaded resume checkpoint '{}' (epoch {})" + .format(args.resume, epoch)) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + else: + # auto-resume from latest checkpoint in output directory + latest = os.path.join(args.output_dir, 'checkpoint.pt') + if os.path.isfile(latest): + print("=> loading latest checkpoint '{}'".format(latest)) + latest_checkpoint = torch.load(latest, map_location='cpu') + args.start_epoch = latest_checkpoint['epoch'] + model.load_state_dict(latest_checkpoint['state_dict']) + optimizer.load_state_dict(latest_checkpoint['optimizer']) + scaler.load_state_dict(latest_checkpoint['scaler']) + best_acc1 = latest_checkpoint['best_acc1'] + print("=> loaded latest checkpoint '{}' (epoch {})" + .format(latest, latest_checkpoint['epoch'])) + + cudnn.benchmark = True + + # Data loading code + print("=> creating dataset") + + crop_size = 224 if '336PX' not in args.model else 336 + transforms_list = [ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.RandomResizedCrop(crop_size, scale=(0.5, 1.0)), + ] + if 'OPENAI' in args.model: + transforms_list.append(transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])) + else: + transforms_list.append(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) + train_transform = transforms.Compose(transforms_list) + + # TODO: uncomment when evaluation is done later + val_transform = transforms.Compose([ + Permute([3, 0, 1, 2]), # T H W C -> C T H W + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in args.model else + transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])) + ]) + + assert 'train' in args.metadata + train_dataset = datasets.get_dataset(train_transform, tokenizer, args, is_training=True) + args.metadata = args.metadata.replace('train', 'val') + val_dataset = datasets.get_dataset(val_transform, tokenizer, args, is_training=False) + args.metadata = args.metadata.replace('val', 'train') + if args.metadata_aux is not None: + train_dataset = GroundTruthDataset(train_dataset) + old_metadata = args.metadata + aux_dataset_list = [] + for aux_i, aux_pkl in enumerate(args.metadata_aux): + args.metadata = aux_pkl + aux_dataset = datasets.get_dataset(train_transform, tokenizer, args, is_training=True) + aux_dataset_list.append(PseudoLabelDataset(aux_dataset)) + print("auxiliary dataset [{}] : source = {}, len(aux_dataset) = {}".format(aux_i, aux_pkl, len(aux_dataset))) + pseudo_label_dataset = torch.utils.data.ConcatDataset(aux_dataset_list) + args.metadata = old_metadata + train_dataset = torch.utils.data.ConcatDataset([train_dataset, pseudo_label_dataset]) + val_dataset = GroundTruthDataset(val_dataset) + + ek100_dataset = datasets.VideoCaptionDatasetCLIP( + 'ek100_mir', + 'datasets/EK100/video_ht256px/', + 'datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv', + transform=val_transform, + is_training=False, + tokenizer=tokenizer, + clip_length=args.clip_length, + clip_stride=args.clip_stride, + sparse_sample=False + ) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + ek100_sampler = torch.utils.data.SequentialSampler(ek100_dataset) + else: + train_sampler = None + val_sampler = None + ek100_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True + ) + print('len(train_loader) = {}'.format(len(train_loader))) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False + ) + print('len(val_loader) = {}'.format(len(val_loader))) + ek100_loader = torch.utils.data.DataLoader( + ek100_dataset, batch_size=args.batch_size * (1 + args.num_hard_neg), shuffle=(ek100_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=ek100_sampler, drop_last=False + ) + print('len(ek100_loader) = {}'.format(len(ek100_loader))) + + if args.fix_lr: + lr_schedule = None + else: + lr_schedule = cosine_scheduler( + args.lr, args.lr_end, args.epochs, len(train_loader) // args.update_freq, + warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start, + ) + + if dist_utils.is_main_process() and args.wandb: + wandb_id = os.path.split(args.output_dir)[-1] + wandb.init(project='LaVid', id=wandb_id, config=args, resume='allow') + + print(args) + + best_metric = 0. + print("=> beginning training") + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + + if hasattr(args, 'eval_in_middle_freq') and args.eval_in_middle_freq > 0: + train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args, + ek100_loader=ek100_loader, eval_in_middle=args.eval_in_middle_freq) + else: + train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args) + + if args.model.startswith('CLIP'): + print('=> 0-shot on EK100') + similarity_matrix = get_similarity_matrix(ek100_loader, model, use_half=args.use_half) + similarity_matrix = (similarity_matrix + 1) / 2 + video_id = pd.read_csv("datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv").values[:, 0] + text_id = pd.read_csv("datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test_sentence.csv").values[:, 0] + indexes = [video_id.tolist().index(elem) for elem in text_id] + similarity_matrix = similarity_matrix[:, indexes] + rel_matrix = pd.read_pickle( + 'datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl' + ) + vis_map, txt_map, avg_map = get_mAP(similarity_matrix, rel_matrix) + print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map)) + vis_ndcg, txt_ndcg, avg_ndcg = get_nDCG(similarity_matrix, rel_matrix) + print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_ndcg, txt_ndcg, avg_ndcg)) + if avg_map > best_metric: + is_best = True + best_metric = avg_map + else: + is_best = False + else: + is_best = False + + is_epoch = ((epoch + 1) % args.save_freq) == 0 + + if args.distributed and args.use_zero: + print("=> consolidating state_dict before saving (due to ZeRO)") + optimizer.consolidate_state_dict() + + print('=> saving checkpoint') + dist_utils.save_on_master({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'criterion': criterion.state_dict(), + 'optimizer': optimizer.state_dict() if dist_utils.get_rank() == 0 else {}, + 'scaler': scaler.state_dict(), + 'best_acc1': best_metric, + 'args': args, + }, is_best, args.output_dir, is_epoch=is_epoch) + + if (epoch + 1) % args.eval_freq != 0: + continue + + # TODO: add evaluation + val_stats = validate(val_loader, model, criterion, args) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in val_stats.items()}, + 'epoch': epoch} + + if dist_utils.is_main_process(): + if args.wandb: + wandb.log(log_stats) + with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f: + f.write(json.dumps(log_stats) + '\n') + + +def train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args, ek100_loader=None, eval_in_middle=0): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + mem = AverageMeter('Mem (GB)', ':6.1f') + metric_names = models.get_metric_names(args.model) + if args.metadata_aux is not None: + metric_names.extend(['num_gt', 'num_pseudo', 'clip_acc_gt', 'clip_acc_pseudo']) + iters_per_epoch = len(train_loader) // args.update_freq + metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, *metrics.values()], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for data_iter, inputs in enumerate(train_loader): + # evaluate in the middle of training + if eval_in_middle > 0 and (data_iter > 0 and data_iter % eval_in_middle) and ek100_loader is not None: + model.eval() + print('=> 0-shot on EK100 in the middle of training') + similarity_matrix = get_similarity_matrix(ek100_loader, model, use_half=args.use_half) + similarity_matrix = (similarity_matrix + 1) / 2 + video_id = pd.read_csv("datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv").values[:, 0] + text_id = pd.read_csv("datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test_sentence.csv").values[:, 0] + indexes = [video_id.tolist().index(elem) for elem in text_id] + similarity_matrix = similarity_matrix[:, indexes] + rel_matrix = pd.read_pickle( + 'datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl' + ) + vis_map, txt_map, avg_map = get_mAP(similarity_matrix, rel_matrix) + print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map)) + vis_ndcg, txt_ndcg, avg_ndcg = get_nDCG(similarity_matrix, rel_matrix) + print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_ndcg, txt_ndcg, avg_ndcg)) + best_metric = avg_map + + print('=> saving checkpoint') + dist_utils.save_on_master({ + 'epoch': epoch + data_iter / len(train_loader), + 'state_dict': model.state_dict(), + 'criterion': criterion.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler.state_dict(), + 'best_acc1': best_metric, + 'args': args, + }, False, args.output_dir, is_epoch=True) # save every time (not to conflict the best_metric tracking in the regular validation phrase) + model.train() + + if args.metadata_aux is not None: + gt_indicators, inputs = inputs + + optim_iter = data_iter // args.update_freq + + # measure data loading time + data_time.update(time.time() - end) + + # update weight decay and learning rate according to their schedule + it = iters_per_epoch * epoch + optim_iter # global training iteration + for k, param_group in enumerate(optimizer.param_groups): + if lr_schedule is not None: + param_group['lr'] = lr_schedule[it] + + inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] + _ = inputs.pop() # loader will a "relevancy" variable which is not needed except ek100_mir + + # compute output + with amp.autocast(enabled=not args.disable_amp): + outputs = model( + *inputs, + use_checkpoint=args.use_checkpoint, + norm_embed=args.norm_embed + ) + if args.metadata_aux is None: + loss_dict = criterion(outputs) + else: + loss_dict = criterion(outputs, gt_indicators) + loss = loss_dict['loss'] + loss /= args.update_freq + + if not math.isfinite(loss.item()): + print("Loss is {}, stopping training".format(loss.item())) + sys.exit(1) + + scaler.scale(loss).backward() + + if (data_iter + 1) % args.update_freq != 0: + continue + + if args.clip_grad_value is not None: + scaler.unscale_(optimizer) + if args.clip_grad_type == 'norm': + torch.nn.utils.clip_grad_norm_( + model.parameters(), args.clip_grad_value, norm_type=2. + ) + elif args.clip_grad_type == 'value': + torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_value) + else: + assert False, f"Unknown clip mode ({args.clip_grad_type})." + # compute gradient and do SGD step + scaler.step(optimizer) + scaler.update() + model.zero_grad(set_to_none=True) + + if hasattr(dist_utils.get_model(model), 'logit_scale'): + # clamp logit scale to [0, 100] + dist_utils.get_model(model).logit_scale.data.clamp_(0, 4.6052) + logit_scale = dist_utils.get_model(model).logit_scale.exp().item() + else: + logit_scale = torch.nan + + for k in loss_dict: + metrics[k].update(loss_dict[k].item(), args.batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if optim_iter % args.print_freq == 0: + if dist_utils.is_main_process() and args.wandb: + wandb.log({**{k: v.item() for k, v in loss_dict.items()}, + 'scaler': scaler.get_scale(), 'logit': logit_scale}) + progress.display(optim_iter) + progress.synchronize() + return {**{k: v.avg for k, v in metrics.items()}, + 'lr': optimizer.param_groups[0]['lr'], + 'logit_scale': logit_scale} + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.2f') + data_time = AverageMeter('Data', ':6.2f') + mem = AverageMeter('Mem (GB)', ':6.1f') + metric_names = models.get_metric_names(args.model) + iters_per_epoch = len(val_loader) // args.update_freq + metrics = OrderedDict([(name, AverageMeter(name, ':.2e')) for name in metric_names]) + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, *metrics.values()], + prefix="Test: " + ) + + # switch to eval mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, inputs in enumerate(val_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.metadata_aux is not None: + gt_indicators, inputs = inputs + + inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs] + _ = inputs.pop() # loader will a "relevancy" variable which is not needed except ek100_mir + + # compute output + outputs = model( + *inputs, + use_checkpoint=args.use_checkpoint, + norm_embed=args.norm_embed + ) + if args.metadata_aux is None: + loss_dict = criterion(outputs) + else: + loss_dict = criterion(outputs, gt_indicators) + + for k in loss_dict: + metrics[k].update(loss_dict[k].item(), args.batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if i % args.print_freq == 0: + if dist_utils.is_main_process() and args.wandb: + wandb.log({**{k: v.item() for k, v in loss_dict.items()}}) + progress.display(i) + progress.synchronize() + return {**{k: v.avg for k, v in metrics.items()}} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('LaVid training and evaluation', parents=[get_args_parser()]) + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + main(args) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2056e3a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +timm==0.5.4 +torch==1.10.1 +torchvision==0.11.2 +decord==0.6.0 +einops==0.4.1 +pandas==1.4.2 +pytorchvideo==0.1.5 +transformers==4.21 +ftfy==4.4.3 +spacy==3.4.1 +scikit-learn==1.1.1 +git+https://github.com/Maluuba/nlg-eval.git@master diff --git a/run_with_submitit_finetune_classification.py b/run_with_submitit_finetune_classification.py new file mode 100644 index 0000000..25d10df --- /dev/null +++ b/run_with_submitit_finetune_classification.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +A script to run multinode training with submitit. +""" +import argparse +import os +import uuid +from pathlib import Path + +import main_finetune_classification as main_finetune +import submitit + + +def parse_args(): + parser = main_finetune.get_args_parser() + parser = argparse.ArgumentParser("Submitit for lavila fine-tuning", parents=[parser]) + parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=8, type=int, help="Number of nodes to request") + parser.add_argument("--timeout", default=2880, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") + parser.add_argument('--comment', default="", type=str, + help='Comment to pass to scheduler, e.g. priority message') + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments/lavila_ft") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main_finetune_classification as main_finetune + + self._setup_gpu_args() + main_finetune.main(self.args) + + def checkpoint(self): + import submitit + + self.args.dist_url = get_init_file().as_uri() + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=40 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name="lavila_ft") + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + +if __name__ == "__main__": + main() diff --git a/run_with_submitit_finetune_retrieval.py b/run_with_submitit_finetune_retrieval.py new file mode 100644 index 0000000..04f40b6 --- /dev/null +++ b/run_with_submitit_finetune_retrieval.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +A script to run multinode training with submitit. +""" +import argparse +import os +import uuid +from pathlib import Path + +import main_finetune_retrieval as main_finetune +import submitit + + +def parse_args(): + parser = main_finetune.get_args_parser() + parser = argparse.ArgumentParser("Submitit for lavila fine-tuning", parents=[parser]) + parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=8, type=int, help="Number of nodes to request") + parser.add_argument("--timeout", default=2880, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") + parser.add_argument('--comment', default="", type=str, + help='Comment to pass to scheduler, e.g. priority message') + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments/lavila_ft") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main_finetune_retrieval as main_finetune + + self._setup_gpu_args() + main_finetune.main(self.args) + + def checkpoint(self): + import submitit + + self.args.dist_url = get_init_file().as_uri() + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=40 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name="lavila_ft") + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + +if __name__ == "__main__": + main() diff --git a/run_with_submitit_infer_narrator.py b/run_with_submitit_infer_narrator.py new file mode 100644 index 0000000..16646e9 --- /dev/null +++ b/run_with_submitit_infer_narrator.py @@ -0,0 +1,127 @@ + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +A script to run multinode training with submitit. +""" +import argparse +import os +import uuid +from pathlib import Path + +import main_infer_narrator +import submitit + + +def parse_args(): + parser = main_infer_narrator.get_args_parser() + parser = argparse.ArgumentParser("Submitit for inferring narrator", parents=[parser]) + parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=4, type=int, help="Number of nodes to request") + parser.add_argument("--timeout", default=2880, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") + parser.add_argument('--comment', default="", type=str, + help='Comment to pass to scheduler, e.g. priority message') + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments/extract_caption") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main_infer_narrator + + self._setup_gpu_args() + main_infer_narrator.main(self.args) + + def checkpoint(self): + import submitit + + self.args.dist_url = get_init_file().as_uri() + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=55 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name="infer_narrator") + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + +if __name__ == "__main__": + main() diff --git a/run_with_submitit_pretrain.py b/run_with_submitit_pretrain.py new file mode 100644 index 0000000..2a9013e --- /dev/null +++ b/run_with_submitit_pretrain.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +A script to run multinode training with submitit. +""" +import argparse +import os +import uuid +from pathlib import Path + +import main_pretrain +import submitit + + +def parse_args(): + parser = main_pretrain.get_args_parser() + parser = argparse.ArgumentParser("Submitit for lavila pre-training", parents=[parser]) + parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=8, type=int, help="Number of nodes to request") + parser.add_argument("--timeout", default=2880, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") + parser.add_argument('--comment', default="", type=str, + help='Comment to pass to scheduler, e.g. priority message') + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments/lavila_pretrain") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main_pretrain + + self._setup_gpu_args() + main_pretrain.main(self.args) + + def checkpoint(self): + import submitit + + self.args.dist_url = get_init_file().as_uri() + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir = get_shared_folder() / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=40 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name="lavila_pretrain") + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + +if __name__ == "__main__": + main()