Skip to content

Commit

Permalink
refactor codebase, add timesformer support, improve tests (#24)
Browse files Browse the repository at this point in the history
* refactor codebase, add timesformer support, improve tests

* reformat

* reformat

* update workflow order

* fix a test, add styling script

* fix readme
  • Loading branch information
fcakyon authored Dec 2, 2022
1 parent 8b85f91 commit 717528a
Show file tree
Hide file tree
Showing 14 changed files with 263 additions and 51 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ jobs:
if: matrix.operating-system == 'macos-latest'
run: pip install torch==${{ matrix.torch-version }}

- name: Install Pytorchvideo from main branch
run: pip install git+https://github.com/facebookresearch/pytorchvideo.git

- name: Lint with flake8, black and isort
run: |
pip install .[dev]
Expand All @@ -77,6 +74,12 @@ jobs:
# exit-zero treats all errors as warnings. Allowed max line length is 120.
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics
- name: Install Pytorchvideo from main branch
run: pip install git+https://github.com/facebookresearch/pytorchvideo.git

- name: Install HF/Transformers from main branch
run: pip install -U git+https://github.com/huggingface/transformers.git

- name: Install video-transformers package from local setup.py
run: >
pip install .
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ jobs:
- name: Install Pytorchvideo from main branch
run: pip install git+https://github.com/facebookresearch/pytorchvideo.git

- name: Install HF/Transformers from main branch
run: pip install -U git+https://github.com/huggingface/transformers.git

- name: Install latest video-transformers package
run: >
pip install --upgrade --force-reinstall video-transformers[test]
Expand Down
73 changes: 66 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ and supports:
conda install pytorch=1.11.0 torchvision=0.12.0 cudatoolkit=11.3 -c pytorch
```

- Install pytorchvideo from main branch:
- Install pytorchvideo and transformers from main branch:

```bash
pip install git+https://github.com/facebookresearch/pytorchvideo.git
pip install git+https://github.com/huggingface/transformers.git
```

- Install `video-transformers`:
Expand Down Expand Up @@ -83,7 +84,48 @@ val_root
...
```

- Fine-tune CVT (from HuggingFace) + Transformer based video classifier:
- Fine-tune Timesformer (from HuggingFace) video classifier:

```python
from torch.optim import AdamW
from video_transformers import VideoModel
from video_transformers.backbones.transformers import TransformersBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.trainer import trainer_factory
from video_transformers.utils.file import download_ucf6

backbone = TransformersBackbone("facebook/timesformer-base-finetuned-k400", num_unfrozen_stages=1)

download_ucf6("./")
datamodule = VideoDataModule(
train_root="ucf6/train",
val_root="ucf6/val",
batch_size=4,
num_workers=4,
num_timesteps=8,
preprocess_input_size=224,
preprocess_clip_duration=1,
preprocess_means=backbone.mean,
preprocess_stds=backbone.std,
preprocess_min_short_side=256,
preprocess_max_short_side=320,
preprocess_horizontal_flip_p=0.5,
)

head = LinearHead(hidden_size=backbone.num_features, num_classes=datamodule.num_classes)
model = VideoModel(backbone, head)

optimizer = AdamW(model.parameters(), lr=1e-4)

Trainer = trainer_factory("single_label_classification")
trainer = Trainer(datamodule, model, optimizer=optimizer, max_epochs=8)

trainer.fit()

```

- Fine-tune ConvNeXT (from HuggingFace) + Transformer based video classifier:

```python
from torch.optim import AdamW
Expand All @@ -95,7 +137,7 @@ from video_transformers.necks import TransformerNeck
from video_transformers.trainer import trainer_factory
from video_transformers.utils.file import download_ucf6

backbone = TimeDistributed(TransformersBackbone("microsoft/cvt-13", num_unfrozen_stages=0))
backbone = TimeDistributed(TransformersBackbone("facebook/convnext-small-224", num_unfrozen_stages=1))
neck = TransformerNeck(
num_features=backbone.num_features,
num_timesteps=8,
Expand Down Expand Up @@ -137,18 +179,18 @@ trainer.fit()

```

- Fine-tune MobileViT (from Timm) + GRU based video classifier:
- Fine-tune Resnet18 (from HuggingFace) + GRU based video classifier:

```python
from video_transformers import TimeDistributed, VideoModel
from video_transformers.backbones.timm import TimmBackbone
from video_transformers.backbones.transformers import TransformersBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import GRUNeck
from video_transformers.trainer import trainer_factory
from video_transformers.utils.file import download_ucf6

backbone = TimeDistributed(TimmBackbone("mobilevitv2_100", num_unfrozen_stages=0))
backbone = TimeDistributed(TransformersBackbone("microsoft/resnet-18", num_unfrozen_stages=1))
neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2, return_last=True)

download_ucf6("./")
Expand Down Expand Up @@ -188,7 +230,7 @@ from video_transformers import VideoModel

model = VideoModel.from_pretrained(model_name_or_path)

model.predict(video_path="video.mp4")
model.predict(video_or_folder_path="video.mp4")
>> [{'filename': "video.mp4", 'predictions': {'class1': 0.98, 'class2': 0.02}}]
```

Expand Down Expand Up @@ -277,3 +319,20 @@ from video_transformers import VideoModel
model = VideoModel.from_pretrained("runs/exp/checkpoint")
model.to_gradio(examples=['video.mp4'], export_dir="runs/exports/", export_filename="app.py")
```


## Contributing

Before opening a PR:

- Install required development packages:

```bash
pip install -e ."[dev]"
```

- Reformat with black and isort:

```bash
python -m tests.run_code_style format
```
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
accelerate>=0.14.0,<0.15.0
evaluate>=0.3.0,<0.4.0
transformers>=4.24.0,<4.25.0
transformers>=4.25.0
timm>=0.6.12,<0.7.0
click==8.0.4
balanced-loss
Expand Down
16 changes: 16 additions & 0 deletions tests/run_code_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import sys

from tests.utils import shell, validate_and_exit

if __name__ == "__main__":
arg = sys.argv[1]

if arg == "check":
sts_flake = shell("flake8 . --config setup.cfg --select=E9,F63,F7,F82")
sts_isort = shell("isort . --check --settings pyproject.toml")
sts_black = shell("black . --check --config pyproject.toml")
validate_and_exit(flake8=sts_flake, isort=sts_isort, black=sts_black)
elif arg == "format":
sts_isort = shell("isort . --settings pyproject.toml")
sts_black = shell("black . --config pyproject.toml")
validate_and_exit(isort=sts_isort, black=sts_black)
32 changes: 15 additions & 17 deletions tests/test_auto_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def test_transformers_backbone(self):
from video_transformers import AutoBackbone

config = {
"framework": {"name": "timm"},
"framework": {"name": "transformers"},
"type": "2d_backbone",
"model_name": "mobilevitv2_100",
"model_name": "microsoft/resnet-18",
"num_timesteps": 8,
}
batch_size = 2
Expand All @@ -20,23 +20,21 @@ def test_transformers_backbone(self):
output = backbone(input)
self.assertEqual(output.shape, (batch_size, config["num_timesteps"], backbone.num_features))

def test_timm_backbone(self):
import torch

def test_from_transformers(self):
from video_transformers import AutoBackbone

config = {
"framework": {"name": "transformers"},
"type": "2d_backbone",
"model_name": "microsoft/cvt-13",
"num_timesteps": 8,
}
batch_size = 2

backbone = AutoBackbone.from_config(config)
input = torch.randn(batch_size, 3, config["num_timesteps"], 224, 224)
output = backbone(input)
self.assertEqual(output.shape, (batch_size, config["num_timesteps"], backbone.num_features))
backbone = AutoBackbone.from_transformers("facebook/timesformer-base-finetuned-k400")
assert backbone.model_name == "facebook/timesformer-base-finetuned-k400"
backbone = AutoBackbone.from_transformers("facebook/timesformer-base-finetuned-k600")
assert backbone.model_name == "facebook/timesformer-base-finetuned-k600"
backbone = AutoBackbone.from_transformers("facebook/timesformer-hr-finetuned-k400")
assert backbone.model_name == "facebook/timesformer-hr-finetuned-k400"
backbone = AutoBackbone.from_transformers("facebook/timesformer-hr-finetuned-k600")
assert backbone.model_name == "facebook/timesformer-hr-finetuned-k600"
backbone = AutoBackbone.from_transformers("facebook/timesformer-base-finetuned-ssv2")
assert backbone.model_name == "facebook/timesformer-base-finetuned-ssv2"
backbone = AutoBackbone.from_transformers("facebook/timesformer-hr-finetuned-ssv2")
assert backbone.model_name == "facebook/timesformer-hr-finetuned-ssv2"


if __name__ == "__main__":
Expand Down
18 changes: 17 additions & 1 deletion tests/test_auto_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class TestAutoHead(unittest.TestCase):
def test_liear_head(self):
def test_linear_head(self):
import torch

from video_transformers import AutoHead
Expand All @@ -20,6 +20,22 @@ def test_liear_head(self):
output = head(input)
self.assertEqual(output.shape, (batch_size, config["num_classes"]))

def test_from_transformers(self):
from video_transformers import AutoHead

linear_head = AutoHead.from_transformers("facebook/timesformer-base-finetuned-k400")
assert linear_head.num_classes == 400
linear_head = AutoHead.from_transformers("facebook/timesformer-base-finetuned-k600")
assert linear_head.num_classes == 600
linear_head = AutoHead.from_transformers("facebook/timesformer-hr-finetuned-k400")
assert linear_head.num_classes == 400
linear_head = AutoHead.from_transformers("facebook/timesformer-hr-finetuned-k600")
assert linear_head.num_classes == 600
linear_head = AutoHead.from_transformers("facebook/timesformer-base-finetuned-ssv2")
assert linear_head.num_classes == 174
linear_head = AutoHead.from_transformers("facebook/timesformer-hr-finetuned-ssv2")
assert linear_head.num_classes == 174


if __name__ == "__main__":
unittest.main()
41 changes: 41 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import shutil
import sys


def shell(command, exit_status=0):
"""
Run command through shell and return exit status if exit status of command run match with given exit status.
Args:
command: (str) Command string which runs through system shell.
exit_status: (int) Expected exit status of given command run.
Returns: actual_exit_status
"""
actual_exit_status = os.system(command)
if actual_exit_status == exit_status:
return 0
return actual_exit_status


def validate_and_exit(expected_out_status=0, **kwargs):
if all([arg == expected_out_status for arg in kwargs.values()]):
# Expected status, OK
sys.exit(0)
else:
# Failure
print_console_centered("Summary Results")
fail_count = 0
for component, exit_status in kwargs.items():
if exit_status != expected_out_status:
print(f"{component} failed.")
fail_count += 1
print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure")
sys.exit(1)


def print_console_centered(text: str, fill_char="="):
w, _ = shutil.get_terminal_size((80, 20))
print(f" {text} ".center(w, fill_char))
2 changes: 1 addition & 1 deletion video_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from video_transformers.auto.neck import AutoNeck
from video_transformers.modeling import TimeDistributed, VideoModel

__version__ = "0.0.7"
__version__ = "0.0.8"
21 changes: 12 additions & 9 deletions video_transformers/auto/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@ def from_config(cls, config: Dict) -> Union[Backbone, TimeDistributed]:
backbone_type = config.get("type")
backbone_model_name = config.get("model_name")

if backbone_framework["name"] == "transformers":
from video_transformers.backbones.transformers import TransformersBackbone
from video_transformers.backbones.transformers import TransformersBackbone

backbone = TransformersBackbone(model_name=backbone_model_name)
elif backbone_framework["name"] == "timm":
from video_transformers.backbones.timm import TimmBackbone

backbone = TimmBackbone(model_name=backbone_model_name)
else:
raise ValueError(f"Unknown framework {backbone_framework}")
backbone = TransformersBackbone(model_name=backbone_model_name)

if backbone_type == "2d_backbone":
from video_transformers.modeling import TimeDistributed

backbone = TimeDistributed(backbone)
return backbone

@classmethod
def from_transformers(cls, name_or_path: str) -> Union[Backbone, TimeDistributed]:
from video_transformers.backbones.transformers import TransformersBackbone

backbone = TransformersBackbone(model_name=name_or_path)

if backbone.type == "2d_backbone":
raise ValueError("2D backbones are not supported for from_transformers method.")
return backbone
12 changes: 12 additions & 0 deletions video_transformers/auto/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ def from_config(cls, config: Dict):
return LinearHead(hidden_size, num_classes, dropout_p)
else:
raise ValueError(f"Unsupported head class name: {head_class_name}")

@classmethod
def from_transformers(cls, name_or_path: str):
from transformers import AutoModelForVideoClassification

from video_transformers.heads import LinearHead

model = AutoModelForVideoClassification.from_pretrained(name_or_path)
linear_head = LinearHead(model.classifier.in_features, model.classifier.out_features)
linear_head.linear.weight = model.classifier.weight
linear_head.linear.bias = model.classifier.bias
return linear_head
Loading

0 comments on commit 717528a

Please sign in to comment.