Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions verl_omni/pipelines/qwen_image_flow_grpo/vllm_omni_rollout_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,26 @@ def prepare_encode(
req_scheduler = copy.deepcopy(self.scheduler)
req_scheduler.set_begin_index(0)

# Resolve SDE / log-prob knobs from sampling extra_args so that the
# step-execution path mirrors ``forward()``'s rollout behaviour.
extra = sampling.extra_args or {}
noise_level = _coalesce_not_none(extra.get("noise_level", None), 0.7)
sde_window_size = _coalesce_not_none(extra.get("sde_window_size", None), None)
sde_window_range = _coalesce_not_none(extra.get("sde_window_range", None), (0, 5))
sde_type = _coalesce_not_none(extra.get("sde_type", None), "sde")
logprobs = _coalesce_not_none(extra.get("logprobs", None), True)
if sde_window_size is not None:
start = torch.randint(
sde_window_range[0],
sde_window_range[1] - sde_window_size + 1,
(1,),
generator=generator,
device=self.device,
).item()
sde_window = (start, start + sde_window_size)
else:
sde_window = (0, len(timesteps) - 1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The default sde_window end index is set to len(timesteps) - 1, which excludes the final denoising step from the collected trajectory. In the diffuse loop (and the new step_scheduler), the condition i < sde_window[1] is used to collect latents and log-probabilities. By ending at len(timesteps) - 1, the last latent (the final output) and the last step's log-probability are not captured in all_latents and all_log_probs, leading to an incomplete trajectory for RL training. This should be changed to len(timesteps) to cover the full range of steps. Note that this same logic exists in the forward method (line 772) and should ideally be updated there as well for consistency, although that line is outside the current diff hunks.

Suggested change
sde_window = (0, len(timesteps) - 1)
sde_window = (0, len(timesteps))


state.prompt_embeds = prompt_embeds
state.prompt_embeds_mask = prompt_embeds_mask
state.negative_prompt_embeds = negative_prompt_embeds
Expand All @@ -305,6 +325,19 @@ def prepare_encode(
state.txt_seq_lens = txt_seq_lens
state.negative_txt_seq_lens = negative_txt_seq_lens
state.sampling.cfg_normalize = True
# Persist the resolved generator so ``step_scheduler`` (executed
# one step at a time by the step-execution engine) keeps drawing
# from the same RNG stream as ``forward()``.
state.sampling.generator = generator
# Rollout / SDE state consumed by ``step_scheduler`` and packaged
# into ``custom_output`` by ``post_decode``.
state.sde_window = sde_window
state.noise_level = noise_level
state.sde_type = sde_type
state.logprobs = logprobs
state.all_latents = []
state.all_log_probs = []
state.all_timesteps = []

return state

Expand Down Expand Up @@ -435,6 +468,106 @@ def diffuse(
all_timesteps = torch.stack(all_timesteps).unsqueeze(0).expand(latents.shape[0], -1)
return latents, all_latents, all_log_probs, all_timesteps

def step_scheduler(
self,
state: DiffusionRequestState,
noise_pred: torch.Tensor,
**kwargs: Any,
) -> None:
"""One scheduler step that mirrors the per-iter body of :meth:`diffuse`.

The default ``QwenImagePipeline.step_scheduler`` calls the standard
scheduler.step without SDE noise / log-prob bookkeeping, which means
``step_execution=True`` would silently drop ``all_latents`` /
``all_log_probs`` / ``all_timesteps`` (and the
``prompt_embeds_mask`` consumer downstream would then receive a
``None`` value that turns into a non-tensor ``LinkedList`` inside the
training ``TensorDict``). Override here to keep the step-mode and
request-mode trajectories equivalent.
"""
del kwargs
if self.interrupt:
return

i = state.step_index
timestep_value = state.timesteps[i]
sde_window = state.sde_window

if i < sde_window[0]:
cur_noise_level = 0.0
elif i == sde_window[0]:
cur_noise_level = state.noise_level
state.all_latents.append(state.latents)
elif i > sde_window[0] and i < sde_window[1]:
cur_noise_level = state.noise_level
else:
cur_noise_level = 0.0

new_latents, log_prob, _, _ = state.scheduler.step(
noise_pred,
timestep_value,
state.latents,
generator=state.sampling.generator,
noise_level=cur_noise_level,
sde_type=state.sde_type,
return_logprobs=state.logprobs,
return_dict=False,
)
state.latents = new_latents

if i >= sde_window[0] and i < sde_window[1]:
state.all_latents.append(state.latents)
state.all_log_probs.append(log_prob)
state.all_timesteps.append(timestep_value)

state.step_index += 1

def post_decode(
self,
state: DiffusionRequestState,
**kwargs: Any,
) -> DiffusionOutput:
"""Decode final latents, package rollout trajectory, and move to CPU.

In ``step_execution`` mode the worker ships the returned
:class:`DiffusionOutput` across an inter-process MessageQueue to the
``vLLMOmniHttpServer`` actor. We must (a) move tensors to CPU so the
receiving process does not initialise a stray CUDA context on GPU 0,
and (b) populate ``custom_output`` with the trajectory fields that
:meth:`forward` produces, so downstream consumers
(``vllm_omni_async_server.generate`` ->
``embeds_padding_2_no_padding``) receive real tensors rather than
``None`` (which becomes a non-tensor ``LinkedList`` in the
``TensorDict`` and breaks ``mask.shape[0]``).
"""
output = super().post_decode(state, **kwargs)
if not isinstance(output, DiffusionOutput):
return output

all_latents = state.all_latents
all_log_probs = state.all_log_probs
all_timesteps = state.all_timesteps

stacked_latents = torch.stack(all_latents, dim=1) if all_latents else None
stacked_log_probs = (
torch.stack(all_log_probs, dim=1) if all_log_probs and all_log_probs[0] is not None else None
)
stacked_timesteps = (
torch.stack(all_timesteps).unsqueeze(0).expand(state.latents.shape[0], -1) if all_timesteps else None
)

output.output = _maybe_to_cpu(output.output)
output.custom_output = {
"all_latents": _maybe_to_cpu(stacked_latents),
"all_log_probs": _maybe_to_cpu(stacked_log_probs),
"all_timesteps": _maybe_to_cpu(stacked_timesteps),
"prompt_embeds": _maybe_to_cpu(state.prompt_embeds),
"prompt_embeds_mask": _maybe_to_cpu(state.prompt_embeds_mask),
"negative_prompt_embeds": _maybe_to_cpu(state.negative_prompt_embeds),
"negative_prompt_embeds_mask": _maybe_to_cpu(state.negative_prompt_embeds_mask),
}
return output

def forward(
self,
req: OmniDiffusionRequest,
Expand Down