-
Notifications
You must be signed in to change notification settings - Fork 36
[Bugfix] Enable step-wise execution #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3c5b5da
e6712fc
628d43a
eefca1b
101ec56
a8f4dd7
16f9152
56e7d6e
d9dbcd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider compatibility for other algorithms beside FlowGRPO. Currently the training on mixgrpo is not converging well (reward mean increase slower than previsou, it seems this PR diables sde window in mixgrpo |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||
|
|
||||||||||||||||||
| import copy | ||||||||||||||||||
| import os | ||||||||||||||||||
| from typing import Any, Literal | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -20,6 +21,7 @@ | |||||||||||||||||
| from vllm_omni.diffusion.distributed.utils import get_local_device | ||||||||||||||||||
| from vllm_omni.diffusion.models.qwen_image import QwenImagePipeline | ||||||||||||||||||
| from vllm_omni.diffusion.request import OmniDiffusionRequest | ||||||||||||||||||
| from vllm_omni.diffusion.worker.utils import DiffusionRequestState | ||||||||||||||||||
|
|
||||||||||||||||||
| from verl_omni.pipelines.model_base import VllmOmniPipelineBase | ||||||||||||||||||
| from verl_omni.pipelines.schedulers import FlowMatchSDEDiscreteScheduler | ||||||||||||||||||
|
|
@@ -146,6 +148,199 @@ def encode_prompt( | |||||||||||||||||
|
|
||||||||||||||||||
| return prompt_embeds, prompt_embeds_mask | ||||||||||||||||||
|
|
||||||||||||||||||
| def _extract_prompt_ids(self, prompts): | ||||||||||||||||||
| """Extract prompt_ids/mask and their negatives from the OmniCustomPrompt list. | ||||||||||||||||||
|
|
||||||||||||||||||
| Falls back to tokenizing ``"prompt"`` / ``"negative_prompt"`` text fields | ||||||||||||||||||
| when ``prompt_ids`` is not provided (e.g. during the engine's dummy | ||||||||||||||||||
| warm-up run, which always submits a text prompt). | ||||||||||||||||||
| """ | ||||||||||||||||||
| prompt_ids = None | ||||||||||||||||||
| prompt_mask = None | ||||||||||||||||||
| negative_prompt_ids = None | ||||||||||||||||||
| negative_prompt_mask = None | ||||||||||||||||||
| if prompts: | ||||||||||||||||||
| p0 = prompts[0] | ||||||||||||||||||
| if isinstance(p0, dict): | ||||||||||||||||||
| prompt_ids = p0.get("prompt_ids", None) | ||||||||||||||||||
| prompt_mask = p0.get("prompt_mask", None) | ||||||||||||||||||
| negative_prompt_ids = p0.get("negative_prompt_ids", None) | ||||||||||||||||||
| negative_prompt_mask = p0.get("negative_prompt_mask", None) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Fallback: tokenize raw text prompt (covers _dummy_run path). | ||||||||||||||||||
| if prompt_ids is None and p0.get("prompt"): | ||||||||||||||||||
| prompt_ids, prompt_mask = self._tokenize_text_prompt(p0["prompt"]) | ||||||||||||||||||
| if negative_prompt_ids is None and p0.get("negative_prompt"): | ||||||||||||||||||
| negative_prompt_ids, negative_prompt_mask = self._tokenize_text_prompt(p0["negative_prompt"]) | ||||||||||||||||||
| elif isinstance(p0, str): | ||||||||||||||||||
| prompt_ids, prompt_mask = self._tokenize_text_prompt(p0) | ||||||||||||||||||
| return prompt_ids, prompt_mask, negative_prompt_ids, negative_prompt_mask | ||||||||||||||||||
|
|
||||||||||||||||||
| def _tokenize_text_prompt(self, text: str | list[str]): | ||||||||||||||||||
| """Tokenize a text prompt using the Qwen chat template (parent behavior).""" | ||||||||||||||||||
| prompt = [text] if isinstance(text, str) else text | ||||||||||||||||||
| txt = [self.prompt_template_encode.format(e) for e in prompt] | ||||||||||||||||||
| tokens = self.tokenizer( | ||||||||||||||||||
| txt, | ||||||||||||||||||
| padding=True, | ||||||||||||||||||
| truncation=False, | ||||||||||||||||||
| return_tensors="pt", | ||||||||||||||||||
| ).to(self.device) | ||||||||||||||||||
| return tokens.input_ids, tokens.attention_mask | ||||||||||||||||||
|
|
||||||||||||||||||
| def prepare_encode( | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To fully support step-wise execution for RL rollouts, this class likely needs to override the per-step execution method (e.g., |
||||||||||||||||||
| self, | ||||||||||||||||||
| state: "DiffusionRequestState", | ||||||||||||||||||
| **kwargs: Any, | ||||||||||||||||||
| ) -> "DiffusionRequestState": | ||||||||||||||||||
| """Populate *state* with encoded prompts, latents, timesteps, and CFG config. | ||||||||||||||||||
|
|
||||||||||||||||||
| Override of ``QwenImagePipeline.prepare_encode`` that accepts pre-tokenized | ||||||||||||||||||
| ``prompt_ids`` (and optional ``prompt_mask``) instead of raw text prompts, | ||||||||||||||||||
| matching the input contract of ``QwenImagePipelineWithLogProbForTest``. | ||||||||||||||||||
| """ | ||||||||||||||||||
| sampling = state.sampling | ||||||||||||||||||
| prompt_ids, prompt_mask, negative_prompt_ids, negative_prompt_mask = self._extract_prompt_ids( | ||||||||||||||||||
| state.prompts or [] | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Normalize list inputs to tensors on device. | ||||||||||||||||||
| if isinstance(prompt_ids, list): | ||||||||||||||||||
| prompt_ids = torch.tensor(prompt_ids, device=self.device) | ||||||||||||||||||
| if isinstance(negative_prompt_ids, list): | ||||||||||||||||||
| negative_prompt_ids = torch.tensor(negative_prompt_ids, device=self.device) | ||||||||||||||||||
|
|
||||||||||||||||||
| if prompt_ids is None: | ||||||||||||||||||
| raise ValueError( | ||||||||||||||||||
| "QwenImagePipelineWithLogProbForTest.prepare_encode requires either " | ||||||||||||||||||
| "'prompt_ids' or a text 'prompt' in state.prompts[0]." | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
+214
to
+217
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message contains a typo in the class name, referring to
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| height = sampling.height or self.default_sample_size * self.vae_scale_factor | ||||||||||||||||||
| width = sampling.width or self.default_sample_size * self.vae_scale_factor | ||||||||||||||||||
| num_inference_steps = sampling.num_inference_steps or 50 | ||||||||||||||||||
| sigmas = sampling.sigmas | ||||||||||||||||||
| guidance_scale = sampling.guidance_scale if sampling.guidance_scale_provided else 1.0 | ||||||||||||||||||
| num_images_per_prompt = sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1 | ||||||||||||||||||
| true_cfg_scale = sampling.true_cfg_scale or 4.0 | ||||||||||||||||||
| max_sequence_length = sampling.max_sequence_length or self.tokenizer_max_length | ||||||||||||||||||
|
|
||||||||||||||||||
| generator = sampling.generator | ||||||||||||||||||
| if generator is None and sampling.seed is not None: | ||||||||||||||||||
| generator = torch.Generator(device=self.device).manual_seed(sampling.seed) | ||||||||||||||||||
|
|
||||||||||||||||||
| self._guidance_scale = guidance_scale | ||||||||||||||||||
| self._attention_kwargs = kwargs.get("attention_kwargs") or {} | ||||||||||||||||||
| self._current_timestep = None | ||||||||||||||||||
| self._interrupt = False | ||||||||||||||||||
|
|
||||||||||||||||||
| if prompt_ids is not None: | ||||||||||||||||||
| batch_size = prompt_ids.shape[0] if prompt_ids.ndim == 2 else 1 | ||||||||||||||||||
| else: | ||||||||||||||||||
| batch_size = 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| has_neg_prompt = negative_prompt_ids is not None | ||||||||||||||||||
| do_true_cfg = true_cfg_scale > 1 and has_neg_prompt | ||||||||||||||||||
| self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) | ||||||||||||||||||
|
|
||||||||||||||||||
| prompt_embeds, prompt_embeds_mask = self.encode_prompt( | ||||||||||||||||||
| prompt_ids=prompt_ids, | ||||||||||||||||||
| attention_mask=prompt_mask, | ||||||||||||||||||
| num_images_per_prompt=num_images_per_prompt, | ||||||||||||||||||
| max_sequence_length=max_sequence_length, | ||||||||||||||||||
| ) | ||||||||||||||||||
| if do_true_cfg: | ||||||||||||||||||
| negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( | ||||||||||||||||||
| prompt_ids=negative_prompt_ids, | ||||||||||||||||||
| attention_mask=negative_prompt_mask, | ||||||||||||||||||
| num_images_per_prompt=num_images_per_prompt, | ||||||||||||||||||
| max_sequence_length=max_sequence_length, | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| negative_prompt_embeds = None | ||||||||||||||||||
| negative_prompt_embeds_mask = None | ||||||||||||||||||
|
|
||||||||||||||||||
| num_channels_latents = self.transformer.in_channels // 4 | ||||||||||||||||||
| latents = self.prepare_latents( | ||||||||||||||||||
| batch_size * num_images_per_prompt, | ||||||||||||||||||
| num_channels_latents, | ||||||||||||||||||
| height, | ||||||||||||||||||
| width, | ||||||||||||||||||
| prompt_embeds.dtype, | ||||||||||||||||||
| self.device, | ||||||||||||||||||
| generator, | ||||||||||||||||||
| None, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for building
Suggested change
References
|
||||||||||||||||||
|
|
||||||||||||||||||
| timesteps, _ = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) | ||||||||||||||||||
| self._num_timesteps = len(timesteps) | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.transformer.guidance_embeds: | ||||||||||||||||||
| guidance = torch.full([1], guidance_scale, dtype=torch.float32) | ||||||||||||||||||
| guidance = guidance.expand(latents.shape[0]) | ||||||||||||||||||
| else: | ||||||||||||||||||
| guidance = None | ||||||||||||||||||
|
|
||||||||||||||||||
| txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None | ||||||||||||||||||
| negative_txt_seq_lens = ( | ||||||||||||||||||
| negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| req_scheduler = copy.deepcopy(self.scheduler) | ||||||||||||||||||
| req_scheduler.set_begin_index(0) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Resolve SDE / log-prob knobs from sampling extra_args so that the | ||||||||||||||||||
| # step-execution path mirrors ``forward()``'s rollout behaviour. | ||||||||||||||||||
| extra = sampling.extra_args or {} | ||||||||||||||||||
| noise_level = _coalesce_not_none(extra.get("noise_level", None), 0.7) | ||||||||||||||||||
| sde_window_size = _coalesce_not_none(extra.get("sde_window_size", None), None) | ||||||||||||||||||
| sde_window_range = _coalesce_not_none(extra.get("sde_window_range", None), (0, 5)) | ||||||||||||||||||
| sde_type = _coalesce_not_none(extra.get("sde_type", None), "sde") | ||||||||||||||||||
| logprobs = _coalesce_not_none(extra.get("logprobs", None), True) | ||||||||||||||||||
| if sde_window_size is not None: | ||||||||||||||||||
| start = torch.randint( | ||||||||||||||||||
| sde_window_range[0], | ||||||||||||||||||
| sde_window_range[1] - sde_window_size + 1, | ||||||||||||||||||
| (1,), | ||||||||||||||||||
| generator=generator, | ||||||||||||||||||
| device=self.device, | ||||||||||||||||||
| ).item() | ||||||||||||||||||
| sde_window = (start, start + sde_window_size) | ||||||||||||||||||
| else: | ||||||||||||||||||
| sde_window = (0, len(timesteps) - 1) | ||||||||||||||||||
|
|
||||||||||||||||||
| state.prompt_embeds = prompt_embeds | ||||||||||||||||||
| state.prompt_embeds_mask = prompt_embeds_mask | ||||||||||||||||||
| state.negative_prompt_embeds = negative_prompt_embeds | ||||||||||||||||||
| state.negative_prompt_embeds_mask = negative_prompt_embeds_mask | ||||||||||||||||||
| state.latents = latents | ||||||||||||||||||
| state.timesteps = timesteps | ||||||||||||||||||
| state.step_index = 0 | ||||||||||||||||||
| state.scheduler = req_scheduler | ||||||||||||||||||
| state.do_true_cfg = do_true_cfg | ||||||||||||||||||
| state.guidance = guidance | ||||||||||||||||||
| state.img_shapes = img_shapes | ||||||||||||||||||
| state.txt_seq_lens = txt_seq_lens | ||||||||||||||||||
| state.negative_txt_seq_lens = negative_txt_seq_lens | ||||||||||||||||||
| state.sampling.cfg_normalize = True | ||||||||||||||||||
| # Persist the resolved generator so ``step_scheduler`` (executed | ||||||||||||||||||
| # one step at a time by the step-execution engine) keeps drawing | ||||||||||||||||||
| # from the same RNG stream as ``forward()``. | ||||||||||||||||||
| state.sampling.generator = generator | ||||||||||||||||||
| # Rollout / SDE state consumed by ``step_scheduler`` and packaged | ||||||||||||||||||
| # into ``custom_output`` by ``post_decode``. | ||||||||||||||||||
| state.sde_window = sde_window | ||||||||||||||||||
| state.noise_level = noise_level | ||||||||||||||||||
| state.sde_type = sde_type | ||||||||||||||||||
| state.logprobs = logprobs | ||||||||||||||||||
| state.all_latents = [] | ||||||||||||||||||
| state.all_log_probs = [] | ||||||||||||||||||
| state.all_timesteps = [] | ||||||||||||||||||
|
|
||||||||||||||||||
| return state | ||||||||||||||||||
|
|
||||||||||||||||||
| def diffuse( | ||||||||||||||||||
| self, | ||||||||||||||||||
| prompt_embeds, | ||||||||||||||||||
|
|
@@ -276,6 +471,106 @@ def diffuse( | |||||||||||||||||
| all_timesteps = torch.stack(all_timesteps).unsqueeze(0).expand(latents.shape[0], -1) | ||||||||||||||||||
| return latents, all_latents, all_log_probs, all_timesteps | ||||||||||||||||||
|
|
||||||||||||||||||
| def step_scheduler( | ||||||||||||||||||
| self, | ||||||||||||||||||
| state: DiffusionRequestState, | ||||||||||||||||||
| noise_pred: torch.Tensor, | ||||||||||||||||||
| **kwargs: Any, | ||||||||||||||||||
| ) -> None: | ||||||||||||||||||
| """One scheduler step that mirrors the per-iter body of :meth:`diffuse`. | ||||||||||||||||||
|
|
||||||||||||||||||
| The default ``QwenImagePipeline.step_scheduler`` calls the standard | ||||||||||||||||||
| scheduler.step without SDE noise / log-prob bookkeeping, which means | ||||||||||||||||||
| ``step_execution=True`` would silently drop ``all_latents`` / | ||||||||||||||||||
| ``all_log_probs`` / ``all_timesteps`` (and the | ||||||||||||||||||
| ``prompt_embeds_mask`` consumer downstream would then receive a | ||||||||||||||||||
| ``None`` value that turns into a non-tensor ``LinkedList`` inside the | ||||||||||||||||||
| training ``TensorDict``). Override here to keep the step-mode and | ||||||||||||||||||
| request-mode trajectories equivalent. | ||||||||||||||||||
| """ | ||||||||||||||||||
| del kwargs | ||||||||||||||||||
| if self.interrupt: | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| i = state.step_index | ||||||||||||||||||
| timestep_value = state.timesteps[i] | ||||||||||||||||||
| sde_window = state.sde_window | ||||||||||||||||||
|
|
||||||||||||||||||
| if i < sde_window[0]: | ||||||||||||||||||
| cur_noise_level = 0.0 | ||||||||||||||||||
| elif i == sde_window[0]: | ||||||||||||||||||
| cur_noise_level = state.noise_level | ||||||||||||||||||
| state.all_latents.append(state.latents) | ||||||||||||||||||
| elif i > sde_window[0] and i < sde_window[1]: | ||||||||||||||||||
| cur_noise_level = state.noise_level | ||||||||||||||||||
| else: | ||||||||||||||||||
| cur_noise_level = 0.0 | ||||||||||||||||||
|
|
||||||||||||||||||
| new_latents, log_prob, _, _ = state.scheduler.step( | ||||||||||||||||||
| noise_pred, | ||||||||||||||||||
| timestep_value, | ||||||||||||||||||
| state.latents, | ||||||||||||||||||
| generator=state.sampling.generator, | ||||||||||||||||||
| noise_level=cur_noise_level, | ||||||||||||||||||
| sde_type=state.sde_type, | ||||||||||||||||||
| return_logprobs=state.logprobs, | ||||||||||||||||||
| return_dict=False, | ||||||||||||||||||
| ) | ||||||||||||||||||
| state.latents = new_latents | ||||||||||||||||||
|
|
||||||||||||||||||
| if i >= sde_window[0] and i < sde_window[1]: | ||||||||||||||||||
| state.all_latents.append(state.latents) | ||||||||||||||||||
| state.all_log_probs.append(log_prob) | ||||||||||||||||||
| state.all_timesteps.append(timestep_value) | ||||||||||||||||||
|
|
||||||||||||||||||
| state.step_index += 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| def post_decode( | ||||||||||||||||||
| self, | ||||||||||||||||||
| state: DiffusionRequestState, | ||||||||||||||||||
| **kwargs: Any, | ||||||||||||||||||
| ) -> DiffusionOutput: | ||||||||||||||||||
| """Decode final latents, package rollout trajectory, and move to CPU. | ||||||||||||||||||
|
|
||||||||||||||||||
| In ``step_execution`` mode the worker ships the returned | ||||||||||||||||||
| :class:`DiffusionOutput` across an inter-process MessageQueue to the | ||||||||||||||||||
| ``vLLMOmniHttpServer`` actor. We must (a) move tensors to CPU so the | ||||||||||||||||||
| receiving process does not initialise a stray CUDA context on GPU 0, | ||||||||||||||||||
| and (b) populate ``custom_output`` with the trajectory fields that | ||||||||||||||||||
| :meth:`forward` produces, so downstream consumers | ||||||||||||||||||
| (``vllm_omni_async_server.generate`` -> | ||||||||||||||||||
| ``embeds_padding_2_no_padding``) receive real tensors rather than | ||||||||||||||||||
| ``None`` (which becomes a non-tensor ``LinkedList`` in the | ||||||||||||||||||
| ``TensorDict`` and breaks ``mask.shape[0]``). | ||||||||||||||||||
| """ | ||||||||||||||||||
| output = super().post_decode(state, **kwargs) | ||||||||||||||||||
| if not isinstance(output, DiffusionOutput): | ||||||||||||||||||
| return output | ||||||||||||||||||
|
|
||||||||||||||||||
| all_latents = state.all_latents | ||||||||||||||||||
| all_log_probs = state.all_log_probs | ||||||||||||||||||
| all_timesteps = state.all_timesteps | ||||||||||||||||||
|
|
||||||||||||||||||
| stacked_latents = torch.stack(all_latents, dim=1) if all_latents else None | ||||||||||||||||||
| stacked_log_probs = ( | ||||||||||||||||||
| torch.stack(all_log_probs, dim=1) if all_log_probs and all_log_probs[0] is not None else None | ||||||||||||||||||
| ) | ||||||||||||||||||
| stacked_timesteps = ( | ||||||||||||||||||
| torch.stack(all_timesteps).unsqueeze(0).expand(state.latents.shape[0], -1) if all_timesteps else None | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| output.output = _maybe_to_cpu(output.output) | ||||||||||||||||||
| output.custom_output = { | ||||||||||||||||||
| "all_latents": _maybe_to_cpu(stacked_latents), | ||||||||||||||||||
| "all_log_probs": _maybe_to_cpu(stacked_log_probs), | ||||||||||||||||||
| "all_timesteps": _maybe_to_cpu(stacked_timesteps), | ||||||||||||||||||
| "prompt_embeds": _maybe_to_cpu(state.prompt_embeds), | ||||||||||||||||||
| "prompt_embeds_mask": _maybe_to_cpu(state.prompt_embeds_mask), | ||||||||||||||||||
| "negative_prompt_embeds": _maybe_to_cpu(state.negative_prompt_embeds), | ||||||||||||||||||
| "negative_prompt_embeds_mask": _maybe_to_cpu(state.negative_prompt_embeds_mask), | ||||||||||||||||||
| } | ||||||||||||||||||
| return output | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward( | ||||||||||||||||||
| self, | ||||||||||||||||||
| req: OmniDiffusionRequest, | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,6 +133,9 @@ async def run_server(self, args: argparse.Namespace): | |
| engine_args["enable_dummy_pipeline"] = True | ||
| engine_args["custom_pipeline_args"] = {"pipeline_class": pipeline_path} | ||
|
|
||
| engine_args["max_num_seqs"] = 256 | ||
| engine_args["step_execution"] = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why hard-coded step_execution=True?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should make step-wise continuous batching a configurable option, and default is False (since only qwen-image supports for now)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it is a patch PR for quick validation. there will be a formal PR later. @knlnguyen1802 |
||
|
|
||
| diffusion_master_port, diffusion_master_sock = get_free_port("127.0.0.1", with_alive_sock=True) | ||
| diffusion_master_sock.close() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand why we need to add so many things to make it work..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main problem is vllm-omni does not official support passing prompt_token_ids as input.
The main logic of pipeline still try to tokenize the prompt.
That why the custom pipeline need to align with new function prepare_encode added by step-wise function.
To make this stable compatible, support from vllm-omni side is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. So vllm-omni does not support truly
--skip-tokenizer-initin https://docs.vllm.ai/en/stable/configuration/engine_args/#modelconfig to acceptprompt_token_ids.Can we make a feature request for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will work on it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @SamitHuang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it conflict with the change with #66? which rename prompt_ids → prompt_token_ids for vllm-omni 0.20+