Skip to content

self.log() does not work with model compiled by PyTorch 2.0 torch.compile() #16822

@ShayanPersonal

Description

@ShayanPersonal

Bug description

logging from a model compiled with torch.compile() appears to be broken.

How to reproduce the bug

All I've done is modified the example from the website to run on 1 GPU and compiled the model with torch.compile. I've tried both the release version and the lightning master branch version of the library and they both fail.

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class LitAutoEncoder(pl.LightningModule):
	def __init__(self):
		super().__init__()
		self.encoder = nn.Sequential(
      nn.Linear(28 * 28, 64),
      nn.ReLU(),
      nn.Linear(64, 3))
		self.decoder = nn.Sequential(
      nn.Linear(3, 64),
      nn.ReLU(),
      nn.Linear(64, 28 * 28))

	def forward(self, x):
		embedding = self.encoder(x)
		return embedding

	def configure_optimizers(self):
		optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
		return optimizer

	def training_step(self, train_batch, batch_idx):
		x, y = train_batch
		x = x.view(x.size(0), -1)
		z = self.encoder(x)    
		x_hat = self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		self.log('train_loss', loss)
		return loss

	def validation_step(self, val_batch, batch_idx):
		x, y = val_batch
		x = x.view(x.size(0), -1)
		z = self.encoder(x)
		x_hat = self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		self.log('val_loss', loss)

# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)

# model
model = LitAutoEncoder()
model = torch.compile(model)

# training
trainer = pl.Trainer(gpus=1, num_nodes=1, precision=16, limit_train_batches=0.5)
trainer.fit(model, train_loader, val_loader)
    

Error messages and logs

lightning.fabric.utilities.exceptions.MisconfigurationException: You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging

Environment

Current environment
Running in WSL Linux.

* CUDA:
        - GPU:
                - NVIDIA GeForce RTX 2080 Ti
        - available:         True
        - version:           11.8
* Lightning:
        - lightning:         2.0.0.dev0
        - lightning-cloud:   0.5.27
        - lightning-utilities: 0.6.0.post0
        - torch:             2.0.0.dev20230219
        - torchaudio:        2.0.0.dev20230219
        - torchdata:         0.7.0.dev20230219
        - torchmetrics:      0.11.1
        - torchtext:         0.15.0.dev20230219
        - torchvision:       0.15.0.dev20230219
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.10.9

More info

No response

cc @tchaton @carmocca @Blaizzy

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingloggingRelated to the `LoggerConnector` and `log()`priority: 0High priority tasktorch.compile

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions