diff --git a/examples/aishell/s0/conf/train_u2++_branchformer.yaml b/examples/aishell/s0/conf/train_u2++_branchformer.yaml index 8702fbeb4..37fda9115 100644 --- a/examples/aishell/s0/conf/train_u2++_branchformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_branchformer.yaml @@ -5,7 +5,7 @@ encoder_conf: output_size: 256 use_attn: true attention_heads: 4 - attention_layer_type: rel_selfattn + selfattention_layer_type: rel_selfattn pos_enc_layer_type: rel_pos use_cgmlp: true cgmlp_linear_units: 2048 diff --git a/examples/librispeech/s0/conf/train_u2++_branchformer.yaml b/examples/librispeech/s0/conf/train_u2++_branchformer.yaml index 3b7961444..f643831ca 100644 --- a/examples/librispeech/s0/conf/train_u2++_branchformer.yaml +++ b/examples/librispeech/s0/conf/train_u2++_branchformer.yaml @@ -5,7 +5,7 @@ encoder_conf: output_size: 256 use_attn: true attention_heads: 4 - attention_layer_type: rel_selfattn + selfattention_layer_type: rel_selfattn pos_enc_layer_type: rel_pos use_cgmlp: true cgmlp_linear_units: 2048 diff --git a/wenet/branchformer/encoder.py b/wenet/branchformer/encoder.py index 7d00b2a70..1c67a91d2 100644 --- a/wenet/branchformer/encoder.py +++ b/wenet/branchformer/encoder.py @@ -16,21 +16,17 @@ """Encoder definition.""" import torch -import torch.nn as nn -from typing import List, Optional, Tuple, Union + +from typing import List, Optional, Union from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP -from wenet.utils.mask import make_pad_mask -from wenet.utils.mask import add_optional_chunk_mask +from wenet.transformer.encoder import BaseEncoder from wenet.utils.class_utils import ( - WENET_ATTENTION_CLASSES, - WENET_EMB_CLASSES, - WENET_SUBSAMPLE_CLASSES, -) + WENET_ATTENTION_CLASSES, ) -class BranchformerEncoder(nn.Module): +class BranchformerEncoder(BaseEncoder): """Branchformer encoder module.""" def __init__( @@ -39,7 +35,7 @@ def __init__( output_size: int = 256, use_attn: bool = True, attention_heads: int = 4, - attention_layer_type: str = "rel_selfattn", + selfattention_layer_type: str = "rel_selfattn", pos_enc_layer_type: str = "rel_pos", use_cgmlp: bool = True, cgmlp_linear_units: int = 2048, @@ -53,30 +49,41 @@ def __init__( dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - padding_idx: int = -1, + input_layer: str = "conv2d", stochastic_depth_rate: Union[float, List[float]] = 0.0, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, causal: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): - super().__init__() - self._output_size = output_size - - self.embed = WENET_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) + super().__init__(input_size, output_size, attention_heads, + cgmlp_linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, True, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing, + use_sdpa, layer_norm_type, norm_eps) encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, ) cgmlp_layer = ConvolutionalGatingMLP @@ -87,6 +94,7 @@ def __init__( dropout_rate, use_linear_after_conv, gate_activation, + causal, ) if isinstance(stochastic_depth_rate, float): @@ -110,221 +118,64 @@ def __init__( f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " f"should be equal to num_blocks ({num_blocks})") - self.encoders = torch.nn.ModuleList([ - BranchformerEncoderLayer( - output_size, WENET_ATTENTION_CLASSES[attention_layer_type]( - *encoder_selfattn_layer_args) if use_attn else None, - cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, - dropout_rate, merge_method, cgmlp_weight[lnum], - attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum]) - for lnum in range(num_blocks) - ]) - self.after_norm = nn.LayerNorm(output_size) - self.static_chunk_size = static_chunk_size - self.global_cmvn = global_cmvn - self.use_dynamic_chunk = use_dynamic_chunk - self.use_dynamic_left_chunk = use_dynamic_left_chunk - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs: torch.Tensor, - ilens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - xs (torch.Tensor): Input tensor (B, T, D). - ilens (torch.Tensor): Input length (#batch). - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - """ - - T = xs.size(1) - masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks - - def forward_chunk( - self, - xs: torch.Tensor, - offset: int, - required_cache_size: int, - att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ Forward just one chunk - - Args: - xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), - where `time == (chunk_size - 1) * subsample_rate + \ - subsample.right_context + 1` - offset (int): current offset in encoder output time stamp - required_cache_size (int): cache size required for next chunk - compuation - >=0: actual cache size - <0: means all history cache is required - att_cache (torch.Tensor): cache tensor for KEY & VALUE in - transformer/conformer attention, with shape - (elayers, head, cache_t1, d_k * 2), where - `head * d_k == hidden-dim` and - `cache_t1 == chunk_size * num_decoding_left_chunks`. - cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, - (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1` - - Returns: - torch.Tensor: output of current input xs, - with shape (b=1, chunk_size, hidden-dim). - torch.Tensor: new attention cache required for next chunk, with - dynamic shape (elayers, head, ?, d_k * 2) - depending on required_cache_size. - torch.Tensor: new conformer cnn cache required for next chunk, with - same shape as the original cnn_cache. - - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) - elayers, cache_t1 = att_cache.size(0), att_cache.size(2) - chunk_size = xs.size(1) - attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding(offset=offset - cache_t1, - size=attention_key_size) - if required_cache_size < 0: - next_cache_start = 0 - elif required_cache_size == 0: - next_cache_start = attention_key_size - else: - next_cache_start = max(attention_key_size - required_cache_size, 0) - r_att_cache = [] - r_cnn_cache = [] - for i, layer in enumerate(self.encoders): - # NOTE(xcsong): Before layer.forward - # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), - # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) - xs, _, new_att_cache, new_cnn_cache = layer( - xs, - att_mask, - pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) - # NOTE(xcsong): After layer.forward - # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), - # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) - r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) - - xs = self.after_norm(xs) - - # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), - # ? may be larger than cache_t1, it depends on required_cache_size - r_att_cache = torch.cat(r_att_cache, dim=0) - # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) - r_cnn_cache = torch.cat(r_cnn_cache, dim=0) - - return (xs, r_att_cache, r_cnn_cache) - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Forward input chunk by chunk with chunk_size like a streaming - fashion - - Here we should pay special attention to computation cache in the - streaming style forward chunk by chunk. Three things should be taken - into account for computation in the current network: - 1. transformer/conformer encoder layers output cache - 2. convolution in conformer - 3. convolution in subsampling - - However, we don't implement subsampling cache for: - 1. We can control subsampling module to output the right result by - overlapping input instead of cache left context, even though it - wastes some computation, but subsampling only takes a very - small fraction of computation in the whole model. - 2. Typically, there are several covolution layers with subsampling - in subsampling module, it is tricky and complicated to do cache - with different convolution layers with different subsampling - rate. - 3. Currently, nn.Sequential is used to stack all the convolution - layers in subsampling, we need to rewrite it to make it work - with cache, which is not prefered. - Args: - xs (torch.Tensor): (1, max_len, dim) - chunk_size (int): decoding chunk size - """ - assert decoding_chunk_size > 0 - # The model is trained by static or dynamic chunk - assert self.static_chunk_size > 0 or self.use_dynamic_chunk - subsampling = self.embed.subsampling_rate - context = self.embed.right_context + 1 # Add current frame - stride = subsampling * decoding_chunk_size - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - outputs = [] - offset = 0 - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - - # Feed forward overlap input step by step - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, att_cache, - cnn_cache) = self.forward_chunk(chunk_xs, offset, - required_cache_size, att_cache, - cnn_cache) - outputs.append(y) - offset += y.size(1) - ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), - device=ys.device, - dtype=torch.bool) - return ys, masks + self.encoders = LayerDropModuleList( + p=stochastic_depth_rate, + modules=[ + BranchformerEncoderLayer( + output_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args) if use_attn else None, + cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, + dropout_rate, merge_method, cgmlp_weight[lnum], + attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum], + gradient_checkpointing) for lnum in range(num_blocks) + ]) + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + return self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + + +# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa +class LayerDropModuleList(torch.nn.ModuleList): + """ + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + We refresh the choice of which layers to drop every time we iterate + over the LayerDropModuleList instance. During evaluation we always + iterate over all layers. + + Usage:: + + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + + Limitations: + 1 can work with ddp when layer's gradient checkpoint disabled + 2 can't work with ddp when layer's gradient checkpoint enables + 3 can work with fsdp + 4 can work with deepspeed + """ + + def __init__(self, p: List[float], modules=None): + super().__init__(modules) + assert len(p) == len(self) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p[i]): + yield m diff --git a/wenet/branchformer/encoder_layer.py b/wenet/branchformer/encoder_layer.py index 9654a2405..0cbd2e6f9 100644 --- a/wenet/branchformer/encoder_layer.py +++ b/wenet/branchformer/encoder_layer.py @@ -46,6 +46,7 @@ def __init__( cgmlp_weight: float = 0.5, attn_branch_drop_rate: float = 0.0, stochastic_depth_rate: float = 0.0, + gradient_checkpointing: bool = False, ): super().__init__() assert (attn is not None) or ( @@ -105,8 +106,9 @@ def __init__( raise ValueError(f"unknown merge method: {merge_method}") else: self.merge_proj = torch.nn.Identity() + self.gradient_checkpointing = gradient_checkpointing - def forward( + def _forward( self, x: torch.Tensor, mask: torch.Tensor, @@ -114,40 +116,8 @@ def forward( mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + stoch_layer_coeff: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute encoded features. - - Args: - x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time, time). - pos_emb (torch.Tensor): positional encoding, must not be None - for BranchformerEncoderLayer. - mask_pad (torch.Tensor): batch padding mask used for conv module. - (#batch, 1,time), (0, 0, 0) means fake mask. - att_cache (torch.Tensor): Cache tensor of the KEY & VALUE - (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. - cnn_cache (torch.Tensor): Convolution cache in cgmlp layer - (#batch=1, size, cache_t2) - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time, time. - torch.Tensor: att_cache tensor, - (#batch=1, head, cache_t1 + time, d_k * 2). - torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). - """ - - stoch_layer_coeff = 1.0 - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - return x, mask, att_cache, cnn_cache - # Two branches x1 = x x2 = x @@ -232,3 +202,42 @@ def forward( x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time, time). + pos_emb (torch.Tensor): positional encoding, must not be None + for BranchformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in cgmlp layer + (#batch=1, size, cache_t2) + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time. + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + stoch_layer_coeff = 1.0 + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + if self.training: + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache, + stoch_layer_coeff)