diff --git a/README.md b/README.md index 76c3679..03a0335 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ datamodule = VideoDataModule( num_workers=4, num_timesteps=8, preprocess_input_size=224, + preprocess_clip_duration=1, preprocess_means=backbone.mean, preprocess_stds=backbone.std, preprocess_min_short_side_scale=256, @@ -119,7 +120,8 @@ Trainer = trainer_factory("single_label_classification") trainer = Trainer( datamodule, model, - optimizer=optimizer + optimizer=optimizer, + max_epochs=8 ) trainer.fit() @@ -142,12 +144,12 @@ neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2 datamodule = VideoDataModule( train_root=".../ucf6/train", val_root=".../ucf6/val", - clip_duration=2, train_dataset_multiplier=1, batch_size=4, num_workers=4, num_timesteps=8, preprocess_input_size=224, + preprocess_clip_duration=1, preprocess_means=backbone.mean, preprocess_stds=backbone.std, preprocess_min_short_side_scale=256, @@ -162,6 +164,7 @@ Trainer = trainer_factory("single_label_classification") trainer = Trainer( datamodule, model, + max_epochs=8 ) trainer.fit() diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 827d4d4..4252292 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -28,12 +28,15 @@ def test_onnx_export(self): "transformer_enc_num_layers": 2, "return_mean": True, }, - "preprocess_means": [0.485, 0.456, 0.406], - "preprocess_stds": [0.229, 0.224, 0.225], - "preprocess_min_short_side_scale": 256, - "preprocess_input_size": 224, - "num_timesteps": 8, + "preprocessor": { + "means": [0.485, 0.456, 0.406], + "stds": [0.229, 0.224, 0.225], + "min_short_side": 256, + "input_size": 224, + "num_timesteps": 8, + }, "labels": ["BodyWeightSquats", "JumpRope", "Lunges", "PullUps", "PushUps", "WallPushups"], + "task": "single_label_classification", } model = VideoClassificationModel.from_config(config) @@ -66,12 +69,15 @@ def test_quantized_onnx_export(self): "transformer_enc_num_layers": 2, "return_mean": True, }, - "preprocess_means": [0.485, 0.456, 0.406], - "preprocess_stds": [0.229, 0.224, 0.225], - "preprocess_min_short_side_scale": 256, - "preprocess_input_size": 224, - "num_timesteps": 8, + "preprocessor": { + "means": [0.485, 0.456, 0.406], + "stds": [0.229, 0.224, 0.225], + "min_short_side": 256, + "input_size": 224, + "num_timesteps": 8, + }, "labels": ["BodyWeightSquats", "JumpRope", "Lunges", "PullUps", "PushUps", "WallPushups"], + "task": "single_label_classification", } model = VideoClassificationModel.from_config(config) diff --git a/tests/test_video_classification_model.py b/tests/test_video_classification_model.py index 78cc71a..5d834f8 100644 --- a/tests/test_video_classification_model.py +++ b/tests/test_video_classification_model.py @@ -30,18 +30,21 @@ def test_transformers_backbone(self): "transformer_enc_num_layers": 2, "return_mean": True, }, - "preprocess_means": [0.485, 0.456, 0.406], - "preprocess_tds": [0.229, 0.224, 0.225], - "preprocess_min_short_side_scale": 256, - "preprocess_input_size": 224, - "num_timesteps": 8, + "preprocessor": { + "means": [0.485, 0.456, 0.406], + "stds": [0.229, 0.224, 0.225], + "min_short_side": 256, + "input_size": 224, + "num_timesteps": 8, + }, "labels": ["BodyWeightSquats", "JumpRope", "Lunges", "PullUps", "PushUps", "WallPushups"], + "task": "single_label_classification", } batch_size = 2 model = VideoClassificationModel.from_config(config) - input = torch.randn(batch_size, 3, config["num_timesteps"], 224, 224) + input = torch.randn(batch_size, 3, config["preprocessor"]["num_timesteps"], 224, 224) output = model(input) self.assertEqual(output.shape, (batch_size, model.head.num_classes)) diff --git a/video_transformers/__init__.py b/video_transformers/__init__.py index e7813f3..2ada70c 100644 --- a/video_transformers/__init__.py +++ b/video_transformers/__init__.py @@ -1,6 +1,6 @@ from video_transformers.auto.backbone import AutoBackbone from video_transformers.auto.head import AutoHead from video_transformers.auto.neck import AutoNeck -from video_transformers.modules import TimeDistributed, VideoClassificationModel +from video_transformers.modeling import TimeDistributed, VideoClassificationModel -__version__ = "0.0.5" +__version__ = "0.0.6" diff --git a/video_transformers/auto/backbone.py b/video_transformers/auto/backbone.py index 37436bd..50490d2 100644 --- a/video_transformers/auto/backbone.py +++ b/video_transformers/auto/backbone.py @@ -1,7 +1,7 @@ from typing import Dict, Union from video_transformers.backbones.base import Backbone -from video_transformers.modules import TimeDistributed +from video_transformers.modeling import TimeDistributed class AutoBackbone: @@ -27,7 +27,7 @@ def from_config(cls, config: Dict) -> Union[Backbone, TimeDistributed]: raise ValueError(f"Unknown framework {backbone_framework}") if backbone_type == "2d_backbone": - from video_transformers.modules import TimeDistributed + from video_transformers.modeling import TimeDistributed backbone = TimeDistributed(backbone) return backbone diff --git a/video_transformers/backbones/timm.py b/video_transformers/backbones/timm.py index feb23f5..45eba1a 100644 --- a/video_transformers/backbones/timm.py +++ b/video_transformers/backbones/timm.py @@ -3,7 +3,7 @@ from torch import nn from video_transformers.backbones.base import Backbone -from video_transformers.modules import Identity +from video_transformers.modeling import Identity from video_transformers.utils.torch import unfreeze_last_n_stages as unfreeze_last_n_stages_torch diff --git a/video_transformers/data.py b/video_transformers/data.py index a240680..f6c03a5 100644 --- a/video_transformers/data.py +++ b/video_transformers/data.py @@ -14,15 +14,15 @@ from torch.utils.data import DataLoader from torchvision.transforms import CenterCrop, Compose, Lambda, RandomCrop, RandomHorizontalFlip -from video_transformers.utils.dataset import LabeledVideoDataset, LabeledVideoPaths +from video_transformers.pytorchvideo_wrapper.data.labeled_video_paths import LabeledVideoDataset, LabeledVideoPaths from video_transformers.utils.extra import class_to_config logger = get_logger(__name__) -class VideoPreprocess: +class VideoPreprocessor: @classmethod - def from_config(cls, config: Dict, **kwargs) -> "VideoPreprocess": + def from_config(cls, config: Dict, **kwargs) -> "VideoPreprocessor": """ Creates an instance of the class from a config. @@ -36,25 +36,27 @@ def from_config(cls, config: Dict, **kwargs) -> "VideoPreprocess": def __init__( self, - timesteps: int = 8, + num_timesteps: int = 8, input_size: int = 224, means: Tuple[float] = (0.45, 0.45, 0.45), stds: Tuple[float] = (0.225, 0.225, 0.225), - min_short_side_scale: int = 256, - max_short_side_scale: int = 320, + min_short_side: int = 256, + max_short_side: int = 320, horizontal_flip_p: float = 0.5, + clip_duration: int = 1, ): """ Creates preprocess transforms. Args: - timesteps: number of frames in a video clip + num_timesteps: number of frames in a video clip input_size: model input isze means: mean of the video clip stds: standard deviation of the video clip min_short_side_scale: minimum short side of the video clip max_short_side_scale: maximum short side of the video clip horizontal_flip_p: probability of horizontal flip + clip_duration: duration of each video clip Properties: train_transform: transforms for training @@ -65,23 +67,24 @@ def __init__( """ super().__init__() - self.timesteps = timesteps + self.num_timesteps = num_timesteps self.input_size = input_size self.means = means self.stds = stds - self.min_short_side_scale = min_short_side_scale - self.max_short_side_scale = max_short_side_scale + self.min_short_side = min_short_side + self.max_short_side = max_short_side self.horizontal_flip_p = horizontal_flip_p + self.clip_duration = clip_duration # Transforms applied to train dataset. self.train_video_transform = Compose( [ - UniformTemporalSubsample(self.timesteps), + UniformTemporalSubsample(self.num_timesteps), Lambda(lambda x: x / 255.0), Normalize(self.means, self.stds), RandomShortSideScale( - min_size=self.min_short_side_scale, - max_size=self.max_short_side_scale, + min_size=self.min_short_side, + max_size=self.max_short_side, ), RandomCrop(self.input_size), RandomHorizontalFlip(p=self.horizontal_flip_p), @@ -93,10 +96,10 @@ def __init__( # Transforms applied on val dataset or for inference. self.val_video_transform = Compose( [ - UniformTemporalSubsample(self.timesteps), + UniformTemporalSubsample(self.num_timesteps), Lambda(lambda x: x / 255.0), Normalize(self.means, self.stds), - ShortSideScale(self.min_short_side_scale), + ShortSideScale(self.min_short_side), CenterCrop(self.input_size), ] ) @@ -109,16 +112,16 @@ def __init__( train_root: str, val_root: str, test_root: str = None, - clip_duration: int = 2, train_dataset_multiplier: int = 1, batch_size: int = 4, num_workers: int = 4, num_timesteps: int = 8, preprocess_input_size: int = 224, + preprocess_clip_duration: int = 1, preprocess_means: Tuple[float] = (0.45, 0.45, 0.45), preprocess_stds: Tuple[float] = (0.225, 0.225, 0.225), - preprocess_min_short_side_scale: int = 256, - preprocess_max_short_side_scale: int = 320, + preprocess_min_short_side: int = 256, + preprocess_max_short_side: int = 320, preprocess_horizontal_flip_p: float = 0.5, ): """ @@ -169,25 +172,26 @@ def __init__( Mean pixel value to be used during normalization. preprocess_stds: Tuple[float] Standard deviation pixel value to be used during normalization. - preprocess_min_short_side_scale: int + preprocess_min_short_side: int Minimum value of the short side of the clip after resizing. - preprocess_max_short_side_scale: int + preprocess_max_short_side: int Maximum value of the short side of the clip after resizing. preprocess_horizontal_flip_p: float Probability of horizontal flip. """ - self.preprocess_config = { - "timesteps": num_timesteps, + self.preprocessor_config = { + "num_timesteps": num_timesteps, "input_size": preprocess_input_size, "means": preprocess_means, "stds": preprocess_stds, - "min_short_side_scale": preprocess_min_short_side_scale, - "max_short_side_scale": preprocess_max_short_side_scale, + "min_short_side": preprocess_min_short_side, + "max_short_side": preprocess_max_short_side, "horizontal_flip_p": preprocess_horizontal_flip_p, + "clip_duration": preprocess_clip_duration, } - self.preprocess = VideoPreprocess.from_config(self.preprocess_config) + self.preprocessor = VideoPreprocessor.from_config(self.preprocessor_config) - self.dataloader_config = {"batch_size": batch_size, "num_workers": num_workers, "clip_duration": clip_duration} + self.dataloader_config = {"batch_size": batch_size, "num_workers": num_workers} self.train_root = train_root self.val_root = val_root @@ -211,12 +215,12 @@ def _get_train_dataloader(self): labeled_video_paths = LabeledVideoPaths.from_path(self.train_root) labeled_video_paths.path_prefix = "" video_sampler = torch.utils.data.RandomSampler - clip_sampler = pytorchvideo.data.make_clip_sampler("random", self.dataloader_config["clip_duration"]) + clip_sampler = pytorchvideo.data.make_clip_sampler("random", self.preprocessor_config["clip_duration"]) dataset = LabeledVideoDataset( labeled_video_paths, clip_sampler, video_sampler, - self.preprocess.train_transform, + self.preprocessor.train_transform, decode_audio=False, decoder="pyav", dataset_multiplier=self.train_dataset_multiplier, @@ -233,12 +237,15 @@ def _get_val_dataloader(self): labeled_video_paths = LabeledVideoPaths.from_path(self.val_root) labeled_video_paths.path_prefix = "" video_sampler = torch.utils.data.SequentialSampler - clip_sampler = pytorchvideo.data.make_clip_sampler("uniform", self.dataloader_config["clip_duration"]) + clip_sampler = pytorchvideo.data.clip_sampling.UniformClipSamplerTruncateFromStart( + clip_duration=self.preprocessor_config["clip_duration"], + truncation_duration=self.preprocessor_config["clip_duration"], + ) dataset = LabeledVideoDataset( labeled_video_paths, clip_sampler, video_sampler, - self.preprocess.val_transform, + self.preprocessor.val_transform, decode_audio=False, decoder="pyav", ) @@ -253,12 +260,15 @@ def _get_test_dataloader(self): labeled_video_paths = LabeledVideoPaths.from_path(self.test_root) labeled_video_paths.path_prefix = "" video_sampler = torch.utils.data.SequentialSampler - clip_sampler = pytorchvideo.data.make_clip_sampler("uniform", self.dataloader_config["clip_duration"]) + clip_sampler = pytorchvideo.data.clip_sampling.UniformClipSamplerTruncateFromStart( + clip_duration=self.preprocessor_config["clip_duration"], + truncation_duration=self.preprocessor_config["clip_duration"], + ) dataset = LabeledVideoDataset( labeled_video_paths, clip_sampler, video_sampler, - self.preprocess.val_transform, + self.preprocessor.val_transform, decode_audio=False, decoder="pyav", ) diff --git a/video_transformers/deployment/__init__.py b/video_transformers/deployment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/video_transformers/deployment/onnx.py b/video_transformers/deployment/onnx.py index 25a6eb4..0e97597 100644 --- a/video_transformers/deployment/onnx.py +++ b/video_transformers/deployment/onnx.py @@ -45,7 +45,7 @@ def export( from onnxruntime.quantization import quantize_dynamic - export_filename = Path(export_path).stem + f"_quantize.{Path(export_path).suffix}" + export_filename = Path(export_path).stem + f"_quantized.{Path(export_path).suffix}" target_model_path = Path(export_path).parent / export_filename diff --git a/video_transformers/modules.py b/video_transformers/modeling.py similarity index 86% rename from video_transformers/modules.py rename to video_transformers/modeling.py index 87f1a8e..f14be06 100644 --- a/video_transformers/modules.py +++ b/video_transformers/modeling.py @@ -1,6 +1,7 @@ import json import os -from typing import Dict, List, Union +from pathlib import Path +from typing import Dict, List, Tuple, Union import torch from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME @@ -9,7 +10,7 @@ from torch import nn import video_transformers.backbones.base -from video_transformers.deployment.onnx import export +import video_transformers.deployment.onnx from video_transformers.heads import LinearHead from video_transformers.utils.torch import get_num_total_params, get_num_trainable_params @@ -129,9 +130,9 @@ def from_config(cls, config: Dict) -> "VideoClassificationModel": backbone=backbone, head=head, neck=neck, - timesteps=config["num_timesteps"], - input_size=config["preprocess_input_size"], + preprocessor_config=config["preprocessor"], labels=config["labels"], + task=config["task"], ) @classmethod @@ -180,29 +181,30 @@ def __init__( backbone: Union[TimeDistributed, video_transformers.backbones.base.Backbone], head: LinearHead, neck=None, - timesteps: int = None, - input_size: int = None, labels: List[str] = None, + preprocessor_config: Dict = None, + task: str = None, ): """ Args: backbone: Backbone model. head: Head model. neck: Neck model. - timesteps: Number of input timesteps (required for onnx export). - input_size: Input size of model (required for onnx export). - labels: List of labels (required for onnx export). + labels: List of labels (required for onnx export and predict). + preprocessor_config: Preprocessor config (required for onnx export and predict). + task: Task name (required for predict). """ super().__init__() self.backbone = backbone self.neck = neck self.head = head - # required for exporting to ONNX - self.timesteps = timesteps - self.input_size = input_size + # required for exporting to ONNX and predict + self.preprocessor_config = preprocessor_config self.labels = labels + self.task = task + @property def num_features(self): return self.backbone.num_features @@ -238,6 +240,7 @@ def num_total_params(self): @property def config(self): config = {} + config["task"] = self.task config["backbone"] = self.backbone.config config["head"] = self.head.config if self.neck is not None: @@ -261,12 +264,18 @@ def to_onnx( export_filename: Filename to export model to. """ - export(self, quantize, opset_version, export_dir, export_filename) + video_transformers.deployment.onnx.export(self, quantize, opset_version, export_dir, export_filename) @property def example_input_array(self): - if self.timesteps and self.input_size: - return torch.rand(1, 3, self.timesteps, self.input_size, self.input_size) + if self.preprocessor_config: + return torch.rand( + 1, + 3, + self.preprocessor_config["num_timesteps"], + self.preprocessor_config["input_size"], + self.preprocessor_config["input_size"], + ) else: return None diff --git a/video_transformers/utils/dataset.py b/video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py similarity index 77% rename from video_transformers/utils/dataset.py rename to video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py index ba8ff21..8415c07 100644 --- a/video_transformers/utils/dataset.py +++ b/video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py @@ -8,22 +8,83 @@ import pathlib import zipfile from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast import torch from iopath.common.file_io import g_pathmgr from pytorchvideo.data.clip_sampling import ClipSampler from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset as LabeledVideoDataset_ from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths as LabeledVideoPaths_ -from pytorchvideo.data.labeled_video_paths import make_dataset_from_video_folders from pytorchvideo.data.video import VideoPathHandler -from torchvision.datasets.folder import make_dataset +from torchvision.datasets.folder import find_classes, has_file_allowed_extension, make_dataset from video_transformers.utils.file import download_file logger = logging.getLogger(__name__) +def make_dataset_from_video_folders( + directory: str, + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + elif not class_to_idx: + raise ValueError("'class_to_index' must have at least one entry to collect any samples.") + + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + + if extensions is not None: + + def is_valid_folder(x: str) -> bool: + if g_pathmgr.ls(x): + return has_file_allowed_extension(g_pathmgr.ls(x)[0], extensions) + else: + return False + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, fnames, _ in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_folder(path): + item = path, class_index + instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + class LabeledVideoPaths(LabeledVideoPaths_): """ LabeledVideoPaths contains pairs of video path and integer index label. @@ -42,8 +103,10 @@ def from_path(cls, data_path: str) -> LabeledVideoPaths: file_path (str): The path to the file to be read. """ - if g_pathmgr.isfile(data_path): + if g_pathmgr.isfile(data_path) and data_path.endswith(".csv"): return LabeledVideoPaths.from_csv(data_path) + elif g_pathmgr.isfile(data_path) and has_file_allowed_extension(data_path, extensions=("mp4", "avi")): + return LabeledVideoPaths.from_video_path(data_path) elif g_pathmgr.isdir(data_path): class_0 = g_pathmgr.ls(data_path)[0] video_0 = g_pathmgr.ls(pathlib.Path(data_path) / class_0)[0] @@ -55,6 +118,15 @@ def from_path(cls, data_path: str) -> LabeledVideoPaths: else: raise FileNotFoundError(f"{data_path} not found.") + @classmethod + def from_video_path(cls, video_path: str) -> LabeledVideoPaths: + """ + Creates a LabeledVideoPaths object from a single video path. + Args: + video_path (str): The path to the video. + """ + return LabeledVideoPaths([(video_path, -1)]) + @classmethod def from_directory(cls, dir_path: str) -> LabeledVideoPaths: """ @@ -188,6 +260,8 @@ def __init__( self._last_clip_end_time = None self.video_path_handler = VideoPathHandler() + self._len = None + @property def labels(self): """ @@ -206,12 +280,18 @@ def videos_per_class(self): return [class_id_to_number[class_id] for class_id in range(max(class_ids) + 1)] def __len__(self): + if self._len is not None: + return self._len + if isinstance(self.video_sampler, torch.utils.data.SequentialSampler): - return sum([1 for _ in self]) + # self._len = sum([1 for _ in self]) + self._len = len(self.video_sampler) + return self._len elif isinstance(self.video_sampler, torch.utils.data.RandomSampler): - return len(self.video_sampler) + self._len = len(self.video_sampler) + return self._len else: - raise ValueError(f"Lenght calculation not implemented for sampler: {type(self.video_sampler)}.") + raise ValueError(f"Length calculation not implemented for sampler: {type(self.video_sampler)}.") def labeled_video_dataset( diff --git a/video_transformers/tasks/base.py b/video_transformers/tasks/base.py index 6e2da95..c92b64a 100644 --- a/video_transformers/tasks/base.py +++ b/video_transformers/tasks/base.py @@ -10,3 +10,19 @@ def validation_step(self, batch): def on_validation_epoch_end(self): raise NotImplementedError() + + @property + def train_metrics(self): + raise NotImplementedError() + + @property + def val_metrics(self): + raise NotImplementedError() + + @property + def last_train_result(self): + return None + + @property + def last_val_result(self): + return None diff --git a/video_transformers/tasks/single_label_classification.py b/video_transformers/tasks/single_label_classification.py index ca7f2c3..14a18f7 100644 --- a/video_transformers/tasks/single_label_classification.py +++ b/video_transformers/tasks/single_label_classification.py @@ -34,6 +34,7 @@ def __init__(self, *args, **kwargs): self._val_metrics = Combine(["f1", "precision", "recall"]) self._last_train_result = None self._last_val_result = None + self.task = "single_label_classification" @property def train_metrics(self): @@ -55,8 +56,8 @@ def training_step(self, batch): inputs = batch["video"] labels = batch["label"] outputs = self.model(inputs) - propabilities = torch.nn.functional.softmax(outputs, dim=1) - predictions = propabilities.argmax(dim=-1) + probabilities = torch.nn.functional.softmax(outputs, dim=1) + predictions = probabilities.argmax(dim=-1) # gather all predictions and targets all_predictions = self.accelerator.gather(predictions) all_labels = self.accelerator.gather(labels) @@ -83,10 +84,9 @@ def on_training_epoch_end(self): def validation_step(self, batch): inputs = batch["video"] labels = batch["label"] - with torch.no_grad(): - outputs = self.model(inputs) - propabilities = torch.nn.functional.softmax(outputs, dim=1) - predictions = propabilities.argmax(dim=-1) + outputs = self.model(inputs) + probabilities = torch.nn.functional.softmax(outputs, dim=1) + predictions = probabilities.argmax(dim=-1) # gather all predictions and targets all_predictions = self.accelerator.gather(predictions) all_labels = self.accelerator.gather(labels) diff --git a/video_transformers/trainer.py b/video_transformers/trainer.py index 597cdbd..4ed3b90 100644 --- a/video_transformers/trainer.py +++ b/video_transformers/trainer.py @@ -10,7 +10,7 @@ from tqdm.auto import tqdm import video_transformers.data -from video_transformers.modules import VideoClassificationModel +from video_transformers.modeling import VideoClassificationModel from video_transformers.schedulers import get_linear_scheduler_with_warmup from video_transformers.tasks.single_label_classification import SingleLabelClassificationTaskMixin from video_transformers.tracking import TensorBoardTracker @@ -223,11 +223,14 @@ def save_checkpoint(self, save_path: Union[str, Path]): data_config = self.hparams["data"] config.update( { - "preprocess_means": data_config["preprocess_config"]["means"], - "preprocess_stds": data_config["preprocess_config"]["stds"], - "preprocess_min_short_side_scale": data_config["preprocess_config"]["min_short_side_scale"], - "preprocess_input_size": data_config["preprocess_config"]["input_size"], - "num_timesteps": data_config["preprocess_config"]["timesteps"], + "preprocessor": { + "means": data_config["preprocess_config"]["means"], + "stds": data_config["preprocess_config"]["stds"], + "min_short_side": data_config["preprocess_config"]["min_short_side"], + "input_size": data_config["preprocess_config"]["input_size"], + "clip_duration": data_config["preprocess_config"]["clip_duration"], + "num_timesteps": data_config["preprocess_config"]["num_timesteps"], + }, "labels": data_config["labels"], } ) @@ -298,13 +301,12 @@ def _one_val_loop(self, pbar, len_val_dataloader): return val_loss def fit(self): - self.accelerator.print("Calculating training & validation sizes for better experience...") - len_train_dataloader = len(self.train_dataloader) - len_val_dataloader = len(self.val_dataloader) - self.accelerator.print(f"Trainable parameteres: {self.model.num_trainable_params}") self.accelerator.print(f"Total parameteres: {self.model.num_total_params}") + len_train_dataloader = len(self.train_dataloader) + len_val_dataloader = len(self.val_dataloader) + for epoch in range(self.starting_epoch, self.hparams["max_epochs"]): self._log_last_lr() @@ -326,7 +328,8 @@ def fit(self): pbar.set_postfix({"loss": f"{train_loss:.4f}", train_score[0]: f"{train_score[1]:.3f}"}) # val loop - val_loss = self._one_val_loop(pbar, len_val_dataloader) + with torch.inference_mode(): + val_loss = self._one_val_loop(pbar, len_val_dataloader) # call the end of val epoch hook self.on_validation_epoch_end()