From 42c52bf3cc128c4dd99a78d1c50b1e33859cd0e5 Mon Sep 17 00:00:00 2001 From: Manuel Candales <42380156+manuelcandales@users.noreply.github.com> Date: Fri, 24 Jan 2025 12:51:30 -0500 Subject: [PATCH] Replace RMSNorm by nn.RMSNorm (#1464) In this PR we replace torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in https://github.com/pytorch/pytorch/pull/145301 Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times. - mul is called 275 times per token - copy_ is called 202 times per token - add is called 97 times per token - where is called 34 times per token - mean is called 33 times per token - rsqrt is called 33 times per token - sub is called 32 times per token - cat is called 32 times per token - stack is called 32 times per token Currently, torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) operation is basically implemented like this: ``` norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) output = norm(x.float()).type_as(x) * weight ``` This means that a single call to torchchat's RMSNorm involves 3 calls to `aten::mul` and calls to `aten::rsqrt`, `aten::mean` and `aten::add`. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls. --- install/install_requirements.sh | 6 +++--- torchchat/model.py | 20 +++----------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 264c3496d..360ba1801 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -51,13 +51,13 @@ echo "Using pip executable: $PIP_EXECUTABLE" # NOTE: If a newly-fetched version of the executorch repo changes the value of # PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20250119 +PYTORCH_NIGHTLY_VERSION=dev20250124 # Nightly version for torchvision -VISION_NIGHTLY_VERSION=dev20250119 +VISION_NIGHTLY_VERSION=dev20250124 # Nightly version for torchtune -TUNE_NIGHTLY_VERSION=dev20250119 +TUNE_NIGHTLY_VERSION=dev20250124 # The pip repository that hosts nightly torch packages. cpu by default. # If cuda is available, based on presence of nvidia-smi, install the pytorch nightly diff --git a/torchchat/model.py b/torchchat/model.py index c01ff1262..ce7dcb5e4 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None: self.layers[str(layer_id)] = TransformerBlock(config) if config.stage_idx == config.n_stages - 1: - self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) if config.tie_word_embeddings: self.output.weight = self.tok_embeddings.weight @@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None: super().__init__() self.attention = Attention(config) self.feed_forward = FeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) + self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps) + self.attention_norm = nn.RMSNorm(config.dim, config.norm_eps) # None for llama architecture, set for granite architectures self.residual_multiplier = ( config.residual_multiplier @@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]): # Check for the presence of the required keys required_keys = {