From 87a04deadb8350779f7be59efd919869f1c882fe Mon Sep 17 00:00:00 2001 From: Moses Hohman Date: Wed, 17 Jul 2024 13:32:20 -0500 Subject: [PATCH] Make problem_type condition consistent with num_labels condition The latter condition generally overrides the former, so this is more of a code reading issue. I'm not sure the bug would ever actually get triggered under normal use. --- src/transformers/pipelines/image_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index c54f372baa9d..8aaa66e6c458 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -171,9 +171,9 @@ def _forward(self, model_inputs): def postprocess(self, model_outputs, function_to_apply=None, top_k=5): if function_to_apply is None: - if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1: + if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1: function_to_apply = ClassificationFunction.SIGMOID - elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1: + elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1: function_to_apply = ClassificationFunction.SOFTMAX elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None: function_to_apply = self.model.config.function_to_apply