Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class WanT2V720PConfig(WanT2V480PConfig):
class WanI2V480PConfig(WanT2V480PConfig, WanI2VCommonConfig):
"""Base configuration for Wan I2V 14B 480P pipeline architecture."""

max_area: int = 480 * 832
# WanConfig-specific parameters with defaults
task_type: ModelTaskType = ModelTaskType.I2V
# Precision for each component
Expand All @@ -130,6 +131,7 @@ def __post_init__(self) -> None:
class WanI2V720PConfig(WanI2V480PConfig):
"""Base configuration for Wan I2V 14B 720P pipeline architecture."""

max_area: int = 720 * 1280
# WanConfig-specific parameters with defaults

# Denoising stage
Expand Down
36 changes: 33 additions & 3 deletions python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import align_to
from sglang.multimodal_gen.utils import StoreBoolean, align_to

logger = init_logger(__name__)

Expand Down Expand Up @@ -137,6 +137,10 @@ class SamplingParams:
return_frames: bool = False
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
# if True, allow user params to override subclass-defined protected fields
override_protected_fields: bool = False
# whether to adjust num_frames for multi-GPU friendly splitting (default: True)
adjust_frames: bool = True

def _set_output_file_ext(self):
# add extension if needed
Expand Down Expand Up @@ -517,6 +521,25 @@ def add_cli_args(parser: Any) -> Any:
default=SamplingParams.return_trajectory_decoded,
help="Whether to return the decoded trajectory",
)
parser.add_argument(
"--override-protected-fields",
action="store_true",
default=SamplingParams.override_protected_fields,
help=(
"If set, allow user params to override fields defined in subclasses "
"(protected by default)."
),
)
parser.add_argument(
"--adjust-frames",
action=StoreBoolean,
default=SamplingParams.adjust_frames,
help=(
"Enable/disable adjusting num_frames to evenly split latent frames across GPUs "
"and satisfy model temporal constraints. Default: true. "
"Examples: --adjust-frames, --adjust-frames true, --adjust-frames false."
),
)
return parser

@classmethod
Expand All @@ -543,7 +566,7 @@ def get_cli_args(cls, args: argparse.Namespace):
def output_file_path(self):
return os.path.join(self.output_path, self.output_file_name)

def _merge_with_user_params(self, user_params):
def _merge_with_user_params(self, user_params: "SamplingParams"):
"""
Merges parameters from a user-provided SamplingParams object.

Expand All @@ -559,6 +582,11 @@ def _merge_with_user_params(self, user_params):
# user is not allowed to modify any param defined in the SamplingParams subclass
subclass_defined_fields = set(type(self).__annotations__.keys())

# global switch: if True, allow overriding protected fields
allow_override_protected = bool(
user_params.override_protected_fields or self.override_protected_fields
)

# Compare against current instance to avoid constructing a default instance
default_params = SamplingParams()

Expand All @@ -575,7 +603,9 @@ def _merge_with_user_params(self, user_params):
if field_name != "output_file_name"
else user_params.output_file_path is not None
)
if is_user_modified and field_name not in subclass_defined_fields:
if is_user_modified and (
allow_override_protected or field_name not in subclass_defined_fields
):
if hasattr(self, field_name):
setattr(self, field_name, user_value)

Expand Down
12 changes: 7 additions & 5 deletions python/sglang/multimodal_gen/runtime/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
diffusion models.
"""

import dataclasses

from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
Expand All @@ -26,11 +27,12 @@ def prepare_request(
Settle SamplingParams according to ServerArgs

"""
# Create a copy of inference args to avoid modifying the original
req = Req(
**shallow_asdict(sampling_params),
VSA_sparsity=server_args.VSA_sparsity,
)
# Create a copy of inference args to avoid modifying the original.
# Filter out fields not defined in Req to avoid unexpected-kw TypeError.
params_dict = shallow_asdict(sampling_params)
req_field_names = {f.name for f in dataclasses.fields(Req)}
filtered_params = {k: v for k, v in params_dict.items() if k in req_field_names}
req = Req(**filtered_params, VSA_sparsity=server_args.VSA_sparsity)
req.adjust_size(server_args)

if req.width <= 0 or req.height <= 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def preprocess_condition_image(
elif isinstance(server_args.pipeline_config, WanI2V480PConfig):
# TODO: could we merge with above?
# resize image only, Wan2.1 I2V
max_area = 720 * 1280
max_area = server_args.pipeline_config.max_area
aspect_ratio = condition_image_height / condition_image_width
mod_value = (
server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
Expand Down
Loading