Skip to content

Commit

Permalink
[branchformer] simplified branchformer (#2482)
Browse files Browse the repository at this point in the history
* [transformer] simplified branchformer

* fix yaml

* support mqa  gradiengt ckpt sdpa

* fix gradient checkponit

* add deepspeed comment in layer dropout

* fix comment
  • Loading branch information
Mddct authored Apr 17, 2024
1 parent 1f0fba4 commit 2b67e6c
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 276 deletions.
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_branchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ encoder_conf:
output_size: 256
use_attn: true
attention_heads: 4
attention_layer_type: rel_selfattn
selfattention_layer_type: rel_selfattn
pos_enc_layer_type: rel_pos
use_cgmlp: true
cgmlp_linear_units: 2048
Expand Down
2 changes: 1 addition & 1 deletion examples/librispeech/s0/conf/train_u2++_branchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ encoder_conf:
output_size: 256
use_attn: true
attention_heads: 4
attention_layer_type: rel_selfattn
selfattention_layer_type: rel_selfattn
pos_enc_layer_type: rel_pos
use_cgmlp: true
cgmlp_linear_units: 2048
Expand Down
331 changes: 91 additions & 240 deletions wenet/branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@
"""Encoder definition."""

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union

from typing import List, Optional, Union

from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.transformer.encoder import BaseEncoder
from wenet.utils.class_utils import (
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES,
WENET_SUBSAMPLE_CLASSES,
)
WENET_ATTENTION_CLASSES, )


class BranchformerEncoder(nn.Module):
class BranchformerEncoder(BaseEncoder):
"""Branchformer encoder module."""

def __init__(
Expand All @@ -39,7 +35,7 @@ def __init__(
output_size: int = 256,
use_attn: bool = True,
attention_heads: int = 4,
attention_layer_type: str = "rel_selfattn",
selfattention_layer_type: str = "rel_selfattn",
pos_enc_layer_type: str = "rel_pos",
use_cgmlp: bool = True,
cgmlp_linear_units: int = 2048,
Expand All @@ -53,30 +49,41 @@ def __init__(
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
padding_idx: int = -1,
input_layer: str = "conv2d",
stochastic_depth_rate: Union[float, List[float]] = 0.0,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
causal: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__()
self._output_size = output_size

self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
WENET_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
super().__init__(input_size, output_size, attention_heads,
cgmlp_linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, True,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)

encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)

cgmlp_layer = ConvolutionalGatingMLP
Expand All @@ -87,6 +94,7 @@ def __init__(
dropout_rate,
use_linear_after_conv,
gate_activation,
causal,
)

if isinstance(stochastic_depth_rate, float):
Expand All @@ -110,221 +118,64 @@ def __init__(
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
f"should be equal to num_blocks ({num_blocks})")

self.encoders = torch.nn.ModuleList([
BranchformerEncoderLayer(
output_size, WENET_ATTENTION_CLASSES[attention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate, merge_method, cgmlp_weight[lnum],
attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum])
for lnum in range(num_blocks)
])
self.after_norm = nn.LayerNorm(output_size)
self.static_chunk_size = static_chunk_size
self.global_cmvn = global_cmvn
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk

def output_size(self) -> int:
return self._output_size

def forward(
self,
xs: torch.Tensor,
ilens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
xs (torch.Tensor): Input tensor (B, T, D).
ilens (torch.Tensor): Input length (#batch).
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""

T = xs.size(1)
masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)

xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks

def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))

xs = self.after_norm(xs)

# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)

return (xs, r_att_cache, r_cnn_cache)

def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks

# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, att_cache,
cnn_cache) = self.forward_chunk(chunk_xs, offset,
required_cache_size, att_cache,
cnn_cache)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones((1, 1, ys.size(1)),
device=ys.device,
dtype=torch.bool)
return ys, masks
self.encoders = LayerDropModuleList(
p=stochastic_depth_rate,
modules=[
BranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate, merge_method, cgmlp_weight[lnum],
attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum],
gradient_checkpointing) for lnum in range(num_blocks)
])

@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
return self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)


# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa
class LayerDropModuleList(torch.nn.ModuleList):
"""
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
We refresh the choice of which layers to drop every time we iterate
over the LayerDropModuleList instance. During evaluation we always
iterate over all layers.
Usage::
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers: # this might iterate over layers 1 and 3
x = layer(x)
for layer in layers: # this might iterate over all layers
x = layer(x)
for layer in layers: # this might not iterate over any layers
x = layer(x)
Args:
p (float): probability of dropping out each layer
modules (iterable, optional): an iterable of modules to add
Limitations:
1 can work with ddp when layer's gradient checkpoint disabled
2 can't work with ddp when layer's gradient checkpoint enables
3 can work with fsdp
4 can work with deepspeed
"""

def __init__(self, p: List[float], modules=None):
super().__init__(modules)
assert len(p) == len(self)
self.p = p

def __iter__(self):
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p[i]):
yield m
Loading

0 comments on commit 2b67e6c

Please sign in to comment.