Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

K-fold cross validation example in Fabric #16902

Closed
awaelchli opened this issue Feb 28, 2023 · 4 comments · Fixed by #16909
Closed

K-fold cross validation example in Fabric #16902

awaelchli opened this issue Feb 28, 2023 · 4 comments · Fixed by #16909
Labels
example fabric lightning.fabric.Fabric
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Feb 28, 2023

Description & Motivation

Loop customization was removed from Trainer 2.0, and thus solutions like https://github.com/SkafteNicki/pl_cross no longer work. We want to provide examples and answers to users who need to perform k-fold cross validation.

Pitch

Build an example how to do k-fold cross validation with Fabric.

Could also be done for Trainer, but is more involved and requires changes to Trainer (probably). An example with Fabric would be a good starting point.

Additional context

#839
See also Slack and social media for several users who asked for it (not specifically with Fabric, but in general).

cc @Borda @carmocca @justusschock @awaelchli

@awaelchli awaelchli added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers example fabric lightning.fabric.Fabric and removed needs triage Waiting to be triaged by maintainers feature Is an improvement or enhancement labels Feb 28, 2023
@awaelchli awaelchli added this to the future milestone Feb 28, 2023
@shenoynikhil
Copy link
Contributor

I'll take it up!

@peterchristofferholm
Copy link

I was thinking something along the lines of this:

from collections import deque, namedtuple

import torch
from torch.utils.data import random_split, ConcatDataset, DataLoader, Dataset
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset, BoringModel

K_FOLDS = 5
MAX_EPOCHS = 100


def split_dataset(data: Dataset, k: int = 5):
    splits = deque(random_split(data, lengths=[1 / k] * k))
    for _ in range(k):
        validation = splits.popleft()
        training = ConcatDataset(splits)
        yield (training, validation)
        splits.append(validation)


def create_splits(fabric, k):
    Split = namedtuple("split", ["model", "optim", "train_dl", "val_dl"])
    data = RandomDataset(32, 500)
    for k, (train, val) in enumerate(split_dataset(data, k)):
        model = BoringModel()
        optim = torch.optim.SGD(model.parameters(), lr=1e-3)
        model, optim = fabric.setup(model, optim)
        train_dl = fabric.setup_dataloaders(DataLoader(train))
        val_dl = fabric.setup_dataloaders(DataLoader(val))
        yield Split(model, optim, train_dl, val_dl)


fabric = Fabric(accelerator="cpu", devices=1)
fabric.launch()

splits = [s for s in create_splits(fabric, K_FOLDS)]

for epoch in range(MAX_EPOCHS):
    epoch_loss = 0
    for k in range(K_FOLDS):

        # select the k'th version. ########################################
        train_dataloader = splits[k].train_dl
        validation_dataloader = splits[k].val_dl
        model = splits[k].model
        optimizer = splits[k].optim
        ###################################################################

        #### training start ###############################################
        _ = model.train()
        for batch_idx, batch in enumerate(train_dataloader):
            loss = model.training_step(batch, batch_idx)["loss"]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        ###################################################################

        #### validation start #############################################
        _ = torch.set_grad_enabled(False)
        _ = model.eval()
        epoch_split_loss = 0

        for batch_idx, batch in enumerate(validation_dataloader):
            loss = model.validation_step(batch, batch_idx)["x"]
            epoch_split_loss += loss
        epoch_loss += epoch_split_loss

        _ = torch.set_grad_enabled(True)
        ###################################################################

    print(f"{epoch=} loss: {epoch_loss / K_FOLDS}")

But I'm not terribly fond of having to keep K different models around in memory. Maybe it would make sense to save and reload the state_dict instead?

@shenoynikhil
Copy link
Contributor

@peterchristofferholm Can you have a look at the updated code in the PR #16909 ?

@peterchristofferholm
Copy link

@shenoynikhil looks reasonable, but you could consider packaging the training loop in some sort of basic Trainer-like class

class BasicTrainer:
    def __init__(self, fabric, LitModel, train_dl, valid_dl):
        model = LitModel()  # instantiate
        optim = model.configure_optimizers()
        self.model, self.optim = fabric.setup(model, optim)
        self.train_dl = fabric.setup_dataloaders(train_dl)
        self.valid_dl = fabric.setup_dataloaders(valid_dl)

    def __iter__(self):
        while True:
            _ = self.model.train()
            for batch_idx, batch in enumerate(self.train_dl):
                loss = self.model.training_step(batch, batch_idx)
                self.optim.zero_grad()
                fabric.backward(loss)
                self.optim.step()

            with torch.set_grad_enabled(False):
                _ = self.model.eval()
                loss = 0
                for batch_idx, batch in enumerate(self.valid_dl):
                    loss += self.model.validation_step(batch, batch_idx)
            yield loss


def _split_dataset(dataset, k):
    splits = deque(random_split(dataset, lengths=[1 / k] * k))
    for _ in range(k):
        valid = splits.popleft()
        train = ConcatDataset(splits)
        yield (DataLoader(train), DataLoader(valid))
        splits.append(valid)


def cross_validation(fabric, dataset, model, n_splits):
    ensemble = (
        BasicTrainer(fabric, model, train_dl, valid_dl)
        for (train_dl, valid_dl) in _split_dataset(dataset, n_splits)
    )
    yield from zip(*ensemble)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
example fabric lightning.fabric.Fabric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants