Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
construct_mm_data,
vLLMMultimodalRequest,
)
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,10 +64,25 @@ async def generate(self, request: vLLMMultimodalRequest, context):
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")

# Decode worker doesn't process embeddings, so we pass None or empty tensor
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
# image_grid_thw for position embeddings calculation. The decode worker
# receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM
# will expand it using the multi_modal_data, ensuring the block count
# matches what prefill computed.
#
# We pass unique placeholder embeddings (seeded by request_id) since the
# actual embeddings are already in the KV cache from prefill. The unique
# values prevent incorrect prefix cache matches between different images.
multi_modal_data = None
if is_qwen_vl_model(self.config.model):
multi_modal_data = construct_qwen_decode_mm_data(
request.image_grid_thw, request.embeddings_shape, request.request_id
)

gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=request.sampling_params,
request_id=request.request_id,
Expand Down Expand Up @@ -222,12 +238,17 @@ async def generate(self, request: vLLMMultimodalRequest, context):
if self.enable_disagg and self.decode_worker_client:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if not is_qwen_vl_model(self.config.model):
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
Expand Down
58 changes: 58 additions & 0 deletions components/src/dynamo/vllm/multimodal_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,61 @@ def _construct_qwen_image_data(
"image_grid_thw": grid_thw_tensor,
}
}


def construct_qwen_decode_mm_data(
image_grid_thw: Optional[List[Any]],
embeddings_shape: Optional[Any],
request_id: str,
*,
dtype: torch.dtype = torch.float16,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Construct schema-valid Qwen multimodal data for vLLM v1 disagg decode.

This is a WORKAROUND (WAR) for vLLM's disaggregated multimodal decode limitations.

Notes:
- vLLM parses multimodal inputs and builds `mm_features` from `multi_modal_data`.
- For Qwen VL models, the parser enforces that image data contains BOTH
`image_embeds` and `image_grid_thw` keys.
- In disaggregated decode, the KV cache already includes the vision context
from prefill; decode still needs `mm_features` for mRoPE initialization.

WAR Details:
- We generate unique placeholder embeddings based on request_id to prevent
incorrect prefix cache matches between different images with same dimensions.
- Without this, zero embeddings + same image_grid_thw would create identical
cache signatures, causing decode to incorrectly reuse cached KV from
different images.

Caching Caveat:
- This WAR disables prefix cache reuse on the DECODE worker (each request
has unique placeholder embeddings).
- Prefix caching still works correctly on the PREFILL worker, which uses
actual image embeddings. This is where the caching benefit matters since
prefill does the heavy computation.
- Decode receives KV blocks from prefill via NIXL transfer anyway, so
decode-side prefix caching provides minimal benefit in disaggregated setup.
"""
if image_grid_thw is None or len(image_grid_thw) == 0:
raise ValueError("No image grid provided for Qwen model.")
if embeddings_shape is None:
raise ValueError("embeddings_shape is required for Qwen decode mm data.")

# WAR: Use request_id hash as seed for unique placeholder values.
# This prevents prefix cache from incorrectly matching different images
# that happen to have the same dimensions (same image_grid_thw).
seed = hash(request_id) & 0xFFFFFFFF # Convert to positive 32-bit int
generator = torch.Generator().manual_seed(seed)
image_embeds = torch.randn(
embeddings_shape, dtype=dtype, device="cpu", generator=generator
)
if image_embeds.ndim == 3:
image_embeds = image_embeds.squeeze(0)

return {
"image": {
"image_embeds": image_embeds,
"image_grid_thw": torch.tensor(image_grid_thw),
}
}
2 changes: 1 addition & 1 deletion examples/backends/vllm/launch/disagg_multimodal_epd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-wo
# Start decode worker
echo "Starting decode worker on GPU 2..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &

echo "=================================================="
echo "All components started. Waiting for initialization..."
Expand Down
Loading