How to switch dataloader every n training steps #12415
-
| 
         Hi everyone, This is what I would like to achieve: I found this solution on the old forum but this only switches the dataset after each epoch. Here is my current attempt at switching it every n batches: class SimpleModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = ...
        self.batch_size = ...
        self.change_every_n_batch = 20
    
    def train_dataloader(self):
        self.current_dataset = (self.global_step // self.change_every_n_batch) % 2
        if self.current_dataset == 0:
            dataset = Dataset1()
        elif self.current_dataset == 1:
            dataset = Dataset2()
        dataloader = DataLoader(dataset, batch_size=self.batch_size)
        return dataloader
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        new_dataset = (self.global_step // self.change_every_n_batch) % 2
        if new_dataset != self.current_dataset:
            self.trainer.reset_train_dataloader(self)
 Any idea what could be going wrong? Or do you have a solution for what I want to achieve? Thanks!  | 
  
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
| 
         hey @matprst ! you can set: 
 and inside  def train_dataloader(self):
    if self.some_flag:
        dataset = Dataset1()
    else:
        dataset = Dataset2()
    self.some_flag = not self.some_flag
    return DataLoader(dataset, batch_size=self.batch_size) | 
  
Beta Was this translation helpful? Give feedback.
-
| 
         Works like a charm, and much cleaner than what I thought! Thanks for the reply! I realise now that since I am using iterable datasets (they are large and don't fit into memory), the reloading restarts the iterable from the beginning rather than continuing where it stopped (or at least returning a random batch). This is another problem with the dataset, so I will consider the question answered.  | 
  
Beta Was this translation helpful? Give feedback.
hey @matprst !
you can set:
limit_train_batches=n. This will ensure that every training epoch will progress for only n batchesreload_dataloaders_every_n_epochs=1. this will ensure that train dataloader is reloaded after every epoch.and inside
train_dataloader, flip the dataloader on each reload. something like: