Skip to content

[BugFix] Refresh TeaCache when num_inference_steps=None#2240

Open
alex-jw-brooks wants to merge 6 commits intovllm-project:mainfrom
alex-jw-brooks:flux2_tc_fix
Open

[BugFix] Refresh TeaCache when num_inference_steps=None#2240
alex-jw-brooks wants to merge 6 commits intovllm-project:mainfrom
alex-jw-brooks:flux2_tc_fix

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

Purpose

Related to #2194

The proper fix for the above issue is to merge the sampling params to get the correct num_inference_steps, but this PR adds a short-term workaround for teacache, which doesn't depend on num_inference_steps. It also adds logging if the cache fails to reset for now while I am working on the more general fix.

This is needed because the warmup initializes teacache, which replaces forward(), and can cause bad behaviors when running TTI on models that accept image inputs. E.g., for Flux2Klein

from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

if __name__ == "__main__":
    omni = Omni(
        model="black-forest-labs/FLUX.2-klein-4B",
        cache_backend="tea_cache",
    )

    prompt = "A cat sitting on a windowsill"

   # If you specify num_inference_steps, you will see the second cache refresh (after warmup)
   # but if you don't pass it, you won't since refresh won't be called.
    sampling_params = OmniDiffusionSamplingParams(
        # Not specifying num_inference_steps will crash forward
    )

    outputs = omni.generate(prompt, sampling_params)
    outputs[0].images[0].save("meow.png")

Not refreshing before entering the forward pass will blow up because the new modulated inputs don't have an image component, while the previous (stale) ones do.

ERROR 03-26 18:11:45 [diffusion_worker.py:481]   File "/home/alex-jw-brooks/vllm-omni/vllm_omni/diffusion/cache/teacache/hook.py", line 222, in _should_compute_full_transformer
ERROR 03-26 18:11:45 [diffusion_worker.py:481]     (modulated_inp - state.previous_modulated_input).abs().mean()
ERROR 03-26 18:11:45 [diffusion_worker.py:481]      ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ERROR 03-26 18:11:45 [diffusion_worker.py:481] RuntimeError: The size of tensor a (4096) must match the size of tensor b (8192) at non-singleton dimension 1
...
ERROR 03-26 18:11:45 [stage_diffusion_client.py:78]   File "/home/alex-jw-brooks/vllm-omni/vllm_omni/entrypoints/async_omni_diffusion.py", line 309, in generate
ERROR 03-26 18:11:45 [stage_diffusion_client.py:78]     raise RuntimeError(f"Diffusion generation failed: {e}") from e
ERROR 03-26 18:11:45 [stage_diffusion_client.py:78] RuntimeError: Diffusion generation failed: The size of tensor a (4096) must match the size of tensor b (8192) at non-singleton dimension 1

This PR allows teacache to refresh in this case, and adds a log if we can't refresh the cache while the more correct fix is added.

@Gaohan123 @wtomin @fhfuih could you please take a look?

@fhfuih
Copy link
Copy Markdown
Contributor

fhfuih commented Mar 27, 2026

EDIT: Sorry, I actually missed your PR description. My understanding is correct, jumped right into your code 😂

Thanks for the PR. A quick question: if I understand it correctly, this PR is only a quick fix It force set the number of inference steps to 0: not None but falsy. This passes the check during cache refreshing, and also yields to pipeline-specific overrides.

And a more complete fix is at your cache_refresh branch

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Mar 27, 2026

Hey @fhfuih! No worries 😆 but yes. My understanding of the flow is

  • The TeaCache hooks gets initialized in load_model, which also creates the StateManager etc for the cache
  • When we run requests, we run the _WrappedForward, which calls the hook's new forward (here).
  • The new forward for TeaCache (this) runs the extractor, then it gets the TeaCache state or creates a new one. After that, it checks the state here to see if it's the first timestep, and compares against the previous modulated state if it isn't.

For TeaCache, the refresh does not depend on the timesteps, and is just resetting the TeaCache state (i.e., the num_inference_steps aren't passed anywhere here). So the value of 0 is just a placeholder I chose because the arg is an int, but in the TeaCache case doesn't matter since all it's doing is clearing the state.

Since it's not being called currently, the state is stale from the last execute model call, so instead of creating a new one on the first time step, it gets the old one, so we fall through this check.

So this fix is okay for a short-term fix for the behavior for TeaCache, but the other branch will fix it more correctly by passing the actual num_inference_steps value, which we need to be able to reset DiTCache correctly 🙂

@alex-jw-brooks alex-jw-brooks changed the title Refresh TeaCache when num_inference_steps=None [BugFix] Refresh TeaCache when num_inference_steps=None Mar 27, 2026
Comment on lines +266 to +270
# FIXME (Alex): When num_inference_steps is None, we defer to
# pipelines for default, but don't refresh the cache; the right
# way to do this is to merge the sampling params first, but
# for now we allow teacache to refresh either way since it does
# not depend on the num_inference_steps.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the comment, maybe we can explain why we need 0 instead of None for now---it hacks the logic in which places, and TeaCache requires which behavior/bugfix

Copy link
Copy Markdown
Contributor

@fhfuih fhfuih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. All looks good to me. @SamitHuang @ZJY0516 could you also have a look and decide whether to merge this hotfix? Since it is related to a previous relevant PR. Thanks

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks reasonable as a short-term workaround, left a couple of nits.

Comment thread vllm_omni/diffusion/worker/diffusion_model_runner.py Outdated
# pipelines for default, but don't refresh the cache; the right
# way to do this is to merge the sampling params first, but
# for now we allow teacache to refresh either way since it does
# not depend on the num_inference_steps.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the comment explains that teacache doesn't depend on num_inference_steps, but it'd be more useful to add a one-liner about why 0 is safe — i.e. TeaCacheBackend.refresh() ignores the value entirely (just resets hook state). That addresses @fhfuih's earlier feedback too.

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! Added a more clear explanation

alex-jw-brooks and others added 6 commits April 14, 2026 05:02
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants