diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index f5f082c33..42752e7bf 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -4,7 +4,7 @@ import torch from transformer_lens.HookedTransformer import HookedTransformer -from transformer_lens.utilities.devices import get_device_for_block_index +from transformer_lens.utilities.devices import get_best_available_device @pytest.fixture @@ -19,36 +19,6 @@ def gpt2_medium_on_4_devices(): return model -@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices") -def test_get_device_for_block_index(gpt2_medium_on_4_devices): - config = gpt2_medium_on_4_devices.cfg - n_layers = config.n_layers - n_devices = config.n_devices - layers_per_device = n_layers // n_devices - config_device = torch.device(config.device) - - # Test with default device (config.device) - for i in range(n_layers): - expected_device = torch.device(config_device.type, i // layers_per_device) - assert get_device_for_block_index(i, config) == expected_device - - # Test with explicit device - device_override = "cuda" - for i in range(n_layers): - expected_device = torch.device(device_override, i // layers_per_device) - assert get_device_for_block_index(i, config, device_override) == expected_device - - # Test with explicit torch.device object - device_override_obj = torch.device("cuda") - for i in range(n_layers): - expected_device = torch.device(device_override_obj.type, i // layers_per_device) - assert get_device_for_block_index(i, config, device_override_obj) == expected_device - - # Test when index is out of bounds - # with pytest.raises(IndexError): - # get_device_for_block_index(n_layers, config) - - @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices") @pytest.mark.parametrize("n_devices", [1, 2, 3, 4]) def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): @@ -85,7 +55,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): # Make sure the tensors in cache remain on their respective devices for i in range(model_n_devices.cfg.n_layers): - expected_device = get_device_for_block_index(i, cfg=model_n_devices.cfg) + expected_device = get_best_available_device(model_n_devices.cfg.device) cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device assert cache_device == expected_device diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py new file mode 100644 index 000000000..a2a02a119 --- /dev/null +++ b/tests/unit/utilities/test_devices.py @@ -0,0 +1,66 @@ +from unittest.mock import Mock + +import torch + +from transformer_lens.utilities.devices import ( + calculate_available_device_cuda_memory, + determine_available_memory_for_available_devices, + sort_devices_based_on_available_memory, +) + + +def mock_available_devices(memory_stats: list[tuple[int, int]]): + torch.cuda.device_count = Mock(return_value=len(memory_stats)) + + def device_props_return(*args, **kwargs): + total_memory = memory_stats[args[0]][0] + device_props = Mock() + device_props.total_memory = total_memory + return device_props + + def memory_allocated_return(*args, **kwargs): + return memory_stats[args[0]][1] + + torch.cuda.get_device_properties = Mock(side_effect=device_props_return) + torch.cuda.memory_allocated = Mock(side_effect=memory_allocated_return) + + +def test_calculate_available_device_cuda_memory(): + mock_available_devices([(80, 40)]) + + result = calculate_available_device_cuda_memory(0) + assert result == 40 + + +def test_determine_available_memory_for_available_devices(): + mock_available_devices( + [ + (80, 60), + (80, 15), + (80, 40), + ] + ) + + result = determine_available_memory_for_available_devices(3) + + assert result == [ + (0, 20), + (1, 65), + (2, 40), + ] + + +def test_sort_devices_based_on_available_memory(): + devices = [ + (0, 20), + (1, 65), + (2, 40), + ] + + result = sort_devices_based_on_available_memory(devices) + + assert result == [ + (1, 65), + (2, 40), + (0, 20), + ] diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index fe7a4de05..fcd2ab4c1 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1091,17 +1091,17 @@ def mps(self): return self.to("mps") def move_model_modules_to_device(self): - self.embed.to(devices.get_device_for_block_index(0, self.cfg)) - self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg)) + self.embed.to(devices.get_best_available_device(self.cfg)) + self.hook_embed.to(devices.get_best_available_device(self.cfg)) if self.cfg.positional_embedding_type != "rotary": - self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) - self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) + self.pos_embed.to(devices.get_best_available_device(self.cfg)) + self.hook_pos_embed.to(devices.get_best_available_device(self.cfg)) if hasattr(self, "ln_final"): - self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) - self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) + self.ln_final.to(devices.get_best_available_device(self.cfg)) + self.unembed.to(devices.get_best_available_device(self.cfg)) for i, block in enumerate(self.blocks): - block.to(devices.get_device_for_block_index(i, self.cfg)) + block.to(devices.get_best_available_device(self.cfg)) @classmethod def from_pretrained( diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 009d2cfb8..3b69bc738 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -279,6 +279,12 @@ def forward( w = einops.rearrange( self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" ) + + if self.b_O.device != w.device: + w = w.to(self.b_O.device) + if self.b_O.device != z.device: + z = z.to(self.b_O.device) + out = F.linear( z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), w, @@ -552,6 +558,10 @@ def apply_rotary( attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, ) -> Float[torch.Tensor, "batch pos head_index d_head"]: # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) + + if x.device != self.rotary_sin.device: + x = x.to(self.rotary_sin.device) + x_pos = x.size(1) x_rot = x[..., : self.cfg.rotary_dim] x_pass = x[..., self.cfg.rotary_dim :] diff --git a/transformer_lens/components/mlps/gated_mlp.py b/transformer_lens/components/mlps/gated_mlp.py index 438e9cda1..1386a157c 100644 --- a/transformer_lens/components/mlps/gated_mlp.py +++ b/transformer_lens/components/mlps/gated_mlp.py @@ -50,6 +50,8 @@ def forward( self, x: Float[torch.Tensor, "batch pos d_model"] ) -> Float[torch.Tensor, "batch pos d_model"]: # Technically, all these einsums could be done with a single matmul, but this is more readable. + if self.W_gate.device != x.device: + x = x.to(self.W_gate.device) pre_act = self.hook_pre( torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp ) # [batch, pos, d_mlp] diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py index 26d5c7c63..9867fa626 100644 --- a/transformer_lens/components/rms_norm.py +++ b/transformer_lens/components/rms_norm.py @@ -42,4 +42,8 @@ def forward( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() ) x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] + + if x.device != self.w.device: + self.to(x.device) + return x * self.w diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 469fe66e1..dcce1586a 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -173,6 +173,10 @@ def forward( # is added to the residual stream" attn_out = self.ln1_post(attn_out) attn_out = self.hook_attn_out(attn_out) + + if resid_pre.device != attn_out.device: + resid_pre = resid_pre.to(attn_out.device) + if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] mlp_in = ( diff --git a/transformer_lens/utilities/attention.py b/transformer_lens/utilities/attention.py index 3deb23a19..e981c5b84 100644 --- a/transformer_lens/utilities/attention.py +++ b/transformer_lens/utilities/attention.py @@ -15,8 +15,15 @@ def simple_attn_linear( b: Float[torch.Tensor, "head_index d_head"], ) -> Float[torch.Tensor, "batch pos head_index d_head"]: """Linear layer for attention calculation.""" + + if input.device != w.device: + w = w.to(input.device) + if input.device != b.device: + b = b.to(input.device) + w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model") b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)") + return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1]) diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index f7de5d3c7..dbc97c080 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -13,6 +13,93 @@ import transformer_lens +AvailableDeviceMemory = list[tuple[int, int]] +""" +This type is passed around between different CUDA memory operations. +The first entry of each tuple will be the device index. +The second entry will be how much memory is currently available. +""" + + +def calculate_available_device_cuda_memory(i: int) -> int: + """Calculates how much memory is available at this moment for the device at the indicated index + + Args: + i (int): The index we are looking at + + Returns: + int: How memory is available + """ + total = torch.cuda.get_device_properties(i).total_memory + allocated = torch.cuda.memory_allocated(i) + return total - allocated + + +def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory: + """Gets all available CUDA devices with their current memory calculated + + Returns: + AvailableDeviceMemory: The list of all available devices with memory precalculated + """ + devices = [] + for i in range(max_devices): + devices.append((i, calculate_available_device_cuda_memory(i))) + + return devices + + +def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory: + """Sorts all available devices with devices with the most available memory returned first + + Args: + devices (AvailableDeviceMemory): All available devices with memory calculated + + Returns: + AvailableDeviceMemory: The same list of passed through devices sorted with devices with most + available memory first + """ + return sorted(devices, key=lambda x: x[1], reverse=True) + + +def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device: + """Gets whichever cuda device has the most available amount of memory for use + + Raises: + EnvironmentError: If there are no available devices, this will error out + + Returns: + torch.device: The specific device that should be used + """ + max_devices = max_devices if max_devices is not None else torch.cuda.device_count() + devices = determine_available_memory_for_available_devices(max_devices) + + if len(devices) <= 0: + raise EnvironmentError( + "TransformerLens has been configured to use CUDA, but no available devices are present" + ) + + sorted_devices = sort_devices_based_on_available_memory(devices=devices) + + return torch.device("cuda", sorted_devices[0][0]) + + +def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -> torch.device: + """Gets the best available device to be used based on the passed in arguments + + Args: + device (Union[torch.device, str]): Either the existing torch device or the string identifier + + Returns: + torch.device: The best available device + """ + assert cfg.device is not None + device = torch.device(cfg.device) + + if device.type == "cuda": + return get_best_available_cuda_device(cfg.n_devices) + else: + return device + def get_device_for_block_index( index: int, @@ -25,6 +112,7 @@ def get_device_for_block_index( This function assists in distributing model layers across multiple devices. The distribution is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). + Args: index (int): Model layer index. cfg (HookedTransformerConfig): Model and device configuration. @@ -33,6 +121,11 @@ def get_device_for_block_index( Returns: torch.device: The device for the specified layer index. + + Deprecated: + This function did not take into account a few factors for multi-GPU support. You should now + use get_best_available_device in order to properly run models on multiple devices. + This will be removed in 3.0 """ assert cfg.device is not None layers_per_device = cfg.n_layers // cfg.n_devices