Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Jan 15, 2022

#1271

Pattern-matched implementation of FusedRMSNorm based on FusedLayerNorm. Tests are passing (needed threshold adjustment for float16), awaiting benchmark results and cleanup.

@eqy
Copy link
Collaborator Author

eqy commented Jan 18, 2022

Some benchmark data on A100:

[------------------------------------------------------------------------------------------------------------ forward ------------------------------------------------------------------------------------------------------------]
                     |  torch.float32  |  fused torch.float32  |  torch.bfloat16  |  fused torch.bfloat16  |  autocast torch.float32  |  fused autocast torch.float32  |  autocast torch.bfloat16  |  fused autocast torch.bfloat16
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      [8, 28, 1024]  |       62.2      |          27.9         |       75.3       |          28.0          |           63.3           |              41.3              |            63.2           |               64.8
      [8, 30, 1024]  |       62.4      |          28.1         |       74.8       |          27.6          |           63.3           |              41.0              |            64.7           |               66.1
      [8, 32, 1024]  |       61.7      |          28.1         |       74.7       |          27.8          |           63.1           |              41.1              |            62.6           |               65.6
      [8, 34, 1024]  |       63.3      |          28.2         |       76.2       |          28.8          |           63.1           |              41.1              |            64.6           |               66.7
      [8, 36, 1024]  |       61.9      |          27.8         |       74.6       |          28.1          |           62.8           |              41.5              |            64.4           |               66.7
      [8, 38, 1024]  |       63.6      |          28.3         |       74.5       |          28.2          |           63.6           |              40.9              |            64.4           |               65.8
      [8, 40, 1024]  |       62.2      |          28.4         |       75.5       |          28.3          |           63.6           |              42.5              |            64.7           |               66.4
      [8, 42, 1024]  |       62.5      |          28.1         |       74.6       |          28.6          |           63.9           |              41.1              |            64.7           |               65.4
      [8, 44, 1024]  |       62.2      |          27.9         |       74.9       |          28.4          |           64.2           |              40.7              |            63.6           |               65.6
      [8, 46, 1024]  |       62.8      |          28.1         |       75.9       |          29.0          |           63.9           |              41.0              |            64.0           |               65.6
      [8, 48, 1024]  |       62.5      |          28.5         |       74.9       |          28.4          |           63.4           |              41.3              |            63.1           |               65.2
      [8, 50, 1024]  |       62.8      |          28.2         |       74.5       |          27.9          |           63.6           |              40.8              |            62.8           |               65.7
      [8, 52, 1024]  |       62.2      |          28.4         |       76.5       |          28.9          |           64.1           |              41.4              |            62.9           |               65.6
      [8, 54, 1024]  |       62.7      |          28.2         |       74.8       |          28.3          |           63.3           |              41.4              |            63.5           |               65.7
      [8, 56, 1024]  |       62.8      |          28.0         |       75.1       |          28.1          |           63.1           |              41.0              |            63.2           |               65.4

Times are in microseconds (us).

[------------------------------------------------------------------------------------------------------------ backward -----------------------------------------------------------------------------------------------------------]
                     |  torch.float32  |  fused torch.float32  |  torch.bfloat16  |  fused torch.bfloat16  |  autocast torch.float32  |  fused autocast torch.float32  |  autocast torch.bfloat16  |  fused autocast torch.bfloat16
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      [8, 28, 1024]  |      386.3      |         150.2         |      425.3       |         149.8          |          382.8           |             161.0              |           384.8           |              229.7
      [8, 30, 1024]  |      662.0      |         363.6         |      427.3       |         144.8          |          379.5           |             163.7              |           394.8           |              222.1
      [8, 32, 1024]  |      368.2      |         149.4         |      438.6       |         149.8          |          377.5           |             162.2              |           374.3           |              225.6
      [8, 34, 1024]  |      376.6      |         149.2         |      429.5       |         148.2          |          376.8           |             162.5              |           381.4           |              213.9
      [8, 36, 1024]  |      395.9      |         154.3         |      425.7       |         143.9          |          383.2           |             162.2              |           390.6           |              234.8
      [8, 38, 1024]  |      403.5      |         158.5         |      432.1       |         149.4          |          390.5           |             164.1              |           411.9           |              353.4
      [8, 40, 1024]  |      378.3      |         152.8         |      437.0       |         152.8          |          660.5           |             223.3              |           386.4           |              223.4
      [8, 42, 1024]  |      381.3      |         148.1         |      427.7       |         145.6          |          363.2           |             164.2              |           384.1           |              212.9
      [8, 44, 1024]  |      375.2      |         150.3         |      425.4       |         149.5          |          383.1           |             161.5              |           380.7           |              216.3
      [8, 46, 1024]  |      372.6      |         147.3         |      440.5       |         152.9          |          380.9           |             162.1              |           383.3           |              216.6
      [8, 48, 1024]  |      618.1      |         239.7         |      437.4       |         145.0          |          380.3           |             159.9              |           379.1           |              222.2
      [8, 50, 1024]  |      374.2      |         146.1         |      424.7       |         153.4          |          381.0           |             161.8              |           382.6           |              225.7
      [8, 52, 1024]  |      379.7      |         148.7         |      433.8       |         148.2          |          390.0           |             162.3              |           376.4           |              226.2
      [8, 54, 1024]  |      374.3      |         150.6         |      428.7       |         150.8          |          381.5           |             162.5              |           376.7           |              212.1
      [8, 56, 1024]  |      377.3      |         145.9         |      426.5       |         144.5          |          379.0           |             161.2              |           389.1           |              217.3

Times are in microseconds (us).

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bunch of suggestions to remove comment outed lines. you can batch into suggestions into one if you like, see https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/incorporating-feedback-in-your-pull-request#applying-suggested-changes.

What do you think about dissecting apex/normalization/fused_layer_norm.py into fused_layer_norm.py and fused_rms_norm.py?

class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
#def forward(ctx, input, weight, bias, normalized_shape, eps):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#def forward(ctx, input, weight, bias, normalized_shape, eps):

#)
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
input_, ctx.normalized_shape, weight_, ctx.eps)
#ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#ctx.save_for_backward(input_, weight_, bias_, mean, invvar)

at::IntList normalized_shape,
#endif
at::Tensor* gamma,
// at::Tensor* beta,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// at::Tensor* beta,

double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma)
// at::Tensor* grad_beta
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// at::Tensor* grad_beta

using accscalar_t = at::acc_type<scalar_t_in, true>;
HostRMSNormGradient(
dout->DATA_PTR<scalar_t_out>(),
// mean->DATA_PTR<accscalar_t>(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// mean->DATA_PTR<accscalar_t>(),

// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
// gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,

epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL);
// gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);

Comment on lines +191 to +197
native = apex.normalization.FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
)
fused = apex.normalization.FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
return native, fused
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this testing won't do much good as it's comparing to itself :)

Since there isn't torch.nn.RMSNorm, perhaps writing one out in plain python?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit opaque here, but the "native" version is computed on CPU which dispatches to a manual plain python version sourced from T5LayerNorm:

return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)

def manual_rms_norm(input, normalized_shape, weight, eps):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining this nuance, @eqy. I can see it now.

@eqy eqy changed the title [WIP] FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm Jan 31, 2022
@eqy
Copy link
Collaborator Author

eqy commented Jan 31, 2022

@crcrpar this is now refactored to use the existing FusedLayerNorm implementation via an added rms_only flag

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the last iteration

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
Comment on lines +306 to +312
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
Copy link
Contributor

@stas00 stas00 Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a copy-n-paste error - as this version has no bias and no mean subtraction in the math formula.

I think the note below needs updating as well wrt bias.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check the follow-up #1285?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh! I missed that one - thank you!

looks much better now - with only one small issue - commented in the other PR.

hubertlu-tw pushed a commit to ROCm/apex that referenced this pull request Apr 15, 2022
* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
jithunnair-amd pushed a commit to ROCm/apex that referenced this pull request Aug 5, 2022
* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (NVIDIA#1274)

* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

* fix and generate docs for FusedRMSNorm (NVIDIA#1285)

* [FusedRMSNorm doc] document where epsilon is added (NVIDIA#1295)

* [FusedRMSNorm doc] add epsilon to formula

* correct

* better wording

* Fix some bugs

* Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs

* Fix NaN issues in FusedRMSNorm

* Update test_fused_layer_norm.py

* Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm

* Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize

Co-authored-by: eqy <eddiey@nvidia.com>
Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants