From ee73151afb4139aff5bf0c86d17f671ff9ace68a Mon Sep 17 00:00:00 2001 From: Lucky Wong Date: Mon, 18 Sep 2023 19:32:54 +0800 Subject: [PATCH] Add E-Branchformer module (#2013) * Add E-Branchformer module * update result --- examples/aishell/s0/README.md | 17 + .../aishell/s0/conf/train_ebranchformer.yaml | 88 ++++ wenet/branchformer/cgmlp.py | 2 + wenet/e_branchformer/encoder.py | 389 ++++++++++++++++++ wenet/e_branchformer/encoder_layer.py | 177 ++++++++ wenet/utils/init_model.py | 5 + 6 files changed, 678 insertions(+) create mode 100644 examples/aishell/s0/conf/train_ebranchformer.yaml create mode 100644 wenet/e_branchformer/encoder.py create mode 100644 wenet/e_branchformer/encoder_layer.py diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index ae5532fd2..85ba82b33 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -203,4 +203,21 @@ | attention rescoring | 4.81 | | LM + attention rescoring | 4.46 | +## E-Branchformer Result +* Feature info: using fbank feature, dither=1.0, cmvn, online speed perturb +* * Model info: + * Model Params: 47,570,132 + * Num Encoder Layer: 17 + * CNN Kernel Size: 31 +* Training info: lr 0.001, weight_decay: 0.000001, batch size 16, 4 gpu, acc_grad 1, 240 epochs +* Decoding info: ctc_weight 0.3, average_num 30 +* Git hash: 89962d1dcae18dd3a281782a40e74dd2721ae8fe + +| decoding mode | CER | +| ---------------------- | ---- | +| attention decoder | 4.73 | +| ctc greedy search | 4.77 | +| ctc prefix beam search | 4.77 | +| attention rescoring | 4.39 | +| LM + attention rescoring | 4.22 | diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml new file mode 100644 index 000000000..edc952295 --- /dev/null +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -0,0 +1,88 @@ +# network architecture +# encoder related +encoder: e_branchformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 8 + linear_units: 1024 # the number of units of position-wise feed forward + num_blocks: 17 # the number of encoder blocks + cgmlp_linear_units: 1024 + cgmlp_conv_kernel: 31 + use_linear_after_conv: false + gate_activation: identity + merge_conv_kernel: 31 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + activation_type: 'swish' + causal: false + pos_enc_layer_type: 'rel_pos' + attention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + spec_sub: true + spec_sub_conf: + num_t_sub: 3 + max_t: 30 + spec_trim: false + spec_trim_conf: + max_t: 50 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 240 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.000001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 35000 diff --git a/wenet/branchformer/cgmlp.py b/wenet/branchformer/cgmlp.py index eb9b41559..190981d88 100644 --- a/wenet/branchformer/cgmlp.py +++ b/wenet/branchformer/cgmlp.py @@ -143,6 +143,7 @@ def __init__( dropout_rate: float, use_linear_after_conv: bool, gate_activation: str, + causal: bool = True, ): super().__init__() @@ -155,6 +156,7 @@ def __init__( dropout_rate=dropout_rate, use_linear_after_conv=use_linear_after_conv, gate_activation=gate_activation, + causal=causal, ) self.channel_proj2 = torch.nn.Linear(linear_units // 2, size) diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py new file mode 100644 index 000000000..a590e3c37 --- /dev/null +++ b/wenet/e_branchformer/encoder.py @@ -0,0 +1,389 @@ +# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) +# 2023 Voicecomm Inc (Kai Li) +# 2023 Lucky Wong +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""Encoder definition.""" + +import torch +import torch.nn as nn +from typing import List, Optional, Tuple, Union + +from wenet.transformer.attention import ( + MultiHeadedAttention, + RelPositionMultiHeadedAttention, +) +from wenet.transformer.embedding import ( + RelPositionalEncoding, + PositionalEncoding, + NoPositionalEncoding, +) +from wenet.transformer.subsampling import ( + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, +) + +from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer +from wenet.branchformer.cgmlp import ConvolutionalGatingMLP +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.utils.mask import make_pad_mask +from wenet.utils.mask import add_optional_chunk_mask +from wenet.utils.common import get_activation + + +class EBranchformerEncoder(nn.Module): + """E-Branchformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + attention_layer_type: str = "rel_selfattn", + pos_enc_layer_type: str = "rel_pos", + activation_type: str = "swish", + cgmlp_linear_units: int = 2048, + cgmlp_conv_kernel: int = 31, + use_linear_after_conv: bool = False, + gate_activation: str = "identity", + merge_method: str = "concat", + num_blocks: int = 12, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + padding_idx: int = -1, + stochastic_depth_rate: Union[float, List[float]] = 0.0, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + causal: bool = False, + merge_conv_kernel: int = 3, + use_ffn: bool = True, + macaron_style: bool = True, + ): + super().__init__() + activation = get_activation(activation_type) + self._output_size = output_size + + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "no_pos": + pos_enc_class = NoPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert attention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling4( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + if attention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif attention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + attention_layer_type) + + cgmlp_layer = ConvolutionalGatingMLP + cgmlp_layer_args = ( + output_size, + cgmlp_linear_units, + cgmlp_conv_kernel, + dropout_rate, + use_linear_after_conv, + gate_activation, + causal + ) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + + if isinstance(stochastic_depth_rate, float): + stochastic_depth_rate = [stochastic_depth_rate] * num_blocks + if len(stochastic_depth_rate) != num_blocks: + raise ValueError( + f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " + f"should be equal to num_blocks ({num_blocks})" + ) + + self.encoders = torch.nn.ModuleList([ + EBranchformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + cgmlp_layer(*cgmlp_layer_args), + positionwise_layer(*positionwise_layer_args) if use_ffn else None, + positionwise_layer(*positionwise_layer_args) + if use_ffn and macaron_style else None, + dropout_rate, + merge_conv_kernel=merge_conv_kernel, + causal=causal, + stochastic_depth_rate=stochastic_depth_rate[lnum], + ) for lnum in range(num_blocks) + ]) + + self.after_norm = nn.LayerNorm(output_size) + self.static_chunk_size = static_chunk_size + self.global_cmvn = global_cmvn + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + ilens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + xs (torch.Tensor): Input tensor (B, T, D). + ilens (torch.Tensor): Input length (#batch). + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ + + T = xs.size(1) + masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + for layer in self.encoders: + xs, chunk_masks, _ , _ = layer(xs, chunk_masks, pos_emb, mask_pad) + + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, att_mask, pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache + ) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + return ys, masks diff --git a/wenet/e_branchformer/encoder_layer.py b/wenet/e_branchformer/encoder_layer.py new file mode 100644 index 000000000..08b0c4d92 --- /dev/null +++ b/wenet/e_branchformer/encoder_layer.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) +# 2023 Voicecomm Inc (Kai Li) +# 2023 Lucky Wong +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""EBranchformerEncoderLayer definition.""" + +import torch +import torch.nn as nn +from typing import Optional, Tuple + + +class EBranchformerEncoderLayer(torch.nn.Module): + """E-Branchformer encoder layer module. + + Args: + size (int): model dimension + attn: standard self-attention or efficient attention + cgmlp: ConvolutionalGatingMLP + feed_forward: feed-forward module, optional + feed_forward: macaron-style feed-forward module, optional + dropout_rate (float): dropout probability + merge_conv_kernel (int): kernel size of the depth-wise conv in merge module + """ + + def __init__( + self, + size: int, + attn: torch.nn.Module, + cgmlp: torch.nn.Module, + feed_forward: Optional[torch.nn.Module], + feed_forward_macaron: Optional[torch.nn.Module], + dropout_rate: float, + merge_conv_kernel: int = 3, + causal: bool = True, + stochastic_depth_rate=0.0, + ): + super().__init__() + + self.size = size + self.attn = attn + self.cgmlp = cgmlp + + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.ff_scale = 1.0 + if self.feed_forward is not None: + self.norm_ff = nn.LayerNorm(size) + if self.feed_forward_macaron is not None: + self.ff_scale = 0.5 + self.norm_ff_macaron = nn.LayerNorm(size) + + self.norm_mha = nn.LayerNorm(size) # for the MHA module + self.norm_mlp = nn.LayerNorm(size) # for the MLP module + # for the final output of the block + self.norm_final = nn.LayerNorm(size) + + self.dropout = torch.nn.Dropout(dropout_rate) + + if causal: + padding = 0 + self.lorder = merge_conv_kernel - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (merge_conv_kernel - 1) % 2 == 0 + padding = (merge_conv_kernel - 1) // 2 + self.lorder = 0 + self.depthwise_conv_fusion = torch.nn.Conv1d( + size + size, + size + size, + kernel_size=merge_conv_kernel, + stride=1, + padding=padding, + groups=size + size, + bias=True, + ) + self.merge_proj = torch.nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time, time). + pos_emb (torch.Tensor): positional encoding, must not be None + for BranchformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in cgmlp layer + (#batch=1, size, cache_t2) + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time. + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + stoch_layer_coeff = 1.0 + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + return x, mask, att_cache, cnn_cache + + if self.feed_forward_macaron is not None: + residual = x + x = self.norm_ff_macaron(x) + x = residual + stoch_layer_coeff * self.ff_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + # Two branches + x1 = x + x2 = x + + # Branch 1: multi-headed attention module + x1 = self.norm_mha(x1) + x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb, att_cache) + x1 = self.dropout(x_att) + + # Branch 2: convolutional gating mlp + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + x2 = self.norm_mlp(x2) + x2, new_cnn_cache = self.cgmlp(x2, mask_pad, cnn_cache) + x2 = self.dropout(x2) + + # Merge two branches + x_concat = torch.cat([x1, x2], dim=-1) + x_tmp = x_concat.transpose(1, 2) + if self.lorder > 0: + x_tmp = nn.functional.pad(x_tmp, (self.lorder, 0), "constant", 0.0) + assert x_tmp.size(2) > self.lorder + x_tmp = self.depthwise_conv_fusion(x_tmp) + x_tmp = x_tmp.transpose(1, 2) + x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x_concat + x_tmp)) + + if self.feed_forward is not None: + # feed forward module + residual = x + x = self.norm_ff(x) + x = residual + stoch_layer_coeff * self.ff_scale * self.dropout( + self.feed_forward(x) + ) + + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index f01e03469..886bca77c 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -23,6 +23,7 @@ from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder from wenet.branchformer.encoder import BranchformerEncoder +from wenet.e_branchformer.encoder import EBranchformerEncoder from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder from wenet.paraformer.paraformer import Paraformer @@ -65,6 +66,10 @@ def init_model(configs): encoder = BranchformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'e_branchformer': + encoder = EBranchformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, global_cmvn=global_cmvn,