@@ -256,3 +256,88 @@ def apply_recursively(obj: Any) -> Any:
256256 return obj
257257
258258 return apply_recursively (inputs if len (inputs ) > 1 else inputs [0 ])
259+
260+
261+ class RandomZoomOut (Transform ):
262+ def __init__ (
263+ self , fill : Union [float , Sequence [float ]] = 0.0 , side_range : Tuple [float , float ] = (1.0 , 4.0 ), p : float = 0.5
264+ ) -> None :
265+ super ().__init__ ()
266+
267+ if fill is None :
268+ fill = 0.0
269+ self .fill = fill
270+
271+ self .side_range = side_range
272+ if side_range [0 ] < 1.0 or side_range [0 ] > side_range [1 ]:
273+ raise ValueError (f"Invalid canvas side range provided { side_range } ." )
274+
275+ self .p = p
276+
277+ def _get_params (self , sample : Any ) -> Dict [str , Any ]:
278+ image = query_image (sample )
279+ orig_c , orig_h , orig_w = get_image_dimensions (image )
280+
281+ r = self .side_range [0 ] + torch .rand (1 ) * (self .side_range [1 ] - self .side_range [0 ])
282+ canvas_width = int (orig_w * r )
283+ canvas_height = int (orig_h * r )
284+
285+ r = torch .rand (2 )
286+ left = int ((canvas_width - orig_w ) * r [0 ])
287+ top = int ((canvas_height - orig_h ) * r [1 ])
288+ right = canvas_width - (left + orig_w )
289+ bottom = canvas_height - (top + orig_h )
290+ padding = [left , top , right , bottom ]
291+
292+ fill = self .fill
293+ if not isinstance (fill , collections .abc .Sequence ):
294+ fill = [fill ] * orig_c
295+
296+ return dict (padding = padding , fill = fill )
297+
298+ def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
299+ if isinstance (input , features .Image ) or is_simple_tensor (input ):
300+ # PyTorch's pad supports only integers on fill. So we need to overwrite the colour
301+ output = F .pad_image_tensor (input , params ["padding" ], fill = 0 , padding_mode = "constant" )
302+
303+ left , top , right , bottom = params ["padding" ]
304+ fill = torch .tensor (params ["fill" ], dtype = input .dtype , device = input .device ).to ().view (- 1 , 1 , 1 )
305+
306+ if top > 0 :
307+ output [..., :top , :] = fill
308+ if left > 0 :
309+ output [..., :, :left ] = fill
310+ if bottom > 0 :
311+ output [..., - bottom :, :] = fill
312+ if right > 0 :
313+ output [..., :, - right :] = fill
314+
315+ if isinstance (input , features .Image ):
316+ output = features .Image .new_like (input , output )
317+
318+ return output
319+ elif isinstance (input , PIL .Image .Image ):
320+ return F .pad_image_pil (
321+ input ,
322+ params ["padding" ],
323+ fill = tuple (int (v ) if input .mode != "F" else v for v in params ["fill" ]),
324+ padding_mode = "constant" ,
325+ )
326+ elif isinstance (input , features .BoundingBox ):
327+ output = F .pad_bounding_box (input , params ["padding" ], format = input .format )
328+
329+ left , top , right , bottom = params ["padding" ]
330+ height , width = input .image_size
331+ height += top + bottom
332+ width += left + right
333+
334+ return features .BoundingBox .new_like (input , output , image_size = (height , width ))
335+ else :
336+ return input
337+
338+ def forward (self , * inputs : Any ) -> Any :
339+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
340+ if torch .rand (1 ) >= self .p :
341+ return sample
342+
343+ return super ().forward (sample )
0 commit comments