@@ -1277,24 +1277,25 @@ def serving_output(self, output):
12771277 """
12781278 Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
12791279 """
1280- if isinstance (output , ModelOutput ):
1281- for key in output .keys ():
1282- if key .endswith ("hidden_states" ) and not getattr (self .config , "output_hidden_states" , False ):
1283- output [key ] = None
1284- elif key .endswith ("attentions" ) and not getattr (self .config , "output_attentions" , False ):
1285- output [key ] = None
1286- elif key == "past_key_values" and not getattr (self .config , "use_cache" , False ):
1287- output [key ] = None
1288- elif key == "cross_attentions" and not (
1289- getattr (self .config , "output_attentions" , False )
1290- and getattr (self .config , "add_cross_attention" , False )
1291- ):
1292- output [key ] = None
1293- if isinstance (output [key ], (tuple , list )):
1294- try :
1295- output [key ] = tf .convert_to_tensor (output [key ])
1296- except (ValueError , tf .errors .InvalidArgumentError ):
1297- pass # Layers may not have the same dimensions
1280+ if not isinstance (output , ModelOutput ):
1281+ return output
1282+ for key in output .keys ():
1283+ if key .endswith ("hidden_states" ) and not getattr (self .config , "output_hidden_states" , False ):
1284+ output [key ] = None
1285+ elif key .endswith ("attentions" ) and not getattr (self .config , "output_attentions" , False ):
1286+ output [key ] = None
1287+ elif key == "past_key_values" and not getattr (self .config , "use_cache" , False ):
1288+ output [key ] = None
1289+ elif key == "cross_attentions" and not (
1290+ getattr (self .config , "output_attentions" , False )
1291+ and getattr (self .config , "add_cross_attention" , False )
1292+ ):
1293+ output [key ] = None
1294+ if isinstance (output [key ], (tuple , list )):
1295+ try :
1296+ output [key ] = tf .convert_to_tensor (output [key ])
1297+ except (ValueError , tf .errors .InvalidArgumentError ):
1298+ pass # Layers may not have the same dimensions
12981299 return output
12991300
13001301 def can_generate (self ) -> bool :
0 commit comments