-
-
Notifications
You must be signed in to change notification settings - Fork 178
ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
patrick-kidger
merged 19 commits into
patrick-kidger:attn-convt-layernorm
from
andyehrenberg:main
Mar 9, 2022
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
7a7a9ed
conv transpose layers
andyehrenberg 4fd42f1
fix output_padding default
andyehrenberg d192b3f
conv transpose tests
andyehrenberg c1b1baf
attention and embedding layers
andyehrenberg 0f344d5
LayerNorm
andyehrenberg 5dda0fc
documentation for attention
andyehrenberg eeb5e59
embedding documentation
andyehrenberg a8145f3
fix import error
andyehrenberg 1ce7b42
Merge branch 'patrick-kidger:main' into main
andyehrenberg 56e29b5
fix attention
andyehrenberg 2a599c5
Merge branch 'main' of https://github.com/andyehrenberg/equinox into …
andyehrenberg 6574edf
_project convenience fn
andyehrenberg dc95ebc
attention and embedding tests
andyehrenberg 7c81430
normalization test
andyehrenberg f5816ec
fix tests
andyehrenberg cf22c67
tolerance for layernorm test
andyehrenberg 40367ea
fix embedding test
andyehrenberg 1c093a9
better docs for attention, kwargs in init, other fixes
andyehrenberg 6c407ee
fix typos, style for conv
andyehrenberg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,17 @@ | ||
| from .attention import MultiheadAttention | ||
| from .composed import MLP, Sequential | ||
| from .conv import Conv, Conv1d, Conv2d, Conv3d | ||
| from .conv import ( | ||
| Conv, | ||
| Conv1d, | ||
| Conv2d, | ||
| Conv3d, | ||
| ConvTranspose, | ||
| ConvTranspose1d, | ||
| ConvTranspose2d, | ||
| ConvTranspose3d, | ||
| ) | ||
| from .dropout import Dropout | ||
| from .embedding import Embedding | ||
| from .linear import Identity, Linear | ||
| from .normalization import LayerNorm | ||
| from .rnn import GRUCell, LSTMCell |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| from typing import Optional | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import jax.random as jrandom | ||
| import numpy as np | ||
|
|
||
| from ..custom_types import Array | ||
| from ..module import Module, static_field | ||
| from .dropout import Dropout | ||
| from .linear import Linear | ||
|
|
||
|
|
||
| class MultiheadAttention(Module): | ||
| """ | ||
| Multihead Attention layer from `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. | ||
|
|
||
| $\text{MultiHeadAttention}(Q, K, V) = \text{Concat}(head_{1},...,head_{h})W^{out}$, | ||
| where $head_i = \text{softmax}(\frac{QW_i^Q(KW_i^K)^\intercal}{\sqrt{d_k}})VW_i^V$ | ||
|
|
||
| """ | ||
|
|
||
| embed_dim: int = static_field() | ||
| num_heads: int = static_field() | ||
| kdim: int = static_field() | ||
| vdim: int = static_field() | ||
| _qkv_same_embed_dim: bool = static_field() | ||
| head_dim: int = static_field() | ||
| q_proj: Linear | ||
| k_proj: Linear | ||
| v_proj: Linear | ||
| out_proj: Linear | ||
| dropout: Dropout | ||
|
|
||
| def __init__( | ||
| self, | ||
| embed_dim: int, | ||
| num_heads: int, | ||
| dropout: float = 0.0, | ||
| use_bias: bool = True, | ||
| kdim: Optional[int] = None, | ||
| vdim: Optional[int] = None, | ||
| add_bias_kv: bool = False, | ||
| *, | ||
| key: "jax.random.PRNGKey", | ||
andyehrenberg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| **kwargs, | ||
| ): | ||
| """**Arguments:** | ||
|
|
||
| - `embed_dim`: Dimension of the model. | ||
| - `num_heads`: Number of parallel attention heads. | ||
| - `dropout`: Dropout probability on attention matrix. Default: `0.0`. | ||
| - `use_bias`: Whether to use a bias term on the output projection. Default: `True`. | ||
| - `kdim`: Total number of features for keys. Default: `None` (use `kdim=embed_dim`). | ||
| - `vdim`: Total number of features for values. Default: `None` (use `vdim=embed_dim`). | ||
| - `add_bias_kv`: Whether to use bias term for value and key projections. Default: `False`. | ||
| - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter | ||
| initialisation. (Keyword only argument.) | ||
|
|
||
| """ | ||
| super().__init__(**kwargs) | ||
| key1, key2, key3, key4 = jrandom.split(key, 4) | ||
| self.embed_dim = embed_dim | ||
| self.kdim = kdim if kdim is not None else embed_dim | ||
| self.vdim = vdim if vdim is not None else embed_dim | ||
| self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | ||
| self.num_heads = num_heads | ||
| self.head_dim = embed_dim // num_heads | ||
| if self.embed_dim % num_heads != 0: | ||
| raise ValueError( | ||
| f"embed_dim must be divisible by num_heads (got embed_dim = {self.embed_dim}" | ||
| f" and num_heads = {self.num_heads})" | ||
| ) | ||
| if self.kdim % num_heads != 0: | ||
| raise ValueError( | ||
| f"kdim must be divisible by num_heads (got kdim = {self.kdim} and " | ||
| f"num_heads = {self.num_heads})" | ||
| ) | ||
| if self.vdim % num_heads != 0: | ||
| raise ValueError( | ||
| f"vdim must be divisible by num_heads (got vdim = {self.vdim} and " | ||
| f"num_heads = {self.num_heads})" | ||
| ) | ||
| if dropout == 0.0: | ||
| self.dropout = Dropout(dropout, deterministic=True) | ||
| else: | ||
| self.dropout = Dropout(dropout) | ||
| self.q_proj = Linear(self.embed_dim, self.embed_dim, use_bias=False, key=key1) | ||
| self.k_proj = Linear(self.kdim, self.embed_dim, use_bias=add_bias_kv, key=key2) | ||
| self.v_proj = Linear(self.vdim, self.embed_dim, use_bias=add_bias_kv, key=key3) | ||
| self.out_proj = Linear(embed_dim, embed_dim, use_bias=use_bias, key=key4) | ||
|
|
||
| def __call__( | ||
| self, | ||
| query: Array, | ||
| key_: Array, | ||
| value: Array, | ||
| attn_mask: Optional[Array] = None, | ||
| *, | ||
| key: Optional["jax.random.PRNGKey"] = None, | ||
| ) -> Array: | ||
| """**Arguments:** | ||
|
|
||
| - `query`: Query embedding. Should be a JAX array of shape `(sequence_length, embed_dim)`. | ||
| - `key_`: Key embedding. Should be a JAX array of shape `(sequence_length, embed_dim)`. | ||
| - `value`: Value embedding. Should be a JAX array of shape `(sequence_length, embed_dim)`. | ||
| - `attn_mask`: A mask preventing attention to certain positions. | ||
| - `key`: A PRNGKey used for dropout. | ||
|
|
||
| **Returns:** | ||
|
|
||
| A JAX array of shape `(sequence_length, embed_dim)`. | ||
| """ | ||
| d1, _ = query.shape | ||
| query_heads = self._project(self.q_proj, query) | ||
| key_heads = self._project(self.k_proj, key_) | ||
| value_heads = self._project(self.v_proj, value) | ||
| attn_logits = jnp.einsum("shd,Shd->hsS", query_heads, key_heads) | ||
| sqrt_key_size = np.sqrt(self.kdim // self.num_heads).astype(key.dtype) | ||
| attn_logits = attn_logits / sqrt_key_size | ||
| attn_logits = self.dropout(attn_logits, key=key) | ||
|
|
||
| if attn_mask is not None: | ||
| if attn_mask.ndim != attn_logits.ndim: | ||
| raise ValueError( | ||
| f"Mask dimensionality {attn_mask.ndim} must match logits " | ||
| f"{attn_logits.ndim}." | ||
| ) | ||
| attn_logits = jnp.where(attn_mask, attn_logits, -1e30) | ||
| attn_weights = jax.nn.softmax(attn_logits, axis=-1) | ||
| attn = jnp.einsum("hsS,Shd->shd", attn_weights, value_heads) | ||
| attn_vec = jnp.reshape(attn, (*query.shape[:-1], -1)) | ||
|
|
||
| return jax.vmap(self.out_proj)(attn_vec) | ||
|
|
||
| def _project(self, proj, x): | ||
| d1, _ = x.shape | ||
| projection = jax.vmap(proj)(x).reshape( | ||
| d1, self.num_heads, self.embed_dim // self.num_heads | ||
| ) | ||
| return projection | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to add the mathematical formulation here?
PyTorch kind-of do this for attention here, but I'm imagining something more precise, a la here.
The documentation generation supports LaTeX:
stuff $\alpha$ morestuff.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added documentation in the latest commit - though I had to skip using flake8 because it's complaining about the '\i' and '\s' in '\intercal' and '\sqrt'. How can this be ignored by the checks in this repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because Python/flake8 is trying to interpret those as escape codes, like
\nfor a newline. Parsing of escape codes can be disabled by prefixing anrbefore the string, i.e.:r""" ... stuff ... """