@@ -358,10 +358,10 @@ class RandomScale(nn.Module):
358358        "scale" ,
359359        "mode" ,
360360        "align_corners" ,
361-         "has_align_corners " ,
361+         "_has_align_corners " ,
362362        "recompute_scale_factor" ,
363-         "has_recompute_scale_factor " ,
364-         "is_distribution " ,
363+         "_has_recompute_scale_factor " ,
364+         "_is_distribution " ,
365365    ]
366366
367367    def  __init__ (
@@ -390,8 +390,9 @@ def __init__(
390390        super ().__init__ ()
391391        if  isinstance (scale , torch .distributions .distribution .Distribution ):
392392            # Distributions are not supported by TorchScript / JIT yet 
393+             assert  scale .batch_shape  ==  torch .Size ([])
393394            self .scale_distribution  =  scale 
394-             self .is_distribution  =  True 
395+             self ._is_distribution  =  True 
395396            self .scale  =  []
396397        else :
397398            assert  hasattr (scale , "__iter__" )
@@ -400,33 +401,12 @@ def __init__(
400401                scale  =  scale .tolist ()
401402            assert  len (scale ) >  0 
402403            self .scale  =  [float (s ) for  s  in  scale ]
403-             self .is_distribution  =  False 
404+             self ._is_distribution  =  False 
404405        self .mode  =  mode 
405406        self .align_corners  =  align_corners  if  mode  not  in   ["nearest" , "area" ] else  None 
406407        self .recompute_scale_factor  =  recompute_scale_factor 
407-         self .has_align_corners  =  torch .__version__  >=  "1.3.0" 
408-         self .has_recompute_scale_factor  =  torch .__version__  >=  "1.6.0" 
409- 
410-     def  _get_scale_mat (
411-         self ,
412-         m : float ,
413-         device : torch .device ,
414-         dtype : torch .dtype ,
415-     ) ->  torch .Tensor :
416-         """ 
417-         Create a scale matrix tensor. 
418- 
419-         Args: 
420- 
421-             m (float): The scale value to use. 
422- 
423-         Returns: 
424-             **scale_mat** (torch.Tensor): A scale matrix. 
425-         """ 
426-         scale_mat  =  torch .tensor (
427-             [[m , 0.0 , 0.0 ], [0.0 , m , 0.0 ]], device = device , dtype = dtype 
428-         )
429-         return  scale_mat 
408+         self ._has_align_corners  =  torch .__version__  >=  "1.3.0" 
409+         self ._has_recompute_scale_factor  =  torch .__version__  >=  "1.6.0" 
430410
431411    def  _scale_tensor (self , x : torch .Tensor , scale : float ) ->  torch .Tensor :
432412        """ 
@@ -440,8 +420,8 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
440420        Returns: 
441421            **x** (torch.Tensor): A scaled NCHW image tensor. 
442422        """ 
443-         if  self .has_align_corners :
444-             if  self .has_recompute_scale_factor :
423+         if  self ._has_align_corners :
424+             if  self ._has_recompute_scale_factor :
445425                x  =  F .interpolate (
446426                    x ,
447427                    scale_factor = scale ,
@@ -472,8 +452,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
472452            **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. 
473453        """ 
474454        assert  x .dim () ==  4 
475-         if  self .is_distribution :
476-             scale  =  self .scale_distribution .sample ().item ()
455+         if  self ._is_distribution :
456+             scale  =  float ( self .scale_distribution .sample ().item () )
477457        else :
478458            n  =  int (
479459                torch .randint (
@@ -508,8 +488,8 @@ class RandomScaleAffine(nn.Module):
508488        "mode" ,
509489        "padding_mode" ,
510490        "align_corners" ,
511-         "has_align_corners " ,
512-         "is_distribution " ,
491+         "_has_align_corners " ,
492+         "_is_distribution " ,
513493    ]
514494
515495    def  __init__ (
@@ -539,8 +519,9 @@ def __init__(
539519        super ().__init__ ()
540520        if  isinstance (scale , torch .distributions .distribution .Distribution ):
541521            # Distributions are not supported by TorchScript / JIT yet 
522+             assert  scale .batch_shape  ==  torch .Size ([])
542523            self .scale_distribution  =  scale 
543-             self .is_distribution  =  True 
524+             self ._is_distribution  =  True 
544525            self .scale  =  []
545526        else :
546527            assert  hasattr (scale , "__iter__" )
@@ -549,11 +530,11 @@ def __init__(
549530                scale  =  scale .tolist ()
550531            assert  len (scale ) >  0 
551532            self .scale  =  [float (s ) for  s  in  scale ]
552-             self .is_distribution  =  False 
533+             self ._is_distribution  =  False 
553534        self .mode  =  mode 
554535        self .padding_mode  =  padding_mode 
555536        self .align_corners  =  align_corners 
556-         self .has_align_corners  =  torch .__version__  >=  "1.3.0" 
537+         self ._has_align_corners  =  torch .__version__  >=  "1.3.0" 
557538
558539    def  _get_scale_mat (
559540        self ,
@@ -591,7 +572,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
591572        scale_matrix  =  self ._get_scale_mat (scale , x .device , x .dtype )[None , ...].repeat (
592573            x .shape [0 ], 1 , 1 
593574        )
594-         if  self .has_align_corners :
575+         if  self ._has_align_corners :
595576            # Pass align_corners explicitly for torch >= 1.3.0 
596577            grid  =  F .affine_grid (
597578                scale_matrix , x .size (), align_corners = self .align_corners 
@@ -620,8 +601,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
620601            **x** (torch.Tensor): A randomly scaled NCHW image *tensor*. 
621602        """ 
622603        assert  x .dim () ==  4 
623-         if  self .is_distribution :
624-             scale  =  self .scale_distribution .sample ().item ()
604+         if  self ._is_distribution :
605+             scale  =  float ( self .scale_distribution .sample ().item () )
625606        else :
626607            n  =  int (
627608                torch .randint (
@@ -1021,7 +1002,7 @@ class TransformationRobustness(nn.Module):
10211002
10221003    def  __init__ (
10231004        self ,
1024-         padding_transform : Optional [nn .Module ] =  None ,
1005+         padding_transform : Optional [nn .Module ] =  nn . ConstantPad2d ( 2 ,  value = 0.5 ) ,
10251006        translate : Optional [Union [int , List [int ]]] =  [4 ] *  10 ,
10261007        scale : Optional [NumSeqOrTensorOrProbDistType ] =  [
10271008            0.995  **  n  for  n  in  range (- 5 , 80 )
@@ -1039,7 +1020,7 @@ def __init__(
10391020
10401021            padding_transform (nn.Module, optional): A padding module instance. No 
10411022                padding will be applied before transforms if set to None. 
1042-                 Default: None  
1023+                 Default: nn.ConstantPad2d(2, value=0.5)  
10431024            translate (int or list of int, optional): The max horizontal and vertical 
10441025                 translation to use for each jitter transform. 
10451026                 Default: [4] * 10 
0 commit comments