Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/move-logger-profiler-teardown
Browse files Browse the repository at this point in the history
  • Loading branch information
tangbinh committed Aug 4, 2021
2 parents d404e45 + 963c267 commit 580c882
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))


- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))


- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))


Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def apply_to_collection(
if isinstance(data, Mapping):
out = []
for k, v in data.items():
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append((k, v))
return elem_type(OrderedDict(out))
Expand All @@ -111,15 +113,25 @@ def apply_to_collection(
if is_namedtuple or is_sequence:
out = []
for d in data:
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)

if _is_dataclass_instance(data):
out = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
getattr(data, field),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs
)
if include_none or v is not None:
out[field] = v
return elem_type(**out)
Expand Down
6 changes: 3 additions & 3 deletions tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,17 @@ def __init__(self, initial_dict):


def test_apply_to_collection_include_none():
to_reduce = [1, 2, 3.4, 5.6, 7]
to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]

def fn(x):
if isinstance(x, float):
return x

reduced = apply_to_collection(to_reduce, (int, float), fn)
assert reduced == [None, None, 3.4, 5.6, None]
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]

reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
assert reduced == [3.4, 5.6]
assert reduced == [3.4, 5.6, (9.1, {})]


def test_apply_to_collections():
Expand Down

0 comments on commit 580c882

Please sign in to comment.