Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/sglang/multimodal_gen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

from sglang.multimodal_gen.configs.pipelines import PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs import PipelineConfig
from sglang.multimodal_gen.configs.sample import SamplingParams
from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

from sglang.multimodal_gen.configs.pipelines.base import (
from sglang.multimodal_gen.configs.pipeline_configs.base import (
PipelineConfig,
SlidingTileAttnConfig,
)
from sglang.multimodal_gen.configs.pipelines.flux import FluxPipelineConfig
from sglang.multimodal_gen.configs.pipelines.hunyuan import (
from sglang.multimodal_gen.configs.pipeline_configs.flux import FluxPipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import (
FastHunyuanConfig,
HunyuanConfig,
)
from sglang.multimodal_gen.configs.pipelines.stepvideo import StepVideoT2VConfig
from sglang.multimodal_gen.configs.pipelines.wan import (
from sglang.multimodal_gen.configs.pipeline_configs.stepvideo import StepVideoT2VConfig
from sglang.multimodal_gen.configs.pipeline_configs.wan import (
SelfForcingWanT2V480PConfig,
WanI2V480PConfig,
WanI2V720PConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
T5Config,
)
from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import (
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ModelTaskType,
PipelineConfig,
preprocess_text,
)
from sglang.multimodal_gen.configs.pipelines.hunyuan import (
from sglang.multimodal_gen.configs.pipeline_configs.hunyuan import (
clip_postprocess_text,
clip_preprocess_text,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
LlamaConfig,
)
from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs import PipelineConfig

PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig
from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig
from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import ModelTaskType, PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ModelTaskType,
PipelineConfig,
)
from sglang.multimodal_gen.utils import calculate_dimensions


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits import StepVideoConfig
from sglang.multimodal_gen.configs.models.vaes import StepVideoVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
T5Config,
)
from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import ModelTaskType, PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ModelTaskType,
PipelineConfig,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)
Expand Down
68 changes: 28 additions & 40 deletions python/sglang/multimodal_gen/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from sglang.multimodal_gen.configs.pipelines import (
from sglang.multimodal_gen.configs.pipeline_configs import (
FastHunyuanConfig,
FluxPipelineConfig,
HunyuanConfig,
Expand All @@ -25,12 +25,12 @@
WanT2V480PConfig,
WanT2V720PConfig,
)
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (
QwenImageEditPipelineConfig,
QwenImagePipelineConfig,
)
from sglang.multimodal_gen.configs.pipelines.wan import (
from sglang.multimodal_gen.configs.pipeline_configs.wan import (
FastWan2_1_T2V_480P_Config,
FastWan2_2_TI2V_5B_Config,
Wan2_2_I2V_A14B_Config,
Expand All @@ -55,7 +55,7 @@
WanT2V_1_3B_SamplingParams,
WanT2V_14B_SamplingParams,
)
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
Expand All @@ -74,49 +74,37 @@
def _discover_and_register_pipelines():
"""
Automatically discover and register all ComposedPipelineBase subclasses.
This function scans the 'sglang.multimodal_gen.runtime.architectures' package,
This function scans the 'sglang.multimodal_gen.runtime.pipelines' package,
finds modules with an 'EntryClass' attribute, and maps the class's 'pipeline_name'
to the class itself in a global registry.
"""
if _PIPELINE_REGISTRY: # E-run only once
if _PIPELINE_REGISTRY: # run only once
return

package_name = "sglang.multimodal_gen.runtime.architectures"
package_name = "sglang.multimodal_gen.runtime.pipelines"
package = importlib.import_module(package_name)

for _, pipeline_type_str, ispkg in pkgutil.iter_modules(package.__path__):
for _, module_name, ispkg in pkgutil.walk_packages(
package.__path__, package.__name__ + "."
):
if not ispkg:
continue
pipeline_type_package_name = f"{package_name}.{pipeline_type_str}"
pipeline_type_package = importlib.import_module(pipeline_type_package_name)
for _, arch, ispkg_arch in pkgutil.iter_modules(pipeline_type_package.__path__):
if not ispkg_arch:
continue
arch_package_name = f"{pipeline_type_package_name}.{arch}"
arch_package = importlib.import_module(arch_package_name)
for _, module_name, ispkg_module in pkgutil.walk_packages(
arch_package.__path__, arch_package.__name__ + "."
):
if not ispkg_module:
pipeline_module = importlib.import_module(module_name)
if hasattr(pipeline_module, "EntryClass"):
entry_cls = pipeline_module.EntryClass
if not isinstance(entry_cls, list):
entry_cls_list = [entry_cls]
else:
entry_cls_list = entry_cls

for cls in entry_cls_list:
if hasattr(cls, "pipeline_name"):
if cls.pipeline_name in _PIPELINE_REGISTRY:
logger.warning(
f"Duplicate pipeline name '{cls.pipeline_name}' found. Overwriting."
)
_PIPELINE_REGISTRY[cls.pipeline_name] = cls
# else:
# logger.warning(
# f"Pipeline class {cls.__name__} does not have a 'pipeline_name' attribute."
# )
pipeline_module = importlib.import_module(module_name)
if hasattr(pipeline_module, "EntryClass"):
entry_cls = pipeline_module.EntryClass
entry_cls_list = (
[entry_cls] if not isinstance(entry_cls, list) else entry_cls
)

for cls in entry_cls_list:
if hasattr(cls, "pipeline_name"):
if cls.pipeline_name in _PIPELINE_REGISTRY:
logger.warning(
f"Duplicate pipeline name '{cls.pipeline_name}' found. Overwriting."
)
_PIPELINE_REGISTRY[cls.pipeline_name] = cls
logger.debug(
f"Registering pipelines complete, {len(_PIPELINE_REGISTRY)} pipelines registered"
)


# --- Part 2: Config Registration ---
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torchvision
from einops import rearrange

from sglang.multimodal_gen.runtime.pipelines import Req
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.pipelines_core import Req
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch

# Suppress verbose logging from imageio, which is triggered when saving images.
logging.getLogger("imageio").setLevel(logging.WARNING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
post_process_sample,
)
from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.scheduler_client import scheduler_client
from sglang.multimodal_gen.runtime.server_args import get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
post_process_sample,
)
from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.server_args import get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logging.getLogger("imageio_ffmpeg").setLevel(logging.WARNING)

from sglang.multimodal_gen.configs.sample.base import DataType, SamplingParams
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import shallow_asdict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

if TYPE_CHECKING:
from sglang.multimodal_gen.runtime.layers.attention import AttentionMetadata
from sglang.multimodal_gen.runtime.pipelines import Req
from sglang.multimodal_gen.runtime.pipelines_core import Req

logger = init_logger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
get_cfg_group,
get_tp_group,
)
from sglang.multimodal_gen.runtime.pipelines import Req, build_pipeline
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.pipelines_core import Req, build_pipeline
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs
from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch
from sglang.multimodal_gen.runtime.utils.logging_utils import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import zmq

from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.server_args import (
PortArgs,
ServerArgs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import zmq

from sglang.multimodal_gen.runtime.pipelines import Req
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.pipelines_core import Req
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.utils import init_logger

Expand Down
63 changes: 0 additions & 63 deletions python/sglang/multimodal_gen/runtime/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,64 +1 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
"""
Diffusion pipelines for sglang.multimodal_gen.

This package contains diffusion pipelines for generating videos and images.
"""

from typing import cast

from sglang.multimodal_gen.registry import get_model_info
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import Req
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
maybe_download_model,
verify_model_config_and_directory,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)


class PipelineWithLoRA(LoRAPipeline, ComposedPipelineBase):
"""Type for a pipeline that has both ComposedPipelineBase and LoRAPipeline functionality."""

pass


def build_pipeline(
server_args: ServerArgs,
) -> PipelineWithLoRA:
"""
Only works with valid hf diffusers configs. (model_index.json)
We want to build a pipeline based on the inference args mode_path:
1. download the model from the hub if it's not already downloaded
2. verify the model config and directory
3. based on the config, determine the pipeline class
"""
model_path = server_args.model_path
model_info = get_model_info(model_path)
if model_info is None:
raise ValueError(f"Unsupported model: {model_path}")

pipeline_cls = model_info.pipeline_cls

# instantiate the pipelines
pipeline = pipeline_cls(model_path, server_args)

logger.info("Pipelines instantiated")

return cast(PipelineWithLoRA, pipeline)


__all__ = [
"build_pipeline",
"ComposedPipelineBase",
"Req",
"LoRAPipeline",
]
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines.stages import (
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages import (
ConditioningStage,
DecodingStage,
DenoisingStage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"""


from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.stages import (
from sglang.multimodal_gen.runtime.pipelines_core.stages import (
ConditioningStage,
DecodingStage,
DenoisingStage,
Expand Down
Loading