Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/e2e/offline_inference/test_nextstep_text2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import sys
from pathlib import Path

import pytest
import torch

from tests.utils import hardware_test
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"


@pytest.mark.core_model
@pytest.mark.advanced_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "L4", "rocm": "MI325", "xpu": "B60"}, num_cards={"cuda": 1, "rocm": 2, "xpu": 2})
def test_nextstep_text2img(run_level):
if run_level == "core_model":
pytest.skip()

m = None
try:
omni_kwargs = {
"model": "stepfun-ai/NextStep-1.1",
"model_class_name": "NextStep11Pipeline",
}
if current_omni_platform.is_xpu():
omni_kwargs["parallel_config"] = DiffusionParallelConfig(tensor_parallel_size=2)
m = Omni(**omni_kwargs)
# high resolution may cause OOM on L4
height = 256
width = 256
outputs = m.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
width=width,
num_inference_steps=4,
guidance_scale=7.5,
guidance_scale_2=1.0,
generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
num_outputs_per_prompt=1,
),
)
first_output = outputs[0]
assert first_output.final_output_type == "image"
assert getattr(first_output, "request_output", None), "no request_output on NextStep output"
req_out = first_output.request_output[0]
assert isinstance(req_out, OmniRequestOutput), "request_output[0] is not OmniRequestOutput"
assert getattr(req_out, "images", None), "no images in NextStep request_output"
images = req_out.images
assert len(images) == 1, f"expected 1 image, got {len(images)}"
assert images[0].width == width
assert images[0].height == height
except Exception as e:
print(f"Test failed with error: {e}")
raise
finally:
if m is not None and hasattr(m, "close"):
m.close()
3 changes: 1 addition & 2 deletions vllm_omni/entrypoints/omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs):
if model_type == "bagel" or "BagelForConditionalGeneration" in architectures:
pipeline_class = "BagelPipeline"
elif model_type == "nextstep":
if od_config.model_class_name is None:
pipeline_class = "NextStep11Pipeline"
pipeline_class = "NextStep11Pipeline"
elif model_type == "glm-image" or "GlmImageForConditionalGeneration" in architectures:
pipeline_class = "GlmImagePipeline"
elif architectures and len(architectures) == 1:
Expand Down