-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
logger called during on_train_start called multiple times during training #13554
Comments
DISCLAIMER: I'm not a Lightning maintainer, don't trust me. Regarding the error when the metric is not available, you can avoid it by passing the option Regarding the issue of import pytorch_lightning as pl
import torch
class LinearRegression(pl.LightningModule):
def __init__(self, D):
super().__init__()
self.linear = torch.nn.Linear(D,1)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr = 0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'strict': False # without this, we get an error
}
}
def on_train_start(self):
print('on_train_start')
def forward(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x).squeeze()
loss = torch.nn.functional.mse_loss(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x).squeeze()
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log('val_loss', loss, on_epoch = True)
if __name__ == '__main__':
D = 5
W = torch.rand(D)
b = torch.rand(1)
X = torch.rand(300, D)
y = torch.matmul(X, W) + b
dataset = torch.utils.data.TensorDataset(X, y)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [250, 50])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 25)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 25)
model = LinearRegression(D)
trainer = pl.Trainer(logger = False, enable_checkpointing = False, max_epochs = 100, check_val_every_n_epoch = 2)
trainer.fit(model, train_loader, val_loader) |
Cool, thanks for the tip with strict. That fixed my problem. Heres a minimum code example for the bug though:
See loss curve that constantly spikes to 500 |
So, after slightly modifying your code, I came to the conclusion that this must be an issue with logging, not with the callback. In So |
Update: I added a |
Ok, cool. Thanks for looking into it. I guess I'll change the title to 'logger used in on_train_start logs multiple times during training' |
Thanks for the investigation! The repro is super helpful: import os
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.loggers.logger import Logger
class BugModel(BoringModel):
def on_train_start(self):
print("ON_TRAIN_START")
self.log("train_loss", 500.0)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
class DebugLogger(Logger):
def log_metrics(self, metrics, step=None):
metrics["global step"] = step
print(metrics)
def log_hyperparams(self, *args, **kwargs):
pass
@property
def name(self):
return ""
@property
def version(self):
return ""
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BugModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=10,
enable_model_summary=False,
logger=DebugLogger(),
enable_checkpointing=False,
enable_progress_bar=False,
log_every_n_steps=10,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run() |
🐛 Bug
I'm using a ReduceLROnPlateau scheduler and having it monitor my validation F1. When telling pytorch lightning to only execute validation every 5 steps (so >1), it doesn't log anything in my validation_step method and thus the val.f1 metric is never set. At the end of a training epoch, it tries to check my val.f1 metric for the scheduler and throws me an error that it can't find it (and only sees all of my train metrics). To get around this problem, I tried logging a value of 0 for val.f1 in on_train_start which according to the documentation should be called only once right before the epochs loop (see https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#hooks). During training though this method is called after every epoch. This is not the expected behaviour.
Example of multiple logging of val.f1
To Reproduce
Include the on_train_start code in your lightning model file, use a ReduceLROnPlateau scheduler that monitors a validation metric, and set check_val_every_n_epoch > 1.
def on_train_start(self): self.log("val.f1", 0)
Expected behavior
on_train_start is only called once.
(A LR Scheduler that cant find its metric doesnt throw an error but only a warning....)
Packages:
- numpy: 1.21.5
- pyTorch_debug: False
- pyTorch_version: 1.11.0
- pytorch-lightning: 1.6.0
- tqdm: 4.63.0
cc @carmocca @edward-io @ananthsub @rohitgr7 @kamil-kaczmarek @Raalsky @Blaizzy
The text was updated successfully, but these errors were encountered: