[Algo] DPO (online) training with SD3.5-medium#77
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Direct Preference Optimization (DPO) support for Stable Diffusion 3.x models, adding a RayDPOTrainer, FM-DPO loss functions, and a SD3Adapter for managing model components. It also implements a PromptTxtRLDataset for text-based prompts and integrates UnifiedReward 2.0 for preference scoring. Reviewer feedback emphasizes optimizing memory by reusing embeddings from the rollout engine instead of reloading heavy encoders, reducing code duplication with centralized utility functions, improving image normalization heuristics, and preventing metadata loss during batch processing.
| def build_final_image_dpo_components( | ||
| cls, | ||
| model_config: DiffusionModelConfig, | ||
| *, | ||
| device: torch.device | int | str, | ||
| dtype: torch.dtype, | ||
| ) -> dict: | ||
| device = cls._normalize_device(device) | ||
| model_path = model_config.local_path | ||
| local_files_only = os.path.exists(model_path) | ||
| components = { | ||
| "tokenizer": CLIPTokenizer.from_pretrained( | ||
| model_path, subfolder="tokenizer", local_files_only=local_files_only | ||
| ), | ||
| "tokenizer_2": CLIPTokenizer.from_pretrained( | ||
| model_path, subfolder="tokenizer_2", local_files_only=local_files_only | ||
| ), | ||
| "tokenizer_3": T5Tokenizer.from_pretrained( | ||
| model_path, subfolder="tokenizer_3", local_files_only=local_files_only | ||
| ), | ||
| "text_encoder": CLIPTextModelWithProjection.from_pretrained( | ||
| model_path, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only | ||
| ), | ||
| "text_encoder_2": CLIPTextModelWithProjection.from_pretrained( | ||
| model_path, subfolder="text_encoder_2", torch_dtype=dtype, local_files_only=local_files_only | ||
| ), | ||
| "text_encoder_3": T5EncoderModel.from_pretrained( | ||
| model_path, subfolder="text_encoder_3", torch_dtype=dtype, local_files_only=local_files_only | ||
| ), | ||
| "vae": AutoencoderKL.from_pretrained( | ||
| model_path, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only | ||
| ), | ||
| } | ||
| for component in components.values(): | ||
| if hasattr(component, "to"): | ||
| component.to(device) | ||
| if hasattr(component, "eval"): | ||
| component.eval() | ||
| if hasattr(component, "requires_grad_"): | ||
| component.requires_grad_(False) | ||
| return components |
There was a problem hiding this comment.
Loading full text encoders (CLIP, T5-XXL) and VAE on every training rank is extremely memory-intensive and redundant. T5-XXL alone can consume ~22GB of GPU memory. Since the rollout engine already computes and returns prompt_embeds and all_latents in the extra_fields, the trainer should reuse these provided embeddings instead of re-encoding them. This would significantly reduce the memory footprint and avoid potential OOM issues on training GPUs.
| if images.max() > 2: | ||
| images = images / 255.0 |
There was a problem hiding this comment.
The heuristic if images.max() > 2: to detect unnormalized images is fragile. If an image is normalized to [0, 1] but contains out-of-range values (e.g., due to noise), or if a [0, 255] image is very dark, this check might fail. It is safer to check the dtype (e.g., torch.uint8) or use a more robust range check (e.g., images.max() > 1.0). Additionally, if the rollout engine is guaranteed to return [0, 1] floats, this check may be unnecessary.
| ray.init(**OmegaConf.to_container(ray_init_kwargs)) | ||
|
|
||
| if task_runner_class is None: | ||
| task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head |
There was a problem hiding this comment.
| gen_batch.meta_info = { | ||
| "recompute_log_prob": False, | ||
| "validate": False, | ||
| "global_steps": self.global_steps, | ||
| } |
There was a problem hiding this comment.
Overwriting gen_batch.meta_info with a new dictionary loses any existing metadata copied during the pop operation in _get_gen_batch. It is safer to use update() to preserve existing metadata.
| gen_batch.meta_info = { | |
| "recompute_log_prob": False, | |
| "validate": False, | |
| "global_steps": self.global_steps, | |
| } | |
| gen_batch.meta_info.update({ | |
| "recompute_log_prob": False, | |
| "validate": False, | |
| "global_steps": self.global_steps, | |
| }) |
There was a problem hiding this comment.
may write a separate trainer for dpo/sft in different file
bfb13f3 to
2b7d9d0
Compare
Co-authored-by: Cursor <cursoragent@cursor.com> update the dpo sd3.5 compatible to sd3.0 updates allow minimal edit edit name new text dataset text dataset update default reward fn update the dataset extra_info update trainer init fix import error change default dpo trainer yaml change model yaml path wrong yaml fix tokenizer path fix target config update the yaml update the yaml revise sd35 model yaml default chat template for sd3.5 refactor diffusion ray trainer no custom chat templte unified reward function reward score simple test fix model name change to unified reward 2B update exp name fix the dataset path apply custom chat template to sd3.5 use ode euler scheduler bypass compute_log_prob in ymal
|
Close with #95 |
What does this PR do?
This PR aims to train the
stable-diffusion-3.5-mediumwith online-DPO agorithm.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,vllm_omni,rollout,trainer,ci,training_utils,recipe,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,diffusion,omni,tests,docker,like[diffusion, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][diffusion, fsdp] feat: new rollout schedulerTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always