Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 0 additions & 203 deletions python/sglang/multimodal_gen/configs/configs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

import dataclasses
from enum import Enum
from typing import Any, Optional

from sglang.multimodal_gen.configs.utils import update_config_from_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import FlexibleArgumentParser, StoreBoolean

logger = init_logger(__name__)

Expand Down Expand Up @@ -57,202 +53,3 @@ def from_string(cls, value: str) -> "VideoLoaderType":
def choices(cls) -> list[str]:
"""Get all available choices as strings for argparse."""
return [video_loader.value for video_loader in cls]


@dataclasses.dataclass
class PreprocessConfig:
"""Configuration for preprocessing operations."""

# Model and dataset configuration
model_path: str = ""
dataset_path: str = ""
dataset_type: DatasetType = DatasetType.HF
dataset_output_dir: str = "./output"

# Dataloader configuration
dataloader_num_workers: int = 1
preprocess_video_batch_size: int = 2

# Saver configuration
samples_per_file: int = 64
flush_frequency: int = 256

# Video processing parameters
video_loader_type: VideoLoaderType = VideoLoaderType.TORCHCODEC
max_height: int = 480
max_width: int = 848
num_frames: int = 163
video_length_tolerance_range: float = 2.0
train_fps: int = 30
speed_factor: float = 1.0
drop_short_ratio: float = 1.0
do_temporal_sample: bool = False

# Model configuration
training_cfg_rate: float = 0.0

# framework configuration
seed: int = 42

@staticmethod
def add_cli_args(
parser: FlexibleArgumentParser, prefix: str = "preprocess"
) -> FlexibleArgumentParser:
"""Add preprocessing configuration arguments to the parser."""
prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else ""

preprocess_args = parser.add_argument_group("Preprocessing Arguments")
# Model & Dataset
preprocess_args.add_argument(
f"--{prefix_with_dot}model-path",
type=str,
default=PreprocessConfig.model_path,
help="Path to the model for preprocessing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}dataset-path",
type=str,
default=PreprocessConfig.dataset_path,
help="Path to the dataset directory for preprocessing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}dataset-type",
type=str,
choices=DatasetType.choices(),
default=PreprocessConfig.dataset_type.value,
help="Type of the dataset",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}dataset-output-dir",
type=str,
default=PreprocessConfig.dataset_output_dir,
help="The output directory where the dataset will be written.",
)

# Dataloader
preprocess_args.add_argument(
f"--{prefix_with_dot}dataloader-num-workers",
type=int,
default=PreprocessConfig.dataloader_num_workers,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}preprocess-video-batch-size",
type=int,
default=PreprocessConfig.preprocess_video_batch_size,
help="Batch size (per device) for the training dataloader.",
)

# Saver
preprocess_args.add_argument(
f"--{prefix_with_dot}samples-per-file",
type=int,
default=PreprocessConfig.samples_per_file,
help="Number of samples per output file",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}flush-frequency",
type=int,
default=PreprocessConfig.flush_frequency,
help="How often to save to parquet files",
)

# Video processing parameters
preprocess_args.add_argument(
f"--{prefix_with_dot}video-loader-type",
type=str,
choices=VideoLoaderType.choices(),
default=PreprocessConfig.video_loader_type.value,
help="Type of the video loader",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}max-height",
type=int,
default=PreprocessConfig.max_height,
help="Maximum height for video processing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}max-width",
type=int,
default=PreprocessConfig.max_width,
help="Maximum width for video processing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}num-frames",
type=int,
default=PreprocessConfig.num_frames,
help="Number of frames to process",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}video-length-tolerance-range",
type=float,
default=PreprocessConfig.video_length_tolerance_range,
help="Video length tolerance range",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}train-fps",
type=int,
default=PreprocessConfig.train_fps,
help="Training FPS",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}speed-factor",
type=float,
default=PreprocessConfig.speed_factor,
help="Speed factor for video processing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}drop-short-ratio",
type=float,
default=PreprocessConfig.drop_short_ratio,
help="Ratio for dropping short videos",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}do-temporal-sample",
action=StoreBoolean,
default=PreprocessConfig.do_temporal_sample,
help="Whether to do temporal sampling",
)

# Model Training configuration
preprocess_args.add_argument(
f"--{prefix_with_dot}training-cfg-rate",
type=float,
default=PreprocessConfig.training_cfg_rate,
help="Training CFG rate",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}seed",
type=int,
default=PreprocessConfig.seed,
help="Seed for random number generator",
)

return parser

@classmethod
def from_kwargs(cls, kwargs: dict[str, Any]) -> Optional["PreprocessConfig"]:
"""Create PreprocessConfig from keyword arguments."""
if "dataset_type" in kwargs and isinstance(kwargs["dataset_type"], str):
kwargs["dataset_type"] = DatasetType.from_string(kwargs["dataset_type"])
if "video_loader_type" in kwargs and isinstance(
kwargs["video_loader_type"], str
):
kwargs["video_loader_type"] = VideoLoaderType.from_string(
kwargs["video_loader_type"]
)

preprocess_config = cls()
if not update_config_from_args(
preprocess_config, kwargs, prefix="preprocess", pop_args=True
):
return None
return preprocess_config

def check_preprocess_config(self) -> None:
if self.dataset_path == "":
raise ValueError("dataset_path must be set for preprocess mode")
if self.samples_per_file <= 0:
raise ValueError("samples_per_file must be greater than 0")
if self.flush_frequency <= 0:
raise ValueError("flush_frequency must be greater than 0")
9 changes: 0 additions & 9 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from enum import Enum
from typing import Any, Optional

from sglang.multimodal_gen.configs.configs import PreprocessConfig
from sglang.multimodal_gen.configs.pipelines import FluxPipelineConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig, STA_Mode
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
Expand Down Expand Up @@ -251,7 +250,6 @@ class ServerArgs:
dist_timeout: int | None = None # timeout for torch.distributed

pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False)
preprocess_config: PreprocessConfig | None = None

# LoRA parameters
# (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated.
Expand Down Expand Up @@ -626,9 +624,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
# Add pipeline configuration arguments
PipelineConfig.add_cli_args(parser)

# Add preprocessing configuration arguments
PreprocessConfig.add_cli_args(parser)

# Logging
parser.add_argument(
"--log-level",
Expand Down Expand Up @@ -734,9 +729,6 @@ def from_dict(cls, kwargs: dict[str, Any]) -> "ServerArgs":
pipeline_config = PipelineConfig.from_kwargs(kwargs)
logger.debug(f"Using PipelineConfig: {type(pipeline_config)}")
server_args_kwargs["pipeline_config"] = pipeline_config
elif attr == "preprocess_config":
preprocess_config = PreprocessConfig.from_kwargs(kwargs)
server_args_kwargs["preprocess_config"] = preprocess_config
elif attr in kwargs:
server_args_kwargs[attr] = kwargs[attr]

Expand Down Expand Up @@ -772,7 +764,6 @@ def from_kwargs(cls, **kwargs: Any) -> "ServerArgs":
kwargs["workload_type"] = WorkloadType.from_string(kwargs["workload_type"])

kwargs["pipeline_config"] = PipelineConfig.from_kwargs(kwargs)
kwargs["preprocess_config"] = PreprocessConfig.from_kwargs(kwargs)
return cls(**kwargs)

@staticmethod
Expand Down
Loading