Skip to content

Commit

Permalink
refactor(model): keep name pattern of class mapping (#2175)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Nov 28, 2023
1 parent 4c4878e commit 0df2759
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 50 deletions.
4 changes: 2 additions & 2 deletions wenet/branchformer/cgmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Tuple
import torch
import torch.nn as nn
from wenet.utils.class_utils import get_activation
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES


class ConvolutionalSpatialGatingUnit(torch.nn.Module):
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
if gate_activation == "identity":
self.act = torch.nn.Identity()
else:
self.act = get_activation(gate_activation)
self.act = WENET_ACTIVATION_CLASSES[gate_activation]()

self.dropout = torch.nn.Dropout(dropout_rate)

Expand Down
4 changes: 2 additions & 2 deletions wenet/e_branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from wenet.utils.mask import add_optional_chunk_mask
from wenet.utils.class_utils import (
WENET_ATTENTION_CLASSES, WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES,
get_activation,
WENET_ACTIVATION_CLASSES,
)

class EBranchformerEncoder(nn.Module):
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
macaron_style: bool = True,
):
super().__init__()
activation = get_activation(activation_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self._output_size = output_size

self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
Expand Down
4 changes: 2 additions & 2 deletions wenet/efficient_conformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from wenet.utils.mask import add_optional_chunk_mask
from wenet.utils.class_utils import (
WENET_ATTENTION_CLASSES, WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES,
get_activation,
WENET_ACTIVATION_CLASSES,
)

class EfficientConformerEncoder(torch.nn.Module):
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk

activation = get_activation(activation_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self.num_blocks = num_blocks
self.attention_heads = attention_heads
self.cnn_module_kernel = cnn_module_kernel
Expand Down
4 changes: 2 additions & 2 deletions wenet/squeezeformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import PositionwiseFeedForward
from wenet.squeezeformer.convolution import ConvolutionModule
from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask
from wenet.utils.class_utils import get_activation
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES


class SqueezeformerEncoder(nn.Module):
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.pos_enc_layer_type = pos_enc_layer_type
activation = get_activation(activation_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
if pos_enc_layer_type != "rel_pos":
Expand Down
7 changes: 4 additions & 3 deletions wenet/transducer/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import nn
from wenet.utils.class_utils import get_activation
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES


class TransducerJoint(torch.nn.Module):
Expand All @@ -23,7 +23,7 @@ def __init__(self,
assert joint_mode in ['add']
super().__init__()

self.activatoin = get_activation(activation)
self.activatoin = WENET_ACTIVATION_CLASSES[activation]()
self.prejoin_linear = prejoin_linear
self.postjoin_linear = postjoin_linear
self.joint_mode = joint_mode
Expand Down Expand Up @@ -55,7 +55,8 @@ def __init__(self,
torch.nn.Tanh(), torch.nn.Dropout(dropout_rate),
torch.nn.Linear(join_dim, 1), torch.nn.LogSigmoid())
self.token_pred = torch.nn.Sequential(
get_activation(hat_activation), torch.nn.Dropout(dropout_rate),
WENET_ACTIVATION_CLASSES[hat_activation](),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(join_dim, self.vocab_size - 1))

def forward(self,
Expand Down
18 changes: 9 additions & 9 deletions wenet/transducer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import nn
from wenet.utils.class_utils import get_activation, get_rnn
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES, WENET_RNN_CLASSES


def ApplyPadding(input, padding, pad_value) -> torch.Tensor:
Expand Down Expand Up @@ -79,12 +79,12 @@ def __init__(self,
# NOTE(Mddct): rnn base from torch not support layer norm
# will add layer norm and prune value in cell and layer
# ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py
self.rnn = get_rnn(rnn_type=rnn_type)(input_size=embed_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout)
self.rnn = WENET_RNN_CLASSES[rnn_type](input_size=embed_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout)
self.projection = nn.Linear(hidden_size, output_size)

def output_size(self):
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(self,
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.ffn = nn.Linear(self.embed_size, self.embed_size)
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
self.activatoin = WENET_ACTIVATION_CLASSES[activation]()

def output_size(self):
return self.embed_size
Expand Down Expand Up @@ -398,7 +398,7 @@ def __init__(self,
groups=embed_size,
bias=bias)
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
self.activatoin = WENET_ACTIVATION_CLASSES[activation]()

def output_size(self):
return self.embed_size
Expand Down
4 changes: 2 additions & 2 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import (
WENET_EMB_CLASSES, WENET_ATTENTION_CLASSES,
get_activation
WENET_ACTIVATION_CLASSES,
)
from wenet.utils.mask import (subsequent_mask, make_pad_mask)

Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
):
super().__init__()
attention_dim = encoder_output_size
activation = get_activation(activation_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

self.embed = torch.nn.Sequential(
torch.nn.Identity() if input_layer == "no_pos" else torch.nn.Embedding(
Expand Down
6 changes: 3 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import (
WENET_EMB_CLASSES, WENET_SUBSAMPLE_CLASSES, WENET_ATTENTION_CLASSES,
get_activation
WENET_ACTIVATION_CLASSES,
)
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
Expand Down Expand Up @@ -326,7 +326,7 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
activation = get_activation(activation_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
activation = get_activation(activation_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
encoder_selfattn_layer_args = (
Expand Down
39 changes: 14 additions & 25 deletions wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright [2023-11-28] <[email protected], Xingchen Song>
import torch

from wenet.transformer.swish import Swish
from wenet.transformer.subsampling import (
LinearNoSubsampling, EmbedinigNoSubsampling,
Conv1dSubsampling2, Conv2dSubsampling4,
Expand All @@ -21,32 +22,20 @@
from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention


def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
from wenet.transformer.swish import Swish

activation_funcs = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU
}

return activation_funcs[act]()


def get_rnn(rnn_type: str) -> torch.nn.Module:
assert rnn_type in ["rnn", "lstm", "gru"]
if rnn_type == "rnn":
return torch.nn.RNN
elif rnn_type == "lstm":
return torch.nn.LSTM
else:
return torch.nn.GRU
WENET_ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}

WENET_RNN_CLASSES = {
"rnn": torch.nn.RNN,
"lstm": torch.nn.LSTM,
"gru": torch.nn.GRU,
}

WENET_SUBSAMPLE_CLASSES = {
"linear": LinearNoSubsampling,
Expand Down

0 comments on commit 0df2759

Please sign in to comment.