Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm_omni/diffusion/models/bagel/bagel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ def generate_image(
if use_sp and use_cfg_text:
if return_trajectory_latents and len(timesteps) > 0:
trajectory_latents.append(x_t.clone())
for i, t in enumerate(timesteps):
for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if frame_condition_token_indexes is not None:
# Cond positions stay at t=0 (clean signal). Matches upstream
Expand Down Expand Up @@ -1816,7 +1816,7 @@ def generate_image(
if use_sp:
if return_trajectory_latents and len(timesteps) > 0:
trajectory_latents.append(x_t.clone())
for i, t in enumerate(timesteps):
for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if frame_condition_token_indexes is not None:
# Cond positions stay at t=0 (clean signal). Matches upstream
Expand Down Expand Up @@ -1867,7 +1867,7 @@ def generate_image(
if return_trajectory_latents and len(timesteps) > 0:
trajectory_latents.append(x_t.clone())

for i, t in enumerate(timesteps):
for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if frame_condition_token_indexes is not None:
# Cond positions stay at t=0 (clean signal). Matches upstream
Expand Down Expand Up @@ -2000,7 +2000,7 @@ def _generate_image_parallel(
if return_trajectory_latents and len(timesteps) > 0:
trajectory_latents.append(x_t.clone())

for i, t in enumerate(timesteps):
for i, t in enumerate(timesteps.tolist()): # host floats; a 0-d tensor t would sync each step
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if frame_condition_token_indexes is not None:
# Cond positions stay at t=0 (clean signal). Matches upstream
Expand Down
42 changes: 29 additions & 13 deletions vllm_omni/model_executor/models/bagel/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>"))
self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>"))
self._vae_token_mask: torch.Tensor | None = None
# Whether the current request packs any VAE / non-VAE tokens, refreshed
# in _adjust_positions_for_img2img. Cached as plain bools so the per-layer
# MoT routing can branch without calling .any() (which forces a device sync).
self._has_vae_tokens: bool = False
self._has_non_vae_tokens: bool = True
self.device = get_local_device()
self._install_mot_modules(config)

Expand Down Expand Up @@ -785,13 +790,18 @@ def _adjust_positions_for_img2img(

if not info_list:
self._vae_token_mask = None
self._has_vae_tokens = False
self._has_non_vae_tokens = True
return positions

boundaries = [0]
for i in range(1, len(positions)):
if positions[i] < positions[i - 1]:
# Copy positions to the host once: indexing the CUDA tensor element by
# element in the loop below would sync the device on every iteration.
pos_list = positions.tolist()
for i in range(1, len(pos_list)):
if pos_list[i] < pos_list[i - 1]:
boundaries.append(i)
boundaries.append(len(positions))
boundaries.append(len(pos_list))

num_requests = len(boundaries) - 1
new_positions = positions.clone()
Expand Down Expand Up @@ -864,7 +874,13 @@ def _adjust_positions_for_img2img(
rope = int(new_positions[end - 1].item()) + 1
self._ropes_pending.append({"ropes": [rope]})

self._vae_token_mask = vae_mask if vae_mask.any() else None
# Resolve mask occupancy once here (the only .any() syncs on this path)
# and cache it; the per-layer routing reads these flags instead of
# re-checking the mask on every decoder layer.
has_vae = bool(vae_mask.any())
self._vae_token_mask = vae_mask if has_vae else None
self._has_vae_tokens = has_vae
self._has_non_vae_tokens = bool((~vae_mask).any()) if has_vae else True
return new_positions

# ------------------------------------------------------------------
Expand Down Expand Up @@ -910,10 +926,10 @@ def _mot_forward(
# Final norm with MoT routing
if residual is not None:
hidden_states = hidden_states + residual
if vae_mask is not None and vae_mask.any():
if vae_mask is not None and self._has_vae_tokens:
out = torch.empty_like(hidden_states)
non_vae = ~vae_mask
if non_vae.any():
if self._has_non_vae_tokens:
out[non_vae] = qwen2_model.norm(hidden_states[non_vae])
out[vae_mask] = qwen2_model.norm_moe_gen(hidden_states[vae_mask])
hidden_states = out
Expand All @@ -931,7 +947,7 @@ def _mot_layer_forward(
vae_mask: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Single decoder-layer forward with MoT routing."""
if vae_mask is None or not vae_mask.any():
if vae_mask is None or not self._has_vae_tokens:
return layer(positions, hidden_states, residual)

non_vae = ~vae_mask
Expand All @@ -941,7 +957,7 @@ def _mot_layer_forward(
hidden_states = hidden_states + residual
residual = hidden_states
normed = torch.empty_like(hidden_states)
if non_vae.any():
if self._has_non_vae_tokens:
normed[non_vae] = layer.input_layernorm(hidden_states[non_vae])
normed[vae_mask] = layer.input_layernorm_moe_gen(hidden_states[vae_mask])
hidden_states = normed
Expand All @@ -953,14 +969,14 @@ def _mot_layer_forward(
hidden_states = hidden_states + residual
residual = hidden_states
normed = torch.empty_like(hidden_states)
if non_vae.any():
if self._has_non_vae_tokens:
normed[non_vae] = layer.post_attention_layernorm(hidden_states[non_vae])
normed[vae_mask] = layer.post_attention_layernorm_moe_gen(hidden_states[vae_mask])
hidden_states = normed

# ---- MLP (split) ----
mlp_out = torch.empty_like(hidden_states)
if non_vae.any():
if self._has_non_vae_tokens:
mlp_out[non_vae] = layer.mlp(hidden_states[non_vae])
mlp_out[vae_mask] = layer.mlp_moe_gen(hidden_states[vae_mask])
hidden_states = mlp_out
Expand All @@ -985,7 +1001,7 @@ def _mot_attn_forward(
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if non_vae.any():
if self._has_non_vae_tokens:
qkv_und, _ = attn.qkv_proj(hidden_states[non_vae])
qkv[non_vae] = qkv_und
qkv_gen, _ = attn.qkv_proj_moe_gen(hidden_states[vae_mask])
Expand All @@ -1001,7 +1017,7 @@ def _mot_attn_forward(

q_out = torch.empty_like(q)
k_out = torch.empty_like(k)
if non_vae.any():
if self._has_non_vae_tokens:
q_out[non_vae] = attn.q_norm(q[non_vae])
k_out[non_vae] = attn.k_norm(k[non_vae])
q_out[vae_mask] = attn.q_norm_moe_gen(q[vae_mask])
Expand All @@ -1021,7 +1037,7 @@ def _mot_attn_forward(
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if non_vae.any():
if self._has_non_vae_tokens:
o_und, _ = attn.o_proj(attn_output[non_vae])
output[non_vae] = o_und
o_gen, _ = attn.o_proj_moe_gen(attn_output[vae_mask])
Expand Down
Loading