From ff25e015f0c9337c65de7ad3827fce95539f0006 Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Tue, 12 Mar 2024 20:18:52 -0700 Subject: [PATCH 1/2] move img_mask@get_attn_mask() to hpu --- .../habana/transformers/generation/utils.py | 1 + optimum/habana/transformers/modeling_utils.py | 4 ++ .../habana/transformers/models/__init__.py | 1 + .../transformers/models/swin/__init__.py | 1 + .../transformers/models/swin/modeling_swin.py | 50 +++++++++++++++++++ 5 files changed, 57 insertions(+) create mode 100644 optimum/habana/transformers/models/swin/__init__.py create mode 100644 optimum/habana/transformers/models/swin/modeling_swin.py diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 755dec4516..944f840d6b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -76,6 +76,7 @@ "t5", "mistral", "mixtral", + "swin", ] diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index bab0f650f3..9d96c2a1eb 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -103,6 +103,7 @@ gaudi_T5LayerSelfAttention_forward, gaudi_T5Stack_forward, gaudi_vit_self_attention_forward, + gaudi_swin_get_attn_mask, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, gaudi_wav2vec2forctc_forward, @@ -121,6 +122,9 @@ def adapt_transformers_to_gaudi(): # Optimization tweak for ViT transformers.models.vit.modeling_vit.ViTSelfAttention.forward = gaudi_vit_self_attention_forward + # Optimization tweak for Swin + transformers.models.swin.modeling_swin.SwinLayer.get_attn_mask = gaudi_swin_get_attn_mask + # Optimization tweak for Wav2Vec2 transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices = _gaudi_wav2vec2_compute_mask_indices # transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices = _gaudi_wav2vec2_sample_negative_indices diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 4232534590..f4066e5df5 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -109,6 +109,7 @@ gaudi_T5Stack_forward, ) from .vit import gaudi_vit_self_attention_forward +from .swin import gaudi_swin_get_attn_mask from .wav2vec2 import ( _gaudi_wav2vec2_compute_mask_indices, _gaudi_wav2vec2_mask_hidden_states, diff --git a/optimum/habana/transformers/models/swin/__init__.py b/optimum/habana/transformers/models/swin/__init__.py new file mode 100644 index 0000000000..59dbee4d5d --- /dev/null +++ b/optimum/habana/transformers/models/swin/__init__.py @@ -0,0 +1 @@ +from .modeling_swin import gaudi_swin_get_attn_mask diff --git a/optimum/habana/transformers/models/swin/modeling_swin.py b/optimum/habana/transformers/models/swin/modeling_swin.py new file mode 100644 index 0000000000..81644a4dc4 --- /dev/null +++ b/optimum/habana/transformers/models/swin/modeling_swin.py @@ -0,0 +1,50 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union + +import torch +from transformers.models.swin.modeling_swin import window_partition + +def gaudi_swin_get_attn_mask(self, height, width, dtype): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device='hpu') + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + return attn_mask From 2c0bbb0f49a6586b126f79d2fcfbcf14f347a819 Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Wed, 13 Mar 2024 08:03:54 -0700 Subject: [PATCH 2/2] review updates --- optimum/habana/transformers/generation/utils.py | 1 - optimum/habana/transformers/models/swin/modeling_swin.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 944f840d6b..755dec4516 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -76,7 +76,6 @@ "t5", "mistral", "mixtral", - "swin", ] diff --git a/optimum/habana/transformers/models/swin/modeling_swin.py b/optimum/habana/transformers/models/swin/modeling_swin.py index 81644a4dc4..48b743439c 100644 --- a/optimum/habana/transformers/models/swin/modeling_swin.py +++ b/optimum/habana/transformers/models/swin/modeling_swin.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" PyTorch Swin Transformer model.""" import math from typing import Optional, Tuple, Union @@ -21,6 +21,10 @@ from transformers.models.swin.modeling_swin import window_partition def gaudi_swin_get_attn_mask(self, height, width, dtype): + ''' + Copied from SwinLayer.get_attn_mask : https://github.com/huggingface/transformers/blob/main/src/transformers/models/swin/modeling_swin.py + The only difference is moving img_mask to hpu for performance + ''' if self.shift_size > 0: # calculate attention mask for SW-MSA img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device='hpu')