-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to load data every epoch #231
Comments
could you explain more? do you have pseudocode? |
What happens if you just don't use the |
@neggert that's the way to do it. I'll add this to the docs. |
@sadbb actually just submitted a PR to enable this. |
I think it is pretty standard to create a dataloader at the beginning of each epoch. I think it should be the default. |
it is already default. this PR is to support the non-default case |
One callout: When doing validation, |
The problem is here, in if test and len(self.get_test_dataloaders()) > 1:
args.append(dataloader_idx)
elif not test and len(self.get_val_dataloaders()) > 1:
args.append(dataloader_idx) Honestly, I never really liked passing different args to validation_step depending on the number of dataloaders anyway. Maybe we should think about changing the design slightly here. |
I strongly discourage removing the from torch.utils.data import Dataset
class DynamicSet(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def reset(self, dataset):
self.dataset = dataset or you can use indices if desired: from contextlib import contextmanager
from torch.utils.data import Subset
class DynamicSet(Subset):
def __init__(self, dataset, indices=None):
if indices is None:
indices = list(range(len(self.dataset)))
super().__init__(dataset, indices)
self.enumerated = False
def __getitem__(self, index):
out = super().__getitem__(index)
if self.enumerated:
out = (self.indices[index], out)
return out
def reveal(self, indices):
self.indices = list(set(self.indices).union(indices))
@property
def hidden(self):
return list(set(range(len(self.dataset))) - set(self.indices))
@contextmanager
def state(self, indices=None, enumerated=None):
try:
if indices is not None:
old_indices = self.indices
self.indices = indices
if enumerated is not None:
old_enumerated = self.enumerated
self.enumerated = enumerated
yield
finally:
if indices is not None:
self.indices = old_indices
if enumerated is not None:
self.enumerated = old_enumerated Usage example (active learning): # inside model.train_dataloader():
train_set = DynamicSet(dataset, indices=[])
# inside model.on_epoch_start():
train_set = model.train_dataloader().dataset
if not train_set.indices:
# use the first half of the dataset in the beginning
train_set.indices = list(range(len(train_set) / 2))
else:
# add to the dataset the hidden items that passes a threshold
with train_set.state(train_set.hidden, enumerated=True):
indices = [i for i, item in train_set if pass_threshold(item)]
train_set.reveal(indices) Side Note:
|
To me, it is still not entirely clear how to achieve this in the best way... In the past, I used Another option (if just different parts of the same dataset need to be used) would be |
hi,
because of my task, i must load new train_data every epoch. But in this package, data can only be loaded once at the beginning of training. How can i load data every epoch?
thanks.
The text was updated successfully, but these errors were encountered: