Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

做mask时候不用划分成9份吧, 4份就可以?附验证代码 #194

Open
jmjkx opened this issue Apr 20, 2022 · 12 comments
Open

做mask时候不用划分成9份吧, 4份就可以?附验证代码 #194

jmjkx opened this issue Apr 20, 2022 · 12 comments

Comments

@jmjkx
Copy link

jmjkx commented Apr 20, 2022

本质上只要保证新窗口内的各个patch有来源的区分性就可以,作者通过 mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 来得到一个来源图,那完全可以划分成4份就可以了啊。

489d3d064a5c802c33e0e66c4a6ddde
这是验证代码

import torch


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows



window_size = 7
H, W = 56, 56
shift_size = window_size//2


img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

#### 以上划分9个窗口
####################################################################################################
#### 以下划分4个窗口

img_mask1 = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices1 = (slice(0, -shift_size),
            slice(-shift_size, None))
w_slices1 = (slice(0, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices1:
    for w in w_slices1:
        img_mask1[:, h, w, :] = cnt
        cnt += 1

mask_windows1 = window_partition(img_mask1, window_size)  # nW, window_size, window_size, 1
mask_windows1 = mask_windows1.view(-1, window_size * window_size)
attn_mask1 = mask_windows1.unsqueeze(1) - mask_windows1.unsqueeze(2)
attn_mask1 = attn_mask1.masked_fill(attn_mask1 != 0, float(-100.0)).masked_fill(attn_mask1 == 0, float(0.0))
t = attn_mask == attn_mask1
print(t.sum() == t.flatten(0).shape[0])

结果是true,是否说明直接划分四个区域就行了呢?

@jmjkx jmjkx changed the title 做mask时候不用划分成9份吧, 4份就可以? 做mask时候不用划分成9份吧, 4份就可以?附验证代码 Apr 20, 2022
@lifan724
Copy link

lifan724 commented May 3, 2022

Me too,最近精读代码想的和你一样

@ain-soph
Copy link

ain-soph commented Jul 7, 2022

I think it might get more concern from authors if you translate this issue into English.

(建议把issue翻译成英文)

@jmjkx
Copy link
Author

jmjkx commented Aug 11, 2022

I think it might get more concern from authors if you translate this issue into English.

(建议把issue翻译成英文)

哈哈, 好吧, 我看作者是亚研院那几个国内兄弟, 懒得搞英文了(还是菜), 哈哈。

@jmjkx
Copy link
Author

jmjkx commented Aug 11, 2022

Me too,最近精读代码想的和你一样

就是分块分多了, 但是还是赞叹构思太巧妙了, 瑕不掩瑜。

@CHENHUI-X
Copy link

大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的

@jmjkx
Copy link
Author

jmjkx commented Sep 1, 2022

大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的

自注意力机制 Q K 相乘不是出来一个矩阵嘛? 然后比如 i行 j列这个元素,代表第i个token和第j个token之间的关系。然后来自不同窗口的两个token应该没关系,所以应该强行置0。

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 

这句话就是给 算出来的矩阵标序号, 算出来来自一个窗口为0, 不同窗口不为0。 不为0的给原矩阵对应位置-100, 这样softmax出来这里就接近0, 也就达到了前面说的强行置0的效果.

@CHENHUI-X
Copy link

大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的

自注意力机制 Q K 相乘不是出来一个矩阵嘛? 然后比如 i行 j列这个元素,代表第i个token和第j个token之间的关系。然后来自不同窗口的两个token应该没关系,所以应该强行置0。

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 

这句话就是给 算出来的矩阵标序号, 算出来来自一个窗口为0, 不同窗口不为0。 不为0的给原矩阵对应位置-100, 这样softmax出来这里就接近0, 也就达到了前面说的强行置0的效果.

嗯嗯,谢谢大佬!

@ain-soph
Copy link

@jmjkx 我之前在pytorch/vision#6246 里添加了SwinV2到torchvision里面。

你可以再提一个issue,如果验证9->4不会引起精度降低的话,这简化还是挺有价值的。

@jmjkx
Copy link
Author

jmjkx commented Oct 24, 2022

@jmjkx 我之前在pytorch/vision#6246 里添加了SwinV2到torchvision里面。

你可以再提一个issue,如果验证9->4不会引起精度降低的话,这简化还是挺有价值的。

好的, 这几天抽时间写个英文的。 在 torchvision repo 提么? 还是在原来微软作者那里提

@ain-soph
Copy link

@jmjkx 我看这个微软的repo好像作者已经不维护了吧。
你可以在torchvision提一个,看看maintainer们愿不愿意接受。

@jmjkx
Copy link
Author

jmjkx commented Oct 24, 2022

@jmjkx 我看这个微软的repo好像作者已经不维护了吧。 你可以在torchvision提一个,看看maintainer们愿不愿意接受。

好的好的, 谢谢

@ResetSun
Copy link

ResetSun commented Aug 7, 2023

请问如果我想在kv上做spatial reduction的话,这个mask该怎么变呢?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants