[trainer, diffusion] feat: add z-image support for flowgrpo training#76
[trainer, diffusion] feat: add z-image support for flowgrpo training#76zhtmike wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Z-Image-Turbo training via FlowGRPO, incorporating OCR data preprocessing, training launch scripts, and specialized adapters for diffusers-based training and vLLM-omni rollouts. The reviewer feedback highlights opportunities to optimize the rollout adapter by batching text encoder calls, improve error handling in the preprocessing script to fail fast on malformed data or missing paths, and correct the normalization dimensions in the shared CFG utility.
| batch_size = prompt_ids.shape[0] | ||
| prompt_embeds_list = [] | ||
| prompt_embeds_mask_list = [] | ||
|
|
||
| for i in range(batch_size): | ||
| # Get the actual non-padded token IDs | ||
| mask_i = attention_mask[i].bool() | ||
| ids_i = prompt_ids[i][mask_i] | ||
|
|
||
| # Apply chat template (matching Z-Image's _encode_prompt) | ||
| prompt_str = self.tokenizer.decode(ids_i, skip_special_tokens=False) | ||
| messages = [{"role": "user", "content": prompt_str}] | ||
| formatted = self.tokenizer.apply_chat_template( | ||
| messages, | ||
| tokenize=False, | ||
| add_generation_prompt=True, | ||
| enable_thinking=True, | ||
| ) | ||
| encoded = self.tokenizer( | ||
| formatted, | ||
| padding="max_length", | ||
| max_length=max_sequence_length, | ||
| truncation=True, | ||
| return_tensors="pt", | ||
| ) | ||
| input_ids = encoded.input_ids.to(self.device) | ||
| attn_mask = encoded.attention_mask.to(self.device).bool() | ||
|
|
||
| # Encode with text encoder (use second-to-last hidden states) | ||
| hidden_states = self.text_encoder( | ||
| input_ids=input_ids, | ||
| attention_mask=attn_mask, | ||
| output_hidden_states=True, | ||
| ).hidden_states[-2] | ||
|
|
||
| # Extract non-padded embeddings | ||
| non_padded = hidden_states[0][attn_mask[0]] | ||
| prompt_embeds_list.append(non_padded) | ||
| prompt_embeds_mask_list.append(torch.ones(len(non_padded), dtype=torch.long, device=self.device)) |
There was a problem hiding this comment.
Calling the text encoder inside a loop over the batch dimension is highly inefficient. This results in multiple sequential GPU kernel launches instead of a single batched operation. Additionally, decoding prompt_ids and re-applying the chat template may lead to double-templating if the input tokens already include template markers. The suggested change batches the text encoder call for significantly better performance.
batch_size = prompt_ids.shape[0]
all_formatted = []
for i in range(batch_size):
# Get the actual non-padded token IDs
mask_i = attention_mask[i].bool()
ids_i = prompt_ids[i][mask_i]
# Apply chat template (matching Z-Image's _encode_prompt)
prompt_str = self.tokenizer.decode(ids_i, skip_special_tokens=False)
messages = [{"role": "user", "content": prompt_str}]
formatted = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
all_formatted.append(formatted)
encoded = self.tokenizer(
all_formatted,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoded.input_ids.to(self.device)
attn_mask = encoded.attention_mask.to(self.device).bool()
# Encode with text encoder (use second-to-last hidden states)
hidden_states = self.text_encoder(
input_ids=input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[-2]
prompt_embeds_list = []
prompt_embeds_mask_list = []
for i in range(batch_size):
# Extract non-padded embeddings
non_padded = hidden_states[i][attn_mask[i]]
prompt_embeds_list.append(non_padded)
prompt_embeds_mask_list.append(torch.ones(len(non_padded), dtype=torch.long, device=self.device))|
|
||
| def extract_solution(solution_str): | ||
| # The solution is stored in the format: 'The image displays "xxx".' | ||
| return solution_str.split('"')[1] |
There was a problem hiding this comment.
The current implementation of extract_solution is fragile as it assumes the input string always contains at least two double quotes. Instead of returning a default or the original string which might cause silent downstream errors, the code should fail fast by raising an explicit error when the format is incorrect.
| return solution_str.split('"')[1] | |
| parts = solution_str.split('"') | |
| if len(parts) < 2: | |
| raise ValueError(f"Expected at least two quotes in solution string, got: {solution_str}") | |
| return parts[1] |
References
- Avoid adding default values to work around errors; prefer failing fast to make issues explicit.
|
|
||
| data_source = "flow_grpo/ocr" | ||
|
|
||
| if local_dataset_path is not None: |
There was a problem hiding this comment.
local_dataset_path is initialized using os.path.expanduser(args.input_dir), which will always return a string. To ensure the code fails fast if the dataset is missing, check if the directory actually exists using os.path.exists() to provide a meaningful error message.
| if local_dataset_path is not None: | |
| if os.path.exists(local_dataset_path): |
References
- Let the code fail fast to make the error explicit and avoid masking underlying issues.
| ori_pos_norm = torch.linalg.vector_norm(noise_pred.float(), dim=-1, keepdim=True) | ||
| new_pos_norm = torch.linalg.vector_norm(pred.float(), dim=-1, keepdim=True) |
There was a problem hiding this comment.
Using dim=-1 for vector_norm only calculates the norm across the last dimension (width). For standard CFG normalization (the rescale trick), the norm should typically be calculated over all non-batch dimensions to correctly represent the magnitude of the entire prediction vector per sample.
| ori_pos_norm = torch.linalg.vector_norm(noise_pred.float(), dim=-1, keepdim=True) | |
| new_pos_norm = torch.linalg.vector_norm(pred.float(), dim=-1, keepdim=True) | |
| dims = tuple(range(1, noise_pred.ndim)) | |
| ori_pos_norm = torch.linalg.vector_norm(noise_pred.float(), dim=dims, keepdim=True) | |
| new_pos_norm = torch.linalg.vector_norm(pred.float(), dim=dims, keepdim=True) |
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,vllm_omni,rollout,trainer,ci,training_utils,recipe,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,diffusion,omni,tests,docker,like[diffusion, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][diffusion, fsdp] feat: new rollout schedulerTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always