|  | 
|  | 1 | +import torch | 
|  | 2 | + | 
|  | 3 | +__all__ = ["DeepSpeech"] | 
|  | 4 | + | 
|  | 5 | + | 
|  | 6 | +class FullyConnected(torch.nn.Module): | 
|  | 7 | +    """ | 
|  | 8 | +    Args: | 
|  | 9 | +        n_feature: Number of input features | 
|  | 10 | +        n_hidden: Internal hidden unit size. | 
|  | 11 | +    """ | 
|  | 12 | + | 
|  | 13 | +    def __init__(self, | 
|  | 14 | +                 n_feature: int, | 
|  | 15 | +                 n_hidden: int, | 
|  | 16 | +                 dropout: float, | 
|  | 17 | +                 relu_max_clip: int = 20) -> None: | 
|  | 18 | +        super(FullyConnected, self).__init__() | 
|  | 19 | +        self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) | 
|  | 20 | +        self.relu_max_clip = relu_max_clip | 
|  | 21 | +        self.dropout = dropout | 
|  | 22 | + | 
|  | 23 | +    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
|  | 24 | +        x = self.fc(x) | 
|  | 25 | +        x = torch.nn.functional.relu(x) | 
|  | 26 | +        x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) | 
|  | 27 | +        if self.dropout: | 
|  | 28 | +            x = torch.nn.functional.dropout(x, self.dropout, self.training) | 
|  | 29 | +        return x | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +class DeepSpeech(torch.nn.Module): | 
|  | 33 | +    """ | 
|  | 34 | +    DeepSpeech model architecture from | 
|  | 35 | +    `"Deep Speech: Scaling up end-to-end speech recognition"` | 
|  | 36 | +    <https://arxiv.org/abs/1412.5567> paper. | 
|  | 37 | +
 | 
|  | 38 | +    Args: | 
|  | 39 | +        n_feature: Number of input features | 
|  | 40 | +        n_hidden: Internal hidden unit size. | 
|  | 41 | +        n_class: Number of output classes | 
|  | 42 | +    """ | 
|  | 43 | + | 
|  | 44 | +    def __init__( | 
|  | 45 | +        self, | 
|  | 46 | +        n_feature: int, | 
|  | 47 | +        n_hidden: int = 2048, | 
|  | 48 | +        n_class: int = 40, | 
|  | 49 | +        dropout: float = 0.0, | 
|  | 50 | +    ) -> None: | 
|  | 51 | +        super(DeepSpeech, self).__init__() | 
|  | 52 | +        self.n_hidden = n_hidden | 
|  | 53 | +        self.fc1 = FullyConnected(n_feature, n_hidden, dropout) | 
|  | 54 | +        self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) | 
|  | 55 | +        self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) | 
|  | 56 | +        self.bi_rnn = torch.nn.RNN( | 
|  | 57 | +            n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True | 
|  | 58 | +        ) | 
|  | 59 | +        self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) | 
|  | 60 | +        self.out = torch.nn.Linear(n_hidden, n_class) | 
|  | 61 | + | 
|  | 62 | +    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
|  | 63 | +        """ | 
|  | 64 | +        Args: | 
|  | 65 | +            x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). | 
|  | 66 | +        Returns: | 
|  | 67 | +            Tensor: Predictor tensor of dimension (batch, time, class). | 
|  | 68 | +        """ | 
|  | 69 | +        # N x C x T x F | 
|  | 70 | +        x = self.fc1(x) | 
|  | 71 | +        # N x C x T x H | 
|  | 72 | +        x = self.fc2(x) | 
|  | 73 | +        # N x C x T x H | 
|  | 74 | +        x = self.fc3(x) | 
|  | 75 | +        # N x C x T x H | 
|  | 76 | +        x = x.squeeze(1) | 
|  | 77 | +        # N x T x H | 
|  | 78 | +        x = x.transpose(0, 1) | 
|  | 79 | +        # T x N x H | 
|  | 80 | +        x, _ = self.bi_rnn(x) | 
|  | 81 | +        # The fifth (non-recurrent) layer takes both the forward and backward units as inputs | 
|  | 82 | +        x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:] | 
|  | 83 | +        # T x N x H | 
|  | 84 | +        x = self.fc4(x) | 
|  | 85 | +        # T x N x H | 
|  | 86 | +        x = self.out(x) | 
|  | 87 | +        # T x N x n_class | 
|  | 88 | +        x = x.permute(1, 0, 2) | 
|  | 89 | +        # N x T x n_class | 
|  | 90 | +        x = torch.nn.functional.log_softmax(x, dim=2) | 
|  | 91 | +        # N x T x n_class | 
|  | 92 | +        return x | 
0 commit comments