Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Fix dataloading for iterable datasets and limit_train_batches #7306

Merged
merged 22 commits into from
May 3, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update test_dataloaders.py
ananthsub committed May 3, 2021
commit f1d9e4d7260409f2f47d92ecc2ecec5cd69e5310
7 changes: 4 additions & 3 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -320,8 +320,9 @@ def test_datasets_dataloaders_with_limit_num_batches(

ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
epoch_cb = EpochCounter()
epochs = 2
trainer = Trainer(
max_epochs=1,
max_epochs=epochs,
callbacks=[epoch_cb, ckpt_callback],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
@@ -338,8 +339,8 @@ def test_datasets_dataloaders_with_limit_num_batches(
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches[0] == limit_val_batches
assert epoch_cb.train_epoch_count == int(limit_train_batches > 0)
assert epoch_cb.val_epoch_count == int(limit_val_batches > 0)
assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0)
assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0)

trainer.test(model, test_dataloaders=test_dl)
assert trainer.num_test_batches[0] == limit_test_batches