-
-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathreuse_modules.py
49 lines (37 loc) · 1.32 KB
/
reuse_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.utils import weight_norm
from torch.nn.utils.rnn import pack_padded_sequence
import config
class Fusion(nn.Module):
""" Crazy multi-modal fusion: negative squared difference minus relu'd sum
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
# found through grad student descent ;)
return - (x - y)**2 + F.relu(x + y)
class FCNet(nn.Module):
def __init__(self, in_size, out_size, activate=None, drop=0.0):
super(FCNet, self).__init__()
self.lin = weight_norm(nn.Linear(in_size, out_size), dim=None)
self.drop_value = drop
self.drop = nn.Dropout(drop)
# in case of using upper character by mistake
self.activate = activate.lower() if (activate is not None) else None
if activate == 'relu':
self.ac_fn = nn.ReLU()
elif activate == 'sigmoid':
self.ac_fn = nn.Sigmoid()
elif activate == 'tanh':
self.ac_fn = nn.Tanh()
def forward(self, x):
if self.drop_value > 0:
x = self.drop(x)
x = self.lin(x)
if self.activate is not None:
x = self.ac_fn(x)
return x