Skip to content

[Algo] DPO (online) training with SD3.5-medium#77

Closed
wtomin wants to merge 9 commits into
verl-project:mainfrom
wtomin:dpo-sd35-refactor
Closed

[Algo] DPO (online) training with SD3.5-medium#77
wtomin wants to merge 9 commits into
verl-project:mainfrom
wtomin:dpo-sd35-refactor

Conversation

@wtomin
Copy link
Copy Markdown
Contributor

@wtomin wtomin commented May 13, 2026

What does this PR do?

This PR aims to train the stable-diffusion-3.5-medium with online-DPO agorithm.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, vllm_omni, rollout, trainer, ci, training_utils, recipe, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, diffusion, omni, tests, docker
    • If this PR involves multiple modules, separate them with , like [diffusion, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][diffusion, fsdp] feat: new rollout scheduler

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Direct Preference Optimization (DPO) support for Stable Diffusion 3.x models, adding a RayDPOTrainer, FM-DPO loss functions, and a SD3Adapter for managing model components. It also implements a PromptTxtRLDataset for text-based prompts and integrates UnifiedReward 2.0 for preference scoring. Reviewer feedback emphasizes optimizing memory by reusing embeddings from the rollout engine instead of reloading heavy encoders, reducing code duplication with centralized utility functions, improving image normalization heuristics, and preventing metadata loss during batch processing.

Comment on lines +106 to +146
def build_final_image_dpo_components(
cls,
model_config: DiffusionModelConfig,
*,
device: torch.device | int | str,
dtype: torch.dtype,
) -> dict:
device = cls._normalize_device(device)
model_path = model_config.local_path
local_files_only = os.path.exists(model_path)
components = {
"tokenizer": CLIPTokenizer.from_pretrained(
model_path, subfolder="tokenizer", local_files_only=local_files_only
),
"tokenizer_2": CLIPTokenizer.from_pretrained(
model_path, subfolder="tokenizer_2", local_files_only=local_files_only
),
"tokenizer_3": T5Tokenizer.from_pretrained(
model_path, subfolder="tokenizer_3", local_files_only=local_files_only
),
"text_encoder": CLIPTextModelWithProjection.from_pretrained(
model_path, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
),
"text_encoder_2": CLIPTextModelWithProjection.from_pretrained(
model_path, subfolder="text_encoder_2", torch_dtype=dtype, local_files_only=local_files_only
),
"text_encoder_3": T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder_3", torch_dtype=dtype, local_files_only=local_files_only
),
"vae": AutoencoderKL.from_pretrained(
model_path, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only
),
}
for component in components.values():
if hasattr(component, "to"):
component.to(device)
if hasattr(component, "eval"):
component.eval()
if hasattr(component, "requires_grad_"):
component.requires_grad_(False)
return components
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Loading full text encoders (CLIP, T5-XXL) and VAE on every training rank is extremely memory-intensive and redundant. T5-XXL alone can consume ~22GB of GPU memory. Since the rollout engine already computes and returns prompt_embeds and all_latents in the extra_fields, the trainer should reuse these provided embeddings instead of re-encoding them. This would significantly reduce the memory footprint and avoid potential OOM issues on training GPUs.

Comment thread verl_omni/pipelines/sd3_dpo/diffusers_training_adapter.py
Comment thread verl_omni/pipelines/sd3_dpo/diffusers_training_adapter.py
Comment on lines +267 to +268
if images.max() > 2:
images = images / 255.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The heuristic if images.max() > 2: to detect unnormalized images is fragile. If an image is normalized to [0, 1] but contains out-of-range values (e.g., due to noise), or if a [0, 255] image is very dark, this check might fail. It is safer to check the dtype (e.g., torch.uint8) or use a more robust range check (e.g., images.max() > 1.0). Additionally, if the rollout engine is guaranteed to return [0, 1] floats, this check may be unnecessary.

ray.init(**OmegaConf.to_container(ray_init_kwargs))

if task_runner_class is None:
task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding num_cpus=1 for the TaskRunner might lead to performance bottlenecks during dataset loading and tokenization (lines 214-230), which are executed within the TaskRunner. It is recommended to make this configurable or set a more reasonable default based on the workload.

Comment on lines +1160 to +1164
gen_batch.meta_info = {
"recompute_log_prob": False,
"validate": False,
"global_steps": self.global_steps,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Overwriting gen_batch.meta_info with a new dictionary loses any existing metadata copied during the pop operation in _get_gen_batch. It is safer to use update() to preserve existing metadata.

Suggested change
gen_batch.meta_info = {
"recompute_log_prob": False,
"validate": False,
"global_steps": self.global_steps,
}
gen_batch.meta_info.update({
"recompute_log_prob": False,
"validate": False,
"global_steps": self.global_steps,
})

Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may write a separate trainer for dpo/sft in different file

@wtomin wtomin force-pushed the dpo-sd35-refactor branch 5 times, most recently from bfb13f3 to 2b7d9d0 Compare May 13, 2026 08:37
Co-authored-by: Cursor <cursoragent@cursor.com>

update the dpo sd3.5

compatible to sd3.0

updates

allow minimal edit

edit name

new text dataset

text dataset

update default reward fn

update the dataset extra_info

update trainer init

fix import error

change default dpo trainer yaml

change model yaml path

wrong yaml

fix tokenizer path

fix target config

update the yaml

update the yaml

revise sd35 model yaml

default chat template for sd3.5

refactor diffusion ray trainer

no custom chat templte

unified reward function

reward score simple test

fix model name

change to unified reward 2B

update exp name

fix the dataset path

apply custom chat template to sd3.5

use ode euler scheduler

bypass compute_log_prob in ymal
@wtomin wtomin force-pushed the dpo-sd35-refactor branch from 2b7d9d0 to 4ec25e9 Compare May 13, 2026 08:45
@wtomin
Copy link
Copy Markdown
Contributor Author

wtomin commented May 20, 2026

Close with #95

@wtomin wtomin closed this May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants