diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 3e16a0810723..26c219c58173 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -959,7 +959,24 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output -class BeitPyramidPoolingModule(nn.ModuleList): +class BeitPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + BeitConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class BeitPyramidPoolingModule(nn.Module): """ Pyramid Pooling Module (PPM) used in PSPNet. @@ -979,17 +996,15 @@ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int self.align_corners = align_corners self.in_channels = in_channels self.channels = channels - for pool_scale in pool_scales: - self.append( - nn.Sequential( - nn.AdaptiveAvgPool2d(pool_scale), - BeitConvModule(self.in_channels, self.channels, kernel_size=1), - ) - ) + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels) + self.blocks.append(block) + self.add_module(str(i), block) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: ppm_outs = [] - for ppm in self: + for ppm in self.blocks: ppm_out = ppm(x) upsampled_ppm_out = nn.functional.interpolate( ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 0e286a773d31..c1efdf210666 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -878,8 +878,26 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output +# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision +class Data2VecVisionPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + Data2VecVisionConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + # Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision -class Data2VecVisionPyramidPoolingModule(nn.ModuleList): +class Data2VecVisionPyramidPoolingModule(nn.Module): """ Pyramid Pooling Module (PPM) used in PSPNet. @@ -899,17 +917,17 @@ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int self.align_corners = align_corners self.in_channels = in_channels self.channels = channels - for pool_scale in pool_scales: - self.append( - nn.Sequential( - nn.AdaptiveAvgPool2d(pool_scale), - Data2VecVisionConvModule(self.in_channels, self.channels, kernel_size=1), - ) + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = Data2VecVisionPyramidPoolingBlock( + pool_scale=pool_scale, in_channels=in_channels, channels=channels ) + self.blocks.append(block) + self.add_module(str(i), block) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: ppm_outs = [] - for ppm in self: + for ppm in self.blocks: ppm_out = ppm(x) upsampled_ppm_out = nn.functional.interpolate( ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners