Skip to content

[Inference Client] fix param docstring and deprecate labels param in zero-shot classification tasks #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,23 +2769,33 @@ def visual_question_answering(
response = self.post(json=payload, model=model, task="visual-question-answering")
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)

@_deprecate_arguments(
version="0.30.0",
deprecated_args=["labels"],
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
)
def zero_shot_classification(
self,
text: str,
labels: List[str],
# temporarily keeping it optional for backward compatibility.
candidate_labels: List[str] = None, # type: ignore
*,
multi_label: Optional[bool] = False,
hypothesis_template: Optional[str] = None,
model: Optional[str] = None,
# deprecated argument
labels: List[str] = None, # type: ignore
) -> List[ZeroShotClassificationOutputElement]:
"""
Provide as input a text and a set of candidate labels to classify the input text.

Args:
text (`str`):
The input text to classify.
labels (`List[str]`):
List of strings. Each string is the verbalization of a possible label for the input text.
candidate_labels (`List[str]`):
The set of possible class labels to classify the text into.
labels (`List[str]`, *optional*):
(deprecated) List of strings. Each string is the verbalization of a possible label for the input text.
multi_label (`bool`, *optional*):
Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of
the label likelihoods for each sequence is 1. If true, the labels are considered independent and
Expand Down Expand Up @@ -2852,9 +2862,17 @@ def zero_shot_classification(
]
```
"""

# handle deprecation
if labels is not None:
if candidate_labels is not None:
raise ValueError(
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
)
candidate_labels = labels
elif candidate_labels is None:
raise ValueError("Must specify `candidate_labels`")
parameters = {
"candidate_labels": labels,
"candidate_labels": candidate_labels,
"multi_label": multi_label,
"hypothesis_template": hypothesis_template,
}
Expand All @@ -2870,28 +2888,39 @@ def zero_shot_classification(
for label, score in zip(output["labels"], output["scores"])
]

@_deprecate_arguments(
version="0.30.0",
deprecated_args=["labels"],
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
)
def zero_shot_image_classification(
self,
image: ContentT,
labels: List[str],
# temporarily keeping it optional for backward compatibility.
candidate_labels: Optional[List[str]] = None,
*,
model: Optional[str] = None,
hypothesis_template: Optional[str] = None,
# deprecated argument
labels: Optional[List[str]] = None, # type: ignore
) -> List[ZeroShotImageClassificationOutputElement]:
"""
Provide input image and text labels to predict text labels for the image.

Args:
image (`Union[str, Path, bytes, BinaryIO]`):
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
labels (`List[str]`):
List of string possible labels. There must be at least 2 labels.
candidate_labels (`List[str]`):
The candidate labels for this image
labels (`List[str]`, *optional*):
(deprecated) List of string possible labels. There must be at least 2 labels.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used.
hypothesis_template (`str`, *optional*):
The sentence used in conjunction with candidateLabels to attempt the text classification by replacing
the placeholder with the candidate labels.

Returns:
`List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.

Expand All @@ -2913,11 +2942,20 @@ def zero_shot_image_classification(
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
```
"""
# handle deprecation
if labels is not None:
if candidate_labels is not None:
raise ValueError(
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
)
candidate_labels = labels
elif candidate_labels is None:
raise ValueError("Must specify `candidate_labels`")
# Raise ValueError if input is less than 2 labels
if len(labels) < 2:
if len(candidate_labels) < 2:
raise ValueError("You must specify at least 2 classes to compare.")
parameters = {
"candidate_labels": labels,
"candidate_labels": candidate_labels,
"hypothesis_template": hypothesis_template,
}
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
Expand Down
58 changes: 48 additions & 10 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,23 +2838,33 @@ async def visual_question_answering(
response = await self.post(json=payload, model=model, task="visual-question-answering")
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)

@_deprecate_arguments(
version="0.30.0",
deprecated_args=["labels"],
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
)
async def zero_shot_classification(
self,
text: str,
labels: List[str],
# temporarily keeping it optional for backward compatibility.
candidate_labels: List[str] = None, # type: ignore
*,
multi_label: Optional[bool] = False,
hypothesis_template: Optional[str] = None,
model: Optional[str] = None,
# deprecated argument
labels: List[str] = None, # type: ignore
) -> List[ZeroShotClassificationOutputElement]:
"""
Provide as input a text and a set of candidate labels to classify the input text.

Args:
text (`str`):
The input text to classify.
labels (`List[str]`):
List of strings. Each string is the verbalization of a possible label for the input text.
candidate_labels (`List[str]`):
The set of possible class labels to classify the text into.
labels (`List[str]`, *optional*):
(deprecated) List of strings. Each string is the verbalization of a possible label for the input text.
multi_label (`bool`, *optional*):
Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of
the label likelihoods for each sequence is 1. If true, the labels are considered independent and
Expand Down Expand Up @@ -2923,9 +2933,17 @@ async def zero_shot_classification(
]
```
"""

# handle deprecation
if labels is not None:
if candidate_labels is not None:
raise ValueError(
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
)
candidate_labels = labels
elif candidate_labels is None:
raise ValueError("Must specify `candidate_labels`")
parameters = {
"candidate_labels": labels,
"candidate_labels": candidate_labels,
"multi_label": multi_label,
"hypothesis_template": hypothesis_template,
}
Expand All @@ -2941,28 +2959,39 @@ async def zero_shot_classification(
for label, score in zip(output["labels"], output["scores"])
]

@_deprecate_arguments(
version="0.30.0",
deprecated_args=["labels"],
custom_message="`labels`has been renamed to `candidate_labels` and will be removed in huggingface_hub>=0.30.0.",
)
async def zero_shot_image_classification(
self,
image: ContentT,
labels: List[str],
# temporarily keeping it optional for backward compatibility.
candidate_labels: Optional[List[str]] = None,
*,
model: Optional[str] = None,
hypothesis_template: Optional[str] = None,
# deprecated argument
labels: Optional[List[str]] = None, # type: ignore
) -> List[ZeroShotImageClassificationOutputElement]:
"""
Provide input image and text labels to predict text labels for the image.

Args:
image (`Union[str, Path, bytes, BinaryIO]`):
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
labels (`List[str]`):
List of string possible labels. There must be at least 2 labels.
candidate_labels (`List[str]`):
The candidate labels for this image
labels (`List[str]`, *optional*):
(deprecated) List of string possible labels. There must be at least 2 labels.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used.
hypothesis_template (`str`, *optional*):
The sentence used in conjunction with candidateLabels to attempt the text classification by replacing
the placeholder with the candidate labels.

Returns:
`List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence.

Expand All @@ -2985,11 +3014,20 @@ async def zero_shot_image_classification(
[ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...]
```
"""
# handle deprecation
if labels is not None:
if candidate_labels is not None:
raise ValueError(
"Cannot specify both `labels` and `candidate_labels`. Use `candidate_labels` instead."
)
candidate_labels = labels
elif candidate_labels is None:
raise ValueError("Must specify `candidate_labels`")
# Raise ValueError if input is less than 2 labels
if len(labels) < 2:
if len(candidate_labels) < 2:
raise ValueError("You must specify at least 2 classes to compare.")
parameters = {
"candidate_labels": labels,
"candidate_labels": candidate_labels,
"hypothesis_template": hypothesis_template,
}
payload = _prepare_payload(image, parameters=parameters, expect_binary=True)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
from huggingface_hub.utils import build_hf_headers

from .testing_utils import with_production_testing
from .testing_utils import expect_deprecation, with_production_testing


# Avoid call to hf.co/api/models in VCRed tests
Expand Down Expand Up @@ -636,6 +636,7 @@ def test_visual_question_answering(self) -> None:
VisualQuestionAnsweringOutputElement(label=None, score=0.01777094043791294, answer="man"),
]

@expect_deprecation("zero_shot_classification")
def test_zero_shot_classification_single_label(self) -> None:
output = self.client.zero_shot_classification(
"A new model offers an explanation for how the Galilean satellites formed around the solar system's"
Expand All @@ -654,6 +655,7 @@ def test_zero_shot_classification_single_label(self) -> None:
],
)

@expect_deprecation("zero_shot_classification")
def test_zero_shot_classification_multi_label(self) -> None:
output = self.client.zero_shot_classification(
"A new model offers an explanation for how the Galilean satellites formed around the solar system's"
Expand Down
20 changes: 16 additions & 4 deletions utils/check_task_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
param_name = stmt.target.value
param_type = cst.Module([]).code_for_node(stmt.annotation.annotation)
docstring = self._extract_docstring(body_statements, index)
# Check if there's a default value
has_default = stmt.value is not None
default_value = cst.Module([]).code_for_node(stmt.value) if has_default else None

self.parameters[param_name] = {
"type": param_type,
"docstring": docstring,
"has_default": has_default,
"default_value": default_value,
}

@staticmethod
Expand Down Expand Up @@ -306,7 +312,7 @@ def _update_parameters(self, params: cst.Parameters) -> cst.Parameters:
new_param = cst.Param(
name=cst.Name(param_name),
annotation=annotation,
default=cst.Name("None"),
default=param_info["default_value"],
)
new_kwonly_params.append(new_param)
# Return the updated parameters object with new and updated parameters
Expand Down Expand Up @@ -381,10 +387,16 @@ def _format_param_docstring(
) -> List[str]:
"""Format the docstring lines for a single parameter."""
# Extract and format the parameter type
param_type = param_info["type"].replace("Optional[", "").rstrip("]")
optional_str = "*optional*" if "Optional[" in param_info["type"] else ""
param_type = param_info["type"]
if param_type.startswith("Optional["):
param_type = param_type[len("Optional[") : -1] # Remove Optional[ and closing ]
optional_str = ", *optional*"
else:
optional_str = ""

# Create the parameter line with type and optionality
param_line = f"{param_indent}{param_name} (`{param_type}`, {optional_str}):"
param_line = f"{param_indent}{param_name} (`{param_type}`{optional_str}):"

# Get and clean up the parameter description
param_desc = (param_info.get("docstring") or "").strip()
param_desc = " ".join(param_desc.split())
Expand Down
Loading