diff --git a/tensor2tensor/serving/serving_utils.py b/tensor2tensor/serving/serving_utils.py index cfc2f3b5b..ba535ca88 100644 --- a/tensor2tensor/serving/serving_utils.py +++ b/tensor2tensor/serving/serving_utils.py @@ -23,6 +23,7 @@ import functools from googleapiclient import discovery import grpc +import numpy as np from tensor2tensor import problems as problems_lib # pylint: disable=unused-import from tensor2tensor.data_generators import text_encoder @@ -140,8 +141,12 @@ def _make_cloud_mlengine_request(examples): } } for ex in examples] } - prediction = api.projects().predict(body=input_data, name=parent).execute() - return prediction["predictions"] + response = api.projects().predict(body=input_data, name=parent).execute() + predictions = response["predictions"] + for prediction in predictions: + prediction["outputs"] = np.array([prediction["outputs"]]) + prediction["scores"] = np.array(prediction["scores"]) + return predictions return _make_cloud_mlengine_request