From 0830c6d7b2c6ee16084a4f65e6773753b6c1ec1d Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 11 Apr 2022 12:44:15 +0200 Subject: [PATCH] Fix code sample --- src/transformers/utils/doc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 39508e18d222..8f0caf825bba 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -723,9 +723,10 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None): >>> logits = model(**inputs).logits >>> # retrieve index of {mask} - >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1] + >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0]) + >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index) - >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1) + >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1) >>> tokenizer.decode(predicted_token_id) {expected_output} ```