From 7343887ac780a2cd16a8d0b0f1d1d05bbe7d4b1d Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 21:40:59 +0200 Subject: [PATCH] sync with master and fix val_split --- flash/vision/segmentation/data.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 8584e537ee..ebc6e67cb6 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -56,11 +56,29 @@ def __init__( predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. """ + self.image_size = image_size + train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform, image_size ) super().__init__(train_transform, val_transform, test_transform, predict_transform) + # TODO: this is kind of boilerplate, let's simplify + def get_state_dict(self) -> Dict[str, Any]: + return { + "train_transform": self._train_transform, + "val_transform": self._val_transform, + "test_transform": self._test_transform, + "predict_transform": self._predict_transform, + "image_size": self.image_size + } + + # TODO: this is kind of boilerplate, let's simplify + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + # TODO: this is kind of boilerplate, let's simplify def _resolve_transforms( self, train_transform: Optional[Union[str, Dict]] = None, @@ -171,7 +189,6 @@ def from_filepaths( image_size: Tuple[int, int] = (196, 196), batch_size: int = 64, num_workers: Optional[int] = None, - #seed: Optional[int] = 42, # SEED NEVER USED data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED @@ -229,13 +246,12 @@ def from_filepaths( train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - # predict_load_data_input=predict_filepaths, # TODO: is it really used ? + predict_load_data_input=predict_filepaths, # TODO: is it really used ? batch_size=batch_size, num_workers=num_workers, data_fetcher=data_fetcher, preprocess=preprocess, - #seed=seed, # THIS CRASHES - #val_split=val_split, # THIS CRASHES + val_split=val_split, **kwargs, # TODO: remove and make explicit params )