@@ -23,6 +23,7 @@ def __init__(
23
23
eval_input_fn = None ,
24
24
eval_hooks = None ,
25
25
model_fn = None ,
26
+ serving_input_receiver_fn = None ,
26
27
):
27
28
self .params = params
28
29
self .feature_columns = feature_columns
@@ -32,6 +33,7 @@ def __init__(
32
33
self .eval_hooks = eval_hooks
33
34
self .model_fn = model_fn
34
35
self .run_config = run_config
36
+ self .serving_input_receiver_fn = serving_input_receiver_fn
35
37
36
38
@property
37
39
def params (self ):
@@ -75,6 +77,12 @@ def eval_hooks(self):
75
77
self .eval_hooks = self ._eval_hooks ()
76
78
return self .__eval_hooks
77
79
80
+ @property
81
+ def serving_input_receiver_fn (self ):
82
+ if self .__serving_input_receiver_fn is None :
83
+ self .serving_input_receiver_fn = self ._serving_input_receiver_fn ()
84
+ return self .__serving_input_receiver_fn
85
+
78
86
@property
79
87
def estimator (self ):
80
88
self .__estimator = tf .estimator .Estimator (
@@ -123,6 +131,10 @@ def eval_hooks(self, eval_hooks):
123
131
eval_hooks = []
124
132
self .__eval_hooks = eval_hooks
125
133
134
+ @serving_input_receiver_fn .setter
135
+ def serving_input_receiver_fn (self , serving_input_receiver_fn ):
136
+ self .__serving_input_receiver_fn = serving_input_receiver_fn
137
+
126
138
@run_config .setter
127
139
def run_config (self , run_config ):
128
140
if run_config is None :
@@ -156,6 +168,26 @@ def _train_hooks(self):
156
168
def _eval_hooks (self ):
157
169
return []
158
170
171
+ def _serving_input_receiver_fn (self ):
172
+ feature_spec = {
173
+ "x" : tf .FixedLenFeature (
174
+ dtype = tf .int64 , shape = [self .params ["max_seq_length" ]]
175
+ ),
176
+ "len" : tf .FixedLenFeature (dtype = tf .int64 , shape = []),
177
+ }
178
+
179
+ def default_serving_input_receiver_fn ():
180
+ serialized_tf_example = tf .placeholder (
181
+ dtype = tf .string , shape = [None ], name = "input_example_tensor"
182
+ )
183
+ receiver_tensors = {"examples" : serialized_tf_example }
184
+ features = tf .parse_example (serialized_tf_example , feature_spec )
185
+ return tf .estimator .export .ServingInputReceiver (
186
+ features , receiver_tensors
187
+ )
188
+
189
+ return default_serving_input_receiver_fn
190
+
159
191
def train (self , dataset , steps , distribution = None , hooks = []):
160
192
self ._add_embedding_params (embedding = dataset .embedding )
161
193
features , labels , stats = dataset .get_features_and_labels (
@@ -232,10 +264,17 @@ def train_and_eval(self, dataset, steps):
232
264
{"duration" : duration_dict , ** eval_stats },
233
265
)
234
266
267
+ def export (self , directory ):
268
+ self .estimator .export_savedmodel (
269
+ directory , self .serving_input_receiver_fn , strip_default_attrs = True
270
+ )
271
+
235
272
def _wrap_model_fn (self , _model_fn ):
236
273
@wraps (_model_fn )
237
274
def wrapper (features , labels , mode , params ):
238
275
spec = _model_fn (features , labels , mode , params )
276
+ if mode == ModeKeys .PREDICT :
277
+ return spec
239
278
std_metrics = {
240
279
"accuracy" : tf .metrics .accuracy (
241
280
labels = labels ,
@@ -264,18 +303,19 @@ def wrapper(features, labels, mode, params):
264
303
tf .summary .scalar ("accuracy" , std_metrics ["accuracy" ][1 ])
265
304
tf .summary .scalar ("auc" , std_metrics ["auc" ][1 ])
266
305
if mode == ModeKeys .EVAL :
267
- attn_hook = SaveAttentionWeightVectorHook (
268
- labels = labels ,
269
- predictions = spec .predictions ["class_ids" ],
270
- targets = features ["target" ]["lit" ],
271
- summary_writer = tf .summary .FileWriterCache .get (
272
- join (self .run_config .model_dir , "eval" )
273
- ),
274
- n_picks = self .params .get ("n_attn_heatmaps" , 5 ),
275
- n_hops = self .params .get ("n_hops" ),
276
- )
277
306
all_eval_hooks = spec .evaluation_hooks or []
278
- all_eval_hooks += [attn_hook ]
307
+ if features .get ("target" ) is not None :
308
+ attn_hook = SaveAttentionWeightVectorHook (
309
+ labels = labels ,
310
+ predictions = spec .predictions ["class_ids" ],
311
+ targets = features ["target" ]["lit" ],
312
+ summary_writer = tf .summary .FileWriterCache .get (
313
+ join (self .run_config .model_dir , "eval" )
314
+ ),
315
+ n_picks = self .params .get ("n_attn_heatmaps" , 5 ),
316
+ n_hops = self .params .get ("n_hops" ),
317
+ )
318
+ all_eval_hooks += [attn_hook ]
279
319
all_metrics = spec .eval_metric_ops or {}
280
320
all_metrics .update (std_metrics )
281
321
return spec ._replace (
@@ -299,8 +339,6 @@ def wrapper(features, labels, mode, params):
299
339
all_training_hooks += [logging_hook ]
300
340
return spec ._replace (training_hooks = all_training_hooks )
301
341
302
- return spec
303
-
304
342
return wrapper
305
343
306
344
def _export_statistics (self , dataset_stats = None , steps = None ):
0 commit comments