diff --git a/tpu_inference/core/disagg_executor.py b/tpu_inference/core/disagg_executor.py index 73e46ab6fb..84e44cb77e 100644 --- a/tpu_inference/core/disagg_executor.py +++ b/tpu_inference/core/disagg_executor.py @@ -21,8 +21,7 @@ class DisaggExecutor(Executor): def _init_executor(self) -> None: """Initialize the worker and load the model. """ - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) + self.driver_worker = WorkerWrapperBase(rpc_rank=0) slice_config = getattr(self.vllm_config.device_config, "slice") idx = slice_config[0] jax_devices = slice_config[-1]