-
Notifications
You must be signed in to change notification settings - Fork 11
/
model.py
139 lines (118 loc) · 4.87 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# This module is mostly borrowed from Olivecrona original implementation
# (https://github.com/MarcusOlivecrona/REINVENT).
# We adpted it to pytorch0.4.0.
#
# Shuangjia Zheng, Aug 2018
#
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_struct import Variable
class MultiGRU(nn.Module):
""" Implements a three layer GRU cell including an embedding layer
and an output linear layer back to the size of the vocabulary"""
def __init__(self, voc_size):
super(MultiGRU, self).__init__()
self.embedding = nn.Embedding(voc_size, 128)
self.gru_1 = nn.GRUCell(128, 512)
self.gru_2 = nn.GRUCell(512, 512)
self.gru_3 = nn.GRUCell(512, 512)
self.linear = nn.Linear(512, voc_size)
def forward(self, x, h):
x = self.embedding(x)
h_out = Variable(torch.zeros(h.size()))
x = h_out[0] = self.gru_1(x, h[0])
x = h_out[1] = self.gru_2(x, h[1])
x = h_out[2] = self.gru_3(x, h[2])
x = self.linear(x)
return x, h_out
def init_h(self, batch_size):
# Initial cell state is zero
return Variable(torch.zeros(3, batch_size, 512))
class RNN():
"""Implements the Prior and Agent RNN. Needs a Vocabulary instance in
order to determine size of the vocabulary and index of the END token"""
def __init__(self, voc):
self.rnn = MultiGRU(voc.vocab_size)
if torch.cuda.is_available():
self.rnn.cuda()
self.voc = voc
def likelihood(self, target):
"""
Retrieves the likelihood of a given sequence
Args:
target: (batch_size * sequence_lenght) A batch of sequences
Outputs:
log_probs : (batch_size) Log likelihood for each example*
entropy: (batch_size) The entropies for the sequences. Not
currently used.
"""
batch_size, seq_length = target.size()
start_token = Variable(torch.zeros(batch_size, 1).long())
start_token[:] = self.voc.vocab['START']
x = torch.cat((start_token, target[:, :-1]), 1)
h = self.rnn.init_h(batch_size)
log_probs = Variable(torch.zeros(batch_size))
entropy = Variable(torch.zeros(batch_size))
for step in range(seq_length):
logits, h = self.rnn(x[:, step], h)
log_prob = F.log_softmax(logits)
prob = F.softmax(logits)
log_probs += NLLLoss(log_prob, target[:, step])
entropy += -torch.sum((log_prob * prob), 1)
return log_probs, entropy
def sample(self, batch_size, max_length=150):
"""
Sample a batch of sequences
Args:
batch_size : Number of sequences to sample
max_length: Maximum length of the sequences
Outputs:
seqs: (batch_size, seq_length) The sampled sequences.
log_probs : (batch_size) Log likelihood for each sequence.
entropy: (batch_size) The entropies for the sequences. Not
currently used.
"""
start_token = Variable(torch.zeros(batch_size).long())
start_token[:] = self.voc.vocab['START']
h = self.rnn.init_h(batch_size)
x = start_token
sequences = []
log_probs = Variable(torch.zeros(batch_size))
finished = torch.zeros(batch_size).byte()
entropy = Variable(torch.zeros(batch_size))
if torch.cuda.is_available():
finished = finished.cuda()
for step in range(max_length):
logits, h = self.rnn(x, h)
prob = F.softmax(logits)
log_prob = F.log_softmax(logits)
x = torch.multinomial(prob,1).view(-1)
sequences.append(x.view(-1, 1))
log_probs += NLLLoss(log_prob, x)
entropy += -torch.sum((log_prob * prob), 1)
x = Variable(x.data)
EOS_sampled = (x == self.voc.vocab['END']).data
finished = torch.ge(finished + EOS_sampled, 1)
if torch.prod(finished) == 1: break
sequences = torch.cat(sequences, 1)
return sequences.data, log_probs, entropy
def NLLLoss(inputs, targets):
"""
Custom Negative Log Likelihood loss that returns loss per example,
rather than for the entire batch.
Args:
inputs : (batch_size, num_classes) *Log probabilities of each class*
targets: (batch_size) *Target class index*
Outputs:
loss : (batch_size) *Loss for each example*
"""
if torch.cuda.is_available():
target_expanded = torch.zeros(inputs.size()).cuda()
else:
target_expanded = torch.zeros(inputs.size())
target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0)
loss = Variable(target_expanded) * inputs
loss = torch.sum(loss, 1)
return loss