Skip to content

Commit

Permalink
Sync attention predictor and LM
Browse files Browse the repository at this point in the history
  • Loading branch information
bgshih committed Dec 29, 2017
1 parent 7fb75b8 commit 6dd64ec
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 156 deletions.
37 changes: 18 additions & 19 deletions builders/predictor_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rare.builders import loss_builder
from rare.builders import hyperparams_builder
from rare.predictors import attention_predictor
from rare.predictors import attention_predictor_with_lm
# from rare.predictors import attention_predictor_with_lm


def build(config, is_training):
Expand All @@ -21,26 +21,25 @@ def build(config, is_training):
rnn_regularizer_object = hyperparams_builder._build_regularizer(predictor_config.rnn_regularizer)
label_map_object = label_map_builder.build(predictor_config.label_map)
loss_object = loss_builder.build(predictor_config.loss)
kwargs = {
'rnn_cell': rnn_cell_object,
'rnn_regularizer': rnn_regularizer_object,
'num_attention_units': predictor_config.num_attention_units,
'max_num_steps': predictor_config.max_num_steps,
'multi_attention': predictor_config.multi_attention,
'beam_width': predictor_config.beam_width,
'reverse': predictor_config.reverse,
'label_map': label_map_object,
'loss': loss_object,
'is_training': is_training,
'sync': predictor_config.sync
}
if not predictor_config.HasField('lm_rnn_cell'):
predictor_class = attention_predictor.AttentionPredictor
lm_rnn_cell_object = None
else:
predictor_class = attention_predictor_with_lm.AttentionPredictorWithLanguageModel
kwargs['lm_rnn_cell'] = _build_language_model_rnn_cell(predictor_config.lm_rnn_cell)

attention_predictor_object = predictor_class(**kwargs)
lm_rnn_cell_object = _build_language_model_rnn_cell(predictor_config.lm_rnn_cell)

attention_predictor_object = attention_predictor.AttentionPredictor(
rnn_cell=rnn_cell_object,
rnn_regularizer=rnn_regularizer_object,
num_attention_units=predictor_config.num_attention_units,
max_num_steps=predictor_config.max_num_steps,
multi_attention=predictor_config.multi_attention,
beam_width=predictor_config.beam_width,
reverse=predictor_config.reverse,
label_map=label_map_object,
loss=loss_object,
sync=predictor_config.sync,
lm_rnn_cell=lm_rnn_cell_object,
is_training=is_training
)
return attention_predictor_object
else:
raise ValueError('Unknown predictor_oneof: {}'.format(predictor_oneof))
Expand Down
164 changes: 115 additions & 49 deletions builders/predictor_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_predictor_builder(self):
reverse: false
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
text_string: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_predictor_with_lm_builder(self):
reverse: false
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
text_string: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
Expand Down Expand Up @@ -121,55 +121,121 @@ def test_predictor_with_lm_builder(self):
sess_outputs = sess.run({'loss': loss})
print(sess_outputs)

# def test_sync_predictor_builder(self):
# predictor_text_proto = """
# attention_predictor {
# rnn_cell {
# lstm_cell {
# num_units: 256
# forget_bias: 1.0
# initializer { orthogonal_initializer { } }
# }
# }
# rnn_regularizer { l2_regularizer { weight: 1e-4 } }
# num_attention_units: 128
# max_num_steps: 10
# multi_attention: false
# beam_width: 1
# reverse: false
# label_map {
# character_set {
# text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
# delimiter: ""
# }
# label_offset: 2
# }
# loss {
# sequence_cross_entropy_loss {
# sequence_normalize: false
# sample_normalize: true
# }
# }
# sync: true
# }
# """
# predictor_proto = predictor_pb2.Predictor()
# text_format.Merge(predictor_text_proto, predictor_proto)
# predictor_object = predictor_builder.build(predictor_proto, True)
def test_sync_predictor_builder(self):
predictor_text_proto = """
attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
}
}
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
num_attention_units: 128
max_num_steps: 10
multi_attention: false
beam_width: 1
reverse: false
label_map {
character_set {
text_string: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
}
loss {
sequence_cross_entropy_loss {
sequence_normalize: false
sample_normalize: true
}
}
sync: true
}
"""
predictor_proto = predictor_pb2.Predictor()
text_format.Merge(predictor_text_proto, predictor_proto)
predictor_object = predictor_builder.build(predictor_proto, True)

# feature_maps = [tf.random_uniform([2, 1, 10, 32], dtype=tf.float32)]
# predictor_object.provide_groundtruth(
# tf.constant([b'hello', b'world'], dtype=tf.string)
# )
# predictions_dict = predictor_object.predict(feature_maps)
# loss = predictor_object.loss(predictions_dict)
feature_maps = [tf.random_uniform([2, 1, 10, 32], dtype=tf.float32)]
predictor_object.provide_groundtruth(
tf.constant([b'hello', b'world'], dtype=tf.string)
)
predictions_dict = predictor_object.predict(feature_maps)
loss = predictor_object.loss(predictions_dict)

# with self.test_session() as sess:
# sess.run([
# tf.global_variables_initializer(),
# tf.tables_initializer()])
# sess_outputs = sess.run({'loss': loss})
# print(sess_outputs)
with self.test_session() as sess:
sess.run([
tf.global_variables_initializer(),
tf.tables_initializer()])
sess_outputs = sess.run({'loss': loss})
print(sess_outputs)

def test_sync_predictor_lm_builder(self):
predictor_text_proto = """
attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
}
}
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
num_attention_units: 128
max_num_steps: 10
multi_attention: false
beam_width: 1
reverse: false
label_map {
character_set {
text_string: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
}
loss {
sequence_cross_entropy_loss {
sequence_normalize: false
sample_normalize: true
}
}
lm_rnn_cell {
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
}
}
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
}
}
}
sync: true
}
"""
predictor_proto = predictor_pb2.Predictor()
text_format.Merge(predictor_text_proto, predictor_proto)
predictor_object = predictor_builder.build(predictor_proto, True)

feature_maps = [tf.random_uniform([2, 1, 10, 32], dtype=tf.float32)]
predictor_object.provide_groundtruth(
tf.constant([b'hello', b'world'], dtype=tf.string)
)
predictions_dict = predictor_object.predict(feature_maps)
loss = predictor_object.loss(predictions_dict)

with self.test_session() as sess:
sess.run([
tf.global_variables_initializer(),
tf.tables_initializer()])
sess_outputs = sess.run({'loss': loss})
print(sess_outputs)

if __name__ == '__main__':
tf.test.main()
32 changes: 30 additions & 2 deletions core/sync_attention_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
from tensorflow.python.ops import array_ops
from tensorflow.contrib import rnn
from tensorflow.contrib import seq2seq
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _compute_attention


class SyncAttentionWrapper(seq2seq.AttentionWrapper):

def __init__(self,
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None):
if not isinstance(cell, (rnn.LSTMCell, rnn.GRUCell)):
raise ValueError('SyncAttentionWrapper only supports LSTMCell and GRUCell, '
'Got: {}'.format(cell))
super(SyncAttentionWrapper, self).__init__(
cell,
attention_mechanism,
attention_layer_size=attention_layer_size,
alignment_history=alignment_history,
cell_input_fn=cell_input_fn,
output_attention=output_attention,
initial_cell_state=initial_cell_state,
name=name
)

def call(self, inputs, state):
if not isinstance(state, seq2seq.AttentionWrapperState):
Expand All @@ -21,8 +45,12 @@ def call(self, inputs, state):
all_attentions = []
all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
if isinstance(self._cell, rnn.LSTMCell):
rnn_cell_state = state.cell_state.h
else:
rnn_cell_state = state.cell_state
attention, alignments = _compute_attention(
attention_mechanism, state.cell_state, previous_alignments[i],
attention_mechanism, rnn_cell_state, previous_alignments[i],
self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(
state.time, alignments) if self._alignment_history else ()
Expand All @@ -36,7 +64,7 @@ def call(self, inputs, state):
cell_inputs = self._cell_input_fn(inputs, attention)
cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)

next_state = AttentionWrapperState(
next_state = seq2seq.AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
Expand Down
49 changes: 47 additions & 2 deletions predictors/attention_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import functools

import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.contrib import seq2seq
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest

from rare.core import predictor
from rare.core import sync_attention_wrapper
from rare.core import loss
Expand All @@ -24,6 +28,7 @@ def __init__(self,
label_map=None,
loss=None,
sync=False,
lm_rnn_cell=None,
is_training=True):
super(AttentionPredictor, self).__init__(is_training)
self._rnn_cell = rnn_cell
Expand All @@ -35,6 +40,7 @@ def __init__(self,
self._reverse = reverse
self._label_map = label_map
self._sync = sync
self._lm_rnn_cell = lm_rnn_cell
self._loss = loss

if not self._is_training and not self._beam_width > 0:
Expand Down Expand Up @@ -123,7 +129,7 @@ def provide_groundtruth(self, groundtruth_text, scope=None):
else:
decoder_inputs = tf.concat([start_labels, text_labels], axis=1)
decoder_targets = tf.concat([text_labels, end_labels], axis=1)
decoder_lengths = text_lengths + 2
decoder_lengths = text_lengths + 1
self._groundtruth_dict['decoder_inputs'] = decoder_inputs
self._groundtruth_dict['decoder_targets'] = decoder_targets
self._groundtruth_dict['decoder_lengths'] = decoder_lengths
Expand All @@ -140,10 +146,15 @@ def postprocess(self, predictions_dict, scope=None):
def _build_decoder_cell(self, feature_maps):
attention_mechanism = self._build_attention_mechanism(feature_maps)
wrapper_class = seq2seq.AttentionWrapper if not self._sync else sync_attention_wrapper.SyncAttentionWrapper
decoder_cell = wrapper_class(
attention_cell = wrapper_class(
self._rnn_cell,
attention_mechanism,
output_attention=False)
if not self._lm_rnn_cell:
decoder_cell = attention_cell
else:
decoder_cell = ConcatOutputMultiRNNCell([attention_cell, self._lm_rnn_cell])

return decoder_cell

def _build_attention_mechanism(self, feature_maps):
Expand Down Expand Up @@ -197,3 +208,37 @@ def _build_decoder(self, decoder_cell, batch_size):
output_layer=output_layer,
length_penalty_weight=0.0)
return decoder


class ConcatOutputMultiRNNCell(rnn.MultiRNNCell):
"""RNN cell composed of multiple RNN cells whose outputs are concatenated along depth."""

@property
def output_size(self):
return sum([cell.output_size for cell in self._cells])

def call(self, inputs, state):
cur_state_pos = 0
outputs = []
new_states = []
for i, cell in enumerate(self._cells):
with vs.variable_scope("cell_%d" % i):
if self._state_is_tuple:
if not nest.is_sequence(state):
raise ValueError(
"Expected state to be a tuple of length %d, but received: %s" %
(len(self.state_size), state))
cur_state = state[i]
else:
cur_state = array_ops.slice(state, [0, cur_state_pos],
[-1, cell.state_size])
cur_state_pos += cell.state_size
cur_output, new_state = cell(inputs, cur_state)
new_states.append(new_state)
outputs.append(cur_output)

new_states = (tuple(new_states) if self._state_is_tuple else
array_ops.concat(new_states, 1))
output = tf.concat(outputs, -1)

return output, new_states
Loading

0 comments on commit 6dd64ec

Please sign in to comment.