-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPoseNetB6.py
83 lines (65 loc) · 2.93 KB
/
PoseNetB6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Author: Anurag Ranjan
# Copyright (c) 2019, Anurag Ranjan
# All rights reserved.
# based on github.com/ClementPinard/SfMLearner-Pytorch
import torch
import torch.nn as nn
def conv(in_planes, out_planes, kernel_size=3):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=2),
nn.ReLU(inplace=True)
)
def upconv(in_planes, out_planes):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True)
)
class PoseNetB6(nn.Module):
def __init__(self, nb_ref_imgs=2):
super(PoseNetB6, self).__init__()
self.nb_ref_imgs = nb_ref_imgs
conv_planes = [16, 32, 64, 128, 256, 256, 256, 256]
self.conv1 = conv(3*(1+self.nb_ref_imgs), conv_planes[0], kernel_size=7)
self.conv2 = conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = conv(conv_planes[1], conv_planes[2])
self.conv4 = conv(conv_planes[2], conv_planes[3])
self.conv5 = conv(conv_planes[3], conv_planes[4])
self.conv6 = conv(conv_planes[4], conv_planes[5])
self.conv7 = conv(conv_planes[5], conv_planes[6])
self.conv8 = conv(conv_planes[6], conv_planes[7])
self.pose_pred = nn.Conv2d(conv_planes[7], 6*self.nb_ref_imgs, kernel_size=1, padding=0)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.xavier_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def init_mask_weights(self):
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.xavier_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
for module in [self.pred_mask1, self.pred_mask2, self.pred_mask3, self.pred_mask4, self.pred_mask5, self.pred_mask6]:
for m in module.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.xavier_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, target_image, ref_imgs):
assert(len(ref_imgs) == self.nb_ref_imgs)
input = [target_image]
input.extend(ref_imgs)
input = torch.cat(input, 1)
out_conv1 = self.conv1(input)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6)
out_conv8 = self.conv8(out_conv7)
pose = self.pose_pred(out_conv8)
pose = pose.mean(3).mean(2)
pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6)
return pose