-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
data_parallel.py
331 lines (273 loc) · 11.7 KB
/
data_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import threading
from collections.abc import Mapping, Iterable
from itertools import chain
import torch
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.warning_utils import WarningCache
def _find_tensors(obj): # pragma: no-cover
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []
def get_a_var(obj): # pragma: no-cover
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (list, tuple)):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
warning_cache = WarningCache()
class LightningDataParallel(DataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# lightning
if self.module.training:
return self.module.training_step(*inputs[0], **kwargs[0])
if self.module.testing:
return self.module.test_step(*inputs[0], **kwargs[0])
return self.module.validation_step(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs)
return outputs
def __gather_structured_result(self, outputs):
prototype_output = outputs[0]
original_class = prototype_output.__class__
outputs = [dict(x) for x in outputs]
# remove all the meta info
meta = outputs[0]['meta']
for i, output in enumerate(outputs):
del output['meta']
outputs = self.gather(outputs)
# pass minimize to constructor for TrainResult
if 'minimize' in outputs:
result = original_class(outputs['minimize'])
else:
result = original_class()
result.update(outputs)
result['meta'] = meta
return result
def gather(self, outputs):
r"""
Override the gather method to support python scalars as well.
"""
def gather_map(outputs):
elem = outputs[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
return Gather.apply(self.output_device, self.dim, *outputs)
if elem is None:
return None
if isinstance(elem, Mapping):
if not all((len(elem) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return elem_type(((k, gather_map([d[k] for d in outputs]))
for k in elem))
if isinstance(elem, Iterable) and not isinstance(elem, str):
return elem_type(map(gather_map, zip(*outputs)))
return outputs
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
gather_map = None
return res
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
class LightningDistributedDataParallel(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
def forward(self, *inputs, **kwargs): # pragma: no-cover
self._sync_params()
fx_called: str = ''
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# --------------
# LIGHTNING MOD
# --------------
# normal
# output = self.module(*inputs[0], **kwargs[0])
# lightning
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
fx_called = 'training_step'
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
fx_called = 'test_step'
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
fx_called = 'validation_step'
else:
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
# output = self.module(*inputs, **kwargs)
# normal lightning (ddp_cpu)
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
else:
output = self.module.validation_step(*inputs, **kwargs)
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
if output is None:
warn_missing_output(f'{fx_called} returned None. Did you forget to re')
return output
def warn_missing_output(fx_called):
if fx_called == 'training_step':
warning_cache.warn("Your training_step returned None. Make sure that was your intention!")
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no-cover
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
fx_called: str = ''
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
module = module.to(device)
# ---------------
# CHANGE
if module.training:
output = module.training_step(*input, **kwargs)
fx_called = 'training_step'
elif module.testing:
output = module.test_step(*input, **kwargs)
fx_called = 'test_step'
else:
output = module.validation_step(*input, **kwargs)
fx_called = 'validation_step'
if output is None:
warn_missing_output(fx_called)
if output is not None and (module.use_dp or module.use_ddp2):
auto_squeeze_dim_zeros(output)
# ---------------
with lock:
results[i] = output
except Exception as ex:
with lock:
results[i] = ex
# TODO: fix hack (maybe not a hack)
# make sure each module knows what training state it's in...
# fixes weird bug where copies are out of sync
root_m = modules[0]
for m in modules[1:]:
m.training = root_m.training
m.testing = root_m.testing
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
def auto_squeeze_dim_zeros(output):
"""
In DP or DDP2 we need to unsqueeze dim 0
:param output:
:return:
"""
if isinstance(output, torch.Tensor):
output = output.unsqueeze(0)
return output
for k, v in output.items():
if not isinstance(v, torch.Tensor):
continue
is_scalar = v.dim() == 0
if is_scalar:
output[k] = output[k].unsqueeze(0)