Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why we need to add so many things to make it work..

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem is vllm-omni does not official support passing prompt_token_ids as input.
The main logic of pipeline still try to tokenize the prompt.
That why the custom pipeline need to align with new function prepare_encode added by step-wise function.
To make this stable compatible, support from vllm-omni side is better

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So vllm-omni does not support truly --skip-tokenizer-init in https://docs.vllm.ai/en/stable/configuration/engine_args/#modelconfig to accept prompt_token_ids.

Can we make a feature request for this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will work on it

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it conflict with the change with #66? which rename prompt_ids → prompt_token_ids for vllm-omni 0.20+

Copy link
Copy Markdown
Collaborator

@SamitHuang SamitHuang May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider compatibility for other algorithms beside FlowGRPO. Currently the training on mixgrpo is not converging well (reward mean increase slower than previsou, it seems this PR diables sde window in mixgrpo

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
from typing import Any, Literal

Expand All @@ -20,6 +21,7 @@
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.models.qwen_image import QwenImagePipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.worker.utils import DiffusionRequestState

from verl_omni.pipelines.model_base import VllmOmniPipelineBase
from verl_omni.pipelines.schedulers import FlowMatchSDEDiscreteScheduler
Expand Down Expand Up @@ -146,6 +148,199 @@ def encode_prompt(

return prompt_embeds, prompt_embeds_mask

def _extract_prompt_ids(self, prompts):
"""Extract prompt_ids/mask and their negatives from the OmniCustomPrompt list.

Falls back to tokenizing ``"prompt"`` / ``"negative_prompt"`` text fields
when ``prompt_ids`` is not provided (e.g. during the engine's dummy
warm-up run, which always submits a text prompt).
"""
prompt_ids = None
prompt_mask = None
negative_prompt_ids = None
negative_prompt_mask = None
if prompts:
p0 = prompts[0]
if isinstance(p0, dict):
prompt_ids = p0.get("prompt_ids", None)
prompt_mask = p0.get("prompt_mask", None)
negative_prompt_ids = p0.get("negative_prompt_ids", None)
negative_prompt_mask = p0.get("negative_prompt_mask", None)

# Fallback: tokenize raw text prompt (covers _dummy_run path).
if prompt_ids is None and p0.get("prompt"):
prompt_ids, prompt_mask = self._tokenize_text_prompt(p0["prompt"])
if negative_prompt_ids is None and p0.get("negative_prompt"):
negative_prompt_ids, negative_prompt_mask = self._tokenize_text_prompt(p0["negative_prompt"])
elif isinstance(p0, str):
prompt_ids, prompt_mask = self._tokenize_text_prompt(p0)
return prompt_ids, prompt_mask, negative_prompt_ids, negative_prompt_mask

def _tokenize_text_prompt(self, text: str | list[str]):
"""Tokenize a text prompt using the Qwen chat template (parent behavior)."""
prompt = [text] if isinstance(text, str) else text
txt = [self.prompt_template_encode.format(e) for e in prompt]
tokens = self.tokenizer(
txt,
padding=True,
truncation=False,
return_tensors="pt",
).to(self.device)
return tokens.input_ids, tokens.attention_mask

def prepare_encode(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To fully support step-wise execution for RL rollouts, this class likely needs to override the per-step execution method (e.g., execute_step or step). The current implementation overrides diffuse to collect all_log_probs and all_latents during the denoising loop. If step_execution is enabled in the engine, the engine will bypass diffuse and call the per-step method instead. Without an override that performs similar data collection and state updates (like incrementing state.step_index), these RL-specific fields will be missing from the final output.

self,
state: "DiffusionRequestState",
**kwargs: Any,
) -> "DiffusionRequestState":
"""Populate *state* with encoded prompts, latents, timesteps, and CFG config.

Override of ``QwenImagePipeline.prepare_encode`` that accepts pre-tokenized
``prompt_ids`` (and optional ``prompt_mask``) instead of raw text prompts,
matching the input contract of ``QwenImagePipelineWithLogProbForTest``.
"""
sampling = state.sampling
prompt_ids, prompt_mask, negative_prompt_ids, negative_prompt_mask = self._extract_prompt_ids(
state.prompts or []
)

# Normalize list inputs to tensors on device.
if isinstance(prompt_ids, list):
prompt_ids = torch.tensor(prompt_ids, device=self.device)
if isinstance(negative_prompt_ids, list):
negative_prompt_ids = torch.tensor(negative_prompt_ids, device=self.device)

if prompt_ids is None:
raise ValueError(
"QwenImagePipelineWithLogProbForTest.prepare_encode requires either "
"'prompt_ids' or a text 'prompt' in state.prompts[0]."
)
Comment on lines +214 to +217
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message contains a typo in the class name, referring to QwenImagePipelineWithLogProbForTest instead of QwenImagePipelineWithLogProb.

Suggested change
raise ValueError(
"QwenImagePipelineWithLogProbForTest.prepare_encode requires either "
"'prompt_ids' or a text 'prompt' in state.prompts[0]."
)
raise ValueError(
"QwenImagePipelineWithLogProb.prepare_encode requires either "
"'prompt_ids' or a text 'prompt' in state.prompts[0]."
)


height = sampling.height or self.default_sample_size * self.vae_scale_factor
width = sampling.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = sampling.num_inference_steps or 50
sigmas = sampling.sigmas
guidance_scale = sampling.guidance_scale if sampling.guidance_scale_provided else 1.0
num_images_per_prompt = sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1
true_cfg_scale = sampling.true_cfg_scale or 4.0
max_sequence_length = sampling.max_sequence_length or self.tokenizer_max_length

generator = sampling.generator
if generator is None and sampling.seed is not None:
generator = torch.Generator(device=self.device).manual_seed(sampling.seed)

self._guidance_scale = guidance_scale
self._attention_kwargs = kwargs.get("attention_kwargs") or {}
self._current_timestep = None
self._interrupt = False

if prompt_ids is not None:
batch_size = prompt_ids.shape[0] if prompt_ids.ndim == 2 else 1
else:
batch_size = 1

has_neg_prompt = negative_prompt_ids is not None
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt)

prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt_ids=prompt_ids,
attention_mask=prompt_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt_ids=negative_prompt_ids,
attention_mask=negative_prompt_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = None
negative_prompt_embeds_mask = None

num_channels_latents = self.transformer.in_channels // 4
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
self.device,
generator,
None,
)

img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for building img_shapes is already implemented in the build_img_shapes utility function in common.py. It is better to use the utility to avoid code duplication and ensure consistency.

Suggested change
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
img_shapes = build_img_shapes(height, width, batch_size, self.vae_scale_factor)
References
  1. Avoid code duplication by reusing existing helper functions for common logic, such as constructing image shapes.


timesteps, _ = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1])
self._num_timesteps = len(timesteps)

if self.transformer.guidance_embeds:
guidance = torch.full([1], guidance_scale, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)

req_scheduler = copy.deepcopy(self.scheduler)
req_scheduler.set_begin_index(0)

# Resolve SDE / log-prob knobs from sampling extra_args so that the
# step-execution path mirrors ``forward()``'s rollout behaviour.
extra = sampling.extra_args or {}
noise_level = _coalesce_not_none(extra.get("noise_level", None), 0.7)
sde_window_size = _coalesce_not_none(extra.get("sde_window_size", None), None)
sde_window_range = _coalesce_not_none(extra.get("sde_window_range", None), (0, 5))
sde_type = _coalesce_not_none(extra.get("sde_type", None), "sde")
logprobs = _coalesce_not_none(extra.get("logprobs", None), True)
if sde_window_size is not None:
start = torch.randint(
sde_window_range[0],
sde_window_range[1] - sde_window_size + 1,
(1,),
generator=generator,
device=self.device,
).item()
sde_window = (start, start + sde_window_size)
else:
sde_window = (0, len(timesteps) - 1)

state.prompt_embeds = prompt_embeds
state.prompt_embeds_mask = prompt_embeds_mask
state.negative_prompt_embeds = negative_prompt_embeds
state.negative_prompt_embeds_mask = negative_prompt_embeds_mask
state.latents = latents
state.timesteps = timesteps
state.step_index = 0
state.scheduler = req_scheduler
state.do_true_cfg = do_true_cfg
state.guidance = guidance
state.img_shapes = img_shapes
state.txt_seq_lens = txt_seq_lens
state.negative_txt_seq_lens = negative_txt_seq_lens
state.sampling.cfg_normalize = True
# Persist the resolved generator so ``step_scheduler`` (executed
# one step at a time by the step-execution engine) keeps drawing
# from the same RNG stream as ``forward()``.
state.sampling.generator = generator
# Rollout / SDE state consumed by ``step_scheduler`` and packaged
# into ``custom_output`` by ``post_decode``.
state.sde_window = sde_window
state.noise_level = noise_level
state.sde_type = sde_type
state.logprobs = logprobs
state.all_latents = []
state.all_log_probs = []
state.all_timesteps = []

return state

def diffuse(
self,
prompt_embeds,
Expand Down Expand Up @@ -276,6 +471,106 @@ def diffuse(
all_timesteps = torch.stack(all_timesteps).unsqueeze(0).expand(latents.shape[0], -1)
return latents, all_latents, all_log_probs, all_timesteps

def step_scheduler(
self,
state: DiffusionRequestState,
noise_pred: torch.Tensor,
**kwargs: Any,
) -> None:
"""One scheduler step that mirrors the per-iter body of :meth:`diffuse`.

The default ``QwenImagePipeline.step_scheduler`` calls the standard
scheduler.step without SDE noise / log-prob bookkeeping, which means
``step_execution=True`` would silently drop ``all_latents`` /
``all_log_probs`` / ``all_timesteps`` (and the
``prompt_embeds_mask`` consumer downstream would then receive a
``None`` value that turns into a non-tensor ``LinkedList`` inside the
training ``TensorDict``). Override here to keep the step-mode and
request-mode trajectories equivalent.
"""
del kwargs
if self.interrupt:
return

i = state.step_index
timestep_value = state.timesteps[i]
sde_window = state.sde_window

if i < sde_window[0]:
cur_noise_level = 0.0
elif i == sde_window[0]:
cur_noise_level = state.noise_level
state.all_latents.append(state.latents)
elif i > sde_window[0] and i < sde_window[1]:
cur_noise_level = state.noise_level
else:
cur_noise_level = 0.0

new_latents, log_prob, _, _ = state.scheduler.step(
noise_pred,
timestep_value,
state.latents,
generator=state.sampling.generator,
noise_level=cur_noise_level,
sde_type=state.sde_type,
return_logprobs=state.logprobs,
return_dict=False,
)
state.latents = new_latents

if i >= sde_window[0] and i < sde_window[1]:
state.all_latents.append(state.latents)
state.all_log_probs.append(log_prob)
state.all_timesteps.append(timestep_value)

state.step_index += 1

def post_decode(
self,
state: DiffusionRequestState,
**kwargs: Any,
) -> DiffusionOutput:
"""Decode final latents, package rollout trajectory, and move to CPU.

In ``step_execution`` mode the worker ships the returned
:class:`DiffusionOutput` across an inter-process MessageQueue to the
``vLLMOmniHttpServer`` actor. We must (a) move tensors to CPU so the
receiving process does not initialise a stray CUDA context on GPU 0,
and (b) populate ``custom_output`` with the trajectory fields that
:meth:`forward` produces, so downstream consumers
(``vllm_omni_async_server.generate`` ->
``embeds_padding_2_no_padding``) receive real tensors rather than
``None`` (which becomes a non-tensor ``LinkedList`` in the
``TensorDict`` and breaks ``mask.shape[0]``).
"""
output = super().post_decode(state, **kwargs)
if not isinstance(output, DiffusionOutput):
return output

all_latents = state.all_latents
all_log_probs = state.all_log_probs
all_timesteps = state.all_timesteps

stacked_latents = torch.stack(all_latents, dim=1) if all_latents else None
stacked_log_probs = (
torch.stack(all_log_probs, dim=1) if all_log_probs and all_log_probs[0] is not None else None
)
stacked_timesteps = (
torch.stack(all_timesteps).unsqueeze(0).expand(state.latents.shape[0], -1) if all_timesteps else None
)

output.output = _maybe_to_cpu(output.output)
output.custom_output = {
"all_latents": _maybe_to_cpu(stacked_latents),
"all_log_probs": _maybe_to_cpu(stacked_log_probs),
"all_timesteps": _maybe_to_cpu(stacked_timesteps),
"prompt_embeds": _maybe_to_cpu(state.prompt_embeds),
"prompt_embeds_mask": _maybe_to_cpu(state.prompt_embeds_mask),
"negative_prompt_embeds": _maybe_to_cpu(state.negative_prompt_embeds),
"negative_prompt_embeds_mask": _maybe_to_cpu(state.negative_prompt_embeds_mask),
}
return output

def forward(
self,
req: OmniDiffusionRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from typing import Any

from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.worker.utils import DiffusionRequestState

from verl_omni.pipelines.model_base import VllmOmniPipelineBase
from verl_omni.pipelines.qwen_image_flow_grpo.vllm_omni_rollout_adapter import QwenImagePipelineWithLogProb
Expand All @@ -51,20 +52,38 @@ class QwenImageMixGRPOPipelineWithLogProb(QwenImagePipelineWithLogProb):
"""Rollout pipeline for Qwen-Image with the MixGRPO algorithm."""

def forward(self, req: OmniDiffusionRequest, **kwargs: Any):
self._maybe_make_progressive_window(req, kwargs)
self._maybe_make_progressive_window(req.sampling_params.extra_args, kwargs)
return super().forward(req, **kwargs)

def prepare_encode(
self,
state: DiffusionRequestState,
**kwargs: Any,
) -> DiffusionRequestState:
"""Override to fix the SDE window before the base prepare_encode draws it.

In step-execution mode ``forward()`` is never called, so
``_maybe_make_progressive_window`` would never run. Calling it here,
against ``state.sampling.extra_args``, ensures that all rollouts in a
batch receive the same deterministic / seeded window regardless of
whether the pipeline runs in full-forward or step-execution mode.
"""
if state.sampling is not None:
if state.sampling.extra_args is None:
state.sampling.extra_args = {}
self._maybe_make_progressive_window(state.sampling.extra_args, kwargs)
return super().prepare_encode(state, **kwargs)

@staticmethod
def _maybe_make_progressive_window(req: OmniDiffusionRequest, kwargs: dict[str, Any]) -> None:
"""Mutate ``req.sampling_params.extra_args["sde_window_range"]`` in place
to fix the window start position.
def _maybe_make_progressive_window(extra: dict[str, Any], kwargs: dict[str, Any]) -> None:
"""Mutate *extra* (``sampling_params.extra_args``) in place to fix the
SDE window start position.

* ``progressive``: deterministic from ``global_steps``.
* ``random`` with ``sde_window_seed`` present: seeded per-step draw so
all ranks agree on the same window position for each training step.
* Otherwise: no-op -- the base pipeline's per-call random draw is used.
"""
extra = req.sampling_params.extra_args
strategy = extra.get("sample_strategy", "random")
size = extra.get("sde_window_size") or kwargs.get("sde_window_size")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ async def run_server(self, args: argparse.Namespace):
engine_args["enable_dummy_pipeline"] = True
engine_args["custom_pipeline_args"] = {"pipeline_class": pipeline_path}

engine_args["max_num_seqs"] = 256
engine_args["step_execution"] = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why hard-coded step_execution=True?
so as max_num_seqs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make step-wise continuous batching a configurable option, and default is False (since only qwen-image supports for now)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is a patch PR for quick validation. there will be a formal PR later. @knlnguyen1802


diffusion_master_port, diffusion_master_sock = get_free_port("127.0.0.1", with_alive_sock=True)
diffusion_master_sock.close()

Expand Down