Skip to content

Commit

Permalink
Multi predictor recognition
Browse files Browse the repository at this point in the history
  • Loading branch information
bgshih committed Dec 28, 2017
1 parent 947aab0 commit 2cd4e5b
Show file tree
Hide file tree
Showing 12 changed files with 488 additions and 366 deletions.
229 changes: 16 additions & 213 deletions builders/attention_recognition_model_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_build_attention_model_single_branch(self):
summarize_activations: false
}
}
bidirectional_rnn {
fw_bw_rnn_cell {
lstm_cell {
Expand All @@ -41,236 +40,39 @@ def test_build_attention_model_single_branch(self):
regularizer { l2_regularizer { weight: 1e-4 } }
}
}
summarize_activations: true
}
attention_predictor {
predictor {
name: "ForwardPredictor"
bahdanau_attention_predictor {
reverse: false
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
}
}
num_attention_units: 128
max_num_steps: 10
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
fc_hyperparams {
op: FC
activation: RELU
initializer { variance_scaling_initializer {} }
regularizer { l2_regularizer { weight: 1e-4 } }
}
}
}
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
}
loss {
sequence_cross_entropy_loss {
sequence_normalize: false
sample_normalize: true
}
}
}
"""
model_proto = model_pb2.Model()
text_format.Merge(model_text_proto, model_proto)
model_object = model_builder.build(model_proto, True)

test_groundtruth_text_list = [
tf.constant(b'hello', dtype=tf.string),
tf.constant(b'world', dtype=tf.string)]
model_object.provide_groundtruth(test_groundtruth_text_list)
test_input_image = tf.random_uniform(
shape=[2, 32, 100, 3], minval=0, maxval=255,
dtype=tf.float32, seed=1)
prediction_dict = model_object.predict(model_object.preprocess(test_input_image))
loss = model_object.loss(prediction_dict)

with self.test_session() as sess:
sess.run([
tf.global_variables_initializer(),
tf.tables_initializer()])
outputs = sess.run({'loss': loss})
print(outputs['loss'])

def test_build_attention_model_multi_branches(self):
model_text_proto = """
attention_recognition_model {
feature_extractor {
convnet {
crnn_net {
net_type: THREE_BRANCHES
conv_hyperparams {
op: CONV
regularizer { l2_regularizer { weight: 1e-4 } }
initializer { variance_scaling_initializer { } }
batch_norm { }
}
summarize_activations: false
}
}
bidirectional_rnn {
fw_bw_rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer {} }
}
}
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
num_output_units: 256
fc_hyperparams {
op: FC
activation: RELU
initializer { variance_scaling_initializer { } }
regularizer { l2_regularizer { weight: 1e-4 } }
}
}
summarize_activations: true
}
attention_predictor {
bahdanau_attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
}
}
num_attention_units: 128
max_num_steps: 10
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
fc_hyperparams {
op: FC
activation: RELU
initializer { variance_scaling_initializer {} }
regularizer { l2_regularizer { weight: 1e-4 } }
}
}
}
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
}
loss {
sequence_cross_entropy_loss {
sequence_normalize: false
sample_normalize: true
}
}
}
"""
model_proto = model_pb2.Model()
text_format.Merge(model_text_proto, model_proto)
model_object = model_builder.build(model_proto, True)

test_groundtruth_text_list = [
tf.constant(b'hello', dtype=tf.string),
tf.constant(b'world', dtype=tf.string)]
model_object.provide_groundtruth(test_groundtruth_text_list)
test_input_image = tf.random_uniform(
shape=[2, 32, 100, 3], minval=0, maxval=255,
dtype=tf.float32, seed=1)
prediction_dict = model_object.predict(model_object.preprocess(test_input_image))
loss = model_object.loss(prediction_dict)

with self.test_session() as sess:
sess.run([
tf.global_variables_initializer(),
tf.tables_initializer()])
outputs = sess.run({'loss': loss})
print(outputs['loss'])

def test_build_attention_model_multi_branches_multi_attention(self):
model_text_proto = """
attention_recognition_model {
feature_extractor {
convnet {
crnn_net {
net_type: THREE_BRANCHES
conv_hyperparams {
op: CONV
regularizer { l2_regularizer { weight: 1e-4 } }
initializer { variance_scaling_initializer { } }
batch_norm { }
}
summarize_activations: false
}
}
bidirectional_rnn {
fw_bw_rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer {} }
multi_attention: false
beam_width: 1
reverse: false
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
}
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
num_output_units: 256
fc_hyperparams {
op: FC
activation: RELU
initializer { variance_scaling_initializer { } }
regularizer { l2_regularizer { weight: 1e-4 } }
}
}
summarize_activations: true
}
attention_predictor {
bahdanau_attention_predictor {
rnn_cell {
lstm_cell {
num_units: 256
forget_bias: 1.0
initializer { orthogonal_initializer { } }
loss {
sequence_cross_entropy_loss {
sequence_normalize: false
sample_normalize: true
}
}
num_attention_units: 128
max_num_steps: 10
rnn_regularizer { l2_regularizer { weight: 1e-4 } }
fc_hyperparams {
op: FC
activation: RELU
initializer { variance_scaling_initializer {} }
regularizer { l2_regularizer { weight: 1e-4 } }
}
multi_attention: true
}
}
label_map {
character_set {
text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
delimiter: ""
}
label_offset: 2
}
loss {
sequence_cross_entropy_loss {
sequence_normalize: false
sample_normalize: true
}
}
}
Expand All @@ -296,5 +98,6 @@ def test_build_attention_model_multi_branches_multi_attention(self):
outputs = sess.run({'loss': loss})
print(outputs['loss'])


if __name__ == '__main__':
tf.test.main()
60 changes: 16 additions & 44 deletions builders/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,35 @@
import tensorflow as tf

from rare.builders import feature_extractor_builder
from rare.builders import loss_builder
from rare.builders import hyperparams_builder

from rare.meta_architectures import attention_recognition_model
from rare.meta_architectures import ctc_recognition_model
from rare.builders import predictor_builder
from rare.meta_architectures import multi_predictors_recognition_model
from rare.protos import model_pb2


def build(config, is_training):
if not isinstance(config, model_pb2.Model):
raise ValueError('config not of type '
'model_pb2.Model')
raise ValueError('config not of type model_pb2.Model')
model_oneof = config.WhichOneof('model_oneof')
if model_oneof == 'attention_recognition_model':
return _build_attention_recognition_model(config.attention_recognition_model, is_training)
elif model_oneof == 'ctc_recognition_model':
return _build_ctc_recognition_model(config.ctc_recognition_model, is_training)
if model_oneof == 'multi_predictors_recognition_model':
return _build_multi_predictors_recognition_model(
config.multi_predictors_recognition_model, is_training)
else:
raise ValueError('Unknown model_oneof: {}'.format(model_oneof))


def _build_attention_recognition_model(config, is_training):
if not isinstance(config, model_pb2.AttentionRecognitionModel):
raise ValueError('config not of type model_pb2.AttentionRecognitionModel')
def _build_multi_predictors_recognition_model(config, is_training):
if not isinstance(config, model_pb2.MultiPredictorsRecognitionModel):
raise ValueError('config not of type model_pb2.MultiPredictorsRecognitionModel')
feature_extractor_object = feature_extractor_builder.build(
config.feature_extractor,
is_training=is_training
)
predictor_object = _build_attention_predictor(
config.attention_predictor,
is_training=is_training)
label_map_object = label_map_builder.build(config.label_map)
loss_object = loss_builder.build(config.loss)

model_object = attention_recognition_model.AttentionRecognitionModel(
predictors_dict = {
predictor_config.name : predictor_builder.build(predictor_config, is_training=is_training)
for predictor_config in config.predictor
}
model_object = multi_predictors_recognition_model.MultiPredictorsRecognitionModel(
feature_extractor=feature_extractor_object,
predictor=predictor_object,
label_map=label_map_object,
loss=loss_object,
is_training=is_training
predictors_dict=predictors_dict,
is_training=is_training,
)
return model_object

def _build_ctc_recognition_model(config, is_training):
if not isinstance(config, model_pb2.CtcRecognitionModel):
raise ValueError('config not of type model_pb2.CtcRecognitionModel')
feature_extractor_object = feature_extractor_builder.build(
config.feature_extractor,
is_training=is_training
)
label_map_object = label_map_builder.build(config.label_map)
fc_hyperparams_object = hyperparams_builder.build(
config.fc_hyperparams,
is_training)
model_object = ctc_recognition_model.CtcRecognitionModel(
feature_extractor=feature_extractor_object,
fc_hyperparams=fc_hyperparams_object,
label_map=label_map_object,
is_training=is_training)
return model_object
Loading

0 comments on commit 2cd4e5b

Please sign in to comment.