Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.logger import init_logger
Expand Down Expand Up @@ -67,7 +69,9 @@ def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]:
return post_process_func


class FluxKontextPipeline(nn.Module, FluxPipelineMixin, SupportImageInput):
class FluxKontextPipeline(
nn.Module, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin
):
"""FLUX.1-Kontext pipeline for image editing with text guidance."""

support_image_input = True
Expand Down Expand Up @@ -148,6 +152,10 @@ def __init__(
self._callback_tensor_inputs = ["latents", "prompt_embeds"]
self.latent_channels = self.vae.config.latent_channels if hasattr(self.vae, "config") else 16

self.setup_diffusion_pipeline_profiler(
enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
)

def _get_t5_prompt_embeds(
self,
prompt: str | list[str] = None,
Expand Down Expand Up @@ -635,58 +643,61 @@ def forward(

# 5. Denoising loop
self.scheduler.set_begin_index(0)
for i, t in enumerate(timesteps):
if self.interrupt:
continue

latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
neg_noise_pred = self.transformer(
with self.progress_bar(total=len(timesteps)) as pbar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

pbar.update()
if output_type == "latent":
image = latents
else:
Expand Down
95 changes: 52 additions & 43 deletions vllm_omni/diffusion/models/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
Expand Down Expand Up @@ -331,7 +333,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator =
raise AttributeError("Could not access latents of provided encoder_output")


class Flux2Pipeline(nn.Module, SupportImageInput):
class Flux2Pipeline(nn.Module, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin):
"""Flux2 pipeline for text-to-image generation."""

_callback_tensor_inputs = ["latents", "prompt_embeds"]
Expand Down Expand Up @@ -389,6 +391,10 @@ def __init__(
self._guidance_scale = None
self._attention_kwargs = None
self._num_timesteps = None

self.setup_diffusion_pipeline_profiler(
enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
)
self._current_timestep = None
self._interrupt = False

Expand Down Expand Up @@ -1027,48 +1033,51 @@ def forward(
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
for i, t in enumerate(timesteps):
if self.interrupt:
continue

self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(latents.dtype)

latent_model_input = latents.to(self.transformer.dtype)
latent_image_ids = latent_ids

if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)

noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, image_seq_len, C)
timestep=timestep / 1000,
guidance=guidance_tensor,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

noise_pred = noise_pred[:, : latents.size(1) :]

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype and torch.backends.mps.is_available():
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
with self.progress_bar(total=len(timesteps)) as pbar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(latents.dtype)

latent_model_input = latents.to(self.transformer.dtype)
latent_image_ids = latent_ids

if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)

noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, image_seq_len, C)
timestep=timestep / 1000,
guidance=guidance_tensor,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

noise_pred = noise_pred[:, : latents.size(1) :]

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype and torch.backends.mps.is_available():
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

pbar.update()

self._current_timestep = None

Expand Down
Loading
Loading