Skip to content

Commit 05f7584

Browse files
committed
Tweak default serving output
1 parent 60a41fa commit 05f7584

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

src/transformers/modeling_tf_utils.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,24 +1277,25 @@ def serving_output(self, output):
12771277
"""
12781278
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
12791279
"""
1280-
if isinstance(output, ModelOutput):
1281-
for key in output.keys():
1282-
if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
1283-
output[key] = None
1284-
elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
1285-
output[key] = None
1286-
elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
1287-
output[key] = None
1288-
elif key == "cross_attentions" and not (
1289-
getattr(self.config, "output_attentions", False)
1290-
and getattr(self.config, "add_cross_attention", False)
1291-
):
1292-
output[key] = None
1293-
if isinstance(output[key], (tuple, list)):
1294-
try:
1295-
output[key] = tf.convert_to_tensor(output[key])
1296-
except (ValueError, tf.errors.InvalidArgumentError):
1297-
pass # Layers may not have the same dimensions
1280+
if not isinstance(output, ModelOutput):
1281+
return output
1282+
for key in output.keys():
1283+
if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
1284+
output[key] = None
1285+
elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
1286+
output[key] = None
1287+
elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
1288+
output[key] = None
1289+
elif key == "cross_attentions" and not (
1290+
getattr(self.config, "output_attentions", False)
1291+
and getattr(self.config, "add_cross_attention", False)
1292+
):
1293+
output[key] = None
1294+
if isinstance(output[key], (tuple, list)):
1295+
try:
1296+
output[key] = tf.convert_to_tensor(output[key])
1297+
except (ValueError, tf.errors.InvalidArgumentError):
1298+
pass # Layers may not have the same dimensions
12981299
return output
12991300

13001301
def can_generate(self) -> bool:

0 commit comments

Comments
 (0)