Skip to content
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bcfcd83
updated loading in patchscopes generation demo to use transformer bridge
degenfabian Aug 18, 2025
b6a8339
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Aug 20, 2025
36bda1b
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Aug 22, 2025
b6297b3
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Aug 26, 2025
47730ca
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 4, 2025
205e572
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 5, 2025
925a3ce
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 6, 2025
0344c4b
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 7, 2025
c0cc76c
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 10, 2025
4a3ba3b
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 10, 2025
cbbfd68
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 12, 2025
8ed8a37
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 12, 2025
8816721
Merge remote-tracking branch 'origin/dev-3.x' into patchscopes_genera…
bryce13950 Sep 12, 2025
e476d41
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 10, 2025
283f976
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 13, 2025
50557cc
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 14, 2025
1b06b0c
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 14, 2025
dfb80f5
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 15, 2025
7fc4134
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 15, 2025
281e5eb
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 15, 2025
a91167a
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 16, 2025
dc87645
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 16, 2025
4033489
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 16, 2025
c2e3b8f
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 16, 2025
af3e640
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 16, 2025
f2593ec
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 17, 2025
a3b69ef
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Oct 23, 2025
ac6033b
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Nov 12, 2025
d1a9f75
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Nov 12, 2025
57d3a48
Merge remote-tracking branch 'origin/dev-3.x-folding' into patchscope…
bryce13950 Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions demos/Patchscopes_Generation_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"from typing import List, Callable, Tuple, Union\n",
"from functools import partial\n",
"from jaxtyping import Float\n",
"from transformer_lens import HookedTransformer\n",
"from transformer_lens.model_bridge import TransformerBridge\n",
"from transformer_lens.ActivationCache import ActivationCache\n",
"import transformer_lens.utils as utils\n",
"from transformer_lens.hook_points import (\n",
Expand Down Expand Up @@ -148,7 +148,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -217,7 +217,8 @@
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# I'm using an M2 macbook air, so I use CPU for better support\n",
"model = HookedTransformer.from_pretrained(\"gpt2-small\", device=\"cpu\")\n",
"model = TransformerBridge.boot_transformers(\"gpt2\", device=\"cpu\")\n",
"model.enable_compatibility_mode()\n",
"model.eval()"
]
},
Expand Down Expand Up @@ -263,17 +264,17 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_source_representation(prompts: List[str], layer_id: int, model: HookedTransformer, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n",
"def get_source_representation(prompts: List[str], layer_id: int, model: TransformerBridge, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n",
" \"\"\"Get source hidden representation represented by (S, i, M, l)\n",
" \n",
" Args:\n",
" - prompts (List[str]): a list of source prompts\n",
" - layer_id (int): the layer id of the model\n",
" - model (HookedTransformer): the source model\n",
" - model (TransformerBridge): the source model\n",
" - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n",
"\n",
" Returns:\n",
Expand Down Expand Up @@ -325,19 +326,19 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n",
"def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: HookedTransformer, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n",
"def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: TransformerBridge, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n",
" \"\"\"Feed the source hidden representation to the target model\n",
" \n",
" Args:\n",
" - source_rep (torch.Tensor): the source hidden representation\n",
" - prompt (List[str]): the target prompt\n",
" - f (Callable): the mapping function\n",
" - model (HookedTransformer): the target model\n",
" - model (TransformerBridge): the target model\n",
" - layer_id (int): the layer id of the target model\n",
" - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n",
" \"\"\"\n",
Expand Down Expand Up @@ -417,11 +418,11 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_with_patching(model: HookedTransformer, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n",
"def generate_with_patching(model: TransformerBridge, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n",
" temp_prompts = prompts\n",
" input_tokens = model.to_tokens(temp_prompts)\n",
" for _ in range(max_new_tokens):\n",
Expand Down
Loading