Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification of RMSNorm layer fusion #14

Open
tobiasvanderwerff opened this issue Oct 3, 2024 · 1 comment
Open

Clarification of RMSNorm layer fusion #14

tobiasvanderwerff opened this issue Oct 3, 2024 · 1 comment

Comments

@tobiasvanderwerff
Copy link

Hi,

Thanks very much for your work and for publishing your code. I am currently working on integration of SpinQuant into torch/ao, and I would like to clarify something about the code that would help me in my implementation.

In the paper, the following is mentioned in footnote 3:

In a pre-norm LLM like LLaMA, we can convert a transformer network into a rotation-invariant network by incorporating the RMSNorm scale parameters α into the weight matrix right after the RMSNorm layer.

In the code, this appears to be done in the fuse_layer_norms function.

However, I also noticed that in that same function, the embedding weights are modified, in the following lines:

# Embedding fusion
for W in [model.model.embed_tokens]:
W_ = W.weight.data.double()
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)

Could you help me understand why this is done? I.e. subtraction of the mean from the input embeddings. I don't see a connection to the RMSNorm layer fusion, so I must be missing something.

Thanks in advance.

@testworldagain
Copy link

When remove those lines(line 42 ~ 45), the ppl of the Llama2-7b on wikitext2 comes from 6.8 down to 6.5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants