diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index e0e699c9c24..cc1b8cabd7b 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -21,6 +21,7 @@ from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger +from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike from vllm.usage.usage_lib import UsageContext @@ -1468,7 +1469,7 @@ async def generation_single_request(task: dict[str, Any]): batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics ): try: - r_outputs = [output] + r_outputs = [output_strip(output, omni_stage)] use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) if use_shm: out_q.put( @@ -1554,3 +1555,23 @@ def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float): from vllm_omni.entrypoints.log_utils import StageStats return StageStats(total_token=_agg_total_tokens, total_gen_time=_agg_total_gen_time_ms) + + +def output_strip(r_output: RequestOutput, omni_stage: OmniStage): + if omni_stage.final_output and omni_stage.final_output_type != "text": + return r_output + + if getattr(r_output, "finished", False): + return r_output + + mm_output = getattr(r_output, "multimodal_output", None) + if mm_output is not None: + r_output.multimodal_output = {} + + outputs = getattr(r_output, "outputs", None) + if outputs is not None: + for out in outputs: + if getattr(out, "multimodal_output", None): + out.multimodal_output = {} + + return r_output