-
Notifications
You must be signed in to change notification settings - Fork 32k
Add video classification pipeline #20151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2f39872
8d80c25
f161a90
a7e7cfd
7ebe5fe
9325079
aa44054
cf4f421
80bf526
8ed1b7d
2f3c93f
8f2a52c
349b780
e51ebb6
36f1fe1
623ceec
7c88d92
68b25e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| from io import BytesIO | ||
| from typing import List, Union | ||
|
|
||
| import requests | ||
|
|
||
| from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends | ||
| from .base import PIPELINE_INIT_ARGS, Pipeline | ||
|
|
||
|
|
||
| if is_decord_available(): | ||
| import numpy as np | ||
|
|
||
| from decord import VideoReader | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| @add_end_docstrings(PIPELINE_INIT_ARGS) | ||
| class VideoClassificationPipeline(Pipeline): | ||
| """ | ||
| Video classification pipeline using any `AutoModelForVideoClassification`. This pipeline predicts the class of a | ||
| video. | ||
|
|
||
| This video classification pipeline can currently be loaded from [`pipeline`] using the following task identifier: | ||
| `"video-classification"`. | ||
|
|
||
| See the list of available models on | ||
| [huggingface.co/models](https://huggingface.co/models?filter=video-classification). | ||
| """ | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| requires_backends(self, "decord") | ||
| self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING) | ||
|
|
||
| def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None): | ||
| preprocess_params = {} | ||
| if frame_sampling_rate is not None: | ||
| preprocess_params["frame_sampling_rate"] = frame_sampling_rate | ||
| if num_frames is not None: | ||
| preprocess_params["num_frames"] = num_frames | ||
|
|
||
| postprocess_params = {} | ||
| if top_k is not None: | ||
| postprocess_params["top_k"] = top_k | ||
| return preprocess_params, {}, postprocess_params | ||
|
|
||
| def __call__(self, videos: Union[str, List[str]], **kwargs): | ||
| """ | ||
| Assign labels to the video(s) passed as inputs. | ||
|
|
||
| Args: | ||
| videos (`str`, `List[str]`): | ||
| The pipeline handles three types of videos: | ||
|
|
||
| - A string containing a http link pointing to a video | ||
| - A string containing a local path to a video | ||
|
|
||
| The pipeline accepts either a single video or a batch of videos, which must then be passed as a string. | ||
| Videos in a batch must all be in the same format: all as http links or all as local paths. | ||
| top_k (`int`, *optional*, defaults to 5): | ||
| The number of top labels that will be returned by the pipeline. If the provided number is higher than | ||
| the number of labels available in the model configuration, it will default to the number of labels. | ||
| num_frames (`int`, *optional*, defaults to `self.model.config.num_frames`): | ||
| The number of frames sampled from the video to run the classification on. If not provided, will default | ||
| to the number of frames specified in the model configuration. | ||
| frame_sampling_rate (`int`, *optional*, defaults to 1): | ||
| The sampling rate used to select frames from the video. If not provided, will default to 1, i.e. every | ||
| frame will be used. | ||
|
|
||
| Return: | ||
| A dictionary or a list of dictionaries containing result. If the input is a single video, will return a | ||
| dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to | ||
| the videos. | ||
|
|
||
| The dictionaries contain the following keys: | ||
|
|
||
| - **label** (`str`) -- The label identified by the model. | ||
| - **score** (`int`) -- The score attributed by the model for that label. | ||
| """ | ||
| return super().__call__(videos, **kwargs) | ||
|
|
||
| def preprocess(self, video, num_frames=None, frame_sampling_rate=1): | ||
|
|
||
| if num_frames is None: | ||
| num_frames = self.model.config.num_frames | ||
|
|
||
| if video.startswith("http://") or video.startswith("https://"): | ||
| video = BytesIO(requests.get(video).content) | ||
|
|
||
| videoreader = VideoReader(video) | ||
| videoreader.seek(0) | ||
|
|
||
| start_idx = 0 | ||
| end_idx = num_frames * frame_sampling_rate - 1 | ||
| indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64) | ||
|
|
||
| video = videoreader.get_batch(indices).asnumpy() | ||
| video = list(video) | ||
|
||
|
|
||
| model_inputs = self.feature_extractor(video, return_tensors=self.framework) | ||
| return model_inputs | ||
|
|
||
| def _forward(self, model_inputs): | ||
| model_outputs = self.model(**model_inputs) | ||
| return model_outputs | ||
|
|
||
| def postprocess(self, model_outputs, top_k=5): | ||
| if top_k > self.model.config.num_labels: | ||
| top_k = self.model.config.num_labels | ||
|
|
||
| if self.framework == "pt": | ||
| probs = model_outputs.logits.softmax(-1)[0] | ||
| scores, ids = probs.topk(top_k) | ||
| else: | ||
| raise ValueError(f"Unsupported framework: {self.framework}") | ||
|
|
||
| scores = scores.tolist() | ||
| ids = ids.tolist() | ||
| return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.