Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix splits with overfit batches (#375)
Browse files Browse the repository at this point in the history
* Fix splits with overfit batches

* Update CHANGELOG.md
  • Loading branch information
ethanwharris authored Jun 8, 2021
1 parent bded76d commit 4b6d2ec
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `deeplabv3`, `lraspp`, and `unet` backbones for the `SemanticSegmentation` task ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))

### Changed

- Changed the installation command for extra features ([#346](https://github.com/PyTorchLightning/lightning-flash/pull/346))
Expand All @@ -26,6 +22,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `SemanticSegmentation` backbone names `torchvision/fcn_resnet50` and `torchvision/fcn_resnet101`, use `fc_resnet50` and `fcn_resnet101` instead ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
- Fixed a bug where using `val_split` with `overfit_batches` would give an infinite recursion ([#375](https://github.com/PyTorchLightning/lightning-flash/pull/375))


## [0.3.0] - 2021-05-20

Expand Down
8 changes: 4 additions & 4 deletions flash/core/data/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SplitDataset(Dataset):
dataset: A dataset to be splitted
indices: List of indices to expose from the dataset
use_duplicated_indices: Wether to allow duplicated indices.
use_duplicated_indices: Whether to allow duplicated indices.
Example::
Expand Down Expand Up @@ -41,9 +41,9 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices
self.indices = indices

def __getattr__(self, key: str):
if key in self._INTERNAL_KEYS:
return getattr(self, key)
return getattr(self.dataset, key)
if key not in self._INTERNAL_KEYS:
return self.dataset.__getattribute__(key)
raise AttributeError

def __setattr__(self, name: str, value: Any) -> None:
if name in self._INTERNAL_KEYS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

import numpy as np
import pytest
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -19,25 +21,12 @@
from flash.core.data.splits import SplitDataset


def test_split_dataset(tmpdir):

def test_split_dataset():
train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1)
assert len(train_ds) == 90
assert len(val_ds) == 10
assert len(np.unique(train_ds.indices)) == len(train_ds.indices)

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(range(100), indices=[100])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(range(50), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True)

class Dataset:

def __init__(self):
Expand All @@ -57,3 +46,27 @@ def __len__(self):

split_dataset.is_passed_down = True
assert split_dataset.dataset.is_passed_down


def test_misconfiguration():
with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(range(100), indices=[100])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(range(50), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 49]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True)

with pytest.raises(MisconfigurationException, match="indices should be a list"):
SplitDataset(list(range(100)), indices="not a list")


def test_deepcopy():
"""Tests that deepcopy works with the ``SplitDataset``."""
dataset = list(range(100))
train_ds, val_ds = DataModule._split_train_val(dataset, val_split=0.1)
deepcopy(train_ds)

0 comments on commit 4b6d2ec

Please sign in to comment.