From c5a722d5949035094fe1d2dd4ca87cd29ce7fc09 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 31 Mar 2023 10:16:53 -0600 Subject: [PATCH] MixedFusedRMSNorm Export Fix (#6296) (#6299) * Added RMSLayerNorm to export_utils * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: David Mosallanezhad Co-authored-by: David Co-authored-by: David Mosallanezhad Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/utils/export_utils.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 3fa7b1322aad..f4328b960143 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -209,6 +209,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm + from apex.normalization import MixedFusedRMSNorm from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -223,33 +224,22 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ p = next(n.parameters()) + if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine + n_state = n.state_dict() elif isinstance(n, FastLayerNorm): shape, eps, affine = n.weight.shape, n.epsilon, True + n_state = n.state_dict() + elif isinstance(n, MixedFusedRMSNorm): + shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine + tmp_n_state = n.state_dict() + n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])} else: return None mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - - def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedLayerNorm pytorch module to replace - Returns: - Equivalent LayerNorm module - """ - if not isinstance(n, RowParallelLinear): - raise ValueError("This function can only change the RowParallelLinear module.") - - dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) - n_state = n.state_dict() mod.load_state_dict(n_state) return mod @@ -296,6 +286,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: "RowParallelLinear": replace_ParallelLinear, "ColumnParallelLinear": replace_ParallelLinear, "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, + "MixedFusedRMSNorm": replace_FusedLayerNorm, } except Exception as e: