Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def _generate(
# 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = input_ids.shape[0]
batch_size = shape_list(input_ids)[0]

# 3. Prepare other model kwargs
if output_attentions is not None:
Expand Down Expand Up @@ -1702,7 +1702,8 @@ def _generate(

@staticmethod
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))

def _prepare_attention_mask_for_generation(
self,
Expand Down Expand Up @@ -2162,7 +2163,7 @@ def greedy_search(
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None

# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
batch_size, cur_len = shape_list(input_ids)

# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
Expand Down Expand Up @@ -2432,7 +2433,7 @@ def sample(
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None

# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
batch_size, cur_len = shape_list(input_ids)

# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
Expand Down Expand Up @@ -2678,18 +2679,16 @@ def beam_search(

def flatten_beam_dim(tensor, batch_axis=0):
"""Flattens the first two dimensions of a non-scalar array."""
shape = shape_list(tensor)
return tf.reshape(
tensor,
tensor.shape[:batch_axis]
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
+ tensor.shape[batch_axis + 2 :],
shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :],
)

def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
return tf.reshape(
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
)
shape = shape_list(tensor)
return tf.reshape(tensor, shape[:batch_axis] + [batch_size, num_beams] + shape[batch_axis + 1 :])

def gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array."""
Expand Down Expand Up @@ -2748,7 +2747,7 @@ def gather_fn(tensor):
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None

# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape
batch_size, num_beams, cur_len = shape_list(input_ids)

# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
Expand Down Expand Up @@ -2894,7 +2893,7 @@ def beam_search_body_fn(
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape,
shape_list(eos_in_next_token),
)

# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
Expand All @@ -2917,7 +2916,7 @@ def beam_search_body_fn(
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = (
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
)
& early_stopping
)
Expand Down
41 changes: 41 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModel,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
Expand Down Expand Up @@ -2163,6 +2164,46 @@ def test_checkpoint_sharding_local(self):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))

def test_generate_tf_function_export(self):
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
max_length = 8

class DummyModel(tf.Module):
def __init__(self, model):
super(DummyModel, self).__init__()
self.model = model

@tf.function(
input_signature=(
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
def serving(self, input_ids, attention_mask):
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_length,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}

dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
for batch_size in range(1, len(dummy_input_ids) + 1):
inputs = {
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)


@require_tf
@is_staging_test
Expand Down