Skip to content

Commit

Permalink
change to_NCHW to process numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Oct 16, 2024
1 parent 4501a4e commit cb93507
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
21 changes: 15 additions & 6 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,29 +212,38 @@ def max_min(input: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
return torch.cat((F.relu(input), F.relu(-input)), dim=dim)


def group_sort(input: torch.Tensor, group_size: Optional[int] = None, dim : int = 1) -> torch.Tensor:
def group_sort(
input: torch.Tensor, group_size: Optional[int] = None, dim: int = 1
) -> torch.Tensor:
r"""
Applies GroupSort activation on the given tensor.
See Also:
:py:func:`group_sort_2`
:py:func:`full_sort`
"""

if group_size is None or group_size > input.shape[dim]:
group_size = input.shape[dim]

if input.shape[dim] % group_size != 0:
raise ValueError("The input size must be a multiple of the group size.")

new_shape = input.shape[:dim]+(input.shape[dim]//group_size,group_size)+input.shape[dim+1:]
new_shape = (
input.shape[:dim]
+ (input.shape[dim] // group_size, group_size)
+ input.shape[dim + 1 :]
)
if group_size == 2:
resh_input = input.view(new_shape)
a, b = torch.min(resh_input, dim+1,keepdim=True)[0], torch.max(resh_input, dim+1,keepdim=True)[0]
return torch.cat([a, b], dim=dim+1).view(input.shape)
a, b = (
torch.min(resh_input, dim + 1, keepdim=True)[0],
torch.max(resh_input, dim + 1, keepdim=True)[0],
)
return torch.cat([a, b], dim=dim + 1).view(input.shape)
fv = input.reshape(new_shape)

return torch.sort(fv,dim=dim+1)[0].reshape(input.shape)
return torch.sort(fv, dim=dim + 1)[0].reshape(input.shape)


def group_sort_2(input: torch.Tensor) -> torch.Tensor:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_GroupSort(group_size, img, expected):
xnp = np.repeat(
np.expand_dims(np.repeat(np.expand_dims(xn, -1), 28, -1), -1), 28, -1
)
xnp = uft.to_NCHW_inv(xnp) # move channel if needed (TF)
xnp = uft.to_NCHW_inv(xnp) # move channel if needed (TF)
x = uft.to_tensor(xnp)
uft.build_layer(gs, (28, 28, 4))
y = gs(x).numpy()
Expand All @@ -141,11 +141,12 @@ def test_GroupSort(group_size, img, expected):
y_t = np.repeat(
np.expand_dims(np.repeat(np.expand_dims(y_tnp, -1), 28, -1), -1), 28, -1
)
y_t = uft.to_NCHW_inv(y_t) # move channel if needed (TF)
y_t = uft.to_NCHW_inv(y_t) # move channel if needed (TF)
# print("aaa",y_t.shape, y_t)
# print("aaab",y.shape, y)
np.testing.assert_equal(y, y_t)


@pytest.mark.parametrize("group_size", [2, 4])
def test_GroupSort_idempotence(group_size):
gs = uft.get_instance_framework(GroupSort, {"group_size": group_size})
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unconstrained_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@

def compare(x, x_ref, index_x=[], index_x_ref=[]):
"""Compare a tensor and its padded version, based on index_x and ref."""
x = uft.to_numpy(uft.to_NCHW(x))
x_ref = uft.to_numpy(uft.to_NCHW(x_ref))
x = uft.to_NCHW(uft.to_numpy(x))
x_ref = uft.to_NCHW(uft.to_numpy(x_ref))
x_cropped = x[:, :, index_x[0] : index_x[1], index_x[3] : index_x[4]][
:, :, :: index_x[2], :: index_x[5]
]
Expand Down
2 changes: 2 additions & 0 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,11 @@ def to_framework_channel(x):
def to_NCHW(x):
return x


def to_NCHW_inv(x):
return x


def get_NCHW(x):
return (x.shape[0], x.shape[1], x.shape[2], x.shape[3])

Expand Down

0 comments on commit cb93507

Please sign in to comment.