Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
30a09b3
updating loading in head detector demo to use transformer bridge
degenfabian Aug 18, 2025
c3c8d82
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Aug 20, 2025
6186819
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Aug 22, 2025
4829e3e
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Aug 26, 2025
b375e94
updated name
bryce13950 Aug 26, 2025
d7a6591
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 4, 2025
a96522e
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 5, 2025
d54b0de
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 6, 2025
b47ecbf
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 7, 2025
1dfe986
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 10, 2025
e106bdb
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 10, 2025
28a4b30
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 12, 2025
351359e
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 12, 2025
8f3dc89
Merge remote-tracking branch 'origin/dev-3.x' into head_detector_demo…
bryce13950 Sep 12, 2025
ea37933
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 10, 2025
542295a
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 13, 2025
116cb28
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 14, 2025
f9623b0
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 14, 2025
0d0debb
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 15, 2025
7bacf8d
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 15, 2025
7293287
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 15, 2025
e789e73
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 16, 2025
0f358ba
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 16, 2025
ce03721
updated installion source
bryce13950 Oct 16, 2025
291a7c3
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 16, 2025
d2fca2a
removed interactive from ci
bryce13950 Oct 16, 2025
580b6ae
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 16, 2025
309f372
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 16, 2025
64f3f5c
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 17, 2025
9b27323
Merge remote-tracking branch 'origin/dev-3.x-folding' into head_detec…
bryce13950 Oct 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ jobs:
- "BERT"
- "Exploratory_Analysis_Demo"
# - "Grokking_Demo"
# - "Head_Detector_Demo"
- "Head_Detector_Demo"
# - "Interactive_Neuroscope"
# - "LLaMA"
# - "LLaMA2_GPU_Quantized"
Expand Down
29 changes: 16 additions & 13 deletions demos/Head_Detector_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -316,7 +316,7 @@
" ipython.magic(\"autoreload 2\")\n",
"\n",
"if IN_COLAB or IN_GITHUB:\n",
" %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n",
" %pip install transformer_lens\n",
" # Install Neel's personal plotting utils\n",
" %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n",
" # Install another version of node that makes PySvelte work way faster\n",
Expand All @@ -329,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"id": "LBjE0qm6Ahyf"
},
Expand All @@ -338,7 +338,7 @@
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
"import plotly.io as pio\n",
"\n",
"if IN_COLAB or not DEBUG_MODE:\n",
"if IN_COLAB or not DEVELOPMENT_MODE:\n",
" # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n",
" pio.renderers.default = \"colab\"\n",
"else:\n",
Expand All @@ -347,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "ScWILAgIGt5O"
},
Expand All @@ -359,7 +359,8 @@
"from tqdm import tqdm\n",
"\n",
"import transformer_lens\n",
"from transformer_lens import HookedTransformer, ActivationCache\n",
"from transformer_lens import ActivationCache\n",
"from transformer_lens.model_bridge import TransformerBridge\n",
"from neel_plotly import line, imshow, scatter"
]
},
Expand Down Expand Up @@ -479,7 +480,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"id": "5ikyL8-S7u2Z"
},
Expand All @@ -493,7 +494,6 @@
"import numpy as np\n",
"import torch\n",
"\n",
"from transformer_lens import HookedTransformer, ActivationCache\n",
"# from transformer_lens.utils import is_lower_triangular, is_square\n",
"\n",
"HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n",
Expand All @@ -515,7 +515,7 @@
"\n",
"\n",
"def detect_head(\n",
" model: HookedTransformer,\n",
" model: TransformerBridge,\n",
" seq: Union[str, List[str]],\n",
" detection_pattern: Union[torch.Tensor, HeadName],\n",
" heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,\n",
Expand Down Expand Up @@ -566,14 +566,16 @@
" --------\n",
" .. code-block:: python\n",
"\n",
" >>> from transformer_lens import HookedTransformer, utils\n",
" >>> from transformer_lens import utils\n",
" >>> from transformer_lens.model_bridge import TransformerBridge\n",
" >>> from transformer_lens.head_detector import detect_head\n",
" >>> import plotly.express as px\n",
"\n",
" >>> def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
" >>> px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
"\n",
" >>> model = HookedTransformer.from_pretrained(\"gpt2-small\")\n",
" >>> model = TransformerBridge.boot_transformers(\"gpt2\")\n",
" >>> model.enable_compatibility_mode()\n",
" >>> sequence = \"This is a test sequence. This is a test sequence.\"\n",
"\n",
" >>> attention_score = detect_head(model, sequence, \"previous_token_head\")\n",
Expand Down Expand Up @@ -777,7 +779,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand All @@ -802,7 +804,8 @@
}
],
"source": [
"model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)"
"model = TransformerBridge.boot_transformers(\"gpt2\", device=device)\n",
"model.enable_compatibility_mode()"
]
},
{
Expand Down
Loading