-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
K-fold cross validation example in Fabric #16902
Comments
I'll take it up! |
I was thinking something along the lines of this: from collections import deque, namedtuple
import torch
from torch.utils.data import random_split, ConcatDataset, DataLoader, Dataset
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset, BoringModel
K_FOLDS = 5
MAX_EPOCHS = 100
def split_dataset(data: Dataset, k: int = 5):
splits = deque(random_split(data, lengths=[1 / k] * k))
for _ in range(k):
validation = splits.popleft()
training = ConcatDataset(splits)
yield (training, validation)
splits.append(validation)
def create_splits(fabric, k):
Split = namedtuple("split", ["model", "optim", "train_dl", "val_dl"])
data = RandomDataset(32, 500)
for k, (train, val) in enumerate(split_dataset(data, k)):
model = BoringModel()
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optim = fabric.setup(model, optim)
train_dl = fabric.setup_dataloaders(DataLoader(train))
val_dl = fabric.setup_dataloaders(DataLoader(val))
yield Split(model, optim, train_dl, val_dl)
fabric = Fabric(accelerator="cpu", devices=1)
fabric.launch()
splits = [s for s in create_splits(fabric, K_FOLDS)]
for epoch in range(MAX_EPOCHS):
epoch_loss = 0
for k in range(K_FOLDS):
# select the k'th version. ########################################
train_dataloader = splits[k].train_dl
validation_dataloader = splits[k].val_dl
model = splits[k].model
optimizer = splits[k].optim
###################################################################
#### training start ###############################################
_ = model.train()
for batch_idx, batch in enumerate(train_dataloader):
loss = model.training_step(batch, batch_idx)["loss"]
optimizer.zero_grad()
loss.backward()
optimizer.step()
###################################################################
#### validation start #############################################
_ = torch.set_grad_enabled(False)
_ = model.eval()
epoch_split_loss = 0
for batch_idx, batch in enumerate(validation_dataloader):
loss = model.validation_step(batch, batch_idx)["x"]
epoch_split_loss += loss
epoch_loss += epoch_split_loss
_ = torch.set_grad_enabled(True)
###################################################################
print(f"{epoch=} loss: {epoch_loss / K_FOLDS}") But I'm not terribly fond of having to keep |
@peterchristofferholm Can you have a look at the updated code in the PR #16909 ? |
@shenoynikhil looks reasonable, but you could consider packaging the training loop in some sort of basic class BasicTrainer:
def __init__(self, fabric, LitModel, train_dl, valid_dl):
model = LitModel() # instantiate
optim = model.configure_optimizers()
self.model, self.optim = fabric.setup(model, optim)
self.train_dl = fabric.setup_dataloaders(train_dl)
self.valid_dl = fabric.setup_dataloaders(valid_dl)
def __iter__(self):
while True:
_ = self.model.train()
for batch_idx, batch in enumerate(self.train_dl):
loss = self.model.training_step(batch, batch_idx)
self.optim.zero_grad()
fabric.backward(loss)
self.optim.step()
with torch.set_grad_enabled(False):
_ = self.model.eval()
loss = 0
for batch_idx, batch in enumerate(self.valid_dl):
loss += self.model.validation_step(batch, batch_idx)
yield loss
def _split_dataset(dataset, k):
splits = deque(random_split(dataset, lengths=[1 / k] * k))
for _ in range(k):
valid = splits.popleft()
train = ConcatDataset(splits)
yield (DataLoader(train), DataLoader(valid))
splits.append(valid)
def cross_validation(fabric, dataset, model, n_splits):
ensemble = (
BasicTrainer(fabric, model, train_dl, valid_dl)
for (train_dl, valid_dl) in _split_dataset(dataset, n_splits)
)
yield from zip(*ensemble) |
Description & Motivation
Loop customization was removed from Trainer 2.0, and thus solutions like https://github.com/SkafteNicki/pl_cross no longer work. We want to provide examples and answers to users who need to perform k-fold cross validation.
Pitch
Build an example how to do k-fold cross validation with Fabric.
Could also be done for Trainer, but is more involved and requires changes to Trainer (probably). An example with Fabric would be a good starting point.
Additional context
#839
See also Slack and social media for several users who asked for it (not specifically with Fabric, but in general).
cc @Borda @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: