Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/dpo_trainer/run_sd3_dpo_unified_reward.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SD3 DPO training on PickScore prompts, vllm_omni rollout
set -x

# Set WORKSPACE to any writable directory; defaults to $HOME.
WORKSPACE=${WORKSPACE:-$HOME}
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)

pickscore_train_path=${PICKSCORE_TRAIN_PATH:-$SCRIPT_DIR/../../datasets/pickscore/train.txt}
pickscore_test_path=${PICKSCORE_TEST_PATH:-$SCRIPT_DIR/../../datasets/pickscore/test.txt}

model_name=stabilityai/stable-diffusion-3.5-medium
reward_model_name=CodeGoat24/UnifiedReward-2.0-qwen3vl-8b
reward_function_path=verl_omni/utils/reward_score/unified_reward.py

NUM_GPUS_ACTOR_ROLLOUT_REWARD=2
NUM_GPUS_ACTOR_ROLLOUT=1
ROLLOUT_TP=1
REWARD_TP=2

ENGINE=vllm_omni
REWARD_ENGINE=vllm

python3 -m verl_omni.trainer.diffusion.main_dpo \
data=prompt_txt_data \
data.train_files=$pickscore_train_path \
data.val_files=$pickscore_test_path \
trainer.resume_mode=disable \
data.train_batch_size=4 \
data.max_prompt_length=256 \
actor_rollout_ref.model.path=$model_name \
actor_rollout_ref.model.algorithm=dpo \
actor_rollout_ref.actor.diffusion_loss.dpo_beta=2000.0 \
actor_rollout_ref.rollout.pipeline.height=256 \
actor_rollout_ref.rollout.pipeline.width=256 \
actor_rollout_ref.rollout.pipeline.num_inference_steps=25 \
actor_rollout_ref.model.lora_rank=64 \
actor_rollout_ref.model.lora_alpha=128 \
actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out']" \
actor_rollout_ref.actor.optim.lr=3e-4 \
actor_rollout_ref.actor.optim.weight_decay=0.0001 \
actor_rollout_ref.actor.ppo_mini_batch_size=4 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.diffusion_loss.loss_mode=dpo \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=$ENGINE \
actor_rollout_ref.rollout.n=$NUM_GPUS_ACTOR_ROLLOUT \
actor_rollout_ref.rollout.k_samples=4 \
actor_rollout_ref.rollout.agent.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / ROLLOUT_TP)) \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.pipeline.guidance_scale=4.0 \
actor_rollout_ref.rollout.pipeline.max_sequence_length=256 \
actor_rollout_ref.rollout.val_kwargs.pipeline.num_inference_steps=50 \
reward.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / REWARD_TP)) \
reward.reward_model.enable=True \
reward.reward_model.model_path=$reward_model_name \
reward.reward_model.rollout.name=$REWARD_ENGINE \
reward.reward_model.rollout.tensor_model_parallel_size=$REWARD_TP \
reward.custom_reward_function.path=$reward_function_path \
reward.custom_reward_function.name=compute_score_unified_reward \
trainer.logger='["console", "wandb"]' \
trainer.project_name=dpo \
trainer.experiment_name=sd3_dpo_unified_reward \
trainer.log_val_generations=8 \
trainer.val_before_train=True \
trainer.n_gpus_per_node=$NUM_GPUS_ACTOR_ROLLOUT_REWARD \
trainer.nnodes=1 \
trainer.save_freq=30 \
trainer.test_freq=30 \
trainer.total_epochs=15 \
trainer.total_training_steps=1000 "$@"
4 changes: 3 additions & 1 deletion verl_omni/agent_loop/diffusion_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA

extra_fields["raw_prompt"] = kwargs["raw_prompt"]

# truncate prompt ids to the prompt length
prompt_ids = output.prompt_ids[: self.rollout_config.prompt_length]
prompt_output = self.tokenizer.pad(
{"input_ids": output.prompt_ids},
{"input_ids": prompt_ids},
padding="max_length",
max_length=self.rollout_config.prompt_length,
return_tensors="pt",
Expand Down
49 changes: 49 additions & 0 deletions verl_omni/agent_loop/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Prompt helpers shared by diffusion agent loops."""

from collections.abc import Mapping
from typing import Any


def stringify_prompt_part(part: Any) -> str:
if part is None:
return ""
if isinstance(part, str):
return part
if isinstance(part, Mapping):
if "text" in part:
return stringify_prompt_part(part["text"])
if "content" in part:
return stringify_prompt_part(part["content"])
return ""
if isinstance(part, (list, tuple)):
return " ".join(text for item in part if (text := stringify_prompt_part(item)))
if hasattr(part, "tolist") and not isinstance(part, str):
return stringify_prompt_part(part.tolist())
return str(part)


def stringify_prompt_messages(messages: Any) -> str:
"""Extract plain text from chat-style prompt messages for diffusion pipelines."""
if isinstance(messages, Mapping):
return stringify_prompt_part(messages.get("content", messages.get("prompt", "")))
if isinstance(messages, str):
return messages
if hasattr(messages, "tolist") and not isinstance(messages, str):
messages = messages.tolist()
if isinstance(messages, (list, tuple)):
return "\n".join(text for message in messages if (text := stringify_prompt_messages(message)))
return stringify_prompt_part(messages)
7 changes: 7 additions & 0 deletions verl_omni/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from verl.utils.profiler import simple_timer

from verl_omni.agent_loop.diffusion_agent_loop import DiffusionAgentLoopOutput
from verl_omni.agent_loop.prompt_utils import stringify_prompt_messages

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand All @@ -43,6 +44,10 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> DiffusionAgent
"""
raw_prompt = kwargs["raw_prompt"]
raw_negative_prompt = kwargs.get("raw_negative_prompt")
prompt_text = stringify_prompt_messages(raw_prompt)
negative_prompt_text = (
stringify_prompt_messages(raw_negative_prompt) if raw_negative_prompt is not None else None
)

# 1. extract images and videos from messages
multi_modal_data = await self.process_vision_info(raw_prompt)
Expand All @@ -63,10 +68,12 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> DiffusionAgent
output = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=prompt_ids,
prompt_text=prompt_text,
sampling_params=sampling_params,
image_data=images,
video_data=videos,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_text=negative_prompt_text,
)
if metrics.get("num_preempted") is None:
metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1
Expand Down
2 changes: 2 additions & 0 deletions verl_omni/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from . import _patch # noqa: F401 — apply Ulysses mask fix
from .qwen_image_flow_grpo import * # noqa: F401, F403
from .qwen_image_mix_grpo import * # noqa: F401, F403
from .sd3_dpo import * # noqa: F401, F403

__all__ = list(qwen_image_flow_grpo.__all__)
__all__ += list(qwen_image_mix_grpo.__all__)
__all__ += list(sd3_dpo.__all__)
20 changes: 20 additions & 0 deletions verl_omni/pipelines/sd3_dpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SD3 (Stable Diffusion 3.x) pipeline module for verl-omni diffusion training.
"""

from .diffusers_training_adapter import SD3Adapter

__all__ = ["SD3Adapter"]
Loading