-
Notifications
You must be signed in to change notification settings - Fork 211
Datapipeline poc #130
Datapipeline poc #130
Changes from 6 commits
45691cd
135eb17
535353c
f66f223
67de76f
3be12a3
17cecb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,13 @@ | |
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pytorch_lightning import Trainer | ||
from torch import nn | ||
|
||
from flash.core.data import DataModule, DataPipeline | ||
from flash.core.data import DataModule | ||
from flash.core.utils import get_callable_dict | ||
from flash.data.data_pipeline import DataPipeline | ||
from flash.data.postprocessing_pipeline import PostProcessingPipeline | ||
|
||
|
||
def predict_context(func: Callable) -> Callable: | ||
|
@@ -31,13 +34,16 @@ def predict_context(func: Callable) -> Callable: | |
|
||
@functools.wraps(func) | ||
def wrapper(self, *args, **kwargs) -> Any: | ||
grad_enabled = torch.is_grad_enabled() | ||
is_training = self.training | ||
self.eval() | ||
torch.set_grad_enabled(False) | ||
|
||
result = func(self, *args, **kwargs) | ||
|
||
self.train() | ||
torch.set_grad_enabled(True) | ||
if is_training: | ||
self.train() | ||
torch.set_grad_enabled(grad_enabled) | ||
return result | ||
|
||
return wrapper | ||
|
@@ -63,6 +69,8 @@ def __init__( | |
learning_rate: float = 5e-5, | ||
): | ||
super().__init__() | ||
self._last_trainer_kwargs = {} | ||
|
||
if model is not None: | ||
self.model = model | ||
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) | ||
|
@@ -144,7 +152,7 @@ def predict( | |
|
||
""" | ||
# enable x to be a path to a folder | ||
if isinstance(x, str): | ||
if isinstance(x, str) and os.path.isdir(x): | ||
files = os.listdir(x) | ||
files = [os.path.join(x, y) for y in files] | ||
x = files | ||
|
@@ -163,22 +171,36 @@ def configure_optimizers(self) -> torch.optim.Optimizer: | |
def data_pipeline(self) -> DataPipeline: | ||
# we need to save the pipeline in case this class | ||
# is loaded from checkpoint and used to predict | ||
if not self._data_pipeline: | ||
try: | ||
# datamodule pipeline takes priority | ||
self._data_pipeline = self.trainer.datamodule.data_pipeline | ||
except AttributeError: | ||
self._data_pipeline = self.default_pipeline() | ||
return self._data_pipeline | ||
return self._get_pipeline('data') | ||
|
||
@data_pipeline.setter | ||
def data_pipeline(self, data_pipeline: DataPipeline) -> None: | ||
self._data_pipeline = data_pipeline | ||
|
||
@property | ||
def postprocessing_pipeline(self) -> PostProcessingPipeline: | ||
return self._get_pipeline('postprocessing') | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: Missing setter |
||
def _get_pipeline(self, pipeline_type: str): | ||
pipeline_attr_name = f'{pipeline_type}_pipline' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
if getattr(self, '_' + pipeline_attr_name) is not None: | ||
return getattr(self, '_' + pipeline_attr_name) | ||
|
||
if self.datamodule is not None and hasattr(self, pipeline_attr_name): | ||
return getattr(self.datamodule, pipeline_attr_name) | ||
|
||
if self.trainer is not None and hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None: | ||
if hasattr(self.trainer.datamodule, | ||
pipeline_attr_name) and getattr(self.trainer.datamodule, pipeline_attr_name is not None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can't, that should be outside the brackets :) |
||
return getattr(self.trainer.datamodule, pipeline_attr_name is not None) | ||
|
||
return None | ||
|
||
@staticmethod | ||
def default_pipeline() -> DataPipeline: | ||
def default_data_pipeline() -> DataPipeline: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think here we should take the data-type default one ? Example collate for text isn't the same than for vision. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but that's why each task would have its own default |
||
"""Pipeline to use when there is no datamodule or it has not defined its pipeline""" | ||
return DataModule.default_pipeline() | ||
return DataModule.default_data_pipeline() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: also do this for postprocessing |
||
|
||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | ||
self.data_pipeline = checkpoint["pipeline"] | ||
|
@@ -188,3 +210,111 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
|
||
def configure_finetune_callback(self): | ||
return [] | ||
|
||
### THE FOLLOWING IS A POC FOR DISTRIBUTED PREDICTION | ||
def on_predict_start(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tchaton does it make sense to have a hook like that (I think we need to revisit lightning hooks in general for all stages) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is it called ? I guess we could add hook for predict. Need a bit more exploration there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to on_fit_start, would be called immediately after the Trainer.predict was called |
||
# TODO: Add hook to lightning Trainer | ||
if self.data_pipeline is not None: | ||
self.data_pipeline._attach_to_model(self) | ||
|
||
if self.postprocessing_pipeline is not None: | ||
self.postprocessing_pipeline._attach_to_model(self) | ||
|
||
def predict_step(self, batch, batch_idx): | ||
# TODO: Move lightning predict loop from predict to predict_step | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tchaton You mentioned the prediction API is not final in lightning, right? IMO it makes sense to rename it to training_step within the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I am good with that :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That was my initial proposal but it got turned down because users are expected to write
but not
because nobody calls
|
||
if isinstance(batch, (tuple, list)) and len(batch) == 2: | ||
x, y = batch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You shouldn't have y when doing predictions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes you have when using the same loader/dart as for training. This is just because it was already there in old predict logic and I didn't want to remove already existing features There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this should be here in the future |
||
else: | ||
x, y = batch, None | ||
|
||
return self(x) | ||
|
||
def new_predict( | ||
self, | ||
x: Any, | ||
skip_collate: Optional[bool] = None, | ||
data_pipeline: Optional[DataPipeline] = None, | ||
postprocessing_pipeline: Optional[PostProcessingPipeline] = None, | ||
data_loader_kwargs: Optional[dict] = None, | ||
**trainer_kwargs | ||
): | ||
if data_pipeline is not None: | ||
self.data_pipeline = data_pipeline | ||
if postprocessing_pipeline is not None: | ||
self.postprocessing_pipeline = postprocessing_pipeline | ||
|
||
trainer = self._create_trainer('predict', **trainer_kwargs) | ||
|
||
if data_loader_kwargs is None: | ||
data_loader_kwargs = {} | ||
|
||
if 'num_workers' not in data_loader_kwargs: | ||
# leave one for main process | ||
data_loader_kwargs['num_workers'] = os.cpu_count() - 1 | ||
|
||
auto_collate = None | ||
if 'collate_fn' not in data_loader_kwargs: | ||
auto_collate = not skip_collate | ||
|
||
dl = self.data_pipeline._generate_loader(x, auto_collate=auto_collate, **data_loader_kwargs) | ||
|
||
return trainer.predict(self, dl) | ||
|
||
def _create_trainer(self, stage: str, **trainer_kwargs): | ||
# TODO: Also use these for trainer creation in training? | ||
# TODO: Have default trainer kwargs per task? | ||
_trainer_kwargs = {} | ||
# TODO: Adjust this to trainer running stage from pl | ||
Comment on lines
+264
to
+267
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any thoughts on that @tchaton @aribornstein ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like it. We had similar function in previous iteration of predict. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I k is we had something like that for training in the beginning. The only downside I See, is That it hides away the Lightning Trainer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could provide an optional argument for the user to provide trainer in case they don't want to use the default trainer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you mean another trainer class? |
||
if stage == 'predict': | ||
_trainer_kwargs.update(logger=None) | ||
|
||
if not 'gpus' in trainer_kwargs and not 'tpu_cores' in trainer_kwargs: | ||
_trainer_kwargs['gpus'], _trainer_kwargs['tpu_cores'] = self._parse_default_devices() | ||
|
||
_trainer_kwargs.update(trainer_kwargs) | ||
|
||
if not hasattr(self, 'trainer') or self.trainer is None or self._last_trainer_kwargs != trainer_kwargs: | ||
self._last_trainer_kwargs = _trainer_kwargs | ||
self.trainer = None | ||
return Trainer(**_trainer_kwargs) | ||
|
||
else: | ||
return self.trainer | ||
|
||
def _parse_default_devices(self): | ||
gpus = None, | ||
tpu_cores = None | ||
|
||
if torch.cuda.is_available(): | ||
gpus = torch.cuda.device_count() | ||
|
||
# TODO: Add logic for automatted TPU device parsing | ||
|
||
return gpus, tpu_cores | ||
|
||
def serve( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this serve function is confusing. Serve means different things on the context and should not be used to perform prediction. It is more like a mode. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I thought we wanted to reserve serve for actually setting up a hosting server for the model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay, so how do we want to handle it? I mean technically the user can set those defaults manually as well, but I feel like this is not what flash should be aiming for. |
||
self, | ||
x, | ||
skip_collate: Optional[bool] = None, | ||
data_pipeline: Optional[DataPipeline] = None, | ||
postprocessing_pipeline: Optional[PostProcessingPipeline] = None, | ||
data_loader_kwargs: Optional[dict] = None, | ||
**trainer_kwargs | ||
): | ||
"""Serving for Production. Basically same as prediction, just other defaults (no workers, no distributed prediction) | ||
""" | ||
|
||
if data_loader_kwargs is None: | ||
data_loader_kwargs = {} | ||
data_loader_kwargs['num_workers'] = 0 | ||
|
||
trainer_kwargs['num_gpus'] = [0] if torch.cuda.is_available() else 0 | ||
# TODO: tpu_cores | ||
return self.new_predict( | ||
x, | ||
skip_collate=skip_collate, | ||
data_pipeline=data_pipeline, | ||
postprocessing_pipeline=postprocessing_pipeline, | ||
data_loader_kwargs=data_loader_kwargs, | ||
**trainer_kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it a bit confusing to have
DataPipeline
andPostProcessingPipeline
as people might expect aPreprocessingPipeline
. Worth to iterate on this one.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thought so as well. Basically I named it data_pipeline since it does loading + preprocessing. But fine with changing it as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also does postprocessing, with
after_uncollate