Skip to content

move img_mask@get_attn_mask() to hpu#102

Merged
libinta merged 2 commits into
habana-mainfrom
swin_attn_mask
Mar 14, 2024
Merged

move img_mask@get_attn_mask() to hpu#102
libinta merged 2 commits into
habana-mainfrom
swin_attn_mask

Conversation

@hsubramony
Copy link
Copy Markdown

@hsubramony hsubramony commented Mar 13, 2024

In certain 1x systems , we notice a drop in performance for swin models as get_attn_mask gets executed in cpu instead of hpu. This was seen after pytorch upgrade to 2.2
This fix allows img_mask in get_attn_mask() to be moved back hpu.

"t5",
"mistral",
"mixtral",
"swin",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to add this

@@ -0,0 +1,50 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have Hugginface only?

def gaudi_swin_get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device='hpu')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check if there is self.device?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt find self.device

@libinta libinta merged commit ae7fc93 into habana-main Mar 14, 2024
@astachowiczhabana
Copy link
Copy Markdown

huggingface#795

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants