diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index e715bdaec8..09144e05a9 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -1526,6 +1526,11 @@ def _generate_image_parallel( f"Use cfg_parallel_size=3 to enable image CFG in parallel mode." ) + # Ensure all ranks start with the same x_t (initial noise may differ + # across ranks when no per-request seed is set). + x_t = x_t.contiguous() + cfg_group.broadcast(x_t, src=0) + # Select this rank's branch inputs if cfg_rank == 0: # Gen branch: use main inputs directly @@ -1553,17 +1558,10 @@ def _generate_image_parallel( for i, t in enumerate(timesteps): timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) - if t > cfg_interval[0] and t <= cfg_interval[1]: - cfg_text_scale_ = cfg_text_scale - cfg_img_scale_ = cfg_img_scale - else: - cfg_text_scale_ = 1.0 - cfg_img_scale_ = 1.0 - - use_cfg_this_step = cfg_text_scale_ > 1.0 + use_cfg_this_step = t > cfg_interval[0] and t <= cfg_interval[1] and cfg_text_scale > 1.0 if use_cfg_this_step: - # Each rank computes its branch's velocity + # CFG interval: each rank computes its own branch local_v_t = self._forward_flow_single_branch( x_t=x_t, timestep=timestep, @@ -1579,46 +1577,34 @@ def _generate_image_parallel( packed_key_value_indexes=branch_key_value_indexes, ) - # All-gather velocities from all CFG ranks gathered = cfg_group.all_gather(local_v_t, separate_tensors=True) - - # Rank 0 combines with CFG formula - if cfg_rank == 0: - v_t = gathered[0] # gen branch - cfg_text_v_t = gathered[1] # text_cfg branch - cfg_img_v_t = gathered[2] if (use_cfg_img and len(gathered) > 2) else None - v_t = self._combine_cfg( - v_t, - cfg_text_v_t, - cfg_img_v_t, - cfg_text_scale_, - cfg_img_scale_, - cfg_renorm_type, - cfg_renorm_min, - ) - x_t = x_t - v_t.to(x_t.device) * dts[i] + v_t = self._combine_cfg( + gathered[0], + gathered[1], + gathered[2] if (use_cfg_img and len(gathered) > 2) else None, + cfg_text_scale, + cfg_img_scale, + cfg_renorm_type, + cfg_renorm_min, + ) else: - # Outside cfg_interval: only rank 0 computes (no CFG needed) - if cfg_rank == 0: - v_t = self._forward_flow_single_branch( - x_t=x_t, - timestep=timestep, - packed_vae_token_indexes=packed_vae_token_indexes, - packed_vae_position_ids=packed_vae_position_ids, - packed_text_ids=packed_text_ids, - packed_text_indexes=packed_text_indexes, - packed_indexes=packed_indexes, - packed_position_ids=packed_position_ids, - packed_seqlens=packed_seqlens, - key_values_lens=key_values_lens, - past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, - ) - x_t = x_t - v_t.to(x_t.device) * dts[i] - - # Broadcast updated x_t from rank 0 to all ranks - x_t = x_t.contiguous() - cfg_group.broadcast(x_t, src=0) + # Outside CFG interval: all ranks compute with gen inputs, no comm + v_t = self._forward_flow_single_branch( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + + x_t = x_t - v_t.to(x_t.device) * dts[i] unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) return unpacked_latent