Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.inputs import TextPrompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -1468,7 +1469,7 @@ async def generation_single_request(task: dict[str, Any]):
batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics
):
try:
r_outputs = [output]
r_outputs = [output_strip(output, omni_stage)]
use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes)
if use_shm:
out_q.put(
Expand Down Expand Up @@ -1554,3 +1555,23 @@ def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float):
from vllm_omni.entrypoints.log_utils import StageStats

return StageStats(total_token=_agg_total_tokens, total_gen_time=_agg_total_gen_time_ms)


def output_strip(r_output: RequestOutput, omni_stage: OmniStage):
if omni_stage.final_output and omni_stage.final_output_type != "text":
return r_output

if getattr(r_output, "finished", False):
return r_output

mm_output = getattr(r_output, "multimodal_output", None)
if mm_output is not None:
r_output.multimodal_output = {}

outputs = getattr(r_output, "outputs", None)
if outputs is not None:
for out in outputs:
if getattr(out, "multimodal_output", None):
out.multimodal_output = {}

return r_output