diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index bab0f650f3..9d96c2a1eb 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -103,6 +103,7 @@ gaudi_T5LayerSelfAttention_forward, gaudi_T5Stack_forward, gaudi_vit_self_attention_forward, + gaudi_swin_get_attn_mask, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, gaudi_wav2vec2forctc_forward, @@ -121,6 +122,9 @@ def adapt_transformers_to_gaudi(): # Optimization tweak for ViT transformers.models.vit.modeling_vit.ViTSelfAttention.forward = gaudi_vit_self_attention_forward + # Optimization tweak for Swin + transformers.models.swin.modeling_swin.SwinLayer.get_attn_mask = gaudi_swin_get_attn_mask + # Optimization tweak for Wav2Vec2 transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices = _gaudi_wav2vec2_compute_mask_indices # transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices = _gaudi_wav2vec2_sample_negative_indices diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 4232534590..f4066e5df5 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -109,6 +109,7 @@ gaudi_T5Stack_forward, ) from .vit import gaudi_vit_self_attention_forward +from .swin import gaudi_swin_get_attn_mask from .wav2vec2 import ( _gaudi_wav2vec2_compute_mask_indices, _gaudi_wav2vec2_mask_hidden_states, diff --git a/optimum/habana/transformers/models/swin/__init__.py b/optimum/habana/transformers/models/swin/__init__.py new file mode 100644 index 0000000000..59dbee4d5d --- /dev/null +++ b/optimum/habana/transformers/models/swin/__init__.py @@ -0,0 +1 @@ +from .modeling_swin import gaudi_swin_get_attn_mask diff --git a/optimum/habana/transformers/models/swin/modeling_swin.py b/optimum/habana/transformers/models/swin/modeling_swin.py new file mode 100644 index 0000000000..48b743439c --- /dev/null +++ b/optimum/habana/transformers/models/swin/modeling_swin.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Swin Transformer model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +from transformers.models.swin.modeling_swin import window_partition + +def gaudi_swin_get_attn_mask(self, height, width, dtype): + ''' + Copied from SwinLayer.get_attn_mask : https://github.com/huggingface/transformers/blob/main/src/transformers/models/swin/modeling_swin.py + The only difference is moving img_mask to hpu for performance + ''' + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device='hpu') + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + return attn_mask