Skip to content

Commit

Permalink
fix recursive call for apply_to_collection(include_none=False) (#8719)
Browse files Browse the repository at this point in the history
(cherry picked from commit 963c267)
  • Loading branch information
awaelchli authored and lexierule committed Aug 11, 2021
1 parent f4def30 commit 58d7c56
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))


- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))



## [1.4.0] - 2021-07-27

### Added
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def apply_to_collection(
if isinstance(data, Mapping):
out = []
for k, v in data.items():
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append((k, v))
return elem_type(OrderedDict(out))
Expand All @@ -111,15 +113,25 @@ def apply_to_collection(
if is_namedtuple or is_sequence:
out = []
for d in data:
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)

if _is_dataclass_instance(data):
out = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
getattr(data, field),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs
)
if include_none or v is not None:
out[field] = v
return elem_type(**out)
Expand Down
6 changes: 3 additions & 3 deletions tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,17 @@ def __init__(self, initial_dict):


def test_apply_to_collection_include_none():
to_reduce = [1, 2, 3.4, 5.6, 7]
to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]

def fn(x):
if isinstance(x, float):
return x

reduced = apply_to_collection(to_reduce, (int, float), fn)
assert reduced == [None, None, 3.4, 5.6, None]
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]

reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
assert reduced == [3.4, 5.6]
assert reduced == [3.4, 5.6, (9.1, {})]


def test_apply_to_collections():
Expand Down

0 comments on commit 58d7c56

Please sign in to comment.