Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend val_check_interval support #8135

Closed
amoshyc opened this issue Jun 25, 2021 · 22 comments · Fixed by #11993
Closed

Extend val_check_interval support #8135

amoshyc opened this issue Jun 25, 2021 · 22 comments · Fixed by #11993
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task
Milestone

Comments

@amoshyc
Copy link

amoshyc commented Jun 25, 2021

🐛 Bug

When I train the model by specifying number of training step instead of epoch, val_check_interval behaves strangely. Please see the following colab:

https://colab.research.google.com/drive/1I0ySRH03T9LdXHoGwCp3Q242dHEHWP_0?usp=sharing

In the code, I log the global_step on each validation.

class Model(pl.LightningModule):
    ...

    def on_validation_start(self):
        print('global_step:', self.global_step)

I set Trainer's max_steps to 100 and val_check_interval to 10.
But when I run the cells In[4] and In[5], the outputs are different.
The only different between In[4] and In[5] is the number of samples of the dataset which should not be the reason.

In[4]:

train_set = RandomDataset(1, 40)
valid_set = RandomDataset(1, 40)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=2)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=2)

trainer = pl.Trainer(
    gpus=1,
    max_steps=100,
    val_check_interval=10,
    num_sanity_val_steps=0,
    log_every_n_steps=10,
    progress_bar_refresh_rate=0,
)
model = Model()
trainer.fit(model, train_loader, valid_loader)

Out[4]:

global_step: 9
global_step: 19
global_step: 29
global_step: 39
global_step: 49
global_step: 59
global_step: 69
global_step: 79
global_step: 89
global_step: 99

In[5]:

train_set = RandomDataset(1, 32)
valid_set = RandomDataset(1, 32)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=2)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=2)

trainer = pl.Trainer(
    gpus=1,
    max_steps=100,
    val_check_interval=10,
    num_sanity_val_steps=0,
    log_every_n_steps=10,
    progress_bar_refresh_rate=0,
)
model = Model()
trainer.fit(model, train_loader, valid_loader)

Out[5]:

global_step: 9
global_step: 25
global_step: 41
global_step: 57
global_step: 73
global_step: 89

Expected behavior

Since I specify max_steps and set val_check_interval to a integer, I expect that result is the same as Out[4] no matter the number of samples in the dataset. The doc says that val_check_interval specifies the number of training step between validations, so Out[5] should be the same as Out[4].

I also expect that the number of times validation performed should be the same. BTW, the x-axis in the tensorboard are also wrong. You can see that in Out[9].

Environment

* CUDA:
	- GPU:
		- Tesla T4
	- available:         True
	- version:           10.2
* Packages:
	- numpy:             1.19.5
	- pyTorch_debug:     False
	- pyTorch_version:   1.9.0+cu102
	- pytorch-lightning: 1.3.7post0
	- tqdm:              4.41.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- 
	- processor:         x86_64
	- python:            3.7.10
	- version:           #1 SMP Sat Jun 5 09:50:34 PDT 2021

cc @Borda @tchaton

@amoshyc amoshyc added bug Something isn't working help wanted Open to be worked on labels Jun 25, 2021
@amoshyc amoshyc changed the title global_step in valiation is wrong when setting val_check_interval to a integer. global_step in validation is wrong when setting val_check_interval to a integer. Jun 25, 2021
@tchaton tchaton added priority: 1 Medium priority task good first issue Good for newcomers labels Jun 28, 2021
@tchaton
Copy link
Contributor

tchaton commented Jun 28, 2021

Dear @amoshyc,

Thanks for reporting this issue. Would be interested in creating a fix PR ?
I believe this shouldn't be too hard and a great way to learn more about Lightning.

Best.
T.C

@MirMustafaAli
Copy link

MirMustafaAli commented Jul 2, 2021

Hi @tchaton ,
Can i work on this issue as it is labelled good first issue. ?

@amoshyc
Copy link
Author

amoshyc commented Jul 2, 2021

Sorry for the late reply. I'm so busy doing my research. I haven't have time to fix it.
The last time I check the code, it seems that the bug is here and someone already notice it.
But I'm not quite sure. I'm not working on it.

@tchaton
Copy link
Contributor

tchaton commented Jul 6, 2021

Dear @MirMustafaAli,

Yes, feel free to work on this one and ping us if you get blocked :)

Best,
T.C

@MirMustafaAli
Copy link

Thanks

@MirMustafaAli
Copy link

MirMustafaAli commented Jul 27, 2021

@tchaton
The issue seems to be in method _should_check_fx where it batch idx of current epoch is being passed as parameter . i was able to resolve it by passing total_batch_idx of epoch training loop to the method.

https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/loops/epoch/training_epoch_loop.py#L169-L170

Feedback and how to proceed required on solution

@carmocca
Copy link
Contributor

carmocca commented Jul 28, 2021

train_set = RandomDataset(1, 32)
valid_set = RandomDataset(1, 32)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=2)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=2)

Given this definition, your dataset has 32 batches. With batch_size=2, that gives us 32/2=16 batches per epoch.

Since you set val_check_interval=10, and there are 16 batches, you only have one chance to run validation in the epoch (in the 10th batch, index 9).

The issue here is that the docs definition is ambiguous

            val_check_interval: How often to check the validation set. Use float to check within a training epoch,
                use int to check every n steps (batches).

"use int to check every n steps (batches)" does not explicitly state whether this is within an epoch or between epochs. The current implementation is for within an epoch.

Changing this would be a breaking change, and I can imagine users wanting to be able to choose which one to do.

cc @PyTorchLightning/core-contributors

Repro bug report model
import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(1, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_validation_start(self):
        print('global_step:', self.global_step)


def run():
    train_set = RandomDataset(1, 32)
    valid_set = RandomDataset(1, 32)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=2)
    valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_steps=100,
        val_check_interval=10,
        num_sanity_val_steps=0,
        progress_bar_refresh_rate=0,
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)


if __name__ == "__main__":
    run()

@awaelchli
Copy link
Contributor

@carmocca note there is also the check_val_every_n_epochs which is currently not mutually exclusive with val_check_interval if I understand correctly. If we did exactly what the docs said, we would have to make them mutually exclusive somehow, right?

@carmocca
Copy link
Contributor

carmocca commented Jul 28, 2021

we would have to make them mutually exclusive somehow, right?

I think so.

These are all the options

Also similar to #12000

@MirMustafaAli
Copy link

MirMustafaAli commented Jul 28, 2021

@carmocca is it possible to have one option prefered over another?

if check_val_every_n_epoch has been provided it automatically supersedes val_check_interval. making it generic and simple for user to use and understand.

These are few things we might have to change. Should we implement it

  • We will have to update docs stating it explicitly.

@tchaton
Copy link
Contributor

tchaton commented Jul 29, 2021

Hey @MirMustafaAli,

I believe the simplest to implement to resolve this issue is as shared by @carmocca.

check every n batches across epochs: Not possible. Potentially val_check_interval=n, check_val_every_n_epoch=None?

@MirMustafaAli
Copy link

@tchaton
will we explicitly explain it in docs?

@carmocca
Copy link
Contributor

will we explicitly explain it in docs?

Yes, we should describe all the possible options and the flags to set to achieve each of them

@MirMustafaAli
Copy link

Hey @MirMustafaAli,

I believe the simplest to implement to resolve this issue is as shared by @carmocca.

check every n batches across epochs: Not possible. Potentially val_check_interval=n, check_val_every_n_epoch=None?

Shall i start with this options as a start?

@stale
Copy link

stale bot commented Aug 30, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Aug 30, 2021
@stale stale bot closed this as completed Sep 7, 2021
@csvance
Copy link

csvance commented Dec 8, 2021

So if I understand correctly, right now there is no good way to handle training with max_steps & val_check_interval where max_steps > len(dataloader)?

Ideally we want to always trigger validation after the final step. So for instance, if I train 30,000 steps, validate every 3000 steps, and the length of the dataloader is 20,000, validation will currently happen after the following steps:

  • 2999
  • 5999
  • 9999
  • 11999
  • 14999
  • 17999
  • 22999
  • 25999
  • 28999

Right now for this workflow, I manually run validation after trainer.fit finishes and save an additional checkpoint after step 29999.

In terms of potential workarounds, I could have my training dataset lie about its length and report some rediculously high number, then take the modulo of the requested index with the true size of the dataset in order to make it valid. Then all steps would fit inside a single "epoch" and things should just work. On the other hand, even if the dataset reports a multiple of its original length, this still has an impact on sampling as you no longer train on every single sample once before training on them again. This could be addressed by using a custom sampler which is aware of the original dataset length:

class PsuedoEpochSampler(object):
    def __init__(self, l: int, n: int):
        """
        Parameters
        ----------
        l
            Length of the dataset (multiple of n and the original length of the dataset)
        n
            Dataset reports its original length multiplied by this number
        """
        assert l % n == 0
        self._l = l // n
        self._n = n
        self._g = None

    def _generator(self):
        idx = np.arange(self._l)
        for ni in range(0, self._n):
            np.random.shuffle(idx)
            for idxi in idx:
                yield int(idxi)

    def __iter__(self):
        self._g = self._generator()
        return self

    def __next__(self):
        return next(self._g)

    def __len__(self):
        return self._l*self._n

class PsuedoEpochSamplerDistributed(object):
    def __init__(self, l: int, n: int, s: int = 0):
        """
        Parameters
        ----------
        l
            Length of the dataset (multiple of n and the original length of the dataset)
        n
            Dataset reports its original length multiplied by this number
        """
        assert l % n == 0
        self._l = l // n
        self._n = n
        self._g = None

        idx = np.arange(self._l)

        # Initial shuffle for splitting dataset inbetween ranks
        state_o = np.random.get_state()
        np.random.seed(s)
        np.random.shuffle(idx)
        np.random.set_state(state_o)

        self._idx = np.array_split(idx, dist.get_world_size())
        self._rank_len = min([len(_) for _ in self._idx])
        for i in range(0, len(self._idx)):
            self._idx[i] = self._idx[i][:self._rank_len]

    def _generator(self):
        idx = self._idx[dist.get_rank()].copy()
        for ni in range(0, self._n):
            np.random.shuffle(idx)
            for idxi in idx:
                yield int(idxi)

    def __iter__(self):
        self._g = self._generator()
        return self

    def __next__(self):
        return next(self._g)

    def __len__(self):
        return self._n*self._rank_len

@carmocca carmocca added feature Is an improvement or enhancement and removed bug Something isn't working good first issue Good for newcomers labels Dec 15, 2021
@carmocca carmocca added this to the 1.7 milestone Dec 15, 2021
@carmocca carmocca reopened this Dec 15, 2021
@stale stale bot removed the won't fix This will not be worked on label Dec 15, 2021
@mitchelldehaven
Copy link

What's the status on this? Like @csvance pointed out, whenever a new epoch starts it throws of doing validation by update step count.

@carmocca
Copy link
Contributor

carmocca commented Jan 3, 2022

It's labeled as "help wanted", meaning, waiting for somebody to start working on it.

@mitchelldehaven
Copy link

Is it fine if I start working on it then?

@carmocca
Copy link
Contributor

carmocca commented Jan 3, 2022

Yes. You can take one of the not possible/implemented options I described here:
#8135 (comment)
and work on it.

@mitchelldehaven
Copy link

Okay, I think I will work on check every n batches across epochs: Not implemented. What this issue is requesting. Potentially val_check_interval=n, check_val_every_n_epoch=None, since that is the use case I am interested in.

@carmocca carmocca changed the title global_step in validation is wrong when setting val_check_interval to a integer. Extend val_check_interval support Jan 12, 2022
@carmocca carmocca modified the milestones: 1.7, 1.6 Feb 1, 2022
@carmocca carmocca modified the milestones: 1.6, 1.7 Feb 16, 2022
@yuvalkirstain
Copy link

@carmocca Any news about this issue? This feature is really useful for scenarios in which the training set is small and validation is expensive (e.g. few-shot for generation tasks).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task
Projects
No open projects
Status: Done
9 participants