Skip to content

Commit

Permalink
incorporate some einops suggestions from @arogozhnikov
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains authored Jan 4, 2021
1 parent 86b8316 commit 0f8fbd3
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions conformer/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# helper functions

Expand All @@ -22,15 +23,6 @@ class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()

class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
self.dims = dims

def forward(self, x):
return x.transpose(*self.dims)

class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -104,7 +96,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):

# shaw's relative positional embedding
seq = torch.arange(n, device = device)
dist = seq[:, None] - seq[None, :]
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
dist = dist.clip(-max_pos_emb, max_pos_emb) + max_pos_emb
rel_pos_emb = self.rel_pos_emb(dist).to(q)
pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
Expand All @@ -114,7 +106,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
mask_value = -torch.finfo(dots.dtype).max
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)

attn = dots.softmax(dim = -1)
Expand Down Expand Up @@ -158,14 +150,14 @@ def __init__(

self.net = nn.Sequential(
nn.LayerNorm(dim),
Transpose((1, 2)),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
nn.Conv1d(inner_dim, dim, 1),
Transpose((1, 2)),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)

Expand Down

0 comments on commit 0f8fbd3

Please sign in to comment.