Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
gaudi_T5LayerSelfAttention_forward,
gaudi_T5Stack_forward,
gaudi_vit_self_attention_forward,
gaudi_swin_get_attn_mask,
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
)
Expand All @@ -143,6 +144,9 @@ def adapt_transformers_to_gaudi():
# Optimization tweak for ViT
transformers.models.vit.modeling_vit.ViTSelfAttention.forward = gaudi_vit_self_attention_forward

# Optimization tweak for Swin
transformers.models.swin.modeling_swin.SwinLayer.get_attn_mask = gaudi_swin_get_attn_mask

# Optimization tweak for Wav2Vec2
transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices = _gaudi_wav2vec2_compute_mask_indices
# transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices = _gaudi_wav2vec2_sample_negative_indices
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
gaudi_T5Stack_forward,
)
from .vit import gaudi_vit_self_attention_forward
from .swin import gaudi_swin_get_attn_mask
from .wav2vec2 import (
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/swin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modeling_swin import gaudi_swin_get_attn_mask
54 changes: 54 additions & 0 deletions optimum/habana/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Swin Transformer model."""

import math
from typing import Optional, Tuple, Union

import torch
from transformers.models.swin.modeling_swin import window_partition

def gaudi_swin_get_attn_mask(self, height, width, dtype):
'''
Copied from SwinLayer.get_attn_mask : https://github.com/huggingface/transformers/blob/main/src/transformers/models/swin/modeling_swin.py
The only difference is moving img_mask to hpu for performance
'''
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device='hpu')
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1

mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None

return attn_mask