From 16e770008922550708bfec0aea261df8b9a55232 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 10 Mar 2022 14:28:53 +0000 Subject: [PATCH 1/2] Tidied; simplified; generalised ConvTranspose implementation. --- equinox/nn/conv.py | 157 +++++++++++++++++++-------------------------- tests/test_nn.py | 85 +++++------------------- 2 files changed, 83 insertions(+), 159 deletions(-) diff --git a/equinox/nn/conv.py b/equinox/nn/conv.py index 34b447cd..3bd00f4e 100644 --- a/equinox/nn/conv.py +++ b/equinox/nn/conv.py @@ -26,46 +26,6 @@ def parse(x: Any) -> tuple: return parse -def compute_adjusted_padding( - input_size: int, - kernel_size: int, - stride: int, - padding: int, - output_padding: int, - dilation: int, -) -> Tuple[int, int]: - """Computes adjusted padding for desired ConvTranspose `output_padding`.""" - kernel_size = (kernel_size - 1) * dilation + 1 - output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding - if padding == 0: - expected_input_size = (output_size - kernel_size + stride) // stride - if input_size != expected_input_size: - raise ValueError( - f"The expected input size with the current set of input " - f"parameters is {expected_input_size} which doesn't " - f"match the actual input size {input_size}." - ) - padding_before = 0 - elif padding == 1: - expected_input_size = (output_size + stride - 1) // stride - if input_size != expected_input_size: - raise ValueError( - f"The expected input size with the current set of input " - f"parameters is {expected_input_size} which doesn't " - f"match the actual input size {input_size}." - ) - padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size) - padding_before = padding_needed // 2 - else: - raise ValueError(f"`padding` must be '0' or '1'. Passed: {padding}.") - - expanded_input_size = (input_size - 1) * stride + 1 - padded_out_size = output_size + kernel_size - 1 - pad_before = kernel_size - 1 - padding_before - pad_after = padded_out_size - expanded_input_size - pad_before - return (pad_before, pad_after) - - class Conv(Module): """General N-dimensional convolution.""" @@ -298,11 +258,10 @@ class ConvTranspose(Module): out_channels: int = static_field() kernel_size: Tuple[int, ...] = static_field() stride: Tuple[int, ...] = static_field() - padding: Tuple[int, ...] = static_field() + padding: Tuple[Tuple[int, int], ...] = static_field() output_padding: Tuple[int, ...] = static_field() dilation: Tuple[int, ...] = static_field() use_bias: bool = static_field() - dimension_numbers: Tuple[str, ...] = static_field() def __init__( self, @@ -326,10 +285,9 @@ def __init__( - `in_channels`: The number of input channels. - `out_channels`: The number of output channels. - `kernel_size`: The size of the transposed convolutional kernel. - - `stride`: The stride of the transposed convolution. - - `padding`: The amount of implicit padding on both sides for `dilation * - (kernel_size - 1) - padding` points. - - `output_padding`: The additional size added to the output shape. + - `stride`: The stride used on the equivalent [`eqx.nn.Conv`][]. + - `padding`: The amount of padding used on the equivalent [`eqx.nn.Conv`][]. + - `output_padding`: Additional padding for the output shape. - `dilation`: The spacing between kernel points. - `use_bias`: Whether to add on a bias after the transposed convolution. - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter @@ -344,51 +302,74 @@ def __init__( are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension. - """ + !!! tip + + Transposed convolutions are often used to go in the "opposite direction" to + a normal convolution. That is, from something with the shape of the output + of a convolution to something with the shape of the input to a convolution. + Moreover, to do so with the same "connectivity", i.e. which inputs can + affect which outputs. + + Relative to an [`eqx.nn.Conv`][] layer, this can be accomplished by + switching the values of `in_channels` and `out_channels`, whilst keeping + `kernel_size`, `stride, `padding`, and `dilation` the same. + + When `stride > 1` then [`eqx.nn.Conv`][] maps multiple input shapes to the + same output shape. `output_padding` is provided to resolve this ambiguity, + by adding a little extra padding to just the bottom/right edges of the + input. + + See [these animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations) + and [this report](https://arxiv.org/abs/1603.07285) as a nice reference. + """ # noqa: E501 + super().__init__(**kwargs) - self.num_spatial_dims = num_spatial_dims - parse = _ntuple(self.num_spatial_dims) wkey, bkey = jrandom.split(key, 2) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = parse(kernel_size) - self.use_bias = use_bias - self.output_padding = parse(output_padding) - self.padding = parse(padding) - lim = 1 / np.sqrt(self.in_channels * np.prod(self.kernel_size)) - if self.num_spatial_dims == 1: - self.dimension_numbers = ("NCH", "IOH", "NCH") - elif self.num_spatial_dims == 2: - self.dimension_numbers = ("NCHW", "IOHW", "NCHW") - elif self.num_spatial_dims == 3: - self.dimension_numbers = ("NCDHW", "IODHW", "NCDHW") - else: - raise NotImplementedError( - "`ConvTranspose` only supports between 1 and 3 spatial dims", - f"({self.num_spatial_dims} was given)", - ) + + parse = _ntuple(num_spatial_dims) + kernel_size = parse(kernel_size) + stride = parse(stride) + output_padding = parse(output_padding) + dilation = parse(dilation) + + for s, o in zip(stride, output_padding): + if output_padding >= stride: + raise ValueError("Must have `output_padding < stride` (elementwise).") + + lim = 1 / np.sqrt(in_channels * np.prod(kernel_size)) self.weight = jrandom.uniform( wkey, - ( - self.in_channels, - self.out_channels, - ) - + self.kernel_size, + (out_channels, in_channels) + kernel_size, minval=-lim, maxval=lim, ) - if self.use_bias: + if use_bias: self.bias = jrandom.uniform( bkey, - (self.out_channels,) + (1,) * self.num_spatial_dims, + (out_channels,) + (1,) * num_spatial_dims, minval=-lim, maxval=lim, ) else: self.bias = None - self.stride = parse(stride) - self.dilation = parse(dilation) + self.num_spatial_dims = num_spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + if isinstance(padding, int): + self.padding = tuple((padding, padding) for _ in range(num_spatial_dims)) + elif isinstance(padding, Sequence) and len(padding) == num_spatial_dims: + self.padding = tuple((p, p) for p in padding) + else: + raise ValueError( + "`padding` must either be an int or tuple of length " + f"{num_spatial_dims}." + ) + self.output_padding = output_padding + self.dilation = dilation + self.use_bias = use_bias def __call__( self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None @@ -411,26 +392,20 @@ def __call__( f" but input has shape {x.shape}.", ) x = jnp.expand_dims(x, axis=0) - padding = self.padding - if self.output_padding is not None: - padding = tuple( - map( - compute_adjusted_padding, - x.shape[2:], - self.weight.shape[2:], - self.stride, - self.padding, - self.output_padding, - self.dilation, - ) + # Given by Relationship 14 of https://arxiv.org/abs/1603.07285 + padding = tuple( + (d * (k - 1) - p0, d * (k - 1) - p1 + o) + for k, (p0, p1), o, d in zip( + self.kernel_size, self.padding, self.output_padding, self.dilation ) - x = lax.conv_transpose( + ) + x = lax.conv_general_dilated( lhs=x, rhs=self.weight, - strides=self.stride, + window_strides=(1,) * self.num_spatial_dims, padding=padding, + lhs_dilation=self.stride, rhs_dilation=self.dilation, - dimension_numbers=self.dimension_numbers, ) if self.use_bias: x = x + self.bias diff --git a/tests/test_nn.py b/tests/test_nn.py index 69d38335..1f65eae9 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -262,41 +262,24 @@ def test_convtranspose1d(getkey): x = jrandom.normal(getkey(), (1, 32)) assert conv(x).shape == (3, 34) - # Some keyword arguments - conv = eqx.nn.ConvTranspose1d(1, out_channels=3, kernel_size=(3,), key=getkey()) - x = jrandom.normal(getkey(), (1, 32)) - assert conv(x).shape == (3, 34) - - # All keyword arguments - conv = eqx.nn.ConvTranspose1d( - in_channels=1, - out_channels=3, - kernel_size=(3,), - padding=0, - output_padding=0, - use_bias=False, - key=getkey(), - ) - x = jrandom.normal(getkey(), (1, 32)) - assert conv(x).shape == (3, 34) - - # Test strides + # Test stride and dilation conv = eqx.nn.ConvTranspose1d( in_channels=3, out_channels=1, - kernel_size=(3,), + kernel_size=3, stride=2, padding=1, output_padding=1, - use_bias=True, + dilation=2, + use_bias=False, key=getkey(), ) - x = jrandom.normal(getkey(), (3, 32)) + x = jrandom.normal(getkey(), (3, 31)) assert conv(x).shape == (1, 64) # Test value matches conv = eqx.nn.ConvTranspose1d(1, 3, kernel_size=3, padding=0, key=getkey()) - new_weight = jnp.arange(9).reshape(1, 3, 3) + new_weight = jnp.arange(9).reshape(3, 1, 3) new_bias = jnp.array([1, 2, 3]).reshape(3, 1) data = jnp.arange(-3, 3).reshape(1, -1) assert new_weight.shape == conv.weight.shape @@ -330,7 +313,7 @@ def test_convtranspose1d(getkey): 15, ] ).reshape(3, 8) - assert jnp.allclose(conv(data), answer) + assert jnp.all(conv(data) == answer) def test_convtranspose2d(getkey): @@ -339,24 +322,7 @@ def test_convtranspose2d(getkey): x = jrandom.normal(getkey(), (1, 32, 32)) assert conv(x).shape == (3, 34, 34) - # Some keyword arguments - conv = eqx.nn.ConvTranspose2d(1, out_channels=3, kernel_size=(3, 3), key=getkey()) - x = jrandom.normal(getkey(), (1, 32, 32)) - assert conv(x).shape == (3, 34, 34) - - # All keyword arguments - conv = eqx.nn.ConvTranspose2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - use_bias=False, - key=getkey(), - ) - x = jrandom.normal(getkey(), (1, 32, 32)) - assert conv(x).shape == (3, 32, 32) - - # Test strides + # Test stride and dilation conv = eqx.nn.ConvTranspose2d( in_channels=3, out_channels=1, @@ -364,10 +330,11 @@ def test_convtranspose2d(getkey): stride=2, padding=1, output_padding=1, - use_bias=True, + dilation=2, + use_bias=False, key=getkey(), ) - x = jrandom.normal(getkey(), (3, 32, 32)) + x = jrandom.normal(getkey(), (3, 31, 31)) assert conv(x).shape == (1, 64, 64) # Test value matches @@ -379,7 +346,7 @@ def test_convtranspose2d(getkey): assert new_bias.shape == conv.bias.shape conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias)) answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3) - assert jnp.allclose(conv(data), answer) + assert jnp.all(conv(data) == answer) def test_convtranspose3d(getkey): @@ -388,26 +355,7 @@ def test_convtranspose3d(getkey): x = jrandom.normal(getkey(), (1, 3, 32, 32)) assert conv(x).shape == (3, 5, 34, 34) - # Some keyword arguments - conv = eqx.nn.ConvTranspose3d( - 1, out_channels=3, kernel_size=(3, 3, 3), key=getkey() - ) - x = jrandom.normal(getkey(), (1, 3, 32, 32)) - assert conv(x).shape == (3, 5, 34, 34) - - # All keyword arguments - conv = eqx.nn.ConvTranspose3d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3, 3), - padding=1, - use_bias=False, - key=getkey(), - ) - x = jrandom.normal(getkey(), (1, 3, 32, 32)) - assert conv(x).shape == (3, 3, 32, 32) - - # Test strides + # Test stride and dilation conv = eqx.nn.ConvTranspose3d( in_channels=3, out_channels=1, @@ -415,10 +363,11 @@ def test_convtranspose3d(getkey): stride=2, padding=1, output_padding=1, - use_bias=True, + dilation=2, + use_bias=False, key=getkey(), ) - x = jrandom.normal(getkey(), (3, 3, 32, 32)) + x = jrandom.normal(getkey(), (3, 2, 31, 31)) assert conv(x).shape == (1, 6, 64, 64) # Test value matches @@ -462,7 +411,7 @@ def test_convtranspose3d(getkey): 1, ] ).reshape(1, 3, 3, 3) - assert jnp.allclose(conv(data), answer) + assert jnp.all(conv(data) == answer) def test_multihead_attention(getkey): From 6e6b17600f2001d60f02481b52f7979377802e05 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 10 Mar 2022 20:04:21 +0000 Subject: [PATCH 2/2] Update conv.py --- equinox/nn/conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/equinox/nn/conv.py b/equinox/nn/conv.py index 3bd00f4e..846ecb1b 100644 --- a/equinox/nn/conv.py +++ b/equinox/nn/conv.py @@ -320,7 +320,7 @@ def __init__( input. See [these animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations) - and [this report](https://arxiv.org/abs/1603.07285) as a nice reference. + and [this report](https://arxiv.org/abs/1603.07285) for a nice reference. """ # noqa: E501 super().__init__(**kwargs)