diff --git a/equinox/nn/__init__.py b/equinox/nn/__init__.py index fd0c9129..1d45e4c7 100644 --- a/equinox/nn/__init__.py +++ b/equinox/nn/__init__.py @@ -1,5 +1,17 @@ +from .attention import MultiheadAttention from .composed import MLP, Sequential -from .conv import Conv, Conv1d, Conv2d, Conv3d +from .conv import ( + Conv, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) from .dropout import Dropout +from .embedding import Embedding from .linear import Identity, Linear +from .normalization import LayerNorm from .rnn import GRUCell, LSTMCell diff --git a/equinox/nn/attention.py b/equinox/nn/attention.py new file mode 100644 index 00000000..ec417823 --- /dev/null +++ b/equinox/nn/attention.py @@ -0,0 +1,141 @@ +from typing import Optional + +import jax +import jax.numpy as jnp +import jax.random as jrandom +import numpy as np + +from ..custom_types import Array +from ..module import Module, static_field +from .dropout import Dropout +from .linear import Linear + + +class MultiheadAttention(Module): + """ + Multihead Attention layer from `Attention Is All You Need `_. + + $\text{MultiHeadAttention}(Q, K, V) = \text{Concat}(head_{1},...,head_{h})W^{out}$, + where $head_i = \text{softmax}(\frac{QW_i^Q(KW_i^K)^\intercal}{\sqrt{d_k}})VW_i^V$ + + """ + + embed_dim: int = static_field() + num_heads: int = static_field() + kdim: int = static_field() + vdim: int = static_field() + _qkv_same_embed_dim: bool = static_field() + head_dim: int = static_field() + q_proj: Linear + k_proj: Linear + v_proj: Linear + out_proj: Linear + dropout: Dropout + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + use_bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + add_bias_kv: bool = False, + *, + key: "jax.random.PRNGKey", + **kwargs, + ): + """**Arguments:** + + - `embed_dim`: Dimension of the model. + - `num_heads`: Number of parallel attention heads. + - `dropout`: Dropout probability on attention matrix. Default: `0.0`. + - `use_bias`: Whether to use a bias term on the output projection. Default: `True`. + - `kdim`: Total number of features for keys. Default: `None` (use `kdim=embed_dim`). + - `vdim`: Total number of features for values. Default: `None` (use `vdim=embed_dim`). + - `add_bias_kv`: Whether to use bias term for value and key projections. Default: `False`. + - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter + initialisation. (Keyword only argument.) + + """ + super().__init__(**kwargs) + key1, key2, key3, key4 = jrandom.split(key, 4) + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + if self.embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads (got embed_dim = {self.embed_dim}" + f" and num_heads = {self.num_heads})" + ) + if self.kdim % num_heads != 0: + raise ValueError( + f"kdim must be divisible by num_heads (got kdim = {self.kdim} and " + f"num_heads = {self.num_heads})" + ) + if self.vdim % num_heads != 0: + raise ValueError( + f"vdim must be divisible by num_heads (got vdim = {self.vdim} and " + f"num_heads = {self.num_heads})" + ) + if dropout == 0.0: + self.dropout = Dropout(dropout, deterministic=True) + else: + self.dropout = Dropout(dropout) + self.q_proj = Linear(self.embed_dim, self.embed_dim, use_bias=False, key=key1) + self.k_proj = Linear(self.kdim, self.embed_dim, use_bias=add_bias_kv, key=key2) + self.v_proj = Linear(self.vdim, self.embed_dim, use_bias=add_bias_kv, key=key3) + self.out_proj = Linear(embed_dim, embed_dim, use_bias=use_bias, key=key4) + + def __call__( + self, + query: Array, + key_: Array, + value: Array, + attn_mask: Optional[Array] = None, + *, + key: Optional["jax.random.PRNGKey"] = None, + ) -> Array: + """**Arguments:** + + - `query`: Query embedding. Should be a JAX array of shape `(sequence_length, embed_dim)`. + - `key_`: Key embedding. Should be a JAX array of shape `(sequence_length, embed_dim)`. + - `value`: Value embedding. Should be a JAX array of shape `(sequence_length, embed_dim)`. + - `attn_mask`: A mask preventing attention to certain positions. + - `key`: A PRNGKey used for dropout. + + **Returns:** + + A JAX array of shape `(sequence_length, embed_dim)`. + """ + d1, _ = query.shape + query_heads = self._project(self.q_proj, query) + key_heads = self._project(self.k_proj, key_) + value_heads = self._project(self.v_proj, value) + attn_logits = jnp.einsum("shd,Shd->hsS", query_heads, key_heads) + sqrt_key_size = np.sqrt(self.kdim // self.num_heads).astype(key.dtype) + attn_logits = attn_logits / sqrt_key_size + attn_logits = self.dropout(attn_logits, key=key) + + if attn_mask is not None: + if attn_mask.ndim != attn_logits.ndim: + raise ValueError( + f"Mask dimensionality {attn_mask.ndim} must match logits " + f"{attn_logits.ndim}." + ) + attn_logits = jnp.where(attn_mask, attn_logits, -1e30) + attn_weights = jax.nn.softmax(attn_logits, axis=-1) + attn = jnp.einsum("hsS,Shd->shd", attn_weights, value_heads) + attn_vec = jnp.reshape(attn, (*query.shape[:-1], -1)) + + return jax.vmap(self.out_proj)(attn_vec) + + def _project(self, proj, x): + d1, _ = x.shape + projection = jax.vmap(proj)(x).reshape( + d1, self.num_heads, self.embed_dim // self.num_heads + ) + return projection diff --git a/equinox/nn/conv.py b/equinox/nn/conv.py index 26e788f3..dcee76d9 100644 --- a/equinox/nn/conv.py +++ b/equinox/nn/conv.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import jax.random as jrandom import numpy as np -from jax.lax import conv_general_dilated +from jax.lax import conv_general_dilated, conv_transpose from ..custom_types import Array from ..module import Module, static_field @@ -26,6 +26,46 @@ def parse(x: Any) -> tuple: return parse +def compute_adjusted_padding( + input_size: int, + kernel_size: int, + stride: int, + padding: int, + output_padding: int, + dilation: int = 1, +) -> Tuple[int, int]: + """Computes adjusted padding for desired ConvTranspose `output_padding`.""" + kernel_size = (kernel_size - 1) * dilation + 1 + output_size = (input_size - 1) * stride - 2 * padding + kernel_size + output_padding + if padding == 0: + expected_input_size = (output_size - kernel_size + stride) // stride + if input_size != expected_input_size: + raise ValueError( + f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_before = 0 + elif padding == 1: + expected_input_size = (output_size + stride - 1) // stride + if input_size != expected_input_size: + raise ValueError( + f"The expected input size with the current set of input " + f"parameters is {expected_input_size} which doesn't " + f"match the actual input size {input_size}." + ) + padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size) + padding_before = padding_needed // 2 + else: + raise ValueError(f"`padding` must be '0' or '1'. Passed: {padding}.") + + expanded_input_size = (input_size - 1) * stride + 1 + padded_out_size = output_size + kernel_size - 1 + pad_before = kernel_size - 1 - padding_before + pad_after = padded_out_size - expanded_input_size - pad_before + return (pad_before, pad_after) + + class Conv(Module): """General N-dimensional convolution.""" @@ -74,7 +114,7 @@ def __init__( All of `kernel_size`, `stride`, `padding`, `dilation` can be either an integer or a sequence of integers. If they are a sequence then the sequence should be of length equal to `num_spatial_dims`, and specify the value of - each property down each spatial dimension in turn.. If they are an integer + each property down each spatial dimension in turn. If they are an integer then the same kernel size / stride / padding / dilation will be used along every spatial dimension. @@ -246,3 +286,249 @@ def __init__( key=key, **kwargs, ) + + +class ConvTranspose(Module): + """General N dimension Transpose Convolution""" + + num_spatial_dims: int = static_field() + weight: Array + bias: Optional[Array] + in_channels: int = static_field() + out_channels: int = static_field() + kernel_size: Tuple[int] = static_field() + stride: Tuple[int] = static_field() + padding: Tuple[int] = static_field() + output_padding: Tuple[int] = static_field() + dilation: Tuple[int] = static_field() + use_bias: bool = static_field() + dimension_numbers: Tuple[str] = static_field() + + def __init__( + self, + num_spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[int, Sequence[int]] = 0, + output_padding: Union[int, Sequence[int]] = 0, + dilation: Union[int, Sequence[int]] = 1, + use_bias: bool = True, + *, + key: "jax.random.PRNGKey", + **kwargs, + ): + """**Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions. For example traditional + convolutions for image processing have this set to `2`. + - `in_channels`: The number of input channels. + - `out_channels`: The number of output channels. + - `kernel_size`: The size of the transposed convolutional kernel. + - `stride`: The stride of the transposed convolution. + - `padding`: The amount of implicit padding on both sides for `dilation * + (kernel_size - 1) - padding` points. + - `output_padding`: The additional size added to the output shape. + - `dilation`: The spacing between kernel points. + - `use_bias`: Whether to add on a bias after the transposed convolution. + - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter + initialisation. (Keyword only argument.) + + !!! info + + All of `kernel_size`, `stride`, `padding`, `output_padding`, `dilation` can + be either an integer or a sequence of integers. If they are a sequence then + the sequence should be of length equal to `num_spatial_dims`, and specify + the value of each property down each spatial dimension in turn.. If they + are an integer then the same kernel size / stride / padding / dilation will + be used along every spatial dimension. + + """ + super().__init__(**kwargs) + self.num_spatial_dims = num_spatial_dims + parse = _ntuple(self.num_spatial_dims) + wkey, bkey = jrandom.split(key, 2) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = parse(kernel_size) + self.use_bias = use_bias + self.output_padding = parse(output_padding) + self.padding = parse(padding) + lim = 1 / np.sqrt(self.in_channels * np.prod(self.kernel_size)) + if self.num_spatial_dims == 1: + self.dimension_numbers = ("NCH", "IOH", "NCH") + elif self.num_spatial_dims == 2: + self.dimension_numbers = ("NCHW", "IOHW", "NCHW") + elif self.num_spatial_dims == 3: + self.dimension_numbers = ("NCDHW", "IODHW", "NCDHW") + else: + raise NotImplementedError( + "`ConvTranspose` only supports between 1 and 3 spatial dims", + f"({self.num_spatial_dims} was given)", + ) + self.weight = jrandom.uniform( + wkey, + ( + self.in_channels, + self.out_channels, + ) + + self.kernel_size, + minval=-lim, + maxval=lim, + ) + if self.use_bias: + self.bias = jrandom.uniform( + bkey, + (self.out_channels,) + (1,) * self.num_spatial_dims, + minval=-lim, + maxval=lim, + ) + else: + self.bias = None + + self.stride = parse(stride) + self.dilation = parse(dilation) + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + """**Arguments:** + + - `x`: The input. Should be a JAX array of shape `(in_channels, dim_1, ..., dim_N)`, where + `N = num_spatial_dims`. + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + + **Returns:** + + A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`. + """ + unbatched_rank = self.num_spatial_dims + 1 + if x.ndim != unbatched_rank: + raise ValueError( + f"Input to `ConvTranspose` needs to have rank {unbatched_rank},", + f" but input has shape {x.shape}.", + ) + x = jnp.expand_dims(x, axis=0) + padding = self.padding + if self.output_padding is not None: + padding = tuple( + map( + compute_adjusted_padding, + x.shape[2:], + self.weight.shape[2:], + self.stride, + self.padding, + self.output_padding, + self.dilation, + ) + ) + x = conv_transpose( + lhs=x, + rhs=self.weight, + strides=self.stride, + padding=padding, + rhs_dilation=self.dilation, + dimension_numbers=self.dimension_numbers, + ) + if self.use_bias: + x += jnp.broadcast_to(self.bias, x.shape) + x = jnp.squeeze(x, axis=0) + return x + + +class ConvTranspose1d(ConvTranspose): + """As [`equinox.nn.ConvTranspose`][] with `num_spatial_dims=1`.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + output_padding=0, + padding=0, + dilation=1, + use_bias=True, + *, + key, + **kwargs, + ): + super().__init__( + num_spatial_dims=1, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + padding=padding, + dilation=dilation, + use_bias=use_bias, + key=key, + **kwargs, + ) + + +class ConvTranspose2d(ConvTranspose): + """As [`equinox.nn.ConvTranspose`][] with `num_spatial_dims=2`.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=(1, 1), + output_padding=(0, 0), + padding=(0, 0), + dilation=(1, 1), + use_bias=True, + *, + key, + **kwargs, + ): + super().__init__( + num_spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + padding=padding, + dilation=dilation, + use_bias=use_bias, + key=key, + **kwargs, + ) + + +class ConvTranspose3d(ConvTranspose): + """As [`equinox.nn.ConvTranspose`][] with `num_spatial_dims=3`.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=(1, 1, 1), + output_padding=(0, 0, 0), + padding=(0, 0, 0), + dilation=(1, 1, 1), + use_bias=True, + *, + key, + **kwargs, + ): + super().__init__( + num_spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + output_padding=output_padding, + padding=padding, + dilation=dilation, + use_bias=use_bias, + key=key, + **kwargs, + ) diff --git a/equinox/nn/embedding.py b/equinox/nn/embedding.py new file mode 100644 index 00000000..f471f3af --- /dev/null +++ b/equinox/nn/embedding.py @@ -0,0 +1,61 @@ +from typing import Optional + +import jax +import jax.random as jrandom + +from ..custom_types import Array +from ..module import Module, static_field + + +class Embedding(Module): + """Simple lookup table style embedding""" + + num_embeddings: int = static_field() + embedding_dim: int = static_field() + weight: Array + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + weight: Optional[Array] = None, + *, + key: "jax.random.PRNGKey", + **kwargs, + ): + """**Arguments:** + + - `num_embeddings`: Size of embedding dictionary. + - `embedding_dim`: Size of each embedding vector. + - `weight`: If given, the embedding lookup table. + - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter + initialisation. (Keyword only argument.) + + """ + super().__init__(**kwargs) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if weight is None: + self.weight = jrandom.normal(key, (num_embeddings, embedding_dim)) + else: + if list(weight.shape) != [num_embeddings, embedding_dim]: + raise ValueError( + f"Shape of weight ({weight.shape}) does not match num_embeddings" + f" ({num_embeddings}) and embedding_dim ({embedding_dim})" + ) + self.weight = weight + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + """**Arguments:** + + - `x`: The table index. + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + + **Returns:** + + A JAX array of shape `embedding_dim` that gives the xth index of the embedding table. + """ + return self.weight[x] diff --git a/equinox/nn/normalization.py b/equinox/nn/normalization.py new file mode 100644 index 00000000..aca0ef90 --- /dev/null +++ b/equinox/nn/normalization.py @@ -0,0 +1,61 @@ +from typing import Optional, Sequence, Union + +import jax +import jax.numpy as jnp + +from ..custom_types import Array +from ..module import Module, static_field + + +class LayerNorm(Module): + """Layer Normalization as described in https://arxiv.org/abs/1607.06450""" + + normalized_shape: Union[int, Sequence[int]] = static_field() + eps: float = static_field() + elementwise_affine: bool = static_field() + weight: Array + bias: Array + + def __init__( + self, + normalized_shape: Union[int, Sequence[int]], + eps: float = 1e-5, + elementwise_affine: bool = True, + *, + key: "jax.random.PRNGKey", + **kwargs, + ): + """**Arguments:** + - `normalized_shape`: Input shape. + - `eps`: Value added to denominator for numerical stability. Default: `1e-5`. + - `elementwise_affine`: Whether the module has learnable affine parameters. Default: `True`. + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + """ + super().__init__(**kwargs) + self.normalized_shape = normalized_shape + self.eps = eps + self.elementwise_affine = elementwise_affine + self.weight = jnp.ones(self.normalized_shape) if elementwise_affine else None + self.bias = jnp.zeros(self.normalized_shape) if elementwise_affine else None + + def __call__( + self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None + ) -> Array: + """**Arguments:** + + - `x`: A JAX array of shape `normalized_shape`. + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + + **Returns:** + + A JAX array of shape `normalized_shape`. + """ + mean = jnp.mean(x, keepdims=True) + variance = jnp.var(x, keepdims=True) + inv = jax.lax.rsqrt(variance + self.eps) + out = (x - mean) * inv + if self.elementwise_affine: + out = self.weight * out + self.bias + return out diff --git a/tests/test_nn.py b/tests/test_nn.py index d66f9107..34595426 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -254,3 +254,269 @@ def test_conv3d(getkey): conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias)) answer = jnp.array([-3, -2, -1, 0, 1, 2, 3, 4, 1, 1, 1, 1]).reshape(1, 3, 2, 2) assert jnp.allclose(conv(data), answer) + + +def test_convtranspose1d(getkey): + # Positional arguments + conv = eqx.nn.ConvTranspose1d(1, 3, 3, key=getkey()) + x = jrandom.normal(getkey(), (1, 32)) + assert conv(x).shape == (3, 34) + + # Some keyword arguments + conv = eqx.nn.ConvTranspose1d(1, out_channels=3, kernel_size=(3,), key=getkey()) + x = jrandom.normal(getkey(), (1, 32)) + assert conv(x).shape == (3, 34) + + # All keyword arguments + conv = eqx.nn.ConvTranspose1d( + in_channels=1, + out_channels=3, + kernel_size=(3,), + padding=0, + output_padding=0, + use_bias=False, + key=getkey(), + ) + x = jrandom.normal(getkey(), (1, 32)) + assert conv(x).shape == (3, 34) + + # Test strides + conv = eqx.nn.ConvTranspose1d( + in_channels=3, + out_channels=1, + kernel_size=(3,), + stride=2, + padding=1, + output_padding=1, + use_bias=True, + key=getkey(), + ) + x = jrandom.normal(getkey(), (3, 32)) + assert conv(x).shape == (1, 64) + + # Test value matches + conv = eqx.nn.ConvTranspose1d(1, 3, kernel_size=3, padding=0, key=getkey()) + new_weight = jnp.arange(9).reshape(1, 3, 3) + new_bias = jnp.array([1, 2, 3]).reshape(3, 1) + data = jnp.arange(-3, 3).reshape(1, -1) + assert new_weight.shape == conv.weight.shape + assert new_bias.shape == conv.bias.shape + conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias)) + answer = jnp.array( + [ + -5, + -6, + -3, + 0, + 3, + 6, + 3, + 1, + -13, + -20, + -20, + -8, + 4, + 16, + 13, + 8, + -21, + -34, + -37, + -16, + 5, + 26, + 23, + 15, + ] + ).reshape(3, 8) + assert jnp.allclose(conv(data), answer) + + +def test_convtranspose2d(getkey): + # Positional arguments + conv = eqx.nn.ConvTranspose2d(1, 3, 3, key=getkey()) + x = jrandom.normal(getkey(), (1, 32, 32)) + assert conv(x).shape == (3, 34, 34) + + # Some keyword arguments + conv = eqx.nn.ConvTranspose2d(1, out_channels=3, kernel_size=(3, 3), key=getkey()) + x = jrandom.normal(getkey(), (1, 32, 32)) + assert conv(x).shape == (3, 34, 34) + + # All keyword arguments + conv = eqx.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + use_bias=False, + key=getkey(), + ) + x = jrandom.normal(getkey(), (1, 32, 32)) + assert conv(x).shape == (3, 32, 32) + + # Test strides + conv = eqx.nn.ConvTranspose2d( + in_channels=3, + out_channels=1, + kernel_size=(3, 3), + stride=2, + padding=1, + output_padding=1, + use_bias=True, + key=getkey(), + ) + x = jrandom.normal(getkey(), (3, 32, 32)) + assert conv(x).shape == (1, 64, 64) + + # Test value matches + conv = eqx.nn.ConvTranspose2d(1, 1, kernel_size=3, padding=1, key=getkey()) + new_weight = jnp.arange(9).reshape(1, 1, 3, 3) + new_bias = jnp.array([1]).reshape(1, 1, 1) + data = jnp.arange(-4, 5).reshape(1, 3, 3) + assert new_weight.shape == conv.weight.shape + assert new_bias.shape == conv.bias.shape + conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias)) + answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3) + assert jnp.allclose(conv(data), answer) + + +def test_convtranspose3d(getkey): + # Positional arguments + conv = eqx.nn.ConvTranspose3d(1, 3, 3, key=getkey()) + x = jrandom.normal(getkey(), (1, 3, 32, 32)) + assert conv(x).shape == (3, 5, 34, 34) + + # Some keyword arguments + conv = eqx.nn.ConvTranspose3d( + 1, out_channels=3, kernel_size=(3, 3, 3), key=getkey() + ) + x = jrandom.normal(getkey(), (1, 3, 32, 32)) + assert conv(x).shape == (3, 5, 34, 34) + + # All keyword arguments + conv = eqx.nn.ConvTranspose3d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3, 3), + padding=1, + use_bias=False, + key=getkey(), + ) + x = jrandom.normal(getkey(), (1, 3, 32, 32)) + assert conv(x).shape == (3, 3, 32, 32) + + # Test strides + conv = eqx.nn.ConvTranspose3d( + in_channels=3, + out_channels=1, + kernel_size=(3, 3, 3), + stride=2, + padding=1, + output_padding=1, + use_bias=True, + key=getkey(), + ) + x = jrandom.normal(getkey(), (3, 3, 32, 32)) + assert conv(x).shape == (1, 6, 64, 64) + + # Test value matches + conv = eqx.nn.ConvTranspose3d( + 1, 1, kernel_size=(2, 2, 2), padding=(0, 0, 0), key=getkey() + ) + new_weight = jnp.arange(8).reshape(1, 1, 2, 2, 2) + new_bias = jnp.array([1]).reshape(1, 1, 1, 1) + data = jnp.arange(-4, 4).reshape(1, 2, 2, 2) + assert new_weight.shape == conv.weight.shape + assert new_bias.shape == conv.bias.shape + conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias)) + answer = jnp.array( + [ + -27, + -44, + -17, + -33, + -49, + -17, + -9, + -12, + -3, + -11, + -9, + 1, + 5, + 29, + 21, + 9, + 23, + 13, + 1, + 4, + 3, + 7, + 15, + 7, + 3, + 4, + 1, + ] + ).reshape(1, 3, 3, 3) + assert jnp.allclose(conv(data), answer) + + +def test_multihead_attention(getkey): + atn = eqx.nn.MultiheadAttention(128, 4, key=getkey()) + x = jrandom.uniform(getkey(), (4, 128)) + assert atn(x, x, x).shape == (4, 128) + + atn = eqx.nn.MultiheadAttention(embed_dim=512, num_heads=8, key=getkey()) + x = jrandom.uniform(getkey(), (2, 512)) + assert atn(x, x, x).shape == (2, 512) + + atn = eqx.nn.MultiheadAttention(4, 2, use_bias=False, key=getkey()) + atn = eqx.tree_at( + lambda x: ( + x.q_proj.weight, + x.k_proj.weight, + x.v_proj.weight, + x.out_proj.weight, + ), + atn, + [jnp.arange(16).reshape(4, 4) for _ in range(4)], + ) + x = jnp.array([[1, 2, 3, 4]]) + assert jnp.allclose(atn(x, x, x), jnp.array([[680.0, 1960.0, 3240.0, 4520.0]])) + + +def test_embedding(getkey): + emb = eqx.nn.Embedding(100, 512, key=getkey()) + x = jnp.array([1]) + assert emb(x).shape == (1, 512) + + emb = eqx.nn.Embedding(num_embeddings=10, embedding_dim=20, key=getkey()) + x = jnp.array([0]) + assert emb(x).shape == (1, 20) + + emb = eqx.nn.Embedding( + 10, 10, weight=jnp.linspace(0.1, 10, 100).reshape(10, 10), key=getkey() + ) + x = jnp.array([-1]) + assert jnp.allclose(emb(x), jnp.linspace(9.1, 10.0, 10)) + + +def test_layer_norm(getkey): + ln = eqx.nn.LayerNorm(128, key=getkey()) + x = jrandom.uniform(getkey(), (128,)) + assert ln(x).shape == (128,) + + ln = eqx.nn.LayerNorm(normalized_shape=(128, 128), key=getkey()) + x = jrandom.uniform(getkey(), (128, 128)) + assert ln(x).shape == (128, 128) + + ln = eqx.nn.LayerNorm(10, key=getkey()) + x1 = jnp.linspace(0.1, 1, 10) + x2 = jnp.linspace(0, 1, 10) + x3 = (x1 - x1.mean()) / jnp.sqrt(x1.var() + 1e-5) + assert jnp.allclose(ln(x1), ln(x2), atol=1e-4) + assert jnp.allclose(ln(x1), x3, atol=1e-4)