diff --git a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py index fa0f6d5478f..c25ebb9533e 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py @@ -18,6 +18,7 @@ construct_mm_data, vLLMMultimodalRequest, ) +from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model logger = logging.getLogger(__name__) @@ -63,10 +64,25 @@ async def generate(self, request: vLLMMultimodalRequest, context): request = vLLMMultimodalRequest.model_validate(request) logger.debug(f"Received decode request: {{ id: {request.request_id} }}.") - # Decode worker doesn't process embeddings, so we pass None or empty tensor + # For Qwen VL models with mRoPE, we need to pass multi_modal_data containing + # image_grid_thw for position embeddings calculation. The decode worker + # receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM + # will expand it using the multi_modal_data, ensuring the block count + # matches what prefill computed. + # + # We pass unique placeholder embeddings (seeded by request_id) since the + # actual embeddings are already in the KV cache from prefill. The unique + # values prevent incorrect prefix cache matches between different images. + multi_modal_data = None + if is_qwen_vl_model(self.config.model): + multi_modal_data = construct_qwen_decode_mm_data( + request.image_grid_thw, request.embeddings_shape, request.request_id + ) + gen = self.engine_client.generate( prompt=TokensPrompt( prompt_token_ids=request.engine_prompt["prompt_token_ids"], + multi_modal_data=multi_modal_data, ), sampling_params=request.sampling_params, request_id=request.request_id, @@ -222,12 +238,17 @@ async def generate(self, request: vLLMMultimodalRequest, context): if self.enable_disagg and self.decode_worker_client: decode_request = copy.deepcopy(request) async for prefill_response in gen: - # Update the prompt token id in the decode request to the one - # in response, which has image templated filled in. So that - # the decode worker will fetch correct amount of KV blocks. - decode_request.engine_prompt[ - "prompt_token_ids" - ] = prefill_response.prompt_token_ids + # For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt. + # The decode worker will pass multi_modal_data which causes vLLM to + # expand the prompt identically to prefill, ensuring block counts match. + # + # For other models: Use the expanded prompt from prefill response. + # These models don't pass multi_modal_data in decode, so they need + # the already-expanded prompt to match the KV cache layout. + if not is_qwen_vl_model(self.config.model): + decode_request.engine_prompt[ + "prompt_token_ids" + ] = prefill_response.prompt_token_ids logger.debug( f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}" ) diff --git a/components/src/dynamo/vllm/multimodal_utils/model.py b/components/src/dynamo/vllm/multimodal_utils/model.py index 42163ca9894..68e94c81711 100644 --- a/components/src/dynamo/vllm/multimodal_utils/model.py +++ b/components/src/dynamo/vllm/multimodal_utils/model.py @@ -177,3 +177,61 @@ def _construct_qwen_image_data( "image_grid_thw": grid_thw_tensor, } } + + +def construct_qwen_decode_mm_data( + image_grid_thw: Optional[List[Any]], + embeddings_shape: Optional[Any], + request_id: str, + *, + dtype: torch.dtype = torch.float16, +) -> Dict[str, Dict[str, torch.Tensor]]: + """Construct schema-valid Qwen multimodal data for vLLM v1 disagg decode. + + This is a WORKAROUND (WAR) for vLLM's disaggregated multimodal decode limitations. + + Notes: + - vLLM parses multimodal inputs and builds `mm_features` from `multi_modal_data`. + - For Qwen VL models, the parser enforces that image data contains BOTH + `image_embeds` and `image_grid_thw` keys. + - In disaggregated decode, the KV cache already includes the vision context + from prefill; decode still needs `mm_features` for mRoPE initialization. + + WAR Details: + - We generate unique placeholder embeddings based on request_id to prevent + incorrect prefix cache matches between different images with same dimensions. + - Without this, zero embeddings + same image_grid_thw would create identical + cache signatures, causing decode to incorrectly reuse cached KV from + different images. + + Caching Caveat: + - This WAR disables prefix cache reuse on the DECODE worker (each request + has unique placeholder embeddings). + - Prefix caching still works correctly on the PREFILL worker, which uses + actual image embeddings. This is where the caching benefit matters since + prefill does the heavy computation. + - Decode receives KV blocks from prefill via NIXL transfer anyway, so + decode-side prefix caching provides minimal benefit in disaggregated setup. + """ + if image_grid_thw is None or len(image_grid_thw) == 0: + raise ValueError("No image grid provided for Qwen model.") + if embeddings_shape is None: + raise ValueError("embeddings_shape is required for Qwen decode mm data.") + + # WAR: Use request_id hash as seed for unique placeholder values. + # This prevents prefix cache from incorrectly matching different images + # that happen to have the same dimensions (same image_grid_thw). + seed = hash(request_id) & 0xFFFFFFFF # Convert to positive 32-bit int + generator = torch.Generator().manual_seed(seed) + image_embeds = torch.randn( + embeddings_shape, dtype=dtype, device="cpu", generator=generator + ) + if image_embeds.ndim == 3: + image_embeds = image_embeds.squeeze(0) + + return { + "image": { + "image_embeds": image_embeds, + "image_grid_thw": torch.tensor(image_grid_thw), + } + } diff --git a/examples/backends/vllm/launch/disagg_multimodal_epd.sh b/examples/backends/vllm/launch/disagg_multimodal_epd.sh index e17da78427c..36115c3d35d 100755 --- a/examples/backends/vllm/launch/disagg_multimodal_epd.sh +++ b/examples/backends/vllm/launch/disagg_multimodal_epd.sh @@ -93,7 +93,7 @@ CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-wo # Start decode worker echo "Starting decode worker on GPU 2..." VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \ -CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & +CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & echo "==================================================" echo "All components started. Waiting for initialization..."