-
Notifications
You must be signed in to change notification settings - Fork 194
/
Copy pathbidirectional_rnn.py
94 lines (79 loc) · 3.26 KB
/
bidirectional_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tensorflow as tf
from tensorflow.contrib.framework import arg_scope
from tensorflow.contrib.layers import fully_connected
class DynamicBidirectionalRnn(object):
def __init__(self,
fw_cell,
bw_cell,
rnn_regularizer=None,
num_output_units=None,
fc_hyperparams=None,
summarize_activations=False):
self._fw_cell = fw_cell
self._bw_cell = bw_cell
self._rnn_regularizer = rnn_regularizer
self._num_output_units = num_output_units
self._fc_hyperparams = fc_hyperparams
self._summarize_activations = summarize_activations
def predict(self, inputs, scope=None):
with tf.variable_scope(scope, 'BidirectionalRnn', [inputs]) as scope:
(output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(
self._fw_cell, self._bw_cell, inputs, time_major=False, dtype=tf.float32)
rnn_outputs = tf.concat([output_fw, output_bw], axis=2)
filter_weights = lambda vars : [x for x in vars if x.op.name.endswith('kernel')]
tf.contrib.layers.apply_regularization(
self._rnn_regularizer,
filter_weights(self._fw_cell.trainable_weights))
tf.contrib.layers.apply_regularization(
self._rnn_regularizer,
filter_weights(self._bw_cell.trainable_weights))
if self._num_output_units > 0:
with arg_scope(self._fc_hyperparams):
rnn_outputs = fully_connected(rnn_outputs, self._num_output_units, activation_fn=tf.nn.relu)
if self._summarize_activations:
max_time = rnn_outputs.get_shape()[1].value
for t in range(max_time):
activation_t = rnn_outputs[:,t,:]
tf.summary.histogram('Activations/{}/Step_{}'.format(scope.name, t), activation_t)
return rnn_outputs
class StaticBidirectionalRnn(object):
def __init__(self,
fw_cell,
bw_cell,
rnn_regularizer=None,
num_output_units=None,
fc_hyperparams=None,
summarize_activations=False):
self._fw_cell = fw_cell
self._bw_cell = bw_cell
self._rnn_regularizer = rnn_regularizer
self._num_output_units = num_output_units
self._fc_hyperparams = fc_hyperparams
self._summarize_activations = summarize_activations
def predict(self, inputs, scope=None):
with tf.variable_scope(scope, 'BidirectionalRnn', [inputs]) as scope:
inputs_list = tf.unstack(inputs, axis=1)
outputs_list, _, _ = tf.nn.static_bidirectional_rnn(
self._fw_cell,
self._bw_cell,
inputs_list,
dtype=tf.float32
)
# apply regularizer
filter_weights = lambda vars : [x for x in vars if x.op.name.endswith('kernel')]
tf.contrib.layers.apply_regularization(
self._rnn_regularizer,
filter_weights(self._fw_cell.trainable_weights))
tf.contrib.layers.apply_regularization(
self._rnn_regularizer,
filter_weights(self._bw_cell.trainable_weights))
# output projection
rnn_outputs = tf.stack(outputs_list, axis=1)
if self._num_output_units > 0:
with arg_scope(self._fc_hyperparams):
rnn_outputs = fully_connected(
rnn_outputs,
self._num_output_units,
activation_fn=tf.nn.relu
)
return rnn_outputs