diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d78df7471..ae20258f8 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -229,7 +229,7 @@ jobs: - "BERT" - "Exploratory_Analysis_Demo" # - "Grokking_Demo" - # - "Head_Detector_Demo" + - "Head_Detector_Demo" # - "Interactive_Neuroscope" # - "LLaMA" # - "LLaMA2_GPU_Quantized" diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 33c9b09d8..6f27d1af0 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -316,7 +316,7 @@ " ipython.magic(\"autoreload 2\")\n", "\n", "if IN_COLAB or IN_GITHUB:\n", - " %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n", + " %pip install transformer_lens\n", " # Install Neel's personal plotting utils\n", " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", " # Install another version of node that makes PySvelte work way faster\n", @@ -329,7 +329,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "id": "LBjE0qm6Ahyf" }, @@ -338,7 +338,7 @@ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", "\n", - "if IN_COLAB or not DEBUG_MODE:\n", + "if IN_COLAB or not DEVELOPMENT_MODE:\n", " # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n", " pio.renderers.default = \"colab\"\n", "else:\n", @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "ScWILAgIGt5O" }, @@ -359,7 +359,8 @@ "from tqdm import tqdm\n", "\n", "import transformer_lens\n", - "from transformer_lens import HookedTransformer, ActivationCache\n", + "from transformer_lens import ActivationCache\n", + "from transformer_lens.model_bridge import TransformerBridge\n", "from neel_plotly import line, imshow, scatter" ] }, @@ -479,7 +480,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "id": "5ikyL8-S7u2Z" }, @@ -493,7 +494,6 @@ "import numpy as np\n", "import torch\n", "\n", - "from transformer_lens import HookedTransformer, ActivationCache\n", "# from transformer_lens.utils import is_lower_triangular, is_square\n", "\n", "HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n", @@ -515,7 +515,7 @@ "\n", "\n", "def detect_head(\n", - " model: HookedTransformer,\n", + " model: TransformerBridge,\n", " seq: Union[str, List[str]],\n", " detection_pattern: Union[torch.Tensor, HeadName],\n", " heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,\n", @@ -566,14 +566,16 @@ " --------\n", " .. code-block:: python\n", "\n", - " >>> from transformer_lens import HookedTransformer, utils\n", + " >>> from transformer_lens import utils\n", + " >>> from transformer_lens.model_bridge import TransformerBridge\n", " >>> from transformer_lens.head_detector import detect_head\n", " >>> import plotly.express as px\n", "\n", " >>> def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", " >>> px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", "\n", - " >>> model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", + " >>> model = TransformerBridge.boot_transformers(\"gpt2\")\n", + " >>> model.enable_compatibility_mode()\n", " >>> sequence = \"This is a test sequence. This is a test sequence.\"\n", "\n", " >>> attention_score = detect_head(model, sequence, \"previous_token_head\")\n", @@ -777,7 +779,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -802,7 +804,8 @@ } ], "source": [ - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)" + "model = TransformerBridge.boot_transformers(\"gpt2\", device=device)\n", + "model.enable_compatibility_mode()" ] }, {