diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 154afdb211f3..4a7b40409383 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -684,7 +684,7 @@ def forward( >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 224, 224] ```"""