Skip to content

[trainer, diffusion] feat: add z-image support for flowgrpo training#76

Draft
zhtmike wants to merge 1 commit into
verl-project:mainfrom
zhtmike:z-image
Draft

[trainer, diffusion] feat: add z-image support for flowgrpo training#76
zhtmike wants to merge 1 commit into
verl-project:mainfrom
zhtmike:z-image

Conversation

@zhtmike
Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike commented May 13, 2026

What does this PR do?

  • add z-image support for flowgrpo training

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, vllm_omni, rollout, trainer, ci, training_utils, recipe, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, diffusion, omni, tests, docker
    • If this PR involves multiple modules, separate them with , like [diffusion, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][diffusion, fsdp] feat: new rollout scheduler

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Z-Image-Turbo training via FlowGRPO, incorporating OCR data preprocessing, training launch scripts, and specialized adapters for diffusers-based training and vLLM-omni rollouts. The reviewer feedback highlights opportunities to optimize the rollout adapter by batching text encoder calls, improve error handling in the preprocessing script to fail fast on malformed data or missing paths, and correct the normalization dimensions in the shared CFG utility.

Comment on lines +112 to +150
batch_size = prompt_ids.shape[0]
prompt_embeds_list = []
prompt_embeds_mask_list = []

for i in range(batch_size):
# Get the actual non-padded token IDs
mask_i = attention_mask[i].bool()
ids_i = prompt_ids[i][mask_i]

# Apply chat template (matching Z-Image's _encode_prompt)
prompt_str = self.tokenizer.decode(ids_i, skip_special_tokens=False)
messages = [{"role": "user", "content": prompt_str}]
formatted = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
encoded = self.tokenizer(
formatted,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoded.input_ids.to(self.device)
attn_mask = encoded.attention_mask.to(self.device).bool()

# Encode with text encoder (use second-to-last hidden states)
hidden_states = self.text_encoder(
input_ids=input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[-2]

# Extract non-padded embeddings
non_padded = hidden_states[0][attn_mask[0]]
prompt_embeds_list.append(non_padded)
prompt_embeds_mask_list.append(torch.ones(len(non_padded), dtype=torch.long, device=self.device))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling the text encoder inside a loop over the batch dimension is highly inefficient. This results in multiple sequential GPU kernel launches instead of a single batched operation. Additionally, decoding prompt_ids and re-applying the chat template may lead to double-templating if the input tokens already include template markers. The suggested change batches the text encoder call for significantly better performance.

            batch_size = prompt_ids.shape[0]
            all_formatted = []

            for i in range(batch_size):
                # Get the actual non-padded token IDs
                mask_i = attention_mask[i].bool()
                ids_i = prompt_ids[i][mask_i]

                # Apply chat template (matching Z-Image's _encode_prompt)
                prompt_str = self.tokenizer.decode(ids_i, skip_special_tokens=False)
                messages = [{"role": "user", "content": prompt_str}]
                formatted = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=True,
                )
                all_formatted.append(formatted)

            encoded = self.tokenizer(
                all_formatted,
                padding="max_length",
                max_length=max_sequence_length,
                truncation=True,
                return_tensors="pt",
            )
            input_ids = encoded.input_ids.to(self.device)
            attn_mask = encoded.attention_mask.to(self.device).bool()

            # Encode with text encoder (use second-to-last hidden states)
            hidden_states = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attn_mask,
                output_hidden_states=True,
            ).hidden_states[-2]

            prompt_embeds_list = []
            prompt_embeds_mask_list = []
            for i in range(batch_size):
                # Extract non-padded embeddings
                non_padded = hidden_states[i][attn_mask[i]]
                prompt_embeds_list.append(non_padded)
                prompt_embeds_mask_list.append(torch.ones(len(non_padded), dtype=torch.long, device=self.device))


def extract_solution(solution_str):
# The solution is stored in the format: 'The image displays "xxx".'
return solution_str.split('"')[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of extract_solution is fragile as it assumes the input string always contains at least two double quotes. Instead of returning a default or the original string which might cause silent downstream errors, the code should fail fast by raising an explicit error when the format is incorrect.

Suggested change
return solution_str.split('"')[1]
parts = solution_str.split('"')
if len(parts) < 2:
raise ValueError(f"Expected at least two quotes in solution string, got: {solution_str}")
return parts[1]
References
  1. Avoid adding default values to work around errors; prefer failing fast to make issues explicit.


data_source = "flow_grpo/ocr"

if local_dataset_path is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

local_dataset_path is initialized using os.path.expanduser(args.input_dir), which will always return a string. To ensure the code fails fast if the dataset is missing, check if the directory actually exists using os.path.exists() to provide a meaningful error message.

Suggested change
if local_dataset_path is not None:
if os.path.exists(local_dataset_path):
References
  1. Let the code fail fast to make the error explicit and avoid masking underlying issues.

Comment on lines +54 to +55
ori_pos_norm = torch.linalg.vector_norm(noise_pred.float(), dim=-1, keepdim=True)
new_pos_norm = torch.linalg.vector_norm(pred.float(), dim=-1, keepdim=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using dim=-1 for vector_norm only calculates the norm across the last dimension (width). For standard CFG normalization (the rescale trick), the norm should typically be calculated over all non-batch dimensions to correctly represent the magnitude of the entire prediction vector per sample.

Suggested change
ori_pos_norm = torch.linalg.vector_norm(noise_pred.float(), dim=-1, keepdim=True)
new_pos_norm = torch.linalg.vector_norm(pred.float(), dim=-1, keepdim=True)
dims = tuple(range(1, noise_pred.ndim))
ori_pos_norm = torch.linalg.vector_norm(noise_pred.float(), dim=dims, keepdim=True)
new_pos_norm = torch.linalg.vector_norm(pred.float(), dim=dims, keepdim=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant