Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix drop last for predicting and testing (#671)
Browse files Browse the repository at this point in the history
* Fix drop last for predicting and testing

* Update CHANGELOG.md

* Update CHANGELOG.md

* Fixes
  • Loading branch information
ethanwharris authored Aug 17, 2021
1 parent 9b86a0f commit 4e89a37
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where it was not possible to pass no metrics to the `ImageClassifier` or `TestClassifier` ([#660](https://github.com/PyTorchLightning/lightning-flash/pull/660))

- Fixed a bug where `drop_last` would be set to True during prediction and testing ([#671](https://github.com/PyTorchLightning/lightning-flash/pull/671))

## [0.4.0] - 2021-06-22

### Added
Expand Down
4 changes: 2 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def process_test_dataset(
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = True,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self._process_dataset(
Expand All @@ -204,7 +204,7 @@ def process_predict_dataset(
pin_memory: bool = False,
collate_fn: Callable = None,
shuffle: bool = False,
drop_last: bool = True,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self._process_dataset(
Expand Down
20 changes: 12 additions & 8 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from itertools import chain
from numbers import Number
from pathlib import Path
from typing import Any, Tuple
Expand Down Expand Up @@ -52,14 +53,20 @@ class Image:


class DummyDataset(torch.utils.data.Dataset):
def __init__(self, num_samples: int = 9):
self.num_samples = num_samples

def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item()

def __len__(self) -> int:
return 9
return self.num_samples


class PredictDummyDataset(DummyDataset):
def __init__(self, num_samples: int):
super().__init__(num_samples)

def __getitem__(self, index: int) -> Tensor:
return torch.rand(1, 28, 28)

Expand Down Expand Up @@ -211,15 +218,12 @@ def _rand_image():
def test_classification_task_trainer_predict(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
task = ClassificationTask(model)
ds = PredictDummyDataset()
batch_size = 3
predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
ds = PredictDummyDataset(10)
batch_size = 6
predict_dl = task.process_predict_dataset(ds, batch_size=batch_size)
trainer = pl.Trainer(default_root_dir=tmpdir)
predictions = trainer.predict(task, predict_dl)
assert len(predictions) == len(ds) // batch_size
for batch_pred in predictions:
assert len(batch_pred) == batch_size
assert all(y < 10 for y in batch_pred)
assert len(list(chain.from_iterable(predictions))) == 10


def test_task_datapipeline_save(tmpdir):
Expand Down

0 comments on commit 4e89a37

Please sign in to comment.