-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect result when a metric object is logged with PyTorch Lightning logging system #2231
Comments
Hi! thanks for your contribution!, great first issue! |
This is another issue but I also noticed that not providing |
Hi @laclouis5, thanks for reporting this issue. torchmetrics/tests/integrations/test_lightning.py Lines 496 to 537 in d88a2cc
and the last assert does not fail, meaning that logging the object and logging the value is equal for this basic example.
|
Here is a code snippet that seems to reproduce the issue on my machine. Note that I'm using a TensorBoard logger, so the issue can be seen in the TensorBoard time series. However, I think I found the problem. It looks like adding a from pathlib import Path
from typing import Any
import lightning
import torch
import torchmetrics
from lightning.pytorch.loggers import TensorBoardLogger
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
class CustomNetwork(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.layer = torch.nn.Linear(10, 5)
def forward(self, input):
return self.layer(input)
class CustomDataset(Dataset):
def __init__(self) -> None:
super().__init__()
def __getitem__(self, index):
return torch.randn(size=(10,)), 3
def __len__(self):
return 100
class CustomDataModule(lightning.LightningDataModule):
def __init__(self) -> None:
super().__init__()
self.train_dataset = CustomDataset()
self.valid_dataset = CustomDataset()
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=2,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)
def val_dataloader(self):
return DataLoader(
self.valid_dataset,
batch_size=2,
shuffle=False,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)
class CustomModel(lightning.LightningModule):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.model = CustomNetwork()
self.metric_1 = torchmetrics.Accuracy(
task="multiclass", average="macro", num_classes=5
)
self.metric_2 = torchmetrics.Accuracy(
task="multiclass", average="macro", num_classes=5
)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, input):
return self.model(input)
def configure_optimizers(self):
return AdamW(self.model.parameters(), lr=0.001)
def training_step(self, batch, index):
input, gt = batch
gt = gt.to(dtype=torch.long)
pred = self(input)
return self.loss_fn(pred, gt)
def validation_step(self, batch, index):
input, gt = batch
gt = gt.to(dtype=torch.long)
pred = self(input)
self.metric_1.update(pred, gt)
self.metric_2.update(pred, gt)
def on_validation_epoch_end(self):
self.log("M1", self.metric_1)
m2 = self.metric_2.compute()
self.log("M2", m2)
# This solves the issue
# self.metric_2.reset()
def main():
torch.use_deterministic_algorithms(True, warn_only=True)
lightning.seed_everything(31415, workers=True)
torch.set_float32_matmul_precision("medium")
Path("test_runs/").mkdir(exist_ok=True)
logger = TensorBoardLogger(save_dir=Path.cwd(), name="test_runs/")
trainer = lightning.Trainer(
precision="16-mixed",
logger=logger,
max_epochs=100,
deterministic="warn",
log_every_n_steps=1,
)
model = CustomModel()
datamodule = CustomDataModule()
trainer.fit(model, datamodule=datamodule)
if __name__ == "__main__":
main() |
I'm closing the issue since I think that the error is on my side here. |
🐛 Bug
I have a Lightning module very similar the one presented in the TorchMetrics PyTorch Lightning tutorial:
Basically, I log the metric object directly with
self.log
but this gives an incorrect result. If I manually compute the result with.compute()
then the calculation is correct.I read the Common Pitfalls section but none of the cases apply to my setting, I think.
My workflow is a little more complex, I'll try to sum up in the following section.
To Reproduce
I'm addressing a simple multi-class classification problem on images. For that, I'm using PyTorch Lightning and I got some standard classification metrics stored in a
MetricCollection
object. In the followingB
is the batch size andC
is the number of classes. (H
,W
) i the image size.Expected behavior
As explained above, this yield an incorrect result. If I change to this, then the result is correct:
As advertised in the tutorial, logging the metric object directly should yield the exact same result.
Environment
The text was updated successfully, but these errors were encountered: