diff --git a/vllm_omni/diffusion/models/flux2/flux2_transformer.py b/vllm_omni/diffusion/models/flux2/flux2_transformer.py index 116e499b0ee..0a4452197f8 100644 --- a/vllm_omni/diffusion/models/flux2/flux2_transformer.py +++ b/vllm_omni/diffusion/models/flux2/flux2_transformer.py @@ -578,9 +578,11 @@ def __init__( guidance_embeds: bool = True, ): super().__init__() + self.guidance_embeds = guidance_embeds self.stacked_params_mapping = None self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim + self.config = SimpleNamespace( patch_size=patch_size, in_channels=in_channels, diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 1da0f0cdaf9..c5bf9b77d9e 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -928,6 +928,7 @@ def forward( self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False + guidance_tensor = None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1017,6 +1018,11 @@ def forward( ) self._num_timesteps = len(timesteps) + # handle guidance + if self.transformer.guidance_embeds is not None: + guidance_tensor = torch.full([1], self.guidance_scale, device=device, dtype=torch.float32) + guidance_tensor = guidance_tensor.expand(latents.shape[0]) + # 7. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -1038,7 +1044,7 @@ def forward( noise_pred = self.transformer( hidden_states=latent_model_input, # (B, image_seq_len, C) timestep=timestep / 1000, - guidance=None, + guidance=guidance_tensor, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, # B, text_seq_len, 4 img_ids=latent_image_ids, # B, image_seq_len, 4