Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix PL compatibility #690

Merged
merged 8 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where `drop_last` would be set to True during prediction and testing ([#671](https://github.com/PyTorchLightning/lightning-flash/pull/671))

- Fixed a bug where flash was not compatible with pytorch-lightning >= 1.4.3 ([#690](https://github.com/PyTorchLightning/lightning-flash/pull/690))

## [0.4.0] - 2021-06-22

### Added
Expand Down
10 changes: 8 additions & 2 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ def _attach_preprocess_to_model(
dataloader = dataloader[0]

if isinstance(dataloader, DataLoader):
dataloader = _PatchDataLoader(dataloader)
try:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
dataloader = _PatchDataLoader(dataloader, stage)
except TypeError:
dataloader = _PatchDataLoader(dataloader)

self._set_loader(model, whole_attr_name, dataloader)

Expand Down Expand Up @@ -536,7 +539,10 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
dataloader = dataloader[0]

if isinstance(dataloader, DataLoader):
dataloader = _PatchDataLoader(dataloader)
try:
dataloader = _PatchDataLoader(dataloader, stage)
except TypeError:
dataloader = _PatchDataLoader(dataloader)

self._set_loader(model, whole_attr_name, dataloader)

Expand Down