Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Datapipeline poc #130

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 143 additions & 13 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from torch import nn

from flash.core.data import DataModule, DataPipeline
from flash.core.data import DataModule
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline
from flash.data.postprocessing_pipeline import PostProcessingPipeline


def predict_context(func: Callable) -> Callable:
Expand All @@ -31,13 +34,16 @@ def predict_context(func: Callable) -> Callable:

@functools.wraps(func)
def wrapper(self, *args, **kwargs) -> Any:
grad_enabled = torch.is_grad_enabled()
is_training = self.training
self.eval()
torch.set_grad_enabled(False)

result = func(self, *args, **kwargs)

self.train()
torch.set_grad_enabled(True)
if is_training:
self.train()
torch.set_grad_enabled(grad_enabled)
return result

return wrapper
Expand All @@ -63,6 +69,8 @@ def __init__(
learning_rate: float = 5e-5,
):
super().__init__()
self._last_trainer_kwargs = {}

if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
Expand Down Expand Up @@ -144,7 +152,7 @@ def predict(

"""
# enable x to be a path to a folder
if isinstance(x, str):
if isinstance(x, str) and os.path.isdir(x):
files = os.listdir(x)
files = [os.path.join(x, y) for y in files]
x = files
Expand All @@ -163,22 +171,36 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
def data_pipeline(self) -> DataPipeline:
# we need to save the pipeline in case this class
# is loaded from checkpoint and used to predict
if not self._data_pipeline:
try:
# datamodule pipeline takes priority
self._data_pipeline = self.trainer.datamodule.data_pipeline
except AttributeError:
self._data_pipeline = self.default_pipeline()
return self._data_pipeline
return self._get_pipeline('data')

@data_pipeline.setter
def data_pipeline(self, data_pipeline: DataPipeline) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing to have DataPipeline and PostProcessingPipeline as people might expect a PreprocessingPipeline. Worth to iterate on this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thought so as well. Basically I named it data_pipeline since it does loading + preprocessing. But fine with changing it as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also does postprocessing, with after_uncollate

self._data_pipeline = data_pipeline

@property
def postprocessing_pipeline(self) -> PostProcessingPipeline:
return self._get_pipeline('postprocessing')

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Missing setter

def _get_pipeline(self, pipeline_type: str):
pipeline_attr_name = f'{pipeline_type}_pipline'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_pipline typo ?


if getattr(self, '_' + pipeline_attr_name) is not None:
return getattr(self, '_' + pipeline_attr_name)

if self.datamodule is not None and hasattr(self, pipeline_attr_name):
return getattr(self.datamodule, pipeline_attr_name)

if self.trainer is not None and hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None:
if hasattr(self.trainer.datamodule,
pipeline_attr_name) and getattr(self.trainer.datamodule, pipeline_attr_name is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can pipeline_attr_name be None ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't, that should be outside the brackets :)

return getattr(self.trainer.datamodule, pipeline_attr_name is not None)

return None

@staticmethod
def default_pipeline() -> DataPipeline:
def default_data_pipeline() -> DataPipeline:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we should take the data-type default one ? Example collate for text isn't the same than for vision.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but that's why each task would have its own default

"""Pipeline to use when there is no datamodule or it has not defined its pipeline"""
return DataModule.default_pipeline()
return DataModule.default_data_pipeline()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: also do this for postprocessing


def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.data_pipeline = checkpoint["pipeline"]
Expand All @@ -188,3 +210,111 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

def configure_finetune_callback(self):
return []

### THE FOLLOWING IS A POC FOR DISTRIBUTED PREDICTION
def on_predict_start(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton does it make sense to have a hook like that (I think we need to revisit lightning hooks in general for all stages)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is it called ? I guess we could add hook for predict. Need a bit more exploration there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to on_fit_start, would be called immediately after the Trainer.predict was called

# TODO: Add hook to lightning Trainer
if self.data_pipeline is not None:
self.data_pipeline._attach_to_model(self)

if self.postprocessing_pipeline is not None:
self.postprocessing_pipeline._attach_to_model(self)

def predict_step(self, batch, batch_idx):
# TODO: Move lightning predict loop from predict to predict_step
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton You mentioned the prediction API is not final in lightning, right?

IMO it makes sense to rename it to training_step within the LightningModule, since (similar to train step etc.) it only runs prediction for one batch at a time, making it more of a step (plus we can use the predict keyword here independently :) )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am good with that :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial proposal but it got turned down because users are expected to write

model.predict(...)

but not

model.predict_step()

because nobody calls

model.training_step

if isinstance(batch, (tuple, list)) and len(batch) == 2:
x, y = batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't have y when doing predictions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes you have when using the same loader/dart as for training. This is just because it was already there in old predict logic and I didn't want to remove already existing features

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
x, y = batch, None

return self(x)

def new_predict(
self,
x: Any,
skip_collate: Optional[bool] = None,
data_pipeline: Optional[DataPipeline] = None,
postprocessing_pipeline: Optional[PostProcessingPipeline] = None,
data_loader_kwargs: Optional[dict] = None,
**trainer_kwargs
):
if data_pipeline is not None:
self.data_pipeline = data_pipeline
if postprocessing_pipeline is not None:
self.postprocessing_pipeline = postprocessing_pipeline

trainer = self._create_trainer('predict', **trainer_kwargs)

if data_loader_kwargs is None:
data_loader_kwargs = {}

if 'num_workers' not in data_loader_kwargs:
# leave one for main process
data_loader_kwargs['num_workers'] = os.cpu_count() - 1

auto_collate = None
if 'collate_fn' not in data_loader_kwargs:
auto_collate = not skip_collate

dl = self.data_pipeline._generate_loader(x, auto_collate=auto_collate, **data_loader_kwargs)

return trainer.predict(self, dl)

def _create_trainer(self, stage: str, **trainer_kwargs):
# TODO: Also use these for trainer creation in training?
# TODO: Have default trainer kwargs per task?
_trainer_kwargs = {}
# TODO: Adjust this to trainer running stage from pl
Comment on lines +264 to +267
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any thoughts on that @tchaton @aribornstein ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it. We had similar function in previous iteration of predict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I k is we had something like that for training in the beginning. The only downside I See, is That it hides away the Lightning Trainer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could provide an optional argument for the user to provide trainer in case they don't want to use the default trainer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean another trainer class?
Actually, the trainer class is something I'd hardcode here tbh.
This is one of the very fundamental lightning aspects and I feel if a user wants to change it, he either should look into customization with callbacks/plugins or subclass the task to overwrite it here directly.

if stage == 'predict':
_trainer_kwargs.update(logger=None)

if not 'gpus' in trainer_kwargs and not 'tpu_cores' in trainer_kwargs:
_trainer_kwargs['gpus'], _trainer_kwargs['tpu_cores'] = self._parse_default_devices()

_trainer_kwargs.update(trainer_kwargs)

if not hasattr(self, 'trainer') or self.trainer is None or self._last_trainer_kwargs != trainer_kwargs:
self._last_trainer_kwargs = _trainer_kwargs
self.trainer = None
return Trainer(**_trainer_kwargs)

else:
return self.trainer

def _parse_default_devices(self):
gpus = None,
tpu_cores = None

if torch.cuda.is_available():
gpus = torch.cuda.device_count()

# TODO: Add logic for automatted TPU device parsing

return gpus, tpu_cores

def serve(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this serve function is confusing. Serve means different things on the context and should not be used to perform prediction. It is more like a mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I thought we wanted to reserve serve for actually setting up a hosting server for the model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, so how do we want to handle it? I mean technically the user can set those defaults manually as well, but I feel like this is not what flash should be aiming for.

self,
x,
skip_collate: Optional[bool] = None,
data_pipeline: Optional[DataPipeline] = None,
postprocessing_pipeline: Optional[PostProcessingPipeline] = None,
data_loader_kwargs: Optional[dict] = None,
**trainer_kwargs
):
"""Serving for Production. Basically same as prediction, just other defaults (no workers, no distributed prediction)
"""

if data_loader_kwargs is None:
data_loader_kwargs = {}
data_loader_kwargs['num_workers'] = 0

trainer_kwargs['num_gpus'] = [0] if torch.cuda.is_available() else 0
# TODO: tpu_cores
return self.new_predict(
x,
skip_collate=skip_collate,
data_pipeline=data_pipeline,
postprocessing_pipeline=postprocessing_pipeline,
data_loader_kwargs=data_loader_kwargs,
**trainer_kwargs
)
Loading