Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 193 additions & 4 deletions src/llamafactory/data/mm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,194 @@ def process_messages(
return messages


@dataclass
class Gemma4Plugin(BasePlugin):
r"""Plugin for the Gemma4 multimodal model."""

@override
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
results, fps_per_video, durations, frames_indices = [], [], [], []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
frames_indices.append(list(range(len(frames))))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for correctly calculate timestamps
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())

if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))

frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)

return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}

@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, Union[list[int], "torch.Tensor"]]:
image_processor = getattr(processor, "image_processor", None)
video_processor = getattr(processor, "video_processor", None)
feature_extractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}

if len(images) != 0:
regularized = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(regularized, return_tensors="pt"))

if len(videos) != 0:
video_data = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
]
mm_inputs.update(
video_processor(
videos=video_data["videos"],
video_metadata=video_metadata,
return_tensors="pt",
return_metadata=True,
do_sample_frames=False,
)
)

if len(audios) != 0: # only for gemma4n
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]

mm_inputs.update(
feature_extractor(
audios,
padding="max_length",
return_tensors="pt",
)
)

return mm_inputs

@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)

boi_token: str = getattr(processor, "boi_token")
eoi_token: str = getattr(processor, "eoi_token")
boa_token: str = getattr(processor, "boa_token")
eoa_token: str = getattr(processor, "eoa_token")
image_token: str = getattr(processor, "image_token")
video_token: str = getattr(processor, "video_token")
audio_token: str = getattr(processor, "audio_token")

if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_image_soft_tokens: list[int] = list(
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
)
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
video_metadata = mm_inputs.get("video_metadata", [])
else:
num_image_soft_tokens = [1] * len(images)
num_video_soft_tokens = [1] * len(videos)
video_metadata = [None] * len(videos)

audio_iter = iter(audios)
image_iter = iter(num_image_soft_tokens)
video_iter = iter(zip(num_video_soft_tokens, video_metadata))

for message in messages:
content = message["content"]

while IMAGE_PLACEHOLDER in content:
n = next(image_iter)
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)

while VIDEO_PLACEHOLDER in content:
num_soft_tokens_per_frame, metadata = next(video_iter)
if self.expand_mm_tokens:
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
video_str = " ".join(frame_strs)
else:
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)

while AUDIO_PLACEHOLDER in content:
current_audio = next(audio_iter)
if self.expand_mm_tokens:
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
else:
audio_str = f"{boa_token}{audio_token}{eoa_token}"

content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)

message["content"] = content

return messages

@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# Pop metadata keys that must not be passed to the model.
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
mm_inputs.pop(key, None)

mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)

return mm_inputs


@dataclass
class InternVLPlugin(BasePlugin):
@override
Expand Down Expand Up @@ -1505,7 +1693,7 @@ def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "Regulariz
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for qwen3vl video timestamp calculation
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
Expand Down Expand Up @@ -1642,7 +1830,7 @@ def _get_mm_inputs(
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
]
mm_inputs.update(
Expand Down Expand Up @@ -1683,7 +1871,7 @@ def process_messages(
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
video_metadata = mm_inputs.get("video_metadata", {})
video_metadata = mm_inputs.get("video_metadata", [])

else:
image_grid_thw = [None] * len(images)
Expand Down Expand Up @@ -2206,8 +2394,9 @@ def process_messages(
"base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
"gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin,
"gemma3n": Gemma3nPlugin,
"gemma4": Gemma4Plugin,
"glm4v": GLM4VPlugin,
"intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,
Expand Down
49 changes: 49 additions & 0 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,55 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
)


register_template(
name="gemma4",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
), # seem not consistent with the chattemplate
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
"gemma4",
image_token="<|image|>",
video_token="<|video|>",
),
template_class=ReasoningTemplate,
)


register_template(
name="gemma4n",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
),
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
"gemma4",
image_token="<|image|>",
video_token="<|video|>",
audio_token="<|audio|>",
),
template_class=ReasoningTemplate,
)


register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
Expand Down
Loading
Loading