Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Nov 5, 2021
1 parent 6716308 commit d2d1d7a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
18 changes: 5 additions & 13 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader

import flash
from flash.core.data.data_module import DataModule
from flash.core.data.new_data_module import DataModule as NewDataModule
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE

Expand Down Expand Up @@ -287,18 +286,11 @@ def request_dataloader(
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)

dataloader = None
if _PL_GREATER_EQUAL_1_5_0:
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
if (
not source.is_module()
or not isinstance(source.instance, DataModule)
or isinstance(source.instance, (LightningModule, NewDataModule))
):
dataloader = source.dataloader()

if dataloader is None:
if is_overridden(hook, model):
dataloader = self.call_hook(hook, pl_module=model)
elif _PL_GREATER_EQUAL_1_5_0:
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
dataloader = source.dataloader()

if isinstance(dataloader, tuple):
dataloader = list(dataloader)
Expand Down
7 changes: 6 additions & 1 deletion flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.utilities.model_helpers import is_overridden

import flash
from flash.core.data.data_pipeline import DataLoaderGetter
Expand Down Expand Up @@ -166,7 +167,11 @@ def _reset_testing(self):
def _reset_dataloader_for_stage(self, running_state: RunningStage):
dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
# If the dataloader exists, we reset it.
dataloader = getattr(self.trainer.datamodule, dataloader_name, None)
dataloader = (
getattr(self.trainer.datamodule, dataloader_name)
if is_overridden(dataloader_name, self.trainer.datamodule)
else None
)
if dataloader:
if _PL_GREATER_EQUAL_1_5_0:
setattr(
Expand Down

0 comments on commit d2d1d7a

Please sign in to comment.