From 613b59131f2a70eb8351a1706bfe88c869ad5925 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Oct 2024 16:47:11 +0200 Subject: [PATCH] [Inference Client] Add task parameters and a maintenance script of these parameters (#2561) * Add additional parameters to Inference Client tasks * Add and run task params generation script * Add back missing test * Add comments to parameters generation script * Fix shared classes imports + text-to-speech task * Satisfy end-of-file-fixer hook * Move helper function to avoid duplicates across scripts * Rename helper function for more clarity * Fix bug in node traversing * Add comments * improve docstring formatting * fixes post-review * Remove aliases from reference package * Fix small bug and regenerate task parameters --- .github/workflows/python-quality.yml | 1 + Makefile | 8 +- .../en/package_reference/inference_types.md | 2 - .../ko/package_reference/inference_types.md | 2 - setup.py | 1 + src/huggingface_hub/__init__.py | 16 + src/huggingface_hub/inference/_client.py | 541 ++++++++++++++--- .../inference/_generated/_async_client.py | 539 ++++++++++++++--- .../inference/_generated/types/__init__.py | 28 +- .../_generated/types/audio_classification.py | 5 +- .../types/automatic_speech_recognition.py | 8 +- .../_generated/types/image_classification.py | 5 +- .../_generated/types/image_to_text.py | 8 +- .../_generated/types/text_classification.py | 15 +- .../_generated/types/text_to_audio.py | 4 +- .../_generated/types/text_to_speech.py | 4 +- .../_generated/types/video_classification.py | 4 +- tests/test_inference_async_client.py | 4 +- tests/test_inference_client.py | 7 +- utils/generate_async_inference_client.py | 15 +- utils/generate_inference_types.py | 89 +-- utils/generate_task_parameters.py | 548 ++++++++++++++++++ utils/helpers.py | 22 + 23 files changed, 1653 insertions(+), 223 deletions(-) create mode 100644 utils/generate_task_parameters.py diff --git a/.github/workflows/python-quality.yml b/.github/workflows/python-quality.yml index acb4e9a7fe..039c146ed1 100644 --- a/.github/workflows/python-quality.yml +++ b/.github/workflows/python-quality.yml @@ -42,6 +42,7 @@ jobs: - run: .venv/bin/python utils/check_static_imports.py - run: .venv/bin/python utils/generate_async_inference_client.py - run: .venv/bin/python utils/generate_inference_types.py + - run: .venv/bin/python utils/generate_task_parameters.py # Run type checking at least on huggingface_hub root file to check all modules # that can be lazy-loaded actually exist. diff --git a/Makefile b/Makefile index 341f8f9e12..c00fc30d4e 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ quality: python utils/check_contrib_list.py python utils/check_static_imports.py python utils/generate_async_inference_client.py + mypy src style: @@ -20,11 +21,14 @@ style: python utils/check_static_imports.py --update python utils/generate_async_inference_client.py --update -inference_types_check: +inference_check: python utils/generate_inference_types.py + python utils/generate_task_parameters.py -inference_types_update: +inference_update: python utils/generate_inference_types.py --update + python utils/generate_task_parameters.py --update + repocard: python utils/push_repocard_examples.py diff --git a/docs/source/en/package_reference/inference_types.md b/docs/source/en/package_reference/inference_types.md index f313294662..aa63c64b68 100644 --- a/docs/source/en/package_reference/inference_types.md +++ b/docs/source/en/package_reference/inference_types.md @@ -398,5 +398,3 @@ This part of the lib is still under development and will be improved in future r [[autodoc]] huggingface_hub.ZeroShotObjectDetectionInputData [[autodoc]] huggingface_hub.ZeroShotObjectDetectionOutputElement - - diff --git a/docs/source/ko/package_reference/inference_types.md b/docs/source/ko/package_reference/inference_types.md index ef4a62a570..393481e10f 100644 --- a/docs/source/ko/package_reference/inference_types.md +++ b/docs/source/ko/package_reference/inference_types.md @@ -397,5 +397,3 @@ rendered properly in your Markdown viewer. [[autodoc]] huggingface_hub.ZeroShotObjectDetectionInputData [[autodoc]] huggingface_hub.ZeroShotObjectDetectionOutputElement - - diff --git a/setup.py b/setup.py index e13aa28f88..373e3119cf 100644 --- a/setup.py +++ b/setup.py @@ -95,6 +95,7 @@ def get_version() -> str: extras["quality"] = [ "ruff>=0.5.0", "mypy==1.5.1", + "libcst==1.4.0", ] extras["all"] = extras["testing"] + extras["quality"] + extras["typing"] diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 99eacb71ba..807f265709 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -281,9 +281,11 @@ "inference._generated.types": [ "AudioClassificationInput", "AudioClassificationOutputElement", + "AudioClassificationOutputTransform", "AudioClassificationParameters", "AudioToAudioInput", "AudioToAudioOutputElement", + "AutomaticSpeechRecognitionEarlyStoppingEnum", "AutomaticSpeechRecognitionGenerationParameters", "AutomaticSpeechRecognitionInput", "AutomaticSpeechRecognitionOutput", @@ -328,6 +330,7 @@ "FillMaskParameters", "ImageClassificationInput", "ImageClassificationOutputElement", + "ImageClassificationOutputTransform", "ImageClassificationParameters", "ImageSegmentationInput", "ImageSegmentationOutputElement", @@ -336,6 +339,7 @@ "ImageToImageOutput", "ImageToImageParameters", "ImageToImageTargetSize", + "ImageToTextEarlyStoppingEnum", "ImageToTextGenerationParameters", "ImageToTextInput", "ImageToTextOutput", @@ -361,6 +365,7 @@ "Text2TextGenerationParameters", "TextClassificationInput", "TextClassificationOutputElement", + "TextClassificationOutputTransform", "TextClassificationParameters", "TextGenerationInput", "TextGenerationInputGenerateParameters", @@ -373,6 +378,7 @@ "TextGenerationStreamOutput", "TextGenerationStreamOutputStreamDetails", "TextGenerationStreamOutputToken", + "TextToAudioEarlyStoppingEnum", "TextToAudioGenerationParameters", "TextToAudioInput", "TextToAudioOutput", @@ -381,6 +387,7 @@ "TextToImageOutput", "TextToImageParameters", "TextToImageTargetSize", + "TextToSpeechEarlyStoppingEnum", "TextToSpeechGenerationParameters", "TextToSpeechInput", "TextToSpeechOutput", @@ -394,6 +401,7 @@ "TranslationParameters", "VideoClassificationInput", "VideoClassificationOutputElement", + "VideoClassificationOutputTransform", "VideoClassificationParameters", "VisualQuestionAnsweringInput", "VisualQuestionAnsweringInputData", @@ -796,9 +804,11 @@ def __dir__(): from .inference._generated.types import ( AudioClassificationInput, # noqa: F401 AudioClassificationOutputElement, # noqa: F401 + AudioClassificationOutputTransform, # noqa: F401 AudioClassificationParameters, # noqa: F401 AudioToAudioInput, # noqa: F401 AudioToAudioOutputElement, # noqa: F401 + AutomaticSpeechRecognitionEarlyStoppingEnum, # noqa: F401 AutomaticSpeechRecognitionGenerationParameters, # noqa: F401 AutomaticSpeechRecognitionInput, # noqa: F401 AutomaticSpeechRecognitionOutput, # noqa: F401 @@ -843,6 +853,7 @@ def __dir__(): FillMaskParameters, # noqa: F401 ImageClassificationInput, # noqa: F401 ImageClassificationOutputElement, # noqa: F401 + ImageClassificationOutputTransform, # noqa: F401 ImageClassificationParameters, # noqa: F401 ImageSegmentationInput, # noqa: F401 ImageSegmentationOutputElement, # noqa: F401 @@ -851,6 +862,7 @@ def __dir__(): ImageToImageOutput, # noqa: F401 ImageToImageParameters, # noqa: F401 ImageToImageTargetSize, # noqa: F401 + ImageToTextEarlyStoppingEnum, # noqa: F401 ImageToTextGenerationParameters, # noqa: F401 ImageToTextInput, # noqa: F401 ImageToTextOutput, # noqa: F401 @@ -876,6 +888,7 @@ def __dir__(): Text2TextGenerationParameters, # noqa: F401 TextClassificationInput, # noqa: F401 TextClassificationOutputElement, # noqa: F401 + TextClassificationOutputTransform, # noqa: F401 TextClassificationParameters, # noqa: F401 TextGenerationInput, # noqa: F401 TextGenerationInputGenerateParameters, # noqa: F401 @@ -888,6 +901,7 @@ def __dir__(): TextGenerationStreamOutput, # noqa: F401 TextGenerationStreamOutputStreamDetails, # noqa: F401 TextGenerationStreamOutputToken, # noqa: F401 + TextToAudioEarlyStoppingEnum, # noqa: F401 TextToAudioGenerationParameters, # noqa: F401 TextToAudioInput, # noqa: F401 TextToAudioOutput, # noqa: F401 @@ -896,6 +910,7 @@ def __dir__(): TextToImageOutput, # noqa: F401 TextToImageParameters, # noqa: F401 TextToImageTargetSize, # noqa: F401 + TextToSpeechEarlyStoppingEnum, # noqa: F401 TextToSpeechGenerationParameters, # noqa: F401 TextToSpeechInput, # noqa: F401 TextToSpeechOutput, # noqa: F401 @@ -909,6 +924,7 @@ def __dir__(): TranslationParameters, # noqa: F401 VideoClassificationInput, # noqa: F401 VideoClassificationOutputElement, # noqa: F401 + VideoClassificationOutputTransform, # noqa: F401 VideoClassificationParameters, # noqa: F401 VisualQuestionAnsweringInput, # noqa: F401 VisualQuestionAnsweringInputData, # noqa: F401 diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 63204da6d4..38b37b71e3 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -64,6 +64,7 @@ ) from huggingface_hub.inference._generated.types import ( AudioClassificationOutputElement, + AudioClassificationOutputTransform, AudioToAudioOutputElement, AutomaticSpeechRecognitionOutput, ChatCompletionInputGrammarType, @@ -81,9 +82,12 @@ SummarizationOutput, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, + TextClassificationOutputTransform, TextGenerationInputGrammarType, TextGenerationOutput, TextGenerationStreamOutput, + TextToImageTargetSize, + TextToSpeechEarlyStoppingEnum, TokenClassificationOutputElement, ToolElement, TranslationOutput, @@ -92,6 +96,7 @@ ZeroShotImageClassificationOutputElement, ) from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status +from huggingface_hub.utils._deprecation import _deprecate_arguments if TYPE_CHECKING: @@ -318,6 +323,8 @@ def audio_classification( audio: ContentT, *, model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["AudioClassificationOutputTransform"] = None, ) -> List[AudioClassificationOutputElement]: """ Perform audio classification on the provided audio content. @@ -330,6 +337,10 @@ def audio_classification( The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for audio classification will be used. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"AudioClassificationOutputTransform"`, *optional*): + The function to apply to the output. Returns: `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. @@ -352,7 +363,19 @@ def audio_classification( ] ``` """ - response = self.post(data=audio, model=model, task="audio-classification") + parameters = {"function_to_apply": function_to_apply, "top_k": top_k} + if all(parameter is None for parameter in parameters.values()): + # if no parameters are provided, send audio as raw data + data = audio + payload: Optional[Dict[str, Any]] = None + else: + # Or some parameters are provided -> send audio as base64 encoded string + data = None + payload = {"inputs": _b64_encode(audio)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = self.post(json=payload, data=data, model=model, task="audio-classification") return AudioClassificationOutputElement.parse_obj_as_list(response) def audio_to_audio( @@ -903,6 +926,14 @@ def document_question_answering( question: str, *, model: Optional[str] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + lang: Optional[str] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + word_boxes: Optional[List[Union[List[float], str]]] = None, ) -> List[DocumentQuestionAnsweringOutputElement]: """ Answer questions on document images. @@ -916,7 +947,29 @@ def document_question_answering( The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. Defaults to None. - + doc_stride (`int`, *optional*): + If the words in the document are too long to fit with the question for the model, it will + be split in several chunks with some overlap. This argument controls the size of that + overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer. + lang (`str`, *optional*): + Language to use while running OCR. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using doc_stride as + overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Can return less + than top_k answers if there are not enough options available within the context. + word_boxes (`List[Union[List[float], str]]`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If provided, the inference will + skip the OCR step and use the provided bounding boxes instead. Returns: `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. @@ -926,15 +979,29 @@ def document_question_answering( `HTTPError`: If the request fails with an HTTP error status code other than HTTP 503. + Example: ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") - [DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)] + [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16, words=None)] ``` """ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + parameters = { + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "lang": lang, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + "word_boxes": word_boxes, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = self.post(json=payload, model=model, task="document-question-answering") return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) @@ -959,7 +1026,7 @@ def feature_extraction( a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. Defaults to None. normalize (`bool`, *optional*): - Whether to normalize the embeddings or not. Defaults to None. + Whether to normalize the embeddings or not. Only available on server powered by Text-Embedding-Inference. prompt_name (`str`, *optional*): The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. @@ -968,7 +1035,7 @@ def feature_extraction( then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" because the prompt text will be prepended before any text to encode. truncate (`bool`, *optional*): - Whether to truncate the embeddings or not. Defaults to None. + Whether to truncate the embeddings or not. Only available on server powered by Text-Embedding-Inference. truncation_direction (`Literal["Left", "Right"]`, *optional*): Which side of the input should be truncated when `truncate=True` is passed. @@ -994,19 +1061,27 @@ def feature_extraction( ``` """ payload: Dict = {"inputs": text} - if normalize is not None: - payload["normalize"] = normalize - if prompt_name is not None: - payload["prompt_name"] = prompt_name - if truncate is not None: - payload["truncate"] = truncate - if truncation_direction is not None: - payload["truncation_direction"] = truncation_direction + parameters = { + "normalize": normalize, + "prompt_name": prompt_name, + "truncate": truncate, + "truncation_direction": truncation_direction, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = self.post(json=payload, model=model, task="feature-extraction") np = _import_numpy() return np.array(_bytes_to_dict(response), dtype="float32") - def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]: + def fill_mask( + self, + text: str, + *, + model: Optional[str] = None, + targets: Optional[List[str]] = None, + top_k: Optional[int] = None, + ) -> List[FillMaskOutputElement]: """ Fill in a hole with a missing word (token to be precise). @@ -1016,8 +1091,13 @@ def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskO model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. - Defaults to None. - + targets (`List[str]`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up + in the whole vocabulary. If the provided targets are not in the model vocab, they will be + tokenized and the first resulting token will be used (with a warning, and that might be + slower). + top_k (`int`, *optional*): + When passed, overrides the number of predictions to return. Returns: `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated probability, token reference, and completed text. @@ -1039,7 +1119,12 @@ def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskO ] ``` """ - response = self.post(json={"inputs": text}, model=model, task="fill-mask") + payload: Dict = {"inputs": text} + parameters = {"targets": targets, "top_k": top_k} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = self.post(json=payload, model=model, task="fill-mask") return FillMaskOutputElement.parse_obj_as_list(response) def image_classification( @@ -1047,6 +1132,8 @@ def image_classification( image: ContentT, *, model: Optional[str] = None, + function_to_apply: Optional[Literal["sigmoid", "softmax", "none"]] = None, + top_k: Optional[int] = None, ) -> List[ImageClassificationOutputElement]: """ Perform image classification on the given image using the specified model. @@ -1057,7 +1144,10 @@ def image_classification( model (`str`, *optional*): The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. - + function_to_apply (`Literal["sigmoid", "softmax", "none"]`, *optional*): + The function to apply to the output scores. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. Returns: `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. @@ -1072,10 +1162,23 @@ def image_classification( >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") - [ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...] + [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...] ``` """ - response = self.post(data=image, model=model, task="image-classification") + parameters = {"function_to_apply": function_to_apply, "top_k": top_k} + + if all(parameter is None for parameter in parameters.values()): + data = image + payload: Optional[Dict[str, Any]] = None + + else: + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + + response = self.post(json=payload, data=data, model=model, task="image-classification") return ImageClassificationOutputElement.parse_obj_as_list(response) def image_segmentation( @@ -1083,6 +1186,10 @@ def image_segmentation( image: ContentT, *, model: Optional[str] = None, + mask_threshold: Optional[float] = None, + overlap_mask_area_threshold: Optional[float] = None, + subtask: Optional[Literal["instance", "panoptic", "semantic"]] = None, + threshold: Optional[float] = None, ) -> List[ImageSegmentationOutputElement]: """ Perform image segmentation on the given image using the specified model. @@ -1099,7 +1206,14 @@ def image_segmentation( model (`str`, *optional*): The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. - + mask_threshold (`float`, *optional*): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*): + Mask overlap threshold to eliminate small, disconnected segments. + subtask (`Literal["instance", "panoptic", "semantic"]`, *optional*): + Segmentation task to be performed, depending on model capabilities. + threshold (`float`, *optional*): + Probability threshold to filter out predicted masks. Returns: `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. @@ -1113,11 +1227,28 @@ def image_segmentation( ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() - >>> client.image_segmentation("cat.jpg"): + >>> client.image_segmentation("cat.jpg") [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] ``` """ - response = self.post(data=image, model=model, task="image-segmentation") + parameters = { + "mask_threshold": mask_threshold, + "overlap_mask_area_threshold": overlap_mask_area_threshold, + "subtask": subtask, + "threshold": threshold, + } + if all(parameter is None for parameter in parameters.values()): + # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image + data = image + payload: Optional[Dict[str, Any]] = None + else: + # if parameters are provided, the image needs to be a base64-encoded string + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = self.post(json=payload, data=data, model=model, task="image-segmentation") output = ImageSegmentationOutputElement.parse_obj_as_list(response) for item in output: item.mask = _b64_to_image(item.mask) # type: ignore [assignment] @@ -1197,7 +1328,7 @@ def image_to_image( data = image payload: Optional[Dict[str, Any]] = None else: - # Or an image + some parameters => use base64 encoding + # if parameters are provided, the image needs to be a base64-encoded string data = None payload = {"inputs": _b64_encode(image)} for key, value in parameters.items(): @@ -1328,10 +1459,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None: return models_by_task def object_detection( - self, - image: ContentT, - *, - model: Optional[str] = None, + self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None ) -> List[ObjectDetectionOutputElement]: """ Perform object detection on the given image using the specified model. @@ -1348,7 +1476,8 @@ def object_detection( model (`str`, *optional*): The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. - + threshold (`float`, *optional*): + The probability necessary to make a prediction. Returns: `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. @@ -1368,13 +1497,37 @@ def object_detection( [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] ``` """ - # detect objects - response = self.post(data=image, model=model, task="object-detection") + parameters = { + "threshold": threshold, + } + if all(parameter is None for parameter in parameters.values()): + # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image + data = image + payload: Optional[Dict[str, Any]] = None + else: + # if parameters are provided, the image needs to be a base64-encoded string + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = self.post(json=payload, data=data, model=model, task="object-detection") return ObjectDetectionOutputElement.parse_obj_as_list(response) def question_answering( - self, question: str, context: str, *, model: Optional[str] = None - ) -> QuestionAnsweringOutputElement: + self, + question: str, + context: str, + *, + model: Optional[str] = None, + align_to_words: Optional[bool] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: """ Retrieve the answer to a question from a given text. @@ -1386,10 +1539,31 @@ def question_answering( model (`str`): The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. - + align_to_words (`bool`, *optional*): + Attempts to align the answer to real words. Improves quality on space separated + languages. Might hurt on non-space-separated languages (like Japanese or Chinese). + doc_stride (`int`, *optional*): + If the context is too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using docStride as + overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. Returns: - [`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer. - + Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: + When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. + When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. @@ -1401,17 +1575,30 @@ def question_answering( >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") - QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara') + QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11) ``` """ - + parameters = { + "align_to_words": align_to_words, + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + } payload: Dict[str, Any] = {"question": question, "context": context} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = self.post( json=payload, model=model, task="question-answering", ) - return QuestionAnsweringOutputElement.parse_obj_as_instance(response) + # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility. + output = QuestionAnsweringOutputElement.parse_obj(response) + return output def sentence_similarity( self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None @@ -1460,12 +1647,23 @@ def sentence_similarity( ) return _bytes_to_list(response) + @_deprecate_arguments( + version="0.29", + deprecated_args=["parameters"], + custom_message=( + "The `parameters` argument is deprecated and will be removed in a future version. " + "Provide individual parameters instead: `clean_up_tokenization_spaces`, `generate_parameters`, and `truncation`." + ), + ) def summarization( self, text: str, *, parameters: Optional[Dict[str, Any]] = None, model: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + generate_parameters: Optional[Dict[str, Any]] = None, + truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, ) -> SummarizationOutput: """ Generate a summary of a given text using a specified model. @@ -1478,8 +1676,13 @@ def summarization( for more details. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. - + Inference Endpoint. If not provided, the default recommended model for summarization will be used. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. + truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + The truncation strategy to use. Returns: [`SummarizationOutput`]: The generated summary text. @@ -1500,11 +1703,25 @@ def summarization( payload: Dict[str, Any] = {"inputs": text} if parameters is not None: payload["parameters"] = parameters + else: + parameters = { + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "generate_parameters": generate_parameters, + "truncation": truncation, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = self.post(json=payload, model=model, task="summarization") return SummarizationOutput.parse_obj_as_list(response)[0] def table_question_answering( - self, table: Dict[str, Any], query: str, *, model: Optional[str] = None + self, + table: Dict[str, Any], + query: str, + *, + model: Optional[str] = None, + parameters: Optional[Dict[str, Any]] = None, ) -> TableQuestionAnsweringOutputElement: """ Retrieve the answer to a question from information given in a table. @@ -1518,6 +1735,8 @@ def table_question_answering( model (`str`): The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. + parameters (`Dict[str, Any]`, *optional*): + Additional inference parameters. Defaults to None. Returns: [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. @@ -1538,11 +1757,15 @@ def table_question_answering( TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') ``` """ + payload: Dict[str, Any] = { + "query": query, + "table": table, + } + + if parameters is not None: + payload["parameters"] = parameters response = self.post( - json={ - "query": query, - "table": table, - }, + json=payload, model=model, task="table-question-answering", ) @@ -1633,7 +1856,14 @@ def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = No response = self.post(json={"table": table}, model=model, task="tabular-regression") return _bytes_to_list(response) - def text_classification(self, text: str, *, model: Optional[str] = None) -> List[TextClassificationOutputElement]: + def text_classification( + self, + text: str, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["TextClassificationOutputTransform"] = None, + ) -> List[TextClassificationOutputElement]: """ Perform text classification (e.g. sentiment-analysis) on the given text. @@ -1644,6 +1874,10 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> List The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. Defaults to None. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"TextClassificationOutputTransform"`, *optional*): + The function to apply to the output. Returns: `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. @@ -1665,7 +1899,15 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> List ] ``` """ - response = self.post(json={"inputs": text}, model=model, task="text-classification") + payload: Dict[str, Any] = {"inputs": text} + parameters = { + "function_to_apply": function_to_apply, + "top_k": top_k, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = self.post(json=payload, model=model, task="text-classification") return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] @overload @@ -2174,6 +2416,9 @@ def text_to_image( num_inference_steps: Optional[float] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, + scheduler: Optional[str] = None, + target_size: Optional[TextToImageTargetSize] = None, + seed: Optional[int] = None, **kwargs, ) -> "Image": """ @@ -2202,7 +2447,14 @@ def text_to_image( usually at the expense of lower image quality. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Inference Endpoint. If not provided, the default recommended text-to-image model will be used. + Defaults to None. + scheduler (`str`, *optional*): + Override the scheduler with a compatible one. + target_size (`TextToImageTargetSize`, *optional*): + The size in pixel of the output image + seed (`int`, *optional*): + Seed for the random number generator. Returns: `Image`: The generated image. @@ -2236,6 +2488,9 @@ def text_to_image( "width": width, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, + "scheduler": scheduler, + "target_size": target_size, + "seed": seed, **kwargs, } for key, value in parameters.items(): @@ -2244,7 +2499,28 @@ def text_to_image( response = self.post(json=payload, model=model, task="text-to-image") return _bytes_to_image(response) - def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: + def text_to_speech( + self, + text: str, + *, + model: Optional[str] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None, + epsilon_cutoff: Optional[float] = None, + eta_cutoff: Optional[float] = None, + max_length: Optional[int] = None, + max_new_tokens: Optional[int] = None, + min_length: Optional[int] = None, + min_new_tokens: Optional[int] = None, + num_beam_groups: Optional[int] = None, + num_beams: Optional[int] = None, + penalty_alpha: Optional[float] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + use_cache: Optional[bool] = None, + ) -> bytes: """ Synthesize an audio of a voice pronouncing a given text. @@ -2253,7 +2529,56 @@ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: The text to synthesize. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Inference Endpoint. If not provided, the default recommended text-to-speech model will be used. + Defaults to None. + do_sample (`bool`, *optional*): + Whether to use sampling instead of greedy decoding when generating new tokens. + early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"`, *optional*): + Controls the stopping condition for beam-based methods. + epsilon_cutoff (`float`, *optional*): + If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + eta_cutoff (`float`, *optional*): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + max_length (`int`, *optional*): + The maximum length (in tokens) of the generated text, including the input. + max_new_tokens (`int`, *optional*): + The maximum number of tokens to generate. Takes precedence over maxLength. + min_length (`int`, *optional*): + The minimum length (in tokens) of the generated text, including the input. + min_new_tokens (`int`, *optional*): + The minimum number of tokens to generate. Takes precedence over maxLength. + num_beam_groups (`int`, *optional*): + Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + num_beams (`int`, *optional*): + Number of beams to use for beam search. + penalty_alpha (`float`, *optional*): + The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + temperature (`float`, *optional*): + The value used to modulate the next token probabilities. + top_k (`int`, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*): + If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + typical_p (`float`, *optional*): + Local typicality measures how similar the conditional probability of predicting a target token next is + to the expected conditional probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most locally typical tokens with + probabilities that add up to typical_p or higher are kept for generation. See [this + paper](https://hf.co/papers/2202.00666) for more details. + use_cache (`bool`, *optional*): + Whether the model should use the past last key/values attentions to speed up decoding Returns: `bytes`: The generated audio. @@ -2274,10 +2599,39 @@ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: >>> Path("hello_world.flac").write_bytes(audio) ``` """ - return self.post(json={"inputs": text}, model=model, task="text-to-speech") + payload: Dict[str, Any] = {"inputs": text} + parameters = { + "do_sample": do_sample, + "early_stopping": early_stopping, + "epsilon_cutoff": epsilon_cutoff, + "eta_cutoff": eta_cutoff, + "max_length": max_length, + "max_new_tokens": max_new_tokens, + "min_length": min_length, + "min_new_tokens": min_new_tokens, + "num_beam_groups": num_beam_groups, + "num_beams": num_beams, + "penalty_alpha": penalty_alpha, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "typical_p": typical_p, + "use_cache": use_cache, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = self.post(json=payload, model=model, task="text-to-speech") + return response def token_classification( - self, text: str, *, model: Optional[str] = None + self, + text: str, + *, + model: Optional[str] = None, + aggregation_strategy: Optional[Literal["none", "simple", "first", "average", "max"]] = None, + ignore_labels: Optional[List[str]] = None, + stride: Optional[int] = None, ) -> List[TokenClassificationOutputElement]: """ Perform token classification on the given text. @@ -2290,6 +2644,12 @@ def token_classification( The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. Defaults to None. + aggregation_strategy (`Literal["none", "simple", "first", "average", "max"]`, *optional*): + The strategy used to fuse tokens based on model predictions. + ignore_labels (`List[str]`, *optional*): + A list of labels to ignore. + stride (`int`, *optional*): + The number of overlapping tokens between chunks when splitting the input text. Returns: `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. @@ -2324,6 +2684,14 @@ def token_classification( ``` """ payload: Dict[str, Any] = {"inputs": text} + parameters = { + "aggregation_strategy": aggregation_strategy, + "ignore_labels": ignore_labels, + "stride": stride, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = self.post( json=payload, model=model, @@ -2332,7 +2700,15 @@ def token_classification( return TokenClassificationOutputElement.parse_obj_as_list(response) def translation( - self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None + self, + text: str, + *, + model: Optional[str] = None, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, + generate_parameters: Optional[Dict[str, Any]] = None, ) -> TranslationOutput: """ Convert text from one language to another. @@ -2341,7 +2717,6 @@ def translation( your specific use case. Source and target languages usually depend on the model. However, it is possible to specify source and target languages for certain models. If you are working with one of these models, you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. - You can find this information in the model card. Args: text (`str`): @@ -2351,9 +2726,15 @@ def translation( a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. Defaults to None. src_lang (`str`, *optional*): - Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. + The source language of the text. Required for models that can translate from multiple languages. tgt_lang (`str`, *optional*): - Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. + Target language to translate to. Required for models that can translate to multiple languages. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + The truncation strategy to use. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. Returns: [`TranslationOutput`]: The generated translated text. @@ -2388,11 +2769,17 @@ def translation( if src_lang is None and tgt_lang is not None: raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") - - # If both `src_lang` and `tgt_lang` are given, pass them to the request body - payload: Dict = {"inputs": text} - if src_lang and tgt_lang: - payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} + payload: Dict[str, Any] = {"inputs": text} + parameters = { + "src_lang": src_lang, + "tgt_lang": tgt_lang, + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "truncation": truncation, + "generate_parameters": generate_parameters, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = self.post(json=payload, model=model, task="translation") return TranslationOutput.parse_obj_as_list(response)[0] @@ -2402,6 +2789,7 @@ def visual_question_answering( question: str, *, model: Optional[str] = None, + top_k: Optional[int] = None, ) -> List[VisualQuestionAnsweringOutputElement]: """ Answering open-ended questions based on an image. @@ -2415,7 +2803,10 @@ def visual_question_answering( The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. Defaults to None. - + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. Returns: `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. @@ -2440,6 +2831,8 @@ def visual_question_answering( ``` """ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + if top_k is not None: + payload.setdefault("parameters", {})["top_k"] = top_k response = self.post(json=payload, model=model, task="visual-question-answering") return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) @@ -2470,7 +2863,7 @@ def zero_shot_classification( The model then evaluates for both hypotheses if they are entailed in the provided `text` or not. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. Returns: `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. @@ -2547,7 +2940,12 @@ def zero_shot_classification( ] def zero_shot_image_classification( - self, image: ContentT, labels: List[str], *, model: Optional[str] = None + self, + image: ContentT, + labels: List[str], + *, + model: Optional[str] = None, + hypothesis_template: Optional[str] = None, ) -> List[ZeroShotImageClassificationOutputElement]: """ Provide input image and text labels to predict text labels for the image. @@ -2559,8 +2957,10 @@ def zero_shot_image_classification( List of string possible labels. There must be at least 2 labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. - + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. + hypothesis_template (`str`, *optional*): + The sentence used in conjunction with `labels` to attempt the text classification by replacing the + placeholder with the candidate labels. Returns: `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. @@ -2586,8 +2986,13 @@ def zero_shot_image_classification( if len(labels) < 2: raise ValueError("You must specify at least 2 classes to compare.") + payload = { + "inputs": {"image": _b64_encode(image), "candidateLabels": ",".join(labels)}, + } + if hypothesis_template is not None: + payload.setdefault("parameters", {})["hypothesis_template"] = hypothesis_template response = self.post( - json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}}, + json=payload, model=model, task="zero-shot-image-classification", ) diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 83e5070b30..8a1384a671 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -50,6 +50,7 @@ ) from huggingface_hub.inference._generated.types import ( AudioClassificationOutputElement, + AudioClassificationOutputTransform, AudioToAudioOutputElement, AutomaticSpeechRecognitionOutput, ChatCompletionInputGrammarType, @@ -67,9 +68,12 @@ SummarizationOutput, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, + TextClassificationOutputTransform, TextGenerationInputGrammarType, TextGenerationOutput, TextGenerationStreamOutput, + TextToImageTargetSize, + TextToSpeechEarlyStoppingEnum, TokenClassificationOutputElement, ToolElement, TranslationOutput, @@ -78,6 +82,7 @@ ZeroShotImageClassificationOutputElement, ) from huggingface_hub.utils import build_hf_headers +from huggingface_hub.utils._deprecation import _deprecate_arguments from .._common import _async_yield_from, _import_aiohttp @@ -351,6 +356,8 @@ async def audio_classification( audio: ContentT, *, model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["AudioClassificationOutputTransform"] = None, ) -> List[AudioClassificationOutputElement]: """ Perform audio classification on the provided audio content. @@ -363,6 +370,10 @@ async def audio_classification( The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for audio classification will be used. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"AudioClassificationOutputTransform"`, *optional*): + The function to apply to the output. Returns: `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. @@ -386,7 +397,19 @@ async def audio_classification( ] ``` """ - response = await self.post(data=audio, model=model, task="audio-classification") + parameters = {"function_to_apply": function_to_apply, "top_k": top_k} + if all(parameter is None for parameter in parameters.values()): + # if no parameters are provided, send audio as raw data + data = audio + payload: Optional[Dict[str, Any]] = None + else: + # Or some parameters are provided -> send audio as base64 encoded string + data = None + payload = {"inputs": _b64_encode(audio)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = await self.post(json=payload, data=data, model=model, task="audio-classification") return AudioClassificationOutputElement.parse_obj_as_list(response) async def audio_to_audio( @@ -945,6 +968,14 @@ async def document_question_answering( question: str, *, model: Optional[str] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + lang: Optional[str] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + word_boxes: Optional[List[Union[List[float], str]]] = None, ) -> List[DocumentQuestionAnsweringOutputElement]: """ Answer questions on document images. @@ -958,7 +989,29 @@ async def document_question_answering( The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. Defaults to None. - + doc_stride (`int`, *optional*): + If the words in the document are too long to fit with the question for the model, it will + be split in several chunks with some overlap. This argument controls the size of that + overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer. + lang (`str`, *optional*): + Language to use while running OCR. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using doc_stride as + overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Can return less + than top_k answers if there are not enough options available within the context. + word_boxes (`List[Union[List[float], str]]`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If provided, the inference will + skip the OCR step and use the provided bounding boxes instead. Returns: `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. @@ -968,16 +1021,30 @@ async def document_question_answering( `aiohttp.ClientResponseError`: If the request fails with an HTTP error status code other than HTTP 503. + Example: ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") - [DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)] + [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16, words=None)] ``` """ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + parameters = { + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "lang": lang, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + "word_boxes": word_boxes, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = await self.post(json=payload, model=model, task="document-question-answering") return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) @@ -1002,7 +1069,7 @@ async def feature_extraction( a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. Defaults to None. normalize (`bool`, *optional*): - Whether to normalize the embeddings or not. Defaults to None. + Whether to normalize the embeddings or not. Only available on server powered by Text-Embedding-Inference. prompt_name (`str`, *optional*): The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. @@ -1011,7 +1078,7 @@ async def feature_extraction( then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" because the prompt text will be prepended before any text to encode. truncate (`bool`, *optional*): - Whether to truncate the embeddings or not. Defaults to None. + Whether to truncate the embeddings or not. Only available on server powered by Text-Embedding-Inference. truncation_direction (`Literal["Left", "Right"]`, *optional*): Which side of the input should be truncated when `truncate=True` is passed. @@ -1038,19 +1105,27 @@ async def feature_extraction( ``` """ payload: Dict = {"inputs": text} - if normalize is not None: - payload["normalize"] = normalize - if prompt_name is not None: - payload["prompt_name"] = prompt_name - if truncate is not None: - payload["truncate"] = truncate - if truncation_direction is not None: - payload["truncation_direction"] = truncation_direction + parameters = { + "normalize": normalize, + "prompt_name": prompt_name, + "truncate": truncate, + "truncation_direction": truncation_direction, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = await self.post(json=payload, model=model, task="feature-extraction") np = _import_numpy() return np.array(_bytes_to_dict(response), dtype="float32") - async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]: + async def fill_mask( + self, + text: str, + *, + model: Optional[str] = None, + targets: Optional[List[str]] = None, + top_k: Optional[int] = None, + ) -> List[FillMaskOutputElement]: """ Fill in a hole with a missing word (token to be precise). @@ -1060,8 +1135,13 @@ async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[Fil model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. - Defaults to None. - + targets (`List[str]`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up + in the whole vocabulary. If the provided targets are not in the model vocab, they will be + tokenized and the first resulting token will be used (with a warning, and that might be + slower). + top_k (`int`, *optional*): + When passed, overrides the number of predictions to return. Returns: `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated probability, token reference, and completed text. @@ -1084,7 +1164,12 @@ async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[Fil ] ``` """ - response = await self.post(json={"inputs": text}, model=model, task="fill-mask") + payload: Dict = {"inputs": text} + parameters = {"targets": targets, "top_k": top_k} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = await self.post(json=payload, model=model, task="fill-mask") return FillMaskOutputElement.parse_obj_as_list(response) async def image_classification( @@ -1092,6 +1177,8 @@ async def image_classification( image: ContentT, *, model: Optional[str] = None, + function_to_apply: Optional[Literal["sigmoid", "softmax", "none"]] = None, + top_k: Optional[int] = None, ) -> List[ImageClassificationOutputElement]: """ Perform image classification on the given image using the specified model. @@ -1102,7 +1189,10 @@ async def image_classification( model (`str`, *optional*): The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. - + function_to_apply (`Literal["sigmoid", "softmax", "none"]`, *optional*): + The function to apply to the output scores. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. Returns: `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. @@ -1118,10 +1208,23 @@ async def image_classification( >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") - [ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...] + [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...] ``` """ - response = await self.post(data=image, model=model, task="image-classification") + parameters = {"function_to_apply": function_to_apply, "top_k": top_k} + + if all(parameter is None for parameter in parameters.values()): + data = image + payload: Optional[Dict[str, Any]] = None + + else: + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + + response = await self.post(json=payload, data=data, model=model, task="image-classification") return ImageClassificationOutputElement.parse_obj_as_list(response) async def image_segmentation( @@ -1129,6 +1232,10 @@ async def image_segmentation( image: ContentT, *, model: Optional[str] = None, + mask_threshold: Optional[float] = None, + overlap_mask_area_threshold: Optional[float] = None, + subtask: Optional[Literal["instance", "panoptic", "semantic"]] = None, + threshold: Optional[float] = None, ) -> List[ImageSegmentationOutputElement]: """ Perform image segmentation on the given image using the specified model. @@ -1145,7 +1252,14 @@ async def image_segmentation( model (`str`, *optional*): The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. - + mask_threshold (`float`, *optional*): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*): + Mask overlap threshold to eliminate small, disconnected segments. + subtask (`Literal["instance", "panoptic", "semantic"]`, *optional*): + Segmentation task to be performed, depending on model capabilities. + threshold (`float`, *optional*): + Probability threshold to filter out predicted masks. Returns: `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. @@ -1160,11 +1274,28 @@ async def image_segmentation( # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() - >>> await client.image_segmentation("cat.jpg"): + >>> await client.image_segmentation("cat.jpg") [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] ``` """ - response = await self.post(data=image, model=model, task="image-segmentation") + parameters = { + "mask_threshold": mask_threshold, + "overlap_mask_area_threshold": overlap_mask_area_threshold, + "subtask": subtask, + "threshold": threshold, + } + if all(parameter is None for parameter in parameters.values()): + # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image + data = image + payload: Optional[Dict[str, Any]] = None + else: + # if parameters are provided, the image needs to be a base64-encoded string + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = await self.post(json=payload, data=data, model=model, task="image-segmentation") output = ImageSegmentationOutputElement.parse_obj_as_list(response) for item in output: item.mask = _b64_to_image(item.mask) # type: ignore [assignment] @@ -1245,7 +1376,7 @@ async def image_to_image( data = image payload: Optional[Dict[str, Any]] = None else: - # Or an image + some parameters => use base64 encoding + # if parameters are provided, the image needs to be a base64-encoded string data = None payload = {"inputs": _b64_encode(image)} for key, value in parameters.items(): @@ -1383,10 +1514,7 @@ async def _fetch_framework(framework: str) -> None: return models_by_task async def object_detection( - self, - image: ContentT, - *, - model: Optional[str] = None, + self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None ) -> List[ObjectDetectionOutputElement]: """ Perform object detection on the given image using the specified model. @@ -1403,7 +1531,8 @@ async def object_detection( model (`str`, *optional*): The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. - + threshold (`float`, *optional*): + The probability necessary to make a prediction. Returns: `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. @@ -1424,13 +1553,37 @@ async def object_detection( [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] ``` """ - # detect objects - response = await self.post(data=image, model=model, task="object-detection") + parameters = { + "threshold": threshold, + } + if all(parameter is None for parameter in parameters.values()): + # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image + data = image + payload: Optional[Dict[str, Any]] = None + else: + # if parameters are provided, the image needs to be a base64-encoded string + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = await self.post(json=payload, data=data, model=model, task="object-detection") return ObjectDetectionOutputElement.parse_obj_as_list(response) async def question_answering( - self, question: str, context: str, *, model: Optional[str] = None - ) -> QuestionAnsweringOutputElement: + self, + question: str, + context: str, + *, + model: Optional[str] = None, + align_to_words: Optional[bool] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: """ Retrieve the answer to a question from a given text. @@ -1442,10 +1595,31 @@ async def question_answering( model (`str`): The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. - + align_to_words (`bool`, *optional*): + Attempts to align the answer to real words. Improves quality on space separated + languages. Might hurt on non-space-separated languages (like Japanese or Chinese). + doc_stride (`int`, *optional*): + If the context is too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using docStride as + overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. Returns: - [`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer. - + Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: + When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. + When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. Raises: [`InferenceTimeoutError`]: If the model is unavailable or the request times out. @@ -1458,17 +1632,30 @@ async def question_answering( >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") - QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara') + QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11) ``` """ - + parameters = { + "align_to_words": align_to_words, + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + } payload: Dict[str, Any] = {"question": question, "context": context} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = await self.post( json=payload, model=model, task="question-answering", ) - return QuestionAnsweringOutputElement.parse_obj_as_instance(response) + # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility. + output = QuestionAnsweringOutputElement.parse_obj(response) + return output async def sentence_similarity( self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None @@ -1518,12 +1705,23 @@ async def sentence_similarity( ) return _bytes_to_list(response) + @_deprecate_arguments( + version="0.29", + deprecated_args=["parameters"], + custom_message=( + "The `parameters` argument is deprecated and will be removed in a future version. " + "Provide individual parameters instead: `clean_up_tokenization_spaces`, `generate_parameters`, and `truncation`." + ), + ) async def summarization( self, text: str, *, parameters: Optional[Dict[str, Any]] = None, model: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + generate_parameters: Optional[Dict[str, Any]] = None, + truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, ) -> SummarizationOutput: """ Generate a summary of a given text using a specified model. @@ -1536,8 +1734,13 @@ async def summarization( for more details. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. - + Inference Endpoint. If not provided, the default recommended model for summarization will be used. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. + truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + The truncation strategy to use. Returns: [`SummarizationOutput`]: The generated summary text. @@ -1559,11 +1762,25 @@ async def summarization( payload: Dict[str, Any] = {"inputs": text} if parameters is not None: payload["parameters"] = parameters + else: + parameters = { + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "generate_parameters": generate_parameters, + "truncation": truncation, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = await self.post(json=payload, model=model, task="summarization") return SummarizationOutput.parse_obj_as_list(response)[0] async def table_question_answering( - self, table: Dict[str, Any], query: str, *, model: Optional[str] = None + self, + table: Dict[str, Any], + query: str, + *, + model: Optional[str] = None, + parameters: Optional[Dict[str, Any]] = None, ) -> TableQuestionAnsweringOutputElement: """ Retrieve the answer to a question from information given in a table. @@ -1577,6 +1794,8 @@ async def table_question_answering( model (`str`): The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. + parameters (`Dict[str, Any]`, *optional*): + Additional inference parameters. Defaults to None. Returns: [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. @@ -1598,11 +1817,15 @@ async def table_question_answering( TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') ``` """ + payload: Dict[str, Any] = { + "query": query, + "table": table, + } + + if parameters is not None: + payload["parameters"] = parameters response = await self.post( - json={ - "query": query, - "table": table, - }, + json=payload, model=model, task="table-question-answering", ) @@ -1696,7 +1919,12 @@ async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str return _bytes_to_list(response) async def text_classification( - self, text: str, *, model: Optional[str] = None + self, + text: str, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["TextClassificationOutputTransform"] = None, ) -> List[TextClassificationOutputElement]: """ Perform text classification (e.g. sentiment-analysis) on the given text. @@ -1708,6 +1936,10 @@ async def text_classification( The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. Defaults to None. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"TextClassificationOutputTransform"`, *optional*): + The function to apply to the output. Returns: `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. @@ -1730,7 +1962,15 @@ async def text_classification( ] ``` """ - response = await self.post(json={"inputs": text}, model=model, task="text-classification") + payload: Dict[str, Any] = {"inputs": text} + parameters = { + "function_to_apply": function_to_apply, + "top_k": top_k, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = await self.post(json=payload, model=model, task="text-classification") return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] @overload @@ -2240,6 +2480,9 @@ async def text_to_image( num_inference_steps: Optional[float] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, + scheduler: Optional[str] = None, + target_size: Optional[TextToImageTargetSize] = None, + seed: Optional[int] = None, **kwargs, ) -> "Image": """ @@ -2268,7 +2511,14 @@ async def text_to_image( usually at the expense of lower image quality. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Inference Endpoint. If not provided, the default recommended text-to-image model will be used. + Defaults to None. + scheduler (`str`, *optional*): + Override the scheduler with a compatible one. + target_size (`TextToImageTargetSize`, *optional*): + The size in pixel of the output image + seed (`int`, *optional*): + Seed for the random number generator. Returns: `Image`: The generated image. @@ -2303,6 +2553,9 @@ async def text_to_image( "width": width, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, + "scheduler": scheduler, + "target_size": target_size, + "seed": seed, **kwargs, } for key, value in parameters.items(): @@ -2311,7 +2564,28 @@ async def text_to_image( response = await self.post(json=payload, model=model, task="text-to-image") return _bytes_to_image(response) - async def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: + async def text_to_speech( + self, + text: str, + *, + model: Optional[str] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None, + epsilon_cutoff: Optional[float] = None, + eta_cutoff: Optional[float] = None, + max_length: Optional[int] = None, + max_new_tokens: Optional[int] = None, + min_length: Optional[int] = None, + min_new_tokens: Optional[int] = None, + num_beam_groups: Optional[int] = None, + num_beams: Optional[int] = None, + penalty_alpha: Optional[float] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + use_cache: Optional[bool] = None, + ) -> bytes: """ Synthesize an audio of a voice pronouncing a given text. @@ -2320,7 +2594,56 @@ async def text_to_speech(self, text: str, *, model: Optional[str] = None) -> byt The text to synthesize. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Inference Endpoint. If not provided, the default recommended text-to-speech model will be used. + Defaults to None. + do_sample (`bool`, *optional*): + Whether to use sampling instead of greedy decoding when generating new tokens. + early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"`, *optional*): + Controls the stopping condition for beam-based methods. + epsilon_cutoff (`float`, *optional*): + If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + eta_cutoff (`float`, *optional*): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + max_length (`int`, *optional*): + The maximum length (in tokens) of the generated text, including the input. + max_new_tokens (`int`, *optional*): + The maximum number of tokens to generate. Takes precedence over maxLength. + min_length (`int`, *optional*): + The minimum length (in tokens) of the generated text, including the input. + min_new_tokens (`int`, *optional*): + The minimum number of tokens to generate. Takes precedence over maxLength. + num_beam_groups (`int`, *optional*): + Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + num_beams (`int`, *optional*): + Number of beams to use for beam search. + penalty_alpha (`float`, *optional*): + The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + temperature (`float`, *optional*): + The value used to modulate the next token probabilities. + top_k (`int`, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*): + If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + typical_p (`float`, *optional*): + Local typicality measures how similar the conditional probability of predicting a target token next is + to the expected conditional probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most locally typical tokens with + probabilities that add up to typical_p or higher are kept for generation. See [this + paper](https://hf.co/papers/2202.00666) for more details. + use_cache (`bool`, *optional*): + Whether the model should use the past last key/values attentions to speed up decoding Returns: `bytes`: The generated audio. @@ -2342,10 +2665,39 @@ async def text_to_speech(self, text: str, *, model: Optional[str] = None) -> byt >>> Path("hello_world.flac").write_bytes(audio) ``` """ - return await self.post(json={"inputs": text}, model=model, task="text-to-speech") + payload: Dict[str, Any] = {"inputs": text} + parameters = { + "do_sample": do_sample, + "early_stopping": early_stopping, + "epsilon_cutoff": epsilon_cutoff, + "eta_cutoff": eta_cutoff, + "max_length": max_length, + "max_new_tokens": max_new_tokens, + "min_length": min_length, + "min_new_tokens": min_new_tokens, + "num_beam_groups": num_beam_groups, + "num_beams": num_beams, + "penalty_alpha": penalty_alpha, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "typical_p": typical_p, + "use_cache": use_cache, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + response = await self.post(json=payload, model=model, task="text-to-speech") + return response async def token_classification( - self, text: str, *, model: Optional[str] = None + self, + text: str, + *, + model: Optional[str] = None, + aggregation_strategy: Optional[Literal["none", "simple", "first", "average", "max"]] = None, + ignore_labels: Optional[List[str]] = None, + stride: Optional[int] = None, ) -> List[TokenClassificationOutputElement]: """ Perform token classification on the given text. @@ -2358,6 +2710,12 @@ async def token_classification( The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. Defaults to None. + aggregation_strategy (`Literal["none", "simple", "first", "average", "max"]`, *optional*): + The strategy used to fuse tokens based on model predictions. + ignore_labels (`List[str]`, *optional*): + A list of labels to ignore. + stride (`int`, *optional*): + The number of overlapping tokens between chunks when splitting the input text. Returns: `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. @@ -2393,6 +2751,14 @@ async def token_classification( ``` """ payload: Dict[str, Any] = {"inputs": text} + parameters = { + "aggregation_strategy": aggregation_strategy, + "ignore_labels": ignore_labels, + "stride": stride, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = await self.post( json=payload, model=model, @@ -2401,7 +2767,15 @@ async def token_classification( return TokenClassificationOutputElement.parse_obj_as_list(response) async def translation( - self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None + self, + text: str, + *, + model: Optional[str] = None, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, + generate_parameters: Optional[Dict[str, Any]] = None, ) -> TranslationOutput: """ Convert text from one language to another. @@ -2410,7 +2784,6 @@ async def translation( your specific use case. Source and target languages usually depend on the model. However, it is possible to specify source and target languages for certain models. If you are working with one of these models, you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. - You can find this information in the model card. Args: text (`str`): @@ -2420,9 +2793,15 @@ async def translation( a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. Defaults to None. src_lang (`str`, *optional*): - Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. + The source language of the text. Required for models that can translate from multiple languages. tgt_lang (`str`, *optional*): - Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. + Target language to translate to. Required for models that can translate to multiple languages. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + The truncation strategy to use. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. Returns: [`TranslationOutput`]: The generated translated text. @@ -2458,11 +2837,17 @@ async def translation( if src_lang is None and tgt_lang is not None: raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") - - # If both `src_lang` and `tgt_lang` are given, pass them to the request body - payload: Dict = {"inputs": text} - if src_lang and tgt_lang: - payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} + payload: Dict[str, Any] = {"inputs": text} + parameters = { + "src_lang": src_lang, + "tgt_lang": tgt_lang, + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "truncation": truncation, + "generate_parameters": generate_parameters, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value response = await self.post(json=payload, model=model, task="translation") return TranslationOutput.parse_obj_as_list(response)[0] @@ -2472,6 +2857,7 @@ async def visual_question_answering( question: str, *, model: Optional[str] = None, + top_k: Optional[int] = None, ) -> List[VisualQuestionAnsweringOutputElement]: """ Answering open-ended questions based on an image. @@ -2485,7 +2871,10 @@ async def visual_question_answering( The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. Defaults to None. - + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. Returns: `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. @@ -2511,6 +2900,8 @@ async def visual_question_answering( ``` """ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + if top_k is not None: + payload.setdefault("parameters", {})["top_k"] = top_k response = await self.post(json=payload, model=model, task="visual-question-answering") return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) @@ -2541,7 +2932,7 @@ async def zero_shot_classification( The model then evaluates for both hypotheses if they are entailed in the provided `text` or not. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. Returns: `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. @@ -2620,7 +3011,12 @@ async def zero_shot_classification( ] async def zero_shot_image_classification( - self, image: ContentT, labels: List[str], *, model: Optional[str] = None + self, + image: ContentT, + labels: List[str], + *, + model: Optional[str] = None, + hypothesis_template: Optional[str] = None, ) -> List[ZeroShotImageClassificationOutputElement]: """ Provide input image and text labels to predict text labels for the image. @@ -2632,8 +3028,10 @@ async def zero_shot_image_classification( List of string possible labels. There must be at least 2 labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed - Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. - + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. + hypothesis_template (`str`, *optional*): + The sentence used in conjunction with `labels` to attempt the text classification by replacing the + placeholder with the candidate labels. Returns: `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. @@ -2660,8 +3058,13 @@ async def zero_shot_image_classification( if len(labels) < 2: raise ValueError("You must specify at least 2 classes to compare.") + payload = { + "inputs": {"image": _b64_encode(image), "candidateLabels": ",".join(labels)}, + } + if hypothesis_template is not None: + payload.setdefault("parameters", {})["hypothesis_template"] = hypothesis_template response = await self.post( - json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}}, + json=payload, model=model, task="zero-shot-image-classification", ) diff --git a/src/huggingface_hub/inference/_generated/types/__init__.py b/src/huggingface_hub/inference/_generated/types/__init__.py index 057a491f46..caa46d05fc 100644 --- a/src/huggingface_hub/inference/_generated/types/__init__.py +++ b/src/huggingface_hub/inference/_generated/types/__init__.py @@ -6,10 +6,12 @@ from .audio_classification import ( AudioClassificationInput, AudioClassificationOutputElement, + AudioClassificationOutputTransform, AudioClassificationParameters, ) from .audio_to_audio import AudioToAudioInput, AudioToAudioOutputElement from .automatic_speech_recognition import ( + AutomaticSpeechRecognitionEarlyStoppingEnum, AutomaticSpeechRecognitionGenerationParameters, AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput, @@ -59,11 +61,18 @@ from .image_classification import ( ImageClassificationInput, ImageClassificationOutputElement, + ImageClassificationOutputTransform, ImageClassificationParameters, ) from .image_segmentation import ImageSegmentationInput, ImageSegmentationOutputElement, ImageSegmentationParameters from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize -from .image_to_text import ImageToTextGenerationParameters, ImageToTextInput, ImageToTextOutput, ImageToTextParameters +from .image_to_text import ( + ImageToTextEarlyStoppingEnum, + ImageToTextGenerationParameters, + ImageToTextInput, + ImageToTextOutput, + ImageToTextParameters, +) from .object_detection import ( ObjectDetectionBoundingBox, ObjectDetectionInput, @@ -84,7 +93,12 @@ TableQuestionAnsweringOutputElement, ) from .text2text_generation import Text2TextGenerationInput, Text2TextGenerationOutput, Text2TextGenerationParameters -from .text_classification import TextClassificationInput, TextClassificationOutputElement, TextClassificationParameters +from .text_classification import ( + TextClassificationInput, + TextClassificationOutputElement, + TextClassificationOutputTransform, + TextClassificationParameters, +) from .text_generation import ( TextGenerationInput, TextGenerationInputGenerateParameters, @@ -98,9 +112,16 @@ TextGenerationStreamOutputStreamDetails, TextGenerationStreamOutputToken, ) -from .text_to_audio import TextToAudioGenerationParameters, TextToAudioInput, TextToAudioOutput, TextToAudioParameters +from .text_to_audio import ( + TextToAudioEarlyStoppingEnum, + TextToAudioGenerationParameters, + TextToAudioInput, + TextToAudioOutput, + TextToAudioParameters, +) from .text_to_image import TextToImageInput, TextToImageOutput, TextToImageParameters, TextToImageTargetSize from .text_to_speech import ( + TextToSpeechEarlyStoppingEnum, TextToSpeechGenerationParameters, TextToSpeechInput, TextToSpeechOutput, @@ -115,6 +136,7 @@ from .video_classification import ( VideoClassificationInput, VideoClassificationOutputElement, + VideoClassificationOutputTransform, VideoClassificationParameters, ) from .visual_question_answering import ( diff --git a/src/huggingface_hub/inference/_generated/types/audio_classification.py b/src/huggingface_hub/inference/_generated/types/audio_classification.py index f828c980cb..f02447e3a2 100644 --- a/src/huggingface_hub/inference/_generated/types/audio_classification.py +++ b/src/huggingface_hub/inference/_generated/types/audio_classification.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] +AudioClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass @@ -18,7 +18,8 @@ class AudioClassificationParameters(BaseInferenceType): Additional inference parameters for Audio Classification """ - function_to_apply: Optional["ClassificationOutputTransform"] = None + function_to_apply: Optional["AudioClassificationOutputTransform"] = None + """The function to apply to the output.""" top_k: Optional[int] = None """When specified, limits the output to the top K most probable classes.""" diff --git a/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py b/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py index 29323bf2a9..cfd35cfcb0 100644 --- a/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +++ b/src/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -EarlyStoppingEnum = Literal["never"] +AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"] @dataclass @@ -20,7 +20,7 @@ class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType): do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" - early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability @@ -40,11 +40,11 @@ class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType): max_length: Optional[int] = None """The maximum length (in tokens) of the generated text, including the input.""" max_new_tokens: Optional[int] = None - """The maximum number of tokens to generate. Takes precedence over maxLength.""" + """The maximum number of tokens to generate. Takes precedence over max_length.""" min_length: Optional[int] = None """The minimum length (in tokens) of the generated text, including the input.""" min_new_tokens: Optional[int] = None - """The minimum number of tokens to generate. Takes precedence over maxLength.""" + """The minimum number of tokens to generate. Takes precedence over min_length.""" num_beam_groups: Optional[int] = None """Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. diff --git a/src/huggingface_hub/inference/_generated/types/image_classification.py b/src/huggingface_hub/inference/_generated/types/image_classification.py index 91b24d2c0b..3f47bb0acd 100644 --- a/src/huggingface_hub/inference/_generated/types/image_classification.py +++ b/src/huggingface_hub/inference/_generated/types/image_classification.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] +ImageClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass @@ -18,7 +18,8 @@ class ImageClassificationParameters(BaseInferenceType): Additional inference parameters for Image Classification """ - function_to_apply: Optional["ClassificationOutputTransform"] = None + function_to_apply: Optional["ImageClassificationOutputTransform"] = None + """The function to apply to the output.""" top_k: Optional[int] = None """When specified, limits the output to the top K most probable classes.""" diff --git a/src/huggingface_hub/inference/_generated/types/image_to_text.py b/src/huggingface_hub/inference/_generated/types/image_to_text.py index 0ebb9a9bc6..0af33e89d5 100644 --- a/src/huggingface_hub/inference/_generated/types/image_to_text.py +++ b/src/huggingface_hub/inference/_generated/types/image_to_text.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -EarlyStoppingEnum = Literal["never"] +ImageToTextEarlyStoppingEnum = Literal["never"] @dataclass @@ -20,7 +20,7 @@ class ImageToTextGenerationParameters(BaseInferenceType): do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" - early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability @@ -40,11 +40,11 @@ class ImageToTextGenerationParameters(BaseInferenceType): max_length: Optional[int] = None """The maximum length (in tokens) of the generated text, including the input.""" max_new_tokens: Optional[int] = None - """The maximum number of tokens to generate. Takes precedence over maxLength.""" + """The maximum number of tokens to generate. Takes precedence over max_length.""" min_length: Optional[int] = None """The minimum length (in tokens) of the generated text, including the input.""" min_new_tokens: Optional[int] = None - """The minimum number of tokens to generate. Takes precedence over maxLength.""" + """The minimum number of tokens to generate. Takes precedence over min_length.""" num_beam_groups: Optional[int] = None """Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. diff --git a/src/huggingface_hub/inference/_generated/types/text_classification.py b/src/huggingface_hub/inference/_generated/types/text_classification.py index bf61a4eebc..830fd6bbd1 100644 --- a/src/huggingface_hub/inference/_generated/types/text_classification.py +++ b/src/huggingface_hub/inference/_generated/types/text_classification.py @@ -9,18 +9,23 @@ from .base import BaseInferenceType -ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] +TextClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass class TextClassificationParameters(BaseInferenceType): - """Additional inference parameters - Additional inference parameters for Text Classification + """ + Additional inference parameters for Text Classification. """ - function_to_apply: Optional["ClassificationOutputTransform"] = None + function_to_apply: Optional["TextClassificationOutputTransform"] = None + """ + The function to apply to the output. + """ top_k: Optional[int] = None - """When specified, limits the output to the top K most probable classes.""" + """ + When specified, limits the output to the top K most probable classes. + """ @dataclass diff --git a/src/huggingface_hub/inference/_generated/types/text_to_audio.py b/src/huggingface_hub/inference/_generated/types/text_to_audio.py index dd8369de4b..e9a26d0431 100644 --- a/src/huggingface_hub/inference/_generated/types/text_to_audio.py +++ b/src/huggingface_hub/inference/_generated/types/text_to_audio.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -EarlyStoppingEnum = Literal["never"] +TextToAudioEarlyStoppingEnum = Literal["never"] @dataclass @@ -20,7 +20,7 @@ class TextToAudioGenerationParameters(BaseInferenceType): do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" - early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + early_stopping: Optional[Union[bool, "TextToAudioEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability diff --git a/src/huggingface_hub/inference/_generated/types/text_to_speech.py b/src/huggingface_hub/inference/_generated/types/text_to_speech.py index 30e0b1d7d8..fa96e885ee 100644 --- a/src/huggingface_hub/inference/_generated/types/text_to_speech.py +++ b/src/huggingface_hub/inference/_generated/types/text_to_speech.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -EarlyStoppingEnum = Literal["never"] +TextToSpeechEarlyStoppingEnum = Literal["never"] @dataclass @@ -20,7 +20,7 @@ class TextToSpeechGenerationParameters(BaseInferenceType): do_sample: Optional[bool] = None """Whether to use sampling instead of greedy decoding when generating new tokens.""" - early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None """Controls the stopping condition for beam-based methods.""" epsilon_cutoff: Optional[float] = None """If set to float strictly between 0 and 1, only tokens with a conditional probability diff --git a/src/huggingface_hub/inference/_generated/types/video_classification.py b/src/huggingface_hub/inference/_generated/types/video_classification.py index 0c5a9d55a8..a32249dc12 100644 --- a/src/huggingface_hub/inference/_generated/types/video_classification.py +++ b/src/huggingface_hub/inference/_generated/types/video_classification.py @@ -9,7 +9,7 @@ from .base import BaseInferenceType -ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] +VideoClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] @dataclass @@ -20,7 +20,7 @@ class VideoClassificationParameters(BaseInferenceType): frame_sampling_rate: Optional[int] = None """The sampling rate used to select frames from the video.""" - function_to_apply: Optional["ClassificationOutputTransform"] = None + function_to_apply: Optional["VideoClassificationOutputTransform"] = None num_frames: Optional[int] = None """The number of sampled frames to consider for classification.""" top_k: Optional[int] = None diff --git a/tests/test_inference_async_client.py b/tests/test_inference_async_client.py index f77409f88c..418ea95e48 100644 --- a/tests/test_inference_async_client.py +++ b/tests/test_inference_async_client.py @@ -275,7 +275,9 @@ def test_sync_vs_async_signatures() -> None: # Check that the async method is async async_method = getattr(async_client, name) - assert inspect.iscoroutinefunction(async_method) + # Since some methods are decorated with @_deprecate_arguments, we need to unwrap the async method to get the actual coroutine function + # TODO: Remove this once the @_deprecate_arguments decorator is removed from the AsyncInferenceClient methods. + assert inspect.iscoroutinefunction(inspect.unwrap(async_method)) # Check that expected inputs and outputs are the same sync_sig = inspect.signature(sync_method) diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 6d6fe56b5a..a4fa971a2b 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -49,10 +49,7 @@ from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, MAIN_INFERENCE_API_FRAMEWORKS from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary -from huggingface_hub.inference._common import ( - _stream_chat_completion_response, - _stream_text_generation_response, -) +from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response from huggingface_hub.utils import build_hf_headers from .testing_utils import with_production_testing @@ -194,7 +191,7 @@ class InferenceClientVCRTest(InferenceClientTest): Tips when adding new tasks: - Most of the time, we only test that the return values are correct. We don't always test the actual output of the model. - - In the CI, VRC replay is always on. If you want to test locally against the server, you can use the `--vcr-mode` + - In the CI, VRC replay is always on. If you want to test locally against the server, you can use the `--vcr-record` and `--disable-vcr` command line options. See https://pytest-vcr.readthedocs.io/en/latest/configuration/. - If you get rate-limited locally, you can use your own token when initializing InferenceClient. /!\\ WARNING: if you do so, you must delete the token from the cassette before committing! diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index 17818d1858..832049ad5d 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -15,13 +15,11 @@ """Contains a tool to generate `src/huggingface_hub/inference/_generated/_async_client.py`.""" import argparse -import os import re -import tempfile from pathlib import Path from typing import NoReturn -from ruff.__main__ import find_ruff_bin +from helpers import format_source_code ASYNC_CLIENT_FILE_PATH = ( @@ -77,17 +75,6 @@ def generate_async_client_code(code: str) -> str: return code -def format_source_code(code: str) -> str: - """Apply formatter on a generated source code.""" - with tempfile.TemporaryDirectory() as tmpdir: - filepath = Path(tmpdir) / "async_client.py" - filepath.write_text(code) - ruff_bin = find_ruff_bin() - os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "check", str(filepath), "--fix", "--quiet"]) - os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "format", str(filepath), "--quiet"]) - return filepath.read_text() - - def check_async_client(update: bool) -> NoReturn: """Check AsyncInferenceClient is correctly defined and consistent with InferenceClient. diff --git a/utils/generate_inference_types.py b/utils/generate_inference_types.py index fe5223426f..23ab7c7b9c 100644 --- a/utils/generate_inference_types.py +++ b/utils/generate_inference_types.py @@ -15,14 +15,11 @@ """Contains a tool to generate `src/huggingface_hub/inference/_generated/types`.""" import argparse -import os import re -import tempfile from pathlib import Path from typing import Dict, List, Literal, NoReturn -from helpers import check_and_update_file_content -from ruff.__main__ import find_ruff_bin +from helpers import check_and_update_file_content, format_source_code huggingface_hub_folder_path = Path(__file__).parents[1] / "src" / "huggingface_hub" @@ -56,6 +53,16 @@ re.VERBOSE | re.MULTILINE, ) +TYPE_ALIAS_REGEX = re.compile( + r""" + ^(?!\s) # to make sure the line does not start with whitespace (top-level) + (\w+) + \s*=\s* + (.+) + $ + """, + re.VERBOSE | re.MULTILINE, +) OPTIONAL_FIELD_REGEX = re.compile(r": Optional\[(.+)\]$", re.MULTILINE) @@ -78,6 +85,7 @@ re.MULTILINE | re.VERBOSE | re.DOTALL, ) + # List of classes that are shared across multiple modules # This is used to fix the naming of the classes (to make them unique by task) SHARED_CLASSES = [ @@ -86,6 +94,7 @@ "ClassificationOutput", "GenerationParameters", "TargetSize", + "EarlyStoppingEnum", ] REFERENCE_PACKAGE_EN_CONTENT = """ @@ -130,6 +139,27 @@ """ +def _replace_class_name(content: str, cls: str, new_cls: str) -> str: + """ + Replace the class name `cls` with the new class name `new_cls` in the content. + """ + pattern = rf""" + (? str: content = content.replace( "\nfrom dataclasses import", "\nfrom .base import BaseInferenceType\nfrom dataclasses import" @@ -144,8 +174,9 @@ def _delete_empty_lines(content: str) -> str: def _fix_naming_for_shared_classes(content: str, module_name: str) -> str: for cls in SHARED_CLASSES: - cls_definition = f"\nclass {cls}" - + # No need to fix the naming of a shared class if it's not used in the module + if cls not in content: + continue # Update class definition # Very hacky way to build "AudioClassificationOutputElement" instead of "ClassificationOutput" new_cls = "".join(part.capitalize() for part in module_name.split("_")) @@ -157,18 +188,8 @@ def _fix_naming_for_shared_classes(content: str, module_name: str) -> str: if new_cls.endswith("ClassificationOutput"): # to get "AudioClassificationOutputElement" new_cls += "Element" - new_cls_definition = "\nclass " + new_cls - content = content.replace(cls_definition, new_cls_definition) - - # Update regular class usage - regular_cls = f": {cls}\n" - new_regular_cls = f": {new_cls}\n" - content = content.replace(regular_cls, new_regular_cls) - - # Update optional class usage - optional_cls = f"Optional[{cls}]" - new_optional_cls = f"Optional[{new_cls}]" - content = content.replace(optional_cls, new_optional_cls) + content = _replace_class_name(content, cls, new_cls) + return content @@ -201,6 +222,15 @@ def _list_dataclasses(content: str) -> List[str]: return INHERITED_DATACLASS_REGEX.findall(content) +def _list_shared_aliases(content: str) -> List[str]: + """List all shared class aliases defined in the module.""" + all_aliases = TYPE_ALIAS_REGEX.findall(content) + shared_class_pattern = r"(\w+(?:" + "|".join(re.escape(cls) for cls in SHARED_CLASSES) + r"))$" + shared_class_regex = re.compile(shared_class_pattern) + aliases = [alias_class for alias_class, _ in all_aliases if shared_class_regex.search(alias_class)] + return aliases + + def fix_inference_classes(content: str, module_name: str) -> str: content = _inherit_from_base(content) content = _delete_empty_lines(content) @@ -227,17 +257,6 @@ def add_dataclasses_to_main_init(content: str, dataclasses: Dict[str, List[str]] return MAIN_INIT_PY_REGEX.sub(f'"inference._generated.types": [{dataclasses_str}]', content) -def format_source_code(code: str) -> str: - """Apply formatter on the generated source code.""" - with tempfile.TemporaryDirectory() as tmpdir: - filepath = Path(tmpdir) / "tmp.py" - filepath.write_text(code) - ruff_bin = find_ruff_bin() - os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "check", str(filepath), "--fix", "--quiet"]) - os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "format", str(filepath), "--quiet"]) - return filepath.read_text() - - def generate_reference_package(dataclasses: Dict[str, List[str]], language: Literal["en", "ko"]) -> str: """Generate the reference package content.""" @@ -259,11 +278,12 @@ def generate_reference_package(dataclasses: Dict[str, List[str]], language: Lite def check_inference_types(update: bool) -> NoReturn: - """Check AsyncInferenceClient is correctly defined and consistent with InferenceClient. + """Check and update inference types. This script is used in the `make style` and `make quality` checks. """ dataclasses = {} + aliases = {} for file in INFERENCE_TYPES_FOLDER_PATH.glob("*.py"): if file.name in IGNORE_FILES: continue @@ -272,21 +292,20 @@ def check_inference_types(update: bool) -> NoReturn: fixed_content = fix_inference_classes(content, module_name=file.stem) formatted_content = format_source_code(fixed_content) - dataclasses[file.stem] = _list_dataclasses(formatted_content) - + aliases[file.stem] = _list_shared_aliases(formatted_content) check_and_update_file_content(file, formatted_content, update) - init_py_content = create_init_py(dataclasses) + all_classes = {module: dataclasses[module] + aliases[module] for module in dataclasses.keys()} + init_py_content = create_init_py(all_classes) init_py_content = format_source_code(init_py_content) init_py_file = INFERENCE_TYPES_FOLDER_PATH / "__init__.py" check_and_update_file_content(init_py_file, init_py_content, update) main_init_py_content = MAIN_INIT_PY_FILE.read_text() - updated_main_init_py_content = add_dataclasses_to_main_init(main_init_py_content, dataclasses) + updated_main_init_py_content = add_dataclasses_to_main_init(main_init_py_content, all_classes) updated_main_init_py_content = format_source_code(updated_main_init_py_content) check_and_update_file_content(MAIN_INIT_PY_FILE, updated_main_init_py_content, update) - reference_package_content_en = generate_reference_package(dataclasses, "en") check_and_update_file_content(REFERENCE_PACKAGE_EN_PATH, reference_package_content_en, update) diff --git a/utils/generate_task_parameters.py b/utils/generate_task_parameters.py new file mode 100644 index 0000000000..fa80c2f7a0 --- /dev/null +++ b/utils/generate_task_parameters.py @@ -0,0 +1,548 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility script to check and update the InferenceClient task methods arguments and docstrings +based on the tasks input parameters. + +What this script does: +- [x] detect missing parameters in method signature +- [x] add missing parameters to methods signature +- [ ] detect outdated parameters in method signature +- [ ] update outdated parameters in method signature + +- [x] detect missing parameters in method docstrings +- [x] add missing parameters to methods docstrings +- [ ] detect outdated parameters in method docstrings +- [ ] update outdated parameters in method docstrings + +- [ ] detect when parameter not used in method implementation +- [ ] update method implementation when parameter not used +Related resources: +- https://github.com/huggingface/huggingface_hub/issues/2063 +- https://github.com/huggingface/huggingface_hub/issues/2557 +- https://github.com/huggingface/huggingface_hub/pull/2561 +""" + +import argparse +import builtins +import inspect +import re +import textwrap +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, NoReturn, Optional, Set + +import libcst as cst +from helpers import format_source_code +from libcst.codemod import CodemodContext +from libcst.codemod.visitors import GatherImportsVisitor + +from huggingface_hub.inference._client import InferenceClient + + +# Paths to project files +BASE_DIR = Path(__file__).parents[1] / "src" / "huggingface_hub" +INFERENCE_TYPES_PATH = BASE_DIR / "inference" / "_generated" / "types" +INFERENCE_CLIENT_FILE = BASE_DIR / "inference" / "_client.py" + +DEFAULT_MODULE = "huggingface_hub.inference._generated.types" + + +# Temporary solution to skip tasks where there is no Parameters dataclass or the schema needs to be updated +TASKS_TO_SKIP = [ + "chat_completion", + "depth_estimation", + "audio_to_audio", + "feature_extraction", + "sentence_similarity", + "table_question_answering", + "automatic_speech_recognition", + "image_to_text", + "image_to_image", +] + +PARAMETERS_DATACLASS_REGEX = re.compile( + r""" + ^@dataclass + \nclass\s(\w+Parameters)\(BaseInferenceType\): + """, + re.VERBOSE | re.MULTILINE, +) + +#### NODE VISITORS + + +class DataclassFieldCollector(cst.CSTVisitor): + """A visitor that collects fields (parameters) from a dataclass.""" + + def __init__(self, dataclass_name: str): + self.dataclass_name = dataclass_name + self.parameters: Dict[str, Dict[str, str]] = {} + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + """Visit class definitions to find the target dataclass.""" + + if node.name.value == self.dataclass_name: + body_statements = node.body.body + for index, field in enumerate(body_statements): + # Check if the statement is a simple statement (like a variable declaration) + if isinstance(field, cst.SimpleStatementLine): + for stmt in field.body: + # Check if it's an annotated assignment (typical for dataclass fields) + if isinstance(stmt, cst.AnnAssign) and isinstance(stmt.target, cst.Name): + param_name = stmt.target.value + param_type = cst.Module([]).code_for_node(stmt.annotation.annotation) + docstring = self._extract_docstring(body_statements, index) + self.parameters[param_name] = { + "type": param_type, + "docstring": docstring, + } + + @staticmethod + def _extract_docstring(body_statements: List[cst.CSTNode], field_index: int) -> str: + """Extract the docstring following a field definition.""" + if field_index + 1 < len(body_statements): + # Check if the next statement is a simple statement (like a string) + next_stmt = body_statements[field_index + 1] + if isinstance(next_stmt, cst.SimpleStatementLine): + for stmt in next_stmt.body: + # Check if the statement is a string expression (potential docstring) + if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.SimpleString): + return stmt.value.evaluated_value.strip() + # No docstring found or there's no statement after the field + return "" + + +class ModulesCollector(cst.CSTVisitor): + """Visitor that maps type names to their defining modules.""" + + def __init__(self): + self.type_to_module = {} + + def visit_ClassDef(self, node: cst.ClassDef): + """Map class definitions to the current module.""" + self.type_to_module[node.name.value] = DEFAULT_MODULE + + def visit_ImportFrom(self, node: cst.ImportFrom): + """Map imported types to their modules.""" + if node.module: + module_name = node.module.value + for alias in node.names: + self.type_to_module[alias.name.value] = module_name + + +class ArgumentsCollector(cst.CSTVisitor): + """Collects existing argument names from a method.""" + + def __init__(self, method_name: str): + self.method_name = method_name + self.existing_args: Set[str] = set() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if node.name.value == self.method_name: + self.existing_args.update( + param.name.value + for param in node.params.params + node.params.kwonly_params + if param.name.value != "self" + ) + + +#### TREE TRANSFORMERS + + +class AddParameters(cst.CSTTransformer): + """Updates a method by adding missing parameters and updating the docstring.""" + + def __init__(self, method_name: str, missing_params: Dict[str, Dict[str, str]]): + self.method_name = method_name + self.missing_params = missing_params + self.found_method = False + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + if original_node.name.value == self.method_name: + self.found_method = True + new_params = self._update_parameters(updated_node.params) + updated_body = self._update_docstring(updated_node.body) + return updated_node.with_changes(params=new_params, body=updated_body) + return updated_node + + def _update_parameters(self, params: cst.Parameters) -> cst.Parameters: + new_kwonly_params = list(params.kwonly_params) + existing_args = {param.name.value for param in params.params + params.kwonly_params} + + for param_name, param_info in self.missing_params.items(): + if param_name not in existing_args: + annotation = cst.Annotation(annotation=cst.parse_expression(param_info["type"])) + new_param = cst.Param( + name=cst.Name(param_name), + annotation=annotation, + default=cst.Name("None"), + ) + new_kwonly_params.append(new_param) + + return params.with_changes(kwonly_params=new_kwonly_params) + + def _update_docstring(self, body: cst.IndentedBlock) -> cst.IndentedBlock: + if not isinstance(body.body[0], cst.SimpleStatementLine) or not isinstance(body.body[0].body[0], cst.Expr): + return body + + docstring_expr = body.body[0].body[0] + if not isinstance(docstring_expr.value, cst.SimpleString): + return body + + docstring = docstring_expr.value.evaluated_value + updated_docstring = self._update_docstring_content(docstring) + new_docstring = cst.SimpleString(f'"""{updated_docstring}"""') + new_body = [body.body[0].with_changes(body=[docstring_expr.with_changes(value=new_docstring)])] + list( + body.body[1:] + ) + return body.with_changes(body=new_body) + + def _update_docstring_content(self, docstring: str) -> str: + docstring_lines = docstring.split("\n") + + # Step 1: find the right insertion index + args_index = next((i for i, line in enumerate(docstring_lines) if line.strip().lower() == "args:"), None) + # If there is no "Args:" section, insert it after the first section that is not empty and not a sub-section + if args_index is None: + insertion_index = next( + ( + i + for i, line in enumerate(docstring_lines) + if line.strip().lower() in ("returns:", "raises:", "examples:", "example:") + ), + len(docstring_lines), + ) + docstring_lines.insert(insertion_index, "Args:") + args_index = insertion_index + insertion_index += 1 + else: + # Find the next section (in this order: Returns, Raises, Example(s)) + next_section_index = next( + ( + i + for i, line in enumerate(docstring_lines) + if line.strip().lower() in ("returns:", "raises:", "example:", "examples:") + ), + None, + ) + if next_section_index is not None: + # If there's a blank line before "Returns:", insert before that blank line + if next_section_index > 0 and docstring_lines[next_section_index - 1].strip() == "": + insertion_index = next_section_index - 1 + else: + # If there's no blank line, insert at the "Returns:" line and add a blank line after insertion + insertion_index = next_section_index + docstring_lines.insert(insertion_index, "") + else: + # If there's no next section, insert at the end + insertion_index = len(docstring_lines) + + # Step 2: format the parameter docstring + # Calculate the base indentation + base_indentation = docstring_lines[args_index][ + : len(docstring_lines[args_index]) - len(docstring_lines[args_index].lstrip()) + ] + param_indentation = base_indentation + " " # Indent parameters under "Args:" + description_indentation = param_indentation + " " # Indent descriptions under parameter names + + param_docs = [] + for param_name, param_info in self.missing_params.items(): + param_type_str = param_info["type"].replace("Optional[", "").rstrip("]") + optional_str = "*optional*" if "Optional[" in param_info["type"] else "" + param_docstring = (param_info.get("docstring") or "").strip() + + # Clean up the docstring to remove extra spaces + param_docstring = " ".join(param_docstring.split()) + + # Prepare the parameter line + param_line = f"{param_indentation}{param_name} (`{param_type_str}`, {optional_str}):" + + # Wrap the parameter docstring + wrapped_description = textwrap.fill( + param_docstring, + width=119, + initial_indent=description_indentation, + subsequent_indent=description_indentation, + ) + + # Combine parameter line and description + if param_docstring: + param_doc = f"{param_line}\n{wrapped_description}" + else: + param_doc = param_line + + param_docs.append(param_doc) + + # Step 3: insert the new parameter docs into the docstring + docstring_lines[insertion_index:insertion_index] = param_docs + return "\n".join(docstring_lines) + + +class AddImports(cst.CSTTransformer): + """Transformer that adds import statements to the module.""" + + def __init__(self, imports_to_add: List[cst.BaseStatement]): + self.imports_to_add = imports_to_add + self.added = False + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + """Insert the import statements into the module.""" + # If imports were already added, don't add them again + if self.added: + return updated_node + insertion_index = 0 + # Find the index where to insert the imports: make sure the imports are inserted before any code and after all imports (not necessary, we can remove/simplify this part) + for idx, stmt in enumerate(updated_node.body): + if not isinstance(stmt, cst.SimpleStatementLine): + insertion_index = idx + break + elif not isinstance(stmt.body[0], (cst.Import, cst.ImportFrom)): + insertion_index = idx + break + # Insert the imports + new_body = ( + list(updated_node.body[:insertion_index]) + + list(self.imports_to_add) + + list(updated_node.body[insertion_index:]) + ) + self.added = True + return updated_node.with_changes(body=new_body) + + +#### UTILS + + +def check_missing_parameters( + inference_client_module: cst.Module, + parameters_module: cst.Module, + method_name: str, + parameter_type_name: str, +) -> Dict[str, Dict[str, str]]: + # Get parameters from the parameters module + params_collector = DataclassFieldCollector(parameter_type_name) + parameters_module.visit(params_collector) + parameters = params_collector.parameters + + # Get existing arguments from the method + method_argument_collector = ArgumentsCollector(method_name) + inference_client_module.visit(method_argument_collector) + existing_args = method_argument_collector.existing_args + missing_params = {k: v for k, v in parameters.items() if k not in existing_args} + return missing_params + + +def get_imports_to_add( + parameters: Dict[str, Dict[str, str]], + parameters_module: cst.Module, + inference_client_module: cst.Module, +) -> Dict[str, List[str]]: + """ + Get the needed imports for missing parameters. + + Args: + parameters (Dict[str, Dict[str, str]]): Dictionary of parameters with their type and docstring. + eg: {"function_to_apply": {"type": "ClassificationOutputTransform", "docstring": "Function to apply to the input."}} + parameters_module (cst.Module): The module where the parameters are defined. + inference_client_module (cst.Module): The module of the inference client. + + Returns: + Dict[str, List[str]]: A dictionary mapping modules to list of types to import. + eg: {"huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} + """ + # Collect all type names from parameter annotations + types_to_import = set() + for param_info in parameters.values(): + types_to_import.update(_collect_type_hints_from_annotation(param_info["type"])) + + # Gather existing imports in the inference client module + context = CodemodContext() + gather_visitor = GatherImportsVisitor(context) + inference_client_module.visit(gather_visitor) + + # Map types to their defining modules in the parameters module + module_collector = ModulesCollector() + parameters_module.visit(module_collector) + + # Determine which imports are needed + needed_imports = {} + for type_name in types_to_import: + types_to_modules = module_collector.type_to_module + module = types_to_modules.get(type_name, DEFAULT_MODULE) + # Maybe no need to check that since the code formatter will handle duplicate imports? + if module not in gather_visitor.object_mapping or type_name not in gather_visitor.object_mapping[module]: + needed_imports.setdefault(module, []).append(type_name) + return needed_imports + + +def _generate_import_statements(import_dict: Dict[str, List[str]]) -> str: + """ + Generate import statements from a dictionary of needed imports. + + Args: + import_dict (Dict[str, List[str]]): Dictionary mapping modules to list of types to import. + eg: {"typing": ["List", "Dict"], "huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} + + Returns: + str: The import statements as a string. + """ + import_statements = [] + for module, imports in import_dict.items(): + if imports: + import_list = ", ".join(imports) + import_statements.append(f"from {module} import {import_list}") + else: + import_statements.append(f"import {module}") + return "\n".join(import_statements) + + +# TODO: Needs to be improved, maybe using `typing.get_type_hints` instead (we gonna need to access the method though)? +def _collect_type_hints_from_annotation(annotation_str: str) -> Set[str]: + """ + Collect type hints from an annotation string. + + Args: + annotation_str (str): The annotation string. + + Returns: + Set[str]: A set of type hints. + """ + type_string = annotation_str.replace(" ", "") + builtin_types = {d for d in dir(builtins) if isinstance(getattr(builtins, d), type)} + types = re.findall(r"\w+|'[^']+'|\"[^\"]+\"", type_string) + extracted_types = {t.strip("\"'") for t in types if t.strip("\"'") not in builtin_types} + return extracted_types + + +def _get_parameter_type_name(method_name: str) -> Optional[str]: + file_path = INFERENCE_TYPES_PATH / f"{method_name}.py" + if not file_path.is_file(): + print(f"File not found: {file_path}") + return None + + content = file_path.read_text(encoding="utf-8") + match = PARAMETERS_DATACLASS_REGEX.search(content) + + return match.group(1) if match else None + + +def _parse_module_from_file(filepath: Path) -> Optional[cst.Module]: + try: + code = filepath.read_text(encoding="utf-8") + return cst.parse_module(code) + except FileNotFoundError: + print(f"File not found: {filepath}") + except cst.ParserSyntaxError as e: + print(f"Syntax error while parsing {filepath}: {e}") + return None + + +def _check_parameters(method_params: Dict[str, str], update: bool) -> NoReturn: + """ + Check if task methods have missing parameters and update the InferenceClient source code if needed. + + Args: + method_params (Dict[str, str]): Dictionary mapping method names to their parameters dataclass names. + update (bool): Whether to update the InferenceClient source code if missing parameters are found. + """ + merged_imports = defaultdict(set) + logs = [] + inference_client_filename = INFERENCE_CLIENT_FILE + # Read and parse the inference client module + inference_client_module = _parse_module_from_file(inference_client_filename) + modified_module = inference_client_module + has_changes = False + for method_name, parameter_type_name in method_params.items(): + parameters_filename = INFERENCE_TYPES_PATH / f"{method_name}.py" + + # Read and parse the parameters module + parameters_module = _parse_module_from_file(parameters_filename) + + # Check if the method has missing parameters + missing_params = check_missing_parameters(modified_module, parameters_module, method_name, parameter_type_name) + if not missing_params: + continue + if update: + ## Get missing imports to add + needed_imports = get_imports_to_add(missing_params, parameters_module, modified_module) + for module, imports_to_add in needed_imports.items(): + merged_imports[module].update(imports_to_add) + # Update method parameters and docstring + modified_module = modified_module.visit(AddParameters(method_name, missing_params)) + has_changes = True + else: + logs.append(f"❌ Missing parameters found in `{method_name}`.") + + if has_changes: + if merged_imports: + import_statements = _generate_import_statements(merged_imports) + imports_to_add = cst.parse_module(import_statements).body + # Update inference client module with the missing imports + modified_module = modified_module.visit(AddImports(imports_to_add)) + # Format the updated source code + formatted_source_code = format_source_code(modified_module.code) + INFERENCE_CLIENT_FILE.write_text(formatted_source_code) + + if len(logs) > 0: + for log in logs: + print(log) + print( + "❌ Mismatch between between parameters defined in tasks methods signature in " + "`./src/huggingface_hub/inference/_client.py` and parameters defined in " + "`./src/huggingface_hub/inference/_generated/types.py \n" + "Please run `make inference_update` or `python utils/generate_task_parameters.py --update" + ) + exit(1) + else: + if update: + print( + "✅ InferenceClient source code has been updated in" + " `./src/huggingface_hub/inference/_client.py`.\n Please make sure the changes are" + " accurate and commit them." + ) + else: + print("✅ All good!") + exit(0) + + +def update_inference_client(update: bool): + print(f"🙈 Skipping the following tasks: {TASKS_TO_SKIP}") + # Get all tasks from the ./src/huggingface_hub/inference/_generated/types/ + tasks = set() + for file in INFERENCE_TYPES_PATH.glob("*.py"): + if file.stem not in TASKS_TO_SKIP: + tasks.add(file.stem) + + # Construct a mapping between method names and their parameters dataclass names + method_params = {} + for method_name, _ in inspect.getmembers(InferenceClient, predicate=inspect.isfunction): + if method_name.startswith("_") or method_name not in tasks: + continue + parameter_type_name = _get_parameter_type_name(method_name) + if parameter_type_name is not None: + method_params[method_name] = parameter_type_name + _check_parameters(method_params, update=update) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--update", + action="store_true", + help=("Whether to update `./src/huggingface_hub/inference/_client.py` if parameters are missing."), + ) + args = parser.parse_args() + update_inference_client(update=args.update) diff --git a/utils/helpers.py b/utils/helpers.py index 34f12e11f3..a2dcecef6e 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -14,10 +14,16 @@ # limitations under the License. """Contains helpers used by the scripts in `./utils`.""" +import subprocess +import tempfile from pathlib import Path +from ruff.__main__ import find_ruff_bin + def check_and_update_file_content(file: Path, expected_content: str, update: bool): + # Ensure the expected content ends with a newline to satisfy end-of-file-fixer hook + expected_content = expected_content.rstrip("\n") + "\n" content = file.read_text() if file.exists() else None if content != expected_content: if update: @@ -26,3 +32,19 @@ def check_and_update_file_content(file: Path, expected_content: str, update: boo else: print(f"❌ Expected content mismatch in {file}.") exit(1) + + +def format_source_code(code: str) -> str: + """Format the generated source code using Ruff.""" + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "tmp.py" + filepath.write_text(code) + ruff_bin = find_ruff_bin() + if not ruff_bin: + raise FileNotFoundError("Ruff executable not found.") + try: + subprocess.run([ruff_bin, "check", str(filepath), "--fix", "--quiet"], check=True) + subprocess.run([ruff_bin, "format", str(filepath), "--quiet"], check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Error running Ruff: {e}") + return filepath.read_text()