Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
gaudi_T5LayerSelfAttention_forward,
gaudi_T5Stack_forward,
gaudi_vit_self_attention_forward,
gaudi_swin_get_attn_mask,
gaudi_wav2vec2_encoder_forward,
gaudi_wav2vec2_forward,
gaudi_wav2vec2forctc_forward,
Expand All @@ -121,6 +122,9 @@ def adapt_transformers_to_gaudi():
# Optimization tweak for ViT
transformers.models.vit.modeling_vit.ViTSelfAttention.forward = gaudi_vit_self_attention_forward

# Optimization tweak for Swin
transformers.models.swin.modeling_swin.SwinLayer.get_attn_mask = gaudi_swin_get_attn_mask

# Optimization tweak for Wav2Vec2
transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices = _gaudi_wav2vec2_compute_mask_indices
# transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices = _gaudi_wav2vec2_sample_negative_indices
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
gaudi_T5Stack_forward,
)
from .vit import gaudi_vit_self_attention_forward
from .swin import gaudi_swin_get_attn_mask
from .wav2vec2 import (
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/swin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modeling_swin import gaudi_swin_get_attn_mask
54 changes: 54 additions & 0 deletions optimum/habana/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Swin Transformer model."""

import math
from typing import Optional, Tuple, Union

import torch
from transformers.models.swin.modeling_swin import window_partition

def gaudi_swin_get_attn_mask(self, height, width, dtype):
'''
Copied from SwinLayer.get_attn_mask : https://github.com/huggingface/transformers/blob/main/src/transformers/models/swin/modeling_swin.py
The only difference is moving img_mask to hpu for performance
'''
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device='hpu')

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check if there is self.device?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt find self.device

height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1

mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None

return attn_mask