Skip to content

Commit

Permalink
[Feature] Add GloRe (PaddlePaddle#1951)
Browse files Browse the repository at this point in the history
  • Loading branch information
aigcliu authored Apr 15, 2022
1 parent e667705 commit 3fd33d5
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 0 deletions.
21 changes: 21 additions & 0 deletions configs/glore/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Graph-Based Global Reasoning Networks

## Reference

> Chen, Yunpeng, Marcus Rohrbach, Zhicheng Yan, Yan Shuicheng, Jiashi Feng, and Yannis Kalantidis. "Graph-based global reasoning networks." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 433-442. 2019.

## Performance

### Cityscapes

| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|GloRe|ResNet50_OS8|1024x512|80000|78.26%|78.61%|78.72%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/glore_resnet50_os8_cityscapes_1024x512_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/glore_resnet50_os8_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=de754e39ac9de4d2e951915c2334d6ec) |


### Pascal VOC 2012 + Aug

| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|GloRe|ResNet50_OS8|512x512|40000|80.16%|80.35%|80.40%|[model](https://bj.bcebos.com/paddleseg/dygraph/pascal_voc12/glore_resnet50_os8_voc12aug_512x512_40k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pascal_voc12/glore_resnet50_os8_voc12aug_512x512_40k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=e40c1dd8d4fcbf2dcda01242dec9d9b5) |
23 changes: 23 additions & 0 deletions configs/glore/glore_resnet50_os8_cityscapes_1024x512_80k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
_base_: '../_base_/cityscapes.yml'

batch_size: 2
iters: 80000

learning_rate:
decay:
end_lr: 1.0e-5

loss:
types:
- type: CrossEntropyLoss
coef: [1, 0.4]

model:
type: GloRe
backbone:
type: ResNet50_vd
output_stride: 8
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
enable_auxiliary_loss: True
align_corners: False
pretrained: null
17 changes: 17 additions & 0 deletions configs/glore/glore_resnet50_os8_voc12aug_512x512_40k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_: '../_base_/pascal_voc12aug.yml'


model:
type: GloRe
backbone:
type: ResNet50_vd
output_stride: 8
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
enable_auxiliary_loss: True
align_corners: False
pretrained: null

loss:
types:
- type: CrossEntropyLoss
coef: [1, 0.4]
1 change: 1 addition & 0 deletions paddleseg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@
from .bisenetv1 import BiseNetV1
from .fastfcn import FastFCN
from .pfpnnet import PFPNNet
from .glore import GloRe
199 changes: 199 additions & 0 deletions paddleseg/models/glore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddleseg.cvlibs import manager
from paddleseg.models import layers
from paddleseg.utils import utils


@manager.MODELS.add_component
class GloRe(nn.Layer):
"""
The GloRe implementation based on PaddlePaddle.
The original article refers to:
Chen, Yunpeng, et al. "Graph-Based Global Reasoning Networks"
(https://arxiv.org/pdf/1811.12814.pdf)
Args:
num_classes (int): The unique number of target classes.
backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
gru_channels (int, optional): The number of input channels in GloRe Unit. Default: 512.
gru_num_state (int, optional): The number of states in GloRe Unit. Default: 128.
gru_num_node (tuple, optional): The number of nodes in GloRe Unit. Default: Default: 128.
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""

def __init__(self,
num_classes,
backbone,
backbone_indices=(2, 3),
gru_channels=512,
gru_num_state=128,
gru_num_node=64,
enable_auxiliary_loss=True,
align_corners=False,
pretrained=None):
super().__init__()

self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]

self.head = GloReHead(num_classes, backbone_indices, backbone_channels,
gru_channels, gru_num_state, gru_num_node,
enable_auxiliary_loss)
self.align_corners = align_corners
self.pretrained = pretrained
self.init_weight()

def forward(self, x):
feat_list = self.backbone(x)
logit_list = self.head(feat_list)
return [
F.interpolate(
logit,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list
]

def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)


class GloReHead(nn.Layer):

def __init__(self,
num_classes,
backbone_indices,
backbone_channels,
gru_channels=512,
gru_num_state=128,
gru_num_node=64,
enable_auxiliary_loss=True):
super().__init__()

in_channels = backbone_channels[1]
self.conv_bn_relu = layers.ConvBNReLU(
in_channels, gru_channels, 1, bias_attr=False)
self.gru_module = GruModule(
num_input=gru_channels,
num_state=gru_num_state,
num_node=gru_num_node)

self.dropout = nn.Dropout(0.1)
self.classifier = nn.Conv2D(512, num_classes, kernel_size=1)
self.auxlayer = layers.AuxLayer(
in_channels=backbone_channels[0],
inter_channels=backbone_channels[0] // 4,
out_channels=num_classes)

self.backbone_indices = backbone_indices
self.enable_auxiliary_loss = enable_auxiliary_loss

def forward(self, feat_list):

logit_list = []
x = feat_list[self.backbone_indices[1]]

feature = self.conv_bn_relu(x)
gru_output = self.gru_module(feature)
output = self.dropout(gru_output)
logit = self.classifier(output)
logit_list.append(logit)

if self.enable_auxiliary_loss:
low_level_feat = feat_list[self.backbone_indices[0]]
auxiliary_logit = self.auxlayer(low_level_feat)
logit_list.append(auxiliary_logit)

return logit_list


class GCN(nn.Layer):
def __init__(self, num_state, num_node, bias=False):
super(GCN, self).__init__()
self.conv1 = nn.Conv1D(num_node, num_node, kernel_size=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv1D(
num_state, num_state, kernel_size=1, bias_attr=bias)

def forward(self, x):
h = self.conv1(paddle.transpose(x, perm=(0, 2, 1)))
h = paddle.transpose(h, perm=(0, 2, 1))
h = h + x
h = self.relu(self.conv2(h))
return h


class GruModule(nn.Layer):
def __init__(self,
num_input=512,
num_state=128,
num_node=64,
normalize=False):
super(GruModule, self).__init__()
self.normalize = normalize
self.num_state = num_state
self.num_node = num_node
self.reduction_dim = nn.Conv2D(num_input, num_state, kernel_size=1)
self.projection_mat = nn.Conv2D(num_input, num_node, kernel_size=1)
self.gcn = GCN(num_state=self.num_state, num_node=self.num_node)
self.extend_dim = nn.Conv2D(
self.num_state, num_input, kernel_size=1, bias_attr=False)
self.extend_bn = nn.SyncBatchNorm(num_input, epsilon=1e-4)

def forward(self, input):
n, c, h, w = input.shape
# B, C, H, W
reduction_dim = self.reduction_dim(input)
# B, N, H, W
mat_B = self.projection_mat(input)
# B, C, H*W
reshaped_reduction = paddle.reshape(
reduction_dim, shape=[n, self.num_state, h * w])
# B, N, H*W
reshaped_B = paddle.reshape(mat_B, shape=[n, self.num_node, h * w])
# B, N, H*W
reproject = reshaped_B
# B, C, N
node_state_V = paddle.matmul(
reshaped_reduction, paddle.transpose(
reshaped_B, perm=[0, 2, 1]))

if self.normalize:
node_state_V = node_state_V * (1. / reshaped_reduction.shape[2])

# B, C, N
gcn_out = self.gcn(node_state_V)
# B, C, H*W
Y = paddle.matmul(gcn_out, reproject)
# B, C, H, W
Y = paddle.reshape(Y, shape=[n, self.num_state, h, w])
Y_extend = self.extend_dim(Y)
Y_extend = self.extend_bn(Y_extend)

out = input + Y_extend
return out

0 comments on commit 3fd33d5

Please sign in to comment.