diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bb9c4565d114..02f1cc939e8b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1069,6 +1069,8 @@ def update( A tuple containing the updated key and value states. """ cache_position = cache_kwargs.get("cache_position") + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1080,8 +1082,6 @@ def update( # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place # operation, that avoids copies and uses less memory. try: - # If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one - cache_position.to(device=k_out.device) k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) except NotImplementedError: