Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FIX] Update BERTLayerNorm Implementation #485

Merged
merged 3 commits into from
Dec 28, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 7 additions & 24 deletions src/gluonnlp/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
'BERTLayerNorm', 'bert_12_768_12', 'bert_24_1024_16']

import os
from mxnet.gluon import Block, HybridBlock
from mxnet.gluon import Block
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import model_store
import mxnet as mx
Expand All @@ -34,35 +34,18 @@
# COMPONENTS #
###############################################################################

class BERTLayerNorm(HybridBlock):
"""BERT style Layer Normalization.

Epsilon is added inside the square root.
class BERTLayerNorm(nn.LayerNorm):
"""BERT style Layer Normalization, where epsilon is added inside the square
root and set to 1e-12 by default.

Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, epsilon=1e-12, in_channels=0, prefix=None, params=None):
super(BERTLayerNorm, self).__init__(prefix=prefix, params=params)
self.gamma = self.params.get('gamma', shape=(in_channels,),
allow_deferred_init=True)
self.beta = self.params.get('beta', shape=(in_channels,),
allow_deferred_init=True)
self._eps = epsilon

def hybrid_forward(self, F, x, gamma, beta): # pylint: disable=arguments-differ
u = F.mean(x, -1, keepdims=True)
s = F.mean(F.broadcast_sub(x, u) ** 2, -1, keepdims=True) + self._eps
x = F.broadcast_div(F.broadcast_sub(x, u), s.sqrt())
return F.broadcast_add(F.broadcast_mul(gamma, x), beta)

def __repr__(self):
s = '{name}('
in_channels = self.gamma.shape[0]
s += 'in_channels={0}, epsilon={1})'.format(in_channels, self._eps)
return s.format(name=self.__class__.__name__)
super(BERTLayerNorm, self).__init__(epsilon=epsilon, in_channels=in_channels,
prefix=prefix, params=params)


class BERTPositionwiseFFN(BasePositionwiseFFN):
"""Structure of the Positionwise Feed-Forward Neural Network for
Expand Down