Skip to content

[WIP] [Refactor]Add Base Class for Diffusion Pipelines#2811

Draft
alex-jw-brooks wants to merge 11 commits intovllm-project:mainfrom
alex-jw-brooks:base_pipeline
Draft

[WIP] [Refactor]Add Base Class for Diffusion Pipelines#2811
alex-jw-brooks wants to merge 11 commits intovllm-project:mainfrom
alex-jw-brooks:base_pipeline

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks commented Apr 15, 2026

Purpose

For this RFC: #2189
As a consequence of this refactor, will also fix issues with default params not refreshing properly when num_inference_steps isn't set: #2189. If this ends up merged before #2240, we won't need that PR anymore.

Current State

  • every diffusion pipeline inherits from VllmDiffusionPipeline
  • all VllmDiffusionPipeline instances have a sampling_param_defaults property which resolves to a DiffusionParamOverrides object, setting default values for sampling params on a per pipeline basis
  • DiffusionParamOverrides can be merged with OmniDiffusionSamplingParams with the priority passed by user > pipeline default (in DiffusionParamOverrides) > OmniDiffusionSamplingParams default
    • Rather than having a sentinel for every dataclass field, this is accomplished with a track_init_args decorator, which just saves the names of the fields that were passed by the user when initializing the dataclass. This is needed for the following case:
      • If the default for foo in the OmniDiffusionSamplingParams is 0, but a pipeline sets foo=1 in its DiffusionParamOverrides, resolving should give 1
      • However, if the user actually passes 0, it should override foo in the pipeline to set a final value of 0. I.e., we need to be able to distinguish between when a user passes a default value in OmniDiffusionSamplingParams vs when it's set by default.
  • When we execute_model, we resolve the sampling params immediately so that when we go to refresh the cache, it behaves correctly

Things Left

  • Take another pass through models and ensure that all params that should be in the params are in the DiffusionParamOverrides
  • Add additional helpers for getting different components (i.e., should have helpers/proprties for getting the transformer/Vae etc so that we don't have to discover attributes as frequently)
    • We can also add additional properties using those ^ to check for things like support for sequence parallelism etc, e.g pipe.supports_sequence_parallel can essentially be boiled down to hasattr(pipe.get_transformer(), "_sp_plan"), etc
  • Need to ensure behavior is consistent with the online path (mostly have tested with offline so far)

Test Plan

  • Tests have been added for some of the merging utils to ensure the sampling param logic (These are in tests/diffusion/inputs/test_data.py)
  • We should add fast generic tests for all pipelines where possible. For example, we can be check that all sampling param overrides are resolvable into valid OmniDiffusionSamplingParams quickly without loading model weights. (These are in tests/diffusion/models/test_base.py)

This is joint work with @vraiti , thanks for the help! 🙂

Will move this out of draft once it's cleaned up and tested since it's quite messy atm, but opening the draft PR in case people have comments on the direction. FYI @wtomin @SamitHuang @lishunyang12 @asukaqaq-s @hsliuustc0106 @fhfuih

alex-jw-brooks and others added 11 commits April 14, 2026 04:52
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Add a track_init_args decorator that wraps __init__ to record which
kwargs the caller explicitly passed.  merge_with_def_params() checks
_init_kwargs to fill in pipeline defaults only for fields the caller
never touched, correctly preserving explicitly-set falsy values like
0 or False.

_convert_dataclasses_to_dict() in entrypoints/utils.py is updated to
use field iteration instead of asdict(), skipping non-init fields and
None values matching field defaults.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants