Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 66 additions & 91 deletions equinox/nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,6 @@ def parse(x: Any) -> tuple:
return parse


def compute_adjusted_padding(
input_size: int,
kernel_size: int,
stride: int,
padding: int,
output_padding: int,
dilation: int,
) -> Tuple[int, int]:
"""Computes adjusted padding for desired ConvTranspose `output_padding`."""
kernel_size = (kernel_size - 1) * dilation + 1
output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding
if padding == 0:
expected_input_size = (output_size - kernel_size + stride) // stride
if input_size != expected_input_size:
raise ValueError(
f"The expected input size with the current set of input "
f"parameters is {expected_input_size} which doesn't "
f"match the actual input size {input_size}."
)
padding_before = 0
elif padding == 1:
expected_input_size = (output_size + stride - 1) // stride
if input_size != expected_input_size:
raise ValueError(
f"The expected input size with the current set of input "
f"parameters is {expected_input_size} which doesn't "
f"match the actual input size {input_size}."
)
padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size)
padding_before = padding_needed // 2
else:
raise ValueError(f"`padding` must be '0' or '1'. Passed: {padding}.")

expanded_input_size = (input_size - 1) * stride + 1
padded_out_size = output_size + kernel_size - 1
pad_before = kernel_size - 1 - padding_before
pad_after = padded_out_size - expanded_input_size - pad_before
return (pad_before, pad_after)


class Conv(Module):
"""General N-dimensional convolution."""

Expand Down Expand Up @@ -298,11 +258,10 @@ class ConvTranspose(Module):
out_channels: int = static_field()
kernel_size: Tuple[int, ...] = static_field()
stride: Tuple[int, ...] = static_field()
padding: Tuple[int, ...] = static_field()
padding: Tuple[Tuple[int, int], ...] = static_field()
output_padding: Tuple[int, ...] = static_field()
dilation: Tuple[int, ...] = static_field()
use_bias: bool = static_field()
dimension_numbers: Tuple[str, ...] = static_field()

def __init__(
self,
Expand All @@ -326,10 +285,9 @@ def __init__(
- `in_channels`: The number of input channels.
- `out_channels`: The number of output channels.
- `kernel_size`: The size of the transposed convolutional kernel.
- `stride`: The stride of the transposed convolution.
- `padding`: The amount of implicit padding on both sides for `dilation *
(kernel_size - 1) - padding` points.
- `output_padding`: The additional size added to the output shape.
- `stride`: The stride used on the equivalent [`eqx.nn.Conv`][].
- `padding`: The amount of padding used on the equivalent [`eqx.nn.Conv`][].
- `output_padding`: Additional padding for the output shape.
- `dilation`: The spacing between kernel points.
- `use_bias`: Whether to add on a bias after the transposed convolution.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
Expand All @@ -344,51 +302,74 @@ def __init__(
are an integer then the same kernel size / stride / padding / dilation will
be used along every spatial dimension.

"""
!!! tip

Transposed convolutions are often used to go in the "opposite direction" to
a normal convolution. That is, from something with the shape of the output
of a convolution to something with the shape of the input to a convolution.
Moreover, to do so with the same "connectivity", i.e. which inputs can
affect which outputs.

Relative to an [`eqx.nn.Conv`][] layer, this can be accomplished by
switching the values of `in_channels` and `out_channels`, whilst keeping
`kernel_size`, `stride, `padding`, and `dilation` the same.

When `stride > 1` then [`eqx.nn.Conv`][] maps multiple input shapes to the
same output shape. `output_padding` is provided to resolve this ambiguity,
by adding a little extra padding to just the bottom/right edges of the
input.

See [these animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations)
and [this report](https://arxiv.org/abs/1603.07285) for a nice reference.
""" # noqa: E501

super().__init__(**kwargs)
self.num_spatial_dims = num_spatial_dims
parse = _ntuple(self.num_spatial_dims)
wkey, bkey = jrandom.split(key, 2)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = parse(kernel_size)
self.use_bias = use_bias
self.output_padding = parse(output_padding)
self.padding = parse(padding)
lim = 1 / np.sqrt(self.in_channels * np.prod(self.kernel_size))
if self.num_spatial_dims == 1:
self.dimension_numbers = ("NCH", "IOH", "NCH")
elif self.num_spatial_dims == 2:
self.dimension_numbers = ("NCHW", "IOHW", "NCHW")
elif self.num_spatial_dims == 3:
self.dimension_numbers = ("NCDHW", "IODHW", "NCDHW")
else:
raise NotImplementedError(
"`ConvTranspose` only supports between 1 and 3 spatial dims",
f"({self.num_spatial_dims} was given)",
)

parse = _ntuple(num_spatial_dims)
kernel_size = parse(kernel_size)
stride = parse(stride)
output_padding = parse(output_padding)
dilation = parse(dilation)

for s, o in zip(stride, output_padding):
if output_padding >= stride:
raise ValueError("Must have `output_padding < stride` (elementwise).")

lim = 1 / np.sqrt(in_channels * np.prod(kernel_size))
self.weight = jrandom.uniform(
wkey,
(
self.in_channels,
self.out_channels,
)
+ self.kernel_size,
(out_channels, in_channels) + kernel_size,
minval=-lim,
maxval=lim,
)
if self.use_bias:
if use_bias:
self.bias = jrandom.uniform(
bkey,
(self.out_channels,) + (1,) * self.num_spatial_dims,
(out_channels,) + (1,) * num_spatial_dims,
minval=-lim,
maxval=lim,
)
else:
self.bias = None

self.stride = parse(stride)
self.dilation = parse(dilation)
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
if isinstance(padding, int):
self.padding = tuple((padding, padding) for _ in range(num_spatial_dims))
elif isinstance(padding, Sequence) and len(padding) == num_spatial_dims:
self.padding = tuple((p, p) for p in padding)
else:
raise ValueError(
"`padding` must either be an int or tuple of length "
f"{num_spatial_dims}."
)
self.output_padding = output_padding
self.dilation = dilation
self.use_bias = use_bias

def __call__(
self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None
Expand All @@ -411,26 +392,20 @@ def __call__(
f" but input has shape {x.shape}.",
)
x = jnp.expand_dims(x, axis=0)
padding = self.padding
if self.output_padding is not None:
padding = tuple(
map(
compute_adjusted_padding,
x.shape[2:],
self.weight.shape[2:],
self.stride,
self.padding,
self.output_padding,
self.dilation,
)
# Given by Relationship 14 of https://arxiv.org/abs/1603.07285
padding = tuple(
(d * (k - 1) - p0, d * (k - 1) - p1 + o)
for k, (p0, p1), o, d in zip(
self.kernel_size, self.padding, self.output_padding, self.dilation
)
x = lax.conv_transpose(
)
x = lax.conv_general_dilated(
lhs=x,
rhs=self.weight,
strides=self.stride,
window_strides=(1,) * self.num_spatial_dims,
padding=padding,
lhs_dilation=self.stride,
rhs_dilation=self.dilation,
dimension_numbers=self.dimension_numbers,
)
if self.use_bias:
x = x + self.bias
Expand Down
85 changes: 17 additions & 68 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,41 +262,24 @@ def test_convtranspose1d(getkey):
x = jrandom.normal(getkey(), (1, 32))
assert conv(x).shape == (3, 34)

# Some keyword arguments
conv = eqx.nn.ConvTranspose1d(1, out_channels=3, kernel_size=(3,), key=getkey())
x = jrandom.normal(getkey(), (1, 32))
assert conv(x).shape == (3, 34)

# All keyword arguments
conv = eqx.nn.ConvTranspose1d(
in_channels=1,
out_channels=3,
kernel_size=(3,),
padding=0,
output_padding=0,
use_bias=False,
key=getkey(),
)
x = jrandom.normal(getkey(), (1, 32))
assert conv(x).shape == (3, 34)

# Test strides
# Test stride and dilation
conv = eqx.nn.ConvTranspose1d(
in_channels=3,
out_channels=1,
kernel_size=(3,),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=True,
dilation=2,
use_bias=False,
key=getkey(),
)
x = jrandom.normal(getkey(), (3, 32))
x = jrandom.normal(getkey(), (3, 31))
assert conv(x).shape == (1, 64)

# Test value matches
conv = eqx.nn.ConvTranspose1d(1, 3, kernel_size=3, padding=0, key=getkey())
new_weight = jnp.arange(9).reshape(1, 3, 3)
new_weight = jnp.arange(9).reshape(3, 1, 3)
new_bias = jnp.array([1, 2, 3]).reshape(3, 1)
data = jnp.arange(-3, 3).reshape(1, -1)
assert new_weight.shape == conv.weight.shape
Expand Down Expand Up @@ -330,7 +313,7 @@ def test_convtranspose1d(getkey):
15,
]
).reshape(3, 8)
assert jnp.allclose(conv(data), answer)
assert jnp.all(conv(data) == answer)


def test_convtranspose2d(getkey):
Expand All @@ -339,35 +322,19 @@ def test_convtranspose2d(getkey):
x = jrandom.normal(getkey(), (1, 32, 32))
assert conv(x).shape == (3, 34, 34)

# Some keyword arguments
conv = eqx.nn.ConvTranspose2d(1, out_channels=3, kernel_size=(3, 3), key=getkey())
x = jrandom.normal(getkey(), (1, 32, 32))
assert conv(x).shape == (3, 34, 34)

# All keyword arguments
conv = eqx.nn.ConvTranspose2d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3),
padding=1,
use_bias=False,
key=getkey(),
)
x = jrandom.normal(getkey(), (1, 32, 32))
assert conv(x).shape == (3, 32, 32)

# Test strides
# Test stride and dilation
conv = eqx.nn.ConvTranspose2d(
in_channels=3,
out_channels=1,
kernel_size=(3, 3),
stride=2,
padding=1,
output_padding=1,
use_bias=True,
dilation=2,
use_bias=False,
key=getkey(),
)
x = jrandom.normal(getkey(), (3, 32, 32))
x = jrandom.normal(getkey(), (3, 31, 31))
assert conv(x).shape == (1, 64, 64)

# Test value matches
Expand All @@ -379,7 +346,7 @@ def test_convtranspose2d(getkey):
assert new_bias.shape == conv.bias.shape
conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias))
answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3)
assert jnp.allclose(conv(data), answer)
assert jnp.all(conv(data) == answer)


def test_convtranspose3d(getkey):
Expand All @@ -388,37 +355,19 @@ def test_convtranspose3d(getkey):
x = jrandom.normal(getkey(), (1, 3, 32, 32))
assert conv(x).shape == (3, 5, 34, 34)

# Some keyword arguments
conv = eqx.nn.ConvTranspose3d(
1, out_channels=3, kernel_size=(3, 3, 3), key=getkey()
)
x = jrandom.normal(getkey(), (1, 3, 32, 32))
assert conv(x).shape == (3, 5, 34, 34)

# All keyword arguments
conv = eqx.nn.ConvTranspose3d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3, 3),
padding=1,
use_bias=False,
key=getkey(),
)
x = jrandom.normal(getkey(), (1, 3, 32, 32))
assert conv(x).shape == (3, 3, 32, 32)

# Test strides
# Test stride and dilation
conv = eqx.nn.ConvTranspose3d(
in_channels=3,
out_channels=1,
kernel_size=(3, 3, 3),
stride=2,
padding=1,
output_padding=1,
use_bias=True,
dilation=2,
use_bias=False,
key=getkey(),
)
x = jrandom.normal(getkey(), (3, 3, 32, 32))
x = jrandom.normal(getkey(), (3, 2, 31, 31))
assert conv(x).shape == (1, 6, 64, 64)

# Test value matches
Expand Down Expand Up @@ -462,7 +411,7 @@ def test_convtranspose3d(getkey):
1,
]
).reshape(1, 3, 3, 3)
assert jnp.allclose(conv(data), answer)
assert jnp.all(conv(data) == answer)


def test_multihead_attention(getkey):
Expand Down