Skip to content

Commit 0e3e966

Browse files
authored
Raise error in DistributedProxySampler when sampler is already a DistributedSampler (#2120)
* Raise error in DistributedProxySampler when sampler is already a DistributedSampler * Add test for TypeError when passing DistributedSampler to DistributedProxySampler
1 parent 0f7819a commit 0e3e966

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

Diff for: ignite/distributed/auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: O
296296
if not isinstance(sampler, Sampler):
297297
raise TypeError(f"Argument sampler should be instance of torch Sampler, but given: {type(sampler)}")
298298

299+
if isinstance(sampler, DistributedSampler):
300+
raise TypeError("Argument sampler must not be a distributed sampler already")
301+
299302
if not hasattr(sampler, "__len__"):
300303
raise TypeError("Argument sampler should have length")
301304

Diff for: tests/ignite/distributed/test_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,6 @@ def test_dist_proxy_sampler():
273273

274274
with pytest.raises(TypeError, match=r"Argument sampler should have length"):
275275
DistributedProxySampler(Sampler([1]))
276+
277+
with pytest.raises(TypeError, match=r"Argument sampler must not be a distributed sampler already"):
278+
DistributedProxySampler(DistributedSampler(sampler, num_replicas=num_replicas, rank=0))

0 commit comments

Comments
 (0)