-
Notifications
You must be signed in to change notification settings - Fork 1.5k
FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm #1274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
028ef04
21c9277
40199a5
6f62a7b
0c7d7ec
f6ee6b4
8d00d22
e5c0c61
50e068b
2cca4ff
ce65026
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm | ||
| from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,23 @@ | |
| fused_layer_norm_cuda = None | ||
|
|
||
|
|
||
| # Reference implementation from Huggingface | ||
| def manual_rms_norm(input, normalized_shape, weight, eps): | ||
| # layer norm should always be calculated in float32 | ||
| dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1)) | ||
| variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True) | ||
| input = input * torch.rsqrt(variance + eps) | ||
|
|
||
| if weight is None: | ||
| return input | ||
|
|
||
| # convert into half-precision if necessary | ||
| if weight.dtype in [torch.float16, torch.bfloat16]: | ||
| input = input.to(self.weight.dtype) | ||
|
|
||
| return weight * input | ||
|
|
||
|
|
||
| class FusedLayerNormAffineFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, weight, bias, normalized_shape, eps): | ||
|
|
@@ -39,6 +56,31 @@ def backward(ctx, grad_output): | |
| return grad_input, grad_weight, grad_bias, None, None | ||
|
|
||
|
|
||
| class FusedRMSNormAffineFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, weight, normalized_shape, eps): | ||
| global fused_layer_norm_cuda | ||
| if fused_layer_norm_cuda is None: | ||
| fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") | ||
| ctx.normalized_shape = normalized_shape | ||
| ctx.eps = eps | ||
| input_ = input.contiguous() | ||
| weight_ = weight.contiguous() | ||
| output, invvar = fused_layer_norm_cuda.rms_forward_affine( | ||
| input_, ctx.normalized_shape, weight_, ctx.eps) | ||
| ctx.save_for_backward(input_, weight_, invvar) | ||
| return output | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| input_, weight_, invvar = ctx.saved_tensors | ||
| grad_input = grad_weight = None | ||
| grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( | ||
| grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps | ||
| ) | ||
| return grad_input, grad_weight, None, None | ||
|
|
||
|
|
||
| class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): | ||
|
|
||
| @staticmethod | ||
|
|
@@ -58,6 +100,25 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): | |
| return output | ||
|
|
||
|
|
||
| class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, input, weight, normalized_shape, eps): | ||
| global fused_layer_norm_cuda | ||
| if fused_layer_norm_cuda is None: | ||
| fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") | ||
| ctx.normalized_shape = normalized_shape | ||
| ctx.eps = eps | ||
| input_ = input.contiguous() | ||
| weight_ = weight.contiguous() | ||
| output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( | ||
| input_, ctx.normalized_shape, weight_, ctx.eps | ||
| ) | ||
|
|
||
| ctx.save_for_backward(input_, weight_, invvar) | ||
| return output | ||
|
|
||
|
|
||
| class FusedLayerNormFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, normalized_shape, eps): | ||
|
|
@@ -81,6 +142,29 @@ def backward(ctx, grad_output): | |
| return grad_input, None, None | ||
|
|
||
|
|
||
| class FusedRMSNormFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, normalized_shape, eps): | ||
| global fused_layer_norm_cuda | ||
| if fused_layer_norm_cuda is None: | ||
| fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") | ||
| ctx.normalized_shape = normalized_shape | ||
| ctx.eps = eps | ||
| input_ = input.contiguous() | ||
| output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) | ||
| ctx.save_for_backward(input_, invvar) | ||
| return output | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| input_, invvar = ctx.saved_tensors | ||
| grad_input = None | ||
| grad_input = fused_layer_norm_cuda.rms_backward( | ||
| grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps | ||
| ) | ||
| return grad_input, None, None | ||
|
|
||
|
|
||
| def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): | ||
| args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) | ||
| with torch.cuda.amp.autocast(enabled=False): | ||
|
|
@@ -99,6 +183,24 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e | |
| return FusedLayerNormAffineMixedDtypesFunction.apply(*args) | ||
|
|
||
|
|
||
| def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): | ||
| args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) | ||
| with torch.cuda.amp.autocast(enabled=False): | ||
| return FusedRMSNormAffineFunction.apply(*args) | ||
|
|
||
|
|
||
| def fused_rms_norm(input, normalized_shape, eps=1e-6): | ||
| args = _cast_if_autocast_enabled(input, normalized_shape, eps) | ||
| with torch.cuda.amp.autocast(enabled=False): | ||
| return FusedRMSNormFunction.apply(*args) | ||
|
|
||
|
|
||
| def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): | ||
| args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) | ||
| with torch.cuda.amp.autocast(enabled=False): | ||
| return FusedRMSNormAffineMixedDtypesFunction.apply(*args) | ||
|
|
||
|
|
||
| class FusedLayerNorm(torch.nn.Module): | ||
| r"""Applies Layer Normalization over a mini-batch of inputs as described in | ||
| the paper `Layer Normalization`_ . | ||
|
|
@@ -195,6 +297,99 @@ def extra_repr(self): | |
| return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) | ||
|
|
||
|
|
||
| class FusedRMSNorm(torch.nn.Module): | ||
| r"""Applies RMS Normalization over a mini-batch of inputs | ||
|
|
||
| Currently only runs on cuda() tensors. | ||
|
|
||
| .. math:: | ||
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | ||
|
|
||
| The mean and standard-deviation are calculated separately over the last | ||
| certain number dimensions which have to be of the shape specified by | ||
| :attr:`normalized_shape`. | ||
| :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of | ||
| :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. | ||
|
Comment on lines
+306
to
+312
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like a copy-n-paste error - as this version has no bias and no mean subtraction in the math formula. I think the note below needs updating as well wrt bias.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you check the follow-up #1285?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh! I missed that one - thank you! looks much better now - with only one small issue - commented in the other PR. |
||
|
|
||
| .. note:: | ||
| Unlike Batch Normalization and Instance Normalization, which applies | ||
| scalar scale and bias for each entire channel/plane with the | ||
| :attr:`affine` option, Layer Normalization applies per-element scale and | ||
| bias with :attr:`elementwise_affine`. | ||
|
|
||
| This layer uses statistics computed from input data in both training and | ||
| evaluation modes. | ||
|
|
||
| Args: | ||
| normalized_shape (int or list or torch.Size): input shape from an expected input | ||
| of size | ||
|
|
||
| .. math:: | ||
| [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1] | ||
| \times \ldots \times \text{normalized}\_\text{shape}[-1]] | ||
|
|
||
| If a single integer is used, it is treated as a singleton list, and this module will | ||
| normalize over the last dimension which is expected to be of that specific size. | ||
| eps: a value added to the denominator for numerical stability. Default: 1e-5 | ||
| elementwise_affine: a boolean value that when set to ``True``, this module | ||
| has learnable per-element affine parameters initialized to ones (for weights) | ||
| and zeros (for biases). Default: ``True``. | ||
|
|
||
| Shape: | ||
| - Input: :math:`(N, *)` | ||
| - Output: :math:`(N, *)` (same shape as input) | ||
|
|
||
| Examples:: | ||
|
|
||
| >>> input = torch.randn(20, 5, 10, 10) | ||
| >>> # With Learnable Parameters | ||
| >>> m = apex.normalization.FusedRMSNorm(input.size()[1:]) | ||
| >>> # Without Learnable Parameters | ||
| >>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False) | ||
| >>> # Normalize over last two dimensions | ||
| >>> m = apex.normalization.FusedRMSNorm([10, 10]) | ||
| >>> # Normalize over last dimension of size 10 | ||
| >>> m = apex.normalization.FusedRMSNorm(10) | ||
| >>> # Activating the module | ||
| >>> output = m(input) | ||
|
|
||
| .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 | ||
| """ | ||
|
|
||
| def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): | ||
| super().__init__() | ||
|
|
||
| global fused_layer_norm_cuda | ||
| fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") | ||
|
|
||
| if isinstance(normalized_shape, numbers.Integral): | ||
| normalized_shape = (normalized_shape,) | ||
| self.normalized_shape = torch.Size(normalized_shape) | ||
| self.eps = eps | ||
| self.elementwise_affine = elementwise_affine | ||
| if self.elementwise_affine: | ||
| self.weight = Parameter(torch.Tensor(*normalized_shape)) | ||
| else: | ||
| self.register_parameter("weight", None) | ||
| self.reset_parameters() | ||
|
|
||
| def reset_parameters(self): | ||
| if self.elementwise_affine: | ||
| init.ones_(self.weight) | ||
|
|
||
| def forward(self, input): | ||
| if not input.is_cuda: | ||
| return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) | ||
|
|
||
| if self.elementwise_affine: | ||
| return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) | ||
| else: | ||
| return fused_rms_norm(input, self.normalized_shape, self.eps) | ||
|
|
||
| def extra_repr(self): | ||
| return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) | ||
|
|
||
|
|
||
| # NOTE (mkozuki): Why "mixed"? | ||
| # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype | ||
| # as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. | ||
|
|
@@ -216,3 +411,26 @@ def forward(self, input: torch.Tensor): | |
| if not input.is_cuda: | ||
| return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) | ||
| return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) | ||
|
|
||
|
|
||
| # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype | ||
| # as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. | ||
| # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" | ||
| class MixedFusedRMSNorm(FusedRMSNorm): | ||
|
|
||
| def __init__(self, normalized_shape, eps=1e-5, **kwargs): | ||
| if "elementwise_affine" in kwargs: | ||
| import warnings | ||
| warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument") | ||
| elementwise_affine = kwargs.pop("elementwise_affine") | ||
| if not elementwise_affine: | ||
| raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`") | ||
|
|
||
| super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) | ||
|
|
||
| def forward(self, input: torch.Tensor): | ||
| # NOTE (mkozuki): CPU path is here mainly for unittest sake. | ||
| # TODO Manual RMS Norm Implementation Here | ||
| if not input.is_cuda: | ||
| return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) | ||
| return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) | ||
Uh oh!
There was an error while loading. Please reload this page.