Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logger called during on_train_start called multiple times during training #13554

Closed
paulhager opened this issue Jul 6, 2022 · 6 comments · Fixed by #14061
Closed

logger called during on_train_start called multiple times during training #13554

paulhager opened this issue Jul 6, 2022 · 6 comments · Fixed by #14061
Assignees
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()`
Milestone

Comments

@paulhager
Copy link

paulhager commented Jul 6, 2022

🐛 Bug

I'm using a ReduceLROnPlateau scheduler and having it monitor my validation F1. When telling pytorch lightning to only execute validation every 5 steps (so >1), it doesn't log anything in my validation_step method and thus the val.f1 metric is never set. At the end of a training epoch, it tries to check my val.f1 metric for the scheduler and throws me an error that it can't find it (and only sees all of my train metrics). To get around this problem, I tried logging a value of 0 for val.f1 in on_train_start which according to the documentation should be called only once right before the epochs loop (see https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#hooks). During training though this method is called after every epoch. This is not the expected behaviour.

Example of multiple logging of val.f1

To Reproduce

Include the on_train_start code in your lightning model file, use a ReduceLROnPlateau scheduler that monitors a validation metric, and set check_val_every_n_epoch > 1.

def on_train_start(self): self.log("val.f1", 0)

Expected behavior

on_train_start is only called once.

(A LR Scheduler that cant find its metric doesnt throw an error but only a warning....)

Packages:
- numpy: 1.21.5
- pyTorch_debug: False
- pyTorch_version: 1.11.0
- pytorch-lightning: 1.6.0
- tqdm: 4.63.0

cc @carmocca @edward-io @ananthsub @rohitgr7 @kamil-kaczmarek @Raalsky @Blaizzy

@paulhager paulhager added the needs triage Waiting to be triaged by maintainers label Jul 6, 2022
@lodo1995
Copy link

lodo1995 commented Jul 7, 2022

DISCLAIMER: I'm not a Lightning maintainer, don't trust me.


Regarding the error when the metric is not available, you can avoid it by passing the option 'strict': False in the lr_scheduler dictionary that you return from configure_optimizers (see my code below).

Regarding the issue of on_train_start being called multiple times, I could not reproduce it with the code below. Maybe you can provide a minimal working example that shows it happening? You can trim down your code to the bare minimum necessary to trigger the issue, or you can start from my code and add something that will trigger it.

import pytorch_lightning as pl
import torch


class LinearRegression(pl.LightningModule):
    def __init__(self, D):
        super().__init__()
        self.linear = torch.nn.Linear(D,1)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = 0.1)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
                'interval': 'epoch',
                'strict': False # without this, we get an error
            }
        }

    def on_train_start(self):
        print('on_train_start')

    def forward(self, x):
        return self.linear(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x).squeeze()
        loss = torch.nn.functional.mse_loss(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x).squeeze()
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log('val_loss', loss, on_epoch = True)


if __name__ == '__main__':
    D = 5
    W = torch.rand(D)
    b = torch.rand(1)
    X = torch.rand(300, D)
    y = torch.matmul(X, W) + b

    dataset = torch.utils.data.TensorDataset(X, y)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [250, 50])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 25)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 25)

    model = LinearRegression(D)

    trainer = pl.Trainer(logger = False, enable_checkpointing = False, max_epochs = 100, check_val_every_n_epoch = 2)
    trainer.fit(model, train_loader, val_loader)

@paulhager
Copy link
Author

paulhager commented Jul 7, 2022

Cool, thanks for the tip with strict. That fixed my problem. Heres a minimum code example for the bug though:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

from pytorch_lightning.loggers import WandbLogger


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def on_train_start(self):
        self.log("train_loss", 500)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    logger = WandbLogger(project='on_train_start_bug')

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=100,
        enable_model_summary=False,
        logger = logger
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

See loss curve that constantly spikes to 500
https://imgur.com/17dyWfZ.png

@lodo1995
Copy link

lodo1995 commented Jul 7, 2022

So, after slightly modifying your code, I came to the conclusion that this must be an issue with logging, not with the callback.

In on_train_start I added a print and I incremented a counter. The print appeared only once at the beginning of the training, and in run, after both fit and test, printing the counter yielded the value 1.

So on_train_start is correctly only called once. The problem is that for some reason the value logged there lingers around.

@lodo1995
Copy link

lodo1995 commented Jul 7, 2022

Update: I added a CSVLogger and verified that it also shows spikes to 500, indicating that this is happening not in a specific logger, but in the overall logging infrastructure.

@paulhager
Copy link
Author

Ok, cool. Thanks for looking into it. I guess I'll change the title to 'logger used in on_train_start logs multiple times during training'

@paulhager paulhager changed the title on_train_start called multiple times during training logger called during on_train_start called multiple times during training Jul 7, 2022
@carmocca carmocca added bug Something isn't working logging Related to the `LoggerConnector` and `log()` labels Aug 5, 2022
@carmocca carmocca self-assigned this Aug 5, 2022
@carmocca carmocca added this to the pl:1.7.x milestone Aug 5, 2022
@carmocca carmocca removed the needs triage Waiting to be triaged by maintainers label Aug 5, 2022
@carmocca
Copy link
Contributor

carmocca commented Aug 5, 2022

Thanks for the investigation! The repro is super helpful:

import os

from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.loggers.logger import Logger


class BugModel(BoringModel):
    def on_train_start(self):
        print("ON_TRAIN_START")
        self.log("train_loss", 500.0)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return loss


class DebugLogger(Logger):
    def log_metrics(self, metrics, step=None):
        metrics["global step"] = step
        print(metrics)

    def log_hyperparams(self, *args, **kwargs):
        pass

    @property
    def name(self):
        return ""

    @property
    def version(self):
        return ""


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BugModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=10,
        enable_model_summary=False,
        logger=DebugLogger(),
        enable_checkpointing=False,
        enable_progress_bar=False,
        log_every_n_steps=10,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()`
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants