Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Update VideoClassifier (#276)
Browse files Browse the repository at this point in the history
* update

* update

* update
  • Loading branch information
tchaton authored May 11, 2021
1 parent d5ca5e5 commit e24aa62
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 37 deletions.
2 changes: 1 addition & 1 deletion flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.__about__ import * # noqa: F401 F403

_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
Expand Down
4 changes: 2 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ def available_models(cls) -> List[str]:
return registry.available_keys()

@classmethod
def get_model_details(cls, key) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
def get_backbone_details(cls, key) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "backbones", None)
if registry is None:
return []
return [v for v in inspect.signature(registry.get(key)).parameters.items()]
Expand Down
44 changes: 24 additions & 20 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@
from torch.utils.data import DistributedSampler
from torchmetrics import Accuracy

from flash.core.classification import ClassificationTask
from flash.core.classification import ClassificationTask, Labels
from flash.core.registry import FlashRegistry
from flash.data.process import Serializer
from flash.utils.imports import _PYTORCHVIDEO_AVAILABLE

_VIDEO_CLASSIFIER_MODELS = FlashRegistry("backbones")
_VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones")

if _PYTORCHVIDEO_AVAILABLE:
from pytorchvideo.models import hub
for fn_name in dir(hub):
if "__" not in fn_name:
fn = getattr(hub, fn_name)
if isinstance(fn, FunctionType):
_VIDEO_CLASSIFIER_MODELS(fn=fn)
_VIDEO_CLASSIFIER_BACKBONES(fn=fn)


class VideoClassifierFinetuning(BaseFinetuning):
Expand All @@ -49,7 +50,7 @@ def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: i
self.unfreeze_epoch = unfreeze_epoch

def freeze_before_training(self, pl_module: LightningModule) -> None:
self.freeze(modules=list(pl_module.model.children())[:-self.num_layers], train_bn=self.train_bn)
self.freeze(modules=list(pl_module.backbone.children())[:-self.num_layers], train_bn=self.train_bn)

def finetune_function(
self,
Expand All @@ -61,7 +62,7 @@ def finetune_function(
if epoch != self.unfreeze_epoch:
return
self.unfreeze_and_add_param_group(
modules=list(pl_module.model.children())[-self.num_layers:],
modules=list(pl_module.backbone.children())[-self.num_layers:],
optimizer=optimizer,
train_bn=self.train_bn,
)
Expand All @@ -72,7 +73,8 @@ class VideoClassifier(ClassificationTask):
Args:
num_classes: Number of classes to classify.
model: A string mapped to ``pytorch_video`` models or ``nn.Module``, defaults to ``"slowfast_r50"``.
backbone: A string mapped to ``pytorch_video`` backbones or ``nn.Module``, defaults to ``"slowfast_r50"``.
backbone_kwargs: Arguments to customize the backbone from PyTorchVideo.
pretrained: Use a pretrained backbone, defaults to ``True``.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
Expand All @@ -81,43 +83,45 @@ class VideoClassifier(ClassificationTask):
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
"""

models: FlashRegistry = _VIDEO_CLASSIFIER_MODELS
backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES

def __init__(
self,
num_classes: int,
model: Union[str, nn.Module] = "slow_r50",
model_kwargs: Optional[Dict] = None,
backbone: Union[str, nn.Module] = "slow_r50",
backbone_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(),
learning_rate: float = 1e-3,
head: Optional[Union[FunctionType, nn.Module]] = None,
serializer: Optional[Serializer] = None,
):
super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer or Labels()
)

self.save_hyperparameters()

if not model_kwargs:
model_kwargs = {}
if not backbone_kwargs:
backbone_kwargs = {}

model_kwargs["pretrained"] = pretrained
model_kwargs["head_activation"] = None
backbone_kwargs["pretrained"] = pretrained
backbone_kwargs["head_activation"] = None

if isinstance(model, nn.Module):
self.model = model
elif isinstance(model, str):
self.model = self.models.get(model)(**model_kwargs)
num_features = self.model.blocks[-1].proj.out_features
if isinstance(backbone, nn.Module):
self.backbone = backbone
elif isinstance(backbone, str):
self.backbone = self.backbones.get(backbone)(**backbone_kwargs)
num_features = self.backbone.blocks[-1].proj.out_features
else:
raise MisconfigurationException(f"model should be either a string or a nn.Module. Found: {model}")
raise MisconfigurationException(f"backbone should be either a string or a nn.Module. Found: {backbone}")

self.head = head or nn.Sequential(
nn.Flatten(),
Expand All @@ -140,7 +144,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
return super().step((batch["video"], batch["label"]), batch_idx)

def forward(self, x: Any) -> Any:
x = self.model(x)
x = self.backbone(x)
if self.head is not None:
x = self.head(x)
return x
Expand Down
23 changes: 11 additions & 12 deletions flash_examples/finetuning/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@

if __name__ == '__main__':

_PATH_ROOT = os.path.dirname(os.path.abspath(__file__))

# 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")

Expand Down Expand Up @@ -73,9 +71,9 @@ def make_transform(

# 3. Load the data from directories.
datamodule = VideoClassificationData.from_folders(
train_folder=os.path.join(_PATH_ROOT, "data/kinetics/train"),
val_folder=os.path.join(_PATH_ROOT, "data/kinetics/val"),
predict_folder=os.path.join(_PATH_ROOT, "data/kinetics/predict"),
train_folder=os.path.join(flash.PROJECT_ROOT, "data/kinetics/train"),
val_folder=os.path.join(flash.PROJECT_ROOT, "data/kinetics/val"),
predict_folder=os.path.join(flash.PROJECT_ROOT, "data/kinetics/predict"),
train_transform=make_transform(train_post_tensor_transform),
val_transform=make_transform(val_post_tensor_transform),
predict_transform=make_transform(val_post_tensor_transform),
Expand All @@ -87,20 +85,21 @@ def make_transform(
)

# 4. List the available models
print(VideoClassifier.available_models())
print(VideoClassifier.available_backbones())
# out: ['efficient_x3d_s', 'efficient_x3d_xs', ... ,slowfast_r50', 'x3d_m', 'x3d_s', 'x3d_xs']
print(VideoClassifier.get_model_details("x3d_xs"))
print(VideoClassifier.get_backbone_details("x3d_xs"))

# 5. Build the model - `x3d_xs` comes with `nn.Softmax` by default for their `head_activation`.
model = VideoClassifier(model="x3d_xs", num_classes=datamodule.num_classes)
model.serializer = Labels()
# 5. Build the VideoClassifier with a PyTorchVideo backbone.
model = VideoClassifier(
backbone="x3d_xs", num_classes=datamodule.num_classes, serializer=Labels(), pretrained=False
)

# 6. Finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer = flash.Trainer(fast_dev_run=True)
trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze())

trainer.save_checkpoint("video_classification.pt")

# 7. Make a prediction
predictions = model.predict(os.path.join(_PATH_ROOT, "data/kinetics/predict"))
predictions = model.predict(os.path.join(flash.PROJECT_ROOT, "data/kinetics/predict"))
print(predictions)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch>=1.7 # TODO: regenerate weights with lower PT version
torchmetrics
torchvision==0.8 # TODO: lower to 0.7 after PT 1.6
pytorch-lightning>=1.3.0rc1
pytorch-lightning>=1.3.1
lightning-bolts>=0.3.3
PyYAML>=5.1
Pillow>=7.2
Expand Down
2 changes: 1 addition & 1 deletion tests/video/test_video_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_image_classifier_finetune(tmpdir):
expected_t_shape = 5
assert sample["video"].shape[1] == expected_t_shape

assert len(VideoClassifier.available_models()) > 5
assert len(VideoClassifier.available_backbones()) > 5

train_transform = {
"post_tensor_transform": Compose([
Expand Down

0 comments on commit e24aa62

Please sign in to comment.