Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ def parse_args():
help="Path to input image for img2img.",
)

# Omni runtime init args
parser.add_argument(
"--output",
type=str,
default=".",
help="Output directory to save images.",
)

# OmniLLM init args
parser.add_argument("--log-stats", action="store_true", default=False)
parser.add_argument("--init-sleep-seconds", type=int, default=20)
parser.add_argument("--batch-timeout", type=int, default=5)
Expand Down Expand Up @@ -65,6 +72,7 @@ def parse_args():

def main():
args = parse_args()
os.makedirs(args.output, exist_ok=True)
model_name = args.model
prompts: list[OmniPromptType] = []
try:
Expand Down Expand Up @@ -166,13 +174,13 @@ def main():
img_idx = 0
for req_output in omni_outputs:
images = getattr(req_output, "images", None)

if not images:
continue

for j, img in enumerate(images):
save_path = f"output_{img_idx}_{j}.png"
save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png")
img.save(save_path)
print(f"[Info] Saved image to {save_path}")
img_idx += 1

print(omni_outputs)
Expand Down
5 changes: 3 additions & 2 deletions vllm_omni/model_executor/models/bagel/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from vllm.transformers_utils.processors.bagel import BagelProcessor

from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.models.bagel.autoencoder import (
AutoEncoderParams,
DiagonalGaussian,
Expand Down Expand Up @@ -439,7 +440,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>"))

self._vae_token_mask: torch.Tensor | None = None

self.device = get_local_device()
self._install_mot_modules(config)

def _install_mot_modules(self, config):
Expand Down Expand Up @@ -627,7 +628,7 @@ def _process_img2img_input(self, multimodal_input):
)
pos_embed = self.latent_pos_embed([vae_position_ids])
packed_timesteps = torch.tensor([timestep], device=padded_latent.device)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with torch.amp.autocast(self.device.type, dtype=torch.bfloat16):
timestep_embeds = self.time_embedder(packed_timesteps.to(padded_latent))
vae_embeds = self.vae2llm(latent) + timestep_embeds + pos_embed

Expand Down
86 changes: 86 additions & 0 deletions vllm_omni/platforms/xpu/stage_configs/bagel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# stage config for running bagel-7b-mot with architecture of OmniLLM.

stage_args:
- stage_id: 0
stage_type: llm
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
runtime:
devices: "0"
# 3 = 1 user prompt + 2 CFG companions (text-unconditional + image-unconditional).
max_batch_size: 3
engine_args:
model_stage: thinker
model_arch: OmniBagelForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.9
enforce_eager: true
trust_remote_code: true
engine_output_type: text
distributed_executor_backend: "mp"
enable_prefix_caching: false
max_num_batched_tokens: 16384
tensor_parallel_size: 1
quantization: fp8
omni_kv_config:
need_send_cache: true
kv_transfer_criteria:
type: prefill_finished #or special token generated
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 2048
seed: 52
detokenize: True
repetition_penalty: 1.05

- stage_id: 1
stage_type: diffusion
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
runtime:
devices: "1"
max_batch_size: 1
engine_args:
model_stage: dit
gpu_memory_utilization: 0.9
enforce_eager: true
trust_remote_code: true
engine_output_type: image
distributed_executor_backend: "mp"
enable_prefix_caching: false
max_num_batched_tokens: 32768
tensor_parallel_size: 1
omni_kv_config:
need_recv_cache: true
engine_input_source: [0]

final_output: true
final_output_type: image
is_comprehension: false
default_sampling_params:
seed: 52

# Runtime edges
runtime:
enabled: true
defaults:
window_size: -1
max_inflight: 1

# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
shared_memory_connector:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536 # 64KB threshold


edges:
- from: 0
to: 1
window_size: -1
Loading