diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 4f23d9793615..f1ad269c00c5 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -144,9 +144,9 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) - windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) - windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) return windows diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 882c8bb025d2..f2b7c493b6ab 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -490,9 +490,9 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) - windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) - windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) return windows diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 588a4200fb1e..f44fafd6e531 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -219,9 +219,9 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) - windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) - windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) return windows diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 926a7dd27679..8c9b24266fb3 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -227,9 +227,9 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) - windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) - windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) return windows