diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 33bcbcda64b1..ccfd0836f82f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1087,8 +1087,8 @@ def update( cache_position = cache_kwargs.get("cache_position") self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] + k_out = self.key_cache[layer_idx].clone() + v_out = self.value_cache[layer_idx].clone() if cache_position is None: k_out.copy_(key_states)