Skip to content

Commit

Permalink
Upgrade Transformers to v4.43.x (#727)
Browse files Browse the repository at this point in the history
Changes required for sync:
- re-copy Llama & Beit attention
- add clip sdp & flash attn
- fix tie_weights method
- upgrade torch version in tests

---------

Co-authored-by: Leon Engländer <[email protected]>
  • Loading branch information
calpt and lenglaender authored Aug 4, 2024
1 parent 8ddbcc8 commit bc90220
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 25 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
- name: Install
run: |
pip install torch==2.1.2
pip install torch==2.3
pip install .[quality]
- name: Check Quality and Repo Consistency
run: |
Expand All @@ -62,7 +62,7 @@ jobs:
${{ runner.os }}-pip-
- name: Install
run: |
pip install torch==2.1.2
pip install torch==2.3
pip install .[sklearn,testing,sentencepiece]
- name: Test
run: |
Expand All @@ -85,7 +85,7 @@ jobs:
${{ runner.os }}-pip-
- name: Install
run: |
pip install torch==2.1.2
pip install torch==2.3
pip install .[sklearn,testing,sentencepiece]
- name: Test
run: |
Expand All @@ -108,7 +108,7 @@ jobs:
${{ runner.os }}-pip-
- name: Install
run: |
pip install torch==2.1.2
pip install torch==2.3
pip install .[sklearn,testing,sentencepiece]
pip install conllu seqeval
- name: Test Examples
Expand Down
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 523 files
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
"sphinx-intl==2.1.0",
"sphinx-multiversion==0.2.4",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.42.4",
"torch",
"transformers~=4.43.3",
]


Expand Down
2 changes: 2 additions & 0 deletions src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def tie_weights(self):
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)

super().tie_weights()

def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
Expand Down
11 changes: 8 additions & 3 deletions src/adapters/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def forward(
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

Expand All @@ -51,9 +52,11 @@ def forward(

# Add relative position bias if present.
if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)

# Add shared relative position bias if provided.
if relative_position_bias is not None:
Expand Down Expand Up @@ -89,15 +92,17 @@ def forward(
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional[BeitRelativePositionBias] = None,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
Expand Down
167 changes: 166 additions & 1 deletion src/adapters/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,25 @@
import torch.utils.checkpoint
from torch import nn

from transformers.models.clip.modeling_clip import CLIPAttention, CLIPEncoderLayer
from transformers.models.clip.modeling_clip import (
CLIPAttention,
CLIPEncoderLayer,
CLIPFlashAttention2,
CLIPSdpaAttention,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
from transformers.utils import is_flash_attn_2_available, logging


if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward

from .mixin_clip import CLIPAttentionAdaptersMixin, CLIPEncoderLayerAdaptersMixin


logger = logging.get_logger(__name__)


class CLIPAttentionWithAdapters(CLIPAttentionAdaptersMixin, CLIPAttention):
def forward(
self,
Expand All @@ -46,9 +60,11 @@ def forward(
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
# >>> END AH Changes <<<

key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
Expand Down Expand Up @@ -115,6 +131,155 @@ def forward(
return attn_output, attn_weights_reshaped


class CLIPFlashAttention2WithAdapters(CLIPAttentionAdaptersMixin, CLIPFlashAttention2):
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False

batch_size, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# >>> START AH Changes <<<
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
# >>> END AH Changes <<<

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)

dropout_rate = self.dropout if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=causal_attention_mask is not None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)

attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights


class CLIPSdpaAttentionWithAdapters(CLIPAttentionAdaptersMixin, CLIPSdpaAttention):
# Adapted from CLIPAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)

# CLIP text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask

bsz, tgt_len, embed_dim = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
key_states, value_states, attn_mask = self.prefix_tuning(key_states, value_states, hidden_states, attn_mask)
# >>> END AH Changes <<<

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, None


class CLIPEncoderLayerWithAdapters(CLIPEncoderLayerAdaptersMixin, CLIPEncoderLayer):
def forward(
self,
Expand Down
Loading

0 comments on commit bc90220

Please sign in to comment.