Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 1 addition & 14 deletions tests/unit/test_loading_from_pretrained_utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from unittest import mock

import pytest

from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.loading_from_pretrained import fill_missing_keys
Expand All @@ -17,7 +15,7 @@ def get_default_config():


@mock.patch("logging.warning")
def test_fill_missing_keys(mock_warning):
def test_fill_missing_keys(mock_warning: mock.MagicMock):
cfg = get_default_config()
model = HookedTransformer(cfg)
default_state_dict = model.state_dict()
Expand Down Expand Up @@ -59,14 +57,3 @@ def test_fill_missing_keys_no_missing_keys():
filled_state_dict = fill_missing_keys(model, default_state_dict)

assert filled_state_dict == default_state_dict


# Failures


def test_fill_missing_keys_raises_error_on_invalid_model():
invalid_model = None
default_state_dict = {}

with pytest.raises(AttributeError):
fill_missing_keys(invalid_model, default_state_dict)
5 changes: 3 additions & 2 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,11 +890,12 @@ def stack_neuron_results(
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)

neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply(
neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply(
torch.arange(self.model.cfg.d_mlp), dim=0
)
if type(neuron_labels) == int:
if isinstance(neuron_labels, int):
neuron_labels = np.array([neuron_labels])

for l in range(layer):
# Note that this has shape batch x pos x head_index x d_model
components.append(
Expand Down
22 changes: 13 additions & 9 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
import os
from typing import Dict, List, Optional, Tuple, Union, cast, overload
from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast, overload

import torch
from einops import repeat
Expand All @@ -32,6 +32,8 @@
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.utilities import devices

T = TypeVar("T", bound="HookedEncoder")


class HookedEncoder(HookedRootModule):
"""
Expand Down Expand Up @@ -332,17 +334,19 @@ def to( # type: ignore
):
return devices.move_to_and_update_config(self, device_or_dtype, print_details)

def cuda(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("cuda")
def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
if isinstance(device, int):
return self.to(f"cuda:{device}")
elif device is None:
return self.to("cuda")
else:
return self.to(device)

def cpu(self):
# Wrapper around cuda that also changes self.cfg.device
def cpu(self: T) -> T:
return self.to("cpu")

def mps(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("mps")
def mps(self: T) -> T:
return self.to(torch.device("mps"))

@classmethod
def from_pretrained(
Expand Down
22 changes: 13 additions & 9 deletions transformer_lens/HookedEncoderDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, cast, overload
from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast, overload

import torch
import tqdm
Expand All @@ -29,6 +29,8 @@
from transformer_lens.utilities import devices
from transformer_lens.utils import sample_logits

T = TypeVar("T", bound="HookedEncoderDecoder")


class HookedEncoderDecoder(HookedRootModule):
"""
Expand Down Expand Up @@ -507,17 +509,19 @@ def to( # type: ignore
):
return devices.move_to_and_update_config(self, device_or_dtype, print_details)

def cuda(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("cuda")
def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
if isinstance(device, int):
return self.to(f"cuda:{device}")
elif device is None:
return self.to("cuda")
else:
return self.to(device)

def cpu(self):
# Wrapper around cuda that also changes self.cfg.device
def cpu(self: T) -> T:
return self.to("cpu")

def mps(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("mps")
def mps(self: T) -> T:
return self.to(torch.device("mps"))

@classmethod
def from_pretrained(
Expand Down
23 changes: 14 additions & 9 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
a deeper understanding of the internal workings of transformers like GPT-2.
"""

from __future__ import annotations

import logging
import os
from typing import (
Expand Down Expand Up @@ -1079,17 +1081,20 @@ def to( # type: ignore
):
return devices.move_to_and_update_config(self, device_or_dtype, print_details)

def cuda(self):
"""Wrapper around cuda that also changes `self.cfg.device`."""
return self.to("cuda")
def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
# TODO: Add support for kwargs
if isinstance(device, int):
return self.to(f"cuda:{device}")
elif device is None:
return self.to("cuda")
else:
return self.to(device)

def cpu(self):
"""Wrapper around cuda that also changes `self.cfg.device`."""
return self.to("cpu")
def cpu(self: T) -> T:
return self.to(torch.device("cpu"))

def mps(self):
"""Wrapper around mps that also changes `self.cfg.device`."""
return self.to("mps")
def mps(self: T) -> T:
return self.to(torch.device("mps"))

def move_model_modules_to_device(self):
self.embed.to(devices.get_best_available_device(self.cfg))
Expand Down
Loading