Skip to content

[trainer] refactor: code refactor for diffusion training#1

Merged
SamitHuang merged 10 commits into
mainfrom
diffusion_refactor
Apr 22, 2026
Merged

[trainer] refactor: code refactor for diffusion training#1
SamitHuang merged 10 commits into
mainfrom
diffusion_refactor

Conversation

@zhtmike
Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike commented Apr 22, 2026

What does this PR do?

Major change:

  • Decouple LLM and diffusion configs, losses, etc., for easier integration of diffusion models/algorithms/engines in the future. (resolve comment from #5951)
    • Diffusion configs are now all placed in verl/trainer/config/diffusion, including diffusion_actor, dp_diffusion_actor, diffusion_fsdp, diffusion_rollout, etc.
    • Diffusion config classes are now all placed in verl/workers/config/diffusion, including actor, model, rollout.
    • Clean LLM-specific configs such as use_remove_padding, or unused configs. The content of _generated_diffusion_trainer.yaml has now been reduced by ~60%.

Other changes:

  • Clean the diffusion agent loop output; remove dead code originally borrowed from the LLM agent loop, such as input_id, attention_mask, etc., which are not used in diffusion training currently.
  • Refactor and remove extra_configs in diffusion configs. It is too loose and may cause confusion for users. (resolve comment from #5951)
  • Decouple LLM and diffusion losses to make diffusion losses clearer.
  • Code cleanup for diffusion trainer, diffusion agent loop, etc.
  • Change the vllm-omni rollout import to registration, following the diffusion trainer pattern. (resolve comment from #5951)

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, vllm_omni, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the diffusion training pipeline, reorganizing configurations into a dedicated directory and introducing a registry for custom vLLM-Omni pipelines. It also adds several CI sanity checks for PR metadata and code quality, along with a pyproject.toml file for project management. A review comment correctly identified that the PR title validation script is missing the new diffusion and omni module tags, which would result in CI failures for PRs using these tags.

Comment thread tests/special_sanity/check_pr_title.py Outdated
zhtmike and others added 3 commits April 22, 2026 13:45
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Comment thread .github/workflows/doc.yml Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add doc ci after doc system is built

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped.

Comment thread pyproject.toml
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can drop requirements* if we use pyproject.toml?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is required by sanity check CI, etc.

@SamitHuang SamitHuang merged commit 858ba8d into main Apr 22, 2026
4 checks passed
@zhtmike zhtmike deleted the diffusion_refactor branch April 22, 2026 06:38
yuekaizhang added a commit to yuekaizhang/verl-omni that referenced this pull request May 25, 2026
…s — client wiring, Hydra rebuild, multi_modal_inputs, eval + docs

Round 3 of the Qwen3-TTS GRPO recipe. Addresses all four Codex round-2
findings and lands T15 / T16 / T17 / T18 stubs. 111 recipe-scope tests pass.

Codex round-2 blocker fixes:

verl-project#1 — Real client wiring. New verl_omni/workers/rollout/_tts_client_patch.py
   monkey-patches LLMServerManager.get_client to return
   AutoRegressiveTTSServerClient when rollout.name == "vllm_omni_tts"
   (falls through to upstream LLMServerClient otherwise). Patch is
   idempotent and triggered at import time from
   verl_omni/workers/rollout/replica.py. New tests in
   tests/workers/test_tts_client_patch.py prove the upstream
   LLMServerManager now returns a client with generate_tts() for the
   AR-TTS rollout name.

verl-project#2 — Hydra config tree rebuilt to compose correctly. The previous
   subdirs (actor/rollout/ref/reward) used "# @Package _global_" which
   double-nested every section, so the trainer saw keys like
   "rollout.actor_rollout_ref.rollout" instead of
   "actor_rollout_ref.rollout". Deleted the broken subdirs and replaced
   the recipe with a single top-level
   verl_omni/trainer/config/qwen3_tts_trainer.yaml that:
   - Inherits the diffusion recipe's full defaults tree
     (diffusion/{actor,ref,rollout,model,model_engine} + data + reward)
     so all upstream-required sections (ray_kwargs, etc.) are present.
   - Overrides actor_rollout_ref.rollout for vllm_omni_tts + AR-TTS
     agent loop / manager.
   - Overrides algorithm with upstream verl AlgoConfig (adv_estimator=grpo).
   - Overrides reward.reward_manager to AsrErrorRateRewardManager.
   - Wires data.custom_cls = Qwen3TTSDataset.
   verl_omni/trainer/qwen3_tts_grpo/main.py adjusted to load from
   ../config with config_name=qwen3_tts_trainer.

   New tests in tests/trainer/test_qwen3_tts_hydra_compose.py compose
   the shipped yaml via Hydra, apply launch overrides, and verify:
   - Top-level keys land at expected paths (no double-nesting).
   - AR-TTS overrides survive composition.
   - The recipe validator accepts the composed config.
   - The validator rejects missing ASR URL, flow_grpo, co_located ASR,
     and n<2 at the real composition layer (not synthetic dicts).

verl-project#3 — multi_modal_inputs in worker output.
   verl_omni/agent_loop/autoregressive_tts_agent_loop.py:_postprocess()
   now emits a multi_modal_inputs object-array column (empty dict per
   row, audio recipes have no image/video inputs) plus __num_turns__,
   so upstream RayPPOTrainer.training_step's unconditional
   batch.non_tensor_batch["multi_modal_inputs"] iteration succeeds.

verl-project#4 — Recipe artifacts finished, no further deferrals:
   - examples/qwen3_tts_grpo_trainer/eval.sh + run_full.sh +
     verl_omni/trainer/qwen3_tts_grpo/run_eval.py.
   - run_eval.py writes eval_results.json with the AC-7 schema (base/RL
     CER + median/mean duration_ratio), enforces target_utt_id
     disjointness, and supports mock-inference JSON for offline scoring.
   - verl_omni/utils/validation_audio_logger.py implements AC-8: writes
     >=4 generated + ref + (optional target) wavs per validation step
     plus metrics.json; ArtifactWriteError on disk failures; post-run
     helper flags zero-artifact steps.
   - docs/recipe/qwen3_tts_grpo.md documents pipeline, layout, smoke,
     eval, full run, hardware, validation logging.
   - examples/qwen3_tts_grpo_trainer/README.md updated — open-items
     section replaced with eval / full-run / audio-logging sections.

Test totals (recipe scope): 111 passed.
  - dataset 9, audio output 4, reward 19, pairing 4, validator 9,
    hydra compose 7, run_eval 2, client patch 4, audio logger 5,
    plus pre-existing recipe-adjacent tests.

Still pending: T1 / T2b / T3 / T14 — live-engine smoke against the
downloaded Qwen3-TTS-12Hz-0.6B-Base + a running Qwen3-ASR server. Round
4's focus.

AI assistance: Claude Code (Opus 4.7, 1M context).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
yuekaizhang added a commit to yuekaizhang/verl-omni that referenced this pull request May 25, 2026
…t, validation logger wired, real eval path, T1 partial

Round 4 of the Qwen3-TTS GRPO recipe. Addresses every Codex round-3
blocker and runs T1's static checkpoint verification.

Round-3 blocker fixes:

verl-project#1 — Real eval path (B1). run_eval.py no longer NotImplementedErrors on
   the default launcher path. _live_inference_split() spins up
   AsyncOmni for each checkpoint, drives one Qwen3-TTS rollout per eval
   row, captures the generated waveform + duration, and feeds it into
   _score_inference() which transcribes via the remote ASR endpoint and
   computes per-row CER + duration_ratio. --mock-inference-json
   retained as an opt-in TEST-ONLY path that bypasses engine + ASR
   (used by tests/trainer/test_qwen3_tts_run_eval.py).

verl-project#2 — Validation logger wired into runtime (M2).
   AutoRegressiveTTSAgentLoopWorker.__init__ resolves the validation
   output dir from config.trainer.validation_data_dir (falling back to
   default_local_dir/validation_audio) and stores
   log_validation_step + post_run_check_emitted_artifacts. New
   _emit_validation_artifacts(...) is called from generate_sequences
   when batch.meta_info["validate"] is True, persisting AC-8 artifacts
   (>=4 generated wavs + ref + optional target + metrics.json) per
   validation step. New tests/agent_loop/test_ar_tts_validation_logging.py
   exercises the worker call path end-to-end (no GPU).

verl-project#3 — Canonical Hydra layout restored (M3). Moved trainer yaml back to
   verl_omni/trainer/config/qwen3_tts/qwen3_tts_trainer.yaml (per AC-9
   and Codex round-3 finding). Two fixes needed to keep composition
   correct in the subdir:
   - `# @Package _global_` on the trainer yaml itself, so its content
     goes to the root namespace (not under qwen3_tts.*).
   - All defaults use absolute paths (`/diffusion/...`, `/data@data`,
     `/reward@reward`) so groups resolve against the config search
     root, not against the qwen3_tts/ subdir.
   - Hardcoded `dp_diffusion_actor` / `dp_diffusion_ref` instead of the
     `${diffusion/model_engine}` interpolation, which breaks in the
     subdir (Hydra namespaces the group choices under qwen3_tts.*).
   main.py adjusted to config_name="qwen3_tts/qwen3_tts_trainer".
   Hydra composition tests pass against the canonical layout (7/7).

verl-project#4 — T1 partial (analyze gate). Round-4 verifies the static contract:
     `safetensors.safe_open('.hf_cache/Qwen3-TTS-12Hz-0.6B-Base/model.safetensors')`
     reports 478 total keys, 76 of them under `speaker_encoder.*`. The
     model supports voice cloning per the plan. The standalone
     scripts/qwen3_tts_smoke.py runs the full T1 + T2b verification
     (live engine launch); it is shipped so round 5 (or a human) can
     execute the live half against a real ref_audio. T1's
     fail-closed-if-absent contract is preserved by the script.

Files changed: 6 modified, 3 new (scripts/qwen3_tts_smoke.py,
tests/agent_loop/test_ar_tts_validation_logging.py, restored
qwen3_tts/qwen3_tts_trainer.yaml). 113 recipe-scope tests pass
(112 R3 + 2 new R4 worker validation logger tests, less 1 because
test_run_eval.py is now stricter about the mock path).

Still pending: T1 part 2 (live AsyncOmni inference), T2b
(stage-0 logprobs in OmniRequestOutput), T3 (remote ASR HTTP smoke),
T14 (end-to-end run_smoke.sh). All four require launching the actual
servers, which round 5 will attempt.

AI assistance: Claude Code (Opus 4.7, 1M context).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
yuekaizhang added a commit to yuekaizhang/verl-omni that referenced this pull request May 25, 2026
…dirs, AC-6 metrics in metrics.json, fail-closed speaker check, post-run AC-8 check

Round 5 of the Qwen3-TTS GRPO recipe. Addresses every Codex round-4
action item except T1/T2b/T3/T14 live-engine runs.

verl-project#1 — Canonical packaging finished (T11). The qwen3_tts/ subdir now
contains the per-role files the plan and AC-9 require:

  verl_omni/trainer/config/qwen3_tts/
  ├── qwen3_tts_trainer.yaml          (top-level, # @Package _global_)
  ├── actor/qwen3_tts_actor.yaml      (PPO mini/micro batch overrides)
  ├── ref/qwen3_tts_ref.yaml          (KL disabled in v1)
  ├── rollout/qwen3_tts_rollout.yaml  (vllm_omni_tts + AR agent loop)
  └── reward/qwen3_tts_reward.yaml    (AsrErrorRateRewardManager + remote ASR)

The trainer yaml's defaults now load each role twice: the diffusion /
shared base first, then the qwen3_tts/<role> file on top. Hydra merges
the second over the first at the same target path. Composition tests
(7/7) pass against the canonical layout. README and docs paths updated
to match.

verl-project#2 — AC-6 metrics in validation metrics.json (T16). The worker's
_emit_validation_artifacts now writes a stable schema with the full
AC-6 scalar set (mean_reward, mean_cer, mean_duration_ratio,
policy_loss, kl_loss) — mean_reward / mean_cer / mean_duration_ratio
are populated from per-sample rewards / CER / waveform durations;
policy_loss / kl_loss are null at validation steps (no parameter
update happens during validation). Test:
tests/agent_loop/test_ar_tts_validation_logging.py now asserts every
AC-6 key is present and policy_loss/kl_loss are explicitly null.

verl-project#3 — Post-run AC-8 check wired into runtime (T16). main.py's hydra
entry now wraps verl.trainer.main_ppo.run_ppo with a try/finally that
resolves the validation artifact dir (config.trainer.validation_data_dir
or default_local_dir/validation_audio) and calls
post_run_check_emitted_artifacts(...) when the dir exists. A run with
any validation_step_* directory containing zero generated audio files
raises ArtifactWriteError at the end of the run, regardless of whether
training itself succeeded. Tests:
tests/trainer/test_qwen3_tts_post_run_check.py covers the resolver,
the happy path, and the empty-step fail-closed path.

verl-project#4 — Fail-closed _check_speaker_encoder (T1 helper). The Hub-id branch
that previously returned True when local weights couldn't be found
(round-4's fail-open shortcut) now returns False with a clear stderr
message instructing the operator to download the checkpoint first.
Tests: tests/trainer/test_qwen3_tts_smoke_helper.py covers four cases:
no-local-checkpoint, no-speaker-encoder, speaker-encoder-present,
unreadable-file. Manual verification:
  smoke._check_speaker_encoder('.hf_cache/Qwen3-TTS-12Hz-0.6B-Base')
      → True   (real downloaded checkpoint, 76 speaker_encoder keys)
  smoke._check_speaker_encoder('Qwen/Qwen3-TTS-12Hz-0.6B-Base')
      → False  (Hub id without local download; fail-closed)

Tests: 122 recipe-scope tests pass (118 R0-R4 + 4 R5 helper/post-run/AC-6).

Live engines (T1 part 2, T2b, T3, T14) — still pending. The structural
plumbing is in place: scripts/qwen3_tts_smoke.py drives the full
T1+T2b verification against a real engine; the recipe trainer wraps
upstream main_ppo with the AC-8 post-run check; the eval pipeline
drives AsyncOmni for both checkpoints. Round 6 (or a human operator)
executes these against live processes.

Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
yuekaizhang added a commit to yuekaizhang/verl-omni that referenced this pull request May 25, 2026
…smoke prompt schema

Round 6, final round of the RLCR loop. Addresses Codex round-5 findings.

verl-project#1 — mean_cer plumbing fixed (T16, AC-8).
AutoRegressiveTTSAgentLoopWorker._compute_score() now preserves the
per-completion CER + duration_penalty values from the reward manager's
breakdown alongside rewards / success / transcripts. The validation
logger's _emit_validation_artifacts() therefore populates metrics.json
with a real, finite mean_cer when reward scoring ran. The
test_emit_validation_artifacts_writes_step_dir test now constructs scored
outputs with per_sample_cers and asserts metrics["mean_cer"] is a finite
float, not just a present key — closing the round-5 fail mode Codex
flagged. Also wraps each completion's waveform / codec_tokens with the
np.empty+assign pattern from BL-20260515-dataproto-object-batch-wrap so
DataProto keeps the underlying ndarrays.

verl-project#2 — Qwen3-TTS smoke prompt schema (T1 / T2b live attempt).
scripts/qwen3_tts_smoke.py was hitting vllm's parse_dec_only_prompt
rejection because OmniCustomPrompt with `extra_args` is the diffusion
shape. Switched the prompt to vllm's TextPrompt-style dict:

    {"prompt": prompt_text,
     "additional_information": {"ref_audio": <path>,
                                "ref_text": [<str>],
                                "task_type": ["Base"]}}

The talker reads task_type via info_dict["task_type"][0] (verified in
the talker source), so task_type / ref_text must be one-element lists.
Stage 1 (code2wav) needs an explicit SamplingParams; passing None
causes orchestrator.build_engine_core_request_from_tokens to fail at
params.clone() with NoneType.

Live smoke results captured this round:
  ✓ Engine bootstrap with verl-omni-side stage_configs_path override:
    OK (120s warmup, both stages initialized)
  ✓ speaker_encoder presence check on cached checkpoint: True (76 keys
    of 478 total)
  ✓ Prompt schema accepted by input preprocessor: OK (no more
    "Prompt dictionary must contain text, tokens, or embeddings")
  ✗ Generation crashes downstream of stage-0 speaker_encoder forward
    when ref_audio is synthetic sine-wave audio rather than real
    Mandarin speech. Engine reports EngineDeadError with NoneType
    .clone() in the build_engine_core_request_from_tokens path.

This is a known limitation of the smoke approach: the Qwen3-TTS Base
speaker encoder is trained on speech and refuses to extract a usable
embedding from a sine sweep. A real ref_audio.wav from
yuekai/aishell (or any Mandarin speech sample) is required to complete
T1 / T2b live verification. The data-prep pipeline ships the right
hooks (examples/qwen3_tts_grpo_trainer/data_process/aishell_voice_clone.py),
but downloading and pairing AISHELL inside a single review round is
infeasible.

T3 (remote Qwen3-ASR HTTP smoke) and T14 (end-to-end run_smoke.sh)
remain pending. The structural plumbing is fully in place — the recipe
is reproducible per AC-9 once a real Mandarin ref audio + a launched
Qwen3-ASR server are available.

Tests: tests/agent_loop/test_ar_tts_validation_logging.py now asserts
mean_cer is a finite float. 122 recipe-scope tests pass.
Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants