From b8d2fa024a604ab0be6d18af41cb2b106c080295 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Mon, 17 Feb 2025 10:32:49 -0300 Subject: [PATCH 01/18] [Core][Feature] Input metadata dump on crash Signed-off-by: Wallas Santos --- .../test_basic_correctness.py | 49 ++++++ vllm/engine/llm_engine.py | 19 ++- vllm/error_report.py | 155 ++++++++++++++++++ vllm/v1/engine/core.py | 20 ++- vllm/worker/worker_base.py | 51 ++++-- 5 files changed, 275 insertions(+), 19 deletions(-) create mode 100644 vllm/error_report.py diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index bd97dd945fed..4392dca849d5 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -5,11 +5,16 @@ """ import os import weakref +from unittest.mock import Mock import pytest from vllm import LLM from vllm.platforms import current_platform +from vllm.v1.engine.core import ModelExecutionV1Error +from vllm.v1.engine.core_client import EngineCoreClient, InprocClient +from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 +from vllm.worker.worker_base import ModelExecutionError from ..conftest import VllmRunner from ..models.utils import check_outputs_equal @@ -147,3 +152,47 @@ def test_models_distributed( name_0="hf", name_1="vllm", ) + + +def test_failed_model_execution(vllm_runner) -> None: + + def make_client( + multiprocess_mode: bool, + asyncio_mode: bool, + vllm_config, # "VllmConfig" + executor_class, # "Type[Executor]" + log_stats: bool, + ) -> "EngineCoreClient": + return InprocClient(vllm_config, executor_class, log_stats) + + EngineCoreClient.make_client = Mock(side_effect=make_client) + with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + + engine = vllm_model.model.llm_engine + mocked_execute_model = Mock( + side_effect=RuntimeError("Mocked Critical Error")) + + if isinstance(engine, LLMEngineV1): + is_v1 = True + engine.engine_core.engine_core.model_executor.execute_model =\ + mocked_execute_model + else: # V0 + is_v1 = False + engine.model_executor.driver_worker.model_runner.execute_model = \ + mocked_execute_model + + with pytest.raises(RuntimeError) as exc_info: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + vllm_model.generate_greedy(prompts, 200, use_tqdm=False) + if is_v1: + assert isinstance(exc_info.value, ModelExecutionV1Error) + assert exc_info.value.scheduler_output is not None + else: + assert isinstance(exc_info.value, ModelExecutionError) + assert exc_info.value.model_input is not None + assert "Mocked Critical Error" in str(exc_info.value) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2e5bc75c6db3..75c34ccfef97 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,6 +29,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) +from vllm.error_report import dump_engine_exception from vllm.executor.executor_base import ExecutorBase from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) @@ -1383,8 +1384,22 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + + except BaseException as err: + stats = self._get_stats(scheduler_outputs=scheduler_outputs) + dump_engine_exception( + err=err, + config=self.vllm_config, + use_cached_outputs=self.use_cached_outputs, + engine_version=0, + stats=stats, + execute_model_req=execute_model_req, + ) + # Re-raise exception + raise err # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. diff --git a/vllm/error_report.py b/vllm/error_report.py new file mode 100644 index 000000000000..06b9b65a352c --- /dev/null +++ b/vllm/error_report.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +import enum +import json +from typing import Union + +import torch + +from vllm.config import VllmConfig +from vllm.engine.metrics import Stats +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SequenceData +from vllm.v1.core.scheduler_output import NewRequestData +from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.worker_base import ModelExecutionError + +logger = init_logger(__name__) + + +# Hacky way to make sure we can serialize the object in JSON format +def is_json_serializable(x): + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + + +def prepare_object_to_dump(obj): + if isinstance(obj, dict): + return {k: prepare_object_to_dump(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [prepare_object_to_dump(v) for v in obj] + elif isinstance(obj, set): + return [prepare_object_to_dump(v) for v in list(obj)] + elif isinstance(obj, tuple): + return [prepare_object_to_dump(v) for v in obj] + elif isinstance(obj, enum.Enum): + return repr(obj) + elif isinstance(obj, SequenceData): + # Custom representation (based on SequenceData.__repr__) + # to obfuscate some parameters + return { + "class": "SequenceData", + "prompt_token_ids_len": len(obj._prompt_token_ids), + "output_token_ids_len": len(obj.output_token_ids), + "cumulative_logprob": obj.cumulative_logprob, + "get_num_computed_tokens": obj.get_num_computed_tokens() + } + + elif isinstance(obj, NewRequestData): + obj_dict = {'class': type(obj).__name__} + for k, v in obj.__dict__.items(): + if k == 'prompt_token_ids': + obj_dict['prompt_token_ids_len'] = len(v) + elif k == 'prompt': + obj_dict['prompt'] = "" + else: + obj_dict[k] = prepare_object_to_dump(v) + + return obj_dict + elif isinstance(obj, torch.Tensor): + # We only print the 'draft'of the tensor to not expose sensitive data + # and to get some metadata in case of CUDA illegal memory access + return (f"Tensor(shape={obj.shape}, " + f"device={obj.device}," + f"dtype={obj.dtype})") + elif hasattr(obj, '__dict__'): + obj_dict = {'class': type(obj).__name__} + obj_dict.update(obj.__dict__) + return prepare_object_to_dump(obj_dict) + else: + # Try to make sure we can serialize the object + # to avoid exception + if is_json_serializable(obj): + return obj + else: + return repr(obj) + + +def dump_engine_exception(err: BaseException, + config: VllmConfig, + engine_version: int, + stats: Stats = None, + use_cached_outputs: Union[bool, None] = None, + execute_model_req: ExecuteModelRequest = None): + + assert engine_version == 0 or engine_version == 1 + + logger.error("Engine crashed, dumping input data") + + if engine_version == 1: + logger.error( + "V1 LLM engine (v%s) with config: %s, ", + VLLM_VERSION, + config, + ) + else: + logger.error( + "V0 LLM engine (v%s) with config: %s, " + "use_cached_outputs=%s, ", + VLLM_VERSION, + config, + use_cached_outputs, + ) + + # For V0 + if isinstance(err, ModelExecutionError): + try: + err_json = prepare_object_to_dump(err.model_input) + logger.error("Model input for execution as JSON:") + logger.error(json.dumps(err_json)) + except BaseException as err: + logger.error("Error preparing object to dump") + logger.error(repr(err)) + + # In case we do not have a ModelExecutionError we still can + # get information from the batch + if execute_model_req is not None: + batch = execute_model_req.seq_group_metadata_list + requests_count = len(batch) + requests_prompt_token_ids_lenghts = ', '.join([ + str(len(r.seq_data[idx].prompt_token_ids)) + for idx, r in enumerate(batch) + ]) + requests_ids = ', '.join([str(r.request_id) for r in batch]) + logger.error( + "Batch info: requests_count=%s, " + "requests_prompt_token_ids_lenghts=(%s), " + "requests_ids=(%s)", requests_count, + requests_prompt_token_ids_lenghts, requests_ids) + + for idx, r in enumerate(batch): + logger.error( + "Errored Batch request #%s: request_id=%s " + "prompt_token_ids_lengths=%s, " + "params=%s, " + "lora_request=%s, prompt_adapter_request=%s ", idx, + r.request_id, str(len(r.seq_data[idx].prompt_token_ids)), + r.sampling_params, r.lora_request, r.prompt_adapter_request) + + # TODO: Have stats for V1 + if stats is not None: + logger.error("System stats:") + logger.error(stats) + + if engine_version == 1: + from vllm.v1.engine.core import ModelExecutionV1Error + if isinstance(err, ModelExecutionV1Error): + try: + err_json = prepare_object_to_dump(err.scheduler_output) + logger.error("Scheduler output for model execution as JSON:") + logger.error(json.dumps(err_json)) + except BaseException as err: + logger.error("Error preparing object to dump") + logger.error(repr(err)) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c7ea7b1a94d8..5c73d3aab44e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -13,6 +13,7 @@ import zmq.asyncio from vllm.config import VllmConfig +from vllm.error_report import dump_engine_exception from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( @@ -35,6 +36,14 @@ POLLING_TIMEOUT_S = 2.5 +class ModelExecutionV1Error(RuntimeError): + scheduler_output: SchedulerOutput + + def __init__(self, *args, scheduler_output): + super().__init__(*args) + self.scheduler_output = scheduler_output + + class EngineCore: """Inner loop of vLLM's Engine.""" @@ -46,6 +55,7 @@ def __init__( ): assert vllm_config.model_config.runner_type != "pooling" + self.config = vllm_config logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -162,7 +172,15 @@ def step(self) -> EngineCoreOutputs: self.propose_tokens() scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) + try: + output = self.model_executor.execute_model(scheduler_output) + except BaseException as err: + err = ModelExecutionV1Error( + f"Model execution failure," + f"reason: {repr(err)}", + scheduler_output=scheduler_output) + dump_engine_exception(err, self.config, 1) + raise err engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore return engine_core_outputs diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 83fcf0865ae1..52730af667ee 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -208,6 +208,14 @@ def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") +class ModelExecutionError(RuntimeError): + model_input: BroadcastableModelInput + + def __init__(self, *args, model_input): + super().__init__(*args) + self.model_input = model_input + + @dataclasses.dataclass(frozen=True) class WorkerInput: """Local inputs to each worker. May contain device-specific data. These @@ -414,15 +422,20 @@ def execute_model( and self.observability_config.collect_model_execute_time): orig_model_execute_time = intermediate_tensors.tensors.get( "model_execute_time", torch.tensor(0)).item() - - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) + try: + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) + except BaseException as err: + raise ModelExecutionError( + f"Model execution failure," + f"reason: {repr(err)}", + model_input=model_input) from err model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: @@ -472,13 +485,19 @@ def _execute_model_spmd( kwargs = extract_previous_hidden_states(execute_model_req) - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) + try: + return self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) + except BaseException as err: + raise ModelExecutionError( + f"Model execution failure," + f"reason: {repr(err)}", + model_input=model_input) from err class WorkerWrapperBase: From 08c9f157911ddad14f95369397ffbbb95cafe9dc Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 18 Feb 2025 10:29:38 -0300 Subject: [PATCH 02/18] fix: mypy complaints Signed-off-by: Wallas Santos --- vllm/error_report.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/error_report.py b/vllm/error_report.py index 06b9b65a352c..4647305a87b0 100644 --- a/vllm/error_report.py +++ b/vllm/error_report.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 + import enum import json -from typing import Union +from typing import Any, Dict, Union import torch @@ -48,7 +49,7 @@ def prepare_object_to_dump(obj): } elif isinstance(obj, NewRequestData): - obj_dict = {'class': type(obj).__name__} + obj_dict: Dict[str, Any] = {'class': type(obj).__name__} for k, v in obj.__dict__.items(): if k == 'prompt_token_ids': obj_dict['prompt_token_ids_len'] = len(v) @@ -80,9 +81,10 @@ def prepare_object_to_dump(obj): def dump_engine_exception(err: BaseException, config: VllmConfig, engine_version: int, - stats: Stats = None, + stats: Union[Stats, None] = None, use_cached_outputs: Union[bool, None] = None, - execute_model_req: ExecuteModelRequest = None): + execute_model_req: Union[ExecuteModelRequest, + None] = None): assert engine_version == 0 or engine_version == 1 From 0ef83a8e6e19a133b13b4b8b38b6905b075970ba Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 18 Feb 2025 11:19:48 -0300 Subject: [PATCH 03/18] fix: mypy complaints Signed-off-by: Wallas Santos --- vllm/error_report.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/error_report.py b/vllm/error_report.py index 4647305a87b0..bba242ddda00 100644 --- a/vllm/error_report.py +++ b/vllm/error_report.py @@ -111,9 +111,9 @@ def dump_engine_exception(err: BaseException, err_json = prepare_object_to_dump(err.model_input) logger.error("Model input for execution as JSON:") logger.error(json.dumps(err_json)) - except BaseException as err: + except BaseException as exception: logger.error("Error preparing object to dump") - logger.error(repr(err)) + logger.error(repr(exception)) # In case we do not have a ModelExecutionError we still can # get information from the batch @@ -152,6 +152,6 @@ def dump_engine_exception(err: BaseException, err_json = prepare_object_to_dump(err.scheduler_output) logger.error("Scheduler output for model execution as JSON:") logger.error(json.dumps(err_json)) - except BaseException as err: + except BaseException as exception: logger.error("Error preparing object to dump") - logger.error(repr(err)) + logger.error(repr(exception)) From d8f75b7fb364235d2d8067230f6b60998216442c Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 25 Feb 2025 11:25:47 -0300 Subject: [PATCH 04/18] fix: dump report Signed-off-by: Wallas Santos --- vllm/error_report.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/error_report.py b/vllm/error_report.py index bba242ddda00..7a5494a118e8 100644 --- a/vllm/error_report.py +++ b/vllm/error_report.py @@ -120,10 +120,12 @@ def dump_engine_exception(err: BaseException, if execute_model_req is not None: batch = execute_model_req.seq_group_metadata_list requests_count = len(batch) - requests_prompt_token_ids_lenghts = ', '.join([ - str(len(r.seq_data[idx].prompt_token_ids)) - for idx, r in enumerate(batch) - ]) + + requests_prompt_token_ids_lenghts = [{ + k: len(v.prompt_token_ids) + for (k, v) in r.seq_data.items() + } for r in batch] + requests_ids = ', '.join([str(r.request_id) for r in batch]) logger.error( "Batch info: requests_count=%s, " From d588fff20bc5cdd053d5eda08e6dcd5f22ff230c Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 25 Feb 2025 15:19:11 -0300 Subject: [PATCH 05/18] fix server hang on shutdown Signed-off-by: Wallas Santos --- vllm/v1/engine/core.py | 6 +++++- vllm/worker/worker_base.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9476a38facf1..72d4c5867ffc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -40,10 +40,14 @@ class ModelExecutionV1Error(RuntimeError): scheduler_output: SchedulerOutput - def __init__(self, *args, scheduler_output): + def __init__(self, *args, scheduler_output=None): super().__init__(*args) self.scheduler_output = scheduler_output + def __reduce__(self): + # To avoid pickle errors + return (self.__class__, (self.args[0], )) + class EngineCore: """Inner loop of vLLM's Engine.""" diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 4f6f7c18c9c2..31e9af058653 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -211,10 +211,14 @@ def list_loras(self) -> Set[int]: class ModelExecutionError(RuntimeError): model_input: BroadcastableModelInput - def __init__(self, *args, model_input): + def __init__(self, *args, model_input=None): super().__init__(*args) self.model_input = model_input + def __reduce__(self): + # To avoid pickle errors + return (self.__class__, (self.args[0], )) + @dataclasses.dataclass(frozen=True) class WorkerInput: From ea544f19492903d4496b61026d0b3f808ab371ed Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Wed, 5 Mar 2025 11:28:53 -0300 Subject: [PATCH 06/18] review feedback Signed-off-by: Wallas Santos --- vllm/error_report.py | 31 +++++++++++-------------------- vllm/v1/engine/core.py | 21 ++++++++++++++++++--- vllm/worker/worker_base.py | 17 ++++++++++++++++- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/vllm/error_report.py b/vllm/error_report.py index 7a5494a118e8..842f7cb66c46 100644 --- a/vllm/error_report.py +++ b/vllm/error_report.py @@ -2,7 +2,7 @@ import enum import json -from typing import Any, Dict, Union +from typing import Any, Union import torch @@ -17,15 +17,6 @@ logger = init_logger(__name__) -# Hacky way to make sure we can serialize the object in JSON format -def is_json_serializable(x): - try: - json.dumps(x) - return True - except (TypeError, OverflowError): - return False - - def prepare_object_to_dump(obj): if isinstance(obj, dict): return {k: prepare_object_to_dump(v) for k, v in obj.items()} @@ -49,7 +40,7 @@ def prepare_object_to_dump(obj): } elif isinstance(obj, NewRequestData): - obj_dict: Dict[str, Any] = {'class': type(obj).__name__} + obj_dict: dict[str, Any] = {'class': type(obj).__name__} for k, v in obj.__dict__.items(): if k == 'prompt_token_ids': obj_dict['prompt_token_ids_len'] = len(v) @@ -60,7 +51,7 @@ def prepare_object_to_dump(obj): return obj_dict elif isinstance(obj, torch.Tensor): - # We only print the 'draft'of the tensor to not expose sensitive data + # We only print the 'draft' of the tensor to not expose sensitive data # and to get some metadata in case of CUDA illegal memory access return (f"Tensor(shape={obj.shape}, " f"device={obj.device}," @@ -70,11 +61,10 @@ def prepare_object_to_dump(obj): obj_dict.update(obj.__dict__) return prepare_object_to_dump(obj_dict) else: - # Try to make sure we can serialize the object - # to avoid exception - if is_json_serializable(obj): - return obj - else: + # Hacky way to make sure we can serialize the object in JSON format + try: + return json.dumps(obj) + except (TypeError, OverflowError): return repr(obj) @@ -88,7 +78,7 @@ def dump_engine_exception(err: BaseException, assert engine_version == 0 or engine_version == 1 - logger.error("Engine crashed, dumping input data") + logger.error("Dumping input data") if engine_version == 1: logger.error( @@ -115,8 +105,9 @@ def dump_engine_exception(err: BaseException, logger.error("Error preparing object to dump") logger.error(repr(exception)) - # In case we do not have a ModelExecutionError we still can - # get information from the batch + # In case we do not have a ModelExecutionError, which is only present if + # the engine raise an error, we still can dump the information from the + # batch if execute_model_req is not None: batch = execute_model_req.seq_group_metadata_list requests_count = len(batch) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9e08be30d9b2..02d2d829442b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -38,6 +38,17 @@ class ModelExecutionV1Error(RuntimeError): + """Custom RuntimeError with input data for model execution + + In a nutshell, this object is useful for custom handling of exception for + the case the engine raises an error. For instance, it is used to log the + input metadata that is useful for debugging on engine crashes. + + Args: + scheduler_output: SchedulerOutput object that contains the input + data for model execution + + """ scheduler_output: SchedulerOutput def __init__(self, *args, scheduler_output=None): @@ -45,7 +56,11 @@ def __init__(self, *args, scheduler_output=None): self.scheduler_output = scheduler_output def __reduce__(self): - # To avoid pickle errors + # To avoid pickle errors. + # This happens when we exchange this object between processes + # since scheduler_output can have objects that only makes sense + # to their context/process we remove them from the serialization + # and only send the summary of the error as a regular RuntimeError. return (self.__class__, (self.args[0], )) @@ -60,7 +75,7 @@ def __init__( ): assert vllm_config.model_config.runner_type != "pooling" - self.config = vllm_config + self.vllm_config = vllm_config logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -171,7 +186,7 @@ def step(self) -> EngineCoreOutputs: f"Model execution failure," f"reason: {repr(err)}", scheduler_output=scheduler_output) - dump_engine_exception(err, self.config, 1) + dump_engine_exception(err, self.vllm_config, 1) raise err engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ce8baeeef64e..3b4b8385a8dd 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -209,6 +209,17 @@ def list_loras(self) -> Set[int]: class ModelExecutionError(RuntimeError): + """Custom RuntimeError with input data for model execution + + In a nutshell, this object is useful for custom handling of exception for + the case the engine raises an error. For instance, it is used to log the + input metadata that is useful for debugging on engine crashes. + + Args: + model_input: BroadcastableModelInput object that contains the input + data for model execution + + """ model_input: BroadcastableModelInput def __init__(self, *args, model_input=None): @@ -216,7 +227,11 @@ def __init__(self, *args, model_input=None): self.model_input = model_input def __reduce__(self): - # To avoid pickle errors + # To avoid pickle errors. + # This happens when we exchange this object between processes. + # since model_input can have objects that only makes sense + # to their context/process we remove them from the serialization + # and only send the summary of the error as a regular RuntimeError. return (self.__class__, (self.args[0], )) From 63f24ab5ac3bafc220804d0689def4fcc2c61783 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Wed, 5 Mar 2025 11:44:17 -0300 Subject: [PATCH 07/18] moved vllm/error_report.py to vllm/logging_utils/dump_input.py Signed-off-by: Wallas Santos --- vllm/engine/llm_engine.py | 2 +- vllm/{error_report.py => logging_utils/dump_input.py} | 0 vllm/v1/engine/core.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename vllm/{error_report.py => logging_utils/dump_input.py} (100%) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e4fc610a5d8d..a579f1670848 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,13 +28,13 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) -from vllm.error_report import dump_engine_exception from vllm.executor.executor_base import ExecutorBase from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger +from vllm.logging_utils.dump_input import dump_engine_exception from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( diff --git a/vllm/error_report.py b/vllm/logging_utils/dump_input.py similarity index 100% rename from vllm/error_report.py rename to vllm/logging_utils/dump_input.py diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 02d2d829442b..e496622ea1fc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -15,8 +15,8 @@ import zmq.asyncio from vllm.config import VllmConfig -from vllm.error_report import dump_engine_exception from vllm.logger import init_logger +from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) From 7d405d486d969ec871b6e654fddd4fbf4c204dc3 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 6 Mar 2025 22:54:51 -0300 Subject: [PATCH 08/18] refact for review Signed-off-by: Wallas Santos --- .../test_basic_correctness.py | 8 +- vllm/engine/llm_engine.py | 5 +- vllm/logging_utils/dump_input.py | 133 ++++++++---------- vllm/sequence.py | 8 ++ vllm/v1/core/scheduler_output.py | 29 ++++ vllm/v1/engine/core.py | 6 +- vllm/worker/worker_base.py | 6 +- 7 files changed, 108 insertions(+), 87 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 60da8635d906..35069e837d60 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -11,10 +11,10 @@ from vllm import LLM from vllm.platforms import current_platform -from vllm.v1.engine.core import ModelExecutionV1Error +from vllm.v1.engine.core import ModelExecutionError from vllm.v1.engine.core_client import EngineCoreClient, InprocClient from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 -from vllm.worker.worker_base import ModelExecutionError +from vllm.worker.worker_base import ModelExecutionV0Error from ..conftest import VllmRunner from ..models.utils import check_outputs_equal @@ -190,9 +190,9 @@ def make_client( ] vllm_model.generate_greedy(prompts, 200, use_tqdm=False) if is_v1: - assert isinstance(exc_info.value, ModelExecutionV1Error) + assert isinstance(exc_info.value, ModelExecutionError) assert exc_info.value.scheduler_output is not None else: - assert isinstance(exc_info.value, ModelExecutionError) + assert isinstance(exc_info.value, ModelExecutionV0Error) assert exc_info.value.model_input is not None assert "Mocked Critical Error" in str(exc_info.value) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a579f1670848..8835db0aa605 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -34,7 +34,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger -from vllm.logging_utils.dump_input import dump_engine_exception +from vllm.logging_utils.dump_input import dump_engine_exception_v0 from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( @@ -1418,11 +1418,10 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: raise except BaseException as err: stats = self._get_stats(scheduler_outputs=scheduler_outputs) - dump_engine_exception( + dump_engine_exception_v0( err=err, config=self.vllm_config, use_cached_outputs=self.use_cached_outputs, - engine_version=0, stats=stats, execute_model_req=execute_model_req, ) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index 842f7cb66c46..6705a63a17ad 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -2,64 +2,49 @@ import enum import json -from typing import Any, Union +from typing import Union import torch from vllm.config import VllmConfig from vllm.engine.metrics import Stats from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SequenceData -from vllm.v1.core.scheduler_output import NewRequestData +from vllm.sequence import ExecuteModelRequest from vllm.version import __version__ as VLLM_VERSION -from vllm.worker.worker_base import ModelExecutionError +from vllm.worker.worker_base import ModelExecutionV0Error logger = init_logger(__name__) -def prepare_object_to_dump(obj): - if isinstance(obj, dict): - return {k: prepare_object_to_dump(v) for k, v in obj.items()} +def prepare_object_to_dump(obj) -> str: + if isinstance(obj, str): + return "'{obj}'" # Double quotes + elif isinstance(obj, dict): + dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ + for k, v in obj.items()}) + return f'{{{dict_str}}}' elif isinstance(obj, list): - return [prepare_object_to_dump(v) for v in obj] + return f'[{', '.join([prepare_object_to_dump(v) for v in obj])}]' elif isinstance(obj, set): - return [prepare_object_to_dump(v) for v in list(obj)] + return f'[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]' + # return [prepare_object_to_dump(v) for v in list(obj)] elif isinstance(obj, tuple): return [prepare_object_to_dump(v) for v in obj] elif isinstance(obj, enum.Enum): return repr(obj) - elif isinstance(obj, SequenceData): - # Custom representation (based on SequenceData.__repr__) - # to obfuscate some parameters - return { - "class": "SequenceData", - "prompt_token_ids_len": len(obj._prompt_token_ids), - "output_token_ids_len": len(obj.output_token_ids), - "cumulative_logprob": obj.cumulative_logprob, - "get_num_computed_tokens": obj.get_num_computed_tokens() - } - - elif isinstance(obj, NewRequestData): - obj_dict: dict[str, Any] = {'class': type(obj).__name__} - for k, v in obj.__dict__.items(): - if k == 'prompt_token_ids': - obj_dict['prompt_token_ids_len'] = len(v) - elif k == 'prompt': - obj_dict['prompt'] = "" - else: - obj_dict[k] = prepare_object_to_dump(v) - - return obj_dict elif isinstance(obj, torch.Tensor): # We only print the 'draft' of the tensor to not expose sensitive data # and to get some metadata in case of CUDA illegal memory access return (f"Tensor(shape={obj.shape}, " f"device={obj.device}," f"dtype={obj.dtype})") + elif hasattr(obj, 'anon_repr'): + return obj.anon_repr() elif hasattr(obj, '__dict__'): - obj_dict = {'class': type(obj).__name__} - obj_dict.update(obj.__dict__) - return prepare_object_to_dump(obj_dict) + items = obj.__dict__.items() + dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \ + for k, v in items]) + return (f"{type(obj).__name__}({dict_str})") else: # Hacky way to make sure we can serialize the object in JSON format try: @@ -68,44 +53,56 @@ def prepare_object_to_dump(obj): return repr(obj) -def dump_engine_exception(err: BaseException, - config: VllmConfig, - engine_version: int, - stats: Union[Stats, None] = None, - use_cached_outputs: Union[bool, None] = None, - execute_model_req: Union[ExecuteModelRequest, - None] = None): - - assert engine_version == 0 or engine_version == 1 +def dump_engine_exception(err: BaseException, config: VllmConfig): logger.error("Dumping input data") - if engine_version == 1: - logger.error( - "V1 LLM engine (v%s) with config: %s, ", - VLLM_VERSION, - config, - ) - else: - logger.error( - "V0 LLM engine (v%s) with config: %s, " - "use_cached_outputs=%s, ", - VLLM_VERSION, - config, - use_cached_outputs, - ) + logger.error( + "V1 LLM engine (v%s) with config: %s, ", + VLLM_VERSION, + config, + ) - # For V0 + # TODO: Have stats for V1 + + from vllm.v1.engine.core import ModelExecutionError if isinstance(err, ModelExecutionError): try: - err_json = prepare_object_to_dump(err.model_input) - logger.error("Model input for execution as JSON:") - logger.error(json.dumps(err_json)) + dump_obj = prepare_object_to_dump(err.scheduler_output) + logger.error("Dumping scheduler output for model execution:") + logger.error(dump_obj) except BaseException as exception: logger.error("Error preparing object to dump") logger.error(repr(exception)) - # In case we do not have a ModelExecutionError, which is only present if + +# TODO: Remove this when V1 is default +def dump_engine_exception_v0(err: BaseException, + config: VllmConfig, + stats: Union[Stats, None] = None, + use_cached_outputs: Union[bool, None] = None, + execute_model_req: Union[ExecuteModelRequest, + None] = None): + + logger.error( + "V0 LLM engine (v%s) with config: %s, " + "use_cached_outputs=%s, ", + VLLM_VERSION, + config, + use_cached_outputs, + ) + + # For V0 + if isinstance(err, ModelExecutionV0Error): + try: + dump_obj = prepare_object_to_dump(err.model_input) + logger.error("Dumping model input for execution:") + logger.error(dump_obj) + except BaseException as exception: + logger.error("Error preparing object to dump") + logger.error(repr(exception)) + + # In case we do not have a ModelExecutionV0Error, which is only present if # the engine raise an error, we still can dump the information from the # batch if execute_model_req is not None: @@ -133,18 +130,6 @@ def dump_engine_exception(err: BaseException, r.request_id, str(len(r.seq_data[idx].prompt_token_ids)), r.sampling_params, r.lora_request, r.prompt_adapter_request) - # TODO: Have stats for V1 if stats is not None: logger.error("System stats:") logger.error(stats) - - if engine_version == 1: - from vllm.v1.engine.core import ModelExecutionV1Error - if isinstance(err, ModelExecutionV1Error): - try: - err_json = prepare_object_to_dump(err.scheduler_output) - logger.error("Scheduler output for model execution as JSON:") - logger.error(json.dumps(err_json)) - except BaseException as exception: - logger.error("Error preparing object to dump") - logger.error(repr(exception)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6a7b1e62a604..9c24cdad242d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -383,6 +383,14 @@ def __repr__(self) -> str: f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") + # Version of __repr__ with the prompt data obfuscated + def anon_repr(self) -> str: + return (f"SequenceData(" + f"prompt_token_ids_len={len(self._prompt_token_ids)}, " + f"output_token_ids_len={len(self.output_token_ids)}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()})") + class Sequence: """Stores the data, status, and block information of a sequence. diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index b6caa8b4ebf7..9eebf3652d2e 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -44,6 +44,35 @@ def from_request( lora_request=request.lora_request, ) + def __repr__(self): + return (f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"prompt={self.prompt}," + f"mm_inputs={self.mm_inputs}," + f"mm_hashes={self.mm_hashes}," + f"mm_positions={self.mm_positions}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}" + ")") + + # Version of __repr__ with the prompt data obfuscated + def anon_repr(self): + return (f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={len(self.prompt_token_ids)}," + f"prompt=''," + f"mm_inputs={self.mm_inputs}," + f"mm_hashes={self.mm_hashes}," + f"mm_positions={self.mm_positions}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}" + ")") + @dataclass class CachedRequestData: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e496622ea1fc..fb32cd370f85 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -37,7 +37,7 @@ POLLING_TIMEOUT_S = 2.5 -class ModelExecutionV1Error(RuntimeError): +class ModelExecutionError(RuntimeError): """Custom RuntimeError with input data for model execution In a nutshell, this object is useful for custom handling of exception for @@ -182,11 +182,11 @@ def step(self) -> EngineCoreOutputs: try: output = self.model_executor.execute_model(scheduler_output) except BaseException as err: - err = ModelExecutionV1Error( + err = ModelExecutionError( f"Model execution failure," f"reason: {repr(err)}", scheduler_output=scheduler_output) - dump_engine_exception(err, self.vllm_config, 1) + dump_engine_exception(err, self.vllm_config) raise err engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 3b4b8385a8dd..43223e86cb89 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -208,7 +208,7 @@ def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") -class ModelExecutionError(RuntimeError): +class ModelExecutionV0Error(RuntimeError): """Custom RuntimeError with input data for model execution In a nutshell, this object is useful for custom handling of exception for @@ -453,7 +453,7 @@ def execute_model( **kwargs, ) except BaseException as err: - raise ModelExecutionError( + raise ModelExecutionV0Error( f"Model execution failure," f"reason: {repr(err)}", model_input=model_input) from err @@ -515,7 +515,7 @@ def _execute_model_spmd( **kwargs, ) except BaseException as err: - raise ModelExecutionError( + raise ModelExecutionV0Error( f"Model execution failure," f"reason: {repr(err)}", model_input=model_input) from err From b4a83eb58ab4584bf840883ed7567bd4fb93dff1 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 7 Mar 2025 10:09:54 -0300 Subject: [PATCH 09/18] refactoring Signed-off-by: Wallas Santos --- .github/ISSUE_TEMPLATE/400-bug-report.yml | 2 +- .../test_basic_correctness.py | 78 ++++++++++--------- vllm/logging_utils/dump_input.py | 4 +- vllm/v1/engine/llm_engine.py | 4 +- 4 files changed, 49 insertions(+), 39 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index d4113da8b5b8..e14f5013cd00 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -75,7 +75,7 @@ body: ``` ``` - The error message you got, with the full traceback. + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. ``` validations: required: true diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 35069e837d60..45147669e477 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -5,14 +5,13 @@ """ import os import weakref -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from vllm import LLM from vllm.platforms import current_platform from vllm.v1.engine.core import ModelExecutionError -from vllm.v1.engine.core_client import EngineCoreClient, InprocClient from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.worker.worker_base import ModelExecutionV0Error @@ -154,45 +153,52 @@ def test_models_distributed( ) +@patch('vllm.v1.engine.llm_engine.VLLM_ENABLE_V1_MULTIPROCESSING', False) def test_failed_model_execution(vllm_runner) -> None: - def make_client( - multiprocess_mode: bool, - asyncio_mode: bool, - vllm_config, # "VllmConfig" - executor_class, # "Type[Executor]" - log_stats: bool, - ) -> "EngineCoreClient": - return InprocClient(vllm_config, executor_class, log_stats) - - EngineCoreClient.make_client = Mock(side_effect=make_client) + # Create model with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + if isinstance(vllm_model.model.llm_engine, LLMEngineV1): + v1_test_failed_model_execution(vllm_model) + else: # V0 + v0_test_failed_model_execution(vllm_model) - engine = vllm_model.model.llm_engine - mocked_execute_model = Mock( - side_effect=RuntimeError("Mocked Critical Error")) - if isinstance(engine, LLMEngineV1): - is_v1 = True - engine.engine_core.engine_core.model_executor.execute_model =\ +def v0_test_failed_model_execution(vllm_model): + engine = vllm_model.model.llm_engine + mocked_execute_model = Mock( + side_effect=RuntimeError("Mocked Critical Error")) + engine.model_executor.driver_worker.model_runner.execute_model = \ mocked_execute_model - else: # V0 - is_v1 = False - engine.model_executor.driver_worker.model_runner.execute_model = \ + with pytest.raises(RuntimeError) as exc_info: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + vllm_model.generate_greedy(prompts, 200, use_tqdm=False) + assert isinstance(exc_info.value, ModelExecutionV0Error) + assert exc_info.value.model_input is not None + assert "Mocked Critical Error" in str(exc_info.value) + + +def v1_test_failed_model_execution(vllm_model): + + engine = vllm_model.model.llm_engine + mocked_execute_model = Mock( + side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model =\ mocked_execute_model - with pytest.raises(RuntimeError) as exc_info: - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - vllm_model.generate_greedy(prompts, 200, use_tqdm=False) - if is_v1: - assert isinstance(exc_info.value, ModelExecutionError) - assert exc_info.value.scheduler_output is not None - else: - assert isinstance(exc_info.value, ModelExecutionV0Error) - assert exc_info.value.model_input is not None - assert "Mocked Critical Error" in str(exc_info.value) + with pytest.raises(RuntimeError) as exc_info: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + vllm_model.generate_greedy(prompts, 200, use_tqdm=False) + assert isinstance(exc_info.value, ModelExecutionError) + assert exc_info.value.scheduler_output is not None + assert "Mocked Critical Error" in str(exc_info.value) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index 6705a63a17ad..ff2a4d39d27b 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -34,7 +34,7 @@ def prepare_object_to_dump(obj) -> str: return repr(obj) elif isinstance(obj, torch.Tensor): # We only print the 'draft' of the tensor to not expose sensitive data - # and to get some metadata in case of CUDA illegal memory access + # and to get some metadata in case of CUDA runtime crashed return (f"Tensor(shape={obj.shape}, " f"device={obj.device}," f"dtype={obj.dtype})") @@ -84,6 +84,8 @@ def dump_engine_exception_v0(err: BaseException, execute_model_req: Union[ExecuteModelRequest, None] = None): + logger.error("Dumping input data") + logger.error( "V0 LLM engine (v%s) with config: %s, " "use_cached_outputs=%s, ", diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 99b97ac8e6c4..c3a0dd101022 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -30,6 +30,8 @@ _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +VLLM_ENABLE_V1_MULTIPROCESSING = envs.VLLM_ENABLE_V1_MULTIPROCESSING + class LLMEngine: """Legacy LLMEngine for backwards compatibility.""" @@ -104,7 +106,7 @@ def from_engine_args( vllm_config = engine_args.create_engine_config(usage_context) executor_class = Executor.get_class(vllm_config) - if envs.VLLM_ENABLE_V1_MULTIPROCESSING: + if VLLM_ENABLE_V1_MULTIPROCESSING: logger.debug("Enabling multiprocessing for LLMEngine.") enable_multiprocessing = True From 02e16736273b538413388640bed3056543a66877 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 7 Mar 2025 10:16:34 -0300 Subject: [PATCH 10/18] refactoring Signed-off-by: Wallas Santos --- vllm/logging_utils/dump_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index ff2a4d39d27b..0a9c08283296 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -29,7 +29,7 @@ def prepare_object_to_dump(obj) -> str: return f'[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]' # return [prepare_object_to_dump(v) for v in list(obj)] elif isinstance(obj, tuple): - return [prepare_object_to_dump(v) for v in obj] + return f'[{', '.join([prepare_object_to_dump(v) for v in obj])}]' elif isinstance(obj, enum.Enum): return repr(obj) elif isinstance(obj, torch.Tensor): From f66489d776b36ca96c02f9976d35070611f19f09 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 7 Mar 2025 16:02:27 -0300 Subject: [PATCH 11/18] fix lint Signed-off-by: Wallas Santos --- vllm/logging_utils/dump_input.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index 0a9c08283296..eaa2ac8f864f 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -24,12 +24,12 @@ def prepare_object_to_dump(obj) -> str: for k, v in obj.items()}) return f'{{{dict_str}}}' elif isinstance(obj, list): - return f'[{', '.join([prepare_object_to_dump(v) for v in obj])}]' + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" elif isinstance(obj, set): - return f'[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]' + return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]" # return [prepare_object_to_dump(v) for v in list(obj)] elif isinstance(obj, tuple): - return f'[{', '.join([prepare_object_to_dump(v) for v in obj])}]' + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" elif isinstance(obj, enum.Enum): return repr(obj) elif isinstance(obj, torch.Tensor): From 5f8264836f407f4a9d8bdf787d4410705073da7a Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 13 Mar 2025 17:16:11 -0300 Subject: [PATCH 12/18] reverted change on llm_engine due to test Signed-off-by: Wallas Santos --- tests/basic_correctness/test_basic_correctness.py | 6 +++--- vllm/v1/engine/llm_engine.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 45147669e477..28fb0b1b87e7 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -5,7 +5,7 @@ """ import os import weakref -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest @@ -153,9 +153,9 @@ def test_models_distributed( ) -@patch('vllm.v1.engine.llm_engine.VLLM_ENABLE_V1_MULTIPROCESSING', False) -def test_failed_model_execution(vllm_runner) -> None: +def test_failed_model_execution(vllm_runner, monkeypatch) -> None: + monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') # Create model with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: if isinstance(vllm_model.model.llm_engine, LLMEngineV1): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 894e53145390..213faaa45160 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -30,8 +30,6 @@ _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -VLLM_ENABLE_V1_MULTIPROCESSING = envs.VLLM_ENABLE_V1_MULTIPROCESSING - class LLMEngine: """Legacy LLMEngine for backwards compatibility.""" @@ -104,7 +102,7 @@ def from_engine_args( vllm_config = engine_args.create_engine_config(usage_context) executor_class = Executor.get_class(vllm_config) - if VLLM_ENABLE_V1_MULTIPROCESSING: + if envs.VLLM_ENABLE_V1_MULTIPROCESSING: logger.debug("Enabling multiprocessing for LLMEngine.") enable_multiprocessing = True From 93ae6dc5d7090ff67b5b05ee829c76e17a2aff2c Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 14 Mar 2025 11:27:57 -0300 Subject: [PATCH 13/18] fix: ensure suppress exception in dump Signed-off-by: Wallas Santos --- vllm/engine/llm_engine.py | 22 +++++++++++++--------- vllm/v1/engine/core.py | 15 ++++++++++----- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f3143d700d74..61274052e080 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,7 +4,7 @@ import time from collections import Counter as collectionsCounter from collections import deque -from contextlib import contextmanager +from contextlib import contextmanager, suppress from dataclasses import dataclass from functools import partial from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, @@ -1423,14 +1423,18 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Raise so the caller is notified that this request failed raise except BaseException as err: - stats = self._get_stats(scheduler_outputs=scheduler_outputs) - dump_engine_exception_v0( - err=err, - config=self.vllm_config, - use_cached_outputs=self.use_cached_outputs, - stats=stats, - execute_model_req=execute_model_req, - ) + # NOTE: ensure we can log extra info without risking raises + # raises unexpected errors during logging + with suppress(BaseException): + stats = self._get_stats( + scheduler_outputs=scheduler_outputs) + dump_engine_exception_v0( + err=err, + config=self.vllm_config, + use_cached_outputs=self.use_cached_outputs, + stats=stats, + execute_model_req=execute_model_req, + ) # Re-raise exception raise err diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2274dd448319..c10d4f56671a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import contextlib import queue import signal import threading @@ -200,11 +201,15 @@ def step(self) -> EngineCoreOutputs: try: output = self.model_executor.execute_model(scheduler_output) except BaseException as err: - err = ModelExecutionError( - f"Model execution failure," - f"reason: {repr(err)}", - scheduler_output=scheduler_output) - dump_engine_exception(err, self.vllm_config) + # NOTE: ensure we can log extra info without risking raises + # raises unexpected errors during logging + with contextlib.suppress(BaseException): + err = ModelExecutionError( + f"Model execution failure," + f"reason: {repr(err)}", + scheduler_output=scheduler_output) + dump_engine_exception(err, self.vllm_config) + # Re-raise exception raise err engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore From 33840559bf7758c858f656ff0a267f71ec98fd4d Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 8 Apr 2025 14:40:19 -0300 Subject: [PATCH 14/18] feat: removed v0 Signed-off-by: Wallas Santos --- .../test_basic_correctness.py | 22 +----- vllm/engine/llm_engine.py | 19 +---- vllm/logging_utils/dump_input.py | 65 ----------------- vllm/worker/worker_base.py | 70 +++++-------------- 4 files changed, 18 insertions(+), 158 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 37b8dfeaa2f1..f2ef6b962742 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -13,7 +13,6 @@ from vllm.platforms import current_platform from vllm.v1.engine.core import ModelExecutionError from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 -from vllm.worker.worker_base import ModelExecutionV0Error from ..conftest import VllmRunner from ..models.utils import check_outputs_equal @@ -172,26 +171,7 @@ def test_failed_model_execution(vllm_runner, monkeypatch) -> None: if isinstance(vllm_model.model.llm_engine, LLMEngineV1): v1_test_failed_model_execution(vllm_model) else: # V0 - v0_test_failed_model_execution(vllm_model) - - -def v0_test_failed_model_execution(vllm_model): - engine = vllm_model.model.llm_engine - mocked_execute_model = Mock( - side_effect=RuntimeError("Mocked Critical Error")) - engine.model_executor.driver_worker.model_runner.execute_model = \ - mocked_execute_model - with pytest.raises(RuntimeError) as exc_info: - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - vllm_model.generate_greedy(prompts, 200, use_tqdm=False) - assert isinstance(exc_info.value, ModelExecutionV0Error) - assert exc_info.value.model_input is not None - assert "Mocked Critical Error" in str(exc_info.value) + pytest.skip("Skipping V0 test") def v1_test_failed_model_execution(vllm_model): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de3144b831fb..f842581bf551 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,7 +4,7 @@ import time from collections import Counter as collectionsCounter from collections import deque -from contextlib import contextmanager, suppress +from contextlib import contextmanager from dataclasses import dataclass from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, @@ -34,7 +34,6 @@ from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger -from vllm.logging_utils.dump_input import dump_engine_exception_v0 from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( @@ -1431,7 +1430,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) self._skip_scheduling_next_step = False - except InputProcessingError as e: # The input for this request cannot be processed, so we must # abort it. If there are remaining requests in the batch that @@ -1445,21 +1443,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: allow_async_output_proc=allow_async_output_proc) # Raise so the caller is notified that this request failed raise - except BaseException as err: - # NOTE: ensure we can log extra info without risking raises - # raises unexpected errors during logging - with suppress(BaseException): - stats = self._get_stats( - scheduler_outputs=scheduler_outputs) - dump_engine_exception_v0( - err=err, - config=self.vllm_config, - use_cached_outputs=self.use_cached_outputs, - stats=stats, - execute_model_req=execute_model_req, - ) - # Re-raise exception - raise err # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index eaa2ac8f864f..a2ce79456d42 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -2,16 +2,12 @@ import enum import json -from typing import Union import torch from vllm.config import VllmConfig -from vllm.engine.metrics import Stats from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest from vllm.version import __version__ as VLLM_VERSION -from vllm.worker.worker_base import ModelExecutionV0Error logger = init_logger(__name__) @@ -74,64 +70,3 @@ def dump_engine_exception(err: BaseException, config: VllmConfig): except BaseException as exception: logger.error("Error preparing object to dump") logger.error(repr(exception)) - - -# TODO: Remove this when V1 is default -def dump_engine_exception_v0(err: BaseException, - config: VllmConfig, - stats: Union[Stats, None] = None, - use_cached_outputs: Union[bool, None] = None, - execute_model_req: Union[ExecuteModelRequest, - None] = None): - - logger.error("Dumping input data") - - logger.error( - "V0 LLM engine (v%s) with config: %s, " - "use_cached_outputs=%s, ", - VLLM_VERSION, - config, - use_cached_outputs, - ) - - # For V0 - if isinstance(err, ModelExecutionV0Error): - try: - dump_obj = prepare_object_to_dump(err.model_input) - logger.error("Dumping model input for execution:") - logger.error(dump_obj) - except BaseException as exception: - logger.error("Error preparing object to dump") - logger.error(repr(exception)) - - # In case we do not have a ModelExecutionV0Error, which is only present if - # the engine raise an error, we still can dump the information from the - # batch - if execute_model_req is not None: - batch = execute_model_req.seq_group_metadata_list - requests_count = len(batch) - - requests_prompt_token_ids_lenghts = [{ - k: len(v.prompt_token_ids) - for (k, v) in r.seq_data.items() - } for r in batch] - - requests_ids = ', '.join([str(r.request_id) for r in batch]) - logger.error( - "Batch info: requests_count=%s, " - "requests_prompt_token_ids_lenghts=(%s), " - "requests_ids=(%s)", requests_count, - requests_prompt_token_ids_lenghts, requests_ids) - - for idx, r in enumerate(batch): - logger.error( - "Errored Batch request #%s: request_id=%s " - "prompt_token_ids_lengths=%s, " - "params=%s, " - "lora_request=%s, prompt_adapter_request=%s ", idx, - r.request_id, str(len(r.seq_data[idx].prompt_token_ids)), - r.sampling_params, r.lora_request, r.prompt_adapter_request) - - if stats is not None: - logger.error("System stats:") - logger.error(stats) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 2f86123e3c26..e5662e69343c 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -208,33 +208,6 @@ def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") -class ModelExecutionV0Error(RuntimeError): - """Custom RuntimeError with input data for model execution - - In a nutshell, this object is useful for custom handling of exception for - the case the engine raises an error. For instance, it is used to log the - input metadata that is useful for debugging on engine crashes. - - Args: - model_input: BroadcastableModelInput object that contains the input - data for model execution - - """ - model_input: BroadcastableModelInput - - def __init__(self, *args, model_input=None): - super().__init__(*args) - self.model_input = model_input - - def __reduce__(self): - # To avoid pickle errors. - # This happens when we exchange this object between processes. - # since model_input can have objects that only makes sense - # to their context/process we remove them from the serialization - # and only send the summary of the error as a regular RuntimeError. - return (self.__class__, (self.args[0], )) - - @dataclasses.dataclass(frozen=True) class WorkerInput: """Local inputs to each worker. May contain device-specific data. These @@ -443,20 +416,15 @@ def execute_model( and self.observability_config.collect_model_execute_time): orig_model_execute_time = intermediate_tensors.tensors.get( "model_execute_time", torch.tensor(0)).item() - try: - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) - except BaseException as err: - raise ModelExecutionV0Error( - f"Model execution failure," - f"reason: {repr(err)}", - model_input=model_input) from err + + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: @@ -506,19 +474,13 @@ def _execute_model_spmd( kwargs = extract_previous_hidden_states(execute_model_req) - try: - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) - except BaseException as err: - raise ModelExecutionV0Error( - f"Model execution failure," - f"reason: {repr(err)}", - model_input=model_input) from err + return self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) class WorkerWrapperBase: From ad368f1e98179cd53c6ef5bcd3e7289e05613b48 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 8 Apr 2025 16:10:32 -0300 Subject: [PATCH 15/18] removed v0 support added logs for scheduler stats minor fixes Signed-off-by: Wallas Santos --- tests/basic_correctness/test_basic_correctness.py | 14 ++++++++------ vllm/logging_utils/dump_input.py | 11 ++++++----- vllm/v1/engine/core.py | 14 +++++++++----- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index f2ef6b962742..9f3b0e8ae079 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -11,7 +11,6 @@ from vllm import LLM from vllm.platforms import current_platform -from vllm.v1.engine.core import ModelExecutionError from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import VllmRunner @@ -165,13 +164,17 @@ def test_models_distributed( def test_failed_model_execution(vllm_runner, monkeypatch) -> None: + from vllm.envs import VLLM_USE_V1 + + if not VLLM_USE_V1: + pytest.skip("Skipping V0 test, dump input not supported") + + # Needed to mock an error in the same process monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - # Create model + with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: if isinstance(vllm_model.model.llm_engine, LLMEngineV1): v1_test_failed_model_execution(vllm_model) - else: # V0 - pytest.skip("Skipping V0 test") def v1_test_failed_model_execution(vllm_model): @@ -190,6 +193,5 @@ def v1_test_failed_model_execution(vllm_model): "The future of AI is", ] vllm_model.generate_greedy(prompts, 200, use_tqdm=False) - assert isinstance(exc_info.value, ModelExecutionError) - assert exc_info.value.scheduler_output is not None + assert isinstance(exc_info.value, RuntimeError) assert "Mocked Critical Error" in str(exc_info.value) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index a2ce79456d42..e7c6217d80c3 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -59,14 +59,15 @@ def dump_engine_exception(err: BaseException, config: VllmConfig): config, ) - # TODO: Have stats for V1 - from vllm.v1.engine.core import ModelExecutionError if isinstance(err, ModelExecutionError): try: - dump_obj = prepare_object_to_dump(err.scheduler_output) - logger.error("Dumping scheduler output for model execution:") - logger.error(dump_obj) + if err.scheduler_output is not None: + dump_obj = prepare_object_to_dump(err.scheduler_output) + logger.error("Dumping scheduler output for model execution:") + logger.error(dump_obj) + if err.scheduler_stats is not None: + logger.error(err.scheduler_stats) except BaseException as exception: logger.error("Error preparing object to dump") logger.error(repr(exception)) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2bf204aa23cb..cfdd025a7273 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -36,6 +36,7 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -62,10 +63,12 @@ class ModelExecutionError(RuntimeError): """ scheduler_output: SchedulerOutput + scheduler_stats: SchedulerStats - def __init__(self, *args, scheduler_output=None): + def __init__(self, *args, scheduler_output=None, scheduler_stats=None): super().__init__(*args) self.scheduler_output = scheduler_output + self.scheduler_stats = scheduler_stats def __reduce__(self): # To avoid pickle errors. @@ -237,13 +240,14 @@ def step(self) -> EngineCoreOutputs: output = self.model_executor.execute_model(scheduler_output) except BaseException as err: # NOTE: ensure we can log extra info without risking raises - # raises unexpected errors during logging + # unexpected errors during logging with contextlib.suppress(BaseException): - err = ModelExecutionError( + model_err = ModelExecutionError( f"Model execution failure," f"reason: {repr(err)}", - scheduler_output=scheduler_output) - dump_engine_exception(err, self.vllm_config) + scheduler_output=scheduler_output, + scheduler_stats=self.scheduler.make_stats()) + dump_engine_exception(model_err, self.vllm_config) # Re-raise exception raise err From 7c18e20ad4330369d6a0b453d974abdb346388a5 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 10 Apr 2025 10:11:46 -0300 Subject: [PATCH 16/18] refact: code clean up Signed-off-by: Wallas Santos --- vllm/logging_utils/dump_input.py | 37 +++++++++++++++++---------- vllm/sequence.py | 8 ------ vllm/v1/engine/core.py | 43 +++----------------------------- 3 files changed, 27 insertions(+), 61 deletions(-) diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index e7c6217d80c3..169e24794095 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +import contextlib import enum import json +from typing import Optional import torch from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.metrics.stats import SchedulerStats from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -49,8 +53,18 @@ def prepare_object_to_dump(obj) -> str: return repr(obj) -def dump_engine_exception(err: BaseException, config: VllmConfig): +def dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + # NOTE: ensure we can log extra info without risking raises + # unexpected errors during logging + with contextlib.suppress(BaseException): + _dump_engine_exception(config, scheduler_output, scheduler_stats) + +def _dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): logger.error("Dumping input data") logger.error( @@ -59,15 +73,12 @@ def dump_engine_exception(err: BaseException, config: VllmConfig): config, ) - from vllm.v1.engine.core import ModelExecutionError - if isinstance(err, ModelExecutionError): - try: - if err.scheduler_output is not None: - dump_obj = prepare_object_to_dump(err.scheduler_output) - logger.error("Dumping scheduler output for model execution:") - logger.error(dump_obj) - if err.scheduler_stats is not None: - logger.error(err.scheduler_stats) - except BaseException as exception: - logger.error("Error preparing object to dump") - logger.error(repr(exception)) + try: + dump_obj = prepare_object_to_dump(scheduler_output) + logger.error("Dumping scheduler output for model execution:") + logger.error(dump_obj) + if scheduler_stats: + logger.error(scheduler_stats) + except BaseException as exception: + logger.error("Error preparing object to dump") + logger.error(repr(exception)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 824cd6b969c7..61867b025315 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -391,14 +391,6 @@ def __repr__(self) -> str: f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") - # Version of __repr__ with the prompt data obfuscated - def anon_repr(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids_len={len(self._prompt_token_ids)}, " - f"output_token_ids_len={len(self.output_token_ids)}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") - class Sequence: """Stores the data, status, and block information of a sequence. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index cfdd025a7273..f499a5575678 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import contextlib import os import queue import signal @@ -36,7 +35,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -50,35 +48,6 @@ _R = TypeVar('_R') # Return type for collective_rpc -class ModelExecutionError(RuntimeError): - """Custom RuntimeError with input data for model execution - - In a nutshell, this object is useful for custom handling of exception for - the case the engine raises an error. For instance, it is used to log the - input metadata that is useful for debugging on engine crashes. - - Args: - scheduler_output: SchedulerOutput object that contains the input - data for model execution - - """ - scheduler_output: SchedulerOutput - scheduler_stats: SchedulerStats - - def __init__(self, *args, scheduler_output=None, scheduler_stats=None): - super().__init__(*args) - self.scheduler_output = scheduler_output - self.scheduler_stats = scheduler_stats - - def __reduce__(self): - # To avoid pickle errors. - # This happens when we exchange this object between processes - # since scheduler_output can have objects that only makes sense - # to their context/process we remove them from the serialization - # and only send the summary of the error as a regular RuntimeError. - return (self.__class__, (self.args[0], )) - - class EngineCore: """Inner loop of vLLM's Engine.""" @@ -239,15 +208,9 @@ def step(self) -> EngineCoreOutputs: try: output = self.model_executor.execute_model(scheduler_output) except BaseException as err: - # NOTE: ensure we can log extra info without risking raises - # unexpected errors during logging - with contextlib.suppress(BaseException): - model_err = ModelExecutionError( - f"Model execution failure," - f"reason: {repr(err)}", - scheduler_output=scheduler_output, - scheduler_stats=self.scheduler.make_stats()) - dump_engine_exception(model_err, self.vllm_config) + # NOTE: This method is exception-free + dump_engine_exception(self.vllm_config, scheduler_output, + self.scheduler.make_stats()) # Re-raise exception raise err From 8cbee30988390ef582ac9108a0879e64f15f37bc Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 1 May 2025 11:03:34 -0300 Subject: [PATCH 17/18] refact: moved execute model to a separated method Signed-off-by: Wallas Santos --- vllm/v1/core/sched/output.py | 2 -- vllm/v1/engine/core.py | 23 ++++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index c98b64dbccc6..24032498e50b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -52,7 +52,6 @@ def __repr__(self): return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," - f"prompt={self.prompt}," f"mm_inputs={self.mm_inputs}," f"mm_hashes={self.mm_hashes}," f"mm_positions={self.mm_positions}," @@ -67,7 +66,6 @@ def anon_repr(self): return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids_len={len(self.prompt_token_ids)}," - f"prompt=''," f"mm_inputs={self.mm_inputs}," f"mm_hashes={self.mm_hashes}," f"mm_positions={self.mm_positions}," diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c81a7d640d12..a7fa614696a7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -193,6 +193,17 @@ def abort_requests(self, request_ids: list[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) + def execute_model(self, scheduler_output: SchedulerOutput): + try: + output = self.model_executor.execute_model(scheduler_output) + except BaseException as err: + # NOTE: This method is exception-free + dump_engine_exception(self.vllm_config, scheduler_output, + self.scheduler.make_stats()) + # Re-raise exception + raise err + return output + def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -204,17 +215,11 @@ def step(self) -> EngineCoreOutputs: scheduler_stats=self.scheduler.make_stats(), ) scheduler_output = self.scheduler.schedule() - try: - output = self.model_executor.execute_model(scheduler_output) - except BaseException as err: - # NOTE: This method is exception-free - dump_engine_exception(self.vllm_config, scheduler_output, - self.scheduler.make_stats()) - # Re-raise exception - raise err + + model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, output) # type: ignore + scheduler_output, model_output) # type: ignore return engine_core_outputs From 51596e4bfcfb87063a67a616089a815674fc5395 Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Wed, 7 May 2025 16:09:15 -0300 Subject: [PATCH 18/18] Update vllm/v1/engine/core.py Co-authored-by: Nick Hill Signed-off-by: Wallas Santos --- vllm/v1/engine/core.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a7fa614696a7..d9dd4957cff2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -195,14 +195,13 @@ def abort_requests(self, request_ids: list[str]): def execute_model(self, scheduler_output: SchedulerOutput): try: - output = self.model_executor.execute_model(scheduler_output) + return self.model_executor.execute_model(scheduler_output) except BaseException as err: # NOTE: This method is exception-free dump_engine_exception(self.vllm_config, scheduler_output, self.scheduler.make_stats()) # Re-raise exception raise err - return output def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -215,9 +214,7 @@ def step(self) -> EngineCoreOutputs: scheduler_stats=self.scheduler.make_stats(), ) scheduler_output = self.scheduler.schedule() - model_output = self.execute_model(scheduler_output) - engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output) # type: ignore