From 81ef938762a9699158acf1a0495bac10eff6fc67 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 26 Mar 2026 17:56:59 +0000 Subject: [PATCH 1/6] add workaround for teacache with num_inference_steps=None Signed-off-by: Alex Brooks --- .../worker/diffusion_model_runner.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 535f053c388..258b5cbe024 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -263,9 +263,25 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: not getattr(req, "skip_cache_refresh", False) and self.cache_backend is not None and self.cache_backend.is_enabled() - and req.sampling_params.num_inference_steps is not None ): - self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps) + # FIXME (Alex): When num_inference_steps is None, we defer to + # pipelines for default, but don't refresh the cache; the right + # way to do this is to merge the sampling params first, but + # for now we allow teacache to refresh either way since it does + # not depend on the num_inference_steps. + num_inference_steps = ( + req.sampling_params.num_inference_steps or 0 + if self.od_config.cache_backend == "tea_cache" + else req.sampling_params.num_inference_steps + ) + if num_inference_steps is not None: + self.cache_backend.refresh(self.pipeline, num_inference_steps) + else: + logger.warning( + "Failed to refresh the diffusion transformer cache; backend %s " + "currently requires num_inference_steps to be passed explicitly", + self.cache_backend, + ) is_primary = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 if is_primary: From 0418588938932c2a038b825992ea4378ee68fec6 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 26 Mar 2026 18:18:03 +0000 Subject: [PATCH 2/6] fix log Signed-off-by: Alex Brooks --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 258b5cbe024..7961d44edf6 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -280,7 +280,7 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: logger.warning( "Failed to refresh the diffusion transformer cache; backend %s " "currently requires num_inference_steps to be passed explicitly", - self.cache_backend, + self.od_config.cache_backend, ) is_primary = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 From c6801b34ecd4840e0698eba7aeaf56879230725a Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 2 Apr 2026 09:46:58 -0600 Subject: [PATCH 3/6] Update vllm_omni/diffusion/worker/diffusion_model_runner.py Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com> Signed-off-by: Alex Brooks --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 7961d44edf6..41d8d4c0d6f 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -273,6 +273,10 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: req.sampling_params.num_inference_steps or 0 if self.od_config.cache_backend == "tea_cache" else req.sampling_params.num_inference_steps + num_inference_steps = ( + req.sampling_params.num_inference_steps if req.sampling_params.num_inference_steps is not None else 0 + if self.od_config.cache_backend == "tea_cache" + else req.sampling_params.num_inference_steps ) if num_inference_steps is not None: self.cache_backend.refresh(self.pipeline, num_inference_steps) From 2dc2eac9d86848936b983351116d47eaacda4c9a Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 2 Apr 2026 15:49:23 +0000 Subject: [PATCH 4/6] clarify comment Signed-off-by: Alex Brooks --- .../diffusion/worker/diffusion_model_runner.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 41d8d4c0d6f..cec9df2d316 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -266,15 +266,16 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: ): # FIXME (Alex): When num_inference_steps is None, we defer to # pipelines for default, but don't refresh the cache; the right - # way to do this is to merge the sampling params first, but - # for now we allow teacache to refresh either way since it does - # not depend on the num_inference_steps. + # way to do this is to merge the sampling params first. + # + # For now, if num_inference_steps is not set, we pass 0 to allow + # TeaCache to refresh to align with the param signature. This is + # okay to force refresh TeaCache because the refresh does not use + # num_inference_steps at all. num_inference_steps = ( - req.sampling_params.num_inference_steps or 0 - if self.od_config.cache_backend == "tea_cache" - else req.sampling_params.num_inference_steps - num_inference_steps = ( - req.sampling_params.num_inference_steps if req.sampling_params.num_inference_steps is not None else 0 + req.sampling_params.num_inference_steps + if req.sampling_params.num_inference_steps is not None + else 0 if self.od_config.cache_backend == "tea_cache" else req.sampling_params.num_inference_steps ) From 911c86d09fa85628af7c3065f434e22de468a0a4 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 2 Apr 2026 15:50:40 +0000 Subject: [PATCH 5/6] doc Signed-off-by: Alex Brooks --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index cec9df2d316..311c1d388fb 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -271,7 +271,8 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: # For now, if num_inference_steps is not set, we pass 0 to allow # TeaCache to refresh to align with the param signature. This is # okay to force refresh TeaCache because the refresh does not use - # num_inference_steps at all. + # num_inference_steps at all (i.e., just resets state and clears + # stale residuals). num_inference_steps = ( req.sampling_params.num_inference_steps if req.sampling_params.num_inference_steps is not None From e1fb78ef1db512690341a716793f3828c5ca8592 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 2 Apr 2026 15:54:58 +0000 Subject: [PATCH 6/6] simplify Signed-off-by: Alex Brooks --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 311c1d388fb..36c3266e8ee 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -273,13 +273,10 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: # okay to force refresh TeaCache because the refresh does not use # num_inference_steps at all (i.e., just resets state and clears # stale residuals). - num_inference_steps = ( - req.sampling_params.num_inference_steps - if req.sampling_params.num_inference_steps is not None - else 0 - if self.od_config.cache_backend == "tea_cache" - else req.sampling_params.num_inference_steps - ) + num_inference_steps = req.sampling_params.num_inference_steps + if self.od_config.cache_backend == "tea_cache" and num_inference_steps is None: + num_inference_steps = 0 + if num_inference_steps is not None: self.cache_backend.refresh(self.pipeline, num_inference_steps) else: