diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff0ff8f5cf0..1345590c346 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2988,10 +2988,16 @@ def _check_and_update_cudagraph_mode( set_draft_graph_params(self.cudagraph_batch_sizes) def capture_model(self) -> None: - parent_module_name = self.__class__.__base__.__module__ + gpu_model_runner_cls = next((cls for cls in self.__class__.__mro__ + if cls.__name__ == "GPUModelRunner"), + None) + if gpu_model_runner_cls is None: + raise TypeError("Could not find GPUModelRunner in the MRO. " + "The class hierarchy may have changed.") + parent_module_name = gpu_model_runner_cls.__module__ with _torch_cuda_wrapper(), _replace_gpu_model_runner_function_wrapper( parent_module_name): - super().capture_model() + GPUModelRunner.capture_model(self) def _prepare_multimodal_fields(self): """