Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,24 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return output


class BeitPyramidPoolingModule(nn.ModuleList):
class BeitPyramidPoolingBlock(nn.Module):
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
super().__init__()
self.layers = [
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(in_channels, channels, kernel_size=1),
]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)

def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state


class BeitPyramidPoolingModule(nn.Module):
"""
Pyramid Pooling Module (PPM) used in PSPNet.

Expand All @@ -979,17 +996,15 @@ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
)
)
self.blocks = []
for i, pool_scale in enumerate(pool_scales):
block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
self.blocks.append(block)
self.add_module(str(i), block)

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = []
for ppm in self:
for ppm in self.blocks:
ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
Expand Down
34 changes: 26 additions & 8 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,26 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return output


# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
class Data2VecVisionPyramidPoolingBlock(nn.Module):
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
super().__init__()
self.layers = [
nn.AdaptiveAvgPool2d(pool_scale),
Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)

def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state


# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
class Data2VecVisionPyramidPoolingModule(nn.ModuleList):
class Data2VecVisionPyramidPoolingModule(nn.Module):
"""
Pyramid Pooling Module (PPM) used in PSPNet.

Expand All @@ -899,17 +917,17 @@ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
Data2VecVisionConvModule(self.in_channels, self.channels, kernel_size=1),
)
self.blocks = []
for i, pool_scale in enumerate(pool_scales):
block = Data2VecVisionPyramidPoolingBlock(
pool_scale=pool_scale, in_channels=in_channels, channels=channels
)
self.blocks.append(block)
self.add_module(str(i), block)

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = []
for ppm in self:
for ppm in self.blocks:
ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
Expand Down