This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
/
data_pipeline.py
556 lines (422 loc) · 21.3 KB
/
data_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import weakref
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import imports
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data._utils.collate import default_collate, default_convert
from flash.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential
from flash.data.process import Postprocess, Preprocess
from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
if TYPE_CHECKING:
from flash.core.model import Task
class DataPipeline:
"""
DataPipeline holds the engineering logic to connect
:class:`~flash.data.process.Preprocess` and/or ``PostProcess`` objects to
the ``DataModule``, Flash ``Task`` and ``Trainer``.
Example::
class CustomPreprocess(Preprocess):
pass
class CustomPostprocess(Postprocess):
pass
custom_data_pipeline = DataPipeline(CustomPreprocess(), CustomPostprocess())
# And it can attached to both the datamodule and model.
datamodule.data_pipeline = custom_data_pipeline
model.data_pipeline = custom_data_pipeline
"""
PREPROCESS_FUNCS: Set[str] = _PREPROCESS_FUNCS
POSTPROCESS_FUNCS: Set[str] = _POSTPROCESS_FUNCS
def __init__(self, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None) -> None:
self._preprocess_pipeline = preprocess or Preprocess()
self._postprocess_pipeline = postprocess or Postprocess()
self._postprocessor = None
self._running_stage = None
@staticmethod
def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool:
"""
Cropped Version of
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py
"""
current_method_name = method_name if prefix is None else f'{prefix}_{method_name}'
if not hasattr(process_obj, current_method_name):
return False
return getattr(process_obj, current_method_name).__code__ != getattr(super_obj, method_name).__code__
@property
def preprocess_state(self):
if self._preprocess_pipeline:
return self._preprocess_pipeline.state
@classmethod
def _is_overriden_recursive(
cls, method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None
) -> bool:
"""
Cropped Version of
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py
"""
assert isinstance(process_obj, super_obj)
if prefix is None and not hasattr(super_obj, method_name):
raise MisconfigurationException(f"This function doesn't belong to the parent class {super_obj}")
current_method_name = method_name if prefix is None else f'{prefix}_{method_name}'
if not hasattr(process_obj, current_method_name):
return DataPipeline._is_overriden_recursive(method_name, process_obj, super_obj)
current_code = inspect.unwrap(getattr(process_obj, current_method_name)).__code__
has_different_code = current_code != getattr(super_obj, method_name).__code__
if not prefix:
return has_different_code
else:
return has_different_code or cls._is_overriden_recursive(method_name, process_obj, super_obj)
@staticmethod
def _identity(samples: Sequence[Any]) -> Sequence[Any]:
return samples
def worker_preprocessor(self, running_stage: RunningStage) -> _PreProcessor:
return self._create_collate_preprocessors(running_stage)[0]
def device_preprocessor(self, running_stage: RunningStage) -> _PreProcessor:
return self._create_collate_preprocessors(running_stage)[1]
def postprocessor(self, running_stage: RunningStage) -> _PostProcessor:
return self._create_uncollate_postprocessors(running_stage)
@classmethod
def _resolve_function_hierarchy(
cls, function_name, process_obj, stage: RunningStage, object_type: Optional[Type] = None
) -> str:
if object_type is None:
object_type = Preprocess
prefixes = ['']
if stage in (RunningStage.TRAINING, RunningStage.TUNING):
prefixes += ['train', 'fit']
elif stage == RunningStage.VALIDATING:
prefixes += ['val', 'fit']
elif stage == RunningStage.TESTING:
prefixes += ['test']
elif stage == RunningStage.PREDICTING:
prefixes += ['predict']
for prefix in prefixes:
if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix):
return f'{prefix}_{function_name}'
return function_name
def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]:
if on_device:
return self._identity, collate
else:
return collate, self._identity
def _create_collate_preprocessors(
self,
stage: RunningStage,
collate_fn: Optional[Callable] = None,
) -> Tuple[_PreProcessor, _PreProcessor]:
original_collate_fn = collate_fn
if collate_fn is None:
collate_fn = default_collate
preprocess: Preprocess = self._preprocess_pipeline
prefix: str = _STAGES_PREFIX[stage]
func_names: Dict[str, str] = {
k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess)
for k in self.PREPROCESS_FUNCS
}
if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=prefix):
collate_fn: Callable = getattr(preprocess, func_names["collate"])
per_batch_transform_overriden: bool = self._is_overriden_recursive(
"per_batch_transform", preprocess, Preprocess, prefix=prefix
)
per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive(
"per_sample_transform_on_device", preprocess, Preprocess, prefix=prefix
)
collate_in_worker_from_transform: Optional[bool] = getattr(
preprocess, f"_{prefix}_collate_in_worker_from_transform"
)
if (
collate_in_worker_from_transform is None and per_batch_transform_overriden
and per_sample_transform_on_device_overriden
):
raise MisconfigurationException(
f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` '
f'are mutual exclusive for stage {stage}'
)
if isinstance(collate_in_worker_from_transform, bool):
worker_collate_fn, device_collate_fn = self._make_collates(not collate_in_worker_from_transform, collate_fn)
else:
worker_collate_fn, device_collate_fn = self._make_collates(
per_sample_transform_on_device_overriden, collate_fn
)
worker_collate_fn = worker_collate_fn.collate_fn if isinstance(
worker_collate_fn, _PreProcessor
) else worker_collate_fn
assert_contains_tensor = self._is_overriden_recursive(
"to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]
)
worker_preprocessor = _PreProcessor(
preprocess, worker_collate_fn,
_Sequential(
preprocess,
getattr(preprocess, func_names['pre_tensor_transform']),
getattr(preprocess, func_names['to_tensor_transform']),
getattr(preprocess, func_names['post_tensor_transform']),
stage,
assert_contains_tensor=assert_contains_tensor,
), getattr(preprocess, func_names['per_batch_transform']), stage
)
worker_preprocessor._original_collate_fn = original_collate_fn
device_preprocessor = _PreProcessor(
preprocess,
device_collate_fn,
getattr(preprocess, func_names['per_sample_transform_on_device']),
getattr(preprocess, func_names['per_batch_transform_on_device']),
stage,
apply_per_sample_transform=device_collate_fn != self._identity,
on_device=True,
)
return worker_preprocessor, device_preprocessor
@staticmethod
def _model_transfer_to_device_wrapper(
func: Callable, preprocessor: _PreProcessor, model: 'Task', stage: RunningStage
) -> Callable:
if not isinstance(func, _StageOrchestrator):
func = _StageOrchestrator(func, model)
func.register_additional_stage(stage, preprocessor)
return func
@staticmethod
def _model_predict_step_wrapper(func: Callable, postprocessor: _PostProcessor, model: 'Task') -> Callable:
if not isinstance(func, _StageOrchestrator):
_original = func
func = _StageOrchestrator(func, model)
func._original = _original
func.register_additional_stage(RunningStage.PREDICTING, postprocessor)
return func
@staticmethod
def _get_dataloader(model: 'Task', loader_name: str) -> Tuple[DataLoader, str]:
dataloader, attr_name = None, None
if hasattr(model, loader_name):
dataloader = getattr(model, loader_name)
attr_name = loader_name
elif model.trainer and hasattr(model.trainer, 'datamodule') and model.trainer.datamodule:
dataloader = getattr(model, f'trainer.datamodule.{loader_name}', None)
attr_name = f'trainer.datamodule.{loader_name}'
return dataloader, attr_name
@staticmethod
def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None:
"""
This function is used to set the loader to model and/or datamodule
"""
*intermediates, final_name = loader_name.split('.')
curr_attr = model
# This relies on python calling all non-integral types by reference.
# It may fail for integral types since those will be called by value.
for intermediate in intermediates:
curr_attr = getattr(curr_attr, intermediate)
setattr(curr_attr, final_name, new_loader)
setattr(model, final_name, new_loader)
def _attach_preprocess_to_model(
self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False
) -> None:
if not stage:
stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
elif isinstance(stage, RunningStage):
stages = [stage]
for stage in stages:
loader_name = f'{_STAGES_PREFIX[stage]}_dataloader'
dataloader, whole_attr_name = self._get_dataloader(model, loader_name)
if not dataloader:
continue
if isinstance(dataloader, (_PatchDataLoader, Callable)):
dataloader = dataloader()
if dataloader is None:
continue
if isinstance(dataloader, Sequence):
was_seq = True
else:
dataloader = [dataloader]
was_seq = False
for idx, loader in enumerate(dataloader):
# TODO: See lightning for proper reinstantiation of loader
if isinstance(loader, DataLoader):
dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")}
dl_args['collate_fn'], device_collate_fn = self._create_collate_preprocessors(
stage=stage, collate_fn=dl_args['collate_fn']
)
if isinstance(dl_args["dataset"], IterableDataset):
del dl_args["sampler"]
# don't have to reinstantiate loader if just rewrapping devices (happens during detach)
if not device_transform_only:
del dl_args["batch_sampler"]
loader = type(loader)(**dl_args)
dataloader[idx] = loader
# don't have to set attribute if rewrapping device part (happens during detach)
if not device_transform_only:
if not was_seq:
dataloader = dataloader[0]
if isinstance(dataloader, DataLoader):
dataloader = _PatchDataLoader(dataloader)
self._set_loader(model, whole_attr_name, dataloader)
model.transfer_batch_to_device = (
self._model_transfer_to_device_wrapper(model.transfer_batch_to_device, device_collate_fn, model, stage)
)
def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcessor:
save_per_sample = None
save_fn = None
postprocess: Postprocess = self._postprocess_pipeline
func_names: Dict[str, str] = {
k: self._resolve_function_hierarchy(k, postprocess, stage, object_type=Postprocess)
for k in self.POSTPROCESS_FUNCS
}
# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if postprocess._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage]
)
if save_per_sample:
save_per_sample: Callable = getattr(postprocess, func_names["save_sample"])
else:
save_fn: Callable = getattr(postprocess, func_names["save_data"])
return _PostProcessor(
getattr(postprocess, func_names["uncollate"]),
getattr(postprocess, func_names["per_batch_transform"]),
getattr(postprocess, func_names["per_sample_transform"]),
save_fn=save_fn,
save_per_sample=save_per_sample
)
def _attach_postprocess_to_model(self, model: 'Task', stage) -> 'Task':
model.predict_step = self._model_predict_step_wrapper(
model.predict_step, self._create_uncollate_postprocessors(stage), model
)
return model
def _attach_to_model(self, model: 'Task', stage: RunningStage = None):
# not necessary to detach. preprocessing and postprocessing for stage will be overwritten.
self._attach_preprocess_to_model(model, stage)
if not stage or stage == RunningStage.PREDICTING:
self._attach_postprocess_to_model(model, stage)
def _detach_from_model(self, model: 'Task', stage: Optional[RunningStage] = None):
self._detach_preprocessing_from_model(model, stage)
if not stage or stage == RunningStage.PREDICTING:
self._detach_postprocess_from_model(model)
def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[RunningStage] = None):
if not stage:
stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
elif isinstance(stage, RunningStage):
stages = [stage]
for stage in stages:
device_collate = None
if isinstance(model.transfer_batch_to_device, _StageOrchestrator):
device_collate = model.transfer_batch_to_device.unregister_stage(stage)
# if no additional funmc available: remove wrapper
if model.transfer_batch_to_device.is_empty():
model.transfer_batch_to_device = model.transfer_batch_to_device.func
if not device_collate:
device_collate = self._identity
loader_name = f'{_STAGES_PREFIX[stage]}_dataloader'
dataloader, whole_attr_name = self._get_dataloader(model, loader_name)
if not dataloader:
continue
if isinstance(dataloader, _PatchDataLoader):
dataloader = dataloader()
elif isinstance(dataloader, Callable):
dataloader = dataloader()
if isinstance(dataloader, Sequence):
was_seq = True
else:
dataloader = [dataloader]
was_seq = False
for idx, loader in enumerate(dataloader):
if isinstance(loader, DataLoader):
dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")}
if isinstance(dl_args['collate_fn'], _PreProcessor):
dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn
if isinstance(dl_args["dataset"], IterableAutoDataset):
del dl_args['sampler']
del dl_args["batch_sampler"]
loader = type(loader)(**dl_args)
dataloader[idx] = loader
if not was_seq:
dataloader = dataloader[0]
if isinstance(dataloader, DataLoader):
dataloader = _PatchDataLoader(dataloader)
self._set_loader(model, whole_attr_name, dataloader)
@staticmethod
def _detach_postprocess_from_model(model: 'Task'):
if hasattr(model.predict_step, '_original'):
# don't delete the predict_step here since we don't know
# if any other pipeline is attached which may rely on this!
model.predict_step = model.predict_step._original
def _generate_callable_auto_dataset(
self, data: Union[Iterable, Any], running_stage: RunningStage = None
) -> Callable:
def fn():
return self._generate_auto_dataset(data, running_stage=running_stage)
return fn
def _generate_auto_dataset(
self,
data: Union[Iterable, Any],
running_stage: RunningStage = None,
use_iterable_auto_dataset: bool = False
) -> Union[AutoDataset, IterableAutoDataset]:
if use_iterable_auto_dataset:
return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage)
return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage)
def to_dataloader(
self, data: Union[Iterable, Any], auto_collate: Optional[bool] = None, **loader_kwargs
) -> DataLoader:
if 'collate_fn' in loader_kwargs:
if auto_collate:
raise MisconfigurationException('auto_collate and collate_fn are mutually exclusive')
else:
if auto_collate is None:
auto_collate = True
collate_fn = self.worker_collate_fn
if collate_fn:
loader_kwargs['collate_fn'] = collate_fn
else:
loader_kwargs['collate_fn'] = default_collate if auto_collate else default_convert
return DataLoader(self._generate_auto_dataset(data), **loader_kwargs)
def __str__(self) -> str:
preprocess: Preprocess = self._preprocess_pipeline
postprocess: Postprocess = self._postprocess_pipeline
return f"{self.__class__.__name__}(preprocess={preprocess}, postprocess={postprocess})"
class _StageOrchestrator:
# This is used to map ``SANITY_CHECKING`` to ``VALIDATING``
internal_mapping = {
RunningStage.TRAINING: RunningStage.TRAINING,
RunningStage.SANITY_CHECKING: RunningStage.VALIDATING,
RunningStage.VALIDATING: RunningStage.VALIDATING,
RunningStage.TESTING: RunningStage.TESTING,
RunningStage.PREDICTING: RunningStage.PREDICTING,
RunningStage.TUNING: RunningStage.TUNING
}
def __init__(self, func_to_wrap: Callable, model: 'Task') -> None:
self.func = func_to_wrap
self._stage_mapping = {k: None for k in RunningStage}
self.model = weakref.proxy(model)
functools.update_wrapper(self, self.func)
def __call__(self, *args, **kwargs):
outputs = self.func(*args, **kwargs)
internal_running_state = self.internal_mapping[self.model.trainer._running_stage]
additional_func = self._stage_mapping.get(internal_running_state, None)
if additional_func:
outputs = additional_func(outputs)
return outputs
def register_additional_stage(self, stage: RunningStage, stage_func: Optional[Callable] = None):
assert stage_func is None or callable(stage_func)
self._stage_mapping[stage] = stage_func.to(self.model.device, self.model.dtype)
def unregister_stage(self, stage: RunningStage):
ret_val = self._stage_mapping.pop(stage)
self._stage_mapping[stage] = None
if ret_val:
ret_val = ret_val.cpu()
return ret_val
def is_empty(self):
return all([v is None for v in self._stage_mapping.values()]) or not self._stage_mapping