Skip to content

Commit

Permalink
Replace RMSNorm by nn.RMSNorm (#1464)
Browse files Browse the repository at this point in the history
In this PR we replace torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in pytorch/pytorch#145301

Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization

Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times.
- mul is called 275 times per token
- copy_ is called 202 times per token
- add is called 97 times per token
- where is called 34 times per token
- mean is called 33 times per token
- rsqrt is called 33 times per token
- sub is called 32 times per token
- cat is called 32 times per token
- stack is called 32 times per token

Currently, torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) operation is basically implemented like this:
```
norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
output = norm(x.float()).type_as(x) * weight
```
This means that a single call to torchchat's RMSNorm involves 3 calls to `aten::mul` and calls to `aten::rsqrt`, `aten::mean` and `aten::add`. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls.
  • Loading branch information
manuelcandales authored Jan 24, 2025
1 parent f4ae60f commit 42c52bf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
6 changes: 3 additions & 3 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ echo "Using pip executable: $PIP_EXECUTABLE"
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
PYTORCH_NIGHTLY_VERSION=dev20250119
PYTORCH_NIGHTLY_VERSION=dev20250124

# Nightly version for torchvision
VISION_NIGHTLY_VERSION=dev20250119
VISION_NIGHTLY_VERSION=dev20250124

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20250119
TUNE_NIGHTLY_VERSION=dev20250124

# The pip repository that hosts nightly torch packages. cpu by default.
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
Expand Down
20 changes: 3 additions & 17 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None:
self.layers[str(layer_id)] = TransformerBlock(config)

if config.stage_idx == config.n_stages - 1:
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.output.weight = self.tok_embeddings.weight
Expand Down Expand Up @@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.attention_norm = nn.RMSNorm(config.dim, config.norm_eps)
# None for llama architecture, set for granite architectures
self.residual_multiplier = (
config.residual_multiplier
Expand Down Expand Up @@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight


def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]):
# Check for the presence of the required keys
required_keys = {
Expand Down

0 comments on commit 42c52bf

Please sign in to comment.