Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squeezeformer implementation #1431

Closed
wants to merge 9 commits into from
Closed

Conversation

swigls
Copy link
Contributor

@swigls swigls commented Sep 5, 2022

This PR is about my personal implementation of Squeezeformer for WeNet encoder structure.
(Original paper: https://arxiv.org/abs/2206.00888)
(Original code: https://github.com/kssteven418/Squeezeformer)

  • Added features

    • SqueezeformerEncoder / SqueezeformerEncoderLayer
      • 2x Time reduce & recover logic (in forward functions of BaseEncoder)
        • TimeReduction2 layer (in subsampling.py)
      • Transformer-style block (Att.->Feedforward->Conv->Feedforward)
      • Scale & Bias layer (in place of pre layer-norm)
      • Option: not to use GLU (in convolution.py)
      • Depthwise conv2d input layers (in subsampling.py)
  • Experimental validation (on LibriSpeech)

    • Batch training
    • Batch inference (both for full-utterance & chunk-wise)
    • Streaming inference (e.g, JIT)
    • Configuration file and WER results

Squeezeformer can be seen as an extension of the Conformer structure.
Thus it can be implemented by modifying the ConformerEncoder class, but I rather added SqueezeformerEncoder class to avoid confusion.

On the other hand, 2x Time reduce & recover logic code is inserted in the middle of forward functions in BaseEncoder class, which might cause unintended side effects.
(Likewise, Option: not to use GLU code is inserted in existing scripts of convolution.py)
Special care is needed for reviewing these codes.

@robin1001
Copy link
Collaborator

Great job, I will go through the details as soon as I can.

if time_recover_idx is not None
else None
)
if len(time_reduce_idx) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be error if time_reduce_idx is None.

@@ -79,14 +81,15 @@ def __init__(self,
self.norm = nn.LayerNorm(channels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument of BatchNorm1d or LayerNorm maybe channels if use_glu else 2 * channels.
In addition, the argument groups of depthwise_conv maybe channels if use_glu else 2 * channels too.

@swigls
Copy link
Contributor Author

swigls commented Sep 9, 2022

Thanks for the reviews, TeaPoly. I've checked them and fixed the corresponding codes.

Additionally, I found a mistake that use_glu was set to be True in the SqueezeformerEncoder of the last commit.
I also modified this part.

for i, recover_layer in enumerate(self.time_recover_layers):
if time_reduce_level == i:
xs = recover_layer(xs)
xs += residual_xs # (B,T,D)
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be a PreLN at the beginning of blocks if you use Squeezeformer. This after_norm would be duplicate because there are PostLN at the end of every Squeezeformer block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Squeezeformer always employ PostLN rather than PreLN, there's no option to choose PreLN.
Therefore, to avoid such a duplication problem, I just inserted an assert code to make sure normalize_before is False in SqueezeformerEncoder.

@uloveqian2021
Copy link
Contributor

Thanks for the reviews, TeaPoly. I've checked them and fixed the corresponding codes.

Additionally, I found a mistake that use_glu was set to be True in the SqueezeformerEncoder of the last commit. I also modified this part.

How to set the values of time_reduce_idx and time_recover_idx?

@swigls
Copy link
Contributor Author

swigls commented Sep 13, 2022

How to set the values of time_reduce_idx and time_recover_idx?
time_reduce_idx and time_recover_idx should be set to lists of (encoder layer) indexes where time reduction or recovery happens.

For example, let time_reduce_idx=[2,5] and time_recover_idx=[8,11] where num_blocks=12.
It means on top of basic conv2d-subsampling (e.g., conv2d with 4x subsampling), additional 2x time reductions are done at the 3rd and 6th encoder blocks. Likewise, 2x time recoveries with residual connections are done at the 9th and last blocks.
In this case, the 1st, 2nd, and the last blocks are processed with a 40 ms stride while the 3rd, 4th, 5th, 9th, 10th, and 11th blocks are processed with an 80 ms stride. The 6th, 7th, and 8th blocks are processed with a 160 ms stride.

This convention, setting time_reduce_idx and time_recover_idx to be lists of indexes, came from the original code (https://github.com/kssteven418/Squeezeformer).

@robin1001
Copy link
Collaborator

Great job! Any experimental result of SqueezeFormer on wenet? And I'm curious about the performance compared to Conformer.

@robin1001
Copy link
Collaborator

robin1001 commented Sep 15, 2022

On the other hand, 2x Time reduce & recover logic code is inserted in the middle of forward functions in BaseEncoder class, which might cause unintended side effects.

I think it is tricky if we insert the time reduction and upsampling in BaseEncoder, and it becomes more tricky if we take streaming, JIT export, ONNX support in consideration. It's better if we decouple it from BaseEncoder and it avoids unintended side effects.

@robin1001
Copy link
Collaborator

@yygle is also working on SqueezeFormer, and he has made great progress, please see #1446. We can work together.

@swigls
Copy link
Contributor Author

swigls commented Sep 15, 2022

On the other hand, 2x Time reduce & recover logic code is inserted in the middle of forward functions in BaseEncoder class, which might cause unintended side effects.

I think it is tricky if we insert the time reduction and upsampling in BaseEncoder, and it becomes more tricky if we take streaming, JIT export, ONNX support in consideration. It's better if we decouple it from BaseEncoder and it avoids unintended side effects.

That's right. As I'm working on other urgent projects now, I'll try to decouple it from BaseEncoder later.

@swigls
Copy link
Contributor Author

swigls commented Sep 15, 2022

Great job! Any experimental result of SqueezeFormer on wenet? And I'm curious about the performance compared to Conformer.

I'm trying to get WER results on LibriSpeech. Maybe I can share some results in a few days if no more errors are found.

@robin1001
Copy link
Collaborator

Thanks for the contribution. #1447 is taken since it gives complete implementation and thorough experiments result.

@heyuandeng
Copy link

Squeezeformer apply 2x Time reduce to get smaller FLOPs, I wonder whether the performence of Squeezeformer will be much worse than Conformer when the speed of audio is fast?

@yygle
Copy link
Contributor

yygle commented Nov 10, 2022

Maybe you are right. Authors haven't do ablation study on this. In our experiments, it show no degradation on public datasets in terms of ’squeeze’ operation. Most of the time, audio frames are usually much more than actual words or ctc path length. You could decrease downsampling factor or set reduce and recover idx to None in order to get better performence on fast audio.

@xingchensong
Copy link
Member

close this PR and leave it as a reference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants