diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 39508e18d222..8f0caf825bba 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -723,9 +723,10 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None): >>> logits = model(**inputs).logits >>> # retrieve index of {mask} - >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1] + >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0]) + >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index) - >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1) + >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1) >>> tokenizer.decode(predicted_token_id) {expected_output} ```