diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 28f40952f2cd..6043f5ddeb4d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -91,8 +91,11 @@ class DynamicLayer(CacheLayerMixin): def lazy_initialization(self, key_states: torch.Tensor): self.dtype, self.device = key_states.dtype, key_states.device - self.keys = torch.tensor([], dtype=self.dtype, device=self.device) - self.values = torch.tensor([], dtype=self.dtype, device=self.device) + # Initialize with proper 4D shape: [batch_size, num_heads, 0, head_dim] + # This ensures torch.cat works correctly in torch.compile mode + batch_size, num_heads, _, head_dim = key_states.shape + self.keys = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=self.dtype, device=self.device) + self.values = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=self.dtype, device=self.device) self.is_initialized = True def update( @@ -545,8 +548,14 @@ def update( if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length: self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value) - self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device) - self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + # Reset to proper 4D empty tensors to ensure torch.cat works correctly in torch.compile mode + batch_size, num_heads, _, head_dim = key_states.shape + self.keys = torch.zeros( + (batch_size, num_heads, 0, head_dim), dtype=key_states.dtype, device=key_states.device + ) + self.values = torch.zeros( + (batch_size, num_heads, 0, head_dim), dtype=key_states.dtype, device=key_states.device + ) else: self.keys = torch.cat([self.keys, key_states], dim=-2) self.values = torch.cat([self.values, value_states], dim=-2)