forked from xiaofengShi/howto100m
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
94 lines (79 loc) · 2.74 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import re
class Net(nn.Module):
def __init__(
self,
embd_dim=1024,
video_dim=2048,
n_pair=1,
we_dim=300,
max_words=30,
sentence_dim=-1,
we=None,
):
super(Net, self).__init__()
if sentence_dim <= 0:
self.text_pooling = Sentence_Maxpool(we_dim, embd_dim)
else:
self.text_pooling = Sentence_Maxpool(we_dim, sentence_dim)
self.GU_text = Gated_Embedding_Unit(
self.text_pooling.out_dim, embd_dim, gating=True)
self.GU_video = Gated_Embedding_Unit(
video_dim, embd_dim, gating=True)
self.n_pair = n_pair
self.embd_dim = embd_dim
self.we = we
self.we_dim = we_dim
def save_checkpoint(self, path):
th.save(self.state_dict(), path)
def load_checkpoint(self, path, cpu=False):
if cpu:
self.load_state_dict(th.load(path,
map_location=lambda storage, loc: storage))
else:
self.load_state_dict(th.load(path))
def forward(self, video, text):
video = self.GU_video(video)
text = self.GU_text(self.text_pooling(text))
return th.matmul(text, video.t())
class Gated_Embedding_Unit(nn.Module):
def __init__(self, input_dimension, output_dimension, gating=True):
super(Gated_Embedding_Unit, self).__init__()
self.fc = nn.Linear(input_dimension, output_dimension)
self.cg = Context_Gating(output_dimension)
self.gating = gating
def forward(self, x):
x = self.fc(x)
if self.gating:
x = self.cg(x)
x = F.normalize(x)
return x
class Sentence_Maxpool(nn.Module):
def __init__(self, word_dimension, output_dim, relu=True):
super(Sentence_Maxpool, self).__init__()
self.fc = nn.Linear(word_dimension, output_dim)
self.out_dim = output_dim
self.relu = relu
def forward(self, x):
x = self.fc(x)
if self.relu:
x = F.relu(x)
return th.max(x, dim=1)[0]
class Context_Gating(nn.Module):
def __init__(self, dimension, add_batch_norm=False):
super(Context_Gating, self).__init__()
self.fc = nn.Linear(dimension, dimension)
self.add_batch_norm = add_batch_norm
self.batch_norm = nn.BatchNorm1d(dimension)
def forward(self, x):
x1 = self.fc(x)
if self.add_batch_norm:
x1 = self.batch_norm(x1)
x = th.cat((x, x1), 1)
return F.glu(x, 1)