Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resume_from_checkpoint broken when fault-tolerant feature enabled #8835

Closed
awaelchli opened this issue Aug 10, 2021 · 1 comment · Fixed by #9371
Closed

resume_from_checkpoint broken when fault-tolerant feature enabled #8835

awaelchli opened this issue Aug 10, 2021 · 1 comment · Fixed by #9371
Assignees
Labels
bug Something isn't working priority: 1 Medium priority task
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Aug 10, 2021

🐛 Bug

A trainer with the resume from checkpoint option does not continue training and stops immediately, despite increased max_epoch settings.

To Reproduce

import os
import unittest.mock

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        print(batch.sum())
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def on_train_epoch_end(self):
        print("epoch ended")

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_load_checkpoint(self, checkpoint):
        pass


@unittest.mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=3,
        limit_val_batches=0,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.fit(model, train_dataloader=train_data)

    trainer.save_checkpoint("lightning_logs/auto.pt")

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=3,
        limit_val_batches=0,
        num_sanity_val_steps=0,
        max_epochs=3,
        weights_summary=None,
        resume_from_checkpoint="lightning_logs/auto.pt",
    )
    trainer.fit(model, train_dataloader=train_data)


if __name__ == "__main__":
    run()

Output:

Epoch 0:   0%|          | 0/3 [00:00<?, ?it/s] tensor(6.2218)
Epoch 0:  33%|███▎      | 1/3 [00:00<00:00, 83.40it/s, loss=-1.23, v_num=82]tensor(-9.5560)
Epoch 0:  67%|██████▋   | 2/3 [00:00<00:00, 120.73it/s, loss=-1.71, v_num=82]tensor(14.5846)
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 140.56it/s, loss=-0.136, v_num=82]epoch ended
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 98.17it/s, loss=-0.136, v_num=82] 

Restoring states from the checkpoint path at lightning_logs/auto.pt
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 22712.84it/s]epoch ended
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 20004.63it/s]epoch ended
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 3038.62it/s] 

Here the training_step does not get invoked on the second fit call. Instead, the epochs end immediately.

ONLY happens when PL_FAULT_TOLERANT_TRAINING=1. This is an experimental feature which is off by default.

Expected behavior

The training continues due to the increased max_epochs setting.

@awaelchli awaelchli added bug Something isn't working help wanted Open to be worked on labels Aug 10, 2021
@awaelchli awaelchli self-assigned this Aug 10, 2021
@awaelchli awaelchli added the priority: 0 High priority task label Aug 10, 2021
@awaelchli awaelchli removed priority: 0 High priority task bug Something isn't working help wanted Open to be worked on labels Aug 10, 2021
@awaelchli awaelchli reopened this Aug 10, 2021
@awaelchli awaelchli added the bug Something isn't working label Aug 10, 2021
@awaelchli awaelchli changed the title resume_from_checkpoint broken on in 1.4 - does not advance loop resume_from_checkpoint broken when fault-tolerant feature enabled Aug 10, 2021
@awaelchli awaelchli added the priority: 1 Medium priority task label Aug 10, 2021
@awaelchli awaelchli added this to the v1.5 milestone Aug 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment