Skip to content

Commit 71a1fbc

Browse files
awaelchliSeanNaren
authored and
SeanNaren
committed
fix recursive call for apply_to_collection(include_none=False) (#8719)
(cherry picked from commit 963c267)
1 parent 408752e commit 71a1fbc

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))
1818

1919

20+
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
21+
22+
23+
2024
## [1.4.0] - 2021-07-27
2125

2226
### Added

pytorch_lightning/utilities/apply_func.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def apply_to_collection(
101101
if isinstance(data, Mapping):
102102
out = []
103103
for k, v in data.items():
104-
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
104+
v = apply_to_collection(
105+
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
106+
)
105107
if include_none or v is not None:
106108
out.append((k, v))
107109
return elem_type(OrderedDict(out))
@@ -111,15 +113,25 @@ def apply_to_collection(
111113
if is_namedtuple or is_sequence:
112114
out = []
113115
for d in data:
114-
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
116+
v = apply_to_collection(
117+
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
118+
)
115119
if include_none or v is not None:
116120
out.append(v)
117121
return elem_type(*out) if is_namedtuple else elem_type(out)
118122

119123
if _is_dataclass_instance(data):
120124
out = {}
121125
for field in data.__dataclass_fields__:
122-
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
126+
v = apply_to_collection(
127+
getattr(data, field),
128+
dtype,
129+
function,
130+
*args,
131+
wrong_dtype=wrong_dtype,
132+
include_none=include_none,
133+
**kwargs
134+
)
123135
if include_none or v is not None:
124136
out[field] = v
125137
return elem_type(**out)

tests/utilities/test_apply_func.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,17 @@ def __init__(self, initial_dict):
151151

152152

153153
def test_apply_to_collection_include_none():
154-
to_reduce = [1, 2, 3.4, 5.6, 7]
154+
to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]
155155

156156
def fn(x):
157157
if isinstance(x, float):
158158
return x
159159

160160
reduced = apply_to_collection(to_reduce, (int, float), fn)
161-
assert reduced == [None, None, 3.4, 5.6, None]
161+
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]
162162

163163
reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
164-
assert reduced == [3.4, 5.6]
164+
assert reduced == [3.4, 5.6, (9.1, {})]
165165

166166

167167
def test_apply_to_collections():

0 commit comments

Comments
 (0)