-
Notifications
You must be signed in to change notification settings - Fork 547
Optim-wip: Consolidate CenterCrop & improve ToRGB #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
e0c36eb
5f56668
b87a384
f279762
0fdc3e9
7aec698
8e21432
f3c8f9e
9127159
6fca51a
968dc01
95adc44
34b9a5c
40dbb08
91fddc4
d09a953
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import math | ||
| import numbers | ||
| from typing import List, Optional, Sequence, Tuple, Union | ||
| from typing import List, Optional, Sequence, Tuple, Union, cast | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -73,15 +73,21 @@ def i1i2i3_transform() -> torch.Tensor: | |
| ] | ||
| return torch.Tensor(i1i2i3_matrix) | ||
|
|
||
| def __init__(self, transform_name: str = "klt") -> None: | ||
| def __init__(self, transform_matrix: Union[str, torch.Tensor] = "klt") -> None: | ||
ProGamerGov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| super().__init__() | ||
ProGamerGov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if transform_name == "klt": | ||
| assert isinstance(transform_matrix, str) or torch.is_tensor(transform_matrix) | ||
| if torch.is_tensor(transform_matrix): | ||
| assert list(transform_matrix.shape) == [3, 3] | ||
ProGamerGov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.register_buffer("transform", transform_matrix) | ||
| elif transform_matrix == "klt": | ||
| self.register_buffer("transform", ToRGB.klt_transform()) | ||
| elif transform_name == "i1i2i3": | ||
| elif transform_matrix == "i1i2i3": | ||
| self.register_buffer("transform", ToRGB.i1i2i3_transform()) | ||
| else: | ||
| raise ValueError("transform_name has to be either 'klt' or 'i1i2i3'") | ||
| raise ValueError( | ||
| "transform_matrix has to be either 'klt', 'i1i2i3'," | ||
| + " or a matrix tensor." | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: | ||
| assert x.dim() == 3 or x.dim() == 4 | ||
|
|
@@ -118,57 +124,65 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor: | |
|
|
||
| class CenterCrop(torch.nn.Module): | ||
| """ | ||
| Center crop the specified amount of pixels from the edges. | ||
| Center crop a specified amount from a tensor | ||
| Arguments: | ||
| size (int, sequence) or (int): Number of pixels to center crop away. | ||
| pixels_from_edges (bool): Whether to treat crop size values as the number | ||
| of pixels from the tensor's edge, or an exact shape in the center. | ||
| """ | ||
|
|
||
| def __init__(self, size: TransformSize = 0) -> None: | ||
| def __init__(self, size: TransformSize = 0, pixels_from_edges: bool = True) -> None: | ||
ProGamerGov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| super(CenterCrop, self).__init__() | ||
| if type(size) is list or type(size) is tuple: | ||
| assert len(size) == 2, ( | ||
| "CenterCrop requires a single crop value or a tuple of (height,width)" | ||
| + "in pixels for cropping." | ||
| ) | ||
| self.crop_val = size | ||
| else: | ||
| self.crop_val = [size] * 2 | ||
| self.crop_vals = size | ||
| self.pixels_from_edges = pixels_from_edges | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| assert ( | ||
| input.dim() == 3 or input.dim() == 4 | ||
| ), "Input to CenterCrop must be 3D or 4D" | ||
| if input.dim() == 4: | ||
| h, w = input.size(2), input.size(3) | ||
| elif input.dim() == 3: | ||
| h, w = input.size(1), input.size(2) | ||
| h_crop = h - self.crop_val[0] | ||
| w_crop = w - self.crop_val[1] | ||
| sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2) | ||
| return input[..., sh : sh + h_crop, sw : sw + w_crop] | ||
| """ | ||
| Center crop an input. | ||
| Arguments: | ||
| input (torch.Tensor): Input to center crop. | ||
| Returns: | ||
| tensor (torch.Tensor): A center cropped tensor. | ||
| """ | ||
|
|
||
| return center_crop(input, self.crop_vals, self.pixels_from_edges) | ||
|
|
||
| def center_crop_shape(input: torch.Tensor, output_size: List[int]) -> torch.Tensor: | ||
|
|
||
| def center_crop( | ||
| input: torch.Tensor, crop_vals: TransformSize, pixels_from_edges: bool = True | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Crop NCHW & CHW outputs by specifying the desired output shape. | ||
| Center crop a specified amount from a tensor | ||
| Arguments: | ||
| input (tensor): A CHW or NCHW image tensor to center crop. | ||
| size (int, sequence) or (int): Number of pixels to center crop away. | ||
ProGamerGov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pixels_from_edges (bool): Whether to treat crop size values as the number | ||
ProGamerGov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| of pixels from the tensor's edge, or an exact shape in the center. | ||
| Returns: | ||
| *tensor*: A center cropped tensor. | ||
| """ | ||
|
|
||
| assert input.dim() == 4 or input.dim() == 3 | ||
| output_size = [output_size] if not hasattr(output_size, "__iter__") else output_size | ||
| assert len(output_size) == 1 or len(output_size) == 2 | ||
| output_size = output_size * 2 if len(output_size) == 1 else output_size | ||
| assert input.dim() == 3 or input.dim() == 4 | ||
| crop_vals = [crop_vals] if not hasattr(crop_vals, "__iter__") else crop_vals | ||
| crop_vals = cast(Union[List[int], Tuple[int], Tuple[int, int]], crop_vals) | ||
| assert len(crop_vals) == 1 or len(crop_vals) == 2 | ||
| crop_vals = crop_vals * 2 if len(crop_vals) == 1 else crop_vals | ||
|
|
||
| if input.dim() == 4: | ||
| h, w = input.size(2), input.size(3) | ||
| if input.dim() == 3: | ||
| h, w = input.size(1), input.size(2) | ||
|
|
||
| h_crop = h - int(round((h - output_size[0]) / 2.0)) | ||
| w_crop = w - int(round((w - output_size[1]) / 2.0)) | ||
|
|
||
| return input[ | ||
| ..., h_crop - output_size[0] : h_crop, w_crop - output_size[1] : w_crop | ||
| ] | ||
| if pixels_from_edges: | ||
| h_crop = h - crop_vals[0] | ||
| w_crop = w - crop_vals[1] | ||
| sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2) | ||
| x = input[..., sh : sh + h_crop, sw : sw + w_crop] | ||
| else: | ||
| h_crop = h - int(round((h - crop_vals[0]) / 2.0)) | ||
| w_crop = w - int(round((w - crop_vals[1]) / 2.0)) | ||
| x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop] | ||
| return x | ||
|
|
||
|
|
||
| def rand_select(transform_values: TransformValList) -> TransformVal: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.