Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
5d049c4
moved setup python
bryce13950 Feb 5, 2025
b08c241
added PR action
bryce13950 Feb 5, 2025
23a7be8
temporarily hardcoded version number
bryce13950 Feb 5, 2025
9a3d869
moved poetry
bryce13950 Feb 5, 2025
64e009c
Revert "temporarily hardcoded version number"
bryce13950 Feb 5, 2025
32e5c2f
Revert "added PR action"
bryce13950 Feb 5, 2025
fd38e0f
Merge pull request #855 from TransformerLensOrg/ci-release-update
bryce13950 Feb 5, 2025
d7f9eb1
Merge pull request #866 from TransformerLensOrg/dev
bryce13950 Feb 13, 2025
5e328e9
Merge pull request #870 from TransformerLensOrg/dev
bryce13950 Feb 18, 2025
e65fafb
Merge pull request #874 from TransformerLensOrg/dev
bryce13950 Feb 20, 2025
3212375
added full hf token authentication (#916)
bryce13950 Apr 30, 2025
d2f3f15
Fix LLama RoPE (#910)
mntss Apr 30, 2025
cae0a67
added conditional check for hugging face (#919)
bryce13950 May 5, 2025
c8f8bb4
created a seperate list of models to test for public PRs (#920)
bryce13950 May 5, 2025
9f8f4d4
added alternative when hf token is not included (#921)
bryce13950 May 5, 2025
a8fd24f
shrunk loss test (#922)
bryce13950 May 5, 2025
4f97497
Fix broken test, per issue #913 (#914)
JasonBenn May 5, 2025
f6af70f
Fix loading on specific device (#906)
mntss May 5, 2025
26f9e9f
changed dictionary keys to work with the new model loading
bryce13950 May 8, 2025
c1f3c49
restored old loading
bryce13950 May 8, 2025
ca01763
moved new weight conversion module to new area
bryce13950 May 8, 2025
4eb017a
added new boot process with bridge
bryce13950 May 8, 2025
79bb036
isolated all adapters to their own directory
bryce13950 May 8, 2025
5bebdd0
simplified factory
bryce13950 May 8, 2025
e3e8c25
got things to boot properly
bryce13950 May 8, 2025
16803be
updated component mapping
bryce13950 May 8, 2025
a4f7b75
updated bridge printing
bryce13950 May 8, 2025
bb2e6d3
created a way to print out information of the model
bryce13950 May 9, 2025
bb1cd63
added gemma 3
bryce13950 May 9, 2025
2e98787
added initial generalized component base
bryce13950 May 9, 2025
88e204f
added new import
bryce13950 May 9, 2025
20aeba3
overrode forward and generate function
bryce13950 May 9, 2025
d4d9e9a
finished setting up component adapters
bryce13950 May 9, 2025
baf0b8c
updated naming of bridge components
bryce13950 May 9, 2025
5d3ec52
renamed some things
bryce13950 May 9, 2025
5968549
added more components
bryce13950 May 9, 2025
4b77de4
made repr a bit clearner
bryce13950 May 9, 2025
7e1131c
added some more wrapper functionality
bryce13950 May 11, 2025
ba13bac
got gemma 3 to run
bryce13950 May 11, 2025
0710900
genearlized output
bryce13950 May 12, 2025
cbab325
added run with cache
bryce13950 May 13, 2025
12601de
added blocks
bryce13950 May 14, 2025
442005c
added some typing
bryce13950 May 14, 2025
7fdf11e
created test for testing translation
bryce13950 May 14, 2025
3e91c29
got mapping to work properly
bryce13950 May 14, 2025
30a371f
resolved string issue
bryce13950 May 14, 2025
5a1d7ea
generalized more things a bit more
bryce13950 May 14, 2025
4a00195
injected adapter to generalized component
bryce13950 May 15, 2025
043fcfb
allowed returning the last part of the path
bryce13950 May 15, 2025
a3d2a27
got the model to run again with more generalized components
bryce13950 May 15, 2025
c3ec869
passed input through hook point
bryce13950 May 15, 2025
0217fb6
remvoed some print statements
bryce13950 May 15, 2025
aa3213b
injected bridge
bryce13950 May 15, 2025
b3bb122
updated typing
bryce13950 May 16, 2025
7bb77e9
translated some more architectures
bryce13950 May 16, 2025
c7b0aa0
created moe
bryce13950 May 16, 2025
0025f96
converted bert
bryce13950 May 16, 2025
55cd13a
added remaining component mapping
bryce13950 May 16, 2025
32403e8
removed extra functions
bryce13950 May 16, 2025
49868d7
cleaned up a bit
bryce13950 May 16, 2025
0bb9f78
ckleaned up more conversions
bryce13950 May 16, 2025
c7ed3b7
migrated more architectures
bryce13950 May 16, 2025
8fcdcbc
imported some more components
bryce13950 May 16, 2025
846189a
additional imporvements for new system
bryce13950 May 16, 2025
e4c8e87
seperated types
bryce13950 May 16, 2025
44a81ba
removed tl config from new booting
bryce13950 May 16, 2025
5b3ccc1
added default config for some miodels
bryce13950 May 16, 2025
9aaa38b
registered additional architectures
bryce13950 May 16, 2025
1f92f4c
finished generate
bryce13950 May 17, 2025
6af4a5c
updated architecture
bryce13950 May 19, 2025
ca44f53
updated test
bryce13950 May 19, 2025
36717c9
created base config class
bryce13950 May 19, 2025
bdb18a8
fixed some tests
bryce13950 May 19, 2025
32d6a88
revered type changes
bryce13950 May 19, 2025
77d7abb
fixed param in test
bryce13950 May 19, 2025
e0fa659
removed extra test and added config file
bryce13950 May 19, 2025
ee1306a
moved mock to centralized location
bryce13950 May 19, 2025
9a9ee92
moved factory to correct spot
bryce13950 May 19, 2025
d1829be
removed extra dataclass
bryce13950 May 19, 2025
c791b38
removed transformers coupling from bridge
bryce13950 May 19, 2025
969ac97
exposed block bridge from directory init
bryce13950 May 19, 2025
9df58e3
removed extra comments
bryce13950 May 20, 2025
f67afc9
removed abc parent class
bryce13950 May 20, 2025
613b47f
moved a couple thigns around
bryce13950 May 22, 2025
e8e1053
renamed directory
bryce13950 May 22, 2025
bbd9420
fixed some refactor issues
bryce13950 May 22, 2025
038a4ee
fixed some more refactor issue
bryce13950 May 22, 2025
fc223f6
fixed some more issues
bryce13950 May 22, 2025
737712c
resolved issue
bryce13950 May 22, 2025
72860b3
made transformer lens config closer to hugging face
bryce13950 May 22, 2025
66849ff
fixed some imports
bryce13950 May 22, 2025
b81e232
fixed some more tests
bryce13950 May 22, 2025
efe980c
removed extra test
bryce13950 May 22, 2025
163a921
fixed test
bryce13950 May 23, 2025
00a381b
remvoed old class
bryce13950 May 23, 2025
aa49459
restared config and boot
bryce13950 May 23, 2025
9c48310
fixed default names
bryce13950 May 23, 2025
f6663a3
removed post init
bryce13950 May 24, 2025
dbfddfa
remvoed transformer lens config
bryce13950 May 24, 2025
c227699
reverted some changes
bryce13950 May 24, 2025
fe51e4a
removed old conversion step
bryce13950 May 24, 2025
f05bd0a
removed extra lines
bryce13950 May 24, 2025
007cbdc
removed extra line
bryce13950 May 24, 2025
b0eb198
removed extra params
bryce13950 May 24, 2025
896f7fe
removed extra config
bryce13950 May 24, 2025
f5e7153
ran format
bryce13950 May 24, 2025
7f808f5
fixed test
bryce13950 May 25, 2025
d848e2f
ran format
bryce13950 May 25, 2025
34ea8c4
fixed test
bryce13950 May 25, 2025
37f1c79
ran format
bryce13950 May 25, 2025
eef1278
fixed docstring
bryce13950 May 25, 2025
83a6278
ran format
bryce13950 May 25, 2025
e19bd4f
fied soem type issues
bryce13950 May 25, 2025
f05d2af
fixed some more typing issues
bryce13950 May 26, 2025
64b453d
fixed more mypy errors
bryce13950 May 26, 2025
7d6bc1a
ran format
bryce13950 May 26, 2025
585a070
fixed more typings
bryce13950 May 26, 2025
8159af4
ran format
bryce13950 May 26, 2025
3240a14
fixed more mypy issues
bryce13950 May 26, 2025
b758c81
ran format
bryce13950 May 26, 2025
3dc45b7
removed extra test
bryce13950 May 26, 2025
5df4570
fixed test
bryce13950 May 26, 2025
68c7687
fixed some typing
bryce13950 May 26, 2025
6d9cbda
ran format
bryce13950 May 26, 2025
e36b98c
Merge branch 'dev' into feature-model-adapter
bryce13950 May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ jobs:
run: |
poetry check --lock
poetry install --with dev
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
huggingface-cli login --token $HF_TOKEN
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Unit Test
run: make unit-test
env:
Expand Down Expand Up @@ -108,6 +115,13 @@ jobs:
run: make docstring-test
- name: Type check
run: poetry run mypy .
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
huggingface-cli login --token $HF_TOKEN
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Test Suite with Coverage Report
run: make coverage-report-test
env:
Expand Down Expand Up @@ -198,6 +212,13 @@ jobs:
with:
name: test-coverage
path: docs/source/_static/coverage
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
huggingface-cli login --token $HF_TOKEN
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Build Docs
run: poetry run build-docs
env:
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@pytest.fixture(scope="module")
def model():
return HookedTransformer.from_pretrained("gpt2-small")
return HookedTransformer.from_pretrained("gpt2-small", device="cpu")


def test_basic_ioi_eval(model):
Expand Down
46 changes: 42 additions & 4 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,29 @@

PYTHIA_MODEL_NAMES = [name for name in OFFICIAL_MODEL_NAMES if name.startswith("EleutherAI/pythia")]

model_names = [
# Small subsets for basic testing
TINY_STORIES_SMALL_MODELS = ["roneneldan/TinyStories-1M"]
PYTHIA_SMALL_MODELS = ["EleutherAI/pythia-70m"]

# Use full lists if HF_TOKEN is available, otherwise use small subsets
TINY_STORIES_TEST_MODELS = (
TINY_STORIES_MODEL_NAMES if os.environ.get("HF_TOKEN", "") else TINY_STORIES_SMALL_MODELS
)
PYTHIA_TEST_MODELS = PYTHIA_MODEL_NAMES if os.environ.get("HF_TOKEN", "") else PYTHIA_SMALL_MODELS

# Small models for basic testing
PUBLIC_MODEL_NAMES = [
"attn-only-demo",
"gpt2-small",
"opt-125m",
"pythia-70m",
"tiny-stories-33M",
"microsoft/phi-1",
"google/gemma-2b",
]

# Full set of models to test
FULL_MODEL_NAMES = [
"attn-only-demo",
"gpt2-small",
"opt-125m",
Expand All @@ -42,7 +64,12 @@
"google/gemma-2b",
"google/gemma-7b",
]

# Use full model list if HF_TOKEN is available, otherwise use public models only
model_names = FULL_MODEL_NAMES if os.environ.get("HF_TOKEN", "") else PUBLIC_MODEL_NAMES

text = "Hello world!"

"""
# Code to regenerate loss store
store = {}
Expand All @@ -52,7 +79,15 @@
store[name] = loss.item()
print(store)
"""
loss_store = {

# Loss values for minimal testing
SMALL_LOSS_STORE = {
"gpt2-small": 5.331855773925781,
"pythia-70m": 4.659344673156738,
}

# Full set of loss values
FULL_LOSS_STORE = {
"attn-only-demo": 5.701841354370117,
"gpt2-small": 5.331855773925781,
"opt-125m": 6.159054279327393,
Expand All @@ -69,6 +104,9 @@
"bloom-560m": 5.237126350402832,
}

# Use full store if HF_TOKEN is available, otherwise use small store
loss_store = FULL_LOSS_STORE if os.environ.get("HF_TOKEN", "") else SMALL_LOSS_STORE

no_processing = [
("solu-1l", 5.256411552429199),
(
Expand Down Expand Up @@ -534,7 +572,7 @@ def edit_pos_embed(z, hook):


def test_all_tinystories_models_exist():
for model in TINY_STORIES_MODEL_NAMES:
for model in TINY_STORIES_TEST_MODELS:
try:
AutoConfig.from_pretrained(model)
except OSError:
Expand All @@ -545,7 +583,7 @@ def test_all_tinystories_models_exist():


def test_all_pythia_models_exist():
for model in PYTHIA_MODEL_NAMES:
for model in PYTHIA_TEST_MODELS:
try:
AutoConfig.from_pretrained(model)
except OSError:
Expand Down
24 changes: 21 additions & 3 deletions tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
loss_n_devices = model_n_devices(model_description_text, return_type="loss")
elapsed_time_n_devices = time.time() - start_time_n_devices

gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_text = (
"Natural language processing tasks, such as question answering, machine translation, reading comprehension, "
"and summarization, are typically approached with supervised learning on taskspecific datasets."
)
gpt2_tokens = model_1_device.to_tokens(gpt2_text)

gpt2_logits_1_device, gpt2_cache_1_device = model_1_device.run_with_cache(
Expand All @@ -55,7 +58,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):

# Make sure the tensors in cache remain on their respective devices
for i in range(model_n_devices.cfg.n_layers):
expected_device = get_best_available_device(model_n_devices.cfg.device)
expected_device = get_best_available_device(model_n_devices.cfg)
cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device
assert cache_device == expected_device

Expand All @@ -80,10 +83,25 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
assert prop_device == pytest.approx(expected_prop_device, rel=0.20)

print(
f"Number of devices: {n_devices}, Model loss (1 device): {loss_1_device}, Model loss ({n_devices} devices): {loss_n_devices}, Time taken (1 device): {elapsed_time_1_device:.4f} seconds, Time taken ({n_devices} devices): {elapsed_time_n_devices:.4f} seconds"
f"Number of devices: {n_devices}, Model loss (1 device): {loss_1_device}, Model loss ({n_devices} devices): {loss_n_devices}, "
f"Time taken (1 device): {elapsed_time_1_device:.4f} seconds, Time taken ({n_devices} devices): {elapsed_time_n_devices:.4f} seconds"
)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices")
def test_load_model_on_target_device():
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1")
assert model.cfg.device == "cuda:1"

for name, param in model.named_parameters():
assert param.device == torch.device(
"cuda:1"
), f"Parameter {name} is on {param.device} instead of cuda:1"

output = model("Hello world")
assert output.device == torch.device("cuda:1")


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices")
def test_cache_device():
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1")
Expand Down
14 changes: 12 additions & 2 deletions tests/acceptance/test_tokenizer_special_tokens.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os

from transformers import AutoTokenizer

import transformer_lens.loading_from_pretrained as loading
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Get's tedious typing these out everytime I want to sweep over all the distinct small models
MODEL_TESTING_LIST = [
# Small models for basic testing
PUBLIC_MODEL_TESTING_LIST = ["gpt2-small", "opt-125m", "pythia-70m"]

# Full set of models to test when HF_TOKEN is available
FULL_MODEL_TESTING_LIST = [
"solu-1l",
"gpt2-small",
"gpt-neo-125M",
Expand All @@ -14,6 +19,11 @@
"pythia-70m",
]

# Use full model list if HF_TOKEN is available, otherwise use public models only
MODEL_TESTING_LIST = (
FULL_MODEL_TESTING_LIST if os.environ.get("HF_TOKEN", "") else PUBLIC_MODEL_TESTING_LIST
)


def test_d_vocab_from_tokenizer():
cfg = HookedTransformerConfig(
Expand Down
114 changes: 114 additions & 0 deletions tests/integration/model_bridge/test_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Integration tests for the model bridge functionality.

This module contains tests that verify the core functionality of the model bridge,
including model initialization, text generation, hooks, and caching.
"""

import pytest
import torch

from transformer_lens.boot import boot


def test_model_initialization():
"""Test that the model can be initialized correctly."""
model_name = "gpt2" # Use a smaller model for testing
bridge = boot(model_name)

assert bridge is not None, "Bridge should be initialized"
assert bridge.tokenizer is not None, "Tokenizer should be initialized"
assert isinstance(bridge.model, torch.nn.Module), "Model should be a PyTorch module"


def test_text_generation():
"""Test basic text generation functionality."""
model_name = "gpt2" # Use a smaller model for testing
bridge = boot(model_name)

prompt = "The quick brown fox jumps over the lazy dog"
output = bridge.generate(prompt, max_new_tokens=10)

assert isinstance(output, str), "Output should be a string"
assert len(output) > len(prompt), "Generated text should be longer than the prompt"


def test_hooks():
"""Test that hooks can be added and removed correctly."""
model_name = "gpt2" # Use a smaller model for testing
bridge = boot(model_name)

# Track if hook was called
hook_called = False

def test_hook(tensor, hook):
nonlocal hook_called
hook_called = True
return tensor

# Add hook to first attention layer
hook_name = "blocks.0.attn"
bridge.blocks[0].attn.add_hook(test_hook)

# Run model
prompt = "Test prompt"
bridge.generate(prompt, max_new_tokens=1)

# Verify hook was called
assert hook_called, "Hook should have been called"

# Remove hook
bridge.blocks[0].attn.remove_hooks()
hook_called = False

# Run model again
bridge.generate(prompt, max_new_tokens=1)

# Verify hook was not called
assert not hook_called, "Hook should not have been called after removal"


def test_cache():
"""Test that the cache functionality works correctly."""
model_name = "gpt2" # Use a smaller model for testing
bridge = boot(model_name)

prompt = "Test prompt"
output, cache = bridge.run_with_cache(prompt)

# Verify output and cache
assert isinstance(output, torch.Tensor), "Output should be a tensor"
assert isinstance(cache, dict), "Cache should be a dictionary"
assert len(cache) > 0, "Cache should contain activations"

# Verify cache contains some expected keys (using actual HuggingFace model structure)
# The exact keys depend on the model architecture, but we should have some basic ones
cache_keys = list(cache.keys())
assert any("wte" in key for key in cache_keys), "Cache should contain word token embeddings"
assert any("ln_f" in key for key in cache_keys), "Cache should contain final layer norm"
assert any("lm_head" in key for key in cache_keys), "Cache should contain language model head"

# Verify that cached tensors are actually tensors
for key, value in cache.items():
assert isinstance(value, torch.Tensor), f"Cache value for {key} should be a tensor"


def test_component_access():
"""Test that model components can be accessed correctly."""
model_name = "gpt2" # Use a smaller model for testing
bridge = boot(model_name)

# Test accessing various components
assert hasattr(bridge, "embed"), "Bridge should have embed component"
assert hasattr(bridge, "blocks"), "Bridge should have blocks component"
assert hasattr(bridge, "unembed"), "Bridge should have unembed component"

# Test accessing block components
block = bridge.blocks[0]
assert hasattr(block, "attn"), "Block should have attention component"
assert hasattr(block, "mlp"), "Block should have MLP component"
assert hasattr(block, "ln1"), "Block should have first layer norm"
assert hasattr(block, "ln2"), "Block should have second layer norm"


if __name__ == "__main__":
pytest.main([__file__])

This file was deleted.

36 changes: 36 additions & 0 deletions tests/mocks/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Mock models for testing."""

import torch.nn as nn


class MockGemma3Model(nn.Module):
"""A mock implementation of the Gemma 3 model architecture for testing purposes.

This mock model replicates the key architectural components of Gemma 3:
- Embedding layer (embed_tokens)
- Multiple transformer layers with:
- Input and post-attention layer norms
- Self-attention with Q, K, V, O projections
- MLP with up, gate, and down projections
- Final layer norm
"""

def __init__(self):
super().__init__()
self.model = nn.Module()
self.model.embed_tokens = nn.Embedding(1000, 512)
self.model.layers = nn.ModuleList([nn.Module() for _ in range(2)])
for layer in self.model.layers:
layer.input_layernorm = nn.LayerNorm(512)
layer.post_attention_layernorm = nn.LayerNorm(512)
layer.self_attn = nn.Module()
layer.self_attn.q_proj = nn.Linear(512, 512)
layer.self_attn.k_proj = nn.Linear(512, 512)
layer.self_attn.v_proj = nn.Linear(512, 512)
layer.self_attn.o_proj = nn.Linear(512, 512)
layer.mlp = nn.Module()
layer.mlp.up_proj = nn.Linear(512, 2048)
layer.mlp.gate_proj = nn.Linear(512, 2048)
layer.mlp.down_proj = nn.Linear(2048, 512)
self.model.norm = nn.LayerNorm(512)
self.embed_tokens = self.model.embed_tokens # For shared embedding/unembedding
Loading