From 574be6fd86c5aa99dd7b60c746d4166def5698e0 Mon Sep 17 00:00:00 2001 From: Ccc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 17 Mar 2023 11:39:48 +0800 Subject: [PATCH] [WIP][Feature] Add ViT-Adapter Model (#2762) --- configs/vit_adapter/README.md | 16 + ...t_vit_adapter_tiny_ade20k_512x512_160k.yml | 66 +++ paddleseg/models/__init__.py | 1 + paddleseg/models/backbones/__init__.py | 3 +- paddleseg/models/backbones/vit_adapter.py | 420 ++++++++++++++++ .../models/layers/ms_deformable_attention.py | 159 ++++++ paddleseg/models/layers/vit_adapter_layers.py | 461 ++++++++++++++++++ paddleseg/models/losses/cross_entropy_loss.py | 11 +- paddleseg/models/upernet_vit_adapter.py | 277 +++++++++++ 9 files changed, 1410 insertions(+), 4 deletions(-) create mode 100644 configs/vit_adapter/README.md create mode 100644 configs/vit_adapter/upernet_vit_adapter_tiny_ade20k_512x512_160k.yml create mode 100644 paddleseg/models/backbones/vit_adapter.py create mode 100644 paddleseg/models/layers/ms_deformable_attention.py create mode 100644 paddleseg/models/layers/vit_adapter_layers.py create mode 100644 paddleseg/models/upernet_vit_adapter.py diff --git a/configs/vit_adapter/README.md b/configs/vit_adapter/README.md new file mode 100644 index 0000000000..ff904971fc --- /dev/null +++ b/configs/vit_adapter/README.md @@ -0,0 +1,16 @@ +# Vision Transformer Adapter for Dense Predictions + +## Reference + +> Chen, Zhe, Yuchen Duan, Wenhai Wang, Junjun He, Tong Lu, Jifeng Dai, and Yu Qiao. "Vision Transformer Adapter for Dense Predictions." arXiv preprint arXiv:2205.08534 (2022). + +## Prerequesites + +Download the ms_deform_attn.zip (https://paddleseg.bj.bcebos.com/dygraph/customized_ops/ms_deform_attn.zip), and then refer to the readme to install ms_deform_attn lib. +## Performance + +### ADE20K + +| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | +|-|-|-|-|-|-|-|-| +|UPerNetViTAdapter|ViT-Adapter-Tiny|512x512|160000|41.90%|-|-|[model](https://paddleseg.bj.bcebos.com/dygraph/ade20k/upernet_vit_adapter_tiny_ade20k_512x512_160k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/ade20k/upernet_vit_adapter_tiny_ade20k_512x512_160k/train_log.txt) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=88173046bd09f61da5f48db66baddd7d)| diff --git a/configs/vit_adapter/upernet_vit_adapter_tiny_ade20k_512x512_160k.yml b/configs/vit_adapter/upernet_vit_adapter_tiny_ade20k_512x512_160k.yml new file mode 100644 index 0000000000..fbc2110a29 --- /dev/null +++ b/configs/vit_adapter/upernet_vit_adapter_tiny_ade20k_512x512_160k.yml @@ -0,0 +1,66 @@ +_base_: '../_base_/ade20k.yml' + +batch_size: 4 # total batch size is 16 +iters: 160000 + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + - type: RandomPaddingCrop + crop_size: [512, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.4 + contrast_range: 0.4 + saturation_range: 0.4 + - type: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +val_dataset: + transforms: + - type: Resize + target_size: [2048, 512] + keep_ratio: True + size_divisor: 32 + - type: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +test_config: + is_slide: True + crop_size: [512, 512] + stride: [341, 341] + +optimizer: + _inherited_: False + type: AdamW + weight_decay: 0.01 + +lr_scheduler: + type: PolynomialDecay + learning_rate: 6.0e-5 + end_lr: 0 + power: 1.0 + warmup_iters: 1500 + warmup_start_lr: 1.0e-6 + +loss: + types: + - type: CrossEntropyLoss + avg_non_ignore: False + coef: [1, 0.4] + +model: + type: UPerNetViTAdapter + backbone: + type: ViTAdapter_Tiny + pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/deit_tiny_patch16_224.zip + backbone_indices: [0, 1, 2, 3] + channels: 512 + pool_scales: [1, 2, 3, 6] + dropout_ratio: 0.1 + aux_loss: True + aux_channels: 256 \ No newline at end of file diff --git a/paddleseg/models/__init__.py b/paddleseg/models/__init__.py index 9436ae97c1..e385105217 100644 --- a/paddleseg/models/__init__.py +++ b/paddleseg/models/__init__.py @@ -67,6 +67,7 @@ from .mscale_ocrnet import MscaleOCRNet from .topformer import TopFormer from .rtformer import RTFormer +from .upernet_vit_adapter import UPerNetViTAdapter from .lpsnet import LPSNet from .maskformer import MaskFormer from .segnext import SegNeXt diff --git a/paddleseg/models/backbones/__init__.py b/paddleseg/models/backbones/__init__.py index 69d7feb5dd..caecdc6487 100644 --- a/paddleseg/models/backbones/__init__.py +++ b/paddleseg/models/backbones/__init__.py @@ -27,5 +27,6 @@ from .cae import * from .top_transformer import * from .uhrnet import * +from .vit_adapter import * from .hrformer import * -from .mscan import * \ No newline at end of file +from .mscan import * diff --git a/paddleseg/models/backbones/vit_adapter.py b/paddleseg/models/backbones/vit_adapter.py new file mode 100644 index 0000000000..6d5366d0f2 --- /dev/null +++ b/paddleseg/models/backbones/vit_adapter.py @@ -0,0 +1,420 @@ +# This file is heavily based on https://github.com/czczup/ViT-Adapter + +import math +from functools import partial + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager +from paddleseg.utils import utils, logger +from paddleseg.cvlibs.param_init import normal_init, trunc_normal_init, constant_init +from paddleseg.models.backbones.transformer_utils import to_2tuple, DropPath +from paddleseg.models.layers.vit_adapter_layers import ( + SpatialPriorModule, InteractionBlock, deform_inputs) +from paddleseg.models.layers.ms_deformable_attention import MSDeformAttn + +__all__ = ['ViTAdapter', 'ViTAdapter_Tiny'] + + +class PatchEmbed(nn.Layer): + """2D Image to Patch Embedding.""" + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + if self.flatten: + x = x.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + x_shape = paddle.shape(x) + N, C = x_shape[1], x_shape[2] + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_scale=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity( + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.layer_scale = layer_scale + if layer_scale: + self.gamma1 = self.create_parameter( + shape=(dim, ), + default_initializer=paddle.nn.initializer.Constant(value=1.)) + self.gamma2 = self.create_parameter( + shape=(dim, ), + default_initializer=paddle.nn.initializer.Constant(value=1.)) + + def forward(self, x, H, W): + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Layer): + """Vision Transformer. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + layer_scale=True, + embed_layer=PatchEmbed, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + act_layer=nn.GELU, + pretrained=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_channels (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + pretrained: (str): pretrained path + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6) + act_layer = act_layer or nn.GELU + self.norm_layer = norm_layer + self.act_layer = act_layer + self.pretrain_size = img_size + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_channels, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = self.create_parameter( + shape=(1, num_patches + self.num_tokens, embed_dim), + default_initializer=paddle.nn.initializer.Constant(value=0.)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = np.linspace(0, drop_path_rate, + depth) # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + layer_scale=layer_scale) for i in range(depth) + ]) + + self.pretrained = pretrained + self.init_weight() + + def init_weight(self): + utils.load_pretrained_model(self, self.pretrained) + + +@manager.BACKBONES.add_component +class ViTAdapter(VisionTransformer): + """ The ViT-Adapter + """ + + def __init__(self, + pretrain_size=224, + num_heads=12, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0., + interaction_indexes=None, + with_cffn=True, + cffn_ratio=0.25, + deform_ratio=1.0, + add_vit_feature=True, + pretrained=None, + use_extra_extractor=True, + *args, + **kwargs): + + super().__init__( + num_heads=num_heads, pretrained=pretrained, *args, **kwargs) + + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + self.feat_channels = [embed_dim] * 4 + + self.level_embed = self.create_parameter( + shape=(3, embed_dim), + default_initializer=paddle.nn.initializer.Constant(value=0.)) + self.spm = SpatialPriorModule( + inplanes=conv_inplane, embed_dim=embed_dim) + self.interactions = nn.Sequential(*[ + InteractionBlock( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=((True if i == len(interaction_indexes) - 1 else + False) and use_extra_extractor)) + for i in range(len(interaction_indexes)) + ]) + self.up = nn.Conv2DTranspose(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_init(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_init(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + constant_init(m.bias, value=0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, (nn.BatchNorm2D, + nn.SyncBatchNorm)): + constant_init(m.bias, value=0) + constant_init(m.weight, value=1.0) + elif isinstance(m, nn.Conv2D) or isinstance(m, nn.Conv2DTranspose): + fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + fan_out //= m._groups + normal_init(m.weight, std=math.sqrt(2.0 / fan_out)) + if m.bias is not None: + constant_init(m.bias, value=0) + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape( + [1, self.pretrain_size[0] // 16, self.pretrain_size[1] // 16, + -1]).transpose([0, 3, 1, 2]) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\ + reshape([1, -1, H * W]).transpose([0, 2, 1]) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = paddle.concat([c2, c3, c4], axis=1) + + # Patch Embedding forward + x, H, W = self.patch_embed(x) + bs, n, dim = x.shape + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H, W) + x = self.pos_drop(x + pos_embed) + + # Interaction + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + x, c = layer(x, c, self.blocks[indexes[0]:indexes[-1] + 1], + deform_inputs1, deform_inputs2, H, W) + outs.append(x.transpose([0, 2, 1]).reshape([bs, dim, H, W])) + + # Split & Reshape + c2 = c[:, 0:c2.shape[1], :] + c3 = c[:, c2.shape[1]:c2.shape[1] + c3.shape[1], :] + c4 = c[:, c2.shape[1] + c3.shape[1]:, :] + + c2 = c2.transpose([0, 2, 1]).reshape([bs, dim, H * 2, W * 2]) + c3 = c3.transpose([0, 2, 1]).reshape([bs, dim, H, W]) + c4 = c4.transpose([0, 2, 1]).reshape([bs, dim, H // 2, W // 2]) + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + x1 = F.interpolate( + x1, scale_factor=4, mode='bilinear', align_corners=False) + x2 = F.interpolate( + x2, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.interpolate( + x4, scale_factor=0.5, mode='bilinear', align_corners=False) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] + + +@manager.BACKBONES.add_component +def ViTAdapter_Tiny(**kwargs): + return ViTAdapter( + num_heads=3, + patch_size=16, + embed_dim=192, + depth=12, + mlp_ratio=4, + drop_path_rate=0.1, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + cffn_ratio=0.25, + deform_ratio=1.0, + interaction_indexes=[[0, 2], [3, 5], [6, 8], [9, 11]], + **kwargs) \ No newline at end of file diff --git a/paddleseg/models/layers/ms_deformable_attention.py b/paddleseg/models/layers/ms_deformable_attention.py new file mode 100644 index 0000000000..8af9f36679 --- /dev/null +++ b/paddleseg/models/layers/ms_deformable_attention.py @@ -0,0 +1,159 @@ +# This file is heavily based on https://github.com/czczup/ViT-Adapter +import math +import warnings + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import param_init +from paddleseg.cvlibs.param_init import constant_init, xavier_uniform + + +class MSDeformAttn(nn.Layer): + def __init__(self, + d_model=256, + n_levels=4, + n_heads=8, + n_points=4, + ratio=1.0): + """Multi-Scale Deformable Attention Module. + + Args: + d_model(int, optional): The hidden dimension. Default: 256 + n_levels(int, optional): The number of feature levels. Default: 4 + n_heads(int, optional): The number of attention heads. Default: 8 + n_points(int, optional): The number of sampling points per attention head per feature level. Default: 4 + ratio (float, optional): The ratio of channels for Linear. Default: 1.0 + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, ' + 'but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 + # which is more efficient in our CUDA implementation + if not self._is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make " + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = 64 + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.ratio = ratio + + self.sampling_offsets = nn.Linear(d_model, + n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, + n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, int(d_model * ratio)) + self.output_proj = nn.Linear(int(d_model * ratio), d_model) + + self._reset_parameters() + + @staticmethod + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError('invalid input for _is_power_of_2: {} (type: {})'. + format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + def _reset_parameters(self): + constant_init(self.sampling_offsets.weight, value=0.) + thetas = paddle.arange( + self.n_heads, dtype='float32') * (2.0 * math.pi / self.n_heads) + grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max( + -1, keepdim=True)[0]).reshape([self.n_heads, 1, 1, 2]).tile( + [1, self.n_levels, self.n_points, 1]) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + + grid_init = grid_init.reshape([-1]) + self.sampling_offsets.bias = self.create_parameter( + shape=grid_init.shape, + default_initializer=paddle.nn.initializer.Assign(grid_init)) + self.sampling_offsets.bias.stop_gradient = True + + constant_init(self.attention_weights.weight, value=0.) + constant_init(self.attention_weights.bias, value=0.) + xavier_uniform(self.value_proj.weight) + constant_init(self.value_proj.bias, value=0.) + xavier_uniform(self.output_proj.weight) + constant_init(self.output_proj.bias, value=0.) + + def forward(self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None): + """ + Args: + query: (N, Length_{query}, C) + reference_points: (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + input_flatten: (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + input_spatial_shapes: (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + input_level_start_index: (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + input_padding_mask: (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + Returns: + output (N, Length_{query}, C) + """ + + def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] + ).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = masked_fill(value, input_padding_mask[..., None], float(0)) + + value = value.reshape([ + N, Len_in, self.n_heads, + int(self.ratio * self.d_model) // self.n_heads + ]) + sampling_offsets = self.sampling_offsets(query).reshape( + [N, Len_q, self.n_heads, self.n_levels, self.n_points, 2]) + attention_weights = self.attention_weights(query).reshape( + [N, Len_q, self.n_heads, self.n_levels * self.n_points]) + attention_weights = F.softmax(attention_weights, -1).\ + reshape([N, Len_q, self.n_heads, self.n_levels, self.n_points]) + + if reference_points.shape[-1] == 2: + offset_normalizer = paddle.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], + -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.' + .format(reference_points.shape[-1])) + try: + import ms_deform_attn + except: + print( + "Import ms_deform_attn failed. Please download the following file and refer to " + "the readme to install ms_deform_attn lib: " + "https://paddleseg.bj.bcebos.com/dygraph/customized_ops/ms_deform_attn.zip" + ) + exit() + output = ms_deform_attn.ms_deform_attn( + value, input_spatial_shapes, input_level_start_index, + sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/paddleseg/models/layers/vit_adapter_layers.py b/paddleseg/models/layers/vit_adapter_layers.py new file mode 100644 index 0000000000..6735331db9 --- /dev/null +++ b/paddleseg/models/layers/vit_adapter_layers.py @@ -0,0 +1,461 @@ +# This file is heavily based on https://github.com/czczup/ViT-Adapter + +import math +import warnings +from functools import partial + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddleseg.models.backbones.transformer_utils import DropPath +from paddleseg.models.layers.ms_deformable_attention import MSDeformAttn + + +def get_reference_points(spatial_shapes): + reference_points_list = [] + for _, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = paddle.meshgrid( + paddle.linspace( + 0.5, H_ - 0.5, H_, dtype='float32'), + paddle.linspace( + 0.5, W_ - 0.5, W_, dtype='float32')) + ref_y = ref_y.reshape([1, -1]) / H_ + ref_x = ref_x.reshape([1, -1]) / W_ + ref = paddle.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = paddle.concat(reference_points_list, 1) + reference_points = paddle.unsqueeze(reference_points, axis=2) + return reference_points + + +def deform_inputs(x): + _, _, h, w = x.shape + spatial_shapes = paddle.to_tensor( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], + dtype='int64') + level_start_index = paddle.concat((paddle.zeros( + (1, ), dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 16, w // 16)]) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = paddle.to_tensor([(h // 16, w // 16)], dtype='int64') + level_start_index = paddle.concat((paddle.zeros( + (1, ), dtype='int64'), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)]) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class DWConv(nn.Layer): + """ + The specific DWConv unsed in ConvFFN. + """ + + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0:16 * n, :].transpose([0, 2, 1]).reshape( + [B, C, H * 2, W * 2]) + x2 = x[:, 16 * n:20 * n, :].transpose([0, 2, 1]).reshape([B, C, H, W]) + x3 = x[:, 20 * n:, :].transpose([0, 2, 1]).reshape( + [B, C, H // 2, W // 2]) + x1 = self.dwconv(x1).flatten(2).transpose([0, 2, 1]) + x2 = self.dwconv(x2).flatten(2).transpose([0, 2, 1]) + x3 = self.dwconv(x3).flatten(2).transpose([0, 2, 1]) + x = paddle.concat([x1, x2, x3], axis=1) + return x + + +class ConvFFN(nn.Layer): + """ + The implementation of ConvFFN unsed in Extractor. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Extractor(nn.Layer): + """ + The Extractor module in ViT-Adapter. + """ + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0., + drop_path=0., + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6)): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, + n_levels=n_levels, + n_heads=num_heads, + n_points=n_points, + ratio=deform_ratio) + self.with_cffn = with_cffn + if with_cffn: + self.ffn = ConvFFN( + in_features=dim, + hidden_features=int(dim * cffn_ratio), + drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index, H, W): + attn = self.attn( + self.query_norm(query), reference_points, + self.feat_norm(feat), spatial_shapes, level_start_index, None) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + +class Injector(nn.Layer): + """ + The Injector module in ViT-Adapter. + """ + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + init_values=0.): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, + n_levels=n_levels, + n_heads=num_heads, + n_points=n_points, + ratio=deform_ratio) + self.gamma = self.create_parameter( + shape=(dim, ), + default_initializer=paddle.nn.initializer.Constant( + value=init_values)) + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index): + attn = self.attn( + self.query_norm(query), reference_points, + self.feat_norm(feat), spatial_shapes, level_start_index, None) + return query + self.gamma * attn + + +class InteractionBlock(nn.Layer): + """ + Combine the Extractor, Extractor and ViT Blocks. + """ + + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + + for _, blk in enumerate(blocks): + x = blk(x, H, W) + + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + + return x, c + + +class InteractionBlockWithCls(nn.Layer): + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial( + nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + x = paddle.concat((cls, x), axis=1) + for _, blk in enumerate(blocks): + x = blk(x, H, W) + cls, x = x[:, :1, ], x[:, 1:, ] + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + return x, c, cls + + +class SpatialPriorModule(nn.Layer): + def __init__(self, inplanes=64, embed_dim=384): + super().__init__() + + self.stem = nn.Sequential(*[ + nn.Conv2D( + 3, + inplanes, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False), nn.SyncBatchNorm(inplanes), nn.ReLU(), + nn.Conv2D( + inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False), nn.SyncBatchNorm(inplanes), nn.ReLU(), + nn.Conv2D( + inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False), nn.SyncBatchNorm(inplanes), nn.ReLU(), + nn.MaxPool2D( + kernel_size=3, stride=2, padding=1) + ]) + self.conv2 = nn.Sequential(*[ + nn.Conv2D( + inplanes, + 2 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False), nn.SyncBatchNorm(2 * inplanes), nn.ReLU() + ]) + self.conv3 = nn.Sequential(*[ + nn.Conv2D( + 2 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False), nn.SyncBatchNorm(4 * inplanes), nn.ReLU() + ]) + self.conv4 = nn.Sequential(*[ + nn.Conv2D( + 4 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False), nn.SyncBatchNorm(4 * inplanes), nn.ReLU() + ]) + self.fc1 = nn.Conv2D( + inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) + self.fc2 = nn.Conv2D( + 2 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) + self.fc3 = nn.Conv2D( + 4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) + self.fc4 = nn.Conv2D( + 4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True) + + def forward(self, x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + c2 = c2.reshape([bs, dim, -1]).transpose([0, 2, 1]) # 8s + c3 = c3.reshape([bs, dim, -1]).transpose([0, 2, 1]) # 16s + c4 = c4.reshape([bs, dim, -1]).transpose([0, 2, 1]) # 32s + + return c1, c2, c3, c4 diff --git a/paddleseg/models/losses/cross_entropy_loss.py b/paddleseg/models/losses/cross_entropy_loss.py index c934a0a5b4..b1cfb3a624 100644 --- a/paddleseg/models/losses/cross_entropy_loss.py +++ b/paddleseg/models/losses/cross_entropy_loss.py @@ -33,6 +33,7 @@ class CrossEntropyLoss(nn.Layer): top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0]. When its value < 1.0, only compute the loss for the top k percent pixels (e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``. + avg_non_ignore (bool, optional): Whether the loss is only averaged over non-ignored value of pixels. Default: True. data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``. """ @@ -40,10 +41,12 @@ def __init__(self, weight=None, ignore_index=255, top_k_percent_pixels=1.0, + avg_non_ignore=True, data_format='NCHW'): super(CrossEntropyLoss, self).__init__() self.ignore_index = ignore_index self.top_k_percent_pixels = top_k_percent_pixels + self.avg_non_ignore = avg_non_ignore self.EPS = 1e-8 self.data_format = data_format if weight is not None: @@ -107,10 +110,12 @@ def _post_process_loss(self, logit, label, semantic_weights, loss): Returns: (Tensor): The average loss. """ - mask = label != self.ignore_index - mask = paddle.cast(mask, 'float32') - label.stop_gradient = True + if self.avg_non_ignore: + mask = paddle.cast(label != self.ignore_index, dtype='float32') + else: + mask = paddle.ones(label.shape, dtype='float32') mask.stop_gradient = True + label.stop_gradient = True if loss.ndim > mask.ndim: loss = paddle.squeeze(loss, axis=-1) diff --git a/paddleseg/models/upernet_vit_adapter.py b/paddleseg/models/upernet_vit_adapter.py new file mode 100644 index 0000000000..cb9dcfd28f --- /dev/null +++ b/paddleseg/models/upernet_vit_adapter.py @@ -0,0 +1,277 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg import utils +from paddleseg.cvlibs import manager +from paddleseg.models import layers + + +@manager.MODELS.add_component +class UPerNetViTAdapter(nn.Layer): + """ + The UPerNetViTAdapter implementation based on PaddlePaddle. + + The original article refers to + Chen, Zhe, Yuchen Duan, Wenhai Wang, Junjun He, Tong Lu, Jifeng Dai, and Yu Qiao. + "Vision Transformer Adapter for Dense Predictions." + (https://arxiv.org/abs/2205.08534). + + The implementation is based on https://github.com/czczup/ViT-Adapter + + Args: + num_classes (int): The unique number of target classes. + backbone (nn.Layer): The backbone network. + backbone_indices (tuple | list): The values indicate the indices of output of backbone. + channels (int, optional): The channels of inter layers in upernet head. Default: 512. + pool_scales (list, optional): The scales in PPM. Default: [1, 2, 3, 6]. + dropout_ratio (float, optional): The dropout ratio for upernet head. Default: 0.1. + aux_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True. + aux_channels (int, optional): The channels of inter layers in auxiliary head. Default: 256. + align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, + e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + pretrained (str, optional): The path or url of pretrained model. Default: None. + """ + + def __init__(self, + num_classes, + backbone, + backbone_indices, + channels=512, + pool_scales=[1, 2, 3, 6], + dropout_ratio=0.1, + aux_loss=True, + aux_channels=256, + align_corners=False, + pretrained=None): + super().__init__() + self.backbone = backbone + self.backbone_indices = backbone_indices + self.align_corners = align_corners + + in_channels = [self.backbone.feat_channels[i] for i in backbone_indices] + self.head = UPerNetHead( + num_classes=num_classes, + in_channels=in_channels, + channels=channels, + pool_scales=pool_scales, + dropout_ratio=dropout_ratio, + aux_loss=aux_loss, + aux_channels=aux_channels, + align_corners=align_corners) + + self.pretrained = pretrained + self.init_weight() + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + def forward(self, x): + feats = self.backbone(x) + feats = [feats[i] for i in self.backbone_indices] + logit_list = self.head(feats) + logit_list = [ + F.interpolate( + logit, + paddle.shape(x)[2:], + mode='bilinear', + align_corners=self.align_corners) for logit in logit_list + ] + return logit_list + + +class ConvBNReLU(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias_attr=False, + **kwargs): + super().__init__() + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size, + bias_attr=bias_attr, + **kwargs) + self.bn = nn.BatchNorm2D(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class PPM(nn.Layer): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple | list): Pooling scales used in PPM. + in_channels (int): Input channels. + channels (int): Output Channels after modules, before conv_seg. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, align_corners): + super().__init__() + self.pool_scales = pool_scales + self.in_channels = in_channels + self.channels = channels + self.align_corners = align_corners + self.stages = nn.LayerList() + for pool_scale in pool_scales: + self.stages.append( + nn.Sequential( + nn.AdaptiveAvgPool2D(output_size=(pool_scale, pool_scale)), + ConvBNReLU( + in_channels=in_channels, + out_channels=channels, + kernel_size=1))) + + def forward(self, x): + ppm_outs = [] + for ppm in self.stages: + ppm_out = ppm(x) + upsampled_ppm_out = F.interpolate( + ppm_out, + paddle.shape(x)[2:], + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class UPerNetHead(nn.Layer): + """ + This head is the implementation of "Unified Perceptual Parsing for Scene Understanding". + This is heavily based on https://github.com/czczup/ViT-Adapter + + Args: + num_classes (int): The unique number of target classes. + in_channels (list[int]): The channels of input features. + channels (int, optional): The channels of inter layers in upernet head. Default: 512. + pool_scales (list, optional): The scales in PPM. Default: [1, 2, 3, 6]. + dropout_ratio (float, optional): The dropout ratio for upernet head. Default: 0.1. + aux_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True. + aux_channels (int, optional): The channels of inter layers in auxiliary head. Default: 256. + align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, + e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + """ + + def __init__(self, + num_classes, + in_channels, + channels, + pool_scales=[1, 2, 3, 6], + dropout_ratio=0.1, + aux_loss=False, + aux_channels=256, + align_corners=False): + super().__init__() + self.align_corners = align_corners + + # PSP Module + self.psp_modules = PPM(pool_scales, + in_channels[-1], + channels, + align_corners=align_corners) + self.bottleneck = ConvBNReLU( + in_channels[-1] + len(pool_scales) * channels, + channels, + 3, + padding=1) + # FPN Module + self.lateral_convs = nn.LayerList() + self.fpn_convs = nn.LayerList() + for ch in in_channels[:-1]: # skip the top layer + l_conv = ConvBNReLU(ch, channels, 1) + fpn_conv = ConvBNReLU(channels, channels, 3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvBNReLU( + len(in_channels) * channels, channels, 3, padding=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2D(dropout_ratio) + else: + self.dropout = None + self.conv_seg = nn.Conv2D(channels, num_classes, kernel_size=1) + + self.aux_loss = aux_loss + if self.aux_loss: + self.aux_conv = ConvBNReLU( + in_channels[2], aux_channels, 3, padding=1) + self.aux_conv_seg = nn.Conv2D( + aux_channels, num_classes, kernel_size=1) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = paddle.concat(psp_outs, axis=1) + output = self.bottleneck(psp_outs) + return output + + def forward(self, inputs): + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + upsampled = F.interpolate( + laterals[i], + paddle.shape(laterals[i - 1])[2:], + mode='bilinear', + align_corners=self.align_corners) + laterals[i - 1] = laterals[i - 1] + upsampled + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + fpn_outs.append(laterals[-1]) # append psp feature + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = F.interpolate( + fpn_outs[i], + size=paddle.shape(fpn_outs[0])[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = paddle.concat(fpn_outs, axis=1) + output = self.fpn_bottleneck(fpn_outs) + + if self.dropout is not None: + output = self.dropout(output) + output = self.conv_seg(output) + logits_list = [output] + + if self.aux_loss and self.training: + aux_output = self.aux_conv(inputs[2]) + aux_output = self.aux_conv_seg(aux_output) + logits_list.append(aux_output) + + return logits_list