From 35e1a6434f10db5d5cd0835fb2bfaee8c4febf63 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Thu, 29 Aug 2024 17:25:27 +0300 Subject: [PATCH 01/54] add sdpa to OPT --- src/transformers/models/opt/modeling_opt.py | 245 +++++++++++++------- 1 file changed, 162 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f058171778e..1c0466a74e23 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,7 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -59,6 +59,8 @@ _SEQ_CLASS_EXPECTED_LOSS = 1.71 _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" +def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int) -> torch.Tensor: + return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() class OPTLearnedPositionalEmbedding(nn.Embedding): """ @@ -116,47 +118,26 @@ def __init__( self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj + def _update_key_and_values(self, hidden_states, key_value_states, past_key_value, is_cross_attention, bsz): if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = _shape(self.k_proj(key_value_states), -1, bsz,self.num_heads,self.head_dim) + value_states = _shape(self.v_proj(key_value_states), -1, bsz,self.num_heads,self.head_dim) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _shape(self.k_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = _shape(self.k_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -167,9 +148,33 @@ def forward( # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) + return past_key_value,key_states,value_states + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + past_key_value, key_states, value_states = self._update_key_and_values(hidden_states, key_value_states, past_key_value, is_cross_attention, bsz) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = _shape(query_states, tgt_len, bsz,self.num_heads,self.head_dim).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -274,36 +279,9 @@ def forward( bsz, _, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) + query_states = self.q_proj(hidden_states) # TODO check if scaling is needed? is this a bug? # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + past_key_value, key_states, value_states = self._update_key_and_values(hidden_states, key_value_states, past_key_value, is_cross_attention, bsz) query_length = query_states.shape[1] tgt_len = key_states.shape[-2] @@ -359,9 +337,77 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value + +class OPTSdpaAttention(OPTAttention): + """ + OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of sdpa + attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + position_ids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once("""OPTModel is using SDPA attention, which currently does not support output_attentions=True. + failing back to eager attention. remove warning using attn_implementation="eager".""") + + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + key_value_states=key_value_states, + ) # TODO after merge add position_ids=position_ids + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) * self.scaling + query_states = _shape(query_states, -1, bsz, self.num_heads, self.head_dim) + + # get key, value proj + past_key_value, key_states, value_states = self._update_key_and_values(hidden_states, key_value_states, past_key_value, is_cross_attention, bsz) + + # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + OPT_ATTENTION_CLASSES = { "eager": OPTAttention, "flash_attention_2": OptFlashAttention2, + "sdpa": OPTSdpaAttention, } @@ -488,6 +534,7 @@ class OPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = self.config.init_std @@ -604,6 +651,7 @@ def __init__(self, config: OPTConfig): self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -615,6 +663,54 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value + + def _update_causal_mask(self, + inputs_embeds: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): + """ + Updates the causal mask for the decoder. + """ + batch_size,seq_length = input_shape + mask_seq_length = past_key_values_length + seq_length + if self._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + + return causal_attention_mask, attention_mask + + if self._use_sdpa and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + is_training=self.training, + ): + return None,None + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + + if self._use_sdpa: + causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(attention_mask,input_shape,inputs_embeds,past_key_values_length) + else: + causal_attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + + return causal_attention_mask, attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -696,32 +792,15 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - batch_size, seq_length = input_shape past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - + + causal_attention_mask, attention_mask = self._update_causal_mask(inputs_embeds, + input_shape, + past_key_values_length, + attention_mask, + output_attentions) # embed positions - if self._use_flash_attention_2: - # 2d mask is passed through the layers - causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - attention_mask = ( - torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if attention_mask is None - else attention_mask - ) - else: - # 4d mask is passed through the layers - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)" - ) - causal_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) From 908e39bd85b8c2fc47263dd33388a46d3bff112d Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Thu, 29 Aug 2024 17:41:06 +0300 Subject: [PATCH 02/54] chore: remove redundant whitespace in OPTDecoder class --- src/transformers/models/opt/modeling_opt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 1c0466a74e23..1f9ef3e600d0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -778,6 +778,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") From c84a4dd1a165508f6fd0a111020da7106fa89d99 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Mon, 2 Sep 2024 18:50:15 +0300 Subject: [PATCH 03/54] fixup --- src/transformers/models/opt/modeling_opt.py | 88 +++++++++++---------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 1f9ef3e600d0..10e67dd61028 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,7 +22,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -59,9 +63,11 @@ _SEQ_CLASS_EXPECTED_LOSS = 1.71 _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int) -> torch.Tensor: return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() + class OPTLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -118,7 +124,6 @@ def __init__( self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) - def _update_key_and_values(self, hidden_states, key_value_states, past_key_value, is_cross_attention, bsz): if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions @@ -126,18 +131,18 @@ def _update_key_and_values(self, hidden_states, key_value_states, past_key_value value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = _shape(self.k_proj(key_value_states), -1, bsz,self.num_heads,self.head_dim) - value_states = _shape(self.v_proj(key_value_states), -1, bsz,self.num_heads,self.head_dim) + key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) elif past_key_value is not None: # reuse k, v, self_attention - key_states = _shape(self.k_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) - value_states = _shape(self.v_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) + key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = _shape(self.k_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) - value_states = _shape(self.v_proj(hidden_states), -1, bsz,self.num_heads,self.head_dim) + key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -148,7 +153,7 @@ def _update_key_and_values(self, hidden_states, key_value_states, past_key_value # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - return past_key_value,key_states,value_states + return past_key_value, key_states, value_states def forward( self, @@ -170,11 +175,12 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj - past_key_value, key_states, value_states = self._update_key_and_values(hidden_states, key_value_states, past_key_value, is_cross_attention, bsz) - + past_key_value, key_states, value_states = self._update_key_and_values( + hidden_states, key_value_states, past_key_value, is_cross_attention, bsz + ) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = _shape(query_states, tgt_len, bsz,self.num_heads,self.head_dim).view(*proj_shape) + query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -279,9 +285,11 @@ def forward( bsz, _, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) # TODO check if scaling is needed? is this a bug? + query_states = self.q_proj(hidden_states) # TODO check if scaling is needed? is this a bug? # get key, value proj - past_key_value, key_states, value_states = self._update_key_and_values(hidden_states, key_value_states, past_key_value, is_cross_attention, bsz) + past_key_value, key_states, value_states = self._update_key_and_values( + hidden_states, key_value_states, past_key_value, is_cross_attention, bsz + ) query_length = query_states.shape[1] tgt_len = key_states.shape[-2] @@ -337,7 +345,6 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value - class OPTSdpaAttention(OPTAttention): """ OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. @@ -358,7 +365,7 @@ def forward( if output_attentions: logger.warning_once("""OPTModel is using SDPA attention, which currently does not support output_attentions=True. failing back to eager attention. remove warning using attn_implementation="eager".""") - + return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, @@ -366,7 +373,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, key_value_states=key_value_states, - ) # TODO after merge add position_ids=position_ids + ) # TODO after merge add position_ids=position_ids is_cross_attention = key_value_states is not None bsz, q_len, _ = hidden_states.size() @@ -375,7 +382,9 @@ def forward( query_states = _shape(query_states, -1, bsz, self.num_heads, self.head_dim) # get key, value proj - past_key_value, key_states, value_states = self._update_key_and_values(hidden_states, key_value_states, past_key_value, is_cross_attention, bsz) + past_key_value, key_states, value_states = self._update_key_and_values( + hidden_states, key_value_states, past_key_value, is_cross_attention, bsz + ) # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous @@ -383,7 +392,6 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False @@ -663,18 +671,18 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - - def _update_causal_mask(self, - inputs_embeds: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - ): + def _update_causal_mask( + self, + inputs_embeds: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + ): """ Updates the causal mask for the decoder. """ - batch_size,seq_length = input_shape + batch_size, seq_length = input_shape mask_seq_length = past_key_values_length + seq_length if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -694,7 +702,7 @@ def _update_causal_mask(self, past_key_values_length=past_key_values_length, is_training=self.training, ): - return None,None + return None, None if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) @@ -703,11 +711,15 @@ def _update_causal_mask(self, f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) - + if self._use_sdpa: - causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(attention_mask,input_shape,inputs_embeds,past_key_values_length) + causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) else: - causal_attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) return causal_attention_mask, attention_mask @@ -778,7 +790,6 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -794,14 +805,11 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - causal_attention_mask, attention_mask = self._update_causal_mask(inputs_embeds, - input_shape, - past_key_values_length, - attention_mask, - output_attentions) + + causal_attention_mask, attention_mask = self._update_causal_mask( + inputs_embeds, input_shape, past_key_values_length, attention_mask, output_attentions + ) # embed positions - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) From be32f920f2e2a5e3d713a7f39963cefbe0c52460 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Wed, 4 Sep 2024 15:20:05 +0300 Subject: [PATCH 04/54] bug fix --- src/transformers/models/opt/modeling_opt.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 10e67dd61028..00de69f768c5 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -285,7 +285,7 @@ def forward( bsz, _, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) # TODO check if scaling is needed? is this a bug? + query_states = self.q_proj(hidden_states) # get key, value proj past_key_value, key_states, value_states = self._update_key_and_values( hidden_states, key_value_states, past_key_value, is_cross_attention, bsz @@ -403,6 +403,10 @@ def forward( attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, + # this model uses the scaling factor in the query projection for some reason, but not in Q@K^T + # so we need to scale to remove scaling in SDPA to have similar results with eager. + # Maybe needs a change in the model to remove scaling in query projection + scale=1.0, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -694,16 +698,7 @@ def _update_causal_mask( ) return causal_attention_mask, attention_mask - - if self._use_sdpa and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - is_training=self.training, - ): - return None, None - + if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) elif attention_mask.shape[1] != mask_seq_length: From 80639948e6e67e114158dd007b584e169001fb08 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Wed, 4 Sep 2024 15:20:45 +0300 Subject: [PATCH 05/54] add sdpa and attention generate test --- tests/models/opt/test_modeling_opt.py | 50 ++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 83721f1281f4..fbd9e72bfbad 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -21,7 +21,7 @@ import timeout_decorator # noqa from transformers import OPTConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, require_torch_sdpa, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -322,6 +322,54 @@ def test_opt_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + max_new_tokens = 30 + _, input_dict = self.model_tester.prepare_config_and_inputs() + model_sdpa = OPTForCausalLM.from_pretrained("facebook/opt-125M", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="sdpa", + ).to(torch_device) + + input_dict["attention_mask"] = torch.ones_like(input_dict["input_ids"]) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = OPTForCausalLM.from_pretrained( + "facebook/opt-125M", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for _, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for _, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + res_eager = model_eager.generate(**input_dict, max_new_tokens=max_new_tokens, do_sample=False) + res_sdpa = model_sdpa.generate(**input_dict, max_new_tokens=max_new_tokens, do_sample=False) + + torch.testing.assert_close( + res_eager, + res_sdpa, + msg=f"\n{res_eager} \nvs\n{res_sdpa}", + ) + @unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.") def test_model_parallelism(self): super().test_model_parallelism() From 248029a60096364cee2be722318b99ea11c4fa5e Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Wed, 4 Sep 2024 15:22:32 +0300 Subject: [PATCH 06/54] fixup --- src/transformers/models/opt/modeling_opt.py | 9 ++++----- tests/models/opt/test_modeling_opt.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 00de69f768c5..bc5a9f733a04 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -23,7 +23,6 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) @@ -285,7 +284,7 @@ def forward( bsz, _, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) + query_states = self.q_proj(hidden_states) # get key, value proj past_key_value, key_states, value_states = self._update_key_and_values( hidden_states, key_value_states, past_key_value, is_cross_attention, bsz @@ -404,9 +403,9 @@ def forward( dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, # this model uses the scaling factor in the query projection for some reason, but not in Q@K^T - # so we need to scale to remove scaling in SDPA to have similar results with eager. + # so we need to scale to remove scaling in SDPA to have similar results with eager. # Maybe needs a change in the model to remove scaling in query projection - scale=1.0, + scale=1.0, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -698,7 +697,7 @@ def _update_causal_mask( ) return causal_attention_mask, attention_mask - + if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) elif attention_mask.shape[1] != mask_seq_length: diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index fbd9e72bfbad..430d9b51fe2a 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -21,7 +21,14 @@ import timeout_decorator # noqa from transformers import OPTConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, require_torch_sdpa, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_torch_fp16, + require_torch_sdpa, + slow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -330,7 +337,8 @@ def test_eager_matches_sdpa_generate(self): """ max_new_tokens = 30 _, input_dict = self.model_tester.prepare_config_and_inputs() - model_sdpa = OPTForCausalLM.from_pretrained("facebook/opt-125M", + model_sdpa = OPTForCausalLM.from_pretrained( + "facebook/opt-125M", torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="sdpa", From b66e3d856c31f57d6182dc1f887b31804a7dee2c Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 8 Sep 2024 11:58:33 +0300 Subject: [PATCH 07/54] Refactor OPTAttention forward method for improved readability and maintainability --- src/transformers/models/opt/modeling_opt.py | 131 +++++++++++++------- 1 file changed, 87 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index bc5a9f733a04..2d1849b44760 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -63,10 +63,6 @@ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" -def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int) -> torch.Tensor: - return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() - - class OPTLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -123,25 +119,47 @@ def __init__( self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) - def _update_key_and_values(self, hidden_states, key_value_states, past_key_value, is_cross_attention, bsz): + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) - value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -152,34 +170,9 @@ def _update_key_and_values(self, hidden_states, key_value_states, past_key_value # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - return past_key_value, key_states, value_states - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - past_key_value, key_states, value_states = self._update_key_and_values( - hidden_states, key_value_states, past_key_value, is_cross_attention, bsz - ) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim).view(*proj_shape) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -286,9 +279,34 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) # get key, value proj - past_key_value, key_states, value_states = self._update_key_and_values( - hidden_states, key_value_states, past_key_value, is_cross_attention, bsz - ) + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) query_length = query_states.shape[1] tgt_len = key_states.shape[-2] @@ -378,12 +396,37 @@ def forward( bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) * self.scaling - query_states = _shape(query_states, -1, bsz, self.num_heads, self.head_dim) + query_states = self._shape(query_states, -1, bsz) # get key, value proj - past_key_value, key_states, value_states = self._update_key_and_values( - hidden_states, key_value_states, past_key_value, is_cross_attention, bsz - ) + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous From 579d60e86d3e29d725fd76ae06c38c96a15d5850 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 8 Sep 2024 12:00:43 +0300 Subject: [PATCH 08/54] undo refactor for _shape and key,val states --- src/transformers/models/opt/modeling_opt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 2d1849b44760..8a918ee2bc03 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -415,8 +415,8 @@ def forward( value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. From b1053765f3dfd50cea628e793cff57a9605b5b06 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 10 Sep 2024 12:53:08 +0300 Subject: [PATCH 09/54] add OPT to doc, fixup didn't find it for some reason --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 3517d93bfc8a..ebddee9dd3b9 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -263,7 +263,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) * [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) - +* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt) From c349632841c71e1f0f775fb220816b5760809ec2 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 10 Sep 2024 13:02:47 +0300 Subject: [PATCH 10/54] change order --- docs/source/en/perf_infer_gpu_one.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ebddee9dd3b9..aa7fb74f11e3 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -229,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) @@ -263,7 +264,6 @@ For now, Transformers supports SDPA inference and training for the following arc * [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) * [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) -* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt) From 6dba8b0049efe591f3557844dbcd2a4a22a41e74 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Wed, 11 Sep 2024 12:51:13 +0300 Subject: [PATCH 11/54] change default attn_implemntation in testing to eager --- tests/models/opt/test_modeling_opt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 430d9b51fe2a..6c82ea57b67a 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -142,6 +142,7 @@ def get_config(self): embed_dim=self.embed_dim, is_encoder_decoder=False, word_embed_proj_dim=self.word_embed_proj_dim, + attn_implementation="eager", ) def get_pipeline_config(self): From 1d21751f49b6f35ef3652f018c3892b4dc37c367 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Mon, 16 Sep 2024 13:30:14 +0300 Subject: [PATCH 12/54] [run-slow] opt From 7233fda57c83d468f7a6936a5617424294408d27 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 17 Sep 2024 13:49:55 +0300 Subject: [PATCH 13/54] change test_eager_matches_sdpa_generate to the one llama --- tests/models/opt/test_modeling_opt.py | 34 +++++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 6c82ea57b67a..fb3d09117446 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -337,15 +337,22 @@ def test_eager_matches_sdpa_generate(self): Overwritting the common test as the test is flaky on tiny models """ max_new_tokens = 30 - _, input_dict = self.model_tester.prepare_config_and_inputs() + + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-125M") + + texts = [ + "hi here's a longer context, getting longer and", + "Hello this is a very long sentence my friend, very long for real", + "Today I am in Paris and", + ] + model_sdpa = OPTForCausalLM.from_pretrained( "facebook/opt-125M", torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="sdpa", ).to(torch_device) - - input_dict["attention_mask"] = torch.ones_like(input_dict["input_ids"]) + print(model_sdpa.config.eos_token_id) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") @@ -370,14 +377,21 @@ def test_eager_matches_sdpa_generate(self): if not has_sdpa: raise ValueError("The SDPA model should have SDPA attention layers") - res_eager = model_eager.generate(**input_dict, max_new_tokens=max_new_tokens, do_sample=False) - res_sdpa = model_sdpa.generate(**input_dict, max_new_tokens=max_new_tokens, do_sample=False) + for padding_side in ["left", "right"]: + tokenizer.padding_side = padding_side + tokenizer.pad_token = tokenizer.eos_token - torch.testing.assert_close( - res_eager, - res_sdpa, - msg=f"\n{res_eager} \nvs\n{res_sdpa}", - ) + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + + with self.subTest(f"{padding_side}"): + torch.testing.assert_close( + res_eager, + res_sdpa, + msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + ) @unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.") def test_model_parallelism(self): From 9bacdeb21fd628bf798eae267ff448f8dfba3e17 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 17 Sep 2024 14:08:19 +0300 Subject: [PATCH 14/54] Update default attention implementation in testing common --- tests/models/opt/test_modeling_tf_opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py index 158baa4ce65e..284ecdd49de3 100644 --- a/tests/models/opt/test_modeling_tf_opt.py +++ b/tests/models/opt/test_modeling_tf_opt.py @@ -42,7 +42,7 @@ def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=No @require_tf class TFOPTModelTester: config_cls = OPTConfig - config_updates = {} + config_updates = {"attn_implementation": "eager"} hidden_act = "gelu" def __init__( From 5b38f7846af380da7fade33c0d287d5ec917fe55 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 17 Sep 2024 14:27:24 +0300 Subject: [PATCH 15/54] [run-slow] opt From 3f24a0479a5a8a9aa58cd2d46071b56d48e7c8f5 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 17 Sep 2024 15:51:11 +0300 Subject: [PATCH 16/54] remove uneeded print --- tests/models/opt/test_modeling_opt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index fb3d09117446..cf2f4c07da03 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -352,7 +352,6 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, attn_implementation="sdpa", ).to(torch_device) - print(model_sdpa.config.eos_token_id) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") From 2efd25aa698da424e7697564388ca2df38450cc3 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Tue, 17 Sep 2024 15:51:39 +0300 Subject: [PATCH 17/54] [run-slow] opt From bdd9cb235c576c85f84fb545111aa94f5873c826 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Wed, 18 Sep 2024 12:21:19 +0300 Subject: [PATCH 18/54] refactor model testers to have attn_implementation="eager" --- tests/models/opt/test_modeling_flax_opt.py | 3 +++ tests/models/opt/test_modeling_opt.py | 4 +++- tests/models/opt/test_modeling_tf_opt.py | 5 ++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py index ef94633f22a8..5ebf23d86a32 100644 --- a/tests/models/opt/test_modeling_flax_opt.py +++ b/tests/models/opt/test_modeling_flax_opt.py @@ -70,6 +70,7 @@ def __init__( embed_dim=16, word_embed_proj_dim=16, initializer_range=0.02, + attn_implemetation="eager", ): self.parent = parent self.batch_size = batch_size @@ -92,6 +93,7 @@ def __init__( self.word_embed_proj_dim = word_embed_proj_dim self.initializer_range = initializer_range self.is_encoder_decoder = False + self.attn_implementation = attn_implemetation def prepare_config_and_inputs(self): input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size) @@ -114,6 +116,7 @@ def prepare_config_and_inputs(self): word_embed_proj_dim=self.word_embed_proj_dim, initializer_range=self.initializer_range, use_cache=False, + attn_implementation=self.attn_implementation, ) inputs_dict = prepare_opt_inputs_dict(config, input_ids) return config, inputs_dict diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index cf2f4c07da03..4ecaed432c11 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -90,6 +90,7 @@ def __init__( num_labels=3, word_embed_proj_dim=16, type_sequence_label_size=2, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -113,6 +114,7 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.word_embed_proj_dim = word_embed_proj_dim self.is_encoder_decoder = False + self.attn_implementation = attn_implementation def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( @@ -142,7 +144,7 @@ def get_config(self): embed_dim=self.embed_dim, is_encoder_decoder=False, word_embed_proj_dim=self.word_embed_proj_dim, - attn_implementation="eager", + attn_implementation=self.attn_implementation, ) def get_pipeline_config(self): diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py index 284ecdd49de3..39c38170e3f6 100644 --- a/tests/models/opt/test_modeling_tf_opt.py +++ b/tests/models/opt/test_modeling_tf_opt.py @@ -42,7 +42,7 @@ def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=No @require_tf class TFOPTModelTester: config_cls = OPTConfig - config_updates = {"attn_implementation": "eager"} + config_updates = {} hidden_act = "gelu" def __init__( @@ -66,6 +66,7 @@ def __init__( bos_token_id=0, embed_dim=16, word_embed_proj_dim=16, + attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -87,6 +88,7 @@ def __init__( self.embed_dim = embed_dim self.word_embed_proj_dim = word_embed_proj_dim self.is_encoder_decoder = False + self.attn_implementation = attn_implementation def prepare_config_and_inputs_for_common(self): input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size) @@ -108,6 +110,7 @@ def prepare_config_and_inputs_for_common(self): embed_dim=self.embed_dim, word_embed_proj_dim=self.word_embed_proj_dim, is_encoder_decoder=False, + attn_implementation=self.attn_implementation, **self.config_updates, ) inputs_dict = prepare_opt_inputs_dict(config, input_ids) From f80e3b3c61627749e6832b0e0a207d896319d3ae Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Wed, 18 Sep 2024 12:23:05 +0300 Subject: [PATCH 19/54] [run-slow] opt From 7ea22eb75e2eb60861652244a2811fe13b26e0fa Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 22 Sep 2024 11:40:58 +0300 Subject: [PATCH 20/54] convert test_eager_matches_sdpa_generate to opt-350M --- tests/models/opt/test_modeling_opt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 4ecaed432c11..eff2ff67f8f9 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -340,7 +340,7 @@ def test_eager_matches_sdpa_generate(self): """ max_new_tokens = 30 - tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-125M") + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350M") texts = [ "hi here's a longer context, getting longer and", @@ -349,7 +349,7 @@ def test_eager_matches_sdpa_generate(self): ] model_sdpa = OPTForCausalLM.from_pretrained( - "facebook/opt-125M", + "facebook/opt-350M", torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="sdpa", @@ -358,7 +358,7 @@ def test_eager_matches_sdpa_generate(self): self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") model_eager = OPTForCausalLM.from_pretrained( - "facebook/opt-125M", + "facebook/opt-350M", torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="eager", From b5547e7ce6b68c1bc2d4a637d896b41e9a90178f Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 22 Sep 2024 11:46:42 +0300 Subject: [PATCH 21/54] bug fix when creating mask for opt --- src/transformers/models/opt/modeling_opt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8a918ee2bc03..d524b8c6895a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -748,8 +748,7 @@ def _update_causal_mask( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) - - if self._use_sdpa: + if self._use_sdpa and not output_attentions: causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, past_key_values_length ) From 668e291cede51c99b79e7e90922ebc6a2b1860cb Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 22 Sep 2024 11:55:56 +0300 Subject: [PATCH 22/54] [run-slow] opt From d9d3bb3ab58e88ab454d4b40411ba300090972e0 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 22 Sep 2024 15:04:41 +0300 Subject: [PATCH 23/54] if layer head mask default to eager --- src/transformers/models/opt/modeling_opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index d524b8c6895a..83884e0080e0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -379,7 +379,7 @@ def forward( output_attentions: bool = False, position_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: + if output_attentions or layer_head_mask is not None: logger.warning_once("""OPTModel is using SDPA attention, which currently does not support output_attentions=True. failing back to eager attention. remove warning using attn_implementation="eager".""") From 388d663f2c0717a6d344ac632a9499fb78ebee97 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 22 Sep 2024 15:09:58 +0300 Subject: [PATCH 24/54] if head mask is not none fall to eager --- src/transformers/models/opt/modeling_opt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 83884e0080e0..ceaa186d3cc5 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -723,6 +723,7 @@ def _update_causal_mask( input_shape: Tuple[int, int], past_key_values_length: int, attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, ): """ @@ -748,7 +749,7 @@ def _update_causal_mask( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) - if self._use_sdpa and not output_attentions: + if self._use_sdpa and not output_attentions and head_mask is None: causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -843,7 +844,7 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 causal_attention_mask, attention_mask = self._update_causal_mask( - inputs_embeds, input_shape, past_key_values_length, attention_mask, output_attentions + inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions ) # embed positions From e735ec4b9177733a9da2e78a241c776adf46dd55 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies Date: Sun, 22 Sep 2024 15:10:27 +0300 Subject: [PATCH 25/54] [run-slow] opt From f94d5742b02ccaacc078353d28c9a31b7d30f9f1 Mon Sep 17 00:00:00 2001 From: Avishai Elmakies <36810152+avishaiElmakies@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:05:26 +0300 Subject: [PATCH 26/54] Update src/transformers/models/opt/modeling_opt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/opt/modeling_opt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index ceaa186d3cc5..f393e0e6e2b2 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -380,8 +380,10 @@ def forward( position_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions or layer_head_mask is not None: - logger.warning_once("""OPTModel is using SDPA attention, which currently does not support output_attentions=True. - failing back to eager attention. remove warning using attn_implementation="eager".""") + logger.warning_once( + "OPTModel is using SDPA attention, which currently does not support output_attentions=True." + 'failing back to eager attention. remove warning using attn_implementation="eager".' + ) return super().forward( hidden_states=hidden_states, From e734d9d1d2c39a861f21513c868a1161f76e2802 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Mon, 23 Sep 2024 10:21:17 +0200 Subject: [PATCH 27/54] Clean up Unpack imports (#33631) clean up Unpack imports --- .../grounding_dino/processing_grounding_dino.py | 10 +--------- src/transformers/models/llava/processing_llava.py | 8 +------- .../llava_onevision/processing_llava_onevision.py | 12 +----------- .../models/pixtral/processing_pixtral.py | 8 +------- .../models/qwen2_vl/processing_qwen2_vl.py | 11 +---------- tests/test_processing_common.py | 1 + 6 files changed, 6 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/grounding_dino/processing_grounding_dino.py b/src/transformers/models/grounding_dino/processing_grounding_dino.py index 00c183338be0..2b5769928518 100644 --- a/src/transformers/models/grounding_dino/processing_grounding_dino.py +++ b/src/transformers/models/grounding_dino/processing_grounding_dino.py @@ -17,20 +17,12 @@ """ import pathlib -import sys from typing import Dict, List, Optional, Tuple, Union from ...image_processing_utils import BatchFeature from ...image_transforms import center_to_corners_format from ...image_utils import AnnotationFormat, ImageInput -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin - - -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput from ...utils import TensorType, is_torch_available diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 28a9410e6cbf..8a9597892c60 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -16,21 +16,15 @@ Processor class for Llava. """ -import sys from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index 2db0ba50c210..f9d550e789d8 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -18,22 +18,12 @@ import math import os -import sys from typing import Iterable, List, Union - -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - from ...feature_extraction_utils import BatchFeature from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array -from ...processing_utils import ( - ProcessingKwargs, - ProcessorMixin, -) +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging from ..auto import AutoImageProcessor diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 1b07aa02771d..d336760c6d9c 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -16,21 +16,15 @@ Processor class for Pixtral. """ -import sys from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image -from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 591b82f053c8..48516e6aa31d 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -23,18 +23,9 @@ from typing import List, Union - -try: - from typing import Unpack -except ImportError: - from typing_extensions import Unpack - from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput -from ...processing_utils import ( - ProcessingKwargs, - ProcessorMixin, -) +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index a51c1d200eb0..0a4abe8656e8 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -21,6 +21,7 @@ import numpy as np from transformers.models.auto.processing_auto import processor_class_from_name +from transformers.processing_utils import Unpack from transformers.testing_utils import ( check_json_file_has_correct_format, require_torch, From 34593ba967a77e480932d501c0c240aa8640d387 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Mon, 23 Sep 2024 11:49:16 +0200 Subject: [PATCH 28/54] Fix DPT /Dinov2 sdpa regression on main (#33660) * fallback to eager if output attentions. * fix copies --- src/transformers/models/dinov2/modeling_dinov2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index ebe322618a2d..bae21dacb95b 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -231,7 +231,6 @@ def forward( return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2 class Dinov2SdpaSelfAttention(Dinov2SelfAttention): def __init__(self, config: Dinov2Config) -> None: super().__init__(config) @@ -240,6 +239,16 @@ def __init__(self, config: Dinov2Config) -> None: def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Dinov2Model is using Dinov2SdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) From 6889d696c4976687d2aa5fed1a520f9f5128cbbf Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:38:52 +0200 Subject: [PATCH 29/54] handle dependency errors in check_imports (#33622) * handle dependency errors in check_imports * change log level to warning --- src/transformers/dynamic_module_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 07cb5940dc4b..4e0e1dd34302 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -183,8 +183,15 @@ def check_imports(filename: Union[str, os.PathLike]) -> List[str]: for imp in imports: try: importlib.import_module(imp) - except ImportError: - missing_packages.append(imp) + except ImportError as exception: + logger.warning(f"Encountered exception while importing {imp}: {exception}") + # Some packages can fail with an ImportError because of a dependency issue. + # This check avoids hiding such errors. + # See https://github.com/huggingface/transformers/issues/33604 + if "No module named" in str(exception): + missing_packages.append(imp) + else: + raise if len(missing_packages) > 0: raise ImportError( From d488c33b62e4f1bc1453e2879506746f4f9d24ce Mon Sep 17 00:00:00 2001 From: chengchengpei <5881383+chengchengpei@users.noreply.github.com> Date: Mon, 23 Sep 2024 03:54:58 -0700 Subject: [PATCH 30/54] add back self.max_position_embeddings = config.max_position_embeddings (#33550) * add back self.max_position_embeddings = config.max_position_embeddings * fix-copies --- src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index aafecb95b6aa..1e79115d3470 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -310,6 +310,7 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index bc06b406bf43..c9ee7b5f57a1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -388,6 +388,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout From 99909159e3e47d079f5159d17673b1ea51c720a8 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 23 Sep 2024 19:07:15 +0800 Subject: [PATCH 31/54] Fix Llava conversion for LlavaQwen2ForCausalLM with Clip vision tower (#33613) fix llavaqwen2 model conversion --- .../models/llava/convert_llava_weights_to_hf.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llava/convert_llava_weights_to_hf.py b/src/transformers/models/llava/convert_llava_weights_to_hf.py index 9841b7cb3d19..b8d936e8cc44 100644 --- a/src/transformers/models/llava/convert_llava_weights_to_hf.py +++ b/src/transformers/models/llava/convert_llava_weights_to_hf.py @@ -76,7 +76,9 @@ def load_original_state_dict(model_id): if "lm_head.weight" not in original_state_dict: original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone() - del original_state_dict["model.image_newline"] # not used in the original implementation because "merge_type=flat" + if "model.image_newline" in original_state_dict: + # not used in the original implementation because "merge_type=flat" + del original_state_dict["model.image_newline"] return original_state_dict @@ -107,7 +109,7 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o image_processor = AutoImageProcessor.from_pretrained(vision_model_id) processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) - if "Qwen" in text_model_id: + if "siglip" in vision_model_id: vision_config = SiglipVisionConfig( hidden_size=1152, image_size=384, @@ -128,8 +130,9 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o # llms-lab interleeave models do not use any selection startegy except for last hidden state if "Qwen" in text_model_id: config.image_token_index = 151646 - config.vision_feature_select_strategy = "full" - config.vision_feature_layer = -1 + if "siglip" in vision_model_id: + config.vision_feature_select_strategy = "full" + config.vision_feature_layer = -1 else: config.pad_token_id = 32001 config.image_token_index = 32000 From 3720ecab98fee9b163f6bd7b4164bc5a43b59882 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:47:32 -0400 Subject: [PATCH 32/54] Uniformize kwargs for Udop processor and update docs (#33628) * Add optional kwargs and uniformize udop * cleanup Unpack * nit Udop --- src/transformers/models/udop/modeling_udop.py | 2 +- .../models/udop/processing_udop.py | 159 ++++++++++-------- tests/models/udop/test_processor_udop.py | 37 ++-- 3 files changed, 110 insertions(+), 88 deletions(-) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 972248daaae5..6f7b6cf06049 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1790,7 +1790,7 @@ def forward( >>> # one can use the various task prefixes (prompts) used during pre-training >>> # e.g. the task prefix for DocVQA is "Question answering. " >>> question = "Question answering. What is the date on the form?" - >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt") + >>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt") >>> # autoregressive generation >>> predicted_ids = model.generate(**encoding) diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index 2902541d6f5b..ddd5d484a988 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -18,10 +18,38 @@ from typing import List, Optional, Union +from transformers import logging + +from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +logger = logging.get_logger(__name__) + + +class UdopTextKwargs(TextKwargs, total=False): + word_labels: Optional[Union[List[int], List[List[int]]]] + boxes: Union[List[List[int]], List[List[List[int]]]] + + +class UdopProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: UdopTextKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "truncation": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } class UdopProcessor(ProcessorMixin): @@ -49,6 +77,8 @@ class UdopProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "LayoutLMv3ImageProcessor" tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["text_pair"] def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) @@ -57,28 +87,16 @@ def __call__( self, images: Optional[ImageInput] = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, - boxes: Union[List[List[int]], List[List[List[int]]]] = None, - word_labels: Optional[Union[List[int], List[List[int]]]] = None, - text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair_target: Optional[ - Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] - ] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchEncoding: + # The following is to capture `text_pair` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + # + *args, + audio=None, + videos=None, + **kwargs: Unpack[UdopProcessorKwargs], + ) -> BatchFeature: """ This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case [`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and @@ -93,6 +111,20 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ # verify input + output_kwargs = self._merge_kwargs( + UdopProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False) + return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False) + text_target = output_kwargs["text_kwargs"].get("text_target", None) + if self.image_processor.apply_ocr and (boxes is not None): raise ValueError( "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." @@ -103,69 +135,47 @@ def __call__( "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." ) - if return_overflowing_tokens is True and return_offsets_mapping is False: + if return_overflowing_tokens and not return_offsets_mapping: raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") if text_target is not None: # use the processor to prepare the targets of UDOP return self.tokenizer( - text_target=text_target, - text_pair_target=text_pair_target, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, + **output_kwargs["text_kwargs"], ) else: # use the processor to prepare the inputs of UDOP # first, apply the image processor - features = self.image_processor(images=images, return_tensors=return_tensors) + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + features_words = features.pop("words", None) + features_boxes = features.pop("boxes", None) + + output_kwargs["text_kwargs"].pop("text_target", None) + output_kwargs["text_kwargs"].pop("text_pair_target", None) + output_kwargs["text_kwargs"]["text_pair"] = text_pair + output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes + output_kwargs["text_kwargs"]["word_labels"] = word_labels # second, apply the tokenizer if text is not None and self.image_processor.apply_ocr and text_pair is None: if isinstance(text, str): text = [text] # add batch dimension (as the image processor always adds a batch dimension) - text_pair = features["words"] + output_kwargs["text_kwargs"]["text_pair"] = features_words encoded_inputs = self.tokenizer( - text=text if text is not None else features["words"], - text_pair=text_pair if text_pair is not None else None, - boxes=boxes if boxes is not None else features["boxes"], - word_labels=word_labels, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, + text=text if text is not None else features_words, + **output_kwargs["text_kwargs"], ) # add pixel values - pixel_values = features.pop("pixel_values") if return_overflowing_tokens is True: - pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"]) - encoded_inputs["pixel_values"] = pixel_values + features["pixel_values"] = self.get_overflowing_images( + features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"] + ) + features.update(encoded_inputs) - return encoded_inputs + return features # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images def get_overflowing_images(self, images, overflow_to_sample_mapping): @@ -198,7 +208,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property - # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names def model_input_names(self): - return ["input_ids", "bbox", "attention_mask", "pixel_values"] + return ["pixel_values", "input_ids", "bbox", "attention_mask"] diff --git a/tests/models/udop/test_processor_udop.py b/tests/models/udop/test_processor_udop.py index 749ec7c3d6df..621b761b5f17 100644 --- a/tests/models/udop/test_processor_udop.py +++ b/tests/models/udop/test_processor_udop.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os import shutil import tempfile import unittest @@ -34,7 +32,7 @@ require_torch, slow, ) -from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available, is_torch_available +from transformers.utils import cached_property, is_pytesseract_available, is_torch_available from ...test_processing_common import ProcessorTesterMixin @@ -55,20 +53,19 @@ class UdopProcessorTest(ProcessorTesterMixin, unittest.TestCase): tokenizer_class = UdopTokenizer rust_tokenizer_class = UdopTokenizerFast - maxDiff = None processor_class = UdopProcessor + maxDiff = None def setUp(self): - image_processor_map = { - "do_resize": True, - "size": 224, - "apply_ocr": True, - } - self.tmpdirname = tempfile.mkdtemp() - self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(image_processor_map) + "\n") + image_processor = LayoutLMv3ImageProcessor( + do_resize=True, + size=224, + apply_ocr=True, + ) + tokenizer = UdopTokenizer.from_pretrained("microsoft/udop-large") + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname) self.tokenizer_pretrained_name = "microsoft/udop-large" @@ -80,15 +77,15 @@ def setUp(self): def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs) + def get_image_processor(self, **kwargs): + return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs) def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]: return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)] - def get_image_processor(self, **kwargs): - return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs) - def tearDown(self): shutil.rmtree(self.tmpdirname) @@ -153,7 +150,7 @@ def test_model_input_names(self): input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor(text=input_str, images=image_input) + inputs = processor(images=image_input, text=input_str) self.assertListEqual(list(inputs.keys()), processor.model_input_names) @@ -472,7 +469,7 @@ def test_processor_case_5(self): question = "What's his name?" words = ["hello", "world"] boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] - input_processor = processor(images[0], question, words, boxes, return_tensors="pt") + input_processor = processor(images[0], question, text_pair=words, boxes=boxes, return_tensors="pt") # verify keys expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] @@ -488,7 +485,9 @@ def test_processor_case_5(self): questions = ["How old is he?", "what's the time"] words = [["hello", "world"], ["my", "name", "is", "niels"]] boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] - input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt") + input_processor = processor( + images, questions, text_pair=words, boxes=boxes, padding=True, return_tensors="pt" + ) # verify keys expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] From 9b11d28973efc257b8326954c608a841b3680646 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 23 Sep 2024 18:28:36 +0100 Subject: [PATCH 33/54] Generation: deprecate `PreTrainedModel` inheriting from `GenerationMixin` (#33203) --- src/transformers/generation/utils.py | 34 ++++++------------ src/transformers/modeling_utils.py | 36 ++++++++++++++----- .../models/albert/modeling_albert.py | 3 +- src/transformers/models/auto/auto_factory.py | 35 ++++++++++++++++++ src/transformers/models/bark/modeling_bark.py | 3 +- src/transformers/models/bart/modeling_bart.py | 5 +-- src/transformers/models/bert/modeling_bert.py | 3 +- .../modeling_bert_generation.py | 3 +- .../models/big_bird/modeling_big_bird.py | 3 +- .../modeling_bigbird_pegasus.py | 5 +-- .../models/biogpt/modeling_biogpt.py | 3 +- .../models/blenderbot/modeling_blenderbot.py | 5 +-- .../modeling_blenderbot_small.py | 5 +-- src/transformers/models/blip/modeling_blip.py | 3 +- .../models/blip/modeling_blip_text.py | 3 +- .../models/blip_2/modeling_blip_2.py | 3 +- .../models/bloom/modeling_bloom.py | 3 +- .../models/camembert/modeling_camembert.py | 3 +- .../models/chameleon/modeling_chameleon.py | 3 +- src/transformers/models/clvp/modeling_clvp.py | 6 ++-- .../models/codegen/modeling_codegen.py | 3 +- .../models/cohere/modeling_cohere.py | 3 +- .../models/cpmant/modeling_cpmant.py | 3 +- src/transformers/models/ctrl/modeling_ctrl.py | 3 +- .../models/data2vec/modeling_data2vec_text.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 3 +- .../models/electra/modeling_electra.py | 3 +- .../models/ernie/modeling_ernie.py | 3 +- .../models/falcon/modeling_falcon.py | 3 +- .../falcon_mamba/modeling_falcon_mamba.py | 3 +- .../models/flaubert/modeling_flaubert.py | 3 +- src/transformers/models/fsmt/modeling_fsmt.py | 3 +- src/transformers/models/fuyu/modeling_fuyu.py | 3 +- src/transformers/models/gemma/diff_gemma.py | 3 +- .../models/gemma/modeling_gemma.py | 3 +- src/transformers/models/gemma2/diff_gemma2.py | 3 +- .../models/gemma2/modeling_gemma2.py | 3 +- src/transformers/models/git/modeling_git.py | 3 +- src/transformers/models/gpt2/modeling_gpt2.py | 5 +-- .../gpt_bigcode/modeling_gpt_bigcode.py | 3 +- .../models/gpt_neo/modeling_gpt_neo.py | 3 +- .../models/gpt_neox/modeling_gpt_neox.py | 3 +- .../modeling_gpt_neox_japanese.py | 3 +- src/transformers/models/gptj/modeling_gptj.py | 3 +- .../models/granite/modeling_granite.py | 3 +- .../models/granitemoe/modeling_granitemoe.py | 3 +- .../models/idefics2/modeling_idefics2.py | 5 +-- .../models/imagegpt/modeling_imagegpt.py | 3 +- .../instructblip/modeling_instructblip.py | 3 +- .../diff_instructblipvideo.py | 3 +- .../modeling_instructblipvideo.py | 3 +- .../models/jamba/modeling_jamba.py | 3 +- .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/kosmos2/modeling_kosmos2.py | 5 +-- src/transformers/models/led/modeling_led.py | 3 +- .../models/llama/modeling_llama.py | 3 +- .../models/llava/modeling_llava.py | 5 +-- .../models/llava_next/modeling_llava_next.py | 5 +-- .../llava_next_video/diff_llava_next_video.py | 3 +- .../modeling_llava_next_video.py | 5 +-- .../modeling_llava_onevision.py | 5 +-- .../models/longt5/modeling_longt5.py | 3 +- .../models/m2m_100/modeling_m2m_100.py | 3 +- .../models/mamba/modeling_mamba.py | 3 +- .../models/mamba2/modeling_mamba2.py | 3 +- .../models/marian/modeling_marian.py | 5 +-- .../models/mbart/modeling_mbart.py | 5 +-- .../megatron_bert/modeling_megatron_bert.py | 3 +- .../models/mistral/modeling_mistral.py | 3 +- .../models/mixtral/modeling_mixtral.py | 3 +- src/transformers/models/mpt/modeling_mpt.py | 3 +- src/transformers/models/mt5/modeling_mt5.py | 3 +- .../models/musicgen/modeling_musicgen.py | 15 +++++--- .../modeling_musicgen_melody.py | 15 +++++--- src/transformers/models/mvp/modeling_mvp.py | 5 +-- .../models/nemotron/modeling_nemotron.py | 3 +- .../models/nllb_moe/modeling_nllb_moe.py | 3 +- src/transformers/models/olmo/modeling_olmo.py | 3 +- .../models/olmoe/modeling_olmoe.py | 3 +- .../models/openai/modeling_openai.py | 3 +- src/transformers/models/opt/modeling_opt.py | 3 +- .../models/paligemma/modeling_paligemma.py | 3 +- .../models/pegasus/modeling_pegasus.py | 5 +-- .../models/pegasus_x/modeling_pegasus_x.py | 3 +- .../models/persimmon/modeling_persimmon.py | 3 +- src/transformers/models/phi/modeling_phi.py | 3 +- src/transformers/models/phi3/modeling_phi3.py | 3 +- .../models/pix2struct/modeling_pix2struct.py | 3 +- .../models/plbart/modeling_plbart.py | 5 +-- .../models/pop2piano/modeling_pop2piano.py | 3 +- .../models/prophetnet/modeling_prophetnet.py | 5 +-- .../models/qwen2/modeling_qwen2.py | 3 +- .../qwen2_audio/modeling_qwen2_audio.py | 5 +-- .../models/qwen2_moe/modeling_qwen2_moe.py | 3 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 3 +- .../modeling_recurrent_gemma.py | 3 +- .../models/reformer/modeling_reformer.py | 3 +- .../models/rembert/modeling_rembert.py | 3 +- .../models/roberta/modeling_roberta.py | 3 +- .../modeling_roberta_prelayernorm.py | 3 +- .../models/roc_bert/modeling_roc_bert.py | 3 +- .../models/roformer/modeling_roformer.py | 3 +- src/transformers/models/rwkv/modeling_rwkv.py | 3 +- .../seamless_m4t/modeling_seamless_m4t.py | 5 +-- .../modeling_seamless_m4t_v2.py | 5 +-- .../speech_to_text/modeling_speech_to_text.py | 3 +- .../models/stablelm/modeling_stablelm.py | 3 +- .../models/starcoder2/modeling_starcoder2.py | 3 +- .../modeling_switch_transformers.py | 3 +- src/transformers/models/t5/modeling_t5.py | 3 +- .../models/trocr/modeling_trocr.py | 3 +- src/transformers/models/udop/modeling_udop.py | 3 +- src/transformers/models/umt5/modeling_umt5.py | 3 +- .../video_llava/modeling_video_llava.py | 5 +-- .../models/vipllava/modeling_vipllava.py | 5 +-- .../models/whisper/generation_whisper.py | 4 +-- .../models/whisper/modeling_whisper.py | 3 +- src/transformers/models/xglm/modeling_xglm.py | 3 +- src/transformers/models/xlm/modeling_xlm.py | 3 +- .../xlm_roberta/modeling_xlm_roberta.py | 3 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 3 +- .../models/xlnet/modeling_xlnet.py | 3 +- src/transformers/models/xmod/modeling_xmod.py | 3 +- tests/generation/test_utils.py | 9 +++++ tests/models/auto/test_modeling_auto.py | 18 ++++++++++ tests/utils/test_modeling_utils.py | 27 ++++++++++++++ 126 files changed, 407 insertions(+), 184 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2fe92d3e3ed6..c1aa338a7d8f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,13 +34,6 @@ ) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput -from ..models.auto import ( - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, -) from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie from ..utils import ( @@ -1117,26 +1110,21 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ + # TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from + # `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can + # safely call `GenerationMixin.generate` if not is_torchdynamo_compiling() and not self.can_generate(): - generate_compatible_mappings = [ - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + terminations_with_generation_support = [ + "ForCausalLM", + "ForConditionalGeneration", + "ForSpeechSeq2Seq", + "ForVision2Seq", ] - generate_compatible_classes = set() - for model_mapping in generate_compatible_mappings: - supported_models = model_mapping.get(type(self.config), default=None) - if supported_models is not None: - generate_compatible_classes.add(supported_models.__name__) - exception_message = ( + raise TypeError( f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head." + "it doesn't have a language model head. Classes that support generation often end in one of these " + f"names: {terminations_with_generation_support}." ) - if generate_compatible_classes: - exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" - raise TypeError(exception_message) def _validate_assistant(self, assistant_model): if assistant_model is None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d40697666360..6fff23f6b6df 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) -def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device except StopIteration: @@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].device -def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): """ Returns the first parameter dtype (can be non-floating) or asserts if none were found. """ @@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].dtype -def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): """ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. """ @@ -1309,6 +1309,7 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +# TODO (joao): remove `GenerationMixin` inheritance in v4.50 class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -1638,11 +1639,30 @@ def can_generate(cls) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. - # Alternativelly, the model can also have a custom `generate` function. - if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): - return False - return True + # Directly inherits `GenerationMixin` -> can generate + if "GenerationMixin" in str(cls.__bases__): + return True + # Model class overwrites `generate` (e.g. time series models) -> can generate + if str(cls.__name__) in str(cls.generate): + return True + # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # was how we detected whether a model could generate. + if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): + logger.warning_once( + f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " + "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " + "to call `generate` and other related functions." + "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " + "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" + "\n - If you are the owner of the model architecture code, please modify your model class such that " + "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." + "\n - If you are not the owner of the model architecture class, please contact the model code owner " + "to update it." + ) + return True + # Otherwise, can't generate + return False @classmethod def _check_and_enable_flash_attn_2( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index dca1fe7f6002..bfd8e38687ac 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, @@ -983,7 +984,7 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: "Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ) -class AlbertForMaskedLM(AlbertPreTrainedModel): +class AlbertForMaskedLM(AlbertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 220ae97f5073..7809b2a6cc2c 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -29,12 +29,17 @@ extract_commit_hash, find_adapter_config_file, is_peft_available, + is_torch_available, logging, requires_backends, ) from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings +if is_torch_available(): + from ...generation import GenerationMixin + + logger = logging.get_logger(__name__) @@ -428,6 +433,7 @@ def from_config(cls, config, **kwargs): model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) + model_class = add_generation_mixin_to_remote_model(model_class) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) @@ -549,6 +555,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) _ = hub_kwargs.pop("code_revision", None) cls.register(config.__class__, model_class, exist_ok=True) + model_class = add_generation_mixin_to_remote_model(model_class) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) @@ -698,6 +705,34 @@ def getattribute_from_module(module, attr): raise ValueError(f"Could not find {attr} in {transformers_module}!") +def add_generation_mixin_to_remote_model(model_class): + """ + Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. + + This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make + `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded + from the Hub may not have the `generate` method after we remove the inheritance. + """ + # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing + if "torch.nn.modules.module.Module" not in str(model_class.__mro__): + return model_class + + # 2. If it already **directly** inherits from GenerationMixin, do nothing + if "GenerationMixin" in str(model_class.__bases__): + return model_class + + # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or + # `prepare_inputs_for_generation` method. + has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate")) + has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation")) + if has_custom_generate or has_custom_prepare_inputs: + model_class_with_generation_mixin = type( + model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} + ) + return model_class_with_generation_mixin + return model_class + + class _LazyAutoMapping(OrderedDict): """ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 5aad7b23a8a6..3102ada542d5 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import functional as F +from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, BarkEosPrioritizerLogitsProcessor, @@ -546,7 +547,7 @@ def device(self) -> torch.device: # GPT2-like autoregressive model -class BarkCausalModel(BarkPreTrainedModel): +class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): config_class = BarkSubModelConfig def __init__(self, config): diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fa928d05caa8..2e4e6dcaeb2d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1557,7 +1558,7 @@ def forward( @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING ) -class BartForConditionalGeneration(BartPreTrainedModel): +class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"] @@ -2010,7 +2011,7 @@ def forward(self, *args, **kwargs): """, BART_START_DOCSTRING, ) -class BartForCausalLM(BartPreTrainedModel): +class BartForCausalLM(BartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 93d6d469b512..b62746da5c6f 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1280,7 +1281,7 @@ def forward( @add_start_docstrings( """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING ) -class BertLMHeadModel(BertPreTrainedModel): +class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index a5fb3d053115..8496d1f6072f 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -863,7 +864,7 @@ def _tie_weights(self): """BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_GENERATION_START_DOCSTRING, ) -class BertGenerationDecoder(BertGenerationPreTrainedModel): +class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index a6b1660d5ae1..41045cb5f000 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -2495,7 +2496,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING ) -class BigBirdForCausalLM(BigBirdPreTrainedModel): +class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9f8e3cd19cd8..e26dce1edfc2 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2436,7 +2437,7 @@ def forward( BIGBIRD_PEGASUS_START_DOCSTRING, ) # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS -class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): +class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"] @@ -2882,7 +2883,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): +class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 16f7aab5c3df..7ad1dcbd661c 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -719,7 +720,7 @@ def forward( @add_start_docstrings( """BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING ) -class BioGptForCausalLM(BioGptPreTrainedModel): +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 12d259fde71e..4ea5926d854c 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1196,7 +1197,7 @@ def forward( @add_start_docstrings( "The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING ) -class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): +class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] @@ -1397,7 +1398,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill -class BlenderbotForCausalLM(BlenderbotPreTrainedModel): +class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index aa0e38bd8e91..3e378f483a31 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1163,7 +1164,7 @@ def forward( "The BlenderbotSmall Model with a language modeling head. Can be used for summarization.", BLENDERBOT_SMALL_START_DOCSTRING, ) -class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): +class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] @@ -1349,7 +1350,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M -class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): +class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 2392961037f2..aef9b8cebec9 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -24,6 +24,7 @@ from torch.nn.functional import normalize from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1035,7 +1036,7 @@ def forward( """, BLIP_START_DOCSTRING, ) -class BlipForConditionalGeneration(BlipPreTrainedModel): +class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): config_class = BlipConfig _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] main_input_name = "pixel_values" diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index a800ba89825d..78384e6ce2f7 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -808,7 +809,7 @@ def forward( # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 -class BlipTextLMHeadModel(BlipTextPreTrainedModel): +class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 8c3b5254ea8b..0b33572a689c 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -2006,7 +2007,7 @@ def forward( """, BLIP_2_START_DOCSTRING, ) -class Blip2ForConditionalGeneration(Blip2PreTrainedModel): +class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config_class = Blip2Config main_input_name = "pixel_values" diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index b5b221b6b37f..0992a5519f95 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -860,7 +861,7 @@ def _update_causal_mask( """, BLOOM_START_DOCSTRING, ) -class BloomForCausalLM(BloomPreTrainedModel): +class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BloomConfig): diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 0d12c800c156..95540f96d3b6 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1544,7 +1545,7 @@ def forward( """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, FacebookAI/roberta-base->almanach/camembert-base -class CamembertForCausalLM(CamembertPreTrainedModel): +class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 23334311ca95..c631181f00c5 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1496,7 +1497,7 @@ def _update_causal_mask( "Chameleon Model with a head on top used for outputting logits for next token prediction.", CHAMELEON_START_DOCSTRING, ) -class ChameleonForConditionalGeneration(ChameleonPreTrainedModel): +class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 479b0fac2b04..f438226064ec 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation import GenerationConfig +from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1278,7 +1278,7 @@ def forward( "The CLVP decoder model with a language modelling head on top.", CLVP_START_DOCSTRING, ) -class ClvpForCausalLM(ClvpPreTrainedModel): +class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) @@ -1509,7 +1509,7 @@ def _reorder_cache( "together to filter out the best speech_ids.", CLVP_START_DOCSTRING, ) -class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): +class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin): config_class = ClvpConfig def __init__(self, config: ClvpConfig): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index be57838975c0..7d6f64d6461a 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -702,7 +703,7 @@ def _update_causal_mask( """, CODEGEN_START_DOCSTRING, ) -class CodeGenForCausalLM(CodeGenPreTrainedModel): +class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index cb1b3f885798..12586af23f0d 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1068,7 +1069,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere -class CohereForCausalLM(CoherePreTrainedModel): +class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Ignore copy diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index c8a313505251..964d0bbfd145 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging @@ -736,7 +737,7 @@ def forward( """, CPMANT_START_DOCSTRING, ) -class CpmAntForCausalLM(CpmAntPreTrainedModel): +class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: CpmAntConfig): diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index bbf3b10a62ec..6d921621d47d 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer @@ -503,7 +504,7 @@ def forward( """, CTRL_START_DOCSTRING, ) -class CTRLLMHeadModel(CTRLPreTrainedModel): +class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a41fdfb56ed1..fcddeab7a595 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -866,7 +867,7 @@ def forward( @add_start_docstrings( """Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING ) -class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): +class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7263713c0840..46de60e24f1a 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -1227,7 +1228,7 @@ def _update_causal_mask( @add_start_docstrings("The DBRX Model transformer for causal language modeling.", DBRX_START_DOCSTRING) -class DbrxForCausalLM(DbrxPreTrainedModel): +class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): def __init__(self, config: DbrxConfig): super().__init__(config) self.transformer = DbrxModel(config) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dd017170bef9..a200d716d451 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, @@ -1524,7 +1525,7 @@ def forward( @add_start_docstrings( """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING ) -class ElectraForCausalLM(ElectraPreTrainedModel): +class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin): _tied_weights_keys = ["generator_lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 6a0a26a5cbe5..6d81c97da023 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1081,7 +1082,7 @@ def forward( @add_start_docstrings( """Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING ) -class ErnieForCausalLM(ErniePreTrainedModel): +class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 9a37fe22e177..270845c20aae 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -25,6 +25,7 @@ from ...activations import get_activation from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1239,7 +1240,7 @@ def _update_causal_mask( "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).", FALCON_START_DOCSTRING, ) -class FalconForCausalLM(FalconPreTrainedModel): +class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: FalconConfig): diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index f682f75f222e..011197d98542 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -717,7 +718,7 @@ def forward( FALCONMAMBA_START_DOCSTRING, ) # Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache -class FalconMambaForCausalLM(FalconMambaPreTrainedModel): +class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 50c6f7ede222..ef1501e78035 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -644,7 +645,7 @@ def forward( FLAUBERT_START_DOCSTRING, ) # Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class FlaubertWithLMHeadModel(FlaubertPreTrainedModel): +class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["pred_layer.proj.weight"] def __init__(self, config): diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 179408aba38e..4d50f9bb5925 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -35,6 +35,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, @@ -1173,7 +1174,7 @@ def set_output_embeddings(self, value): @add_start_docstrings( "The FSMT Model with a language modeling head. Can be used for summarization.", FSMT_START_DOCSTRING ) -class FSMTForConditionalGeneration(PretrainedFSMTModel): +class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 089313b03b7b..0aabbf6b3654 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from torch import nn +from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...models.auto.modeling_auto import AutoModelForCausalLM @@ -145,7 +146,7 @@ def _init_weights(self, module): "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.", FUYU_START_DOCSTRING, ) -class FuyuForCausalLM(FuyuPreTrainedModel): +class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): def __init__(self, config: FuyuConfig): super().__init__(config) self.padding_idx = config.pad_token_id diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index 36f2d1c594ab..dcc43bc74aec 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -34,6 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -527,7 +528,7 @@ def forward( # Example where we ony modify the docstring and call super -class GemmaForCausalLM(LlamaForCausalLM): +class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin): def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index dd4c899d13d4..8d9bb88686de 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -988,7 +989,7 @@ def _update_causal_mask( return causal_mask -class GemmaForCausalLM(GemmaPreTrainedModel): +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 30f371a1b612..a66ce3160b5f 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -33,6 +33,7 @@ ) from ...cache_utils import Cache +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging @@ -473,7 +474,7 @@ def _update_causal_mask( return causal_mask -class Gemma2ForCausalLM(GemmaForCausalLM): +class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin): def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index be964c9aed01..6b55500739b4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -931,7 +932,7 @@ def _update_causal_mask( return causal_mask -class Gemma2ForCausalLM(Gemma2PreTrainedModel): +class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 7471eec811a0..59d3a406ec35 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...file_utils import ModelOutput +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1324,7 +1325,7 @@ def forward( @add_start_docstrings( """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING ) -class GitForCausalLM(GitPreTrainedModel): +class GitForCausalLM(GitPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8dfbfb906444..e99f4b126246 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1182,7 +1183,7 @@ def forward( """, GPT2_START_DOCSTRING, ) -class GPT2LMHeadModel(GPT2PreTrainedModel): +class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): @@ -1384,7 +1385,7 @@ def _reorder_cache( """, GPT2_START_DOCSTRING, ) -class GPT2DoubleHeadsModel(GPT2PreTrainedModel): +class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0f927a72469d..ca1c03fcd9f9 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1040,7 +1041,7 @@ def forward( """, GPT_BIGCODE_START_DOCSTRING, ) -class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 28309f7738eb..2fae1753154c 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -917,7 +918,7 @@ def _update_causal_mask( """, GPT_NEO_START_DOCSTRING, ) -class GPTNeoForCausalLM(GPTNeoPreTrainedModel): +class GPTNeoForCausalLM(GPTNeoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 274c571fa893..c1b2aa899985 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -30,6 +30,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1110,7 +1111,7 @@ def _update_causal_mask( @add_start_docstrings( """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING ) -class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): +class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 048e108a8ec2..3db2099511bc 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -815,7 +816,7 @@ def _update_causal_mask( """GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""", GPT_NEOX_JAPANESE_START_DOCSTRING, ) -class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): +class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] def __init__(self, config): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 84f6d985f764..9eeb26c5e403 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1011,7 +1012,7 @@ def _update_causal_mask( """, GPTJ_START_DOCSTRING, ) -class GPTJForCausalLM(GPTJPreTrainedModel): +class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index f62de411a4fa..9a8d4570e7be 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1004,7 +1005,7 @@ def _update_causal_mask( return causal_mask -class GraniteForCausalLM(GranitePreTrainedModel): +class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Granite diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 3ac462bdad34..d724485990b9 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1234,7 +1235,7 @@ def _update_causal_mask( return causal_mask -class GraniteMoeForCausalLM(GraniteMoePreTrainedModel): +class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: GraniteMoeConfig): diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 41be300095e7..9273d91ac401 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -23,11 +23,12 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -1450,7 +1451,7 @@ def forward( """The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, IDEFICS2_START_DOCSTRING, ) -class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): +class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 5d59a4ed90e4..a027876b43d3 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -880,7 +881,7 @@ def forward( """, IMAGEGPT_START_DOCSTRING, ) -class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): +class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: ImageGPTConfig): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index ba77afe9f7c2..dff897f59d2d 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1283,7 +1284,7 @@ def forward( """, INSTRUCTBLIP_START_DOCSTRING, ) -class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): +class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): config_class = InstructBlipConfig main_input_name = "pixel_values" diff --git a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py index 506da83c5322..be569abc9137 100644 --- a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py @@ -45,6 +45,7 @@ InstructBlipVisionModel, ) +from ...generation import GenerationMixin from ...utils import logging @@ -128,7 +129,7 @@ class InstructBlipVideoQFormerModel(InstructBlipQFormerModel): pass -class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): +class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin): def forward( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 8cb813e0ac57..bcc299b1ba78 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -30,6 +30,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1292,7 +1293,7 @@ def forward( """, INSTRUCTBLIPVIDEO_START_DOCSTRING, ) -class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel): +class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): config_class = InstructBlipVideoConfig main_input_name = "pixel_values" diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 60e1670a3c27..4b8630efbfa9 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1424,7 +1425,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba -class JambaForCausalLM(JambaPreTrainedModel): +class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: JambaConfig): diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7c4394d0e1a1..e9c069604991 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1202,7 +1203,7 @@ def _update_causal_mask( return causal_mask -class JetMoeForCausalLM(JetMoePreTrainedModel): +class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 69641790b2db..90e21ed2f558 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1521,7 +1522,7 @@ def forward( """, KOSMOS2_START_DOCSTRING, ) -class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel): +class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): config_class = Kosmos2TextConfig _tied_weights_keys = ["lm_head.weight"] @@ -1864,7 +1865,7 @@ def forward( """, KOSMOS2_START_DOCSTRING, ) -class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel): +class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): config_class = Kosmos2Config main_input_name = "pixel_values" _tied_weights_keys = ["text_model.lm_head.weight"] diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 41b6c0a2bea2..f96bfd82b526 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -2298,7 +2299,7 @@ def forward( @add_start_docstrings( "The LED Model with a language modeling head. Can be used for summarization.", LED_START_DOCSTRING ) -class LEDForConditionalGeneration(LEDPreTrainedModel): +class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin): base_model_prefix = "led" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0bc44f314b5e..73b6bcd8b4a4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1101,7 +1102,7 @@ def _update_causal_mask( return causal_mask -class LlamaForCausalLM(LlamaPreTrainedModel): +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index eb1c55341b07..092008873d1e 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -237,7 +238,7 @@ def _supports_sdpa(self): """The LLAVA model which consists of a vision backbone and a language model.""", LLAVA_START_DOCSTRING, ) -class LlavaForConditionalGeneration(LlavaPreTrainedModel): +class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index bf76921090b2..a96b0d894204 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -23,10 +23,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -349,7 +350,7 @@ def _supports_sdpa(self): """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_START_DOCSTRING, ) -class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): +class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaNextConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index e765dfb95cc3..c5ca2bf00324 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -29,6 +29,7 @@ image_size_to_num_patches, ) +from ...generation import GenerationMixin from ...utils import ( logging, replace_return_docstrings, @@ -218,7 +219,7 @@ class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector): pass -class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): +class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin): def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): super().__init__(config, **super_kwargs) self.vision_resampler = LlavaNextVideoPooler(config) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 589bf346ceeb..7ad9e0769eb3 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -29,10 +29,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -387,7 +388,7 @@ def _supports_sdpa(self): """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_VIDEO_START_DOCSTRING, ) -class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): +class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin): def __init__( self, config: LlavaNextVideoConfig, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 593500c2e404..948efbc922b7 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -23,10 +23,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, logging, @@ -358,7 +359,7 @@ def _init_weights(self, module): """The LLaVA-Onevision model which consists of a vision backbone and a language model.""", LLAVA_ONEVISION_START_DOCSTRING, ) -class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel): +class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaOnevisionConfig): super().__init__(config) self.vision_tower = AutoModel.from_config( diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b2a6ed11ca57..8f9385c0fe76 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1900,7 +1901,7 @@ def forward( @add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) -class LongT5ForConditionalGeneration(LongT5PreTrainedModel): +class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 23a855fff256..86a4378da29c 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1342,7 +1343,7 @@ def forward( @add_start_docstrings( "The M2M100 Model with a language modeling head. Can be used for summarization.", M2M_100_START_DOCSTRING ) -class M2M100ForConditionalGeneration(M2M100PreTrainedModel): +class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 14a3dea1d1cc..6bed1caab23a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -657,7 +658,7 @@ def forward( """, MAMBA_START_DOCSTRING, ) -class MambaForCausalLM(MambaPreTrainedModel): +class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 19d53437130e..01074af38a51 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -932,7 +933,7 @@ def forward( """, MAMBA2_START_DOCSTRING, ) -class Mamba2ForCausalLM(Mamba2PreTrainedModel): +class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): _tied_weights_keys = [] def __init__(self, config): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 2045f673540f..cb26bb11e094 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1224,7 +1225,7 @@ def forward( @add_start_docstrings( "The Marian Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING ) -class MarianMTModel(MarianPreTrainedModel): +class MarianMTModel(MarianPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ "final_logits_bias", @@ -1504,7 +1505,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en -class MarianForCausalLM(MarianPreTrainedModel): +class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9455f21b2073..3f2d6cb8e2ba 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1526,7 +1527,7 @@ def forward( "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.", MBART_START_DOCSTRING, ) -class MBartForConditionalGeneration(MBartPreTrainedModel): +class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] @@ -1967,7 +1968,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 -class MBartForCausalLM(MBartPreTrainedModel): +class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 16641655e203..20506f91bcbc 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1110,7 +1111,7 @@ def forward( """MegatronBert Model with a `language modeling` head on top for CLM fine-tuning.""", MEGATRON_BERT_START_DOCSTRING, ) -class MegatronBertForCausalLM(MegatronBertPreTrainedModel): +class MegatronBertForCausalLM(MegatronBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder"] def __init__(self, config): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 80992734046a..ffa1a18307e9 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -950,7 +951,7 @@ def _update_causal_mask( return causal_mask -class MistralForCausalLM(MistralPreTrainedModel): +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index fcc0d66e19c4..a1786fbb17e3 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1186,7 +1187,7 @@ def _update_causal_mask( return causal_mask -class MixtralForCausalLM(MixtralPreTrainedModel): +class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 85579636dcc4..9c826c370b75 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -24,6 +24,7 @@ from torch.nn import functional as F from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -500,7 +501,7 @@ def forward( """, MPT_START_DOCSTRING, ) -class MptForCausalLM(MptPreTrainedModel): +class MptForCausalLM(MptPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: MptConfig): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 54943cf982dd..6a7406f11b5b 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1550,7 +1551,7 @@ def forward( @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING) -class MT5ForConditionalGeneration(MT5PreTrainedModel): +class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin): r""" Examples: diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f720faac038e..3109c4fc2431 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,9 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.configuration_utils import GenerationConfig, GenerationMode -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList -from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation import ( + ClassifierFreeGuidanceLogitsProcessor, + GenerationConfig, + GenerationMixin, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, +) from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1206,7 +1211,7 @@ def forward( "The MusicGen decoder model with a language modelling head on top.", MUSICGEN_START_DOCSTRING, ) -class MusicgenForCausalLM(MusicgenPreTrainedModel): +class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenDecoderConfig): super().__init__(config) @@ -1658,7 +1663,7 @@ def generate( "for music generation tasks with one or both of text and audio prompts.", MUSICGEN_START_DOCSTRING, ) -class MusicgenForConditionalGeneration(PreTrainedModel): +class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenConfig base_model_prefix = "encoder_decoder" main_input_name = "input_ids" diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a8a8fe960989..c8345870b253 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,9 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.configuration_utils import GenerationConfig, GenerationMode -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList -from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation import ( + ClassifierFreeGuidanceLogitsProcessor, + GenerationConfig, + GenerationMixin, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, +) from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1117,7 +1122,7 @@ def forward( MUSICGEN_MELODY_START_DOCSTRING, ) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody,MusicGen->Musicgen Melody -class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): +class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__(config) @@ -1585,7 +1590,7 @@ def generate( decoder (`Optional[MusicgenMelodyForCausalLM]`, *optional*): MusicGen Melody decoder used to generate audio codes. """, ) -class MusicgenMelodyForConditionalGeneration(PreTrainedModel): +class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenMelodyConfig main_input_name = "input_ids" supports_gradient_checkpointing = True diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 319f1760cef9..c47c4b26b539 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1351,7 +1352,7 @@ def forward( @add_start_docstrings( "The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING ) -class MvpForConditionalGeneration(MvpPreTrainedModel): +class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: MvpConfig): @@ -1791,7 +1792,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class MvpForCausalLM(MvpPreTrainedModel): +class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4d079b4dde10..aa699853d557 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -980,7 +981,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron -class NemotronForCausalLM(NemotronPreTrainedModel): +class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 2bec0fb84dce..c33844da0f55 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1604,7 +1605,7 @@ def forward( @add_start_docstrings( "The NllbMoe Model with a language modeling head. Can be used for summarization.", NLLB_MOE_START_DOCSTRING ) -class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): +class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 03fa524532a0..a44b7d2a0a4c 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1022,7 +1023,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo -class OlmoForCausalLM(OlmoPreTrainedModel): +class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 2cbde7dc8631..d30cace3a705 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1173,7 +1174,7 @@ def _update_causal_mask( return causal_mask -class OlmoeForCausalLM(OlmoePreTrainedModel): +class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 2b24850f3f0c..0aa02a6f5d84 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu_new, silu +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, SequenceSummary from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer @@ -524,7 +525,7 @@ def forward( """, OPENAI_GPT_START_DOCSTRING, ) -class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): +class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index f393e0e6e2b2..8742f2628907 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1009,7 +1010,7 @@ def forward( ) -class OPTForCausalLM(OPTPreTrainedModel): +class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 48fffb6b428d..b5fddce1d6a9 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -22,6 +22,7 @@ from torch import nn from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -302,7 +303,7 @@ def _supports_sdpa(self): """The PALIGEMMA model which consists of a vision backbone and a language model.""", PALIGEMMA_START_DOCSTRING, ) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): def __init__(self, config: PaliGemmaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 42cef3a63558..03d1574e9be2 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1244,7 +1245,7 @@ def forward( @add_start_docstrings( "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING ) -class PegasusForConditionalGeneration(PegasusPreTrainedModel): +class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] @@ -1456,7 +1457,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class PegasusForCausalLM(PegasusPreTrainedModel): +class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 6d9072777bf6..77c0b32e6433 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1464,7 +1465,7 @@ def forward( @add_start_docstrings("The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING) -class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): +class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a6fd2284afb6..4f122e14284d 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -847,7 +848,7 @@ def _update_causal_mask( return causal_mask -class PersimmonForCausalLM(PersimmonPreTrainedModel): +class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4d0a076b5f9a..03ed19bc34ac 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1139,7 +1140,7 @@ def _update_causal_mask( return causal_mask -class PhiForCausalLM(PhiPreTrainedModel): +class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e0ca84be1848..12ee9f017f81 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1160,7 +1161,7 @@ def _update_causal_mask( return causal_mask -class Phi3ForCausalLM(Phi3PreTrainedModel): +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 94d882c80566..f209d7d88287 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -1553,7 +1554,7 @@ def forward( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", PIX2STRUCT_START_DOCSTRING, ) -class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): +class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): config_class = Pix2StructConfig main_input_name = "flattened_patches" _tied_weights_keys = ["decoder.lm_head.weight"] diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93d91e160089..d15e079770a3 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1254,7 +1255,7 @@ def forward( "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.", PLBART_START_DOCSTRING, ) -class PLBartForConditionalGeneration(PLBartPreTrainedModel): +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] @@ -1568,7 +1569,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base -class PLBartForCausalLM(PLBartPreTrainedModel): +class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c769cff3c454..e6488898e8a9 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,6 +25,7 @@ from transformers.generation import GenerationConfig from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1001,7 +1002,7 @@ def forward(self, feature, index_value, embedding_offset): @add_start_docstrings("""Pop2Piano Model with a `language modeling` head on top.""", Pop2Piano_START_DOCSTRING) -class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel): +class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: Pop2PianoConfig): diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 96fa2e2c12e5..7d23088f6e57 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -26,6 +26,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1856,7 +1857,7 @@ def forward( "The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.", PROPHETNET_START_DOCSTRING, ) -class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): +class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] def __init__(self, config: ProphetNetConfig): @@ -2073,7 +2074,7 @@ def get_decoder(self): " language modeling.", PROPHETNET_START_DOCSTRING, ) -class ProphetNetForCausalLM(ProphetNetPreTrainedModel): +class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin): _tied_weights_keys = [ "prophetnet.word_embeddings.weight", "prophetnet.decoder.word_embeddings.weight", diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1e79115d3470..10c0b6f38669 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1078,7 +1079,7 @@ def _update_causal_mask( return causal_mask -class Qwen2ForCausalLM(Qwen2PreTrainedModel): +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 14235bf0aaf6..bf48e1c6a97e 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -22,10 +22,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -855,7 +856,7 @@ def forward(self, audio_features): """The QWEN2AUDIO model which consists of a audio backbone and a language model.""", QWEN2AUDIO_START_DOCSTRING, ) -class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel): +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c9ee7b5f57a1..1b28e9baf25f 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1253,7 +1254,7 @@ def _update_causal_mask( return causal_mask -class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): +class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 02716153fdda..938ec4d5e423 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1416,7 +1417,7 @@ def _update_causal_mask( """ -class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel): +class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a8f076fad79c..e04929489984 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import PreTrainedModel @@ -777,7 +778,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma -class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel): +class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 2e98a07217e6..37b675539e66 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward @@ -2183,7 +2184,7 @@ def _pad_to_mult_of_chunk_length( @add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING) -class ReformerModelWithLMHead(ReformerPreTrainedModel): +class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 31f7e3dce454..99016c1be429 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1002,7 +1003,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING ) -class RemBertForCausalLM(RemBertPreTrainedModel): +class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index bbf16ec039b4..91500e1926d7 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1003,7 +1004,7 @@ def forward( @add_start_docstrings( """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING ) -class RobertaForCausalLM(RobertaPreTrainedModel): +class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 95657c260dc7..9ed9b11d9431 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -855,7 +856,7 @@ def forward( ROBERTA_PRELAYERNORM_START_DOCSTRING, ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer -class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): +class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index c4efbf16323e..2969f7f1a3d0 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1403,7 +1404,7 @@ def prepare_inputs_for_generation( @add_start_docstrings( """RoCBert Model with a `language modeling` head on top for CLM fine-tuning.""", ROC_BERT_START_DOCSTRING ) -class RoCBertForCausalLM(RoCBertPreTrainedModel): +class RoCBertForCausalLM(RoCBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 69588ff743a0..c98b525abe08 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -1033,7 +1034,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING ) -class RoFormerForCausalLM(RoFormerPreTrainedModel): +class RoFormerForCausalLM(RoFormerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 7dec1f26e1a3..8361afbf727b 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -751,7 +752,7 @@ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): """, RWKV_START_DOCSTRING, ) -class RwkvForCausalLM(RwkvPreTrainedModel): +class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin): _tied_weights_keys = ["head.weight"] def __init__(self, config): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index ba8230ec509d..8e226d92a105 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2150,7 +2151,7 @@ def forward( embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. """, ) -class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel): +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [ "vocoder", "speech_encoder", @@ -2664,7 +2665,7 @@ def remove_weight_norm(self): "The text-to-text SeamlessM4T Model transformer which can be used for T2TT.", SEAMLESS_M4T_START_DOCSTRING, ) -class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel): +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 2d1fde8eed69..aa710ad95266 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2439,7 +2440,7 @@ def forward( embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. """, ) -class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel): +class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [ "vocoder", "speech_encoder", @@ -2922,7 +2923,7 @@ def remove_weight_norm(self): SEAMLESS_M4T_V2_START_DOCSTRING, ) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToText with SeamlessM4T->SeamlessM4Tv2,SeamlessM4Tv2Tokenizer->SeamlessM4TTokenizer, SeamlessM4Tv2Processor->SeamlessM4TProcessor -class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel): +class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 8353a172b212..bdd532fa25e8 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1207,7 +1208,7 @@ def forward( "The Speech2Text Model with a language modeling head. Can be used for summarization.", SPEECH_TO_TEXT_START_DOCSTRING, ) -class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): +class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 13641ecb37f2..463a30fabe77 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1123,7 +1124,7 @@ def _update_causal_mask( # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm -class StableLmForCausalLM(StableLmPreTrainedModel): +class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5eaf50f090fa..079ad1298fb9 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1053,7 +1054,7 @@ def _update_causal_mask( # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2 -class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): +class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c5797d4573b7..96b6c7334b15 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -1456,7 +1457,7 @@ def forward( @add_start_docstrings( """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING ) -class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel): +class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: SwitchTransformersConfig): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a90101924c5b..43e3f3afa4a8 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1542,7 +1543,7 @@ def forward( @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) -class T5ForConditionalGeneration(T5PreTrainedModel): +class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 04eb40ab2a2f..67b97cf9c852 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -736,7 +737,7 @@ def forward(self, *args, **kwargs): " [`VisionEncoderDecoder`].", TROCR_START_DOCSTRING, ) -class TrOCRForCausalLM(TrOCRPreTrainedModel): +class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 6f7b6cf06049..c621b742323d 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,6 +34,7 @@ ) from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -1679,7 +1680,7 @@ def forward( This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.""", UDOP_START_DOCSTRING, ) -class UdopForConditionalGeneration(UdopPreTrainedModel): +class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 3271689540b9..a7d1e5bacc65 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1101,7 +1102,7 @@ def forward( @add_start_docstrings("""UMT5 Model with a `language modeling` head on top.""", UMT5_START_DOCSTRING) -class UMT5ForConditionalGeneration(UMT5PreTrainedModel): +class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin): r""" Examples: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 9ae80be65ae4..7c7cfec20959 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -239,7 +240,7 @@ def _supports_sdpa(self): """The VideoLlava model which consists of a vision backbone and a language model.""", VIDEO_LLAVA_START_DOCSTRING, ) -class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.video_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 53a321369719..95129d46bbd8 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -240,7 +241,7 @@ def _supports_sdpa(self): VIPLLAVA_START_DOCSTRING, ) # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava -class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): +class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin): def __init__(self, config: VipLlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 8012c3c1bbfc..7a4e9487288e 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -25,7 +25,7 @@ from transformers.cache_utils import EncoderDecoderCache -from ...generation.configuration_utils import GenerationConfig +from ...generation import GenerationConfig, GenerationMixin from ...generation.logits_process import ( LogitsProcessorList, SuppressTokensAtBeginLogitsProcessor, @@ -172,7 +172,7 @@ def _pad_to_max_length( return sequences -class WhisperGenerationMixin: +class WhisperGenerationMixin(GenerationMixin): def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): """ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b82b978e5e6d..93ec57fcf4b4 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, @@ -1915,7 +1916,7 @@ def forward(self, *args, **kwargs): """, WHISPER_START_DOCSTRING, ) -class WhisperForCausalLM(WhisperPreTrainedModel): +class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin): _tied_weights_keys = ["proj_out.weight"] main_input_name = "input_ids" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 4f1693583494..3090bc2973cd 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -696,7 +697,7 @@ def forward( """, XGLM_START_DOCSTRING, ) -class XGLMForCausalLM(XGLMPreTrainedModel): +class XGLMForCausalLM(XGLMPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 280383630987..3acec2353b69 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -657,7 +658,7 @@ def forward(self, x, y=None): """, XLM_START_DOCSTRING, ) -class XLMWithLMHeadModel(XLMPreTrainedModel): +class XLMWithLMHeadModel(XLMPreTrainedModel, GenerationMixin): _tied_weights_keys = ["pred_layer.proj.weight"] def __init__(self, config): diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 2adae33fbd50..a153f0946893 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1006,7 +1007,7 @@ def forward( XLM_ROBERTA_START_DOCSTRING, ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): +class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index f86abf823e90..0c384ad45c52 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -986,7 +987,7 @@ def forward( """XLM-RoBERTa-XL Model with a `language modeling` head on top for CLM fine-tuning.""", XLM_ROBERTA_XL_START_DOCSTRING, ) -class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): +class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 5d424ebe12dd..7681fbafad6d 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( @@ -1286,7 +1287,7 @@ def forward( """, XLNET_START_DOCSTRING, ) -class XLNetLMHeadModel(XLNetPreTrainedModel): +class XLNetLMHeadModel(XLNetPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_loss.weight"] def __init__(self, config): diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index b1ca8116a72a..71474cc9c45b 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -956,7 +957,7 @@ def forward( "X-MOD Model with a `language modeling` head on top for CLM fine-tuning.", XMOD_START_DOCSTRING, ) -class XmodForCausalLM(XmodPreTrainedModel): +class XmodForCausalLM(XmodPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2f8e60c79151..600942a7ac08 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2099,6 +2099,15 @@ def test_assisted_decoding_with_num_logits_to_keep(self): ) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + @pytest.mark.generate + def test_inherits_generation_mixin(self): + """ + Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel` + to inherit it. + """ + for model_class in self.all_generative_model_classes: + self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape config = config.text_config if hasattr(config, "text_config") else config diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 39770b091bef..95d716898343 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -67,6 +67,7 @@ BertModel, FunnelBaseModel, FunnelModel, + GenerationMixin, GPT2Config, GPT2LMHeadModel, ResNetBackbone, @@ -571,3 +572,20 @@ def test_dynamic_saving_from_local_repo(self): _ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True) self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file()) self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file()) + + def test_custom_model_patched_generation_inheritance(self): + """ + Tests that our inheritance patching for generate-compatible models works as expected. Without this feature, + old Hub models lose the ability to call `generate`. + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True + ) + self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM") + + # It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to + # stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present. + self.assertTrue(isinstance(model, GenerationMixin)) + # More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance + # patching was added in v4.45) + self.assertTrue("GenerationMixin" in str(model.__class__.__bases__)) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2130ed4b7c88..5155647059f1 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -90,6 +90,7 @@ BertConfig, BertModel, CLIPTextModel, + GenerationMixin, PreTrainedModel, T5Config, T5ForConditionalGeneration, @@ -1715,6 +1716,32 @@ def test_isin_mps_friendly(self): torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) ) + def test_can_generate(self): + """Tests the behavior of `PreTrainedModel.can_generate` method.""" + # 1 - By default, a model CAN'T generate + self.assertFalse(BertModel.can_generate()) + + # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly + class DummyBertWithMixin(BertModel, GenerationMixin): + pass + + self.assertTrue(DummyBertWithMixin.can_generate()) + + # 3 - Alternatively, a model can implement a `generate` method + class DummyBertWithGenerate(BertModel): + def generate(self): + pass + + self.assertTrue(DummyBertWithGenerate.can_generate()) + + # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited + # `GenerationMixin`) + class DummyBertWithPrepareInputs(BertModel): + def prepare_inputs_for_generation(self): + pass + + self.assertTrue(DummyBertWithPrepareInputs.can_generate()) + def test_save_and_load_config_with_custom_generation(self): """ Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter From d3f8417e2b759ca6bc2607443be6bb8c0f00470c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 24 Sep 2024 17:40:56 +0800 Subject: [PATCH 34/54] Enable BNB multi-backend support (#31098) * enable cpu bnb path * fix style * fix code style * fix 4 bit path * Update src/transformers/utils/import_utils.py Co-authored-by: Aarni Koskela * add multi backend refactor tests * fix style * tweak 4bit quantizer + fix corresponding tests * tweak 8bit quantizer + *try* fixing corresponding tests * fix dequant bnb 8bit * account for Intel CPU in variability of expected outputs * enable cpu and xpu device map * further tweaks to account for Intel CPU * fix autocast to work with both cpu + cuda * fix comments * fix comments * switch to testing_utils.torch_device * allow for xpu in multi-gpu tests * fix tests 4bit for CPU NF4 * fix bug with is_torch_xpu_available needing to be called as func * avoid issue where test reports attr err due to other failure * fix formatting * fix typo from resolving of merge conflict * polish based on last PR review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix CI * Update src/transformers/integrations/integration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/integrations/integration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix error log * fix error msg * add \n in error log * make quality * rm bnb cuda restriction in doc * cpu model don't need dispatch * fix doc * fix style * check cuda avaliable in testing * fix tests * Update docs/source/en/model_doc/chameleon.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: Aarni Koskela * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Aarni Koskela * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Aarni Koskela * fix doc * fix check multibackends * fix import sort * remove check torch in bnb * docs: update bitsandbytes references with multi-backend info * docs: fix small mistakes in bnb paragraph * run formatting * reveret bnb check * move bnb multi-backend check to import_utils * Update src/transformers/utils/import_utils.py Co-authored-by: Aarni Koskela * fix bnb check * minor fix for bnb * check lib first * fix code style * Revert "run formatting" This reverts commit ac108c6d6b34f45a5745a736ba57282405cfaa61. * fix format * give warning when bnb version is low and no cuda found] * fix device assignment check to be multi-device capable * address akx feedback on get_avlbl_dev fn * revert partially, as we don't want the function that public, as docs would be too much (enforced) --------- Co-authored-by: Aarni Koskela Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/llm_tutorial_optimization.md | 2 +- docs/source/en/model_doc/chameleon.md | 12 ++- docs/source/en/model_doc/llava_next.md | 12 ++- docs/source/en/model_doc/llava_next_video.md | 12 ++- docs/source/en/model_doc/llava_onevision.md | 14 ++- docs/source/en/model_doc/mixtral.md | 2 +- docs/source/en/model_doc/video_llava.md | 12 ++- docs/source/en/model_memory_anatomy.md | 2 +- docs/source/en/perf_train_gpu_one.md | 2 +- docs/source/en/quantization/bitsandbytes.md | 8 ++ docs/source/en/quantization/overview.md | 16 ++- src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/bitsandbytes.py | 98 ++++++++++++++++- .../quantizers/quantizer_bnb_4bit.py | 22 +++- .../quantizers/quantizer_bnb_8bit.py | 23 ++-- src/transformers/testing_utils.py | 48 ++++++++- src/transformers/utils/__init__.py | 32 ++++++ src/transformers/utils/import_utils.py | 22 +++- tests/quantization/bnb/test_4bit.py | 100 ++++++++++++------ tests/quantization/bnb/test_mixed_int8.py | 95 +++++++++++------ 20 files changed, 436 insertions(+), 100 deletions(-) diff --git a/docs/source/en/llm_tutorial_optimization.md b/docs/source/en/llm_tutorial_optimization.md index a675a6de39a2..9d3d8ad6ba8b 100644 --- a/docs/source/en/llm_tutorial_optimization.md +++ b/docs/source/en/llm_tutorial_optimization.md @@ -181,7 +181,7 @@ for every matrix multiplication. Dequantization and re-quantization is performed Therefore, inference time is often **not** reduced when using quantized weights, but rather increases. Enough theory, let's give it a try! To quantize the weights with Transformers, you need to make sure that -the [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) library is installed. +the [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes) library is installed. ```bash !pip install bitsandbytes diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index 28ec01ad6158..2fa9c1db866c 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -128,7 +128,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza ### Quantization using Bitsandbytes -The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Simply change the snippet above with: ```python from transformers import ChameleonForConditionalGeneration, BitsAndBytesConfig diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index d0558be76467..f04827cc7d5f 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -233,7 +233,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza ### Quantization using Bitsandbytes -The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes`, and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Simply change the snippet above with: ```python from transformers import LlavaNextForConditionalGeneration, BitsAndBytesConfig diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md index 48e50f950621..fe905dfb7932 100644 --- a/docs/source/en/model_doc/llava_next_video.md +++ b/docs/source/en/model_doc/llava_next_video.md @@ -205,7 +205,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. This allows for efficient deployment on resource-constrained cases. -First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: +First, make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Then simply load the quantized model by adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: ```python diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index 64a127abca4c..717784da738d 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -264,9 +264,19 @@ processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spac ## Model optimization -### Quantization using Bitsandbytes +### Quantization using bitsandbytes -The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Simply change the snippet above with: ```python from transformers import LlavaOnevisionForConditionalGeneration, BitsAndBytesConfig diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md index 26eff8ec21ad..71c7d7921ef0 100644 --- a/docs/source/en/model_doc/mixtral.md +++ b/docs/source/en/model_doc/mixtral.md @@ -141,7 +141,7 @@ The Flash Attention-2 model uses also a more memory efficient cache slicing mech As the Mixtral model has 45 billion parameters, that would require about 90GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), a single A100 with 40GB of RAM is enough to fit the entire model, as in that case only about 27 GB of RAM is required. -Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the BitsAndyBytes quantization (but refer to [this page](../quantization.md) for other quantization methods): +Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods): ```python >>> import torch diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md index f098e82a1776..1c4b5b4b874d 100644 --- a/docs/source/en/model_doc/video_llava.md +++ b/docs/source/en/model_doc/video_llava.md @@ -139,7 +139,17 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. his allows for efficient deployment on resource-constrained cases. -First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: +First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a GPU/accelerator that is supported by the library. + + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + +Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: ```python diff --git a/docs/source/en/model_memory_anatomy.md b/docs/source/en/model_memory_anatomy.md index c1d9d4c54bc7..44c197aae5cf 100644 --- a/docs/source/en/model_memory_anatomy.md +++ b/docs/source/en/model_memory_anatomy.md @@ -233,7 +233,7 @@ Let's look at the details. **Optimizer States:** - 8 bytes * number of parameters for normal AdamW (maintains 2 states) -- 2 bytes * number of parameters for 8-bit AdamW optimizers like [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +- 2 bytes * number of parameters for 8-bit AdamW optimizers like [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) - 4 bytes * number of parameters for optimizers like SGD with momentum (maintains only 1 state) **Gradients** diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index c90f2ca58483..364fc46544c6 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -284,7 +284,7 @@ training_args = TrainingArguments(per_device_train_batch_size=4, optim="adamw_bn However, we can also use a third-party implementation of the 8-bit optimizer for demonstration purposes to see how that can be integrated. -First, follow the installation guide in the GitHub [repo](https://github.com/TimDettmers/bitsandbytes) to install the `bitsandbytes` library +First, follow the installation guide in the GitHub [repo](https://github.com/bitsandbytes-foundation/bitsandbytes) to install the `bitsandbytes` library that implements the 8-bit Adam optimizer. Next you need to initialize the optimizer. This involves two steps: diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md index 334b6145e537..e9447555e824 100644 --- a/docs/source/en/quantization/bitsandbytes.md +++ b/docs/source/en/quantization/bitsandbytes.md @@ -38,6 +38,14 @@ pip install --upgrade accelerate transformers + + +bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers. diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 9eb74793a127..97bb0cf53263 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -49,7 +49,7 @@ Use the table below to help you decide which quantization method to use. |-------------------------------------|-------------------------|-----|----------|----------------|-----------------------|-------------------------|----------------|-------------------------------------|--------------|------------------------|---------------------------------------------| | [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 / 2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM | | [AWQ](./awq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ | -| [bitsandbytes](./bitsandbytes) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 4 / 8 | 🟢 | 🟢 | 🟢 | https://github.com/TimDettmers/bitsandbytes | +| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 * | 🟢 | 🟡 * | 🔴 ** | 🔴 (soon!) | 4 / 8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes | | [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ | | GGUF / GGML (llama.cpp) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 1 - 8 | 🔴 | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp | | [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ | @@ -57,3 +57,17 @@ Use the table below to help you decide which quantization method to use. | [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | | [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | + + + +\* bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + +We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. + + + + + +\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships. + + diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 0a28ff022a53..00bbcf2d060f 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -31,6 +31,7 @@ "replace_with_bnb_linear", "set_module_8bit_tensor_to_device", "set_module_quantized_tensor_to_device", + "validate_bnb_backend_availability", ], "deepspeed": [ "HfDeepSpeedConfig", @@ -124,6 +125,7 @@ replace_with_bnb_linear, set_module_8bit_tensor_to_device, set_module_quantized_tensor_to_device, + validate_bnb_backend_availability, ) from .deepspeed import ( HfDeepSpeedConfig, diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index f37ca9a2650b..2501261b55e0 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -6,7 +6,15 @@ from packaging import version -from ..utils import is_accelerate_available, is_bitsandbytes_available, logging +from ..utils import ( + get_available_devices, + is_accelerate_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_ipex_available, + is_torch_available, + logging, +) if is_bitsandbytes_available(): @@ -332,7 +340,7 @@ def get_keys_to_not_convert(model): # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 -def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): +def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None): """ Helper function to dequantize 4bit or 8bit bnb weights. @@ -350,7 +358,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): logger.warning_once( f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" ) - return output_tensor + return output_tensor.to(dtype) if state.SCB is None: state.SCB = weight.SCB @@ -361,7 +369,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): if state.CxB is None: state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) - return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype) def _create_accelerate_new_hook(old_hook): @@ -383,6 +391,7 @@ def _create_accelerate_new_hook(old_hook): def _dequantize_and_replace( model, + dtype, modules_to_not_convert=None, current_key_name=None, quantization_config=None, @@ -422,7 +431,7 @@ def _dequantize_and_replace( else: state = None - new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state)) if bias is not None: new_module.bias = bias @@ -441,6 +450,7 @@ def _dequantize_and_replace( if len(list(module.children())) > 0: _, has_been_replaced = _dequantize_and_replace( module, + dtype, modules_to_not_convert, current_key_name, quantization_config, @@ -458,6 +468,7 @@ def dequantize_and_replace( ): model, has_been_replaced = _dequantize_and_replace( model, + model.dtype, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config, ) @@ -468,3 +479,80 @@ def dequantize_and_replace( ) return model + + +def _validate_bnb_multi_backend_availability(raise_exception): + import bitsandbytes as bnb + + bnb_supported_devices = getattr(bnb, "supported_torch_devices", set()) + available_devices = get_available_devices() + + if available_devices == {"cpu"} and not is_ipex_available(): + from importlib.util import find_spec + + if find_spec("intel_extension_for_pytorch"): + logger.warning( + "You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible." + ) + + available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment + + if not available_devices.intersection(bnb_supported_devices): + if raise_exception: + bnb_supported_devices_with_info = set( # noqa: C401 + '"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)' + if device == "cpu" + else device + for device in bnb_supported_devices + ) + err_msg = ( + f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. " + "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + + logger.error(err_msg) + raise RuntimeError(err_msg) + + logger.warning("No supported devices found for bitsandbytes multi-backend.") + return False + + logger.debug("Multi-backend validation successful.") + return True + + +def _validate_bnb_cuda_backend_availability(raise_exception): + if not is_torch_available(): + return False + + import torch + + if not torch.cuda.is_available(): + log_msg = ( + "CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. " + "Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + if raise_exception: + logger.error(log_msg) + raise RuntimeError(log_msg) + + logger.warning(log_msg) + return False + + logger.debug("CUDA backend validation successful.") + return True + + +def validate_bnb_backend_availability(raise_exception=False): + """ + Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not. + """ + if not is_bitsandbytes_available(): + if importlib.util.find_spec("bitsandbytes") and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.43.1"): + return _validate_bnb_cuda_backend_availability(raise_exception) + return False + + if is_bitsandbytes_multi_backend_available(): + return _validate_bnb_multi_backend_availability(raise_exception) + return _validate_bnb_cuda_backend_availability(raise_exception) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 827ca310f35a..73e7664aeb88 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -29,6 +29,7 @@ is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_xpu_available, logging, ) @@ -65,8 +66,6 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available(): raise ImportError( f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" @@ -76,6 +75,12 @@ def validate_environment(self, *args, **kwargs): "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) + from ..integrations import validate_bnb_backend_availability + from ..utils import is_bitsandbytes_multi_backend_available + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + validate_bnb_backend_availability(raise_exception=True) + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): raise ValueError( "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" @@ -91,7 +96,9 @@ def validate_environment(self, *args, **kwargs): device_map_without_lm_head = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -255,10 +262,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.cuda.current_device()} + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + elif is_torch_xpu_available(): + device_map = {"": f"xpu:{torch.xpu.current_device()}"} + else: + device_map = {"": "cpu"} logger.info( "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " + f"Setting device_map to {device_map}. " "If you want to use the model for inference, please set device_map ='auto' " ) return device_map diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index dbfceac2de86..65d97716d02c 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -27,6 +27,7 @@ is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_xpu_available, logging, ) from .quantizers_utils import get_module_from_name @@ -64,9 +65,6 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") - if not is_accelerate_available(): raise ImportError( f"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" @@ -76,6 +74,12 @@ def validate_environment(self, *args, **kwargs): "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) + from ..integrations import validate_bnb_backend_availability + from ..utils import is_bitsandbytes_multi_backend_available + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + validate_bnb_backend_availability(raise_exception=True) + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): raise ValueError( "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" @@ -91,7 +95,9 @@ def validate_environment(self, *args, **kwargs): device_map_without_lm_head = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -127,10 +133,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.cuda.current_device()} + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + elif is_torch_xpu_available(): + device_map = {"": f"xpu:{torch.xpu.current_device()}"} + else: + device_map = {"": "cpu"} logger.info( "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " + f"Setting device_map to {device_map}. " "If you want to use the model for inference, please set device_map ='auto' " ) return device_map diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index e0608acfeb8a..2cc0fa571089 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -61,6 +61,7 @@ is_auto_gptq_available, is_av_available, is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, is_bs4_available, is_cv2_available, is_cython_available, @@ -224,6 +225,17 @@ def parse_int_from_env(key, default=None): _run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) +def get_device_count(): + import torch + + if is_torch_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() + + return num_devices + + def is_pt_tf_cross_test(test_case): """ Decorator marking a test as a test that control interactions between PyTorch and TensorFlow. @@ -331,6 +343,29 @@ def tooslow(test_case): return unittest.skip(reason="test is too slow")(test_case) +def skip_if_not_implemented(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except NotImplementedError as e: + raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}") + + return wrapper + + +def apply_skip_if_not_implemented(cls): + """ + Class decorator to apply @skip_if_not_implemented to all test methods. + """ + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, skip_if_not_implemented(attr)) + return cls + + def custom_tokenizers(test_case): """ Decorator marking a test for a custom tokenizer. @@ -738,9 +773,9 @@ def require_torch_multi_gpu(test_case): if not is_torch_available(): return unittest.skip(reason="test requires PyTorch")(test_case) - import torch + device_count = get_device_count() - return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) + return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case) def require_torch_multi_accelerator(test_case): @@ -947,6 +982,15 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): + """ + Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. + """ + if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available(): + return test_case + return require_torch_gpu(test_case) + + def require_torch_accelerator(test_case): """Decorator marking a test that requires an accessible accelerator and PyTorch.""" return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")( diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index eee350349f55..93976c237556 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import lru_cache +from typing import FrozenSet + from huggingface_hub import get_full_repo_name # for backward compatibility from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility from packaging import version @@ -118,6 +121,7 @@ is_auto_gptq_available, is_av_available, is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, is_bs4_available, is_coloredlogs_available, is_cv2_available, @@ -277,3 +281,31 @@ def check_min_version(min_version): + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other " "versions of HuggingFace Transformers." ) + + +@lru_cache() +def get_available_devices() -> FrozenSet[str]: + """ + Returns a frozenset of devices available for the current PyTorch installation. + """ + devices = {"cpu"} # `cpu` is always supported as a device in PyTorch + + if is_torch_cuda_available(): + devices.add("cuda") + + if is_torch_mps_available(): + devices.add("mps") + + if is_torch_xpu_available(): + devices.add("xpu") + + if is_torch_npu_available(): + devices.add("npu") + + if is_torch_mlu_available(): + devices.add("mlu") + + if is_torch_musa_available(): + devices.add("musa") + + return frozenset(devices) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ad8b649aaa4e..289dd02fdd52 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -849,15 +849,29 @@ def is_torch_xpu_available(check_device=False): return hasattr(torch, "xpu") and torch.xpu.is_available() +@lru_cache() def is_bitsandbytes_available(): - if not is_torch_available(): + if not is_torch_available() or not _bitsandbytes_available: return False - # bitsandbytes throws an error if cuda is not available - # let's avoid that by adding a simple check import torch - return _bitsandbytes_available and torch.cuda.is_available() + # `bitsandbytes` versions older than 0.43.1 eagerly require CUDA at import time, + # so those versions of the library are practically only available when CUDA is too. + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.1"): + return torch.cuda.is_available() + + # Newer versions of `bitsandbytes` can be imported on systems without CUDA. + return True + + +def is_bitsandbytes_multi_backend_available() -> bool: + if not is_bitsandbytes_available(): + return False + + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) def is_flash_attn_2_available(): diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 785402b3f798..0ac9b3d82fc7 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -30,12 +30,13 @@ pipeline, ) from transformers.testing_utils import ( + apply_skip_if_not_implemented, is_bitsandbytes_available, is_torch_available, require_accelerate, require_bitsandbytes, require_torch, - require_torch_gpu, + require_torch_gpu_if_bnb_not_multi_backend_enabled, require_torch_multi_gpu, slow, torch_device, @@ -85,7 +86,7 @@ def forward(self, input, *args, **kwargs): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class Base4bitTest(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -111,6 +112,7 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) +@apply_skip_if_not_implemented class Bnb4BitTest(Base4bitTest): def setUp(self): super().setUp() @@ -206,7 +208,7 @@ def test_rwkv_4bit(self): tok = AutoTokenizer.from_pretrained(model_id) text = "Hello my name is" - input_ids = tok.encode(text, return_tensors="pt").to(0) + input_ids = tok.encode(text, return_tensors="pt").to(torch_device) _ = model.generate(input_ids, max_new_tokens=30) @@ -217,7 +219,9 @@ def test_generate_quality(self): the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = self.model_4bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -234,7 +238,7 @@ def test_generate_quality_config(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") output_sequences = model_4bit_from_config.generate( - input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10 + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -252,7 +256,9 @@ def test_generate_quality_dequantize(self): model_4bit.dequantize() encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_4bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -267,15 +273,18 @@ def test_device_assignment(self): self.assertEqual(self.model_4bit.device.type, "cpu") self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) - # Move back to CUDA device - self.model_4bit.to(0) - self.assertEqual(self.model_4bit.device, torch.device(0)) - self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + if torch.cuda.is_available(): + # Move back to CUDA device + self.model_4bit.to("cuda") + self.assertEqual(self.model_4bit.device.type, "cuda") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. - Checks also if other models are casted correctly. + Test whether attempting to change the device or cast the dtype of a model + after converting it to 4-bit precision will raise an appropriate error. + The test ensures that such operations are prohibited on 4-bit models + to prevent invalid conversions. """ # Moving with `to` or `cuda` is not supported with versions < 0.43.2. @@ -297,25 +306,24 @@ def test_device_and_dtype_assignment(self): self.model_4bit.to(torch.float16) with self.assertRaises(ValueError): - # Tries with a `dtype` and `device` - self.model_4bit.to(device="cuda:0", dtype=torch.float16) - - with self.assertRaises(ValueError): - # Tries with a cast + # Tries to cast the 4-bit model to float32 using `float()` self.model_4bit.float() with self.assertRaises(ValueError): - # Tries with a cast + # Tries to cast the 4-bit model to float16 using `half()` self.model_4bit.half() # Test if we did not break anything + self.model_4bit.to(torch.device(torch_device)) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") self.model_fp16 = self.model_fp16.to(torch.float32) - _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) - # Check that this does not throw an error - _ = self.model_fp16.cuda() + if torch.cuda.is_available(): + # Check that this does not throw an error + _ = self.model_fp16.cuda() # Check this does not throw an error _ = self.model_fp16.to("cpu") @@ -344,8 +352,9 @@ def test_bnb_4bit_wrong_config(self): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow +@apply_skip_if_not_implemented class Bnb4BitT5Test(unittest.TestCase): @classmethod def setUpClass(cls): @@ -375,14 +384,14 @@ def test_inference_without_keep_in_fp32(self): # test with `google-t5/t5-small` model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto") - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_4bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) T5ForConditionalGeneration._keep_in_fp32_modules = modules @@ -400,17 +409,18 @@ def test_inference_with_keep_in_fp32(self): # there was a bug with decoders - this test checks that it is fixed self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit)) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_4bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) +@apply_skip_if_not_implemented class Classes4BitModelTest(Base4bitTest): def setUp(self): super().setUp() @@ -460,6 +470,7 @@ def test_correct_head_class(self): self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter) +@apply_skip_if_not_implemented class Pipeline4BitTest(Base4bitTest): def setUp(self): super().setUp() @@ -469,7 +480,8 @@ def tearDown(self): TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 """ - del self.pipe + if hasattr(self, "pipe"): + del self.pipe gc.collect() torch.cuda.empty_cache() @@ -484,7 +496,12 @@ def test_pipeline(self): self.pipe = pipeline( "text-generation", model=self.model_name, - model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16}, + model_kwargs={ + "device_map": "auto", + "load_in_4bit": True, + # float16 isn't supported on CPU, use bfloat16 instead + "torch_dtype": torch.bfloat16 if torch_device == "cpu" else torch.float16, + }, max_new_tokens=self.MAX_NEW_TOKENS, ) @@ -494,6 +511,7 @@ def test_pipeline(self): @require_torch_multi_gpu +@apply_skip_if_not_implemented class Bnb4bitTestMultiGpu(Base4bitTest): def setUp(self): super().setUp() @@ -515,10 +533,13 @@ def test_multi_gpu_loading(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Second real batch - output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_parallel = model_parallel.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) +@apply_skip_if_not_implemented class Bnb4BitTestTraining(Base4bitTest): def setUp(self): self.model_name = "facebook/opt-350m" @@ -531,7 +552,10 @@ def test_training(self): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True) - self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + if torch.cuda.is_available(): + self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + else: + self.assertTrue(all(param.device.type == "cpu" for param in model.parameters())) for param in model.parameters(): param.requires_grad = False # freeze the model - train adapters later @@ -547,10 +571,10 @@ def test_training(self): module.v_proj = LoRALayer(module.v_proj, rank=16) # Step 3: dummy batch - batch = self.tokenizer("Test batch ", return_tensors="pt").to(0) + batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device) # Step 4: Check if the gradient is not None - with torch.cuda.amp.autocast(): + with torch.autocast(torch_device): out = model.forward(**batch) out.logits.norm().backward() @@ -562,6 +586,7 @@ def test_training(self): self.assertTrue(module.weight.grad is None) +@apply_skip_if_not_implemented class Bnb4BitGPT2Test(Bnb4BitTest): model_name = "openai-community/gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187 @@ -570,8 +595,9 @@ class Bnb4BitGPT2Test(Bnb4BitTest): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow +@apply_skip_if_not_implemented class BaseSerializationTest(unittest.TestCase): model_name = "facebook/opt-125m" input_text = "Mars colonists' favorite meals are" @@ -635,7 +661,9 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa d1[k].quant_state.as_dict().values(), ): if isinstance(v0, torch.Tensor): - self.assertTrue(torch.equal(v0, v1.to(v0.device))) + # The absmax will not be saved in the quant_state when using NF4 in CPU + if v0.numel() != 0: + self.assertTrue(torch.equal(v0, v1.to(v0.device))) else: self.assertTrue(v0 == v1) @@ -659,6 +687,7 @@ def _decode(token): ) +@apply_skip_if_not_implemented class ExtendedSerializationTest(BaseSerializationTest): """ tests more combinations of parameters @@ -706,8 +735,9 @@ class GPTSerializationTest(BaseSerializationTest): @require_bitsandbytes @require_accelerate -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow +@apply_skip_if_not_implemented class Bnb4BitTestBasicConfigTest(unittest.TestCase): def test_load_in_4_and_8_bit_fails(self): with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index ca3f043c749a..5a99ab32e42b 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -30,14 +30,17 @@ pipeline, ) from transformers.testing_utils import ( + apply_skip_if_not_implemented, is_accelerate_available, + is_bitsandbytes_available, is_torch_available, require_accelerate, require_bitsandbytes, require_torch, - require_torch_gpu, + require_torch_gpu_if_bnb_not_multi_backend_enabled, require_torch_multi_gpu, slow, + torch_device, ) @@ -77,10 +80,14 @@ def forward(self, input, *args, **kwargs): return self.module(input, *args, **kwargs) + self.adapter(input) +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class BaseMixedInt8Test(unittest.TestCase): # We keep the constants inside the init function and model loading inside setUp function @@ -108,6 +115,7 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) +@apply_skip_if_not_implemented class MixedInt8Test(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -240,7 +248,6 @@ def test_llm_skip(self): r""" A simple test to check if `llm_int8_skip_modules` works as expected """ - import bitsandbytes as bnb quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"]) seq_classification_model = AutoModelForSequenceClassification.from_pretrained( @@ -263,7 +270,9 @@ def test_generate_quality(self): the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = self.model_8bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -280,7 +289,7 @@ def test_generate_quality_config(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") output_sequences = model_8bit_from_config.generate( - input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10 + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -298,7 +307,9 @@ def test_generate_quality_dequantize(self): model_8bit.dequantize() encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_8bit.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -319,8 +330,10 @@ def test_raise_if_config_and_load_in_8bit(self): def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. - Checks also if other models are casted correctly. + Test whether attempting to change the device or cast the dtype of a model + after converting it to 8-bit precision will raise an appropriate error. + The test ensures that such operations are prohibited on 8-bit models + to prevent invalid conversions. """ with self.assertRaises(ValueError): # Tries with `str` @@ -332,21 +345,21 @@ def test_device_and_dtype_assignment(self): with self.assertRaises(ValueError): # Tries with a `device` - self.model_8bit.to(torch.device("cuda:0")) + self.model_8bit.to(torch.device(torch_device)) with self.assertRaises(ValueError): - # Tries with a `device` + # Tries to cast the 8-bit model to float32 using `float()` self.model_8bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries to cast the 4-bit model to float16 using `half()` self.model_8bit.half() # Test if we did not break anything encoded_input = self.tokenizer(self.input_text, return_tensors="pt") self.model_fp16 = self.model_fp16.to(torch.float32) - _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) # Check this does not throw an error _ = self.model_fp16.to("cpu") @@ -385,7 +398,9 @@ def test_int8_serialization(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_from_saved.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -410,7 +425,9 @@ def test_int8_serialization_regression(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_from_saved.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -435,7 +452,9 @@ def test_int8_serialization_sharded(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model_from_saved.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -455,7 +474,7 @@ def test_int8_from_pretrained(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -463,7 +482,7 @@ def test_int8_from_pretrained(self): @require_bitsandbytes @require_accelerate @require_torch -@require_torch_gpu +@require_torch_gpu_if_bnb_not_multi_backend_enabled @slow class MixedInt8T5Test(unittest.TestCase): @classmethod @@ -494,14 +513,14 @@ def test_inference_without_keep_in_fp32(self): # test with `google-t5/t5-small` model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_8bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) T5ForConditionalGeneration._keep_in_fp32_modules = modules @@ -511,7 +530,6 @@ def test_inference_with_keep_in_fp32(self): `flan-t5-small` uses `T5DenseGatedActDense` whereas `google-t5/t5-small` uses `T5DenseReluDense`. We need to test both cases. """ - import bitsandbytes as bnb from transformers import T5ForConditionalGeneration @@ -521,14 +539,14 @@ def test_inference_with_keep_in_fp32(self): # there was a bug with decoders - this test checks that it is fixed self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt)) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_8bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) def test_inference_with_keep_in_fp32_serialized(self): @@ -538,7 +556,6 @@ def test_inference_with_keep_in_fp32_serialized(self): `flan-t5-small` uses `T5DenseGatedActDense` whereas `google-t5/t5-small` uses `T5DenseReluDense`. We need to test both cases. """ - import bitsandbytes as bnb from transformers import T5ForConditionalGeneration @@ -553,14 +570,14 @@ def test_inference_with_keep_in_fp32_serialized(self): # there was a bug with decoders - this test checks that it is fixed self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt)) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) # test with `flan-t5-small` model = T5ForConditionalGeneration.from_pretrained( self.dense_act_model_name, load_in_8bit=True, device_map="auto" ) - encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) + encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) _ = model.generate(**encoded_input) @@ -614,6 +631,7 @@ def test_correct_head_class(self): self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter) +@apply_skip_if_not_implemented class MixedInt8TestPipeline(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -623,7 +641,8 @@ def tearDown(self): TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 """ - del self.pipe + if hasattr(self, "pipe"): + del self.pipe gc.collect() torch.cuda.empty_cache() @@ -648,6 +667,7 @@ def test_pipeline(self): @require_torch_multi_gpu +@apply_skip_if_not_implemented class MixedInt8TestMultiGpu(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -669,11 +689,14 @@ def test_multi_gpu_loading(self): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Second real batch - output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_parallel = model_parallel.generate( + input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10 + ) self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @require_torch_multi_gpu +@apply_skip_if_not_implemented class MixedInt8TestCpuGpu(BaseMixedInt8Test): def setUp(self): super().setUp() @@ -683,7 +706,7 @@ def check_inference_correctness(self, model): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Check the exactness of the results - output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) # Get the generation output_text = self.tokenizer.decode(output_parallel[0], skip_special_tokens=True) @@ -819,6 +842,7 @@ def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self): self.check_inference_correctness(model_8bit) +@apply_skip_if_not_implemented class MixedInt8TestTraining(BaseMixedInt8Test): def setUp(self): self.model_name = "facebook/opt-350m" @@ -831,7 +855,10 @@ def test_training(self): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True) - self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + if torch.cuda.is_available(): + self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) + else: + self.assertTrue(all(param.device.type == "cpu" for param in model.parameters())) for param in model.parameters(): param.requires_grad = False # freeze the model - train adapters later @@ -847,10 +874,10 @@ def test_training(self): module.v_proj = LoRALayer(module.v_proj, rank=16) # Step 3: dummy batch - batch = self.tokenizer("Test batch ", return_tensors="pt").to(0) + batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device) # Step 4: Check if the gradient is not None - with torch.cuda.amp.autocast(): + with torch.autocast(torch_device): out = model.forward(**batch) out.logits.norm().backward() @@ -862,6 +889,7 @@ def test_training(self): self.assertTrue(module.weight.grad is None) +@apply_skip_if_not_implemented class MixedInt8GPT2Test(MixedInt8Test): model_name = "openai-community/gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357 @@ -870,6 +898,9 @@ class MixedInt8GPT2Test(MixedInt8Test): EXPECTED_OUTPUTS.add("Hello my name is John Doe, and I'm a fan of the") # Expected values on a A10 EXPECTED_OUTPUTS.add("Hello my name is John Doe, and I am a member of the") + # Expected values on Intel CPU + EXPECTED_OUTPUTS.add("Hello my name is John Doe. I am a man. I am") + EXPECTED_OUTPUTS.add("Hello my name is John, and I'm a writer. I'm") def test_int8_from_pretrained(self): r""" @@ -887,6 +918,6 @@ def test_int8_from_pretrained(self): # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) From 52a0a757c4e38448675fe80540317f038e486f7f Mon Sep 17 00:00:00 2001 From: Tibor Reiss <75096465+tibor-reiss@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:35:23 +0200 Subject: [PATCH 35/54] Fix error string after refactoring into get_chat_template (#33652) * Fix error string after refactoring into get_chat_template * Take suggestion from CR Co-authored-by: Matt --------- Co-authored-by: Matt --- src/transformers/tokenization_utils_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b4490578a709..c6467bb7d7f7 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1966,10 +1966,9 @@ def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional # priority: `chat_template` argument > `tokenizer.chat_template` if self.chat_template is not None: chat_template = self.chat_template - else: raise ValueError( - "Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template " + "Cannot use chat template functions because tokenizer.chat_template is not set and no template " "argument was passed! For information about writing templates and setting the " "tokenizer.chat_template attribute, please see the documentation at " "https://huggingface.co/docs/transformers/main/en/chat_templating" From 400927e7796a6ae457b5d70f0007db92926f1d13 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Tue, 24 Sep 2024 09:10:51 -0400 Subject: [PATCH 36/54] uniformize git processor (#33668) * uniformize git processor * update doctring --- src/transformers/models/git/modeling_git.py | 2 +- src/transformers/models/git/processing_git.py | 64 +++++++++++-------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 59d3a406ec35..2d90b82069fd 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1191,7 +1191,7 @@ def forward( >>> text = "this is an image of two cats" - >>> inputs = processor(text, images=image, return_tensors="pt") + >>> inputs = processor(images=image, text=text, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 98649c644e72..3744d81a0aca 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -16,8 +16,16 @@ Image/Text processor class for GIT """ -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class GitProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class GitProcessor(ProcessorMixin): @@ -42,7 +50,14 @@ def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[GitProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -51,13 +66,13 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -68,7 +83,7 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -76,29 +91,26 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - tokenizer_kwargs, image_processor_kwargs = {}, {} - if kwargs: - tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys} - image_processor_kwargs = { - k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys - } - if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs) + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + GitProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(text_features) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """ From 3b0d24c85c54051f77239f49be8b04a309e0fca9 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:54:07 +0200 Subject: [PATCH 37/54] Modular `transformers`: modularity and inheritance for new model additions (#33248) * update exampel * update * push the converted diff files for testing and ci * correct one example * fix class attributes and docstring * nits * oups * fixed config! * update * nitd * class attributes are not matched against the other, this is missing * fixed overwriting self.xxx now onto the attributes I think * partial fix, now order with docstring * fix docstring order? * more fixes * update * fix missing docstrings! * examples don't all work yet * fixup * nit * updated * hick * update * delete * update * update * update * fix * all default * no local import * fix more diff * some fix related to "safe imports" * push fixed * add helper! * style * add a check * all by default * add the * update * FINALLY! * nit * fix config dependencies * man that is it * fix fix * update diffs * fix the last issue * re-default to all * alll the fixes * nice * fix properties vs setter * fixup * updates * update dependencies * make sure to install what needs to be installed * fixup * quick fix for now * fix! * fixup * update * update * updates * whitespaces * nit * fix * simplify everything, and make it file agnostic (should work for image processors) * style * finish fixing all import issues * fixup * empty modeling should not be written! * Add logic to find who depends on what * update * cleanup * update * update gemma to support positions * some small nits * this is the correct docstring for gemma2 * fix merging of docstrings * update * fixup * update * take doc into account * styling * update * fix hidden activation * more fixes * final fixes! * fixup * fixup instruct blip video * update * fix bugs * align gemma2 with the rest as well * updats * revert * update * more reversiom * grind * more * arf * update * order will matter * finish del stuff * update * rename to modular * fixup * nits * update makefile * fixup * update order of the checks! * fix * fix docstring that has a call inside * fiix conversion check * style * add some initial documentation * update * update doc * some fixup * updates * yups * Mostly todo gimme a minut * update * fixup * revert some stuff * Review docs for the modular transformers (#33472) Docs * good update * fixup * mmm current updates lead to this code * okay, this fixes it * cool * fixes * update * nit * updates * nits * fix doc * update * revert bad changes * update * updates * proper update * update * update? * up * update * cool * nits * nits * bon bon * fix * ? * minimise changes * update * update * update * updates? * fixed gemma2 * kind of a hack * nits * update * remove `diffs` in favor of `modular` * fix make fix copies --------- Co-authored-by: Lysandre Debut --- .circleci/config.yml | 5 +- Makefile | 2 + docs/source/en/_toctree.yml | 4 + docs/source/en/modular_transformers.md | 121 ++ examples/diff-conversion/README.md | 20 - examples/modular-transformers/README.md | 20 + .../configuration_dummy.py | 0 .../configuration_my_new_model.py | 196 +++ .../configuration_my_new_model2.py | 97 ++ .../configuration_new_model.py | 134 +++ .../configuration_super.py | 0 .../convert_examples.sh | 2 +- .../modular-transformers/modeling_dummy.py | 1053 ++++++++++++++++ .../modeling_dummy_bert.py | 1038 ++++++++++++++++ .../modeling_my_new_model2.py | 1059 +++++++++++++++++ .../modular-transformers/modeling_super.py | 953 +++++++++++++++ .../modular_dummy.py} | 3 +- .../modular_dummy_bert.py | 27 + .../modular_my_new_model.py} | 5 +- .../modular_my_new_model2.py} | 0 .../modular_new_model.py} | 7 +- .../modular-transformers/modular_roberta.py | 20 + .../modular_super.py} | 3 +- setup.py | 4 +- src/transformers/dependency_versions_table.py | 2 + .../models/gemma/configuration_gemma.py | 8 +- .../models/gemma/modeling_gemma.py | 340 +++--- .../gemma/{diff_gemma.py => modular_gemma.py} | 338 +++++- .../models/gemma2/configuration_gemma2.py | 49 +- .../models/gemma2/modeling_gemma2.py | 212 ++-- .../{diff_gemma2.py => modular_gemma2.py} | 490 +++++++- .../configuration_instructblipvideo.py | 24 +- .../modeling_instructblipvideo.py | 179 ++- ...pvideo.py => modular_instructblipvideo.py} | 189 +-- .../configuration_llava_next_video.py | 17 +- .../modeling_llava_next_video.py | 28 +- ...t_video.py => modular_llava_next_video.py} | 53 +- .../modeling_llava_onevision.py | 1 + utils/check_modular_conversion.py | 76 ++ utils/create_dependency_mapping.py | 69 ++ ...onverter.py => modular_model_converter.py} | 434 +++++-- 41 files changed, 6504 insertions(+), 778 deletions(-) create mode 100644 docs/source/en/modular_transformers.md delete mode 100644 examples/diff-conversion/README.md create mode 100644 examples/modular-transformers/README.md create mode 100644 examples/modular-transformers/configuration_dummy.py create mode 100644 examples/modular-transformers/configuration_my_new_model.py create mode 100644 examples/modular-transformers/configuration_my_new_model2.py create mode 100644 examples/modular-transformers/configuration_new_model.py create mode 100644 examples/modular-transformers/configuration_super.py rename examples/{diff-conversion => modular-transformers}/convert_examples.sh (83%) create mode 100644 examples/modular-transformers/modeling_dummy.py create mode 100644 examples/modular-transformers/modeling_dummy_bert.py create mode 100644 examples/modular-transformers/modeling_my_new_model2.py create mode 100644 examples/modular-transformers/modeling_super.py rename examples/{diff-conversion/diff_dummy.py => modular-transformers/modular_dummy.py} (97%) create mode 100644 examples/modular-transformers/modular_dummy_bert.py rename examples/{diff-conversion/diff_my_new_model.py => modular-transformers/modular_my_new_model.py} (84%) rename examples/{diff-conversion/diff_my_new_model2.py => modular-transformers/modular_my_new_model2.py} (100%) rename examples/{diff-conversion/diff_new_model.py => modular-transformers/modular_new_model.py} (85%) create mode 100644 examples/modular-transformers/modular_roberta.py rename examples/{diff-conversion/diff_super.py => modular-transformers/modular_super.py} (97%) rename src/transformers/models/gemma/{diff_gemma.py => modular_gemma.py} (67%) rename src/transformers/models/gemma2/{diff_gemma2.py => modular_gemma2.py} (50%) rename src/transformers/models/instructblipvideo/{diff_instructblipvideo.py => modular_instructblipvideo.py} (76%) rename src/transformers/models/llava_next_video/{diff_llava_next_video.py => modular_llava_next_video.py} (95%) create mode 100644 utils/check_modular_conversion.py create mode 100644 utils/create_dependency_mapping.py rename utils/{diff_model_converter.py => modular_model_converter.py} (60%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9932156aa969..ca2afc67c10e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -137,7 +137,7 @@ jobs: parallelism: 1 steps: - checkout - - run: uv pip install -e . + - run: uv pip install -e ".[quality]" - run: name: Show installed libraries and their versions command: pip freeze | tee installed.txt @@ -162,13 +162,14 @@ jobs: parallelism: 1 steps: - checkout - - run: uv pip install -e . + - run: uv pip install -e ".[quality]" - run: name: Show installed libraries and their versions command: pip freeze | tee installed.txt - store_artifacts: path: ~/transformers/installed.txt - run: python utils/check_copies.py + - run: python utils/check_modular_conversion.py - run: python utils/check_table.py - run: python utils/check_dummies.py - run: python utils/check_repo.py diff --git a/Makefile b/Makefile index d3998327cc71..710c555b74f6 100644 --- a/Makefile +++ b/Makefile @@ -36,6 +36,7 @@ autogenerate_code: deps_table_update repo-consistency: python utils/check_copies.py + python utils/check_modular_conversion.py python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py @@ -80,6 +81,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency fix-copies: python utils/check_copies.py --fix_and_overwrite + python utils/check_modular_conversion.py --fix_and_overwrite python utils/check_table.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite python utils/check_doctest_list.py --fix_and_overwrite diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9f6b1e1782e8..482974a837de 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -5,6 +5,8 @@ title: Quick tour - local: installation title: Installation + - local: add_new_model + title: Adding a new model to `transformers` title: Get started - sections: - local: pipeline_tutorial @@ -149,6 +151,8 @@ title: Interoperability with GGUF files - local: tiktoken title: Interoperability with TikToken files + - local: modular_transformers + title: Modularity in `transformers` title: Developer guides - sections: - local: quantization/overview diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md new file mode 100644 index 000000000000..33d2bb948348 --- /dev/null +++ b/docs/source/en/modular_transformers.md @@ -0,0 +1,121 @@ +# Modular transformers + +`transformers` is an opinionated framework; our philosophy is defined in the following [conceptual guide](./philosophy). + +The core of that philosophy is exemplified by the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy) +aspect of the library. This component's downside is that it limits the inheritance and importability of components from +files to others in the toolkit. + +As a result, model components tend to be repeated across many files. There are as many attention layers defined +in `transformers` as there are models, and a significant number of those are identical to each other. +The unfortunate consequence is that independent implementations tend to diverge as fixes and changes get applied +to specific parts of the code. + +In order to balance this issue, we introduced the concept of "copies" across the library. By adding a comment indicating +that code is a copy of another, we can enforce through CI and local commands that copies do not diverge. However, +while the complexity is low, this is often quite tedious to do. + +And, finally, this contributes to adding a significant overhead to contributing models which we would like to remove. +This approach often requires model contributions to add modeling code (~1k lines), processor (~500 lines), tests, docs, +etc. Model contribution PRs rarely add less than 3-5k lines of code, with much of this code being boilerplate. + +This raises the bar for contributions, and with Modular Transformers, we're aiming to lower the bar to a much more +acceptable point. + +## What is it? + +Modular Transformers introduces the concept of a "modular" file to a model folder. This modular file accepts code +that isn't typically accepted in modeling/processing files, as it allows importing from neighbouring models as well +as inheritance from classes to others. + +This modular file defines models, processors, and the configuration class that would otherwise be defined in their +respective modules. + +Finally, this feature introduces a new `linter` which will "unravel" the modular file into the "single model, single +file" directory structure. These files will get auto-generated every time the script is run; reducing the required +contributions to the modular file, and therefore only to the changes between the contributed model and others. + +Model users will end up importing and using the single-file interface, so no change is expected here. Doing this, we +hope to combine the best of both worlds: enabling simple contributions while sticking to our philosophy. + +This is therefore a replacement for the `# Copied from` markers, and previously contributed models can be expected to +be moved to the new Modular Transformers format in the coming months. + +### Details + +The "linter", which unravels the inheritance and creates all single-files from the modular file, will flatten the +inheritance while trying to be invisible to Python users. At this time, the linter flattens a **single** level of +inheritance. + +For example: +- If a configuration class inherits from another and adds/deletes an argument, the generated file will either directly + reference it (in case of addition) or completely remove it (in case of deletion). +- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically + inferred. All submodules will be automatically inferred from the superclass. + +You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular` +file, and the corresponding files will be created for you. + +### Enforcement + +[TODO] We are introducing a new test, that makes sure the generated content matches what is present in the `modular_xxxx.py` + +### Examples + +Here is a quick example with BERT and RoBERTa. The two models are intimately related: their modeling implementation +differs solely by a change in the embedding layer. + +Instead of redefining the model entirely, here is what the `modular_roberta.py` file looks like for the modeling & +configuration classes (for the sake of the example, the tokenizer is ignored at this time as very different). + +```python +from torch import nn +from ..bert.configuration_bert import BertConfig +from ..bert.modeling_bert import ( + BertModel, + BertEmbeddings, + BertForMaskedLM +) + +# The RoBERTa config is identical to BERT's config +class RobertaConfig(BertConfig): + model_type = 'roberta' + +# We redefine the embeddings here to highlight the padding ID difference, and we redefine the position embeddings +class RobertaEmbeddings(BertEmbeddings): + def __init__(self, config): + super().__init__(config()) + + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + +# The RoBERTa model is identical to the BERT model, except for the embedding layer. +# We redefine the embeddings above, so here there is no need to do additional work +class RobertaModel(BertModel): + def __init__(self, config): + super().__init__(config) + self.embeddings = RobertaEmbeddings(config) + + +# The heads now only need to redefine the model inside to the correct `RobertaModel` +class RobertaForMaskedLM(BertForMaskedLM): + def __init__(self, config): + super().__init__(config) + self.model = RobertaModel(config) +``` + +Note that if you do not use the dependency that you defined, you will have the following error: + +```bash +ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used + when you define `BertModel`, as it is one of it's direct dependencies. Make sure + you use it in the `__init__` function. +``` + +Additionally, you may find a list of examples here: + +## What it is not + +It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual. \ No newline at end of file diff --git a/examples/diff-conversion/README.md b/examples/diff-conversion/README.md deleted file mode 100644 index a575a83b015c..000000000000 --- a/examples/diff-conversion/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# Using the `diff_converter` linter - -`pip install libcst` is a must! - -# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs - -The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`. - -Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage. - -`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"` - -## How it works -We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies. - -The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file. -We use ruff to automatically remove the potential duplicate imports. - -## Why we use libcst instead of the native AST? -AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst` \ No newline at end of file diff --git a/examples/modular-transformers/README.md b/examples/modular-transformers/README.md new file mode 100644 index 000000000000..4eba1d03aebc --- /dev/null +++ b/examples/modular-transformers/README.md @@ -0,0 +1,20 @@ +# Using the `modular_converter` linter + +`pip install libcst` is a must! + +# `sh examples/modular-transformers/convert_examples.sh` to get the converted outputs + +The modular converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular file like `modular_gemma.py` into a `single model single file`. + +Examples of possible usage are available in the `examples/modular-transformers`, or `modular_gemma` for a full model usage. + +`python utils/modular_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/modular-transformers/modular_my_new_model2.py"` + +## How it works +We use the `libcst` parser to produce an AST representation of the `modular_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the modularerence dependencies. + +The code from the `modular` file and the class dependency mapping are "merged" to produce the single model single file. +We use ruff to automatically remove the potential duplicate imports. + +## Why we use libcst instead of the native AST? +AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst` \ No newline at end of file diff --git a/examples/modular-transformers/configuration_dummy.py b/examples/modular-transformers/configuration_dummy.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py new file mode 100644 index 000000000000..d7c946dbe318 --- /dev/null +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -0,0 +1,196 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class MyNewModelConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MyNewModel-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MyNewModelModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens, + MyNewModel 2 up to 4096, CodeMyNewModel up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'my_new_model3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads + new_param (`int`, *optional*, defaults to `False`): + A fun new parameter + + ```python + >>> from transformers import MyNewModelModel, MyNewModelConfig + + >>> # Initializing a MyNewModel my_new_model-7b style configuration + >>> configuration = MyNewModelConfig() + + >>> # Initializing a model from the my_new_model-7b style configuration + >>> model = MyNewModelModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "my_new_model" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=True, + head_dim=None, + new_param=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.mlp_bias = mlp_bias + self.new_param = new_param diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py new file mode 100644 index 000000000000..b940d8d93b30 --- /dev/null +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -0,0 +1,97 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class MyNewModel2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "my_new_model2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/examples/modular-transformers/configuration_new_model.py b/examples/modular-transformers/configuration_new_model.py new file mode 100644 index 000000000000..7d57f9fe25b0 --- /dev/null +++ b/examples/modular-transformers/configuration_new_model.py @@ -0,0 +1,134 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Example where we only want to overwrite the defaults of an init + +from transformers import PretrainedConfig + + +class NewModelConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NewModelModel`]. It is used to instantiate an NewModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the NewModel-7B. + e.g. [google/new_model-7b](https://huggingface.co/google/new_model-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the NewModel model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NewModelModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import NewModelModel, NewModelConfig + >>> # Initializing a NewModel new_model-7b style configuration + >>> configuration = NewModelConfig() + >>> # Initializing a model from the new_model-7b style configuration + >>> model = NewModelModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "new_model" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256030, + hidden_size=64, + intermediate_size=90, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=1500, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def num_heads(self): + return self.num_attention_heads diff --git a/examples/modular-transformers/configuration_super.py b/examples/modular-transformers/configuration_super.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/diff-conversion/convert_examples.sh b/examples/modular-transformers/convert_examples.sh similarity index 83% rename from examples/diff-conversion/convert_examples.sh rename to examples/modular-transformers/convert_examples.sh index 1cfdc3e33cdf..4af31f1b4268 100644 --- a/examples/diff-conversion/convert_examples.sh +++ b/examples/modular-transformers/convert_examples.sh @@ -1,7 +1,7 @@ #!/bin/bash # Iterate over each file in the current directory -for file in examples/diff-conversion/diff_*; do +for file in examples/modular-transformers/modular_*; do # Check if it's a regular file if [ -f "$file" ]; then # Call the Python script with the file name as an argument diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py new file mode 100644 index 000000000000..5dd76c603035 --- /dev/null +++ b/examples/modular-transformers/modeling_dummy.py @@ -0,0 +1,1053 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from math import log +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from .configuration_dummy import DummyConfig + + +def _pre_process_input(input_ids): + print(log(input_ids)) + return input_ids + + +logger = logging.get_logger(__name__) + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class DummyRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DummyRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DummyRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[DummyConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`DummyRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DummyMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DummyAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DummyConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = DummyRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class DummyFlashAttention2(DummyAttention): + """ + Dummy flash attention module. This module inherits from `DummyAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DummyRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class DummySdpaAttention(DummyAttention): + """ + Dummy attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DummyAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DummyAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DummyModel is using DummySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +DUMMY_ATTENTION_CLASSES = { + "eager": DummyAttention, + "flash_attention_2": DummyFlashAttention2, + "sdpa": DummySdpaAttention, +} + + +class DummyDecoderLayer(nn.Module): + def __init__(self, config: DummyConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DUMMY_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = DummyMLP(config) + self.input_layernorm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DUMMY_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DummyConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Dummy Model outputting raw hidden-states without any specific head on top.", + DUMMY_START_DOCSTRING, +) +class DummyPreTrainedModel(PreTrainedModel): + config_class = DummyConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DummyDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DUMMY_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Dummy Model outputting raw hidden-states without any specific head on top.", + DUMMY_START_DOCSTRING, +) +class DummyModel(DummyPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DummyDecoderLayer`] + + Args: + config: DummyConfig + """ + + def __init__(self, config: DummyConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DummyDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DummyRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DUMMY_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + input_ids = _pre_process_input(input_ids) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py new file mode 100644 index 000000000000..bdedd1f5f5a2 --- /dev/null +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -0,0 +1,1038 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + get_torch_version, + logging, +) +from .configuration_dummy_bert import DummyBertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-dummy_bert/dummy_bert-base-uncased" +_CONFIG_FOR_DOC = "DummyBertConfig" + + +def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class DummyBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class DummyBertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in DummyBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class DummyBertSdpaSelfAttention(DummyBertSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from DummyBertSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "DummyBertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class DummyBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +DUMMY_BERT_SELF_ATTENTION_CLASSES = { + "eager": DummyBertSelfAttention, + "sdpa": DummyBertSdpaSelfAttention, +} + + +class DummyBertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = DUMMY_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = DummyBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class DummyBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class DummyBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class DummyBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = DummyBertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = DummyBertAttention(config, position_embedding_type="absolute") + self.intermediate = DummyBertIntermediate(config) + self.output = DummyBertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class DummyBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([DummyBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class DummyBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class DummyBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DummyBertConfig + load_tf_weights = load_tf_weights_in_dummy_bert + base_model_prefix = "dummy_bert" + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +DUMMY_BERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DummyBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DUMMY_BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DummyBert Model transformer outputting raw hidden-states without any specific head on top.", + DUMMY_BERT_START_DOCSTRING, +) +class DummyBertModel(DummyBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _no_split_modules = ["DummyBertEmbeddings", "DummyBertLayer"] + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = DummyBertEmbeddings(config) + self.encoder = DummyBertEncoder(config) + + self.pooler = DummyBertPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DUMMY_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py new file mode 100644 index 000000000000..fea7994a53ee --- /dev/null +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -0,0 +1,1059 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from .configuration_my_new_model2 import MyNewModel2Config + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class MyNewModel2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst MyNewModel2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class MyNewModel2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MyNewModel2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_activation is None: + logger.warning_once( + "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" + "MyNewModel2's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" + "`config.hidden_activation` if you want to override this behaviour.\n" + "See https://github.com/huggingface/transformers/pull/29402 for more details." + ) + config.hidden_activation = "gelu_pytorch_tanh" + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MyNewModel2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MyNewModel2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = MyNewModel2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MyNewModel2FlashAttention2(MyNewModel2Attention): + """ + MyNewModel2 flash attention module. This module inherits from `MyNewModel2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MyNewModel2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MyNewModel2SdpaAttention(MyNewModel2Attention): + """ + MyNewModel2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MyNewModel2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MyNewModel2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MyNewModel2Model is using MyNewModel2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MY_NEW_MODEL2_ATTENTION_CLASSES = { + "eager": MyNewModel2Attention, + "flash_attention_2": MyNewModel2FlashAttention2, + "sdpa": MyNewModel2SdpaAttention, +} + + +class MyNewModel2DecoderLayer(nn.Module): + def __init__(self, config: MyNewModel2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MY_NEW_MODEL2_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = MyNewModel2MLP(config) + self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MY_NEW_MODEL2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MyNewModel2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.", + MY_NEW_MODEL2_START_DOCSTRING, +) +class MyNewModel2PreTrainedModel(PreTrainedModel): + config_class = MyNewModel2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MyNewModel2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MY_NEW_MODEL2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.", + MY_NEW_MODEL2_START_DOCSTRING, +) +class MyNewModel2Model(MyNewModel2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MyNewModel2DecoderLayer`] + + Args: + config: MyNewModel2Config + """ + + def __init__(self, config: MyNewModel2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False # noqa: F841 + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # normalized + # MyNewModel2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +@add_start_docstrings( + """ + The MyNewModel2 Model transformer with a sequence classification head on top (linear layer). + + [`MyNewModel2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MY_NEW_MODEL2_START_DOCSTRING, +) +class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MyNewModel2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py new file mode 100644 index 000000000000..d91bdb1820c2 --- /dev/null +++ b/examples/modular-transformers/modeling_super.py @@ -0,0 +1,953 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from .configuration_super import SuperConfig + + +logger = logging.get_logger(__name__) + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class SuperRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + SuperRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class SuperRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[SuperConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class SuperMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class SuperAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = SuperRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class SuperFlashAttention2(SuperAttention): + """ + Super flash attention module. This module inherits from `SuperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (SuperRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class SuperSdpaAttention(SuperAttention): + """ + Super attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `SuperAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from SuperAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "SuperModel is using SuperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +SUPER_ATTENTION_CLASSES = { + "eager": SuperAttention, + "flash_attention_2": SuperFlashAttention2, + "sdpa": SuperSdpaAttention, +} + + +class SuperDecoderLayer(nn.Module): + def __init__(self, config: SuperConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = SUPER_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = SuperMLP(config) + self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +SUPER_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SuperConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Super Model outputting raw hidden-states without any specific head on top.", + SUPER_START_DOCSTRING, +) +class SuperPreTrainedModel(PreTrainedModel): + config_class = SuperConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SuperDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +SUPER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Super Model outputting raw hidden-states without any specific head on top.", + SUPER_START_DOCSTRING, +) +class SuperModel(SuperPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuperDecoderLayer`] + + Args: + config: SuperConfig + """ + + def __init__(self, config: SuperConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [SuperDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = SuperRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(SUPER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + out = super().forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + ) + out.logits *= 2**4 + return out + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask diff --git a/examples/diff-conversion/diff_dummy.py b/examples/modular-transformers/modular_dummy.py similarity index 97% rename from examples/diff-conversion/diff_dummy.py rename to examples/modular-transformers/modular_dummy.py index c5fd57f9f66e..33dc38d0b447 100644 --- a/examples/diff-conversion/diff_dummy.py +++ b/examples/modular-transformers/modular_dummy.py @@ -3,10 +3,11 @@ import torch -from transformers import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel +from ...cache_utils import Cache + def _pre_process_input(input_ids): print(log(input_ids)) diff --git a/examples/modular-transformers/modular_dummy_bert.py b/examples/modular-transformers/modular_dummy_bert.py new file mode 100644 index 000000000000..7a83a2e0ed2f --- /dev/null +++ b/examples/modular-transformers/modular_dummy_bert.py @@ -0,0 +1,27 @@ +from typing import List, Optional, Tuple, Union + +import torch + +from transformers.models.bert.modeling_bert import BertModel + +from ...modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions + + +class DummyBertModel(BertModel): + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + return super().forward(input_ids) diff --git a/examples/diff-conversion/diff_my_new_model.py b/examples/modular-transformers/modular_my_new_model.py similarity index 84% rename from examples/diff-conversion/diff_my_new_model.py rename to examples/modular-transformers/modular_my_new_model.py index dddcc1d61c11..c1ea8b0a7249 100644 --- a/examples/diff-conversion/diff_my_new_model.py +++ b/examples/modular-transformers/modular_my_new_model.py @@ -5,10 +5,11 @@ # here there is no `ARG` so we are gonna take parent doc class MyNewModelConfig(LlamaConfig): r""" - mlp_bias (`bool`, *optional*, defaults to `False`) + new_param (`int`, *optional*, defaults to `False`): + A fun new parameter """ def __init__(self, mlp_bias=True, new_param=0, **super_kwargs): + super().__init__(self, **super_kwargs) self.mlp_bias = mlp_bias self.new_param = new_param - super().__init__(self, **super_kwargs) diff --git a/examples/diff-conversion/diff_my_new_model2.py b/examples/modular-transformers/modular_my_new_model2.py similarity index 100% rename from examples/diff-conversion/diff_my_new_model2.py rename to examples/modular-transformers/modular_my_new_model2.py diff --git a/examples/diff-conversion/diff_new_model.py b/examples/modular-transformers/modular_new_model.py similarity index 85% rename from examples/diff-conversion/diff_new_model.py rename to examples/modular-transformers/modular_new_model.py index 1486d40c6cdb..166c7955c1b5 100644 --- a/examples/diff-conversion/diff_new_model.py +++ b/examples/modular-transformers/modular_new_model.py @@ -26,5 +26,10 @@ def __init__( rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, + **kwargs, ): - super().__init__(self) + super().__init__(self, **kwargs) + + @property + def num_heads(self): + return self.num_attention_heads diff --git a/examples/modular-transformers/modular_roberta.py b/examples/modular-transformers/modular_roberta.py new file mode 100644 index 000000000000..a3e0218f9320 --- /dev/null +++ b/examples/modular-transformers/modular_roberta.py @@ -0,0 +1,20 @@ +import torch.nn as nn + +from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel + + +class RobertaEmbeddings(BertEmbeddings): + def __init__(self, config): + super().__init__(config) + self.pad_token_id = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, config.pad_token_id + ) + + +class RobertaModel(BertModel): + def __init__(self, config): + super().__init__(self, config) + # Error out here. Why? Because `RobertaEmbeddings` is defined but not used. + # no, because it's defined, and RobertaModel should use RobertaEmbedding + # here if initialized that way it won't use the new embedding. diff --git a/examples/diff-conversion/diff_super.py b/examples/modular-transformers/modular_super.py similarity index 97% rename from examples/diff-conversion/diff_super.py rename to examples/modular-transformers/modular_super.py index 160f067ee01b..59909a41e4dc 100644 --- a/examples/diff-conversion/diff_super.py +++ b/examples/modular-transformers/modular_super.py @@ -2,10 +2,11 @@ import torch -from transformers import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel +from ...cache_utils import Cache + # example where we need some deps and some functions class SuperModel(LlamaModel): diff --git a/setup.py b/setup.py index 14a80d3321be..6ea9b192618e 100644 --- a/setup.py +++ b/setup.py @@ -192,6 +192,8 @@ "urllib3<2.0.0", "uvicorn", "pytest-rich", + "libcst", + "rich", ] @@ -345,7 +347,7 @@ def run(self): extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"] extras["ruff"] = deps_list("ruff") -extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3") +extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "urllib3", "libcst", "rich") extras["all"] = ( extras["tf"] diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c199884a1960..2634a7b6b3f2 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -97,4 +97,6 @@ "urllib3": "urllib3<2.0.0", "uvicorn": "uvicorn", "pytest-rich": "pytest-rich", + "libcst": "libcst", + "rich": "rich", } diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index e8de9ddcee2e..3ab61c522eff 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. @@ -21,7 +21,7 @@ # limitations under the License. -from transformers import PretrainedConfig +from ...configuration_utils import PretrainedConfig class GemmaConfig(PretrainedConfig): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8d9bb88686de..948dd8287b61 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. @@ -39,7 +39,6 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -51,63 +50,6 @@ from .configuration_gemma import GemmaConfig -logger = logging.get_logger(__name__) - - -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -128,7 +70,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) +logger = logging.get_logger(__name__) class GemmaRotaryEmbedding(nn.Module): @@ -159,30 +101,6 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -212,6 +130,30 @@ def forward(self, x, position_ids): return cos, sin +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_activation is None: + logger.warning_once( + "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" + "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" + "`config.hidden_activation` if you want to override this behaviour.\n" + "See https://github.com/huggingface/transformers/pull/29402 for more details." + ) + config.hidden_activation = "gelu_pytorch_tanh" + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -358,6 +300,94 @@ def forward( return attn_output, attn_weights, past_key_value +class GemmaSdpaAttention(GemmaAttention): + """ + Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from GemmaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + class GemmaFlashAttention2(GemmaAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays @@ -458,7 +488,6 @@ def forward( is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -468,92 +497,57 @@ def forward( return attn_output, attn_weights, past_key_value -class GemmaSdpaAttention(GemmaAttention): - """ - Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - # Adapted from GemmaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) - return attn_output, None, past_key_value + return causal_mask GEMMA_ATTENTION_CLASSES = { @@ -567,9 +561,7 @@ class GemmaDecoderLayer(nn.Module): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = GemmaMLP(config) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -830,9 +822,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False + return_legacy_cache = False # noqa: F841 if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + return_legacy_cache = True # noqa: F841 if past_key_values is None: past_key_values = DynamicCache() else: @@ -975,6 +967,7 @@ def _update_causal_mask( cache_position=cache_position, batch_size=input_tensor.shape[0], ) + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -1149,6 +1142,7 @@ def prepare_inputs_for_generation( position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) @@ -1230,7 +1224,7 @@ def set_input_embeddings(self, value): @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/modular_gemma.py similarity index 67% rename from src/transformers/models/gemma/diff_gemma.py rename to src/transformers/models/gemma/modular_gemma.py index dcc43bc74aec..ca89b6cf2a6d 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -21,8 +21,15 @@ from torch import nn from torch.nn import CrossEntropyLoss -from transformers import PretrainedConfig -from transformers.models.llama.modeling_llama import ( +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import is_torchdynamo_compiling, logging +from ..llama.modeling_llama import ( + LlamaDecoderLayer, LlamaFlashAttention2, LlamaForCausalLM, LlamaForSequenceClassification, @@ -32,14 +39,6 @@ repeat_kv, ) -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import CausalLMOutputWithPast -from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import logging - logger = logging.get_logger(__name__) @@ -216,6 +215,35 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): + """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding): + """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -340,8 +368,95 @@ def forward( return attn_output, attn_weights, past_key_value -# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention? -class GemmaFlashAttention2(LlamaFlashAttention2): +class GemmaSdpaAttention(GemmaAttention): + """ + Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from GemmaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -427,12 +542,12 @@ def forward( value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -442,7 +557,95 @@ def forward( return attn_output, attn_weights, past_key_value +GEMMA_ATTENTION_CLASSES = { + "eager": GemmaAttention, + "flash_attention_2": GemmaFlashAttention2, + "sdpa": GemmaSdpaAttention, +} + + +class GemmaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__(config) + self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.mlp = GemmaMLP(config) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + class GemmaModel(LlamaModel): + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet! + self.post_init() + def forward( self, input_ids: torch.LongTensor = None, @@ -455,7 +658,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -513,22 +716,72 @@ def forward( normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer - return super().forward( - causal_mask, - position_ids, - past_key_values, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - cache_position, - input_ids=None, - inputs_embeds=hidden_states, + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, ) # Example where we ony modify the docstring and call super -class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin): +class GemmaForCausalLM(LlamaForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.post_init() + def forward( self, input_ids: torch.LongTensor = None, @@ -542,18 +795,9 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM @@ -589,10 +833,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -618,8 +870,14 @@ def forward( class GemmaForSequenceClassification(LlamaForSequenceClassification): - pass + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.post_init() class GemmaForTokenClassification(LlamaForTokenClassification): - pass + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.post_init() diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 7da541207bfe..6f4b2eaf2a45 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. @@ -19,7 +19,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from transformers import PretrainedConfig + + +from ...configuration_utils import PretrainedConfig class Gemma2Config(PretrainedConfig): @@ -53,7 +55,8 @@ class Gemma2Config(PretrainedConfig): head_dim (`int`, *optional*, defaults to 256): The attention head dimension. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the decoder. + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -77,16 +80,17 @@ class Gemma2Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. + ```python >>> from transformers import Gemma2Model, Gemma2Config - >>> # Initializing a Gemma2 gemma2-9b style configuration + >>> # Initializing a Gemma2 gemma2-7b style configuration >>> configuration = Gemma2Config() - >>> # Initializing a model from the gemma2-9b style configuration + >>> # Initializing a model from the gemma2-7b style configuration >>> model = Gemma2Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -94,6 +98,7 @@ class Gemma2Config(PretrainedConfig): model_type = "gemma2" keys_to_ignore_at_inference = ["past_key_values"] + cache_implementation = "hybrid" def __init__( self, @@ -116,12 +121,19 @@ def __init__( rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, - final_logit_softcapping=30.0, - attn_logit_softcapping=50.0, query_pre_attn_scalar=224, sliding_window=4096, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -130,23 +142,14 @@ def __init__( self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads - self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.attn_logit_softcapping = attn_logit_softcapping - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - self.final_logit_softcapping = final_logit_softcapping + self.hidden_activation = hidden_activation self.query_pre_attn_scalar = query_pre_attn_scalar self.sliding_window = sliding_window - self.cache_implementation = "hybrid" + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6b55500739b4..22438ccc80a6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. @@ -22,13 +22,14 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn as nn import torch.utils.checkpoint -from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,7 +40,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, @@ -49,66 +49,6 @@ from .configuration_gemma2 import Gemma2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) - - -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - - class Gemma2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -129,6 +69,24 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class Gemma2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +logger = logging.get_logger(__name__) + + class Gemma2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -191,21 +149,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class Gemma2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -253,12 +196,12 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.rotary_emb = Gemma2RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -502,9 +445,11 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: @@ -515,6 +460,7 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -533,6 +479,59 @@ def forward( return attn_output, None, past_key_value +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + GEMMA2_ATTENTION_CLASSES = { "eager": Gemma2Attention, "flash_attention_2": Gemma2FlashAttention2, @@ -543,19 +542,16 @@ def forward( class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() - self.config = config self.hidden_size = config.hidden_size - self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.config = config self.is_sliding = not bool(layer_idx % 2) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -567,6 +563,25 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor if self.config._attn_implementation == "flash_attention_2": @@ -580,6 +595,7 @@ def forward( attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) if attention_mask.shape[-1] <= 1: # when decoding attention_mask = attention_mask[:, :, :, -self.sliding_window :] + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -711,13 +727,20 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridCache`, *optional*): + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other - cache classes. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` @@ -812,8 +835,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # Instantiate an empty cache if needed. - if use_cache and past_key_values is None: + if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, @@ -828,6 +850,7 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -844,6 +867,7 @@ def forward( normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -880,7 +904,6 @@ def forward( hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1009,6 +1032,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" + if self.training and self.config._attn_implementation != "eager": logger.warning_once( "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " @@ -1187,10 +1211,10 @@ def set_input_embeddings(self, value): @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py similarity index 50% rename from src/transformers/models/gemma2/diff_gemma2.py rename to src/transformers/models/gemma2/modular_gemma2.py index a66ce3160b5f..7aca6650961e 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,30 +13,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch +import torch.nn as nn import torch.utils.checkpoint from torch.nn import CrossEntropyLoss -from transformers.models.gemma.configuration_gemma import GemmaConfig -from transformers.models.gemma.modeling_gemma import ( +from ...activations import ACT2FN +from ...cache_utils import Cache, HybridCache +from ...configuration_utils import PretrainedConfig +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, +) +from ..gemma.modeling_gemma import ( GemmaAttention, GemmaDecoderLayer, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, + GemmaPreTrainedModel, GemmaRMSNorm, + _prepare_4d_causal_attention_mask_with_cache_position, apply_rotary_pos_emb, repeat_kv, ) -from ...cache_utils import Cache -from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging - if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -45,33 +56,230 @@ logger = logging.get_logger(__name__) -class Gemma2Config(GemmaConfig): - cache_implementation = "hybrid" # TODO this is not properly ported, but cls attr is better +class Gemma2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma2-7B. + e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma2Model`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. + + ```python + >>> from transformers import Gemma2Model, Gemma2Config + >>> # Initializing a Gemma2 gemma2-7b style configuration + >>> configuration = Gemma2Config() + >>> # Initializing a model from the gemma2-7b style configuration + >>> model = Gemma2Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma2" + keys_to_ignore_at_inference = ["past_key_values"] + cache_implementation = "hybrid" def __init__( self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, query_pre_attn_scalar=224, sliding_window=4096, final_logit_softcapping=30.0, - **super_kwargs, + attn_logit_softcapping=50.0, + **kwargs, ): - super().__init__(self, **super_kwargs) + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation self.query_pre_attn_scalar = query_pre_attn_scalar self.sliding_window = sliding_window - self.cache_implementation = "hybrid" self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping class Gemma2RMSNorm(GemmaRMSNorm): pass +class Gemma2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + class Gemma2Attention(GemmaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value class Gemma2FlashAttention2(Gemma2Attention): @@ -119,9 +327,19 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if attention_mask is not None: + seq_len = attention_mask.shape[1] + key_states = key_states[:, :, :seq_len] + value_states = value_states[:, :, :seq_len] + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -156,7 +374,6 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING attn_output = _flash_attention_forward( query_states, key_states, @@ -166,7 +383,9 @@ def forward( dropout=dropout_rate, softmax_scale=self.scaling, is_causal=self.is_causal, + sliding_window=self.sliding_window, use_top_left_mask=self._flash_attn_uses_top_left_mask, + softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -227,7 +446,12 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -269,8 +493,9 @@ def forward( class Gemma2DecoderLayer(GemmaDecoderLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__(config, layer_idx) - - self.is_sliding = bool(layer_idx % 2) + self.config = config + self.is_sliding = not bool(layer_idx % 2) + self.mlp = Gemma2MLP(config) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -286,11 +511,18 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - attention_mask = attention_mask * torch.tril( - torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1]) - ) - if cache_position[0] > 0: - attention_mask = attention_mask[:, -self.sliding_window :] + # Flash-attn is a 2D tensor + if self.config._attn_implementation == "flash_attention_2": + if past_key_value is not None: # when decoding + attention_mask = attention_mask[:, -self.sliding_window :] + else: + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] residual = hidden_states @@ -326,13 +558,38 @@ def forward( return outputs -class Gemma2Model(GemmaModel): +class Gemma2PreTrainedModel(GemmaPreTrainedModel): + _supports_quantized_cache = False + + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + """ + Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. + SDPA reduces the model performance on Gemma2 because of the logits softcapping. + """ + config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "sdpa": + config._attn_implementation = "eager" + + return config + + +class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): + def __init__(self, config: Gemma2Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.post_init() + def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[HybridCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -361,8 +618,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None and not self.training: + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + batch_size=batch_size, + max_cache_len=seq_len, + device=self.device, + dtype=inputs_embeds.dtype, + ) + if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -437,50 +707,50 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_key_values: Cache, + past_key_values: HybridCache, output_attentions: bool, ): + # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. + # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape + # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible + # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None + return attention_mask dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if past_key_values is not None: + if isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_length() else: - target_length = attention_mask.shape[-1] + target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - if attention_mask is not None and attention_mask.dim() == 4: - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) return causal_mask -class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin): +class Gemma2ForCausalLM(GemmaForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = Gemma2Model(config) + self.post_init() + def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[HybridCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -488,18 +758,9 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM @@ -514,12 +775,17 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -535,15 +801,23 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping + # TODO: remove the float() operation in v4.46 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -567,10 +841,94 @@ def forward( attentions=outputs.attentions, ) + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s + # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride + # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the + # batch size = 1 case, `position_ids` is already contiguous but with varying stride + # which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if ( + isinstance(past_key_values, HybridCache) + and attention_mask.ndim == 2 + and not self.config._attn_implementation == "flash_attention_2" + ): + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + class Gemma2ForSequenceClassification(GemmaForSequenceClassification): - pass + def __init__(self, config): + super().__init__(config) + self.model = Gemma2Model(config) + self.post_init() class Gemma2ForTokenClassification(GemmaForTokenClassification): - pass + def __init__(self, config): + super().__init__(config) + self.model = Gemma2Model(config) + self.post_init() diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index 051e8e218071..02672bdce83a 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. @@ -24,9 +24,7 @@ from ...configuration_utils import PretrainedConfig from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from ...utils import ( - logging, -) +from ...utils import logging from ..auto import CONFIG_MAPPING @@ -36,8 +34,8 @@ class InstructBlipVideoVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to - instantiate a Instructblipvideo vision encoder according to the specified arguments, defining the model architecture. - Instantiating a configuration defaults will yield a similar configuration to that of the Instructblipvideo + instantiate a InstructBlipVideo vision encoder according to the specified arguments, defining the model architecture. + Instantiating a configuration defaults will yield a similar configuration to that of the InstructBlipVideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the @@ -58,7 +56,7 @@ class InstructBlipVideoVisionConfig(PretrainedConfig): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. to 1e-5): The epsilon used by the layer + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. to 1e-5): The epsilon used by the layer normalization layers. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. @@ -137,9 +135,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], class InstructBlipVideoQFormerConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to - instantiate a Instructblipvideo Querying Transformer (Q-Former) model according to the specified arguments, defining the + instantiate a InstructBlipVideo Querying Transformer (Q-Former) model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - the Instructblipvideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) + the InstructBlipVideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -189,7 +187,7 @@ class InstructBlipVideoQFormerConfig(PretrainedConfig): ```python >>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel - >>> # Initializing a Instructblipvideo Salesforce/instruct-blip-flan-t5 style configuration + >>> # Initializing a InstructBlipVideo Salesforce/instruct-blip-flan-t5 style configuration >>> configuration = InstructBlipVideoQFormerConfig() >>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration @@ -360,7 +358,7 @@ def from_vision_qformer_text_configs( **kwargs, ): r""" - Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a Instructblipvideo vision model, Q-Former and + Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and language model configurations. Returns: diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index bcc299b1ba78..e165fb411af8 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -354,6 +353,21 @@ def _init_weights(self, module): module.bias.data.zero_() +INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): @@ -371,6 +385,71 @@ def _init_weights(self, module): Whether to interpolate the pre-trained position encodings. """ +INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See + [`InstructBlipVideoProcessor.__call__`] for details. + + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be + provided to serve as text prompt, which the language model can continue. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an + encoder-decoder language model (like T5) is used. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) + + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. +""" + # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo class InstructBlipVideoEncoder(nn.Module): @@ -459,87 +538,6 @@ def forward( ) -INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See - [`InstructBlipVideoProcessor.__call__`] for details. - - qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided - to serve as text prompt, which the Q-Former model will encode. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be - provided to serve as text prompt, which the language model can continue. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an - encoder-decoder language model (like T5) is used. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) - - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - Only relevant in case an encoder-decoder language model (like T5) is used. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. -""" - - # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): main_input_name = "pixel_values" @@ -1089,7 +1087,7 @@ def forward( class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel): """ - Querying Transformer (Q-Former), used in Instructblipvideo. Slightly modified from BLIP-2 as it also takes the + Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the instruction as input. """ @@ -1285,7 +1283,7 @@ def forward( @add_start_docstrings( """ - Instructblipvideo Model for generating text given an image and an optional text prompt. The model consists of a vision + InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision encoder, Querying Transformer (Q-Former) and a language model. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue @@ -1358,7 +1356,7 @@ def _preprocess_accelerate(self): hf_device_map = self.hf_device_map if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: - # warn users about unexpected behavior when using multi-GPU + Instructblipvideo + `accelerate`. + # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`. logger.warning( "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." @@ -1505,7 +1503,6 @@ def forward( ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -1584,11 +1581,11 @@ def generate( interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: - """ + r""" Overrides `generate` function to be able to use the model as a conditional generator. Args: - pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or + pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): The sequence used as a prompt to be fed to the Q-Former module. diff --git a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py similarity index 76% rename from src/transformers/models/instructblipvideo/diff_instructblipvideo.py rename to src/transformers/models/instructblipvideo/modular_instructblipvideo.py index be569abc9137..2128f25df662 100644 --- a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -21,32 +21,18 @@ from torch.nn import CrossEntropyLoss from transformers.models.instructblip.configuration_instructblip import ( - InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig, ) from transformers.models.instructblip.modeling_instructblip import ( - InstructBlipAttention, - InstructBlipEncoder, - InstructBlipEncoderLayer, InstructBlipForConditionalGeneration, InstructBlipForConditionalGenerationModelOutput, - InstructBlipMLP, - InstructBlipPreTrainedModel, - InstructBlipQFormerAttention, - InstructBlipQFormerEmbeddings, - InstructBlipQFormerEncoder, - InstructBlipQFormerIntermediate, - InstructBlipQFormerLayer, - InstructBlipQFormerModel, - InstructBlipQFormerOutput, - InstructBlipQFormerSelfOutput, - InstructBlipVisionEmbeddings, - InstructBlipVisionModel, ) -from ...generation import GenerationMixin +from ...configuration_utils import PretrainedConfig +from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from ...utils import logging +from ..auto import CONFIG_MAPPING logger = logging.get_logger(__name__) @@ -60,76 +46,132 @@ class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig): pass -class InstructBlipVideoConfig(InstructBlipConfig): - pass - - -@dataclass -class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput): - pass - - -class InstructBlipVideoVisionEmbeddings(InstructBlipVisionEmbeddings): - pass - - -class InstructBlipVideoAttention(InstructBlipAttention): - pass - - -class InstructBlipVideoMLP(InstructBlipMLP): - pass - - -class InstructBlipVideoEncoderLayer(InstructBlipEncoderLayer): - pass - - -class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel): - pass - - -class InstructBlipVideoEncoder(InstructBlipEncoder): - pass - +class InstructBlipVideoConfig(PretrainedConfig): + r""" + [`InstructBlipVideoConfig`] is the configuration class to store the configuration of a + [`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified + arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with + the defaults will yield a similar configuration to that of the Instructblipvideo + [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture. -class InstructBlipVideoVisionModel(InstructBlipVisionModel): - pass + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`]. + qformer_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`]. + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize any [`PretrainedConfig`]. + num_query_tokens (`int`, *optional*, defaults to 32): + The number of query tokens passed through the Transformer. -class InstructBlipVideoQFormerSelfOutput(InstructBlipQFormerSelfOutput): - pass + video_token_index (`int`, *optional*): + Token index of special video token. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: -class InstructBlipVideoQFormerAttention(InstructBlipQFormerAttention): - pass + ```python + >>> from transformers import ( + ... InstructBlipVideoVisionConfig, + ... InstructBlipVideoQFormerConfig, + ... OPTConfig, + ... InstructBlipVideoConfig, + ... InstructBlipVideoForConditionalGeneration, + ... ) + >>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration + >>> configuration = InstructBlipVideoConfig() -class InstructBlipVideoQFormerIntermediate(InstructBlipQFormerIntermediate): - pass + >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration + >>> model = InstructBlipVideoForConditionalGeneration(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config -class InstructBlipVideoQFormerOutput(InstructBlipQFormerOutput): - pass + >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig + >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations + >>> vision_config = InstructBlipVideoVisionConfig() + >>> qformer_config = InstructBlipVideoQFormerConfig() + >>> text_config = OPTConfig() -class InstructBlipVideoQFormerLayer(InstructBlipQFormerLayer): - pass + >>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config) + ```""" + model_type = "instructblipvideo" -class InstructBlipVideoQFormerEncoder(InstructBlipQFormerEncoder): - pass + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + video_token_index=None, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.") + + if qformer_config is None: + qformer_config = {} + logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.") + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).") + + self.vision_config = InstructBlipVideoVisionConfig(**vision_config) + self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config) + text_model_type = text_config["model_type"] if "model_type" in text_config else "opt" + self.text_config = CONFIG_MAPPING[text_model_type](**text_config) + + self.tie_word_embeddings = self.text_config.tie_word_embeddings + self.is_encoder_decoder = self.text_config.is_encoder_decoder + + self.num_query_tokens = num_query_tokens + self.video_token_index = video_token_index + self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size + self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + + @classmethod + def from_vision_qformer_text_configs( + cls, + vision_config: InstructBlipVideoVisionConfig, + qformer_config: InstructBlipVideoQFormerConfig, + text_config: PretrainedConfig, + **kwargs, + ): + r""" + Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and + language model configurations. + Returns: + [`InstructBlipVideoConfig`]: An instance of a configuration object + """ -class InstructBlipVideoQFormerEmbeddings(InstructBlipQFormerEmbeddings): - pass + return cls( + vision_config=vision_config.to_dict(), + qformer_config=qformer_config.to_dict(), + text_config=text_config.to_dict(), + **kwargs, + ) -class InstructBlipVideoQFormerModel(InstructBlipQFormerModel): +@dataclass +class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput): pass -class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin): +class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): def forward( self, pixel_values: torch.FloatTensor, @@ -146,15 +188,6 @@ def forward( interpolate_pos_encoding: bool = False, ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size - - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., - config.vocab_size]` - - Returns: - - Examples: - ```python >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration >>> import torch @@ -339,11 +372,11 @@ def generate( interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: - """ + r""" Overrides `generate` function to be able to use the model as a conditional generator. Args: - pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or + pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): The sequence used as a prompt to be fed to the Q-Former module. diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index 3f310565b437..1631c1018306 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. @@ -20,17 +20,10 @@ # limitations under the License. -from transformers import PretrainedConfig - -from ...utils import ( - logging, -) +from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING -logger = logging.get_logger(__name__) - - class LlavaNextVideoConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an @@ -48,7 +41,7 @@ class LlavaNextVideoConfig(PretrainedConfig): ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32001): - The image token index to encode the image prompt. + The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 7ad9e0769eb3..b73b2f6994d9 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -130,6 +129,12 @@ def unpad_image(tensor, original_size): Returns: `torch.Tensor`: The unpadded image tensor. """ + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + original_size = original_size.tolist() original_height, original_width = original_size current_height, current_width = tensor.shape[1:] @@ -180,6 +185,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. @@ -191,6 +197,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None @@ -455,7 +462,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.vocab_size = model_embeds.num_embeddings return model_embeds - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration._merge_input_ids_with_image_features def _merge_input_ids_with_image_features( self, image_features, @@ -695,7 +701,7 @@ def _merge_input_ids_with_image_features( return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids - def pack_image_features(self, image_features, image_sizes, image_newline=None): + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -704,6 +710,8 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None): List of image feature tensor, each contains all the visual feature of all patches. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. image_newline (`torch.Tensor` of shape `(embed_dim)`) New line embedding vector. Returns: @@ -718,8 +726,14 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None): base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size - if height * width != base_image_feature.shape[0]: + + if vision_feature_select_strategy == "default": + expected_num_patches = height * width + elif vision_feature_select_strategy == "full": + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: raise ValueError("The number of patches is not consistent with the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py similarity index 95% rename from src/transformers/models/llava_next_video/diff_llava_next_video.py rename to src/transformers/models/llava_next_video/modular_llava_next_video.py index c5ca2bf00324..f0ec4578e488 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -21,15 +21,13 @@ import torch.utils.checkpoint from torch import nn -from transformers import PretrainedConfig from transformers.models.llava_next.modeling_llava_next import ( LlavaNextCausalLMOutputWithPast, LlavaNextForConditionalGeneration, - LlavaNextMultiModalProjector, image_size_to_num_patches, ) -from ...generation import GenerationMixin +from ...configuration_utils import PretrainedConfig from ...utils import ( logging, replace_return_docstrings, @@ -56,18 +54,8 @@ class LlavaNextVideoConfig(PretrainedConfig): The config object or dictionary of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. - video_token_index (`int`, *optional*, defaults to 32000): - The video token index to encode the image prompt. image_token_index (`int`, *optional*, defaults to 32001): - The image token index to encode the image prompt. - spatial_pool_mode (`str`, *optional*, defaults to `"average"`): - Pooling mode to use for videos. Can be "average", "max" or "conv". - spatial_pool_stride (`int`, *optional*, defaults to 2): - Stride used in the pooling layer for videos. - image_seq_length (`int`, *optional*, defaults to 576): - Sequence length of one image embedding. - video_seq_length (`int`, *optional*, defaults to 288): - Sequence length of one video embedding. + The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): @@ -81,6 +69,16 @@ class LlavaNextVideoConfig(PretrainedConfig): of the form `(height, width)`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. + video_token_index (`int`, *optional*, defaults to 32000): + The video token index to encode the image prompt. + spatial_pool_mode (`str`, *optional*, defaults to `"average"`): + Pooling mode to use for videos. Can be "average", "max" or "conv". + spatial_pool_stride (`int`, *optional*, defaults to 2): + Stride used in the pooling layer for videos. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + video_seq_length (`int`, *optional*, defaults to 288): + Sequence length of one video embedding. Example: @@ -178,7 +176,13 @@ def __init__( @dataclass class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast): - pass + """ + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + video_hidden_states: Optional[torch.FloatTensor] = None class LlavaNextVideoPooler(nn.Module): @@ -215,11 +219,7 @@ def forward(self, image_features): return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() -class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector): - pass - - -class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin): +class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): super().__init__(config, **super_kwargs) self.vision_resampler = LlavaNextVideoPooler(config) @@ -287,6 +287,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -298,6 +300,10 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -329,7 +335,7 @@ def forward( ... frames.append(frame) ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) - >>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto) + >>> model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", device_map="auto") >>> processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf") >>> prompt = "USER: