-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Squeezeformer implementation #1431
Conversation
Great job, I will go through the details as soon as I can. |
wenet/transformer/encoder.py
Outdated
if time_recover_idx is not None | ||
else None | ||
) | ||
if len(time_reduce_idx) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be error if time_reduce_idx
is None
.
wenet/transformer/convolution.py
Outdated
@@ -79,14 +81,15 @@ def __init__(self, | |||
self.norm = nn.LayerNorm(channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argument of BatchNorm1d
or LayerNorm
maybe channels if use_glu else 2 * channels
.
In addition, the argument groups
of depthwise_conv
maybe channels if use_glu else 2 * channels
too.
Thanks for the reviews, TeaPoly. I've checked them and fixed the corresponding codes. Additionally, I found a mistake that |
for i, recover_layer in enumerate(self.time_recover_layers): | ||
if time_reduce_level == i: | ||
xs = recover_layer(xs) | ||
xs += residual_xs # (B,T,D) | ||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) | ||
if self.normalize_before: | ||
xs = self.after_norm(xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be a PreLN
at the beginning of blocks if you use Squeezeformer
. This after_norm
would be duplicate because there are PostLN
at the end of every Squeezeformer
block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As Squeezeformer
always employ PostLN
rather than PreLN
, there's no option to choose PreLN
.
Therefore, to avoid such a duplication problem, I just inserted an assert
code to make sure normalize_before
is False
in SqueezeformerEncoder
.
How to set the values of time_reduce_idx and time_recover_idx? |
For example, let This convention, setting |
Great job! Any experimental result of SqueezeFormer on wenet? And I'm curious about the performance compared to Conformer. |
I think it is tricky if we insert the time reduction and upsampling in BaseEncoder, and it becomes more tricky if we take streaming, JIT export, ONNX support in consideration. It's better if we decouple it from BaseEncoder and it avoids unintended side effects. |
That's right. As I'm working on other urgent projects now, I'll try to decouple it from BaseEncoder later. |
I'm trying to get WER results on LibriSpeech. Maybe I can share some results in a few days if no more errors are found. |
Thanks for the contribution. #1447 is taken since it gives complete implementation and thorough experiments result. |
Squeezeformer apply 2x Time reduce to get smaller FLOPs, I wonder whether the performence of Squeezeformer will be much worse than Conformer when the speed of audio is fast? |
Maybe you are right. Authors haven't do ablation study on this. In our experiments, it show no degradation on public datasets in terms of ’squeeze’ operation. Most of the time, audio frames are usually much more than actual words or ctc path length. You could decrease downsampling factor or set reduce and recover idx to None in order to get better performence on fast audio. |
close this PR and leave it as a reference |
This PR is about my personal implementation of Squeezeformer for WeNet encoder structure.
(Original paper: https://arxiv.org/abs/2206.00888)
(Original code: https://github.com/kssteven418/Squeezeformer)
Added features
Experimental validation (on LibriSpeech)
Squeezeformer can be seen as an extension of the Conformer structure.
Thus it can be implemented by modifying the ConformerEncoder class, but I rather added SqueezeformerEncoder class to avoid confusion.
On the other hand, 2x Time reduce & recover logic code is inserted in the middle of forward functions in BaseEncoder class, which might cause unintended side effects.
(Likewise, Option: not to use GLU code is inserted in existing scripts of convolution.py)
Special care is needed for reviewing these codes.