[Engine]Refactor output processing for multimodal capabilities in vLLM-omni#20
Conversation
- Introduced OmniRequestState to manage multimodal request states. - Enhanced MultimodalOutputProcessor to handle various output types including images, text, and latents. - Implemented methods for accumulating multimodal tensors and processing outputs. - Updated output handling to ensure compatibility with vLLM's base processor while allowing custom modality handlers.
There was a problem hiding this comment.
Pull Request Overview
This PR refactors the multimodal output processing system in vLLM-omni by replacing the original standalone processor with a vLLM-compatible architecture that extends vLLM's base OutputProcessor. The changes introduce better state management for multimodal requests and normalize different output types before delegating to vLLM's processing pipeline.
Key changes:
- Introduced
OmniRequestStateto track multimodal tensor accumulation across request lifetime - Replaced the original
MultimodalOutputProcessorwith a vLLM-compatible version that extendsVLLMOutputProcessor - Implemented modality-specific routing and normalization methods to handle images, text, latents, and audio outputs
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| return | ||
| try: | ||
| if mm_type: | ||
| self.mm_type = (mm_type or "").lower() |
There was a problem hiding this comment.
Redundant check: if mm_type is truthy, the or \"\" fallback in the .lower() call is unnecessary since mm_type is already confirmed to be non-empty.
| self.mm_type = (mm_type or "").lower() | |
| self.mm_type = mm_type.lower() |
| except Exception: | ||
| pass | ||
| if self.mm_accumulated is None: | ||
| self.mm_accumulated = t | ||
| else: | ||
| return "image" | ||
|
|
||
| # Check for latent-related attributes | ||
| if hasattr(output, 'latents') or hasattr(output, 'latent_representation'): | ||
| return "latents" | ||
|
|
||
| # Check for pooling output | ||
| if hasattr(output, 'pooler_output') and output.pooler_output is not None: | ||
| return "pooling" | ||
|
|
||
| # Default to text | ||
| return "text" | ||
|
|
||
| def _process_text_output(self, output: Any) -> List[RequestOutput]: | ||
| """Process text output.""" | ||
| if isinstance(output, RequestOutput): | ||
| return [output] | ||
|
|
||
| # Create a mock RequestOutput for text | ||
| completion_output = CompletionOutput( | ||
| index=0, | ||
| text=getattr(output, 'text', ''), | ||
| token_ids=getattr(output, 'token_ids', []), | ||
| cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), | ||
| logprobs=getattr(output, 'logprobs', None), | ||
| finish_reason=getattr(output, 'finish_reason', 'length') | ||
| ) | ||
|
|
||
| return [self._build_request_output(output, [completion_output])] | ||
| self.mm_accumulated = torch.cat([self.mm_accumulated, t], dim=0) | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Bare exception handler silently swallows all errors. Consider logging the exception or handling specific exception types to aid debugging.
| except Exception: | |
| pass | |
| if self.mm_accumulated is None: | |
| self.mm_accumulated = t | |
| else: | |
| return "image" | |
| # Check for latent-related attributes | |
| if hasattr(output, 'latents') or hasattr(output, 'latent_representation'): | |
| return "latents" | |
| # Check for pooling output | |
| if hasattr(output, 'pooler_output') and output.pooler_output is not None: | |
| return "pooling" | |
| # Default to text | |
| return "text" | |
| def _process_text_output(self, output: Any) -> List[RequestOutput]: | |
| """Process text output.""" | |
| if isinstance(output, RequestOutput): | |
| return [output] | |
| # Create a mock RequestOutput for text | |
| completion_output = CompletionOutput( | |
| index=0, | |
| text=getattr(output, 'text', ''), | |
| token_ids=getattr(output, 'token_ids', []), | |
| cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), | |
| logprobs=getattr(output, 'logprobs', None), | |
| finish_reason=getattr(output, 'finish_reason', 'length') | |
| ) | |
| return [self._build_request_output(output, [completion_output])] | |
| self.mm_accumulated = torch.cat([self.mm_accumulated, t], dim=0) | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.exception("Failed to move tensor to CPU in add_multimodal_tensor.") | |
| if self.mm_accumulated is None: | |
| self.mm_accumulated = t | |
| else: | |
| self.mm_accumulated = torch.cat([self.mm_accumulated, t], dim=0) | |
| except Exception as e: | |
| logger.exception("Exception occurred in add_multimodal_tensor.") |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Bare exception handler silently swallows all errors. Consider logging the exception or handling specific exception types to aid debugging.
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.exception("Error in add_multimodal_tensor") |
| if num_cached_tokens is not None: | ||
| # Keep num_cached_tokens in RequestOutput for compatibility | ||
| try: | ||
| self.num_cached_tokens = num_cached_tokens # type: ignore[attr-defined] |
There was a problem hiding this comment.
The type: ignore[attr-defined] comment suggests this attribute is being set dynamically on a class that doesn't define it. Consider documenting why this dynamic attribute is necessary or adding it to the class definition.
| setattr(base_output, "multimodal_output", {}) | ||
| setattr(base_output, "multimodal_output", {self.mm_type: tensor}) | ||
| except Exception as e: | ||
| logger.warning("Error in _new_completion_output", e) |
There was a problem hiding this comment.
The logger.warning() call is missing the exception argument. It should be logger.warning(\"Error in _new_completion_output: %s\", e) to properly format the exception message.
| logger.warning("Error in _new_completion_output", e) | |
| logger.warning("Error in _new_completion_output: %s", e) |
| assert req_state.detokenizer is not None | ||
| assert req_state.logprobs_processor is not None | ||
| stop_string = req_state.detokenizer.update( | ||
| new_token_ids, finish_reason == FinishReason.STOP) | ||
| if stop_string: | ||
| finish_reason = FinishReason.STOP | ||
| stop_reason = stop_string | ||
| req_state.logprobs_processor.update_from_output(eco) |
There was a problem hiding this comment.
These assertions will fail for pooling-only requests where detokenizer and logprobs_processor are set to None (lines 60-61). The assertions should be conditional or removed to support non-text output modes.
| assert req_state.detokenizer is not None | |
| assert req_state.logprobs_processor is not None | |
| stop_string = req_state.detokenizer.update( | |
| new_token_ids, finish_reason == FinishReason.STOP) | |
| if stop_string: | |
| finish_reason = FinishReason.STOP | |
| stop_reason = stop_string | |
| req_state.logprobs_processor.update_from_output(eco) | |
| if getattr(req_state, "output_kind", None) == RequestOutputKind.TEXT: | |
| assert req_state.detokenizer is not None | |
| assert req_state.logprobs_processor is not None | |
| stop_string = req_state.detokenizer.update( | |
| new_token_ids, finish_reason == FinishReason.STOP) | |
| if stop_string: | |
| finish_reason = FinishReason.STOP | |
| stop_reason = stop_string | |
| req_state.logprobs_processor.update_from_output(eco) |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Bare exception handler silently swallows all errors. Consider logging the exception or handling specific exception types to aid debugging.
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.warning("Error accumulating multimodal tensor: %s", e) |
| setattr(ro, "multimodal_output", {}) | ||
| ro.multimodal_output[mm_key] = req_state.mm_accumulated | ||
| except Exception as e: | ||
| logger.warning("Error in process_outputs", e) |
There was a problem hiding this comment.
The logger.warning() call is missing the exception argument. It should be logger.warning(\"Error in process_outputs: %s\", e) to properly format the exception message.
| logger.warning("Error in process_outputs", e) | |
| logger.warning("Error in process_outputs: %s", e) |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Bare exception handler silently swallows all errors. Consider logging the exception or handling specific exception types to aid debugging.
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.exception(f"Exception in output handler for type '{output_type}': {e}") |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Bare exception handler silently swallows all errors. Consider logging the exception or handling specific exception types to aid debugging.
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.warning( | |
| "Failed to convert pooling_output to tensor: %r. Exception: %s", | |
| eco.pooling_output, e) |
hsliuustc0106
left a comment
There was a problem hiding this comment.
does this file work for different output modality including text, image, wav and etc.
Yes, we have left the interface for the text, image, image+text, audio, hidden states, etc. But the actual implementation and postprocessing need further discussion depending on the user interface design. |
- Added logging for exceptions during tensor movement to CPU in OmniRequestState and MultimodalOutputProcessor. - Improved robustness by ensuring the output pipeline continues without crashing on errors. - Updated comments for clarity on error handling behavior.
|
lgtm |
…t-processor [Engine]Refactor output processing for multimodal capabilities in vLLM-omni
fix: use legacy config loading path instead of StageConfigFactory
fix: use legacy config loading path instead of StageConfigFactory
Purpose
This PR implements Phase 2 features of https://github.com/hsliuustc0106/vllm-omni/issues/10 . Refactor output processing for multimodal capabilities in vLLM-omni
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.