diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index ec9a61e90099..a3d26b789c64 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1533,7 +1533,7 @@ def _generate( # 2. Define model inputs input_ids = self._prepare_model_inputs(input_ids, bos_token_id) # inputs_ids now has to be defined and cannot be None anymore - batch_size = input_ids.shape[0] + batch_size = shape_list(input_ids)[0] # 3. Prepare other model kwargs if output_attentions is not None: @@ -1702,7 +1702,8 @@ def _generate( @staticmethod def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor: - return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) + shape = shape_list(tensor) + return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:])) def _prepare_attention_mask_for_generation( self, @@ -2162,7 +2163,7 @@ def greedy_search( decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None # 3. init tensors to use for "xla-compileable" generate function - batch_size, cur_len = input_ids.shape + batch_size, cur_len = shape_list(input_ids) # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) @@ -2432,7 +2433,7 @@ def sample( decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None # 3. init tensors to use for "xla-compileable" generate function - batch_size, cur_len = input_ids.shape + batch_size, cur_len = shape_list(input_ids) # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) @@ -2678,18 +2679,16 @@ def beam_search( def flatten_beam_dim(tensor, batch_axis=0): """Flattens the first two dimensions of a non-scalar array.""" + shape = shape_list(tensor) return tf.reshape( tensor, - tensor.shape[:batch_axis] - + [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]] - + tensor.shape[batch_axis + 2 :], + shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :], ) def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0): """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" - return tf.reshape( - tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :] - ) + shape = shape_list(tensor) + return tf.reshape(tensor, shape[:batch_axis] + [batch_size, num_beams] + shape[batch_axis + 1 :]) def gather_beams(nested, beam_indices, batch_axis=0): """Gathers the beam slices indexed by beam_indices into new beam array.""" @@ -2748,7 +2747,7 @@ def gather_fn(tensor): decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None # 3. init tensors to use for "xla-compileable" generate function - batch_size, num_beams, cur_len = input_ids.shape + batch_size, num_beams, cur_len = shape_list(input_ids) # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( @@ -2894,7 +2893,7 @@ def beam_search_body_fn( eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape) did_topk_just_finished = eos_in_next_token & tf.broadcast_to( tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0), - eos_in_next_token.shape, + shape_list(eos_in_next_token), ) # non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next @@ -2917,7 +2916,7 @@ def beam_search_body_fn( topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) beams_in_batch_are_full = ( tf.broadcast_to( - tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape + tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) ) & early_stopping ) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 31338e56f972..d63b1b32733e 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -73,6 +73,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, TFAutoModel, + TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFBertModel, TFSharedEmbeddings, @@ -2163,6 +2164,46 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(model.weights, new_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + def test_generate_tf_function_export(self): + test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") + max_length = 8 + + class DummyModel(tf.Module): + def __init__(self, model): + super(DummyModel, self).__init__() + self.model = model + + @tf.function( + input_signature=( + tf.TensorSpec((None, max_length), tf.int32, name="input_ids"), + tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"), + ), + jit_compile=True, + ) + def serving(self, input_ids, attention_mask): + outputs = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_length, + return_dict_in_generate=True, + ) + return {"sequences": outputs["sequences"]} + + dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]] + dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]] + dummy_model = DummyModel(model=test_model) + with tempfile.TemporaryDirectory() as tmp_dir: + tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving}) + serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"] + for batch_size in range(1, len(dummy_input_ids) + 1): + inputs = { + "input_ids": tf.constant(dummy_input_ids[:batch_size]), + "attention_mask": tf.constant(dummy_attention_masks[:batch_size]), + } + tf_func_outputs = serving_func(**inputs)["sequences"] + tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length) + tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) + @require_tf @is_staging_test