@@ -39,7 +39,7 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
3939 key = keys [int (torch .randint (len (keys ), ()))]
4040 return key , dct [key ]
4141
42- def _check_support (self , input : Any ) -> None :
42+ def _check_unsupported (self , input : Any ) -> None :
4343 if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
4444 raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
4545
@@ -52,7 +52,7 @@ def fn(
5252 if type (input ) in {torch .Tensor , features .Image } or isinstance (input , PIL .Image .Image ):
5353 return id , input
5454
55- self ._check_support (input )
55+ self ._check_unsupported (input )
5656 return None
5757
5858 images = list (query_recursively (fn , sample ))
@@ -444,11 +444,8 @@ def forward(self, *inputs: Any) -> Any:
444444 else :
445445 magnitude = 0.0
446446
447- return _put_into_sample (
448- sample ,
449- id ,
450- self ._apply_image_transform (sample , transform_id , magnitude , interpolation = self .interpolation , fill = fill ),
451- )
447+ image = self ._apply_image_transform (image , transform_id , magnitude , interpolation = self .interpolation , fill = fill )
448+ return _put_into_sample (sample , id , image )
452449
453450
454451class AugMix (_AutoAugmentBase ):
@@ -543,7 +540,7 @@ def forward(self, *inputs: Any) -> Any:
543540 magnitude = 0.0
544541
545542 aug = self ._apply_image_transform (
546- image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
543+ aug , transform_id , magnitude , interpolation = self .interpolation , fill = fill
547544 )
548545 mix .add_ (combined_weights [:, i ].view (batch_dims ) * aug )
549546 mix = mix .view (orig_dims ).to (dtype = image .dtype )
0 commit comments