Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm_omni.model_executor.stage_input_processors.glm_image import (
_parse_generated_tokens,
_upsample_token_ids,
ar2diffusion,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def _ar_output(
token_ids: list[int],
*,
multimodal_output: dict | None = None,
):
return SimpleNamespace(
outputs=[SimpleNamespace(token_ids=token_ids)],
multimodal_output=multimodal_output,
)


def test_upsample_token_ids_matches_nearest_neighbor_layout():
token_ids = torch.tensor([1, 2, 3, 4], dtype=torch.long)

upsampled = _upsample_token_ids(token_ids, token_h=2, token_w=2)

expected = torch.tensor(
[
1,
1,
2,
2,
1,
1,
2,
2,
3,
3,
4,
4,
3,
3,
4,
4,
],
dtype=torch.long,
)
torch.testing.assert_close(upsampled, expected)


def test_ar2diffusion_builds_upsampled_prior_tokens_for_t2i():
stage_list = [
SimpleNamespace(
engine_outputs=[
_ar_output([10, 11, 12, 13, 14], multimodal_output=None),
]
)
]

outputs = ar2diffusion(
stage_list=stage_list,
engine_input_source=[0],
prompt=[{"prompt": "hello", "height": 64, "width": 64}],
)

assert len(outputs) == 1
assert outputs[0]["prompt"] == "hello"
assert outputs[0]["height"] == 64
assert outputs[0]["width"] == 64
torch.testing.assert_close(
outputs[0]["extra"]["prior_token_ids"],
torch.tensor([11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14], dtype=torch.long),
)


def test_ar2diffusion_normalizes_serialized_prior_token_image_ids():
stage_list = [
SimpleNamespace(
engine_outputs=[
_ar_output(
[21, 22, 23, 24],
multimodal_output={"prior_token_image_ids": [[101, 102, 103, 104]]},
),
]
)
]

outputs = ar2diffusion(
stage_list=stage_list,
engine_input_source=[0],
prompt=[
{
"prompt": "edit",
"height": 64,
"width": 64,
"multi_modal_data": {"image": object()},
}
],
requires_multimodal_data=True,
)

prior_image_ids = outputs[0]["extra"]["prior_token_image_ids"]
assert isinstance(prior_image_ids, list)
assert len(prior_image_ids) == 1
assert isinstance(prior_image_ids[0], torch.Tensor)
torch.testing.assert_close(
prior_image_ids[0],
torch.tensor([101, 102, 103, 104], dtype=torch.long),
)
assert "pil_image" in outputs[0]


def test_ar2diffusion_uses_i2i_large_tokens_without_preview_prefix():
stage_list = [
SimpleNamespace(
engine_outputs=[
_ar_output(
[31, 32, 33, 34, 16385],
multimodal_output={"prior_token_image_ids": [torch.tensor([201, 202, 203, 204], dtype=torch.long)]},
),
]
)
]

outputs = ar2diffusion(
stage_list=stage_list,
engine_input_source=[0],
prompt=[{"prompt": "edit", "height": 64, "width": 64}],
)

torch.testing.assert_close(
outputs[0]["extra"]["prior_token_ids"],
torch.tensor([31, 31, 32, 32, 31, 31, 32, 32, 33, 33, 34, 34, 33, 33, 34, 34], dtype=torch.long),
)
torch.testing.assert_close(
outputs[0]["extra"]["prior_token_image_ids"][0],
torch.tensor([201, 202, 203, 204], dtype=torch.long),
)


def test_ar2diffusion_reads_prior_token_image_ids_from_completion_output_fallback():
output = SimpleNamespace(
token_ids=[41, 42, 43, 44],
multimodal_output={"prior_token_image_ids": [torch.tensor([301, 302, 303, 304], dtype=torch.long)]},
)
stage_list = [SimpleNamespace(engine_outputs=[SimpleNamespace(outputs=[output], multimodal_output=None)])]

outputs = ar2diffusion(
stage_list=stage_list,
engine_input_source=[0],
prompt=[{"prompt": "fallback", "height": 64, "width": 64}],
)

torch.testing.assert_close(
outputs[0]["extra"]["prior_token_image_ids"][0],
torch.tensor([301, 302, 303, 304], dtype=torch.long),
)


def test_parse_generated_tokens_adjusts_grid_for_truncated_output():
prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(
[51, 52, 53, 54],
height=128,
width=128,
)

assert pixel_h == 64
assert pixel_w == 64
torch.testing.assert_close(
prior_token_ids,
torch.tensor([51, 51, 52, 52, 51, 51, 52, 52, 53, 53, 54, 54, 53, 53, 54, 54], dtype=torch.long),
)
18 changes: 10 additions & 8 deletions vllm_omni/model_executor/models/glm_image/glm_image_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ def get_image_size_with_most_features(self) -> tuple[int, int]:
return (image_size, image_size)


def _upsample_token_grid_nearest(token_ids: torch.Tensor, grid_h: int, grid_w: int) -> torch.Tensor:
"""Upsample a 2D token grid by 2x using integer nearest-neighbor."""
token_grid = token_ids.view(grid_h, grid_w)
token_grid = token_grid.repeat_interleave(2, dim=0).repeat_interleave(2, dim=1)
return token_grid.reshape(-1)


class GlmImageDummyInputsBuilder(BaseDummyInputsBuilder[GlmImageProcessingInfo]):
"""
Builds dummy inputs for GLM-Image model profiling.
Expand Down Expand Up @@ -2231,10 +2238,7 @@ def forward(
for i, tokens in enumerate(image_tokens_list):
grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
# Reshape to 2D grid
tokens_2d = tokens.view(1, 1, grid_h, grid_w)
# Upsample by 2x (nearest neighbor)
tokens_upsampled = F.interpolate(tokens_2d.float(), scale_factor=2, mode="nearest").to(dtype=torch.long)
upsampled_token_ids.append(tokens_upsampled.view(-1))
upsampled_token_ids.append(_upsample_token_grid_nearest(tokens, grid_h, grid_w))

prior_token_image_ids_info = {
"prior_token_image_ids": upsampled_token_ids,
Expand Down Expand Up @@ -2439,11 +2443,9 @@ def _process_image_input(
for i, tokens in enumerate(image_tokens_list):
grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
# Reshape to 2D grid
tokens_2d = tokens.view(1, 1, grid_h, grid_w)
# Upsample by 2x (nearest neighbor)
tokens_upsampled = F.interpolate(tokens_2d.float(), scale_factor=2, mode="nearest").to(dtype=torch.long)
tokens_upsampled = _upsample_token_grid_nearest(tokens, grid_h, grid_w)
# Keep as CPU tensor for proper serialization through pooling_output
upsampled_token_ids.append(tokens_upsampled.view(-1).detach().cpu().contiguous())
upsampled_token_ids.append(tokens_upsampled.detach().cpu().contiguous())

# Note: We only include prior_token_image_ids in the info dict.
# image_grid_thw is NOT included because:
Expand Down
10 changes: 5 additions & 5 deletions vllm_omni/model_executor/stage_input_processors/glm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) ->
Returns:
Upsampled token IDs of shape [num_tokens * 4]
"""
token_ids = token_ids.view(1, 1, token_h, token_w)
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(dtype=torch.long)
token_ids = token_ids.view(-1)
return token_ids
token_grid = token_ids.view(token_h, token_w)
# Integer nearest-neighbor upsampling avoids the float cast/interpolate
# overhead in the AR -> diffusion bridge.
token_grid = token_grid.repeat_interleave(2, dim=0).repeat_interleave(2, dim=1)
return token_grid.reshape(-1)


def _parse_generated_tokens(
Expand Down Expand Up @@ -261,5 +262,4 @@ def ar2diffusion(
diffusion_input[key] = original_prompt[key]

diffusion_inputs.append(diffusion_input)

return diffusion_inputs
Loading