@@ -326,3 +326,78 @@ def forward(
326326 )
327327
328328 return image , target
329+
330+
331+ class FixedSizeCrop (nn .Module ):
332+ def __init__ (self , size , fill = 0 , padding_mode = "constant" ):
333+ super ().__init__ ()
334+ size = tuple (T ._setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." ))
335+ self .crop_height = size [0 ]
336+ self .crop_width = size [1 ]
337+ self .fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
338+ self .padding_mode = padding_mode
339+
340+ def _pad (self , img , target , padding ):
341+ # Taken from the functional_tensor.py pad
342+ if isinstance (padding , int ):
343+ pad_left = pad_right = pad_top = pad_bottom = padding
344+ elif len (padding ) == 1 :
345+ pad_left = pad_right = pad_top = pad_bottom = padding [0 ]
346+ elif len (padding ) == 2 :
347+ pad_left = pad_right = padding [0 ]
348+ pad_top = pad_bottom = padding [1 ]
349+ else :
350+ pad_left = padding [0 ]
351+ pad_top = padding [1 ]
352+ pad_right = padding [2 ]
353+ pad_bottom = padding [3 ]
354+
355+ padding = [pad_left , pad_top , pad_right , pad_bottom ]
356+ img = F .pad (img , padding , self .fill , self .padding_mode )
357+ if target is not None :
358+ target ["boxes" ][:, 0 ::2 ] += pad_left
359+ target ["boxes" ][:, 1 ::2 ] += pad_top
360+ if "masks" in target :
361+ target ["masks" ] = F .pad (target ["masks" ], padding , 0 , "constant" )
362+
363+ return img , target
364+
365+ def _crop (self , img , target , top , left , height , width ):
366+ img = F .crop (img , top , left , height , width )
367+ if target is not None :
368+ boxes = target ["boxes" ]
369+ boxes [:, 0 ::2 ] -= left
370+ boxes [:, 1 ::2 ] -= top
371+ boxes [:, 0 ::2 ].clamp_ (min = 0 , max = width )
372+ boxes [:, 1 ::2 ].clamp_ (min = 0 , max = height )
373+
374+ is_valid = (boxes [:, 0 ] < boxes [:, 2 ]) & (boxes [:, 1 ] < boxes [:, 3 ])
375+
376+ target ["boxes" ] = boxes [is_valid ]
377+ target ["labels" ] = target ["labels" ][is_valid ]
378+ if "masks" in target :
379+ target ["masks" ] = F .crop (target ["masks" ][is_valid ], top , left , height , width )
380+
381+ return img , target
382+
383+ def forward (self , img , target = None ):
384+ _ , height , width = F .get_dimensions (img )
385+ new_height = min (height , self .crop_height )
386+ new_width = min (width , self .crop_width )
387+
388+ if new_height != height or new_width != width :
389+ offset_height = max (height - self .crop_height , 0 )
390+ offset_width = max (width - self .crop_width , 0 )
391+
392+ r = torch .rand (1 )
393+ top = int (offset_height * r )
394+ left = int (offset_width * r )
395+
396+ img , target = self ._crop (img , target , top , left , new_height , new_width )
397+
398+ pad_bottom = max (self .crop_height - new_height , 0 )
399+ pad_right = max (self .crop_width - new_width , 0 )
400+ if pad_bottom != 0 or pad_right != 0 :
401+ img , target = self ._pad (img , target , [0 , 0 , pad_right , pad_bottom ])
402+
403+ return img , target
0 commit comments