diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 2adb1c34249..fc01c147cd6 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -1467,7 +1467,7 @@ async def generation_single_request(task: dict[str, Any]): batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics ): try: - r_outputs = [output] + r_outputs = [output_strip(output, omni_stage)] use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) if use_shm: out_q.put( @@ -1553,3 +1553,32 @@ def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float): from vllm_omni.entrypoints.log_utils import StageStats return StageStats(total_token=_agg_total_tokens, total_gen_time=_agg_total_gen_time_ms) + + +def output_strip(r_output: RequestOutput | OmniRequestOutput, omni_stage: OmniStage): + """ + Strip unnecessary multimodal outputs from stages results, + in order to: + - reduce memory usage + - reduce transfer & serialization overhead + """ + + # check multimodal data is required by stage output config. + if omni_stage.final_output and omni_stage.final_output_type != "text": + return r_output + + # If the request has already finished, should not be altered. + if getattr(r_output, "finished", False): + return r_output + + mm_output = getattr(r_output, "multimodal_output", None) + if mm_output is not None: + r_output.multimodal_output = {} + + outputs = getattr(r_output, "outputs", None) + if outputs is not None: + for out in outputs: + if getattr(out, "multimodal_output", None): + out.multimodal_output = {} + + return r_output