-
Notifications
You must be signed in to change notification settings - Fork 5
/
single_proc_qanta.py
165 lines (117 loc) · 4.29 KB
/
single_proc_qanta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from numpy import *
from util.gen_util import *
from util.math_util import *
from util.dtree_util import *
from rnn.propagation import *
from rnn.gradient_check import *
import cPickle
## this class is mainly for developing and checking model correctness
## the parallelized version (qanta.py) should be used for large datasets
# does forward propagation and computes error, no backprop
# useful for gradient check
def objective(rng, data, params, d, len_voc, rel_list, lambdas):
params = unroll_params(params, d, len_voc, rel_list)
(rel_dict, Wv, b, L) = params
error_sum = 0.0
num_nodes = 0
tree_size = 0.0
for tree in data:
nodes = tree.get_nodes()
for node in nodes:
node.vec = L[:, node.ind].reshape( (d, 1))
tree.ans_vec = L[:, tree.ans_ind].reshape( (d, 1))
forward_prop(rng, params, tree, d)
error_sum += tree.error()
tree_size += len(nodes)
# regularize
[lambda_W, lambda_L] = lambdas
reg_cost = 0.0
for key in rel_list:
reg_cost += 0.5 * lambda_W * sum(rel_dict[key] ** 2)
reg_cost += 0.5 * lambda_W * sum(Wv ** 2)
reg_cost += 0.5 * lambda_L * sum(L ** 2)
cost = error_sum / tree_size + reg_cost
return cost
# does both forward and backprop
def objective_and_grad(rng, data, params, d, len_voc, rel_list, lambdas):
params = unroll_params(params, d, len_voc, rel_list)
grads = init_dtrnn_grads(rel_list, d, len_voc)
(rel_dict, Wv, b, L) = params
error_sum = 0.0
num_nodes = 0
tree_size = 0
for index, tree in enumerate(data):
nodes = tree.get_nodes()
for node in nodes:
node.vec = L[:, node.ind].reshape( (d, 1))
tree.ans_vec = L[:, tree.ans_ind].reshape( (d, 1))
forward_prop(rng, params, tree, d)
error_sum += tree.error()
tree_size += len(nodes)
backprop(params[:-1], tree, d, len_voc, grads)
# regularize
[lambda_W, lambda_L] = lambdas
reg_cost = 0.0
for key in rel_list:
reg_cost += 0.5 * lambda_W * sum(rel_dict[key] ** 2)
grads[0][key] = grads[0][key] / tree_size
grads[0][key] += lambda_W * rel_dict[key]
reg_cost += 0.5 * lambda_W * sum(Wv ** 2)
grads[1] = grads[1] / tree_size
grads[1] += lambda_W * Wv
grads[2] = grads[2] / tree_size
reg_cost += 0.5 * lambda_L * sum(L ** 2)
grads[3] = grads[3] / tree_size
grads[3] += lambda_L * L
cost = error_sum / tree_size + reg_cost
grad = roll_params(grads, rel_list)
return cost, grad
# loads a small dataset and checks the gradients
if __name__ == '__main__':
# word embedding dimension
d = 5
rng = random.RandomState(0)
# regularization lambdas: [lambda_W, lambda_L]
lambdas = [1e-4, 1e-3]
# load small dataset for developing
trees = cPickle.load(open('data/toy_dtrees', 'rb'))
# populate vocabulary, relation, and answer list
vocab = []
ans_list = []
rel_list = []
for tree in trees:
for node in tree.get_nodes():
word = node.word.lower()
if word not in vocab:
vocab.append(word)
for ind, rel in node.kids:
if rel not in rel_list:
rel_list.append(rel)
if tree.ans.lower() not in vocab:
vocab.append(tree.ans.lower())
ans_ind = vocab.index(tree.ans.lower())
if ans_ind not in ans_list:
ans_list.append(ans_ind)
ans_list = array(ans_list)
# we don't need the "extra" root node
rel_list.remove('root')
print 'found', len(rel_list), 'dependency relations:'
print rel_list
# generate params / We
params = gen_dtrnn_params(rng, d, rel_list)
rel_list = params[0].keys()
orig_We = gen_rand_we(rng, len(vocab), d)
# add We matrix to params
params += (orig_We, )
r = roll_params(params, rel_list)
dim = r.shape[0]
# add vocab lookup to leaves / answer
print 'adding lookup'
for tree in trees:
for node in tree.get_nodes():
node.ind = vocab.index(node.word.lower())
tree.ans_ind = vocab.index(tree.ans)
tree.ans_list = ans_list[ans_list != tree.ans_ind]
# check gradient
f_args = [trees, r, d, len(vocab), rel_list, lambdas]
gradient_check(objective_and_grad, objective, dim, f_args, rng)