Skip to content

Commit

Permalink
MixedFusedRMSNorm Export Fix (#6296) (#6299)
Browse files Browse the repository at this point in the history
* Added RMSLayerNorm to export_utils



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: David Mosallanezhad <[email protected]>
Co-authored-by: David <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Mar 31, 2023
1 parent 6fb33d4 commit c5a722d
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):

try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization import MixedFusedRMSNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
Expand All @@ -223,33 +224,22 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
"""

p = next(n.parameters())

if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
n_state = n.state_dict()
elif isinstance(n, FastLayerNorm):
shape, eps, affine = n.weight.shape, n.epsilon, True
n_state = n.state_dict()
elif isinstance(n, MixedFusedRMSNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
tmp_n_state = n.state_dict()
n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])}
else:
return None

mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
n_state = n.state_dict()
mod.load_state_dict(n_state)
return mod

def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
"""
Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export.
Args:
n: the FusedLayerNorm pytorch module to replace
Returns:
Equivalent LayerNorm module
"""
if not isinstance(n, RowParallelLinear):
raise ValueError("This function can only change the RowParallelLinear module.")

dev = next(n.parameters()).device
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev)

n_state = n.state_dict()
mod.load_state_dict(n_state)
return mod

Expand Down Expand Up @@ -296,6 +286,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
"RowParallelLinear": replace_ParallelLinear,
"ColumnParallelLinear": replace_ParallelLinear,
"FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax,
"MixedFusedRMSNorm": replace_FusedLayerNorm,
}

except Exception as e:
Expand Down

0 comments on commit c5a722d

Please sign in to comment.