diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 4039e06ac1e..4231d4cc638 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -1747,7 +1747,7 @@ def generate_image( if use_sp and use_cfg_text: if return_trajectory_latents and len(timesteps) > 0: trajectory_latents.append(x_t.clone()) - for i, t in enumerate(timesteps): + for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) if frame_condition_token_indexes is not None: # Cond positions stay at t=0 (clean signal). Matches upstream @@ -1816,7 +1816,7 @@ def generate_image( if use_sp: if return_trajectory_latents and len(timesteps) > 0: trajectory_latents.append(x_t.clone()) - for i, t in enumerate(timesteps): + for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) if frame_condition_token_indexes is not None: # Cond positions stay at t=0 (clean signal). Matches upstream @@ -1867,7 +1867,7 @@ def generate_image( if return_trajectory_latents and len(timesteps) > 0: trajectory_latents.append(x_t.clone()) - for i, t in enumerate(timesteps): + for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) if frame_condition_token_indexes is not None: # Cond positions stay at t=0 (clean signal). Matches upstream @@ -2000,7 +2000,7 @@ def _generate_image_parallel( if return_trajectory_latents and len(timesteps) > 0: trajectory_latents.append(x_t.clone()) - for i, t in enumerate(timesteps): + for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) if frame_condition_token_indexes is not None: # Cond positions stay at t=0 (clean signal). Matches upstream diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py index 4bc318e30a8..282abed23de 100644 --- a/vllm_omni/model_executor/models/bagel/bagel.py +++ b/vllm_omni/model_executor/models/bagel/bagel.py @@ -476,6 +476,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>")) self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>")) self._vae_token_mask: torch.Tensor | None = None + # Whether the current request packs any VAE / non-VAE tokens, refreshed + # in _adjust_positions_for_img2img. Cached as plain bools so the per-layer + # MoT routing can branch without calling .any() (which forces a device sync). + self._has_vae_tokens: bool = False + self._has_non_vae_tokens: bool = True self.device = get_local_device() self._install_mot_modules(config) @@ -785,13 +790,18 @@ def _adjust_positions_for_img2img( if not info_list: self._vae_token_mask = None + self._has_vae_tokens = False + self._has_non_vae_tokens = True return positions boundaries = [0] - for i in range(1, len(positions)): - if positions[i] < positions[i - 1]: + # Copy positions to the host once: indexing the CUDA tensor element by + # element in the loop below would sync the device on every iteration. + pos_list = positions.tolist() + for i in range(1, len(pos_list)): + if pos_list[i] < pos_list[i - 1]: boundaries.append(i) - boundaries.append(len(positions)) + boundaries.append(len(pos_list)) num_requests = len(boundaries) - 1 new_positions = positions.clone() @@ -864,7 +874,13 @@ def _adjust_positions_for_img2img( rope = int(new_positions[end - 1].item()) + 1 self._ropes_pending.append({"ropes": [rope]}) - self._vae_token_mask = vae_mask if vae_mask.any() else None + # Resolve mask occupancy once here (the only .any() syncs on this path) + # and cache it; the per-layer routing reads these flags instead of + # re-checking the mask on every decoder layer. + has_vae = bool(vae_mask.any()) + self._vae_token_mask = vae_mask if has_vae else None + self._has_vae_tokens = has_vae + self._has_non_vae_tokens = bool((~vae_mask).any()) if has_vae else True return new_positions # ------------------------------------------------------------------ @@ -910,10 +926,10 @@ def _mot_forward( # Final norm with MoT routing if residual is not None: hidden_states = hidden_states + residual - if vae_mask is not None and vae_mask.any(): + if vae_mask is not None and self._has_vae_tokens: out = torch.empty_like(hidden_states) non_vae = ~vae_mask - if non_vae.any(): + if self._has_non_vae_tokens: out[non_vae] = qwen2_model.norm(hidden_states[non_vae]) out[vae_mask] = qwen2_model.norm_moe_gen(hidden_states[vae_mask]) hidden_states = out @@ -931,7 +947,7 @@ def _mot_layer_forward( vae_mask: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: """Single decoder-layer forward with MoT routing.""" - if vae_mask is None or not vae_mask.any(): + if vae_mask is None or not self._has_vae_tokens: return layer(positions, hidden_states, residual) non_vae = ~vae_mask @@ -941,7 +957,7 @@ def _mot_layer_forward( hidden_states = hidden_states + residual residual = hidden_states normed = torch.empty_like(hidden_states) - if non_vae.any(): + if self._has_non_vae_tokens: normed[non_vae] = layer.input_layernorm(hidden_states[non_vae]) normed[vae_mask] = layer.input_layernorm_moe_gen(hidden_states[vae_mask]) hidden_states = normed @@ -953,14 +969,14 @@ def _mot_layer_forward( hidden_states = hidden_states + residual residual = hidden_states normed = torch.empty_like(hidden_states) - if non_vae.any(): + if self._has_non_vae_tokens: normed[non_vae] = layer.post_attention_layernorm(hidden_states[non_vae]) normed[vae_mask] = layer.post_attention_layernorm_moe_gen(hidden_states[vae_mask]) hidden_states = normed # ---- MLP (split) ---- mlp_out = torch.empty_like(hidden_states) - if non_vae.any(): + if self._has_non_vae_tokens: mlp_out[non_vae] = layer.mlp(hidden_states[non_vae]) mlp_out[vae_mask] = layer.mlp_moe_gen(hidden_states[vae_mask]) hidden_states = mlp_out @@ -985,7 +1001,7 @@ def _mot_attn_forward( device=hidden_states.device, dtype=hidden_states.dtype, ) - if non_vae.any(): + if self._has_non_vae_tokens: qkv_und, _ = attn.qkv_proj(hidden_states[non_vae]) qkv[non_vae] = qkv_und qkv_gen, _ = attn.qkv_proj_moe_gen(hidden_states[vae_mask]) @@ -1001,7 +1017,7 @@ def _mot_attn_forward( q_out = torch.empty_like(q) k_out = torch.empty_like(k) - if non_vae.any(): + if self._has_non_vae_tokens: q_out[non_vae] = attn.q_norm(q[non_vae]) k_out[non_vae] = attn.k_norm(k[non_vae]) q_out[vae_mask] = attn.q_norm_moe_gen(q[vae_mask]) @@ -1021,7 +1037,7 @@ def _mot_attn_forward( device=hidden_states.device, dtype=hidden_states.dtype, ) - if non_vae.any(): + if self._has_non_vae_tokens: o_und, _ = attn.o_proj(attn_output[non_vae]) output[non_vae] = o_und o_gen, _ = attn.o_proj_moe_gen(attn_output[vae_mask])