Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix RTD build (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 22, 2021
1 parent ae525d9 commit 40cb3ab
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
36 changes: 30 additions & 6 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_epoch_end(self, outputs) -> None:
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -120,13 +120,21 @@ def process_train_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
trainer,
batch_size,
num_workers,
pin_memory,
collate_fn,
shuffle,
drop_last,
sampler,
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -136,13 +144,21 @@ def process_val_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
trainer,
batch_size,
num_workers,
pin_memory,
collate_fn,
shuffle,
drop_last,
sampler,
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -152,7 +168,15 @@ def process_test_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
trainer,
batch_size,
num_workers,
pin_memory,
collate_fn,
shuffle,
drop_last,
sampler,
)

def process_predict_dataset(
Expand Down
9 changes: 3 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def process_train_dataset(
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
persistent_workers: bool = True,
) -> DataLoader:
return self._process_dataset(
dataset,
Expand All @@ -155,7 +154,7 @@ def process_train_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=persistent_workers and num_workers > 0,
persistent_workers=num_workers > 0,
)

def process_val_dataset(
Expand All @@ -169,7 +168,6 @@ def process_val_dataset(
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
persistent_workers: bool = True,
) -> DataLoader:
return self._process_dataset(
dataset,
Expand All @@ -180,7 +178,7 @@ def process_val_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=persistent_workers and num_workers > 0,
persistent_workers=num_workers > 0,
)

def process_test_dataset(
Expand All @@ -194,7 +192,6 @@ def process_test_dataset(
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
persistent_workers: bool = True,
) -> DataLoader:
return self._process_dataset(
dataset,
Expand All @@ -205,7 +202,7 @@ def process_test_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=persistent_workers and num_workers > 0,
persistent_workers=num_workers > 0,
)

def process_predict_dataset(
Expand Down
17 changes: 7 additions & 10 deletions flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _labels_to_indices(data):

def _convert_dataset(
self,
trainer: flash.Trainer,
trainer: "flash.Trainer",
dataset: BaseAutoDataset,
ways: int,
shots: int,
Expand Down Expand Up @@ -334,14 +334,14 @@ def _sanetize_batch_size(self, batch_size: int) -> int:
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool,
drop_last: bool,
sampler: Optional[Sampler],
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
dataset = self._convert_dataset(
trainer=trainer,
Expand All @@ -366,13 +366,12 @@ def process_train_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=True,
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand Down Expand Up @@ -404,13 +403,12 @@ def process_val_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=True,
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand Down Expand Up @@ -442,7 +440,6 @@ def process_test_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=True,
)

def process_predict_dataset(
Expand Down

0 comments on commit 40cb3ab

Please sign in to comment.