Skip to content

Commit

Permalink
[WIP] Decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
bgshih committed Dec 4, 2017
1 parent f5b1016 commit 8594dec
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 10 deletions.
Empty file added builders/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions builders/decoder_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from rare.decoders import AttentionDecoder
from rare.protos import decoder_pb2
from rare.builders import rnn_cell_builder

def build(decoder_config):
if not isinstance(decoder_config, decoder_pb2.Decoder):
raise ValueError('decoder_config not of type '
'decoder_pb2.Decoder')
decoder_oneof = decoder_config.WhichOneof('decoder_oneof')

if decoder_oneof == 'attention_decoder':
attention_decoder_config = decoder_config.attention_decoder

rnn_cell = rnn_cell_builder.build(attention_decoder_config.rnn_cell)
num_attention_units = attention_decoder_config.num_attention_units
attention_conv_kernel_size = attention_decoder_config.attention_conv_kernel_size

attention_decoder_object = AttentionDecoder(
rnn_cell,
num_attention_units,
attention_conv_kernel_size
)
return attention_decoder_object

else:
raise ValueError('Unknown decoder_oneof: {}'.format(decoder_oneof))
27 changes: 27 additions & 0 deletions builders/decoder_builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import tensorflow as tf

from google.protobuf import text_format
from rare.builders import decoder_builder
from rare.protos import decoder_pb2


class DecoderBuilderTest(tf.test.TestCase):

def test_build_decoder(self):
decoder_text_proto = """
attention_decoder {
rnn_cell {
gru_cell {
num_units: 256
}
}
num_attention_units: 256
attention_conv_kernel_size: 5
}
"""
decoder_proto = decoder_pb2.Decoder()
text_format.Merge(decoder_proto, decoder_proto)
decoder_object = decoder_builder.build(decoder_proto)

self.assertEqual(decoder_object.num_attention_units, 256)
self.assertEqual(decoder_object.attention_conv_kernel_size, 5)
30 changes: 30 additions & 0 deletions builders/rnn_cell_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import tensorflow as tf
from rare.protos import rnn_cell_pb2


def build(rnn_cell_config):
if not isinstance(rnn_cell_config, rnn_cell_pb2.RnnCell):
raise ValueError('rnn_cell_config not of type '
'rnn_cell_pb2.RnnCell')
rnn_cell_oneof = rnn_cell_config.WhichOneof('rnn_cell_oneof')

if rnn_cell_oneof == 'lstm_cell':
lstm_cell_config = rnn_cell_config.lstm_cell

lstm_cell_object = tf.contrib.rnn.LSTMCell(
lstm_cell_config.num_units,
use_peepholes=lstm_cell_config.use_peepholes,
forget_bias=lstm_cell_config.forget_bias
)
return lstm_cell_object

elif rnn_cell_oneof == 'gru_cell':
gru_cell_config = rnn_cell_config.gru_cell

gru_cell_object = tf.contrib.rnn.GRUCell(
gru_cell_config.num_units,
)
return gru_cell_object

else:
raise ValueError('Unknown rnn_cell_oneof: {}'.format(rnn_cell_oneof))
37 changes: 37 additions & 0 deletions builders/rnn_cell_builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import tensorflow as tf

from google.protobuf import text_format
from rare.builders import rnn_cell_builder
from rare.protos import rnn_cell_pb2


class RnnCellBuilderTest(tf.test.TestCase):

def test_build_lstm_cell(self):
rnn_cell_text_proto = """
lstm_cell {
num_units: 1024
use_peepholes: true
forget_bias: 1.5
}
"""
rnn_cell_proto = rnn_cell_pb2.RnnCell()
text_format.Merge(rnn_cell_text_proto, rnn_cell_proto)
rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto)

self.assertEqual(rnn_cell_object.state_size, 1024)

def test_build_gru_cell(self):
rnn_cell_text_proto = """
gru_cell {
num_units: 1024
}
"""
rnn_cell_proto = rnn_cell_pb2.RnnCell()
text_format.Merge(rnn_cell_text_proto, rnn_cell_proto)
rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto)

self.assertEqual(rnn_cell_object.state_size, 1024)

if __name__ == '__main__':
tf.test.main()
Empty file added core/__init__.py
Empty file.
51 changes: 41 additions & 10 deletions decoders/attention_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,56 @@

class AttentionDecoder():
def __init__(self,
rnn_cell,
num_attention_units,
attention_conv_kernel_size):
pass
attention_conv_kernel_size,
is_training=True):
self._rnn_cell = rnn_cell
self._num_attention_units = num_attention_units
self._attention_conv_kernel_size = attention_conv_kernel_size
self._is_training = is_training

def predict(self, feature_map, num_steps):
def predict(self, feature_map, num_steps, decoder_inputs=None):
if not self._is_training:
raise RuntimeError('predict should only be called during training')
if isinstance(feature_map, list):
feature_map = feature_map[-1]

initial_attention = tf.zeros([])
batch_size = feature_map.get_shape()[0].value
feature_map_size = tf.shape(feature_map_size)[1:3]

initial_attention = tf.zeros(feature_map_size, tf.float32)
initial_state = self._rnn_cell.zero_state(batch_size, tf.float32)
initial_output = self._output_embedding_fn(tf.tile([symbols.GO], batch_size))

outputs_list = []

def _predict_step(self, feature_map, last_state, last_attention, last_output, reuse=None):
last_state = initial_state
last_attention = initial_attention
last_output = initial_output
for i in range(num_steps):
with tf.variable_scope('PredictStep_{}'.format(i), reuse=(i > 0)):
output, new_state, attention_weights = \
self._predict_step(
feature_map,
last_state,
last_attention,
last_output
)
outputs_list.append(output)
last_state = new_state
last_attention = attention_weights
last_output = self._output_embedding_fn(decoder_inputs[:,i])
outputs = tf.concat(outputs_list, axis=1) # => [batch_size, num_steps, output_dims]
return outputs

def _predict_step(self, feature_map, last_state, last_attention, last_output):
"""
Args:
feature_map: a float32 tensor with shape [batch_size, map_height, map_width, depth]
last_state: a float32 tensor with shape [batch_size, ]
last_attention: a float32 tensor with shape [batch_size, map_height, map_width, depth]
"""

batch_size = feature_map.get_shape()[0].value
feature_map_depth = feature_map.get_shape()[3].value
if batch_size is None or feature_map_depth is None:
Expand All @@ -40,7 +72,6 @@ def _predict_step(self, feature_map, last_state, last_attention, last_output, re
kernel_size=self._attention_conv_kernel_size,
stride=1,
biases_initializer=None,
reuse=reuse
) # => [batch_size, map_height, map_width, num_attention_units]
ws = fully_connected(
last_state,
Expand All @@ -57,7 +88,6 @@ def _predict_step(self, feature_map, last_state, last_attention, last_output, re
1,
activation_fn=None,
biases_initializer=None,
reuse=reuse
) # => [batch_size, map_height, map_width, 1]
attention_scores_flat = tf.reshape(
tf.squeeze(attention_scores, axis=3),
Expand All @@ -79,5 +109,6 @@ def _predict_step(self, feature_map, last_state, last_attention, last_output, re
axis=2
) # [1, map_depth]

return self._rnn(glimpse, last_state)

output, new_state = self._rnn(glimpse, last_state)

return output, new_state, attention_weights
16 changes: 16 additions & 0 deletions protos/decoder.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
syntax = "proto2";
package rare.protos;

import "rare/protos/rnn_cell.proto";

message DecoderConfig {
oneof decoder_oneof {
AttentionDecoder attention_decoder = 1;
}
}

message AttentionDecoder {
optional RnnCell rnn_cell = 1;
optional uint32 num_attention_units = 2 [default=128];
optional uint32 attention_conv_kernel_size = 3 [default=5];
}
19 changes: 19 additions & 0 deletions protos/rnn_cell.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
syntax = "proto2";
package rare.protos;

message RnnCell {
oneof rnn_cell_oneof {
LstmCell lstm_cell = 1;
GruCell gru_cell = 2;
}
}

message LstmCell {
optional uint32 num_units = 1 [default=128];
optional boolean use_peepholes = 2 [default=false];
optional float32 forget_bias = 3 [default=1.0];
}

message GruCell {
optional uint32 num_units = 1 [default=128];
}

0 comments on commit 8594dec

Please sign in to comment.