diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 394d2aaa2fed..f81066bea6c2 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -269,9 +269,10 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None): ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) - >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) + >>> model = {model_class}.from_pretrained( + ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification" + ... ) - >>> num_labels = len(model.config.id2label) >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to( ... torch.float ... )