Skip to content

Commit c8f8bb4

Browse files
authored
created a seperate list of models to test for public PRs (#920)
* created a seperate list of models to test for public PRs * ran format
1 parent cae0a67 commit c8f8bb4

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

tests/acceptance/test_evals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
@pytest.fixture(scope="module")
88
def model():
9-
return HookedTransformer.from_pretrained("gpt2-small")
9+
return HookedTransformer.from_pretrained("gpt2-small", device="cpu")
1010

1111

1212
def test_basic_ioi_eval(model):

tests/acceptance/test_hooked_transformer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,19 @@
2121

2222
PYTHIA_MODEL_NAMES = [name for name in OFFICIAL_MODEL_NAMES if name.startswith("EleutherAI/pythia")]
2323

24-
model_names = [
24+
# Small models for basic testing
25+
PUBLIC_MODEL_NAMES = [
26+
"attn-only-demo",
27+
"gpt2-small",
28+
"opt-125m",
29+
"pythia-70m",
30+
"tiny-stories-33M",
31+
"microsoft/phi-1",
32+
"google/gemma-2b",
33+
]
34+
35+
# Full set of models to test
36+
FULL_MODEL_NAMES = [
2537
"attn-only-demo",
2638
"gpt2-small",
2739
"opt-125m",
@@ -42,6 +54,10 @@
4254
"google/gemma-2b",
4355
"google/gemma-7b",
4456
]
57+
58+
# Use full model list if HF_TOKEN is available, otherwise use public models only
59+
model_names = FULL_MODEL_NAMES if os.environ.get("HF_TOKEN", "") else PUBLIC_MODEL_NAMES
60+
4561
text = "Hello world!"
4662
"""
4763
# Code to regenerate loss store

tests/acceptance/test_tokenizer_special_tokens.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import os
2+
13
from transformers import AutoTokenizer
24

35
import transformer_lens.loading_from_pretrained as loading
46
from transformer_lens import HookedTransformer, HookedTransformerConfig
57

6-
# Get's tedious typing these out everytime I want to sweep over all the distinct small models
7-
MODEL_TESTING_LIST = [
8+
# Small models for basic testing
9+
PUBLIC_MODEL_TESTING_LIST = ["gpt2-small", "opt-125m", "pythia-70m"]
10+
11+
# Full set of models to test when HF_TOKEN is available
12+
FULL_MODEL_TESTING_LIST = [
813
"solu-1l",
914
"gpt2-small",
1015
"gpt-neo-125M",
@@ -14,6 +19,11 @@
1419
"pythia-70m",
1520
]
1621

22+
# Use full model list if HF_TOKEN is available, otherwise use public models only
23+
MODEL_TESTING_LIST = (
24+
FULL_MODEL_TESTING_LIST if os.environ.get("HF_TOKEN", "") else PUBLIC_MODEL_TESTING_LIST
25+
)
26+
1727

1828
def test_d_vocab_from_tokenizer():
1929
cfg = HookedTransformerConfig(

0 commit comments

Comments
 (0)