Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 71 additions & 99 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import argparse
import os
from typing import cast

from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType
from vllm_omni.inputs.data import OmniPromptType


def parse_args():
Expand Down Expand Up @@ -58,6 +57,7 @@ def parse_args():
choices=[1, 2, 3],
help="CFG parallel size: 1=batched (single GPU), 2=parallel with 2 branches (text CFG only), 3=parallel (3 GPUs).",
)
parser.add_argument("--seed", type=int, default=None, help="Random seed for generation.")

args = parser.parse_args()
return args
Expand Down Expand Up @@ -88,108 +88,80 @@ def main():

from PIL import Image

if args.modality == "img2img":
from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion

print(f"[Info] Running in {args.modality} mode (Stage 1 only, cfg_parallel_size={args.cfg_parallel_size})")

client = OmniDiffusion(
model=model_name,
parallel_config={"cfg_parallel_size": args.cfg_parallel_size},
)
from vllm_omni.entrypoints.omni import Omni

omni_kwargs = {}
if args.stage_configs_path:
omni_kwargs["stage_configs_path"] = args.stage_configs_path

omni_kwargs.update(
{
"log_stats": args.log_stats,
"init_sleep_seconds": args.init_sleep_seconds,
"batch_timeout": args.batch_timeout,
"init_timeout": args.init_timeout,
"shm_threshold_bytes": args.shm_threshold_bytes,
"worker_backend": args.worker_backend,
"ray_address": args.ray_address,
}
)

if args.image_path:
if os.path.exists(args.image_path):
omni = Omni(model=model_name, **omni_kwargs)

formatted_prompts = []
for p in prompts:
if args.modality == "img2img":
if not args.image_path or not os.path.exists(args.image_path):
raise ValueError(f"img2img requires --image-path pointing to an existing file, got: {args.image_path}")
loaded_image = Image.open(args.image_path).convert("RGB")
final_prompt_text = f"<|fim_middle|><|im_start|>{p}<|im_end|>"
prompt_dict = {
"prompt": final_prompt_text,
"multi_modal_data": {"img2img": loaded_image},
"modalities": ["img2img"],
}
if args.negative_prompt is not None:
prompt_dict["negative_prompt"] = args.negative_prompt
formatted_prompts.append(prompt_dict)
elif args.modality == "img2text":
if args.image_path:
loaded_image = Image.open(args.image_path).convert("RGB")
prompts = [
{
"prompt": cast(str, p),
"multi_modal_data": {"image": loaded_image},
}
for p in prompts
]
else:
print(f"[Warning] Image path {args.image_path} does not exist.")

result = client.generate(
prompts,
OmniDiffusionSamplingParams(
seed=52,
need_kv_receive=False,
num_inference_steps=args.steps,
extra_args={
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
},
),
)

# Ensure result is a list for iteration
if not isinstance(result, list):
omni_outputs = [result]
final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {
"prompt": final_prompt_text,
"multi_modal_data": {"image": loaded_image},
"modalities": ["text"],
}
formatted_prompts.append(prompt_dict)
elif args.modality == "text2text":
final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
formatted_prompts.append(prompt_dict)
else:
omni_outputs = result

else:
from vllm_omni.entrypoints.omni import Omni

omni_kwargs = {}
if args.stage_configs_path:
omni_kwargs["stage_configs_path"] = args.stage_configs_path

omni_kwargs.update(
{
"log_stats": args.log_stats,
"init_sleep_seconds": args.init_sleep_seconds,
"batch_timeout": args.batch_timeout,
"init_timeout": args.init_timeout,
"shm_threshold_bytes": args.shm_threshold_bytes,
"worker_backend": args.worker_backend,
"ray_address": args.ray_address,
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
if args.negative_prompt is not None:
prompt_dict["negative_prompt"] = args.negative_prompt
formatted_prompts.append(prompt_dict)

params_list = omni.default_sampling_params_list
if args.modality in ("text2img", "img2img"):
params_list[0].max_tokens = 1 # type: ignore
if len(params_list) > 1:
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
if args.seed is not None:
diffusion_params.seed = args.seed # type: ignore
extra = {
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}
)

omni = Omni(model=model_name, **omni_kwargs)

formatted_prompts = []
for p in args.prompts:
if args.modality == "img2text":
if args.image_path:
loaded_image = Image.open(args.image_path).convert("RGB")
final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {
"prompt": final_prompt_text,
"multi_modal_data": {"image": loaded_image},
"modalities": ["text"],
}
formatted_prompts.append(prompt_dict)
elif args.modality == "text2text":
final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
formatted_prompts.append(prompt_dict)
else:
# text2img
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
if args.negative_prompt is not None:
prompt_dict["negative_prompt"] = args.negative_prompt
formatted_prompts.append(prompt_dict)

params_list = omni.default_sampling_params_list
if args.modality == "text2img":
params_list[0].max_tokens = 1 # type: ignore # The first stage is a SamplingParam (vllm)
if len(params_list) > 1:
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
extra = {
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt
diffusion_params.extra_args = extra # type: ignore
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt
diffusion_params.extra_args = extra # type: ignore

omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

for i, req_output in enumerate(omni_outputs):
images = getattr(req_output, "images", None)
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed/omni_connectors/test_kv_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
OmniKVCacheConfig,
OmniKVTransferManager,
)
from vllm_omni.distributed.omni_connectors.utils.kv_utils import normalize_layer_kv
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.cache]
Expand Down Expand Up @@ -192,8 +193,7 @@ def test_normalize_layer_kv_rejects_invalid_inputs(kv_config, common_constants,
else:
layer_kv = (torch.randn(2, block_size, num_heads, head_dim), "not-a-tensor")

manager = OmniKVTransferManager(kv_config)
normalized = manager._normalize_layer_kv(layer_kv, req_id=req_id, layer_idx=0)
normalized = normalize_layer_kv(layer_kv, req_id=req_id, layer_idx=0)
assert normalized is None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ stage_args:
model_arch: BagelForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.35
gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
engine_output_type: text
Expand Down Expand Up @@ -46,7 +46,7 @@ stage_args:
max_batch_size: 1
engine_args:
model_stage: dit
gpu_memory_utilization: 0.55
gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
engine_output_type: image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ stage_args:
model_arch: BagelForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.35
gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
engine_output_type: text
Expand Down Expand Up @@ -45,7 +45,7 @@ stage_args:
max_batch_size: 1
engine_args:
model_stage: dit
gpu_memory_utilization: 0.55
gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
engine_output_type: image
Expand Down
3 changes: 3 additions & 0 deletions vllm_omni/diffusion/models/bagel/pipeline_bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
else:
gen_context["ropes"] = [seq_len]

if req.sampling_params.kv_metadata and "image_shape" in req.sampling_params.kv_metadata:
image_shape = tuple(req.sampling_params.kv_metadata["image_shape"])

cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
if cfg_text_kv is not None:
logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache")
Expand Down
62 changes: 12 additions & 50 deletions vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .factory import OmniConnectorFactory
from .utils.config import ConnectorSpec
from .utils.kv_utils import normalize_layer_kv

logger = init_logger(__name__)

Expand Down Expand Up @@ -200,8 +201,12 @@ def handle_finished_requests_kv_transfer(
logger.warning(f"Request {req_id} has no block IDs, skipping")
continue

custom_metadata = data.get("custom_metadata")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need introduce the custom_metadata? is there any doc to explain this arg? o.w., this many bring poos user/dev experiences.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we need transfer RoPE and image shape from ar stage to dit stage


# Extract KV cache from GPU blocks -> CPU tensors
kv_data = self._extract_kv_cache(req_id, block_ids, seq_len, kv_caches, block_size, cache_dtype)
kv_data = self._extract_kv_cache(
req_id, block_ids, seq_len, kv_caches, block_size, cache_dtype, custom_metadata
)
if kv_data:
# Resolve global request ID if available
transfer_req_id = request_id_resolver(req_id) if request_id_resolver else req_id
Expand All @@ -224,6 +229,7 @@ def _extract_kv_cache(
kv_caches: list[LayerKV],
block_size: int,
cache_dtype: str,
custom_metadata: dict[str, Any] | None = None,
) -> KVCacheTransferData | None:
"""Extract KV cache from GPU blocks for a single request.

Expand All @@ -234,6 +240,7 @@ def _extract_kv_cache(
kv_caches: List of KV cache (tensor or tuple) per layer
block_size: Size of each cache block
cache_dtype: Data type of the cache
custom_metadata: Optional custom metadata to include

Note: If key/value block counts differ, extraction uses only the overlapping
block range. Extra key/value blocks are ignored, so returned KV may be partial.
Expand All @@ -246,7 +253,7 @@ def _extract_kv_cache(
value_cache: list[torch.Tensor | None] = [None] * num_layers

for layer_idx, layer_kv in enumerate(kv_caches):
kv_pair = self._normalize_layer_kv(layer_kv, req_id=req_id, layer_idx=layer_idx)
kv_pair = normalize_layer_kv(layer_kv, req_id=req_id, layer_idx=layer_idx)
if kv_pair is None:
continue
key_blocks, value_blocks = kv_pair
Expand Down Expand Up @@ -289,57 +296,10 @@ def _extract_kv_cache(
"num_layers": num_layers,
"dtype": str(cache_dtype),
"seq_len": seq_len,
**(custom_metadata or {}),
},
)

def _normalize_layer_kv(
self,
layer_kv: LayerKV,
req_id: str,
layer_idx: int,
) -> tuple[torch.Tensor, torch.Tensor] | None:
"""Normalize one layer KV cache to a `(key_blocks, value_blocks)` tuple.

Args:
layer_kv: The raw KV cache (tensor or tuple) for the layer
req_id: Request ID for logging
layer_idx: Layer index for logging

Returns:
Tuple of (key_blocks, value_blocks) if valid, None otherwise
"""
if isinstance(layer_kv, torch.Tensor):
if layer_kv.ndim < 3 or layer_kv.shape[0] != 2:
logger.warning(
f"Layer {layer_idx} for request {req_id} has invalid stacked KV shape: "
f"expected [2, blocks, block_size, ...], got {tuple(layer_kv.shape)}"
)
return None
key_blocks = layer_kv[0]
value_blocks = layer_kv[1]
elif isinstance(layer_kv, tuple):
if len(layer_kv) != 2:
logger.warning(
f"Layer {layer_idx} for request {req_id} has KV pair length {len(layer_kv)} (expected 2)"
)
return None
key_blocks, value_blocks = layer_kv
if not isinstance(key_blocks, torch.Tensor) or not isinstance(value_blocks, torch.Tensor):
logger.warning(f"Layer {layer_idx} for request {req_id} has non-tensor KV pair entries")
return None
else:
logger.warning(f"Layer {layer_idx} for request {req_id} has unsupported KV type {type(layer_kv).__name__}")
return None
# ensure key/value blocks are at least 2D for block indexing
if key_blocks.ndim < 2 or value_blocks.ndim < 2:
logger.warning(
f"Layer {layer_idx} for request {req_id} has invalid KV block shape: "
f"got key={tuple(key_blocks.shape)} value={tuple(value_blocks.shape)}"
)
return None

return key_blocks, value_blocks

def _transfer_kv_cache(self, kv_data: KVCacheTransferData, transfer_req_id: str) -> None:
"""Transfer KV cache data to downstream stage via OmniConnector.

Expand Down Expand Up @@ -496,6 +456,8 @@ def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None:

if "metadata" in data:
req.kv_metadata = data["metadata"]
if hasattr(req, "sampling_params") and req.sampling_params is not None:
req.sampling_params.kv_metadata = data["metadata"]

# Legacy compatibility method
def receive_kv_cache(self, req: Any, target_device: torch.device | None = None) -> bool:
Expand Down
Loading