Skip to content

Commit

Permalink
fix-pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Oct 23, 2020
1 parent ade68df commit 85ad318
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,24 +746,25 @@ def test_dataloader_distributed_sampler(tmpdir):
trainer.test(ckpt_path=None)


class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):

def train_dataloader(self):
dataloader = super().train_dataloader()
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False)
return DataLoader(
dataset,
batch_size=self.batch_size,
drop_last=False,
sampler=dist_sampler,
shuffle=False
)


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_distributed_sampler_already_attached(tmpdir):
""" Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """

class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):

def train_dataloader(self):
dataloader = super().train_dataloader()
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False)
return DataLoader(
dataset,
batch_size=self.batch_size,
drop_last=False,
sampler=dist_sampler,
shuffle=False
)

model = ModelWithDataLoaderDistributedSampler()
trainer = Trainer(
gpus=[0, 1],
Expand All @@ -778,7 +779,6 @@ def train_dataloader(self):
assert result == 1, "DDP Training failed"



@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test
Expand Down

0 comments on commit 85ad318

Please sign in to comment.