Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into fix/yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jan 24, 2021
2 parents b8eda67 + ef7345d commit 62f93a3
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 47 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
_WANDB_AVAILABLE = _module_available("wandb")

try:
import wandb
from wandb.wandb_run import Run

import wandb
except ImportError:
# needed for test mocks, these tests shall be updated
wandb, Run = None, None
Expand Down
41 changes: 27 additions & 14 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -306,12 +306,8 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,

if isinstance(all_lengths, (int, float)):
length = all_lengths

elif isinstance(all_lengths, Mapping):
length = compute_func(all_lengths.values())

elif isinstance(all_lengths, Sequence):
length = compute_func(all_lengths)
else:
length = _nested_calc_num_data(all_lengths, compute_func)

return length

Expand Down Expand Up @@ -437,13 +433,8 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
if isinstance(all_lengths, (int, float)):
return all_lengths

elif isinstance(all_lengths, Mapping):
return min(all_lengths.values())

elif isinstance(all_lengths, Sequence):
return min(all_lengths)

raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping')
else:
return _nested_calc_num_data(all_lengths, min)

def __len__(self) -> int:
return self._calc_num_batches(self.loaders)
Expand Down Expand Up @@ -516,3 +507,25 @@ def create_loader_iters(
"""
# dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))


def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):

if isinstance(data, int):
return data

if isinstance(data, Mapping):
data = list(data.values())

if not isinstance(data, Sequence):
raise TypeError(f'Expected data to be int, Sequence or Mapping, but got {type(data).__name__}')

new_data = []

for x in data:
if isinstance(x, (Mapping, Sequence)):
new_data.append(_nested_calc_num_data(x, compute_func))
else:
new_data.append(x)

return compute_func(new_data)
97 changes: 65 additions & 32 deletions tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.utils.data import TensorDataset

from pytorch_lightning.trainer.supporters import (
_nested_calc_num_data,
CombinedDataset,
CombinedLoader,
CombinedLoaderIterator,
Expand Down Expand Up @@ -61,7 +62,7 @@ def test_cycle_iterator():
def test_none_length_cycle_iterator():
"""Test the infinite cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100))
assert iterator.__len__() == float('inf')
assert iterator.__len__() == float("inf")

# test infinite loop
for idx, item in enumerate(iterator):
Expand All @@ -70,12 +71,15 @@ def test_none_length_cycle_iterator():
assert item == 0


@pytest.mark.parametrize(['dataset_1', 'dataset_2'], [
([list(range(10)), list(range(20))]),
([range(10), range(20)]),
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))])
])
@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[
([list(range(10)), list(range(20))]),
([range(10), range(20)]),
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]),
],
)
def test_combined_dataset(dataset_1, dataset_2):
"""Verify the length of the CombinedDataset"""
datasets = [dataset_1, dataset_2]
Expand All @@ -86,83 +90,91 @@ def test_combined_dataset(dataset_1, dataset_2):


def test_combined_dataset_length_mode_error():
with pytest.raises(MisconfigurationException, match='Invalid Mode'):
CombinedDataset._calc_num_data([range(10)], 'test')
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
CombinedDataset._calc_num_data([range(10)], "test")


def test_combined_loader_iterator_dict_min_size():
"""Test `CombinedLoaderIterator` given mapping loaders"""
loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4),
'b': torch.utils.data.DataLoader(range(20), batch_size=5)}
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}

combined_iter = CombinedLoaderIterator(loaders)

for idx, item in enumerate(combined_iter):
assert isinstance(item, dict)
assert len(item) == 2
assert 'a' in item and 'b' in item
assert "a" in item and "b" in item

assert idx == min(len(loaders['a']), len(loaders['b'])) - 1
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1


def test_combined_loader_init_mode_error():
"""Test the ValueError when constructing `CombinedLoader`"""
with pytest.raises(MisconfigurationException, match='selected unsupported mode'):
CombinedLoader([range(10)], 'testtt')
with pytest.raises(MisconfigurationException, match="selected unsupported mode"):
CombinedLoader([range(10)], "testtt")


def test_combined_loader_loader_type_error():
"""Test the ValueError when wrapping the loaders"""
with pytest.raises(ValueError, match='Invalid Datatype'):
CombinedLoader(None, 'max_size_cycle')
with pytest.raises(ValueError, match="Invalid Datatype"):
CombinedLoader(None, "max_size_cycle")


def test_combined_loader_calc_length_mode_error():
"""Test the ValueError when calculating the number of batches"""
with pytest.raises(TypeError, match='Got Type NoneType, but expected one of Sequence, int or Mapping'):
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
CombinedLoader._calc_num_batches(None)


def test_combined_loader_dict_min_size():
"""Test `CombinedLoader` of mode 'min_size' given mapping loaders"""
loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4),
'b': torch.utils.data.DataLoader(range(20), batch_size=5)}
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}

combined_loader = CombinedLoader(loaders, 'min_size')
combined_loader = CombinedLoader(loaders, "min_size")

assert len(combined_loader) == min([len(v) for v in loaders.values()])

for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert 'a' in item and 'b' in item
assert "a" in item and "b" in item

assert idx == len(combined_loader) - 1


def test_combined_loader_dict_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders"""
loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4),
'b': torch.utils.data.DataLoader(range(20), batch_size=5)}
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}

combined_loader = CombinedLoader(loaders, 'max_size_cycle')
combined_loader = CombinedLoader(loaders, "max_size_cycle")

assert len(combined_loader) == max([len(v) for v in loaders.values()])

for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert 'a' in item and 'b' in item
assert "a" in item and "b" in item

assert idx == len(combined_loader) - 1


def test_combined_loader_sequence_min_size():
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders"""
loaders = [torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5)]
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]

combined_loader = CombinedLoader(loaders, 'min_size')
combined_loader = CombinedLoader(loaders, "min_size")

assert len(combined_loader) == min([len(v) for v in loaders])

Expand All @@ -175,10 +187,12 @@ def test_combined_loader_sequence_min_size():

def test_combined_loader_sequence_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders"""
loaders = [torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5)]
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]

combined_loader = CombinedLoader(loaders, 'max_size_cycle')
combined_loader = CombinedLoader(loaders, "max_size_cycle")

assert len(combined_loader) == max([len(v) for v in loaders])

Expand All @@ -187,3 +201,22 @@ def test_combined_loader_sequence_max_size_cycle():
assert len(item) == 2

assert idx == len(combined_loader) - 1


@pytest.mark.parametrize(
["input_data", "compute_func", "expected_length"],
[
([*range(10), list(range(1, 20))], min, 0),
([*range(10), list(range(1, 20))], max, 19),
([*range(10), {str(i): i for i in range(1, 20)}], min, 0),
([*range(10), {str(i): i for i in range(1, 20)}], max, 19),
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0),
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19),
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0),
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19),
],
)
def test_nested_calc_num_data(input_data, compute_func, expected_length):
calculated_length = _nested_calc_num_data(input_data, compute_func)

assert calculated_length == expected_length

0 comments on commit 62f93a3

Please sign in to comment.