Skip to content

Commit a8f3ea4

Browse files
authored
Add files via upload
1 parent 4e17417 commit a8f3ea4

File tree

6 files changed

+729
-1
lines changed

6 files changed

+729
-1
lines changed

C2AFNet.py

+350
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
import torch
2+
import torch.nn as nn
3+
import math
4+
import torch.nn.functional as F
5+
import torch.utils.model_zoo as model_zoo
6+
import torchvision.models as models
7+
8+
class Separable_conv(nn.Module):
9+
def __init__(self, inp, oup):
10+
super(Separable_conv, self).__init__()
11+
12+
self.conv = nn.Sequential(
13+
# dw
14+
nn.Conv2d(inp, inp, kernel_size=3, stride=1, padding=1, groups=inp, bias=False),
15+
nn.BatchNorm2d(inp),
16+
nn.ReLU(inplace=True),
17+
# pw
18+
nn.Conv2d(inp, oup, kernel_size=1),
19+
)
20+
21+
def forward(self, x):
22+
return self.conv(x)
23+
24+
25+
model = models.vgg16_bn(pretrained=True)
26+
model_urls = {
27+
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
28+
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
29+
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
30+
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
31+
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
32+
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
33+
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
34+
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
35+
}
36+
37+
class vgg_rgb(nn.Module):
38+
def __init__(self, pretrained=True):
39+
super(vgg_rgb, self).__init__()
40+
self.features = nn.Sequential(
41+
nn.Conv2d(3, 64, 3, 1, 1), # first model 224*24*64
42+
nn.BatchNorm2d(64),
43+
nn.ReLU(inplace=True),
44+
nn.Conv2d(64, 64, 3, 1, 1),
45+
nn.BatchNorm2d(64),
46+
nn.ReLU(inplace=True), # [:6]
47+
nn.MaxPool2d(kernel_size=2, stride=2),
48+
nn.Conv2d(64, 128, 3, 1, 1), # second model 112*112*128
49+
nn.BatchNorm2d(128),
50+
nn.ReLU(inplace=True),
51+
nn.Conv2d(128, 128, 3, 1, 1),
52+
nn.BatchNorm2d(128),
53+
nn.ReLU(inplace=True), # [6:13]
54+
nn.MaxPool2d(kernel_size=2, stride=2),
55+
nn.Conv2d(128, 256, 3, 1, 1), # third model 56*56*256
56+
nn.BatchNorm2d(256),
57+
nn.ReLU(inplace=True),
58+
nn.Conv2d(256, 256, 3, 1, 1),
59+
nn.BatchNorm2d(256),
60+
nn.ReLU(inplace=True),
61+
nn.Conv2d(256, 256, 3, 1, 1),
62+
nn.BatchNorm2d(256),
63+
nn.ReLU(inplace=True), # [13:23]
64+
nn.MaxPool2d(kernel_size=2, stride=2),
65+
nn.Conv2d(256, 512, 3, 1, 1), # forth model 28*28*512
66+
nn.BatchNorm2d(512),
67+
nn.ReLU(inplace=True),
68+
nn.Conv2d(512, 512, 3, 1, 1),
69+
nn.BatchNorm2d(512),
70+
nn.ReLU(inplace=True),
71+
nn.Conv2d(512, 512, 3, 1, 1),
72+
nn.BatchNorm2d(512),
73+
nn.ReLU(inplace=True), # [13:33]
74+
nn.MaxPool2d(kernel_size=2, stride=2),
75+
nn.Conv2d(512, 512, 3, 1, 1), # fifth model 14*14*512
76+
nn.BatchNorm2d(512),
77+
nn.ReLU(inplace=True),
78+
nn.Conv2d(512, 512, 3, 1, 1),
79+
nn.BatchNorm2d(512),
80+
nn.ReLU(inplace=True),
81+
nn.Conv2d(512, 512, 3, 1, 1),
82+
nn.BatchNorm2d(512),
83+
nn.ReLU(inplace=True), # [33:43]
84+
)
85+
86+
if pretrained:
87+
pretrained_vgg = model_zoo.load_url(model_urls['vgg16_bn'])
88+
model_dict = {}
89+
state_dict = self.state_dict()
90+
for k, v in pretrained_vgg.items():
91+
if k in state_dict:
92+
model_dict[k] = v
93+
# print(k, v)
94+
95+
state_dict.update(model_dict)
96+
self.load_state_dict(state_dict)
97+
98+
def forward(self, rgb):
99+
A1 = self.features[:6](rgb)
100+
A2 = self.features[6:13](A1)
101+
A3 = self.features[13:23](A2)
102+
A4 = self.features[23:33](A3)
103+
A5 = self.features[33:43](A4)
104+
return A1, A2, A3, A4, A5
105+
106+
107+
class vgg_depth(nn.Module):
108+
def __init__(self, pretrained=True):
109+
super(vgg_depth, self).__init__()
110+
self.features = nn.Sequential(
111+
nn.Conv2d(3, 64, 3, 1, 1), # first model 224*224*64
112+
nn.BatchNorm2d(64),
113+
nn.ReLU(inplace=True),
114+
nn.Conv2d(64, 64, 3, 1, 1),
115+
nn.BatchNorm2d(64),
116+
nn.ReLU(inplace=True), # [:6]
117+
nn.MaxPool2d(kernel_size=2, stride=2),
118+
nn.Conv2d(64, 128, 3, 1, 1), # second model 112*112*128
119+
nn.BatchNorm2d(128),
120+
nn.ReLU(inplace=True),
121+
nn.Conv2d(128, 128, 3, 1, 1),
122+
nn.BatchNorm2d(128),
123+
nn.ReLU(inplace=True), # [6:13]
124+
nn.MaxPool2d(kernel_size=2, stride=2),
125+
nn.Conv2d(128, 256, 3, 1, 1), # third model 56*56*256
126+
nn.BatchNorm2d(256),
127+
nn.ReLU(inplace=True),
128+
nn.Conv2d(256, 256, 3, 1, 1),
129+
nn.BatchNorm2d(256),
130+
nn.ReLU(inplace=True),
131+
nn.Conv2d(256, 256, 3, 1, 1),
132+
nn.BatchNorm2d(256),
133+
nn.ReLU(inplace=True), # [13:23]
134+
nn.MaxPool2d(kernel_size=2, stride=2),
135+
nn.Conv2d(256, 512, 3, 1, 1), # forth model 28*28*512
136+
nn.BatchNorm2d(512),
137+
nn.ReLU(inplace=True),
138+
nn.Conv2d(512, 512, 3, 1, 1),
139+
nn.BatchNorm2d(512),
140+
nn.ReLU(inplace=True),
141+
nn.Conv2d(512, 512, 3, 1, 1),
142+
nn.BatchNorm2d(512),
143+
nn.ReLU(inplace=True), # [13:33]
144+
nn.MaxPool2d(kernel_size=2, stride=2),
145+
nn.Conv2d(512, 512, 3, 1, 1), # fifth model 14*14*512
146+
nn.BatchNorm2d(512),
147+
nn.ReLU(inplace=True),
148+
nn.Conv2d(512, 512, 3, 1, 1),
149+
nn.BatchNorm2d(512),
150+
nn.ReLU(inplace=True),
151+
nn.Conv2d(512, 512, 3, 1, 1),
152+
nn.BatchNorm2d(512),
153+
nn.ReLU(inplace=True), # [33:43]
154+
)
155+
156+
if pretrained:
157+
pretrained_vgg = model_zoo.load_url(model_urls['vgg16_bn'])
158+
model_dict = {}
159+
state_dict = self.state_dict()
160+
for k, v in pretrained_vgg.items():
161+
if k in state_dict:
162+
model_dict[k] = v
163+
# print(k, v)
164+
165+
state_dict.update(model_dict)
166+
self.load_state_dict(state_dict)
167+
168+
def forward(self, thermal):
169+
A1_d = self.features[:6](thermal)
170+
A2_d = self.features[6:13](A1_d)
171+
A3_d = self.features[13:23](A2_d)
172+
A4_d = self.features[23:33](A3_d)
173+
A5_d = self.features[33:43](A4_d)
174+
return A1_d, A2_d, A3_d, A4_d, A5_d
175+
176+
177+
class Hsigmoid(nn.Module):
178+
def __init__(self, inplace=True):
179+
super(Hsigmoid, self).__init__()
180+
self.inplace = inplace
181+
182+
def forward(self, x):
183+
return F.relu6(x + 3., inplace=self.inplace) / 6.
184+
185+
186+
class Spatical_Fuse_attention3_GHOST(nn.Module): # 最终为rgb rgb, y为depth 加入恒等变化
187+
def __init__(self, in_channels,):
188+
super(Spatical_Fuse_attention3_GHOST, self).__init__()
189+
self.conv = nn.Conv2d(in_channels, 1, 3, 1, 1)
190+
self.active = Hsigmoid()
191+
192+
def forward(self, x, y):
193+
input_y = self.conv(y)
194+
input_y = self.active(input_y)
195+
# return input_y
196+
return x + x * input_y
197+
198+
class Channel_Fuse_attention2(nn.Module): # 最终为depth x为depth, y为rgb 加入恒等变化
199+
def __init__(self, channel, reduction=4):
200+
super(Channel_Fuse_attention2, self).__init__()
201+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
202+
self.fc = nn.Sequential(
203+
nn.Linear(channel, channel // reduction, bias=False),
204+
nn.Linear(channel // reduction, channel, bias=False),
205+
Hsigmoid()
206+
)
207+
208+
def forward(self, x, y):
209+
b, c, _, _ = x.size()
210+
y = self.avg_pool(y).view(b, c)
211+
y = self.fc(y).view(b, c, 1, 1)
212+
return x + x * y.expand_as(x)
213+
214+
215+
class Gatefusion3(nn.Module):
216+
def __init__(self, channel):
217+
super(Gatefusion3, self).__init__()
218+
self.channel = channel
219+
self.gate = nn.Sigmoid()
220+
221+
def forward(self, x, y, fusion_up):
222+
first_fusion = torch.cat((x, y), dim=1)
223+
gate_fusion = self.gate(first_fusion)
224+
gate_fusion = torch.split(gate_fusion, self.channel, dim=1)
225+
fusion_x = gate_fusion[0] * x + x
226+
fusion_y = gate_fusion[1] * y + y
227+
fusion = fusion_x + fusion_y
228+
fusion = torch.abs((fusion - fusion_up)) * fusion + fusion
229+
return fusion
230+
231+
class Gatefusion3_fusionup(nn.Module):
232+
def __init__(self, channel):
233+
super(Gatefusion3_fusionup, self).__init__()
234+
self.channel = channel
235+
self.gate = nn.Sigmoid()
236+
237+
def forward(self, x, y):
238+
first_fusion = torch.cat((x, y), dim=1)
239+
gate_fusion = self.gate(first_fusion)
240+
gate_fusion = torch.split(gate_fusion, self.channel, dim=1)
241+
fusion_x = gate_fusion[0] * x + x
242+
fusion_y = gate_fusion[1] * y + y
243+
fusion = fusion_x + fusion_y
244+
return fusion
245+
246+
class CCAFNet(nn.Module):
247+
def __init__(self, ):
248+
super(CCAFNet, self).__init__()
249+
# rgb,depth encode
250+
self.rgb_pretrained = vgg_rgb()
251+
self.depth_pretrained = vgg_depth()
252+
253+
# rgb Fuse_model
254+
self.SAG1 = Spatical_Fuse_attention3_GHOST(64)
255+
self.SAG2 = Spatical_Fuse_attention3_GHOST(128)
256+
self.SAG3 = Spatical_Fuse_attention3_GHOST(256)
257+
258+
# depth Fuse_model
259+
self.CAG4 = Channel_Fuse_attention2(512)
260+
self.CAG5 = Channel_Fuse_attention2(512)
261+
262+
self.gatefusion5 = Gatefusion3_fusionup(512)
263+
self.gatefusion4 = Gatefusion3(512)
264+
self.gatefusion3 = Gatefusion3(256)
265+
self.gatefusion2 = Gatefusion3(128)
266+
self.gatefusion1 = Gatefusion3(64)
267+
268+
269+
# Upsample_model
270+
self.upsample1 = nn.Sequential(nn.Conv2d(288, 144, 3, 1, 1),nn.BatchNorm2d(144),nn.ReLU())
271+
self.upsample2 = nn.Sequential(nn.Conv2d(448, 224,3,1,1),nn.BatchNorm2d(224),nn.ReLU(),
272+
nn.UpsamplingBilinear2d(scale_factor=2, ))
273+
self.upsample3 = nn.Sequential(nn.Conv2d(640, 320,3,1,1),nn.BatchNorm2d(320),nn.ReLU(),
274+
nn.UpsamplingBilinear2d(scale_factor=2, ))
275+
self.upsample4 = nn.Sequential(nn.Conv2d(768, 384,3,1,1),nn.BatchNorm2d(384),nn.ReLU(),
276+
nn.UpsamplingBilinear2d(scale_factor=2, ))
277+
self.upsample5 = nn.Sequential(nn.Conv2d(512, 256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(),
278+
nn.UpsamplingBilinear2d(scale_factor=2, ))
279+
280+
# duibi
281+
self.upsample5_4 = nn.Sequential(nn.Conv2d(512, 512,3,1,1),nn.BatchNorm2d(512),nn.ReLU(),
282+
nn.UpsamplingBilinear2d(scale_factor=2, ))
283+
self.upsample4_3 = nn.Sequential(nn.Conv2d(768, 256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(),
284+
nn.UpsamplingBilinear2d(scale_factor=2, ))
285+
self.upsample3_2 = nn.Sequential(nn.Conv2d(640, 128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(),
286+
nn.UpsamplingBilinear2d(scale_factor=2, ))
287+
self.upsample2_1 = nn.Sequential(nn.Conv2d(448, 64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(),
288+
nn.UpsamplingBilinear2d(scale_factor=2, ))
289+
290+
self.conv = nn.Conv2d(144, 1, 1)
291+
self.conv2 = nn.Conv2d(224, 1, 1)
292+
self.conv3 = nn.Conv2d(320, 1, 1)
293+
self.conv4 = nn.Conv2d(384, 1, 1)
294+
self.conv5 = nn.Conv2d(256, 1, 1)
295+
296+
def forward(self, rgb, depth):
297+
# rgb
298+
A1, A2, A3, A4, A5 = self.rgb_pretrained(rgb)
299+
# depth
300+
A1_d, A2_d, A3_d, A4_d, A5_d = self.depth_pretrained(depth)
301+
302+
SAG1_R = self.SAG1(A1, A1_d)
303+
SAG2_R = self.SAG2(A2, A2_d)
304+
SAG3_R = self.SAG3(A3, A3_d)
305+
306+
CAG5_D = self.CAG5(A5_d, A5)
307+
CAG4_D = self.CAG4(A4_d, A4)
308+
309+
F5 = self.gatefusion5(A5, CAG5_D)
310+
F5_UP = self.upsample5_4(F5)
311+
F5 = self.upsample5(F5) # 14*14
312+
F4 = self.gatefusion4(A4, CAG4_D, F5_UP)
313+
F4 = torch.cat((F4, F5), dim=1)
314+
F4_UP = self.upsample4_3(F4)
315+
F4 = self.upsample4(F4) # 28*28
316+
F3 = self.gatefusion3(SAG3_R, A3_d, F4_UP)
317+
F3 = torch.cat((F3, F4), dim=1)
318+
F3_UP = self.upsample3_2(F3)
319+
F3 = self.upsample3(F3) # 56*56
320+
F2 = self.gatefusion2(SAG2_R, A2_d, F3_UP)
321+
F2 = torch.cat((F2, F3), dim=1)
322+
F2_UP = self.upsample2_1(F2)
323+
F2 = self.upsample2(F2) # 112*112
324+
F1 = self.gatefusion1(SAG1_R, A1_d, F2_UP)
325+
F1 = torch.cat((F1, F2), dim=1)
326+
F1 = self.upsample1(F1) # 224*224
327+
out = self.conv(F1)
328+
329+
out5 = self.conv5(F5)
330+
out4 = self.conv4(F4)
331+
out3 = self.conv3(F3)
332+
out2 = self.conv2(F2)
333+
334+
if self.training:
335+
return out, out2, out3, out4, out5
336+
return out
337+
338+
339+
340+
341+
if __name__=='__main__':
342+
343+
# model = ghost_net()
344+
# model.eval()
345+
model = CCAFNet()
346+
rgb = torch.randn(1, 3, 224, 224)
347+
depth = torch.randn(1, 3, 224, 224)
348+
out = model(rgb,depth)
349+
for i in out:
350+
print(i.shape)

README.md

+13-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,13 @@
1-
# CCAFNet
1+
Code and result about RGB-D or CCAFNet(TMM)<br>
2+
'CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images'
3+
4+
# Citation
5+
6+
@inproceedings{fan2020bbsnet, <br>
7+
title={BBS-Net: RGB-D Salient Object Detection with a Bifurcated Backbone Strategy Network},<br>
8+
author={Fan, Deng-Ping and Zhai, Yingjie and Borji, Ali and Yang, Jufeng and Shao, Ling},<br>
9+
booktitle={ECCV},<br>
10+
year={2020}}]<br>
11+
12+
# Acknowledgement
13+
The implement this project based on the code of ‘Cascaded Partial Decoder for Fast and Accurate Salient Object Detection, CVPR2019’and 'BBS-Net: RGB-D Salient Object Detection with a Bifurcated Backbone Strategy Network' proposed by Wu et al and Deng et al.

config.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import argparse
2+
parser = argparse.ArgumentParser()
3+
# train/val
4+
parser.add_argument('--epoch', type=int, default=200, help='epoch number')
5+
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
6+
parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
7+
parser.add_argument('--trainsize', type=int, default=224, help='training dataset size')
8+
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
9+
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
10+
parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate')
11+
parser.add_argument('--load', type=str, default=None, help='train from checkpoints')
12+
parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
13+
parser.add_argument('--train_root', type=str, default='/home/zy/PycharmProjects/zy/newdata/train/NJUNLPR', help='the train images root')
14+
parser.add_argument('--val_root', type=str, default='/home/zy/PycharmProjects/zy/newdata/val', help='the val images root')
15+
parser.add_argument('--save_path', type=str, default='/media/zy/shuju/RGBDweight/PVTbackbone_SC2/', help='the path to save models and logs')
16+
# test(predict)
17+
parser.add_argument('--testsize', type=int, default=224, help='testing size')
18+
parser.add_argument('--test_path',type=str,default='/home/zy/PycharmProjects/zy/newdata/test/',help='test dataset path')
19+
# parser.add_argument('--test_path',type=str,default='/home/zy/PycharmProjects/zy/DUT-RGBD/test_data/',help='test dataset path')
20+
opt = parser.parse_args()

0 commit comments

Comments
 (0)