Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC] Add KFold - External Loop. #8715

Closed
wants to merge 19 commits into from
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))


- Added `KFoldLoop` example ([#8715](https://github.com/PyTorchLightning/pytorch-lightning/pull/8715))


### Changed

- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))
Expand Down
162 changes: 162 additions & 0 deletions pl_examples/loops_customisation/k_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
WARNING: Loop customization is in `pre-alpha release` and the API is likely to change quite a lot !
Please, open issues with your own particular requests, so the Lightning Team can progressively converge to a great API.
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Type

import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data import Dataset

from pytorch_lightning import _logger as log
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.loops.base import ExternalLoop
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.boring_model import BoringDataModule, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException

seed_everything(42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seed_everything(42)

rather not seed anything globally



class SplitDataset(Dataset):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""SplitDataset is used to create Dataset Subset using indices.
Args:
dataset: A dataset to be splitted
indices: List of indices to expose from the dataset
use_duplicated_indices: Whether to allow duplicated indices.
Example::
split_ds = SplitDataset(dataset, indices=[10, 14, 25])
split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True)
"""

_INTERNAL_KEYS = ("dataset", "indices", "data")

def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indices: bool = False) -> None:
if indices is None:
indices = []
if not isinstance(indices, list):
raise MisconfigurationException("indices should be a list")

if use_duplicated_indices:
indices = list(indices)
else:
indices = list(np.unique(indices))

if np.max(indices) >= len(dataset) or np.min(indices) < 0:
raise MisconfigurationException(f"`indices` should be within [0, {len(dataset) -1}].")

self.dataset = dataset
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.indices = indices

def __getattr__(self, key: str):
if key not in self._INTERNAL_KEYS:
return self.dataset.__getattribute__(key)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise AttributeError

def __setattr__(self, name: str, value: Any) -> None:
if name in self._INTERNAL_KEYS:
self.__dict__[name] = value
else:
setattr(self.dataset, name, value)

def __getitem__(self, index: int) -> Any:
return self.dataset[self.indices[index]]

def __len__(self) -> int:
return len(self.indices) - 1
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class KFoldLoop(ExternalLoop):

num_folds: int
num_epochs: int = 10
best_model_paths: List[str] = field(default_factory=lambda: [])
restarting: bool = False

@staticmethod
def loop_base_callback() -> Type[Callback]:
class BaseKFoldCallback(Callback):
@rank_zero_only
def on_fold_start(self, trainer, pl_module, counter):
"""Override with your own logic"""

return BaseKFoldCallback
Comment on lines +148 to +154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we define this outside this class but in the file namespace?


@property
def done(self) -> bool:
return self.current_fold >= self.num_folds

def reset(self) -> None:
if not self.restarting:
self.current_fold = 0
self.set_max_epochs(self.num_epochs)

def generate_fold(self, dataloader_kwargs: Dict[str, Any], stage: str):
dataset = dataloader_kwargs["dataset"]
kfold = KFold(self.num_folds, random_state=42, shuffle=True)
train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold]
if stage == "train":
dataloader_kwargs["dataset"] = SplitDataset(dataset, train_indices.tolist())
else:
dataloader_kwargs["dataset"] = SplitDataset(dataset, validation_indices.tolist())
dataloader_kwargs["sampler"].data_source = dataloader_kwargs["dataset"]
return dataloader_kwargs

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
# temporary hack
self.trainer.datamodule.setup("fit")

def on_advance_start(self):
self.reload_train_dataloader(self.generate_fold)
self.reload_val_dataloaders(self.generate_fold)
self.trainer.call_hook("on_fold_start", self.current_fold)
self.lightning_module.reset_parameters()

def advance(self):
return self.trainer.fit(self.lightning_module, train_dataloader=self.train_dataloader)

def on_advance_end(self) -> None:
self.current_fold += 1
self.increment_max_epochs(self.num_epochs)
# stored best weight path for this fold
self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path)

def on_save_checkpoint(self) -> Dict:
return {"current_fold": self.current_fold}

def on_load_checkpoint(self, state_dict) -> None:
self.current_fold = state_dict["current_fold"]


class KFoldCallback(KFoldLoop.loop_base_callback()):

"""This callback demonstrates how to implement your own callback API."""

@rank_zero_only
def on_fold_start(self, trainer, pl_module, counter):
log.info(f"Starting to train on fold {counter}")


loop = KFoldLoop(5)
model = BoringModel()
datamodule = BoringDataModule()
trainer = Trainer(callbacks=KFoldCallback())
trainer.run_loop(model, datamodule=datamodule, loop=loop)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
71 changes: 70 additions & 1 deletion pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from functools import partial
from typing import Any, Callable, Dict, List, Optional

from deprecate import void
from torch.utils.data.dataloader import DataLoader
from torchmetrics import Metric

import pytorch_lightning as pl
Expand Down Expand Up @@ -238,3 +240,70 @@ def _load_from_state_dict(

self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True


class ExternalLoop(Loop):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""This Loop is meant wrap trainer calls"""

@property
def trainer(self) -> Optional["pl.Trainer"]:
return self._trainer

@trainer.setter
def trainer(self, trainer: "pl.Trainer"):
"""Connects this loop's trainer and its children"""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
if hasattr(self, "_trainer") and isinstance(self._trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be attached to only 1 `Trainer` instance."
)
self._trainer = trainer
for v in self.__dict__.values():
if isinstance(v, Loop):
v.trainer = trainer

def set_max_epochs(self, max_epochs: int):
self.trainer.fit_loop.max_epochs = max_epochs

def increment_max_epochs(self, max_epochs: int):
self.trainer.fit_loop.max_epochs += max_epochs

def set_max_steps(self, max_steps: int):
self.trainer.fit_loop.max_steps = max_steps

def increment_max_steps(self, max_steps: int):
self.trainer.fit_loop.max_steps += max_steps

def reload_train_dataloader(self, user_function: Optional[Callable] = None) -> DataLoader:
self.trainer.train_dataloader = None
self.trainer.reset_train_dataloader(self.trainer.lightning_module)
if user_function:
user_function = partial(user_function, stage="train")
loaders = self.trainer.train_dataloader.loaders
loaders = loaders if isinstance(loaders, DataLoader) else loaders.loaders
self.trainer.train_dataloader.loaders = self.trainer.apply_user_function(loaders, user_function)
return self.trainer.train_dataloader

def reload_val_dataloaders(self, user_function: Optional[Callable] = None) -> List[DataLoader]:
self.trainer.reset_val_dataloader(self.trainer.lightning_module)
if user_function:
user_function = partial(user_function, stage="val")
self.trainer.val_dataloaders = [
self.trainer.apply_user_function(dl, user_function) for dl in self.trainer.val_dataloaders
]
return self.trainer.val_dataloaders

@property
def lightning_module(self):
return self.trainer.lightning_module

@property
def train_dataloader(self) -> DataLoader:
return self.trainer.train_dataloader

@property
def val_dataloaders(self) -> List[DataLoader]:
return self.trainer.val_dataloaders
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def on_keyboard_interrupt(self):
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def user_defined_hook(self, hook_name: str, *args, **kwargs):
"""Called when a user calls call_hook directly with its own hook name."""
for callback in self.callbacks:
if hasattr(callback, hook_name):
getattr(callback, hook_name)(self, self.lightning_module, *args, **kwargs)

@staticmethod
def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
Expand Down
61 changes: 61 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,67 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin
dataloader = dl_cls(**dl_kwargs)
return dataloader

def apply_user_function(self, dataloader: DataLoader, user_function: Callable) -> DataLoader:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(self._resolve_batch_sampler(dataloader, dataloader.sampler, mode=RunningStage.translate))

required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
and p.name not in dl_kwargs
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

dl_cls = type(dataloader)
dataloader = dl_cls(**user_function(dl_kwargs))
return dataloader

def _get_distributed_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DistributedSampler:
Expand Down
Loading