Skip to content

Commit

Permalink
Expose extract_batch_size method and add corresponding tests. (#8357)
Browse files Browse the repository at this point in the history
* expose extract_batch and make public

* first pass

* early return

* add changelog

* move to utilities/data.py

* add test_data.py

* tests are passing

* precommit hook

* address pep8 failure

Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
kandluis and carmocca authored Jul 13, 2021
1 parent ff2aed6 commit 000fbe6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add `extract_batch_size` utility and corresponding tests to extract batch dimension from multiple batch types. ([#8357](https://github.com/PyTorchLightning/pytorch-lightning/pull/8357/))

- Add support for named parameter groups in `LearningRateMonitor` ([#7987](https://github.com/PyTorchLightning/pytorch-lightning/pull/7987))


Expand Down
26 changes: 3 additions & 23 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from collections.abc import Generator
from dataclasses import asdict, dataclass, replace
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import torch
from torchmetrics import Metric

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.enums import LightningEnum
Expand Down Expand Up @@ -589,31 +590,10 @@ def fn(item: ResultMetric) -> None:

def extract_batch_size(self, batch: Any) -> None:
try:
self.batch_size = self._extract_batch_size(batch)
self.batch_size = extract_batch_size(batch)
except RecursionError:
self.batch_size = 1

def _extract_batch_size(self, batch: Any) -> int:
"""
Recursively unpack a batch to find a torch.Tensor.
Returns:
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
"""
if isinstance(batch, torch.Tensor):
size = batch.size(0)
elif isinstance(batch, str):
return len(batch)
elif isinstance(batch, dict):
sample = next(iter(batch.values()), 1)
size = self._extract_batch_size(sample)
elif isinstance(batch, Iterable):
sample = next(iter(batch), 1)
size = self._extract_batch_size(sample)
else:
size = 1
return size

def to(self, *args, **kwargs) -> 'ResultCollection':
"""Move all data to the given device."""

Expand Down
26 changes: 25 additions & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Any, Iterable, Mapping, Union

import torch
from torch.utils.data import DataLoader, IterableDataset

from pytorch_lightning.utilities import rank_zero_warn

BType = Union[torch.Tensor, str, Mapping[Any, 'BType'], Iterable['BType']]


def extract_batch_size(batch: BType) -> int:
"""
Recursively unpack a batch to find a torch.Tensor.
Returns:
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
"""
if isinstance(batch, torch.Tensor):
return batch.size(0)
if isinstance(batch, str):
return len(batch)
if isinstance(batch, dict):
sample = next(iter(batch.values()), 1)
return extract_batch_size(sample)
if isinstance(batch, Iterable):
sample = next(iter(batch), 1)
return extract_batch_size(sample)

return 1


def has_iterable_dataset(dataloader: DataLoader):
return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset)
Expand Down
21 changes: 21 additions & 0 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

from pytorch_lightning.utilities.data import extract_batch_size


def test_extract_batch_size():
"""Tests the behavior of extracting the batch size."""
batch = "test string"
assert extract_batch_size(batch) == 11

batch = torch.zeros(11, 10, 9, 8)
assert extract_batch_size(batch) == 11

batch = {'test': torch.zeros(11, 10)}
assert extract_batch_size(batch) == 11

batch = [torch.zeros(11, 10)]
assert extract_batch_size(batch) == 11

batch = {'test': [{'test': [torch.zeros(11, 10)]}]}
assert extract_batch_size(batch) == 11

0 comments on commit 000fbe6

Please sign in to comment.