diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index f98c87166ca0..12653ba6e039 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -46,13 +46,36 @@ def __call__(self, sequences, labels, hypothesis_template): class ZeroShotClassificationPipeline(ChunkPipeline): """ NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural - language inference) tasks. + language inference) tasks. Equivalent of `text-classification` pipelines, but these models don't require a + hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is + **much** more flexible. Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis pair and passed to the pretrained model. Then, the logit for *entailment* is taken as the logit for the candidate label being valid. Any NLI model can be used, but the id of the *entailment* label must be included in the model config's :attr:*~transformers.PretrainedConfig.label2id*. + Example: + + ```python + >>> from transformers import pipeline + + >>> oracle = pipeline(model="facebook/bart-large-mnli") + >>> answers = oracle( + ... "I have a problem with my iphone that needs to be resolved asap!!", + ... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"], + ... ) + {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'], 'scores': [0.504, 0.479, 0.013, 0.003, 0.002]} + + >>> oracle( + ... "I have a problem with my iphone that needs to be resolved asap!!", + ... candidate_labels=["english", "german"], + ... ) + {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['english', 'german'], 'scores': [0.814, 0.186]} + ``` + + [Learn more about the basics of using a pipeline in the [pipeline tutorial]](../pipeline_tutorial) + This NLI pipeline can currently be loaded from [`pipeline`] using the following task identifier: `"zero-shot-classification"`.