@@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module):
77 # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
88 def forward (self , img1 , img2 , flow , valid_flow_mask ):
99
10- assert all (isinstance (arg , torch .Tensor ) for arg in (img1 , img2 , flow , valid_flow_mask ) if arg is not None )
11- assert all (arg .dtype == torch .float32 for arg in (img1 , img2 , flow ) if arg is not None )
10+ if not all (isinstance (arg , torch .Tensor ) for arg in (img1 , img2 , flow , valid_flow_mask ) if arg is not None ):
11+ raise TypeError ("This method expects all input arguments to be of type torch.Tensor." )
12+ if not all (arg .dtype == torch .float32 for arg in (img1 , img2 , flow ) if arg is not None ):
13+ raise TypeError ("This method expects the tensors img1, img2 and flow of be of dtype torch.float32." )
1214
13- assert img1 .shape == img2 .shape
15+ if img1 .shape != img2 .shape :
16+ raise ValueError ("img1 and img2 should have the same shape." )
1417 h , w = img1 .shape [- 2 :]
15- if flow is not None :
16- assert flow .shape == (2 , h , w )
18+ if flow is not None and flow . shape != ( 2 , h , w ) :
19+ raise ValueError ( f" flow.shape should be (2, { h } , { w } ) instead of { flow . shape } " )
1720 if valid_flow_mask is not None :
18- assert valid_flow_mask .shape == (h , w )
19- assert valid_flow_mask .dtype == torch .bool
21+ if valid_flow_mask .shape != (h , w ):
22+ raise ValueError (f"valid_flow_mask.shape should be ({ h } , { w } ) instead of { valid_flow_mask .shape } " )
23+ if valid_flow_mask .dtype != torch .bool :
24+ raise TypeError ("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}" )
2025
2126 return img1 , img2 , flow , valid_flow_mask
2227
@@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing):
109114 def __init__ (self , p = 0.5 , scale = (0.02 , 0.33 ), ratio = (0.3 , 3.3 ), value = 0 , inplace = False , max_erase = 1 ):
110115 super ().__init__ (p = p , scale = scale , ratio = ratio , value = value , inplace = inplace )
111116 self .max_erase = max_erase
112- assert self .max_erase > 0
117+ if self .max_erase <= 0 :
118+ raise ValueError ("max_raise should be greater than 0" )
113119
114120 def forward (self , img1 , img2 , flow , valid_flow_mask ):
115121 if torch .rand (1 ) > self .p :
0 commit comments