From c03ff868f61fae9955438956491a2dfb296e7d63 Mon Sep 17 00:00:00 2001 From: Matthieu Terris Date: Fri, 27 Oct 2023 09:47:37 +0200 Subject: [PATCH 1/9] added convtranspose --- .gitignore | 1 + tests/test_all_the_things.py | 16 +++ torch2jax/__init__.py | 230 ++++++++++++++++++++++++++++++++++- 3 files changed, 246 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ac38e85..60b6104 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ dist/ result __pycache__/ +.idea/ diff --git a/tests/test_all_the_things.py b/tests/test_all_the_things.py index a548108..fcb3292 100644 --- a/tests/test_all_the_things.py +++ b/tests/test_all_the_things.py @@ -382,3 +382,19 @@ def test_vit_b16(): # Models use different convolution backends and are too deep to compare gradients programmatically. But they line up # to reasonable expectations. + +def test_conv_transpose(): + model = torch.nn.ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2), bias=False) + model.eval() + + parameters = {k: t2j(v) for k, v in model.named_parameters()} + input_batch = random.normal(random.PRNGKey(123), (3, 8, 16, 16)) + res_torch = model(j2t(input_batch)) + + jaxified_module = t2j(model) + res_jax = jaxified_module(input_batch, state_dict=parameters) + res_jax_jit = jit(jaxified_module)(input_batch, state_dict=parameters) + + # Test forward pass with and without jax.jit + aac(res_jax, res_torch.numpy(force=True), atol=1e-1) + aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-1) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index cdc3d6b..394430d 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -1,13 +1,25 @@ import copy import functools import math -from typing import Optional +from typing import (Any, Optional, Sequence, Union, Tuple) + +import numpy as np import jax import jax.dlpack import jax.numpy as jnp +from jax.lib import xla_client + import torch +Precision = xla_client.PrecisionConfig.Precision +Precision.__str__ = lambda precision: precision.name +PrecisionType = Any +PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]] + +Array = Any +DType = Any +Shape = Sequence[int] def t2j_array(torch_array): # Using dlpack here causes segfaults on eg `t2j(lambda x: torch.Tensor([3.0]) * x)(jnp.array([0.0]))` when we use @@ -194,6 +206,222 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): res += coerce(bias)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] return res +@implements(torch.nn.functional.conv_transpose2d) +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + # Padding not performed properly yet, this is only tested for padding=0. + # This implementation is taken from this PR https://github.com/google/jax/pull/5772 + + res = gradient_based_conv_transpose( + lhs=coerce(input), + rhs=coerce(weight), + strides=stride, + padding='VALID', + dimension_numbers=('NCHW', 'OIHW', 'NCHW'), + dilation=dilation + ) + if bias is not None: + res += coerce(bias)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] + return res + +def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1): + """ Taken from https://github.com/google/jax/pull/5772 + Determines the output length of a transposed convolution given the input length. + Function modified from Keras. + Arguments: + input_length: Integer. + filter_size: Integer. + padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple. + output_padding: Integer, amount of padding along the output dimension. Can + be set to `None` in which case the output length is inferred. + stride: Integer. + dilation: Integer. + Returns: + The output length (integer). + """ + if input_length is None: + return None + + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == 'VALID': + length = input_length * stride + max(filter_size - stride, 0) + elif padding == 'SAME': + length = input_length * stride + else: + length = ((input_length - 1) * stride + filter_size + - padding[0] - padding[1]) + + else: + if padding == 'SAME': + pad = filter_size // 2 + total_pad = pad * 2 + elif padding == 'VALID': + total_pad = 0 + else: + total_pad = padding[0] + padding[1] + + length = ((input_length - 1) * stride + filter_size - total_pad + + output_padding) + + return length + + +def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: int, stride: int, + padding: Union[str, Tuple[int, int]], dilation: int = 1,) -> Tuple[int, int]: + """ + Taken from https://github.com/google/jax/pull/5772 + Computes adjusted padding for desired ConvTranspose `output_size`. + Ported from DeepMind Haiku. + """ + kernel_size = (kernel_size - 1) * dilation + 1 + if padding == "VALID": + expected_input_size = (output_size - kernel_size + stride) // stride + if input_size != expected_input_size: + raise ValueError(f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}.") + padding_before = 0 + elif padding == "SAME": + expected_input_size = (output_size + stride - 1) // stride + if input_size != expected_input_size: + raise ValueError(f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}.") + padding_needed = max(0, + (input_size - 1) * stride + kernel_size - output_size) + padding_before = padding_needed // 2 + else: + padding_before = padding[0] # type: ignore[assignment] + + expanded_input_size = (input_size - 1) * stride + 1 + padded_out_size = output_size + kernel_size - 1 + pad_before = kernel_size - 1 - padding_before + pad_after = padded_out_size - expanded_input_size - pad_before + return (pad_before, pad_after) + +def gradient_based_conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + output_padding: Optional[Sequence[int]] = None, + output_shape: Optional[Sequence[int]] = None, + dilation: Optional[Sequence[int]] = None, + dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None, + transpose_kernel: bool = True, + precision: PrecisionLike = None) -> Array: + """ + Taken from https://github.com/google/jax/pull/5772 + Convenience wrapper for calculating the N-d transposed convolution. + Much like `conv_transpose`, this function calculates transposed convolutions + via fractionally strided convolution rather than calculating the gradient + (transpose) of a forward convolution. However, the latter is more common + among deep learning frameworks, such as TensorFlow, PyTorch, and Keras. + This function provides the same set of APIs to help reproduce results in these frameworks. + Args: + lhs: a rank `n+2` dimensional input array. + rhs: a rank `n+2` dimensional array of kernel weights. + strides: sequence of `n` integers, amounts to strides of the corresponding forward convolution. + padding: `"SAME"`, `"VALID"`, or a sequence of `n` integer 2-tuples that controls + the before-and-after padding for each `n` spatial dimension of + the corresponding forward convolution. + output_padding: A sequence of integers specifying the amount of padding along + each spacial dimension of the output tensor, used to disambiguate the output shape of + transposed convolutions when the stride is larger than 1. + (see a detailed description at + 1https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) + The amount of output padding along a given dimension must + be lower than the stride along that same dimension. + If set to `None` (default), the output shape is inferred. + If both `output_padding` and `output_shape` are specified, they have to be mutually compatible. + output_shape: Output shape of the spatial dimensions of a transpose + convolution. Can be `None` or an iterable of `n` integers. If a `None` value is given (default), + the shape is automatically calculated. + Similar to `output_padding`, `output_shape` is also for disambiguating the output shape + when stride > 1 (see also + https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose) + If both `output_padding` and `output_shape` are specified, they have to be mutually compatible. + dilation: `None`, or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `rhs`. Dilated convolution + is also known as atrous convolution. + dimension_numbers: tuple of dimension descriptors as in + lax.conv_general_dilated. Defaults to tensorflow convention. + transpose_kernel: if `True` flips spatial axes and swaps the input/output + channel axes of the kernel. This makes the output of this function identical + to the gradient-derived functions like keras.layers.Conv2DTranspose and + torch.nn.ConvTranspose2d applied to the same kernel. + Although for typical use in neural nets this is unnecessary + and makes input/output channel specification confusing, you need to set this to `True` + in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and PyTorch. + precision: Optional. Either ``None``, which means the default precision for + the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. + Returns: + Transposed N-d convolution. + """ + assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 + ndims = len(lhs.shape) + one = (1,) * (ndims - 2) + # Set dimensional layout defaults if not specified. + if dimension_numbers is None: + if ndims == 2: + dimension_numbers = ('NC', 'IO', 'NC') + elif ndims == 3: + dimension_numbers = ('NHC', 'HIO', 'NHC') + elif ndims == 4: + dimension_numbers = ('NHWC', 'HWIO', 'NHWC') + elif ndims == 5: + dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') + else: + raise ValueError('No 4+ dimensional dimension_number defaults.') + dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) + k_shape = np.take(rhs.shape, dn.rhs_spec) + k_sdims = k_shape[2:] # type: ignore[index] + i_shape = np.take(lhs.shape, dn.lhs_spec) + i_sdims = i_shape[2:] # type: ignore[index] + + # Calculate correct output shape given padding and strides. + if dilation is None: + dilation = (1,) * (rhs.ndim - 2) + + if output_padding is None: + output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] + + if isinstance(padding, str): + if padding in {'SAME', 'VALID'}: + padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] + else: + raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") + + inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, + padding, output_padding, strides, dilation)) + if output_shape is None: + output_shape = inferred_output_shape # type: ignore[assignment] + else: + if not output_shape == inferred_output_shape: + raise ValueError(f"`output_padding` and `output_shape` are not compatible." + f"Inferred output shape from `output_padding`: {inferred_output_shape}, " + f"but got `output_shape` {output_shape}") + + pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, + k_sdims, strides, padding, dilation)) + + if transpose_kernel: + # flip spatial dims and swap input / output channel axes + rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) + rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, + precision=precision) + +def _flip_axes(x, axes): + """ + Taken from https://github.com/google/jax/pull/5772 + Flip ndarray 'x' along each axis specified in axes tuple.""" + for axis in axes: + x = np.flip(x, axis) + return x + @implements(torch.nn.functional.dropout) def dropout(input, p=0.5, training=True, inplace=False): assert not training, "TODO: implement dropout=True" From 348d519b2f8a5ceb9c5b65caf1123e09475827a6 Mon Sep 17 00:00:00 2001 From: Matthieu Terris <31830373+matthieutrs@users.noreply.github.com> Date: Sat, 28 Oct 2023 13:36:20 +0200 Subject: [PATCH 2/9] Update torch2jax/__init__.py Co-authored-by: Samuel Ainsworth --- torch2jax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index 394430d..1c0767e 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -1,7 +1,7 @@ import copy import functools import math -from typing import (Any, Optional, Sequence, Union, Tuple) +from typing import Optional, Sequence, Tuple, Union import numpy as np From ec65450a14c03a2ec7abe939638f527336a0430e Mon Sep 17 00:00:00 2001 From: Matthieu Terris <31830373+matthieutrs@users.noreply.github.com> Date: Sat, 28 Oct 2023 13:36:32 +0200 Subject: [PATCH 3/9] Update torch2jax/__init__.py Co-authored-by: Samuel Ainsworth --- torch2jax/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index 1c0767e..774f64b 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -8,8 +8,6 @@ import jax import jax.dlpack import jax.numpy as jnp -from jax.lib import xla_client - import torch Precision = xla_client.PrecisionConfig.Precision From 1aaa8740cf3ffda2ec2f3b0ad6673e2bbf84c15c Mon Sep 17 00:00:00 2001 From: Matthieu Terris <31830373+matthieutrs@users.noreply.github.com> Date: Sat, 28 Oct 2023 13:38:10 +0200 Subject: [PATCH 4/9] Update torch2jax/__init__.py Co-authored-by: Samuel Ainsworth --- torch2jax/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index 774f64b..5394bb1 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -10,14 +10,6 @@ import jax.numpy as jnp import torch -Precision = xla_client.PrecisionConfig.Precision -Precision.__str__ = lambda precision: precision.name -PrecisionType = Any -PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]] - -Array = Any -DType = Any -Shape = Sequence[int] def t2j_array(torch_array): # Using dlpack here causes segfaults on eg `t2j(lambda x: torch.Tensor([3.0]) * x)(jnp.array([0.0]))` when we use From e68c244389cd2d34437ae1200a0da9e2cdc8c3f5 Mon Sep 17 00:00:00 2001 From: Matthieu Terris <31830373+matthieutrs@users.noreply.github.com> Date: Sat, 28 Oct 2023 13:38:19 +0200 Subject: [PATCH 5/9] Update torch2jax/__init__.py Co-authored-by: Samuel Ainsworth --- torch2jax/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index 5394bb1..8ecd611 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -292,14 +292,14 @@ def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: in pad_after = padded_out_size - expanded_input_size - pad_before return (pad_before, pad_after) -def gradient_based_conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], +def gradient_based_conv_transpose(lhs, rhs, strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], output_padding: Optional[Sequence[int]] = None, output_shape: Optional[Sequence[int]] = None, dilation: Optional[Sequence[int]] = None, dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None, transpose_kernel: bool = True, - precision: PrecisionLike = None) -> Array: + precision=None): """ Taken from https://github.com/google/jax/pull/5772 Convenience wrapper for calculating the N-d transposed convolution. From 8879a8b86562f7a567469d3bb63080d7d35b5e32 Mon Sep 17 00:00:00 2001 From: Matthieu Terris Date: Sat, 28 Oct 2023 16:51:37 +0200 Subject: [PATCH 6/9] fixed conv_transpose2d --- .gitignore | 2 +- tests/test_all_the_things.py | 33 ++++++++++++++++++--------------- torch2jax/__init__.py | 26 ++++++++++++++++---------- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 60b6104..d1421a4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ dist/ result __pycache__/ -.idea/ +.ruff_cache/ \ No newline at end of file diff --git a/tests/test_all_the_things.py b/tests/test_all_the_things.py index fcb3292..7dcd859 100644 --- a/tests/test_all_the_things.py +++ b/tests/test_all_the_things.py @@ -383,18 +383,21 @@ def test_vit_b16(): # Models use different convolution backends and are too deep to compare gradients programmatically. But they line up # to reasonable expectations. -def test_conv_transpose(): - model = torch.nn.ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2), bias=False) - model.eval() - - parameters = {k: t2j(v) for k, v in model.named_parameters()} - input_batch = random.normal(random.PRNGKey(123), (3, 8, 16, 16)) - res_torch = model(j2t(input_batch)) - - jaxified_module = t2j(model) - res_jax = jaxified_module(input_batch, state_dict=parameters) - res_jax_jit = jit(jaxified_module)(input_batch, state_dict=parameters) - - # Test forward pass with and without jax.jit - aac(res_jax, res_torch.numpy(force=True), atol=1e-1) - aac(res_jax_jit, res_torch.numpy(force=True), atol=1e-1) +def test_conv_transpose2d(): + for in_channels in [2, 4, 8]: + for out_channels in [2, 4, 8]: + for kernel_size in [(1, 1), (2, 2), (3, 3), 1, 2, 3, (1, 2), (2, 3)]: + for stride in [(1, 1), (2, 2), (3, 3), 1, 2, 3, (1, 2), (2, 3)]: + for bias in [False, True]: + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + output_padding = (max(stride[0]-kernel_size[0], 0), max(stride[1]-kernel_size[1], 0)) + model = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, bias=bias, output_padding=output_padding) + parameters = {k: t2j(v) for k, v in model.named_parameters()} + input_batch = random.normal(random.PRNGKey(123), (3, in_channels, 16, 16)) + res_torch = model(j2t(input_batch)) + jaxified_module = t2j(model) + res_jax = jaxified_module(input_batch, state_dict=parameters) + aac(res_jax, res_torch.numpy(force=True), atol=1e-1) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index 8ecd611..f441997 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -3,8 +3,6 @@ import math from typing import Optional, Sequence, Tuple, Union -import numpy as np - import jax import jax.dlpack import jax.numpy as jnp @@ -198,9 +196,16 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): @implements(torch.nn.functional.conv_transpose2d) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - # Padding not performed properly yet, this is only tested for padding=0. # This implementation is taken from this PR https://github.com/google/jax/pull/5772 + if isinstance(stride, int): + stride = (stride, stride) + output_padding_lax = (max(stride[0] - weight.shape[2], 0), max(stride[1] - weight.shape[3], 0)) + if isinstance(output_padding, int): + output_padding = (output_padding, output_padding) + assert output_padding == output_padding_lax, f"lax conv_transpose assumes output_padding = " \ + f"{output_padding_lax}, found {output_padding}" + res = gradient_based_conv_transpose( lhs=coerce(input), rhs=coerce(weight), @@ -237,7 +242,7 @@ def _deconv_output_length(input_length, filter_size, padding, output_padding=Non # Infer length if output padding is None, else compute the exact length if output_padding is None: if padding == 'VALID': - length = input_length * stride + max(filter_size - stride, 0) + length = input_length * stride + jax.lax.max(filter_size - stride, 0) elif padding == 'SAME': length = input_length * stride else: @@ -267,6 +272,7 @@ def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: in Ported from DeepMind Haiku. """ kernel_size = (kernel_size - 1) * dilation + 1 + if padding == "VALID": expected_input_size = (output_size - kernel_size + stride) // stride if input_size != expected_input_size: @@ -280,7 +286,7 @@ def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: in raise ValueError(f"The expected input size with the current set of input " f"parameters is {expected_input_size} which doesn't " f"match the actual input size {input_size}.") - padding_needed = max(0, + padding_needed = jax.lax.max(0, (input_size - 1) * stride + kernel_size - output_size) padding_before = padding_needed // 2 else: @@ -366,9 +372,9 @@ def gradient_based_conv_transpose(lhs, rhs, strides: Sequence[int], else: raise ValueError('No 4+ dimensional dimension_number defaults.') dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) - k_shape = np.take(rhs.shape, dn.rhs_spec) + k_shape = jnp.take(jnp.array(rhs.shape), jnp.array(dn.rhs_spec)) k_sdims = k_shape[2:] # type: ignore[index] - i_shape = np.take(lhs.shape, dn.lhs_spec) + i_shape = jnp.take(jnp.array(lhs.shape), jnp.array(dn.lhs_spec)) i_sdims = i_shape[2:] # type: ignore[index] # Calculate correct output shape given padding and strides. @@ -399,8 +405,8 @@ def gradient_based_conv_transpose(lhs, rhs, strides: Sequence[int], if transpose_kernel: # flip spatial dims and swap input / output channel axes - rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) - rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + rhs = _flip_axes(rhs, dn.rhs_spec[2:]) + rhs = jnp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, precision=precision) @@ -409,7 +415,7 @@ def _flip_axes(x, axes): Taken from https://github.com/google/jax/pull/5772 Flip ndarray 'x' along each axis specified in axes tuple.""" for axis in axes: - x = np.flip(x, axis) + x = jnp.flip(x, axis) return x @implements(torch.nn.functional.dropout) From 831e7c3aacba3cd90c19f4ba25d3d1b40edcf8c7 Mon Sep 17 00:00:00 2001 From: Matthieu Terris Date: Sat, 28 Oct 2023 17:29:31 +0200 Subject: [PATCH 7/9] ci --- tests/test_all_the_things.py | 11 ++-- torch2jax/__init__.py | 113 +++++++++++++++++++---------------- 2 files changed, 70 insertions(+), 54 deletions(-) diff --git a/tests/test_all_the_things.py b/tests/test_all_the_things.py index 678f46a..ffd2ef9 100644 --- a/tests/test_all_the_things.py +++ b/tests/test_all_the_things.py @@ -439,6 +439,7 @@ def test_vit_b16(): # Models use different convolution backends and are too deep to compare gradients programmatically. But they line up # to reasonable expectations. + def test_conv_transpose2d(): for in_channels in [2, 4, 8]: for out_channels in [2, 4, 8]: @@ -446,11 +447,13 @@ def test_conv_transpose2d(): for stride in [(1, 1), (2, 2), (3, 3), 1, 2, 3, (1, 2), (2, 3)]: for bias in [False, True]: if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) + kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): - stride = (stride, stride) - output_padding = (max(stride[0]-kernel_size[0], 0), max(stride[1]-kernel_size[1], 0)) - model = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, bias=bias, output_padding=output_padding) + stride = (stride, stride) + output_padding = (max(stride[0] - kernel_size[0], 0), max(stride[1] - kernel_size[1], 0)) + model = torch.nn.ConvTranspose2d( + in_channels, out_channels, kernel_size, stride, bias=bias, output_padding=output_padding + ) parameters = {k: t2j(v) for k, v in model.named_parameters()} input_batch = random.normal(random.PRNGKey(123), (3, in_channels, 16, 16)) res_torch = model(j2t(input_batch)) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index d0c79d4..af33171 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -239,16 +239,17 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi output_padding_lax = (max(stride[0] - weight.shape[2], 0), max(stride[1] - weight.shape[3], 0)) if isinstance(output_padding, int): output_padding = (output_padding, output_padding) - assert output_padding == output_padding_lax, f"lax conv_transpose assumes output_padding = " \ - f"{output_padding_lax}, found {output_padding}" + assert output_padding == output_padding_lax, ( + f"lax conv_transpose assumes output_padding = " f"{output_padding_lax}, found {output_padding}" + ) res = gradient_based_conv_transpose( - lhs=coerce(input), - rhs=coerce(weight), - strides=stride, - padding='VALID', - dimension_numbers=('NCHW', 'OIHW', 'NCHW'), - dilation=dilation + lhs=coerce(input), + rhs=coerce(weight), + strides=stride, + padding="VALID", + dimension_numbers=("NCHW", "OIHW", "NCHW"), + dilation=dilation, ) if bias is not None: res += coerce(bias)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] @@ -256,7 +257,7 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1): - """ Taken from https://github.com/google/jax/pull/5772 + """Taken from https://github.com/google/jax/pull/5772 Determines the output length of a transposed convolution given the input length. Function modified from Keras. Arguments: @@ -278,31 +279,35 @@ def _deconv_output_length(input_length, filter_size, padding, output_padding=Non # Infer length if output padding is None, else compute the exact length if output_padding is None: - if padding == 'VALID': + if padding == "VALID": length = input_length * stride + jax.lax.max(filter_size - stride, 0) - elif padding == 'SAME': + elif padding == "SAME": length = input_length * stride else: - length = ((input_length - 1) * stride + filter_size - - padding[0] - padding[1]) + length = (input_length - 1) * stride + filter_size - padding[0] - padding[1] else: - if padding == 'SAME': + if padding == "SAME": pad = filter_size // 2 total_pad = pad * 2 - elif padding == 'VALID': + elif padding == "VALID": total_pad = 0 else: total_pad = padding[0] + padding[1] - length = ((input_length - 1) * stride + filter_size - total_pad + - output_padding) + length = (input_length - 1) * stride + filter_size - total_pad + output_padding return length -def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: int, stride: int, - padding: Union[str, Tuple[int, int]], dilation: int = 1,) -> Tuple[int, int]: +def _compute_adjusted_padding( + input_size: int, + output_size: int, + kernel_size: int, + stride: int, + padding: Union[str, Tuple[int, int]], + dilation: int = 1, +) -> Tuple[int, int]: """ Taken from https://github.com/google/jax/pull/5772 Computes adjusted padding for desired ConvTranspose `output_size`. @@ -313,18 +318,21 @@ def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: in if padding == "VALID": expected_input_size = (output_size - kernel_size + stride) // stride if input_size != expected_input_size: - raise ValueError(f"The expected input size with the current set of input " - f"parameters is {expected_input_size} which doesn't " - f"match the actual input size {input_size}.") + raise ValueError( + f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) padding_before = 0 elif padding == "SAME": expected_input_size = (output_size + stride - 1) // stride if input_size != expected_input_size: - raise ValueError(f"The expected input size with the current set of input " - f"parameters is {expected_input_size} which doesn't " - f"match the actual input size {input_size}.") - padding_needed = jax.lax.max(0, - (input_size - 1) * stride + kernel_size - output_size) + raise ValueError( + f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_needed = jax.lax.max(0, (input_size - 1) * stride + kernel_size - output_size) padding_before = padding_needed // 2 else: padding_before = padding[0] # type: ignore[assignment] @@ -336,14 +344,18 @@ def _compute_adjusted_padding(input_size: int, output_size: int, kernel_size: in return (pad_before, pad_after) -def gradient_based_conv_transpose(lhs, rhs, strides: Sequence[int], - padding: Union[str, Sequence[Tuple[int, int]]], - output_padding: Optional[Sequence[int]] = None, - output_shape: Optional[Sequence[int]] = None, - dilation: Optional[Sequence[int]] = None, - dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None, - transpose_kernel: bool = True, - precision=None): +def gradient_based_conv_transpose( + lhs, + rhs, + strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + output_padding: Optional[Sequence[int]] = None, + output_shape: Optional[Sequence[int]] = None, + dilation: Optional[Sequence[int]] = None, + dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None, + transpose_kernel: bool = True, + precision=None, +): """ Taken from https://github.com/google/jax/pull/5772 Convenience wrapper for calculating the N-d transposed convolution. @@ -400,15 +412,15 @@ def gradient_based_conv_transpose(lhs, rhs, strides: Sequence[int], # Set dimensional layout defaults if not specified. if dimension_numbers is None: if ndims == 2: - dimension_numbers = ('NC', 'IO', 'NC') + dimension_numbers = ("NC", "IO", "NC") elif ndims == 3: - dimension_numbers = ('NHC', 'HIO', 'NHC') + dimension_numbers = ("NHC", "HIO", "NHC") elif ndims == 4: - dimension_numbers = ('NHWC', 'HWIO', 'NHWC') + dimension_numbers = ("NHWC", "HWIO", "NHWC") elif ndims == 5: - dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') + dimension_numbers = ("NHWDC", "HWDIO", "NHWDC") else: - raise ValueError('No 4+ dimensional dimension_number defaults.') + raise ValueError("No 4+ dimensional dimension_number defaults.") dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) k_shape = jnp.take(jnp.array(rhs.shape), jnp.array(dn.rhs_spec)) k_sdims = k_shape[2:] # type: ignore[index] @@ -423,30 +435,31 @@ def gradient_based_conv_transpose(lhs, rhs, strides: Sequence[int], output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] if isinstance(padding, str): - if padding in {'SAME', 'VALID'}: + if padding in {"SAME", "VALID"}: padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] else: raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") - inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, - padding, output_padding, strides, dilation)) + inferred_output_shape = tuple( + map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation) + ) if output_shape is None: output_shape = inferred_output_shape # type: ignore[assignment] else: if not output_shape == inferred_output_shape: - raise ValueError(f"`output_padding` and `output_shape` are not compatible." - f"Inferred output shape from `output_padding`: {inferred_output_shape}, " - f"but got `output_shape` {output_shape}") + raise ValueError( + f"`output_padding` and `output_shape` are not compatible." + f"Inferred output shape from `output_padding`: {inferred_output_shape}, " + f"but got `output_shape` {output_shape}" + ) - pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, - k_sdims, strides, padding, dilation)) + pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation)) if transpose_kernel: # flip spatial dims and swap input / output channel axes rhs = _flip_axes(rhs, dn.rhs_spec[2:]) rhs = jnp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) - return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, - precision=precision) + return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, precision=precision) def _flip_axes(x, axes): From f2ffe8ab235773a2c241f959ce0033c7b6e720a2 Mon Sep 17 00:00:00 2001 From: Matthieu Terris <31830373+matthieutrs@users.noreply.github.com> Date: Mon, 30 Oct 2023 12:25:11 +0100 Subject: [PATCH 8/9] Update torch2jax/__init__.py Co-authored-by: Samuel Ainsworth --- torch2jax/__init__.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index af33171..905a8ba 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -233,23 +233,18 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): @implements(torch.nn.functional.conv_transpose2d) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): # This implementation is taken from this PR https://github.com/google/jax/pull/5772 + assert input.ndim == 4, "TODO: implement non-batched input" + assert groups == 1, "TODO: implement groups != 1" - if isinstance(stride, int): - stride = (stride, stride) - output_padding_lax = (max(stride[0] - weight.shape[2], 0), max(stride[1] - weight.shape[3], 0)) - if isinstance(output_padding, int): - output_padding = (output_padding, output_padding) - assert output_padding == output_padding_lax, ( - f"lax conv_transpose assumes output_padding = " f"{output_padding_lax}, found {output_padding}" - ) - + ph, pw = (padding, padding) if isinstance(padding, int) else padding res = gradient_based_conv_transpose( lhs=coerce(input), rhs=coerce(weight), strides=stride, - padding="VALID", - dimension_numbers=("NCHW", "OIHW", "NCHW"), + padding=[(ph, ph), (pw, pw)], + output_padding=output_padding, dilation=dilation, + dimension_numbers=("NCHW", "OIHW", "NCHW"), ) if bias is not None: res += coerce(bias)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] From 462554a684a88433fd23c101ce031f9da2fb9e59 Mon Sep 17 00:00:00 2001 From: Matthieu Terris <31830373+matthieutrs@users.noreply.github.com> Date: Mon, 30 Oct 2023 12:25:24 +0100 Subject: [PATCH 9/9] Update tests/test_all_the_things.py Co-authored-by: Samuel Ainsworth --- tests/test_all_the_things.py | 49 ++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/tests/test_all_the_things.py b/tests/test_all_the_things.py index ffd2ef9..1dceb89 100644 --- a/tests/test_all_the_things.py +++ b/tests/test_all_the_things.py @@ -441,22 +441,33 @@ def test_vit_b16(): def test_conv_transpose2d(): - for in_channels in [2, 4, 8]: - for out_channels in [2, 4, 8]: - for kernel_size in [(1, 1), (2, 2), (3, 3), 1, 2, 3, (1, 2), (2, 3)]: - for stride in [(1, 1), (2, 2), (3, 3), 1, 2, 3, (1, 2), (2, 3)]: - for bias in [False, True]: - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if isinstance(stride, int): - stride = (stride, stride) - output_padding = (max(stride[0] - kernel_size[0], 0), max(stride[1] - kernel_size[1], 0)) - model = torch.nn.ConvTranspose2d( - in_channels, out_channels, kernel_size, stride, bias=bias, output_padding=output_padding - ) - parameters = {k: t2j(v) for k, v in model.named_parameters()} - input_batch = random.normal(random.PRNGKey(123), (3, in_channels, 16, 16)) - res_torch = model(j2t(input_batch)) - jaxified_module = t2j(model) - res_jax = jaxified_module(input_batch, state_dict=parameters) - aac(res_jax, res_torch.numpy(force=True), atol=1e-1) + for in_channels in [1, 2]: + for out_channels in [1, 2]: + for kernel_size in [1, 2, (1, 2)]: + for stride in [1, 2, (1, 2)]: + for padding in [(0, 0), 1, 2, (1, 2)]: + for output_padding in [0, 1, 2, (1, 2)]: + for bias in [False, True]: + for dilation in [1, 2, (1, 2)]: + model = torch.nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + dilation=dilation, + ) + params = {k: random.normal(random.PRNGKey(123), v.shape) for k, v in model.named_parameters()} + model.load_state_dict({k: j2t(v) for k, v in params.items()}) + + input_batch = random.normal(random.PRNGKey(123), (3, in_channels, 16, 16)) + try: + res_torch = model(j2t(input_batch)) + except RuntimeError: + # RuntimeError: output padding must be smaller than either stride or dilation + continue + + res_jax = t2j(model)(input_batch, state_dict=params) + aac(res_jax, res_torch.numpy(force=True), atol=1e-4)