Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 100 additions & 7 deletions python/sglang/multimodal_gen/configs/sample/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import dataclasses
import hashlib
import json
import math
import os.path
import re
import time
import unicodedata
import uuid
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any
Expand Down Expand Up @@ -137,7 +137,7 @@ class SamplingParams:
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep

def set_output_file_ext(self):
def _set_output_file_ext(self):
# add extension if needed
if not any(
self.output_file_name.endswith(ext)
Expand All @@ -147,7 +147,7 @@ def set_output_file_ext(self):
f"{self.output_file_name}.{self.data_type.get_default_extension()}"
)

def set_output_file_name(self):
def _set_output_file_name(self):
# settle output_file_name
if (
self.output_file_name is None
Expand Down Expand Up @@ -178,7 +178,7 @@ def set_output_file_name(self):
self.output_file_name = _sanitize_filename(self.output_file_name)

# Ensure a proper extension is present
self.set_output_file_ext()
self._set_output_file_ext()

def __post_init__(self) -> None:
assert self.num_frames >= 1
Expand All @@ -195,6 +195,93 @@ def check_sampling_param(self):
if self.prompt_path and not self.prompt_path.endswith(".txt"):
raise ValueError("prompt_path must be a txt file")

def adjust(
self,
server_args: ServerArgs,
):
"""
final adjustment, called after merged with user params
"""
pipeline_config = server_args.pipeline_config
if not isinstance(self.prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(self.prompt)}")

# Process negative prompt
if self.negative_prompt is not None and not self.negative_prompt.isspace():
# avoid stripping default negative prompt: ' ' for qwen-image
self.negative_prompt = self.negative_prompt.strip()

# Validate dimensions
if self.num_frames <= 0:
raise ValueError(
f"height, width, and num_frames must be positive integers, got "
f"height={self.height}, width={self.width}, "
f"num_frames={self.num_frames}"
)

if pipeline_config.task_type.is_image_gen():
# settle num_frames
logger.debug(f"Setting num_frames to 1 because this is a image-gen model")
self.num_frames = 1
self.data_type = DataType.IMAGE
else:
# Adjust number of frames based on number of GPUs for video task
use_temporal_scaling_frames = (
pipeline_config.vae_config.use_temporal_scaling_frames
)
num_frames = self.num_frames
num_gpus = server_args.num_gpus
temporal_scale_factor = (
pipeline_config.vae_config.arch_config.temporal_compression_ratio
)

if use_temporal_scaling_frames:
orig_latent_num_frames = (num_frames - 1) // temporal_scale_factor + 1
else: # stepvideo only
orig_latent_num_frames = self.num_frames // 17 * 3

if orig_latent_num_frames % server_args.num_gpus != 0:
# Adjust latent frames to be divisible by number of GPUs
if self.num_frames_round_down:
# Ensure we have at least 1 batch per GPU
new_latent_num_frames = (
max(1, (orig_latent_num_frames // num_gpus)) * num_gpus
)
else:
new_latent_num_frames = (
math.ceil(orig_latent_num_frames / num_gpus) * num_gpus
)

if use_temporal_scaling_frames:
# Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor
new_num_frames = (
new_latent_num_frames - 1
) * temporal_scale_factor + 1
else: # stepvideo only
# Find the least common multiple of 3 and num_gpus
divisor = math.lcm(3, num_gpus)
# Round up to the nearest multiple of this LCM
new_latent_num_frames = (
(new_latent_num_frames + divisor - 1) // divisor
) * divisor
# Convert back to actual frames using the StepVideo formula
new_num_frames = new_latent_num_frames // 3 * 17

logger.info(
"Adjusting number of frames from %s to %s based on number of GPUs (%s)",
self.num_frames,
new_num_frames,
server_args.num_gpus,
)
self.num_frames = new_num_frames

self.num_frames = server_args.pipeline_config.adjust_num_frames(
self.num_frames
)

self._set_output_file_name()
self.log(server_args=server_args)

def update(self, source_dict: dict[str, Any]) -> None:
for key, value in source_dict.items():
if hasattr(self, key):
Expand All @@ -220,9 +307,15 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams":
sampling_params = cls(**kwargs)
return sampling_params

def from_user_sampling_params(self, user_params):
sampling_params = deepcopy(self)
sampling_params._merge_with_user_params(user_params)
@staticmethod
def from_user_sampling_params_args(model_path: str, server_args, *args, **kwargs):
sampling_params = SamplingParams.from_pretrained(model_path)

user_sampling_params = SamplingParams(*args, **kwargs)
sampling_params._merge_with_user_params(user_sampling_params)

sampling_params.adjust(server_args)

return sampling_params

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,15 @@ def generate(
else DataType.VIDEO
)
pretrained_sampling_params.data_type = data_type
pretrained_sampling_params.set_output_file_name()
pretrained_sampling_params._set_output_file_name()
pretrained_sampling_params.adjust(self.server_args)

requests: list[Req] = []
for output_idx, p in enumerate(prompts):
current_sampling_params = deepcopy(pretrained_sampling_params)
current_sampling_params.prompt = p
requests.append(
prepare_request(
p,
server_args=self.server_args,
sampling_params=current_sampling_params,
)
Expand Down Expand Up @@ -310,21 +310,11 @@ def generate(
continue
for output_idx, sample in enumerate(output_batch.output):
num_outputs = len(output_batch.output)
output_file_name = req.output_file_name
if num_outputs > 1 and output_file_name:
base, ext = os.path.splitext(output_file_name)
output_file_name = f"{base}_{output_idx}{ext}"

save_path = (
os.path.join(req.output_path, output_file_name)
if output_file_name
else None
)
frames = self.post_process_sample(
sample,
fps=req.fps,
save_output=req.save_output,
save_file_path=save_path,
save_file_path=req.output_file_path(num_outputs, output_idx),
data_type=req.data_type,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ def _build_sampling_params_from_request(
) -> SamplingParams:
width, height = _parse_size(size)
ext = _choose_ext(output_format, background)

server_args = get_global_server_args()
sampling_params = SamplingParams.from_pretrained(server_args.model_path)

# Build user params
user_params = SamplingParams(
sampling_params = SamplingParams.from_user_sampling_params_args(
model_path=server_args.model_path,
request_id=request_id,
prompt=prompt,
image_path=image_path,
Expand All @@ -70,18 +68,9 @@ def _build_sampling_params_from_request(
height=height,
num_outputs_per_prompt=max(1, min(int(n or 1), 10)),
save_output=True,
server_args=server_args,
output_file_name=f"{request_id}.{ext}",
)

# Let SamplingParams auto-generate a file name, then force desired extension
sampling_params = sampling_params.from_user_sampling_params(user_params)
if not sampling_params.output_file_name:
sampling_params.output_file_name = request_id
if not sampling_params.output_file_name.endswith(f".{ext}"):
# strip any existing extension and apply desired one
base = sampling_params.output_file_name.rsplit(".", 1)[0]
sampling_params.output_file_name = f"{base}.{ext}"

sampling_params.log(server_args)
return sampling_params


Expand All @@ -107,7 +96,6 @@ def _build_req_from_sampling(s: SamplingParams) -> Req:
async def generations(
request: ImageGenerationsRequest,
):

request_id = generate_request_id()
sampling = _build_sampling_params_from_request(
request_id=request_id,
Expand All @@ -118,7 +106,6 @@ async def generations(
background=request.background,
)
batch = prepare_request(
prompt=request.prompt,
server_args=get_global_server_args(),
sampling_params=sampling,
)
Expand Down Expand Up @@ -175,7 +162,6 @@ async def edits(
background: Optional[str] = Form("auto"),
user: Optional[str] = Form(None),
):

request_id = generate_request_id()
# Resolve images from either `image` or `image[]` (OpenAI SDK sends `image[]` when list is provided)
images = image or image_array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
router = APIRouter(prefix="/v1/videos", tags=["videos"])


# NOTE(mick): the sampling params needs to be further adjusted
# FIXME: duplicated with the one in `image_api.py`
def _build_sampling_params_from_request(
request_id: str, request: VideoGenerationsRequest
) -> SamplingParams:
Expand All @@ -56,9 +58,8 @@ def _build_sampling_params_from_request(
request.num_frames if request.num_frames is not None else derived_num_frames
)
server_args = get_global_server_args()
# TODO: should we cache this sampling_params?
sampling_params = SamplingParams.from_pretrained(server_args.model_path)
user_params = SamplingParams(
sampling_params = SamplingParams.from_user_sampling_params_args(
model_path=server_args.model_path,
request_id=request_id,
prompt=request.prompt,
num_frames=num_frames,
Expand All @@ -67,10 +68,10 @@ def _build_sampling_params_from_request(
height=height,
image_path=request.input_reference,
save_output=True,
server_args=server_args,
output_file_name=request_id,
)
sampling_params = sampling_params.from_user_sampling_params(user_params)
sampling_params.set_output_file_name()
sampling_params.log(server_args)

return sampling_params


Expand Down Expand Up @@ -195,7 +196,6 @@ async def create_video(

# Build Req for scheduler
batch = prepare_request(
prompt=req.prompt,
server_args=get_global_server_args(),
sampling_params=sampling_params,
)
Expand Down
Loading
Loading