diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 0426776beed1..d3e97c1f7b2f 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): # important: this ported version of Llava isn't meant for training from scratch - only diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index ad76561df54f..487ef7789e63 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaNextVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): # important: this ported version of LlavaNext isn't meant for training from scratch - only diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index e3264dfd91e1..a900805eb8d4 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaNextVideoVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 8a693e56f80b..5ba20ff52788 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): # important: this ported version of PaliGemmaisn't meant for training from scratch - only diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index cb54c433fde8..330ef62e56fb 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): std = ( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index c5f856e78745..0687880f0bb7 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VipLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): # important: this ported version of VipLlava isn't meant for training from scratch - only