Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix splits with overfit batches #375

Merged
merged 2 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `deeplabv3`, `lraspp`, and `unet` backbones for the `SemanticSegmentation` task ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))

### Changed

- Changed the installation command for extra features ([#346](https://github.com/PyTorchLightning/lightning-flash/pull/346))
Expand All @@ -26,6 +22,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `SemanticSegmentation` backbone names `torchvision/fcn_resnet50` and `torchvision/fcn_resnet101`, use `fc_resnet50` and `fcn_resnet101` instead ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
- Fixed a bug where using `val_split` with `overfit_batches` would give an infinite recursion ([#375](https://github.com/PyTorchLightning/lightning-flash/pull/375))


## [0.3.0] - 2021-05-20

Expand Down
8 changes: 4 additions & 4 deletions flash/core/data/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SplitDataset(Dataset):

dataset: A dataset to be splitted
indices: List of indices to expose from the dataset
use_duplicated_indices: Wether to allow duplicated indices.
use_duplicated_indices: Whether to allow duplicated indices.

Example::

Expand Down Expand Up @@ -41,9 +41,9 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices
self.indices = indices

def __getattr__(self, key: str):
if key in self._INTERNAL_KEYS:
return getattr(self, key)
return getattr(self.dataset, key)
if key not in self._INTERNAL_KEYS:
return self.dataset.__getattribute__(key)
raise AttributeError

def __setattr__(self, name: str, value: Any) -> None:
if name in self._INTERNAL_KEYS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

import numpy as np
import pytest
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -19,25 +21,12 @@
from flash.core.data.splits import SplitDataset


def test_split_dataset(tmpdir):

def test_split_dataset():
train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1)
assert len(train_ds) == 90
assert len(val_ds) == 10
assert len(np.unique(train_ds.indices)) == len(train_ds.indices)

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(range(100), indices=[100])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(range(50), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True)

class Dataset:

def __init__(self):
Expand All @@ -57,3 +46,27 @@ def __len__(self):

split_dataset.is_passed_down = True
assert split_dataset.dataset.is_passed_down


def test_misconfiguration():
with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(range(100), indices=[100])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(range(50), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True)

with pytest.raises(MisconfigurationException, match="indices should be a list"):
SplitDataset(list(range(100)), indices="not a list")


def test_deepcopy():
"""Tests that deepcopy works with the ``SplitDataset``."""
dataset = list(range(100))
train_ds, val_ds = DataModule._split_train_val(dataset, val_split=0.1)
deepcopy(train_ds)