@@ -367,24 +367,21 @@ def test_dice_score(pred, target, expected):
367
367
assert score == expected
368
368
369
369
370
- @pytest .mark .parametrize (['target' , 'pred' , 'half_ones' , 'reduction' , 'remove_bg' , 'expected' ], [
371
- pytest .param ((torch .arange (120 ) % 3 ).view (- 1 , 1 ), (torch .arange (120 ) % 3 ).view (- 1 , 1 ),
372
- False , 'none' , False , torch .Tensor ([1 , 1 , 1 ])),
373
- pytest .param ((torch .arange (120 ) % 3 ).view (- 1 , 1 ), (torch .arange (120 ) % 3 ).view (- 1 , 1 ),
374
- False , 'elementwise_mean' , False , torch .Tensor ([1 ])),
375
- pytest .param ((torch .arange (120 ) % 3 ).view (- 1 , 1 ), (torch .arange (120 ) % 3 ).view (- 1 , 1 ),
376
- False , 'none' , True , torch .Tensor ([1 , 1 ])),
377
- pytest .param ((torch .arange (120 ) % 3 ).view (- 1 , 1 ), (torch .arange (120 ) % 3 ).view (- 1 , 1 ),
378
- True , 'none' , False , torch .Tensor ([0.5 , 0.5 , 0.5 ])),
379
- pytest .param ((torch .arange (120 ) % 3 ).view (- 1 , 1 ), (torch .arange (120 ) % 3 ).view (- 1 , 1 ),
380
- True , 'elementwise_mean' , False , torch .Tensor ([0.5 ])),
381
- pytest .param ((torch .arange (120 ) % 3 ).view (- 1 , 1 ), (torch .arange (120 ) % 3 ).view (- 1 , 1 ),
382
- True , 'none' , True , torch .Tensor ([0.5 , 0.5 ])),
370
+ @pytest .mark .parametrize (['half_ones' , 'reduction' , 'remove_bg' , 'expected' ], [
371
+ pytest .param (False , 'none' , False , torch .Tensor ([1 , 1 , 1 ])),
372
+ pytest .param (False , 'elementwise_mean' , False , torch .Tensor ([1 ])),
373
+ pytest .param (False , 'none' , True , torch .Tensor ([1 , 1 ])),
374
+ pytest .param (True , 'none' , False , torch .Tensor ([0.5 , 0.5 , 0.5 ])),
375
+ pytest .param (True , 'elementwise_mean' , False , torch .Tensor ([0.5 ])),
376
+ pytest .param (True , 'none' , True , torch .Tensor ([0.5 , 0.5 ])),
383
377
])
384
- def test_iou (target , pred , half_ones , reduction , remove_bg , expected ):
378
+ def test_iou (half_ones , reduction , remove_bg , expected ):
379
+ pred = (torch .arange (120 ) % 3 ).view (- 1 , 1 )
380
+ target = (torch .arange (120 ) % 3 ).view (- 1 , 1 )
385
381
if half_ones :
386
382
pred [:60 ] = 1
387
- assert torch .all (torch .eq (iou (pred , target , remove_bg = remove_bg , reduction = reduction ), expected ))
383
+ iou_val = iou (pred , target , remove_bg = remove_bg , reduction = reduction )
384
+ assert torch .allclose (iou_val , expected , atol = 1e-9 )
388
385
389
386
390
387
# example data taken from
0 commit comments