Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 64 additions & 54 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import typing
import warnings
from pathlib import Path
from typing import Any, Callable, Optional, TypedDict, Union
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union

import numpy as np
import typing_extensions
Expand Down Expand Up @@ -386,14 +386,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
return_assistant_tokens_mask: Optional[bool] = False


class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
class ChatTemplateLoadKwargs(TypedDict, total=False):
"""
Keyword arguments for processor chat templates.
Keyword arguments used to load multimodal data in processor chat templates.

tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
Expand All @@ -415,13 +411,26 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs):
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
"""

tokenize: Optional[bool] = False
return_dict: Optional[bool] = False
num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None
sampling_rate: Optional[int] = 16_000
sample_indices_fn: Optional[Callable] = None
load_audio_from_video: Optional[bool] = False


class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False):
"""
Keyword arguments for processor's `apply_chat_template`.

tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
"""

tokenize: Optional[bool] = False
return_dict: Optional[bool] = False


class AllKwargsForChatTemplate(
Expand Down Expand Up @@ -1236,11 +1245,11 @@ def __call__(

def _process_messages_for_chat_template(
self,
conversation: list[list[dict[str, str]]],
batch_images: list[ImageInput],
batch_videos: list[VideoInput],
batch_video_metadata: list[list[dict[str, any]]],
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
conversation: List[List[Dict[str, str]]],
batch_images: List[ImageInput],
batch_videos: List[VideoInput],
batch_video_metadata: List[List[Dict[str, any]]],
**mm_load_kwargs: Unpack[ChatTemplateLoadKwargs],
):
"""
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
Expand Down Expand Up @@ -1311,18 +1320,18 @@ def apply_chat_template(
)

# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
# and for multimodal chat template
# and for multimodal data loading. Everything else will be used in `__call__`
tokenizer_template_kwargs = {}
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, tokenizer_value)
default_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, default_value)
tokenizer_template_kwargs[tokenizer_key] = value

chat_template_kwargs = {}
for key in ProcessorChatTemplateKwargs.__annotations__.keys():
processor_value = getattr(ProcessorChatTemplateKwargs, key, None)
value = kwargs.pop(key, processor_value)
chat_template_kwargs[key] = value
mm_load_kwargs = {}
for mm_load_key in ChatTemplateLoadKwargs.__annotations__.keys():
default_value = getattr(ChatTemplateLoadKwargs, mm_load_key, None)
value = kwargs.pop(mm_load_key, default_value)
mm_load_kwargs[mm_load_key] = value
Comment on lines +1330 to +1334
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really related to this PR, but this code to match kwargs with the TypedDicts feels quite long and confusing, especially if it's used multiple times. Is there some cleaner way to populate a dict with default values that can be overridden by kwargs - maybe use custom classes/dataclasses instead, or add a helper method to the TypedDicts that gets inherited?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is getting out of hand. I am planning to refactor this and the new video loading a bit in subsequent PRs. In general it looks now we are over-abusing TypedDict for what it usually is not used, so I will consider doing something else


if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
Expand All @@ -1333,13 +1342,8 @@ def apply_chat_template(
is_batched = False
conversations = [conversation]

num_frames = chat_template_kwargs.get("num_frames")
video_fps = chat_template_kwargs.get("video_fps")
video_load_backend = chat_template_kwargs.get("video_load_backend")
tokenize = chat_template_kwargs.get("tokenize")
return_dict = chat_template_kwargs.get("return_dict")
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
sampling_rate = chat_template_kwargs.pop("sampling_rate")
tokenize = kwargs.pop("tokenize", False)
return_dict = kwargs.pop("return_dict", False)

if tokenize:
batch_images, batch_videos = [], []
Expand Down Expand Up @@ -1369,31 +1373,37 @@ def apply_chat_template(
if key in vision_info and vision_info["type"] == "video"
]

# Audio models do not accept nested list of audios (yet!)
for fname in audio_fnames:
batch_audios.append(load_audio(fname, sampling_rate=sampling_rate))
for fname in image_fnames:
images.append(load_image(fname))
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
)
else:
video, metadata = load_video(
fname,
num_frames=num_frames,
fps=video_fps,
backend=video_load_backend,
sample_indices_fn=sample_indices_fn,
)
videos.append(video)
video_metadata.append(metadata)

# Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
if not mm_load_kwargs["load_audio_from_video"]:
for fname in audio_fnames:
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
else:
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
audios = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
)
else:
video, metadata = load_video(
fname,
num_frames=mm_load_kwargs["num_frames"],
fps=mm_load_kwargs["video_fps"],
backend=mm_load_kwargs["video_load_backend"],
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
)
audios = load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])
batch_audios.append(audios)
videos.append(video)
video_metadata.append(metadata)

# Currently all processors can accept nested list of batches, but not flat list of visuals
# So we'll make a batched list of images and let the processor handle it
Expand All @@ -1409,7 +1419,7 @@ def apply_chat_template(
batch_images=batch_images,
batch_videos=batch_videos,
batch_video_metadata=batch_video_metadata,
**chat_template_kwargs,
**mm_load_kwargs,
)

prompt = self.tokenizer.apply_chat_template(
Expand Down Expand Up @@ -1438,7 +1448,7 @@ def apply_chat_template(
text=prompt,
images=batch_images if batch_images else None,
videos=batch_videos if batch_videos else None,
audios=batch_audios if batch_audios else None,
audio=batch_audios if batch_audios else None,
**kwargs,
)
if return_dict:
Expand Down
69 changes: 65 additions & 4 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,10 +1097,7 @@ def test_chat_template_video_custom_sampling(self):
{
"role": "user",
"content": [
{
"type": "video",
"path": video_file_path,
},
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
Expand Down Expand Up @@ -1189,6 +1186,70 @@ def _process_messages_for_chat_template(
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)

@require_librosa
@require_av
def test_audio_chat_template_from_video(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")

signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest(f"{self.processor_class} does not suport video inputs")

if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")

video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "Which of these animals is making the sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is a cow."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "Is it the same sound?"},
],
},
]

formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1) # batch size=1

out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="np",
load_audio_from_video=True,
)
self.assertTrue(self.audio_input_name in out_dict)
self.assertTrue(self.video_input_name in out_dict)

# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
self.assertEqual(len(out_dict[self.video_input_name]), 1) # 1 video in the conversation

@require_librosa
def test_audio_chat_template_single(self):
processor = self.get_processor()
Expand Down