-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
Description
Bug description
training mode does not get switched on automatically in the training_step, so if the model is put into eval mode accidentally by the user it stays this way. This is nice for training for example with batchnorm in eval mode, but it is not clear from the "lightning in 15 minutes" that this is the intended behavior (see video screenshot), if it really is the desired behavior.
So not sure if it is a bug or it would be just useful to add a comment to the training_step() documentation
What version are you seeing the problem on?
v2.4
How to reproduce the bug
import time
from typing import Any
import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler
from src.main.ml.data.data_augmentation.helpers.random_numbers import create_rng_from_string
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.simple_layer(input)
class TestBatchSampler(Sampler):
def __init__(self, step=0):
super().__init__()
self.step = step
def __len__(self) -> int:
return 1e100
def __iter__(self): # -> Iterator[int]:
return self
def __next__(self): # -> Iterator[int]:
return_value = self.step
self.step += 1
return [return_value]
class TestDataset(Dataset):
def __init__(self, in_dim):
super().__init__()
self.in_dim = in_dim
self.total_len = 512
def __len__(self):
return 1
def __getitem__(self, idx):
rng = create_rng_from_string(
str(idx) + "_"
+ "random_items")
return torch.tensor(rng.random(self.in_dim), dtype=torch.float32)
class TestDataModule(pl.LightningDataModule):
def __init__(self, start_step=0):
super().__init__()
self.in_dim = 512
self.val_batch_size = 1
self.start_step = start_step
def train_dataloader(self):
train_ds = TestDataset(self.in_dim)
train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(step=self.start_step), num_workers=4,
shuffle=False)
return train_dl
def val_dataloader(self):
val_ds = TestDataset(self.in_dim)
val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
return val_dl
class TestLitModel(pl.LightningModule):
def __init__(self, model):
super().__init__()
# self.test_module_obj = TestModule(in_dim=512, out_dim=16)
self.test_module_obj = model
# self.test_module_obj.eval()
self.automatic_optimization = True
def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
print(f"training mode in validation_step:{self.test_module_obj.training}")
return torch.tensor(1.0)
def training_step(self, batch, batch_idx):
print(f"training mode in training_step:{self.test_module_obj.training}")
time.sleep(0.5)
output = self.test_module_obj(batch)
loss = output.sum()
# for checking manual mode
# optimizer = self.optimizers()
# self.manual_backward(loss)
#
# optimizer.step()
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.test_module_obj.parameters()
)
return optimizer
if __name__ == '__main__':
test_data_loader = TestDataModule()
module_ = TestModule(in_dim=512, out_dim=16)
module_.eval()
test_lit_model = TestLitModel(module_)
trainer = pl.Trainer(
log_every_n_steps=1,
max_epochs=-1,
max_steps=20,
val_check_interval=5,
)
trainer.fit(test_lit_model, test_data_loader)Error messages and logs
# Error messages and logs here please
```Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]training mode in validation_step:False
Epoch 0: | | 0/? [00:00<?, ?it/s] training mode in training_step:False
Epoch 0: | | 1/? [00:00<00:00, 1.03it/s, v_num=14]training mode in training_step:False
Epoch 0: | | 2/? [00:01<00:00, 1.36it/s, v_num=14]training mode in training_step:False
Epoch 0: | | 3/? [00:01<00:00, 1.52it/s, v_num=14]training mode in training_step:False
Epoch 0: | | 4/? [00:02<00:00, 1.62it/s, v_num=14]training mode in training_step:False
Epoch 0: | | 5/? [00:02<00:00, 1.68it/s, v_num=14]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]training mode in validation_step:False
### Environment
<details>
<summary>Current environment</summary>
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):
</details>
### More info
_No response_
cc @lantiga @justusschock @borda
