Skip to content
Merged
23 changes: 21 additions & 2 deletions vllm_omni/diffusion/worker/diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,28 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput:
not getattr(req, "skip_cache_refresh", False)
and self.cache_backend is not None
and self.cache_backend.is_enabled()
and req.sampling_params.num_inference_steps is not None
):
self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps)
# FIXME (Alex): When num_inference_steps is None, we defer to
# pipelines for default, but don't refresh the cache; the right
# way to do this is to merge the sampling params first.
#
# For now, if num_inference_steps is not set, we pass 0 to allow
# TeaCache to refresh to align with the param signature. This is
# okay to force refresh TeaCache because the refresh does not use
# num_inference_steps at all (i.e., just resets state and clears
# stale residuals).
num_inference_steps = req.sampling_params.num_inference_steps
if self.od_config.cache_backend == "tea_cache" and num_inference_steps is None:
num_inference_steps = 0

if num_inference_steps is not None:
self.cache_backend.refresh(self.pipeline, num_inference_steps)
else:
logger.warning(
"Failed to refresh the diffusion transformer cache; backend %s "
"currently requires num_inference_steps to be passed explicitly",
self.od_config.cache_backend,
)

is_primary = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
if is_primary:
Expand Down
Loading