Skip to content

[Bugfix] Fix GLM-Image output dimensions and image edit pipeline#2320

Merged
hsliuustc0106 merged 55 commits into
vllm-project:mainfrom
JaredforReal:fix/glm
Apr 20, 2026
Merged

[Bugfix] Fix GLM-Image output dimensions and image edit pipeline#2320
hsliuustc0106 merged 55 commits into
vllm-project:mainfrom
JaredforReal:fix/glm

Conversation

@JaredforReal
Copy link
Copy Markdown
Contributor

@JaredforReal JaredforReal commented Mar 30, 2026

Purpose

Fixes several issues in the GLM-Image serving and AR-to-diffusion pipeline:

  • Pass output dimensions to AR stage: The generate_images and edit_images endpoints now inject target_h / target_w into mm_processor_kwargs so the AR stage generates tokens for the correct grid size.
  • Fix small preview token calculation: Replaces the incorrect token_h // 2 / token_w // 2 formula with sqrt(ratio) * (factor // 2), matching GlmImageProcessor._build_prompt_with_target_shape. This was producing wrong preview grids for non-square images.
  • Normalize img2img modality key: GlmImageDataParser.parse_mm_data maps "img2img""image" so downstream code (mm_hashes, kwargs merging) sees a single consistent modality key.
  • Robust dimension resolution in ar2diffusion: Prefers mm_processor_kwargs over top-level prompt fields, with integer coercion and informative logging.

Test plan

  • Verify text-to-image generation produces correct output sizes for square (e.g. 1024x1024) and non-square requests
  • Verify image editing (/v1/images/edits) works end-to-end
  • Confirm small preview grid dimensions match expected values

Benchmark

Terminal 1

vllm serve zai-org/GLM-Image --port 8091 --host 0.0.0.0 --served-model-name glm-image --omni --enable-diffusion-pipeline-profiler

Terminal 2 for T2I mode

python benchmarks/diffusion/diffusion_benchmark_service.py \
    --model glm-image \
    --backend vllm-omni \ 
    --dataset vbench \
    --task t2i \
    --num-prompts 20 \
    --height 1024 --width 1024 \    
    --num-inference-steps 50 \
    --max-concurrency 1 \
    --warmup-requests 1 \  
    --output-file glm_image_t2i_1024_baseline.json 
================= Serving Benchmark Result =================
Backend:                                 vllm-omni
Model:                                   /workspace/GLM-Image
Dataset:                                 vbench
Task:                                    t2i
--------------------------------------------------
Benchmark duration (s):                  726.87
Request rate:                            inf
Max request concurrency:                 1
Successful requests:                     20/20
--------------------------------------------------
Request throughput (req/s):              0.03
Latency Mean (s):                        36.3433
Latency Median (s):                      36.2490
Latency P99 (s):                         37.1899
Latency P95 (s):                         37.1436
--------------------------------------------------
Peak Memory Max (MB):                    23420.00
Peak Memory Mean (MB):                   23420.00
Peak Memory Median (MB):                 23420.00
--------------------------------------------------
Stage Durations Mean (s):
  GlmImagePipeline.text_encoder.forward: 0.0142
  GlmImagePipeline.diffuse:              13.8316
  GlmImagePipeline.vae.decode:           0.1907

============================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copilot AI review requested due to automatic review settings March 30, 2026 06:37
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes GLM-Image image sizing inconsistencies across the serving layer, AR token generation/parsing, and AR→diffusion handoff so non-square and edit requests resolve to the correct output grid dimensions.

Changes:

  • Inject mm_processor_kwargs.target_h/target_w (and fallback height/width) into image generation/edit prompts in the OpenAI API server.
  • Align small preview token grid calculation in _parse_generated_tokens with GlmImageProcessor._build_prompt_with_target_shape.
  • Normalize img2img multimodal inputs to image within the GLM-Image AR data parser and improve dimension resolution in ar2diffusion.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
vllm_omni/model_executor/stage_input_processors/glm_image.py Updates preview grid math and makes ar2diffusion prefer mm_processor_kwargs target dims with coercion + logging.
vllm_omni/model_executor/models/glm_image/glm_image_ar.py Normalizes img2img modality key to image before parsing multimodal inputs.
vllm_omni/entrypoints/openai/api_server.py Passes requested output dimensions down to the AR stage via mm_processor_kwargs for generations and edits.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_omni/model_executor/stage_input_processors/glm_image.py Outdated
Comment thread vllm_omni/model_executor/stage_input_processors/glm_image.py Outdated
Comment thread vllm_omni/model_executor/models/glm_image/glm_image_ar.py Outdated
Comment thread vllm_omni/model_executor/stage_input_processors/glm_image.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 23445c1481

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/model_executor/stage_input_processors/glm_image.py Outdated
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@JaredforReal JaredforReal force-pushed the fix/glm branch 2 times, most recently from bc12b05 to 96c82da Compare March 30, 2026 12:16
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@JaredforReal JaredforReal force-pushed the fix/glm branch 2 times, most recently from 21dca1d to 5008cda Compare April 16, 2026 10:01
@scyiwei1986
Copy link
Copy Markdown

scyiwei1986 commented Apr 16, 2026

I use this mr on 0.18.0, but i got some wired output.
big_output_e2e

with the folloing changes.
/vllm-workspace/vllm-omni/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
44c44
< from vllm.inputs.data import MultiModalDataDict

from vllm.inputs import MultiModalDataDict

I use offline mode to run the test .

python end2end.py --model-path /opt/data/weights/modelscope/GLM-Image --config-path ../config/glm_image_graph.yaml --prompt "a dog sitting on the table" --height 1024 --width 1920 --num-inference-steps 20 --output big_output_e2e.png > e2e.log 2>&1 &

glm_image_graph.yaml
# Stage config for running GLM-Image with 2-stage architecture
# Stage 0: AR Model (vLLM implementation) - generates prior_token_ids
# Stage 1: Diffusion (DiT + VAE) - denoising and image decoding

stage_args:
  # Stage 0: AR Model (GlmImageForConditionalGeneration)
  # This stage uses the vLLM-optimized AR model to generate prior tokens
  # for conditioning the diffusion process.
  - stage_id: 0
    stage_type: llm
    runtime:
      process: true
      devices: "0"
      requires_multimodal_data: true # Required for i2i mode to receive source images
    engine_args:
      model_stage: ar
      max_num_seqs: 1
      model_arch: GlmImageForConditionalGeneration
      model_subdir: vision_language_encoder # AR model config.json is in this subdirectory
      tokenizer_subdir: processor # Use processor's tokenizer (not ByT5 from tokenizer/)
      worker_cls: vllm_omni.platforms.npu.worker.npu_ar_worker.NPUARWorker
      #worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
      scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
      gpu_memory_utilization: 0.7
      async_scheduling: true
      enforce_eager: false
      trust_remote_code: true
      compilation_config: 
        cudagraph_capture_sizes: [1,2]
        cudagraph_mode: "FULL_DECODE_ONLY"
      engine_output_type: token_ids # Output prior_token_ids for diffusion stage
      distributed_executor_backend: "mp"
      enable_prefix_caching: false
      max_num_batched_tokens: 32768
    final_output: false # AR is not the final output
    is_comprehension: true
    default_sampling_params:
      temperature: 0.9 # From model's generation_config.json
      top_p: 0.75 # From model's generation_config.json
      top_k: 16512 # vision_vocab_size from generation_config.json
      max_tokens: 3601 # Large enough for 1920x1024 (needs 2401) and some margin
      # For reference: 1024x1024 needs 1281, 1920x1024 needs 2401, 2048x2048 needs 5121
      stop_token_ids: [16385] # eos_token_id from generation_config.json
      seed: 42
      detokenize: false

  # Stage 1: Diffusion (DiT + VAE)
  # This stage receives prior_token_ids from AR and performs denoising + VAE decode
  - stage_id: 1
    stage_type: diffusion
    runtime:
      process: true
      devices: "1" # Can use different GPU, or same GPU if memory allows
      requires_multimodal_data: true # Required for i2i mode to pass condition images
    engine_args:
      model_stage: dit
      max_num_seqs: 1
      model_arch: GlmImagePipeline # Required for diffusion model class resolution
      # Diffusion-specific parameters
      num_gpus: 1
      enforce_eager: true
      trust_remote_code: true
      distributed_executor_backend: "mp"
    engine_input_source: [0] # Input from AR stage
    custom_process_input_func: vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion
    final_output: true
    final_output_type: image
    default_sampling_params:
      # Diffusion-specific parameters only (no LLM params like temperature/top_p/top_k)
      seed: 42
      num_inference_steps: 50
      guidance_scale: 1.5
      height: 1024
      width: 1024

# Top-level runtime config
runtime:
  enabled: true
  defaults:
    window_size: -1 # Trigger downstream only after full upstream completion
    max_inflight: 1 # Process serially within each stage

  edges:
    - from: 0 # AR → Diffusion: trigger after AR completes
      to: 1
      window_size: -1



@JaredforReal
Copy link
Copy Markdown
Contributor Author

JaredforReal commented Apr 16, 2026

@scyiwei1986 don't set any max_tokens in the config.yaml, it would be calculated internally.

and i try ur prompt:

curl -s http://172.18.67.228:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "messages": [
      {"role": "user", "content": "a dog sitting on the table"}    
    ],
    "extra_body": {
      "height": 1024,
      "width": 1920,
      "num_inference_steps": 20,
      "true_cfg_scale": 1.5,
      "seed": 42
    }
  }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > dog.png

I got:
dog

and offline end2end.py
CUDA_VISIBLE_DEVICES=1,2 python end2end.py --model-path /workspace/GLM-Image --prompt "a dog sitting on the table" --height 1024 --width 1924 --output dog.png --num-inference-steps 20

I got:
image

Works fine on H100 GPU, but I don't have any NPU machine, so we might need help from the community

@hsliuustc0106 hsliuustc0106 added the merge-test label to trigger buildkite merge test CI label Apr 16, 2026
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@hsliuustc0106 hsliuustc0106 removed the merge-test label to trigger buildkite merge test CI label Apr 20, 2026
Signed-off-by: Jared Wen <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 enabled auto-merge (squash) April 20, 2026 10:49
@hsliuustc0106 hsliuustc0106 disabled auto-merge April 20, 2026 11:24
@hsliuustc0106 hsliuustc0106 merged commit 7e28eda into vllm-project:main Apr 20, 2026
6 of 8 checks passed
nainiu258 pushed a commit to nainiu258/vllm-omni that referenced this pull request Apr 21, 2026
…lm-project#2320)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: Jared Wen <w13431838023@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Signed-off-by: nainiu258 <cperfect02@163.com>
qinganrice pushed a commit to qinganrice/vllm-omni that referenced this pull request Apr 23, 2026
…lm-project#2320)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: Jared Wen <w13431838023@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
ptarasiewiczNV added a commit to ptarasiewiczNV/vllm-omni that referenced this pull request Apr 23, 2026
PR vllm-project#2320 (`7e28eda9`) dropped `max_tokens: 1281` from the legacy
GLM-Image stage config and moved the compute into
`serving_chat._apply_request_overrides`, but gated it on
`height is not None and width is not None`. For the recipe's bare-curl
request (no `extra_body.height` / `extra_body.width`) the gate skipped
the compute, `SamplingParams.max_tokens` fell through to vLLM's
`max_model_len - seq_len` (~131k), and the AR stage's generation
budget no longer matched the VQ token layout the parser expects —
leaving the pre-refactor path latently broken since vllm-project#2320 and
surfacing as the IndexError the deploy-yaml edit in PR vllm-project#3034 was
working around.

Drop the `height is not None and width is not None` gate and default
`target_h`/`target_w` to `1024` (the stage-1 yaml default) so
`compute_max_tokens` always produces the correct budget, whether the
user overrides one dimension, both, or none. Gate the whole block on
`"GlmImageForConditionalGeneration" in model_config.hf_config.architectures`
so other AR-based image models (Bagel, HunyuanImage-3, …) — which
have their own `max_tokens` in their yamls and different token layouts
— are unaffected.

Also fix the related `getattr(explicit_fields, "max_tokens", None)`
bug: `explicit_fields` is a `set[str]` (Pydantic's `model_fields_set`),
so the getattr always returned `None` and the user's explicit
`max_tokens` was silently overwritten. Replace with proper set
membership check.

Signed-off-by: Piotr Tarasiewicz <ptarasiewicz@nvidia.com>
ptarasiewiczNV added a commit to ptarasiewiczNV/vllm-omni that referenced this pull request Apr 23, 2026
PR vllm-project#2320 (`7e28eda9`) dropped `max_tokens: 1281` from the GLM-Image
stage config and moved the compute into
`serving_chat._apply_request_overrides`, but gated it on
`height is not None and width is not None`. For the recipe's bare-curl
request (no `extra_body.height` / `extra_body.width`) the gate skipped
the compute; `SamplingParams.max_tokens` then fell through to vLLM's
`max_model_len - seq_len` (~131k) and the AR stage's generation
budget no longer matched the VQ token layout the parser expects,
leaving the pre-refactor path latently broken since vllm-project#2320 and
surfacing as the IndexError the deploy-yaml edit in vllm-project#3034 was
working around.

Fix: when the user didn't pass h/w, fall back to the diffusion stage's
default h/w (GLM-Image stage-1 yaml already declares
`height: 1024, width: 1024`), rather than hardcoding a second size
default in serving_chat or re-adding the yaml entry. This makes the
compute effectively unconditional for AR + image-diffusion pipelines
that declare a target size in their sampling params; LLM-only and
audio pipelines have neither height nor width in any stage's params
and continue to skip the block — no architecture gate needed.

Also fix a related bug: `getattr(explicit_fields, "max_tokens", None)`
was reading an attribute off a `set[str]` (Pydantic's
`model_fields_set`), so it always returned `None` and silently
overwrote user-provided `max_tokens`. Replaced with a proper set
membership check.

Signed-off-by: Piotr Tarasiewicz <ptarasiewicz@nvidia.com>
ptarasiewiczNV added a commit to ptarasiewiczNV/vllm-omni that referenced this pull request Apr 23, 2026
Cosmetic: restore the two-line `ref_image_count = len(reference_images)`
/ `is_img2img = ref_image_count > 0` shape from the pre-vllm-project#2320 code to
keep the diff against main smaller and match the surrounding style.

Signed-off-by: Piotr Tarasiewicz <ptarasiewicz@nvidia.com>
ptarasiewiczNV added a commit to ptarasiewiczNV/vllm-omni that referenced this pull request Apr 23, 2026
Match the upstream pre-vllm-project#2320 intent: the AR `max_tokens` is a function
of the target h/w (small-preview + large-target + EOS); a
user-supplied `max_tokens` can only mismatch the VQ token layout the
parser expects. Explicit `"max_tokens": null` on the request also
lands here, and the field-copy loop drops None values, so presence-
based gating would leave `params.max_tokens` unset. Restoring the
simple "always compute" shape avoids both edge cases.

Signed-off-by: Piotr Tarasiewicz <ptarasiewicz@nvidia.com>
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
…lm-project#2320)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: Jared Wen <w13431838023@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…lm-project#2320)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: Jared Wen <w13431838023@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants