Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.utilities.devices import get_device_for_block_index
from transformer_lens.utilities.devices import get_best_available_device


@pytest.fixture
Expand All @@ -19,36 +19,6 @@ def gpt2_medium_on_4_devices():
return model


@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices")
def test_get_device_for_block_index(gpt2_medium_on_4_devices):
config = gpt2_medium_on_4_devices.cfg
n_layers = config.n_layers
n_devices = config.n_devices
layers_per_device = n_layers // n_devices
config_device = torch.device(config.device)

# Test with default device (config.device)
for i in range(n_layers):
expected_device = torch.device(config_device.type, i // layers_per_device)
assert get_device_for_block_index(i, config) == expected_device

# Test with explicit device
device_override = "cuda"
for i in range(n_layers):
expected_device = torch.device(device_override, i // layers_per_device)
assert get_device_for_block_index(i, config, device_override) == expected_device

# Test with explicit torch.device object
device_override_obj = torch.device("cuda")
for i in range(n_layers):
expected_device = torch.device(device_override_obj.type, i // layers_per_device)
assert get_device_for_block_index(i, config, device_override_obj) == expected_device

# Test when index is out of bounds
# with pytest.raises(IndexError):
# get_device_for_block_index(n_layers, config)


@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices")
@pytest.mark.parametrize("n_devices", [1, 2, 3, 4])
def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
Expand Down Expand Up @@ -85,7 +55,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):

# Make sure the tensors in cache remain on their respective devices
for i in range(model_n_devices.cfg.n_layers):
expected_device = get_device_for_block_index(i, cfg=model_n_devices.cfg)
expected_device = get_best_available_device(model_n_devices.cfg.device)
cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device
assert cache_device == expected_device

Expand Down
66 changes: 66 additions & 0 deletions tests/unit/utilities/test_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest.mock import Mock

import torch

from transformer_lens.utilities.devices import (
calculate_available_device_cuda_memory,
determine_available_memory_for_available_devices,
sort_devices_based_on_available_memory,
)


def mock_available_devices(memory_stats: list[tuple[int, int]]):
torch.cuda.device_count = Mock(return_value=len(memory_stats))

def device_props_return(*args, **kwargs):
total_memory = memory_stats[args[0]][0]
device_props = Mock()
device_props.total_memory = total_memory
return device_props

def memory_allocated_return(*args, **kwargs):
return memory_stats[args[0]][1]

torch.cuda.get_device_properties = Mock(side_effect=device_props_return)
torch.cuda.memory_allocated = Mock(side_effect=memory_allocated_return)


def test_calculate_available_device_cuda_memory():
mock_available_devices([(80, 40)])

result = calculate_available_device_cuda_memory(0)
assert result == 40


def test_determine_available_memory_for_available_devices():
mock_available_devices(
[
(80, 60),
(80, 15),
(80, 40),
]
)

result = determine_available_memory_for_available_devices(3)

assert result == [
(0, 20),
(1, 65),
(2, 40),
]


def test_sort_devices_based_on_available_memory():
devices = [
(0, 20),
(1, 65),
(2, 40),
]

result = sort_devices_based_on_available_memory(devices)

assert result == [
(1, 65),
(2, 40),
(0, 20),
]
14 changes: 7 additions & 7 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,17 +1091,17 @@ def mps(self):
return self.to("mps")

def move_model_modules_to_device(self):
self.embed.to(devices.get_device_for_block_index(0, self.cfg))
self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg))
self.embed.to(devices.get_best_available_device(self.cfg))
self.hook_embed.to(devices.get_best_available_device(self.cfg))
if self.cfg.positional_embedding_type != "rotary":
self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
self.pos_embed.to(devices.get_best_available_device(self.cfg))
self.hook_pos_embed.to(devices.get_best_available_device(self.cfg))

if hasattr(self, "ln_final"):
self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
self.ln_final.to(devices.get_best_available_device(self.cfg))
self.unembed.to(devices.get_best_available_device(self.cfg))
for i, block in enumerate(self.blocks):
block.to(devices.get_device_for_block_index(i, self.cfg))
block.to(devices.get_best_available_device(self.cfg))

@classmethod
def from_pretrained(
Expand Down
10 changes: 10 additions & 0 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ def forward(
w = einops.rearrange(
self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
)

if self.b_O.device != w.device:
w = w.to(self.b_O.device)
if self.b_O.device != z.device:
z = z.to(self.b_O.device)

out = F.linear(
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
w,
Expand Down Expand Up @@ -552,6 +558,10 @@ def apply_rotary(
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
# Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)

if x.device != self.rotary_sin.device:
x = x.to(self.rotary_sin.device)

x_pos = x.size(1)
x_rot = x[..., : self.cfg.rotary_dim]
x_pass = x[..., self.cfg.rotary_dim :]
Expand Down
2 changes: 2 additions & 0 deletions transformer_lens/components/mlps/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def forward(
self, x: Float[torch.Tensor, "batch pos d_model"]
) -> Float[torch.Tensor, "batch pos d_model"]:
# Technically, all these einsums could be done with a single matmul, but this is more readable.
if self.W_gate.device != x.device:
x = x.to(self.W_gate.device)
pre_act = self.hook_pre(
torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp
) # [batch, pos, d_mlp]
Expand Down
4 changes: 4 additions & 0 deletions transformer_lens/components/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ def forward(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
)
x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]

if x.device != self.w.device:
self.to(x.device)

return x * self.w
4 changes: 4 additions & 0 deletions transformer_lens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def forward(
# is added to the residual stream"
attn_out = self.ln1_post(attn_out)
attn_out = self.hook_attn_out(attn_out)

if resid_pre.device != attn_out.device:
resid_pre = resid_pre.to(attn_out.device)

if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
mlp_in = (
Expand Down
7 changes: 7 additions & 0 deletions transformer_lens/utilities/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ def simple_attn_linear(
b: Float[torch.Tensor, "head_index d_head"],
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
"""Linear layer for attention calculation."""

if input.device != w.device:
w = w.to(input.device)
if input.device != b.device:
b = b.to(input.device)

w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")

return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])


Expand Down
93 changes: 93 additions & 0 deletions transformer_lens/utilities/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,93 @@

import transformer_lens

AvailableDeviceMemory = list[tuple[int, int]]
"""
This type is passed around between different CUDA memory operations.
The first entry of each tuple will be the device index.
The second entry will be how much memory is currently available.
"""


def calculate_available_device_cuda_memory(i: int) -> int:
"""Calculates how much memory is available at this moment for the device at the indicated index

Args:
i (int): The index we are looking at

Returns:
int: How memory is available
"""
total = torch.cuda.get_device_properties(i).total_memory
allocated = torch.cuda.memory_allocated(i)
return total - allocated


def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
"""Gets all available CUDA devices with their current memory calculated

Returns:
AvailableDeviceMemory: The list of all available devices with memory precalculated
"""
devices = []
for i in range(max_devices):
devices.append((i, calculate_available_device_cuda_memory(i)))

return devices


def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
"""Sorts all available devices with devices with the most available memory returned first

Args:
devices (AvailableDeviceMemory): All available devices with memory calculated

Returns:
AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
available memory first
"""
return sorted(devices, key=lambda x: x[1], reverse=True)


def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
"""Gets whichever cuda device has the most available amount of memory for use

Raises:
EnvironmentError: If there are no available devices, this will error out

Returns:
torch.device: The specific device that should be used
"""
max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
devices = determine_available_memory_for_available_devices(max_devices)

if len(devices) <= 0:
raise EnvironmentError(
"TransformerLens has been configured to use CUDA, but no available devices are present"
)

sorted_devices = sort_devices_based_on_available_memory(devices=devices)

return torch.device("cuda", sorted_devices[0][0])


def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -> torch.device:
"""Gets the best available device to be used based on the passed in arguments

Args:
device (Union[torch.device, str]): Either the existing torch device or the string identifier

Returns:
torch.device: The best available device
"""
assert cfg.device is not None
device = torch.device(cfg.device)

if device.type == "cuda":
return get_best_available_cuda_device(cfg.n_devices)
else:
return device


def get_device_for_block_index(
index: int,
Expand All @@ -25,6 +112,7 @@ def get_device_for_block_index(
This function assists in distributing model layers across multiple devices. The distribution
is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).


Args:
index (int): Model layer index.
cfg (HookedTransformerConfig): Model and device configuration.
Expand All @@ -33,6 +121,11 @@ def get_device_for_block_index(

Returns:
torch.device: The device for the specified layer index.

Deprecated:
This function did not take into account a few factors for multi-GPU support. You should now
use get_best_available_device in order to properly run models on multiple devices.
This will be removed in 3.0
"""
assert cfg.device is not None
layers_per_device = cfg.n_layers // cfg.n_devices
Expand Down
Loading