Skip to content

Commit d980038

Browse files
committed
creates directory for exported models, starts work on functions to export saved model
1 parent fde2e84 commit d980038

File tree

5 files changed

+93
-45
lines changed

5 files changed

+93
-45
lines changed

export/.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Ignore everything in this directory
2+
*
3+
# Except this file
4+
!.gitignore

main.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,30 @@
2727
"initializer": tf.initializers.random_uniform(minval=-0.1, maxval=0.1),
2828
}
2929

30-
# experiment = Experiment(
31-
# dataset=Dataset(
32-
# path=DATASETS.DEBUG_PATH,
33-
# parser=DATASETS.DEBUG_PARSER,
34-
# embedding=Embedding(path=EMBEDDINGS.DEBUG),
35-
# ),
36-
# model=MemNet(),
37-
# run_config=tf.estimator.RunConfig(tf_random_seed=1234),
38-
# )
39-
# experiment.run(job="train+eval", steps=1, start_tb=True)
40-
# # experiment.run(job="train", steps=200, hooks=[debug_hook])
4130
experiment = Experiment(
4231
dataset=Dataset(
43-
path=DATASETS.DONG2014_PATH,
44-
parser=DATASETS.DONG2014_PARSER,
45-
embedding=Embedding(path=EMBEDDINGS.GLOVE_TWITTER_100D),
32+
path=DATASETS.DEBUG_PATH,
33+
parser=DATASETS.DEBUG_PARSER,
34+
embedding=Embedding(path=EMBEDDINGS.DEBUG),
4635
),
47-
model=MemNet(),
48-
contd_tag="gold",
36+
model=Lstm(),
37+
contd_tag="gold_debug",
4938
# run_config=tf.estimator.RunConfig(tf_random_seed=1234),
5039
)
51-
experiment.run(job="train+eval", steps=16000, start_tb=True)
40+
# experiment.run(job="train+eval", steps=1, start_tb=True)
41+
# # experiment.run(job="train", steps=200, hooks=[debug_hook])
42+
# experiment = Experiment(
43+
# dataset=Dataset(
44+
# path=DATASETS.DONG2014_PATH,
45+
# parser=DATASETS.DONG2014_PARSER,
46+
# embedding=Embedding(path=EMBEDDINGS.GLOVE_TWITTER_25D),
47+
# ),
48+
# model=Lstm(),
49+
# contd_tag="gold",
50+
# # run_config=tf.estimator.RunConfig(tf_random_seed=1234),
51+
# )
52+
experiment.run(job="train+eval", steps=1)
53+
experiment.export_model()
5254
# experiment = Experiment(
5355
# dataset=Dataset(
5456
# path=DATASETS.NAKOV2016_PATH,

tsaplay/experiments/Experiment.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def __init__(
1111
):
1212
self.dataset = dataset
1313
self.model = model
14+
self.model_name = model.__class__.__name__
1415
if embedding is not None:
1516
self.embedding = embedding
1617
self.dataset.embedding = self.embedding
@@ -21,8 +22,9 @@ def __init__(
2122
)
2223
else:
2324
self.embedding = self.dataset.embedding
25+
self.contd_tag = contd_tag
2426
self.exp_dir = self._init_exp_dir(
25-
model=self.model, dataset=self.dataset, contd_tag=contd_tag
27+
model=self.model, dataset=self.dataset, contd_tag=self.contd_tag
2628
)
2729
if run_config is None:
2830
run_config = self.model.run_config
@@ -69,6 +71,17 @@ def run(
6971
debug_port=debug_port,
7072
)
7173

74+
def export_model(self):
75+
if self.contd_tag is None:
76+
print("No continue tag defined, nothing to export!")
77+
else:
78+
export_model_name = "_".join(
79+
[self.model_name.lower(), self.contd_tag]
80+
)
81+
export_dir = _join(getcwd(), "export", export_model_name)
82+
self.model.export(directory=export_dir)
83+
return
84+
7285
def _init_exp_dir(self, model, dataset, contd_tag):
7386
all_exps_path = _join(dirname(abspath(__file__)), "data")
7487
rel_model_path = _join(
@@ -93,15 +106,6 @@ def _init_exp_dir(self, model, dataset, contd_tag):
93106
)
94107
return exp_dir
95108

96-
def _init_run_config(self, exp_dir, run_config):
97-
summary_dir = _join(exp_dir, "tb_summary")
98-
if run_config is None:
99-
return tf.estimator.RunConfig(model_dir=summary_dir)
100-
elif run_config.model_dir is None:
101-
return run_config.replace(model_dir=summary_dir)
102-
else:
103-
return run_config
104-
105109
def _init_model_dir(self, exp_dir, run_config):
106110
summary_dir = _join(exp_dir, "tb_summary")
107111
if run_config.model_dir is None:

tsaplay/models/Model.py

+51-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
eval_input_fn=None,
2424
eval_hooks=None,
2525
model_fn=None,
26+
serving_input_receiver_fn=None,
2627
):
2728
self.params = params
2829
self.feature_columns = feature_columns
@@ -32,6 +33,7 @@ def __init__(
3233
self.eval_hooks = eval_hooks
3334
self.model_fn = model_fn
3435
self.run_config = run_config
36+
self.serving_input_receiver_fn = serving_input_receiver_fn
3537

3638
@property
3739
def params(self):
@@ -75,6 +77,12 @@ def eval_hooks(self):
7577
self.eval_hooks = self._eval_hooks()
7678
return self.__eval_hooks
7779

80+
@property
81+
def serving_input_receiver_fn(self):
82+
if self.__serving_input_receiver_fn is None:
83+
self.serving_input_receiver_fn = self._serving_input_receiver_fn()
84+
return self.__serving_input_receiver_fn
85+
7886
@property
7987
def estimator(self):
8088
self.__estimator = tf.estimator.Estimator(
@@ -123,6 +131,10 @@ def eval_hooks(self, eval_hooks):
123131
eval_hooks = []
124132
self.__eval_hooks = eval_hooks
125133

134+
@serving_input_receiver_fn.setter
135+
def serving_input_receiver_fn(self, serving_input_receiver_fn):
136+
self.__serving_input_receiver_fn = serving_input_receiver_fn
137+
126138
@run_config.setter
127139
def run_config(self, run_config):
128140
if run_config is None:
@@ -156,6 +168,26 @@ def _train_hooks(self):
156168
def _eval_hooks(self):
157169
return []
158170

171+
def _serving_input_receiver_fn(self):
172+
feature_spec = {
173+
"x": tf.FixedLenFeature(
174+
dtype=tf.int64, shape=[self.params["max_seq_length"]]
175+
),
176+
"len": tf.FixedLenFeature(dtype=tf.int64, shape=[]),
177+
}
178+
179+
def default_serving_input_receiver_fn():
180+
serialized_tf_example = tf.placeholder(
181+
dtype=tf.string, shape=[None], name="input_example_tensor"
182+
)
183+
receiver_tensors = {"examples": serialized_tf_example}
184+
features = tf.parse_example(serialized_tf_example, feature_spec)
185+
return tf.estimator.export.ServingInputReceiver(
186+
features, receiver_tensors
187+
)
188+
189+
return default_serving_input_receiver_fn
190+
159191
def train(self, dataset, steps, distribution=None, hooks=[]):
160192
self._add_embedding_params(embedding=dataset.embedding)
161193
features, labels, stats = dataset.get_features_and_labels(
@@ -232,10 +264,17 @@ def train_and_eval(self, dataset, steps):
232264
{"duration": duration_dict, **eval_stats},
233265
)
234266

267+
def export(self, directory):
268+
self.estimator.export_savedmodel(
269+
directory, self.serving_input_receiver_fn, strip_default_attrs=True
270+
)
271+
235272
def _wrap_model_fn(self, _model_fn):
236273
@wraps(_model_fn)
237274
def wrapper(features, labels, mode, params):
238275
spec = _model_fn(features, labels, mode, params)
276+
if mode == ModeKeys.PREDICT:
277+
return spec
239278
std_metrics = {
240279
"accuracy": tf.metrics.accuracy(
241280
labels=labels,
@@ -264,18 +303,19 @@ def wrapper(features, labels, mode, params):
264303
tf.summary.scalar("accuracy", std_metrics["accuracy"][1])
265304
tf.summary.scalar("auc", std_metrics["auc"][1])
266305
if mode == ModeKeys.EVAL:
267-
attn_hook = SaveAttentionWeightVectorHook(
268-
labels=labels,
269-
predictions=spec.predictions["class_ids"],
270-
targets=features["target"]["lit"],
271-
summary_writer=tf.summary.FileWriterCache.get(
272-
join(self.run_config.model_dir, "eval")
273-
),
274-
n_picks=self.params.get("n_attn_heatmaps", 5),
275-
n_hops=self.params.get("n_hops"),
276-
)
277306
all_eval_hooks = spec.evaluation_hooks or []
278-
all_eval_hooks += [attn_hook]
307+
if features.get("target") is not None:
308+
attn_hook = SaveAttentionWeightVectorHook(
309+
labels=labels,
310+
predictions=spec.predictions["class_ids"],
311+
targets=features["target"]["lit"],
312+
summary_writer=tf.summary.FileWriterCache.get(
313+
join(self.run_config.model_dir, "eval")
314+
),
315+
n_picks=self.params.get("n_attn_heatmaps", 5),
316+
n_hops=self.params.get("n_hops"),
317+
)
318+
all_eval_hooks += [attn_hook]
279319
all_metrics = spec.eval_metric_ops or {}
280320
all_metrics.update(std_metrics)
281321
return spec._replace(
@@ -299,8 +339,6 @@ def wrapper(features, labels, mode, params):
299339
all_training_hooks += [logging_hook]
300340
return spec._replace(training_hooks=all_training_hooks)
301341

302-
return spec
303-
304342
return wrapper
305343

306344
def _export_statistics(self, dataset_stats=None, steps=None):

tsaplay/models/Tang2016a/common.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def lstm_input_fn(
3333
features["mappings"]["right"],
3434
)
3535
]
36-
sen_map, sen_len = prep_features_for_dataset(mappings=sentences)
36+
sen_map, sen_len = prep_features_for_dataset(
37+
mappings=sentences, max_seq_length=max_seq_length
38+
)
39+
sentence = wrap_mapping_length_literal(sen_map, sen_len)
3740
labels = make_labels_dataset_from_list(labels)
3841

39-
dataset = tf.data.Dataset.from_tensor_slices((sen_map, sen_len, labels))
40-
dataset = dataset.map(
41-
lambda sentence, length, label: ({"x": sentence, "len": length}, label)
42-
)
42+
dataset = tf.data.Dataset.zip((sentence, labels))
4343

4444
iterator = prep_dataset_and_get_iterator(
4545
dataset=dataset,

0 commit comments

Comments
 (0)