Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,10 @@ def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None:
self.patch_size = config.patch_size
self.text_len = config.text_len

# Could have been class attribute, but the type of block is decide based on
# attn_backend, therefore we put it here.
self._repeated_blocks = []

# 1. Patch & position embedding
self.patch_embedding = PatchEmbed(
in_chans=config.in_channels,
Expand All @@ -718,6 +722,7 @@ def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None:
if (attn_backend and attn_backend.lower() == "video_sparse_attn")
else WanTransformerBlock
)
self._repeated_blocks.append(transformer_block.__class__.__name__)
self.blocks = nn.ModuleList(
[
transformer_block(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,40 @@ def _maybe_enable_torch_compile(self, module: object) -> None:
except ImportError:
pass
mode = os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs")
logger.info(f"Compiling transformer with mode: {mode}")
# TODO(triple-mu): support customized fullgraph and dynamic in the future
module.compile(mode=mode, fullgraph=False, dynamic=None)
fullgraph = False
dynamic = None
regional = self.server_args.regional_compile
logger.info(
f"Compiling transformer with mode: {mode}, fullgraph: {fullgraph}, dynamic: {dynamic}, regional: {regional}"
)
if regional:
self.regionally_compile(
module, mode=mode, fullgraph=fullgraph, dynamic=dynamic
)
else:
# TODO(triple-mu): support customized fullgraph and dynamic in the future
module.compile(mode=mode, fullgraph=fullgraph, dynamic=dynamic)

def regionally_compile(self, module, *args, **kwargs):
repeated_blocks = getattr(module, "_repeated_blocks", None)

if not repeated_blocks:
logger.warning(
"`_repeated_blocks` attribute is empty. "
f"Set `_repeated_blocks` for the class `{module.__class__.__name__}` to benefit from faster compilation. "
)
return

has_compiled_region = False
for submod in module.modules():
if submod.__class__.__name__ in repeated_blocks:
submod.compile(*args, **kwargs)
has_compiled_region = True

if not has_compiled_region:
logger.warning(
f"Regional compilation failed because {repeated_blocks} classes are not found in the model."
)

def _maybe_enable_cache_dit(self, num_inference_steps: int, batch: Req) -> None:
"""Enable cache-dit on the transformers if configured (idempotent).
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class ServerArgs:

# Compilation
enable_torch_compile: bool = False
regional_compile: bool = False

# warmup
warmup: bool = False
Expand Down Expand Up @@ -588,6 +589,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Use torch.compile to speed up DiT inference."
+ "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)",
)
parser.add_argument(
"--regional-compile",
action=StoreBoolean,
default=ServerArgs.regional_compile,
help="Apply regional compile on repeated blocks to reduce compile time.",
)

# warmup
parser.add_argument(
Expand Down
Loading