Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/e2e/online_serving/test_ltx2_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Comprehensive tests of diffusion features that are available in online serving mode
and are supported by the LTX2 model.
- Lightricks/LTX-2
- rootonchair/LTX-2-19b-distilled

Coverage:
- Cache-DiT
- CFG-Parallel
- Ulysses-SP
- Ring-Attn

assert_diffusion_response validates successful generation
"""

import os

import pytest

from tests.helpers.mark import hardware_marks
from tests.helpers.media import (
generate_synthetic_image,
)
from tests.helpers.runtime import (
OmniServer,
OmniServerParams,
OpenAIClientHandler,
)

PROMPT = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
NEGATIVE_PROMPT = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"})
PARALLEL_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"}, num_cards=2)

LTX2_MODELS = [
("Lightricks/LTX-2", "LTX2Pipeline"),
("Lightricks/LTX-2", "LTX2ImageToVideoPipeline"),
("rootonchair/LTX-2-19b-distilled", "LTX2TwoStagesPipeline"),
("rootonchair/LTX-2-19b-distilled", "LTX2ImageToVideoTwoStagesPipeline"),
]

PARALLEL_CONFIGS = [
("cfg_parallel", ["--cfg-parallel-size", "2"]),
("ulysses_sp", ["--usp", "2"]),
("ring_atten", ["--ring", "2"]),
("hsdp", ["--use-hsdp", "--hsdp-shard-size", "2"]),
]


def _get_ltx2_feature_cases():
cases = []

# Single-card: Cache-DiT (applies to all models)
for model_path, model_cls_name in LTX2_MODELS:
cases.append(
pytest.param(
OmniServerParams(
model=model_path,
model_class_name=model_cls_name,
server_args=["--cache-backend", "cache_dit", "--enable-layerwise-offload"],
),
id=f"{model_cls_name}_cache_dit",
marks=SINGLE_CARD_FEATURE_MARKS,
)
)

# Multi-card features
for model_path, model_cls_name in LTX2_MODELS:
for feat_id, server_args in PARALLEL_CONFIGS:
cases.append(
pytest.param(
OmniServerParams(
model=model_path,
model_class_name=model_cls_name,
server_args=server_args,
),
id=f"{model_cls_name}_{feat_id}",
marks=PARALLEL_FEATURE_MARKS,
)
)

return cases


@pytest.mark.advanced_model
@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_ltx2_feature_cases(),
indirect=True,
)
def test_ltx2_diffusion_features(
omni_server: OmniServer,
openai_client: OpenAIClientHandler,
):
model_path = omni_server.model
model_name = os.path.basename(os.path.normpath(model_path))
model_class_name = omni_server.model_class_name
is_distilled_model = model_name == "LTX-2-19b-distilled"
is_i2v = "ImageToVideo" in model_class_name

form_data = {
"prompt": PROMPT,
"negative_prompt": NEGATIVE_PROMPT,
"height": 768,
"width": 512,
"num_inference_steps": 2,
"guidance_scale": 1.0 if is_distilled_model else 4.0,
"seed": 42,
}

request_config = {
"model": model_path,
"form_data": form_data,
"model_class_name": model_class_name, # use for validate diffusion response for two-stages pipeline
}

if is_i2v:
request_config["image_reference"] = f"data:image/jpeg;base64,{generate_synthetic_image(758, 512)['base64']}"

openai_client.send_video_diffusion_request(request_config)
6 changes: 6 additions & 0 deletions tests/helpers/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def assert_video_diffusion_response(
expected_width = _maybe_int(form_data.get("width"))
expected_height = _maybe_int(form_data.get("height"))
expected_fps = _maybe_int(form_data.get("fps"))
model_class_name = request_config.get("model_class_name", None)
if model_class_name is not None and (
model_class_name == "LTX2TwoStagesPipeline" or model_class_name == "LTX2ImageToVideoTwoStagesPipeline"
):
expected_height *= 2
expected_width *= 2

for vid_bytes in response.videos:
assert_video_valid(
Expand Down
2 changes: 2 additions & 0 deletions tests/helpers/fixtures/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st
port=port,
env_dict=params.env_dict,
use_omni=params.use_omni,
model_class_name=params.model_class_name,
)
if port
else OmniServer(
model,
server_args,
env_dict=params.env_dict,
use_omni=params.use_omni,
model_class_name=params.model_class_name,
)
) as server:
print("OmniServer started successfully")
Expand Down
5 changes: 5 additions & 0 deletions tests/helpers/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class OmniServerParams(NamedTuple):
use_stage_cli: bool = False
init_timeout: int | None = None
stage_init_timeout: int | None = None # None: fixture supplies default (600 s)
model_class_name: str | None = None


class OmniServer:
Expand All @@ -131,13 +132,15 @@ def __init__(
port: int | None = None,
env_dict: dict[str, str] | None = None,
use_omni: bool = True,
model_class_name: str | None = None,
) -> None:
run_forced_gpu_cleanup_round()
cleanup_dist_env_and_memory()
self.model = model
self.serve_args = serve_args
self.env_dict = env_dict
self.use_omni = use_omni
self.model_class_name = model_class_name
self.proc: subprocess.Popen | None = None
self.host = "127.0.0.1"
self.port = get_open_port() if port is None else port
Expand All @@ -161,6 +164,8 @@ def _start_server(self) -> None:
]
if self.use_omni:
cmd.append("--omni")
if self.model_class_name:
self.serve_args.extend(["--model-class-name", self.model_class_name])
cmd += self.serve_args

print(f"Launching OmniServer with: {' '.join(cmd)}")
Expand Down
Loading