Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,12 @@ class OmniDiffusionConfig:
# Step mode settings
step_execution: bool = False

# Default generator device for random number generation (e.g., "cpu", "cuda")
# If set, this device will be used for torch.Generator when per-request
# generator_device is not specified. This is useful for ensuring reproducible
# results across different hardware configurations.
default_generator_device: str | None = None

@property
def is_moe(self) -> bool:
num_experts = self.tf_model_config.get("num_experts", None)
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/worker/diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput:
if req.sampling_params.generator is None and req.sampling_params.seed is not None:
if req.sampling_params.generator_device is not None:
gen_device = req.sampling_params.generator_device
elif self.od_config.default_generator_device is not None:
gen_device = self.od_config.default_generator_device
elif self.device.type == "cpu":
gen_device = "cpu"
else:
Expand Down Expand Up @@ -361,6 +363,8 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner
if state.sampling.generator is None and state.sampling.seed is not None:
if state.sampling.generator_device is not None:
gen_device = state.sampling.generator_device
elif self.od_config.default_generator_device is not None:
gen_device = self.od_config.default_generator_device
elif self.device.type == "cpu":
gen_device = "cpu"
else:
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"num_weight_load_threads": kwargs.get("num_weight_load_threads", 4),
"quantization": kwargs.get("quantization", None),
"enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False),
"default_generator_device": kwargs.get("default_generator_device", None),
**(
{
"profiler_config": asdict(kwargs["profiler_config"])
Expand Down
10 changes: 10 additions & 0 deletions vllm_omni/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,16 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)

# Default generator device for reproducible random number generation
omni_config_group.add_argument(
"--default-generator-device",
type=str,
default=None,
help="Default device for torch.Generator in diffusion models (e.g., 'cpu', 'cuda'). "
"When set, this device will be used for random number generation if not overridden "
"per-request. Using 'cpu' ensures reproducible results across different GPU hardware.",
)
return serve_parser


Expand Down