Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
sync with master and fix val_split
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Apr 28, 2021
1 parent d1a91fd commit 7343887
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,29 @@ def __init__(
predict_transform: Dictionary with the set of transforms to apply during prediction.
image_size: A tuple with the expected output image size.
"""
self.image_size = image_size

train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms(
train_transform, val_transform, test_transform, predict_transform, image_size
)
super().__init__(train_transform, val_transform, test_transform, predict_transform)

# TODO: this is kind of boilerplate, let's simplify
def get_state_dict(self) -> Dict[str, Any]:
return {
"train_transform": self._train_transform,
"val_transform": self._val_transform,
"test_transform": self._test_transform,
"predict_transform": self._predict_transform,
"image_size": self.image_size
}

# TODO: this is kind of boilerplate, let's simplify
@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):
return cls(**state_dict)

# TODO: this is kind of boilerplate, let's simplify
def _resolve_transforms(
self,
train_transform: Optional[Union[str, Dict]] = None,
Expand Down Expand Up @@ -171,7 +189,6 @@ def from_filepaths(
image_size: Tuple[int, int] = (196, 196),
batch_size: int = 64,
num_workers: Optional[int] = None,
#seed: Optional[int] = 42, # SEED NEVER USED
data_fetcher: BaseDataFetcher = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED
Expand Down Expand Up @@ -229,13 +246,12 @@ def from_filepaths(
train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None,
val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None,
test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None,
# predict_load_data_input=predict_filepaths, # TODO: is it really used ?
predict_load_data_input=predict_filepaths, # TODO: is it really used ?
batch_size=batch_size,
num_workers=num_workers,
data_fetcher=data_fetcher,
preprocess=preprocess,
#seed=seed, # THIS CRASHES
#val_split=val_split, # THIS CRASHES
val_split=val_split,
**kwargs, # TODO: remove and make explicit params
)

Expand Down

0 comments on commit 7343887

Please sign in to comment.