Skip to content

Commit

Permalink
make conformer able to do things autoregressively, to save issues wit…
Browse files Browse the repository at this point in the history
…h variable lengths in soundstorm
  • Loading branch information
lucidrains committed May 17, 2023
1 parent a37a2ad commit fc70d51
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions conformer/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def __init__(
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.):
dropout = 0.
):
super().__init__()

inner_dim = dim * expansion_factor
Expand Down Expand Up @@ -185,12 +186,13 @@ def __init__(
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)

self.attn = PreNorm(dim, self.attn)
Expand Down Expand Up @@ -222,7 +224,8 @@ def __init__(
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.dim = dim
Expand All @@ -236,6 +239,7 @@ def __init__(
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
conv_causal = conv_causal

))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'conformer',
packages = find_packages(),
version = '0.3.1',
version = '0.3.2',
license='MIT',
description = 'The convolutional module from the Conformer paper',
author = 'Phil Wang',
Expand Down

0 comments on commit fc70d51

Please sign in to comment.