Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b56a5cd

Browse files
Drunkarlukaszkaiser
authored andcommitted
Modify serving utils (#1495)
* Add decode logic for model with return_beams=True. * Add print logic for model with return_beams=True.
1 parent a4071d6 commit b56a5cd

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

tensor2tensor/serving/query.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,26 @@ def main(_):
9090
outputs = serving_utils.predict([inputs], problem, request_fn)
9191
outputs, = outputs
9292
output, score = outputs
93-
print_str = """
93+
if len(score.shape) > 0:
94+
print_str = """
95+
Input:
96+
{inputs}
97+
98+
Output (Scores [{score}]):
99+
{output}
100+
"""
101+
score_text = ",".join(["{:.3f}".format(s) for s in score])
102+
print(print_str.format(inputs=inputs, output=output, score=score_text))
103+
else:
104+
print_str = """
94105
Input:
95106
{inputs}
96107
97108
Output (Score {score:.3f}):
98109
{output}
99-
"""
100-
print(print_str.format(inputs=inputs, output=output, score=score))
110+
"""
111+
print(print_str.format(inputs=inputs, output=output, score=score))
112+
101113
if FLAGS.inputs_once:
102114
break
103115

tensor2tensor/serving/serving_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def _encode(inputs, encoder, add_eos=True):
9494

9595

9696
def _decode(output_ids, output_decoder):
97-
return output_decoder.decode(output_ids, strip_extraneous=True)
97+
if len(output_ids.shape) > 1:
98+
return [output_decoder.decode(o, strip_extraneous=True) for o in output_ids]
99+
else:
100+
return output_decoder.decode(output_ids, strip_extraneous=True)
98101

99102

100103

0 commit comments

Comments
 (0)