Skip to content

Switching into training mode in training_step #20216

@heth27

Description

@heth27

Bug description

training mode does not get switched on automatically in the training_step, so if the model is put into eval mode accidentally by the user it stays this way. This is nice for training for example with batchnorm in eval mode, but it is not clear from the "lightning in 15 minutes" that this is the intended behavior (see video screenshot), if it really is the desired behavior.

So not sure if it is a bug or it would be just useful to add a comment to the training_step() documentation

image

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import time
from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler

from src.main.ml.data.data_augmentation.helpers.random_numbers import create_rng_from_string


class TestModule(nn.Module):
    def __init__(self, in_dim=512, out_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)

    def forward(self, input):
        return self.simple_layer(input)


class TestBatchSampler(Sampler):
    def __init__(self, step=0):
        super().__init__()
        self.step = step

    def __len__(self) -> int:
        return 1e100

    def __iter__(self):  # -> Iterator[int]:
        return self

    def __next__(self):  # -> Iterator[int]:
        return_value = self.step
        self.step += 1
        return [return_value]


class TestDataset(Dataset):
    def __init__(self, in_dim):
        super().__init__()
        self.in_dim = in_dim
        self.total_len = 512

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        rng = create_rng_from_string(
            str(idx) + "_"
            + "random_items")
        return torch.tensor(rng.random(self.in_dim), dtype=torch.float32)


class TestDataModule(pl.LightningDataModule):
    def __init__(self, start_step=0):
        super().__init__()
        self.in_dim = 512
        self.val_batch_size = 1
        self.start_step = start_step

    def train_dataloader(self):
        train_ds = TestDataset(self.in_dim)
        train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(step=self.start_step), num_workers=4,
                              shuffle=False)
        return train_dl

    def val_dataloader(self):
        val_ds = TestDataset(self.in_dim)
        val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
        return val_dl


class TestLitModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        # self.test_module_obj = TestModule(in_dim=512, out_dim=16)
        self.test_module_obj = model
        # self.test_module_obj.eval()
        self.automatic_optimization = True

    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        print(f"training mode in validation_step:{self.test_module_obj.training}")
        return torch.tensor(1.0)

    def training_step(self, batch, batch_idx):
        print(f"training mode in training_step:{self.test_module_obj.training}")

        time.sleep(0.5)

        output = self.test_module_obj(batch)

        loss = output.sum()

        # for checking manual mode
        # optimizer = self.optimizers()
        # self.manual_backward(loss)
        #
        # optimizer.step()
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.test_module_obj.parameters()
        )
        return optimizer


if __name__ == '__main__':
    test_data_loader = TestDataModule()
    module_ = TestModule(in_dim=512, out_dim=16)
    module_.eval()
    test_lit_model = TestLitModel(module_)

    trainer = pl.Trainer(
        log_every_n_steps=1,
        max_epochs=-1,
        max_steps=20,
        val_check_interval=5,

    )

    trainer.fit(test_lit_model, test_data_loader)

Error messages and logs

# Error messages and logs here please
```Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]training mode in validation_step:False
Epoch 0: |          | 0/? [00:00<?, ?it/s] training mode in training_step:False
Epoch 0: |          | 1/? [00:00<00:00,  1.03it/s, v_num=14]training mode in training_step:False
Epoch 0: |          | 2/? [00:01<00:00,  1.36it/s, v_num=14]training mode in training_step:False
Epoch 0: |          | 3/? [00:01<00:00,  1.52it/s, v_num=14]training mode in training_step:False
Epoch 0: |          | 4/? [00:02<00:00,  1.62it/s, v_num=14]training mode in training_step:False
Epoch 0: |          | 5/? [00:02<00:00,  1.68it/s, v_num=14]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]training mode in validation_step:False



### Environment

<details>
  <summary>Current environment</summary>

#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):


</details>


### More info

_No response_

cc @lantiga @justusschock @borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions