Skip to content

Commit

Permalink
Evaluate with lexicons
Browse files Browse the repository at this point in the history
  • Loading branch information
bgshih committed Feb 2, 2018
1 parent 8cd08ae commit 03627dc
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 19 deletions.
2 changes: 0 additions & 2 deletions core/spatial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def _localize(self, preprocessed_images):
conv_output = tf.reshape(conv_output, [batch_size, -1])
with arg_scope(self._fc_hyperparams):
fc1 = fully_connected(conv_output, 512)
# fc2 = fully_connected(fc1, 2 * k, activation_fn=None, normalizer_fn=None)
# ctrl_pts = tf.sigmoid(fc2)
fc2_weights_initializer = tf.zeros_initializer()
fc2_biases_initializer = tf.constant_initializer(self._init_bias)
fc2 = fully_connected(0.1 * fc1, 2 * k,
Expand Down
2 changes: 1 addition & 1 deletion core/standard_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class InputDataFields(object):
filename = 'filename'
groundtruth_text = 'groundtruth_text'
groundtruth_keypoints = 'groundtruth_keypoints'
lexicon = 'lexicon'


class TfExampleFields(object):
Expand All @@ -19,5 +20,4 @@ class TfExampleFields(object):
source_id = 'image/source_id'
transcript = 'image/transcript'
lexicon = 'image/lexicon'
lexicon_2 = 'image/lexicon_2'
keypoints = 'image/keypoints'
15 changes: 14 additions & 1 deletion data_decoders/tf_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self):
tf.FixedLenFeature((), tf.string, default_value=''),
fields.TfExampleFields.keypoints: \
tf.VarLenFeature(tf.float32),
fields.TfExampleFields.lexicon: \
tf.FixedLenFeature((), tf.string, default_value=''),
}
self.items_to_handlers = {
fields.InputDataFields.image: \
Expand All @@ -39,7 +41,12 @@ def __init__(self):
fields.InputDataFields.groundtruth_text: \
slim_example_decoder.Tensor(fields.TfExampleFields.transcript),
fields.InputDataFields.groundtruth_keypoints: \
slim_example_decoder.Tensor(fields.TfExampleFields.keypoints)
slim_example_decoder.Tensor(fields.TfExampleFields.keypoints),
fields.InputDataFields.lexicon: \
slim_example_decoder.ItemHandlerCallback(
[fields.TfExampleFields.lexicon],
self._split_lexicon
)
}

def Decode(self, tf_example_string_tensor):
Expand All @@ -64,3 +71,9 @@ def Decode(self, tf_example_string_tensor):
tensor_dict[fields.InputDataFields.groundtruth_keypoints] = normalized_keypoints

return tensor_dict

def _split_lexicon(self, keys_to_tensors):
joined_lexicon = keys_to_tensors[fields.TfExampleFields.lexicon]
lexicon_sparse = tf.string_split([joined_lexicon], delimiter='\t')
lexicon = tf.sparse_tensor_to_dense(lexicon_sparse, default_value='')[0]
return lexicon
29 changes: 26 additions & 3 deletions evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import logging
import tensorflow as tf
import numpy as np
import editdistance

from rare.core import preprocessor
from rare.core import prefetcher
Expand All @@ -17,7 +19,8 @@
def _extract_prediction_tensors(model,
create_input_dict_fn,
data_preprocessing_steps,
ignore_groundtruth=False):
ignore_groundtruth=False,
evaluate_with_lexicon=False):
# input queue
input_dict = create_input_dict_fn()
prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
Expand All @@ -35,13 +38,32 @@ def _extract_prediction_tensors(model,
predictions_dict = model.predict(tf.expand_dims(preprocessed_image, 0))
recognitions = model.postprocess(predictions_dict)

def _lexicon_search(lexicon, word):
edit_distances = []
for lex_word in lexicon:
edit_distances.append(editdistance.eval(lex_word.lower(), word.lower()))
edit_distances = np.asarray(edit_distances, dtype=np.int)
argmin = np.argmin(edit_distances)
return lexicon[argmin]

if evaluate_with_lexicon:
lexicon = input_dict[fields.InputDataFields.lexicon]
recognition_text = tf.py_func(
_lexicon_search,
[lexicon, recognitions['text'][0]],
tf.string,
stateful=False,
)
else:
recognition_text = recognitions['text'][0]

tensor_dict = {
'original_image': original_image,
'original_image_shape': original_image_shape,
'preprocessed_image_shape': preprocessed_image_shape,
'filename': preprocessed_input_dict[fields.InputDataFields.filename],
'groundtruth_text': input_dict[fields.InputDataFields.groundtruth_text],
'recognition_text': recognitions['text'][0],
'recognition_text': recognition_text,
}
if 'control_points' in predictions_dict:
tensor_dict.update({
Expand All @@ -64,7 +86,8 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config,
model=model,
create_input_dict_fn=create_input_dict_fn,
data_preprocessing_steps=data_preprocessing_steps,
ignore_groundtruth=eval_config.ignore_groundtruth)
ignore_groundtruth=eval_config.ignore_groundtruth,
evaluate_with_lexicon=eval_config.eval_with_lexicon)

summary_writer = tf.summary.FileWriter(eval_dir)

Expand Down
3 changes: 3 additions & 0 deletions protos/eval.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ message EvalConfig {
// Whether to evaluate instance masks.
optional bool eval_instance_masks = 12 [default=false];

// Whether to evaluate with lexicon
optional bool eval_with_lexicon = 15 [default=false];

// data preprocessing steps
repeated PreprocessingStep data_preprocessing_steps = 13;
}
19 changes: 7 additions & 12 deletions tools/create_iiit5k_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
FLAGS = flags.FLAGS


def create_iiit5k_subset(output_path, train_subset=True):
def create_iiit5k_subset(output_path, train_subset=True, lexicon_index=None):
writer = tf.python_io.TFRecordWriter(output_path)

mat_file_name = 'traindata.mat' if train_subset else 'testdata.mat'
Expand All @@ -28,12 +28,8 @@ def create_iiit5k_subset(output_path, train_subset=True):
for entry in tqdm(entries):
image_rel_path = str(entry[0][0])
groundtruth_text = str(entry[1][0])
lexicon = [str(t[0]) for t in entry[2].flatten()]

if train_subset:
lexicon_2 = []
else:
lexicon_2 = [str(t[0]) for t in entry[3].flatten()]
if not train_subset:
lexicon = [str(t[0]) for t in entry[lexicon_index].flatten()]

image_path = os.path.join(FLAGS.data_dir, image_rel_path)
with open(image_path, 'rb') as f:
Expand All @@ -53,15 +49,14 @@ def create_iiit5k_subset(output_path, train_subset=True):
fields.TfExampleFields.transcript: \
dataset_util.bytes_feature(groundtruth_text.encode('utf-8')),
fields.TfExampleFields.lexicon: \
dataset_util.bytes_feature(('\t'.join(lexicon)).encode('utf-8')),
fields.TfExampleFields.lexicon_2: \
dataset_util.bytes_feature(('\t'.join(lexicon_2)).encode('utf-8'))
dataset_util.bytes_feature(('\t'.join(lexicon)).encode('utf-8'))
}))
writer.write(example.SerializeToString())

writer.close()


if __name__ == '__main__':
create_iiit5k_subset('iiit5k_train.tfrecord', train_subset=True)
# create_iiit5k_subset('iiit5k_test.tfrecord', train_subset=False)
# create_iiit5k_subset('data/iiit5k_train.tfrecord', train_subset=True)
create_iiit5k_subset('data/iiit5k_test_50.tfrecord', train_subset=False, lexicon_index=2)
# create_iiit5k_subset('data/iiit5k_test_1k.tfrecord', train_subset=False, lexicon_index=3)

0 comments on commit 03627dc

Please sign in to comment.