Skip to content

Conversation

@mntss
Copy link
Contributor

@mntss mntss commented Apr 15, 2025

Description

The current LLama3.x implementations use incorrect scaling factors for RoPE

https://github.com/huggingface/transformers/blob/3165eb7c2808832d0de86c8f508d9da6b2124044/src/transformers/modeling_rope_utils.py#L407-L410

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Screenshots

I used this script for testing:

# %%
import os
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM
from datasets import load_dataset

torch.set_grad_enabled(False)

# %%
# Load models
dtype = torch.float32
model_tl = HookedTransformer.from_pretrained_no_processing("meta-llama/Llama-3.2-1B", dtype=dtype)
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=dtype)

# %%
# Load dataset
dataset = load_dataset("EleutherAI/fineweb-edu-dedup-10b", split="train")

# %%
# Create dataloader helper function
def create_dataloader(dataset, model, batch_size=64, max_length=256):
    """Create a dataloader with the dataset."""
    
    def tokenize_collate_fn(batch):
        texts = [item["text"] for item in batch]
        tokenized = model.tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
        return tokenized.to(model.cfg.device, non_blocking=True)
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=tokenize_collate_fn,
    )

# %%
# Create dataloader and get a batch
dataloader = create_dataloader(
    dataset.take(1024), model_tl, batch_size=16
)
batch = next(iter(dataloader))

# %%
# Move models and data to CUDA
model_tl = model_tl.to('cuda').eval()
model_hf = model_hf.to('cuda').eval()
batch = batch.to('cuda')

# %%
# Run both models
with torch.no_grad():
    logits_hf = model_hf(batch.input_ids).logits
    logits_tl = model_tl(batch.input_ids)

# %%
# Compare logits (optional - to verify models produce similar results)
logits_diff = logits_hf - logits_tl
print(f"Logits difference max: {logits_diff.abs().max()}")
print(f"Logits difference mean: {logits_diff.abs().mean()}")

# %%
# Compare top predictions
top_k_hf = logits_hf.argsort(dim=-1, descending=True)[:, :, :5]
top_k_tl = logits_tl.argsort(dim=-1, descending=True)[:, :, :5]

# Check if top predictions match
top_matches = (top_k_hf == top_k_tl).float().mean()
print(f"Top-5 prediction match rate: {top_matches}")

Before:

Logits difference max: 2.952045440673828
Logits difference mean: 0.05180485174059868
Top-5 prediction match rate: 0.899121105670929

After:

Logits difference max: 0.00018334388732910156
Logits difference mean: 8.470005013805348e-06
Top-5 prediction match rate: 1.0

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@bryce13950 bryce13950 changed the base branch from main to dev April 30, 2025 20:38
@bryce13950 bryce13950 merged commit d2f3f15 into TransformerLensOrg:dev Apr 30, 2025
10 of 13 checks passed
bryce13950 added a commit that referenced this pull request May 26, 2025
* moved setup python

* added PR action

* temporarily hardcoded version number

* moved poetry

* Revert "temporarily hardcoded version number"

This reverts commit 23a7be8.

* Revert "added PR action"

This reverts commit b08c241.

* added full hf token authentication (#916)

* Fix LLama RoPE (#910)

Co-authored-by: Bryce Meyer <[email protected]>

* added conditional check for hugging face (#919)

* created a seperate list of models to test for public PRs (#920)

* created a seperate list of models to test for public PRs

* ran format

* added alternative when hf token is not included (#921)

* shrunk loss test (#922)

* Fix broken test, per issue #913 (#914)

Co-authored-by: Bryce Meyer <[email protected]>

* Fix loading on specific device (#906)

* Fix loading on specific device

* format

---------

Co-authored-by: Bryce Meyer <[email protected]>

* changed dictionary keys to work with the new model loading

* restored old loading

* moved new weight conversion module to new area

* added new boot process with bridge

* isolated all adapters to their own directory

* simplified factory

* got things to boot properly

* updated component mapping

* updated bridge printing

* created a way to print out information of the model

* added gemma 3

* added initial generalized component base

* added new import

* overrode forward and generate function

* finished setting up component adapters

* updated naming of bridge components

* renamed some things

* added more components

* made repr a bit clearner

* added some more wrapper functionality

* got gemma 3 to run

* genearlized output

* added run with cache

* added blocks

* added some typing

* created test for testing translation

* got mapping to work properly

* resolved string issue

* generalized more things a bit more

* injected adapter to generalized component

* allowed returning the last part of the path

* got the model to run again with more generalized components

* passed input through hook point

* remvoed some print statements

* injected bridge

* updated typing

* translated some more architectures

* created moe

* converted bert

* added remaining component mapping

* removed extra functions

* cleaned up a bit

* ckleaned up more conversions

* migrated more architectures

* imported some more components

* additional imporvements for new system

* seperated types

* removed tl config from new booting

* added default config for some miodels

* registered additional architectures

* finished generate

* updated architecture

* updated test

* created base config class

* fixed some tests

* revered type changes

* fixed param in test

* removed extra test and added config file

* moved mock to centralized location

* moved factory to correct spot

* removed extra dataclass

* removed transformers coupling from bridge

* exposed block bridge from directory init

* removed extra comments

* removed abc parent class

* moved a couple thigns around

* renamed directory

* fixed some refactor issues

* fixed some more refactor issue

* fixed some more issues

* resolved issue

* made transformer lens config closer to hugging face

* fixed some imports

* fixed some more tests

* removed extra test

* fixed test

* remvoed old class

* restared config and boot

* fixed default names

* removed post init

* remvoed transformer lens config

* reverted some changes

* removed old conversion step

* removed extra lines

* removed extra line

* removed extra params

* removed extra config

* ran format

* fixed test

* ran format

* fixed test

* ran format

* fixed docstring

* ran format

* fied soem type issues

* fixed some more typing issues

* fixed more mypy errors

* ran format

* fixed more typings

* ran format

* fixed more mypy issues

* ran format

* removed extra test

* fixed test

* fixed some typing

* ran format

---------

Co-authored-by: Mateusz Piotrowski <[email protected]>
Co-authored-by: Jason Benn <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants