From 450ef45f45bd9c60101a287536a0a7ea5f402634 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 9 May 2026 15:29:33 +0800 Subject: [PATCH 1/4] fix vllm dp & reset_encoder_cache & fix vllm init with zero3 --- swift/infer_engine/vllm_engine.py | 6 ++-- swift/megatron/trainers/rollout_mixin.py | 12 ++++++- swift/pipelines/infer/rollout.py | 44 ++++++++++++++++++++++-- swift/rlhf_trainers/rollout_mixin.py | 12 ++++++- swift/rlhf_trainers/vllm_client.py | 38 ++++++++++++++++++++ 5 files changed, 105 insertions(+), 7 deletions(-) diff --git a/swift/infer_engine/vllm_engine.py b/swift/infer_engine/vllm_engine.py index f1f55ed7a7..8800e4eb4f 100644 --- a/swift/infer_engine/vllm_engine.py +++ b/swift/infer_engine/vllm_engine.py @@ -15,7 +15,8 @@ from swift.metrics import Metric from swift.model import get_processor from swift.template import Template -from swift.utils import get_device, get_dist_setting, get_logger, is_dist, safe_snapshot_download +from swift.utils import (disable_deepspeed_zero3, get_device, get_dist_setting, get_logger, is_dist, + safe_snapshot_download) from .infer_engine import InferEngine from .patch import patch_auto_tokenizer from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, @@ -180,7 +181,8 @@ def _get_processor(self): task_type=self.task_type) def _prepare_engine(self) -> None: - with patch_auto_tokenizer(self.tokenizer), self._patch_auto_config(): + with patch_auto_tokenizer(self.tokenizer), self._patch_auto_config(), \ + disable_deepspeed_zero3(): llm_engine_cls = AsyncLLMEngine if self.use_async_engine else LLMEngine engine = llm_engine_cls.from_engine_args(self.engine_args) self.engine = engine diff --git a/swift/megatron/trainers/rollout_mixin.py b/swift/megatron/trainers/rollout_mixin.py index 0fe482bb89..9a3031e83b 100644 --- a/swift/megatron/trainers/rollout_mixin.py +++ b/swift/megatron/trainers/rollout_mixin.py @@ -221,6 +221,7 @@ def _init_rollout_engine(self): # Server mode uses external vLLM server if self.is_main_process: self.vllm_client.get_engine_type() + self.vllm_client.reset_mm_cache() enable_lora = [self.vllm_client.enable_lora] else: enable_lora = [False] @@ -236,6 +237,7 @@ def _init_rollout_engine(self): with context(): set_expandable_segments(False) self.engine = self._prepare_vllm_engine() + self.engine.engine.reset_mm_cache() if args.sleep_level > 0: self.engine.engine.sleep(args.sleep_level) set_expandable_segments(True) @@ -322,11 +324,19 @@ def _move_model_to_vllm(self): else: self._move_adapter_to_vllm() - # Reset prefix cache + self._reset_vllm_cache() + + def _reset_vllm_cache(self): + # Reset prefix cache and encoder cache(vllm>=0.17) + vllm_ge_16 = check_vllm_version_ge('0.16') if self.vllm_mode == 'server' and self.is_main_process: self.vllm_client.reset_prefix_cache() + if vllm_ge_16: + self.vllm_client.reset_encoder_cache() elif self.vllm_mode == 'colocate': self.engine.engine.reset_prefix_cache() + if vllm_ge_16: + self.engine.engine.reset_encoder_cache() def _move_full_model_to_vllm(self): """Transfer full model weights to vLLM engine. diff --git a/swift/pipelines/infer/rollout.py b/swift/pipelines/infer/rollout.py index 75932190a0..c960b4731d 100644 --- a/swift/pipelines/infer/rollout.py +++ b/swift/pipelines/infer/rollout.py @@ -32,6 +32,7 @@ from itertools import chain from multiprocessing import Pipe, Process from multiprocessing.connection import Connection +from transformers.utils import is_torch_npu_available from typing import Dict, List, Optional, Union from swift.arguments import RolloutArguments @@ -302,12 +303,31 @@ def get_rollout_engine_type(args: RolloutArguments, engine: GRPOVllmEngine): return rollout_engine +def _set_visible_devices_for_dp_rank(data_parallel_rank: int, tensor_parallel_size: int): + + def _get_device_env_var(): + if is_torch_npu_available(): + return 'ASCEND_RT_VISIBLE_DEVICES' + return 'CUDA_VISIBLE_DEVICES' + + env_var = _get_device_env_var() + current = os.environ.get(env_var) + if current: + all_devices = current.split(',') + else: + from swift.utils import get_device_count + all_devices = [str(i) for i in range(get_device_count())] + + start = data_parallel_rank * tensor_parallel_size + end = start + tensor_parallel_size + selected = all_devices[start:end] + os.environ[env_var] = ','.join(selected) + + def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int, connection: Connection) -> None: try: args._import_external_plugins() - os.environ['VLLM_DP_RANK'] = str(data_parallel_rank) - os.environ['VLLM_DP_RANK_LOCAL'] = str(data_parallel_rank) - os.environ['VLLM_DP_SIZE'] = str(args.vllm_data_parallel_size) + _set_visible_devices_for_dp_rank(data_parallel_rank, args.vllm_tensor_parallel_size) os.environ['VLLM_DP_MASTER_PORT'] = str(master_port) worker_seed = get_seed() engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(), seed=worker_seed) @@ -396,6 +416,8 @@ def _register_rl_rollout_app(self): self.app.post('/update_adapter_param/')(self.update_adapter_param) self.app.post('/update_flattened_params/')(self.update_flattened_params) self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache) + self.app.post('/reset_encoder_cache/')(self.reset_encoder_cache) + self.app.post('/reset_mm_cache/')(self.reset_mm_cache) self.app.post('/close_communicator/')(self.close_communicator) self.app.post('/infer/', response_model=None)(self.infer) self.app.post('/get_engine_type/')(self.get_engine_type) @@ -632,6 +654,22 @@ async def reset_prefix_cache(self): success = all(output for output in all_outputs) return {'message': 'Request received, resetting prefix cache status: ' + str(success)} + async def reset_encoder_cache(self): + """Resets the encoder cache (vision encoder embeddings) for the model.""" + for connection in self.connections: + connection.send({'type': 'call', 'method': 'reset_encoder_cache'}) + all_outputs = [connection.recv() for connection in self.connections] + success = all(output for output in all_outputs) + return {'message': 'Request received, resetting encoder cache status: ' + str(success)} + + async def reset_mm_cache(self): + """Resets the multimodal processor cache for the model.""" + for connection in self.connections: + connection.send({'type': 'call', 'method': 'reset_mm_cache'}) + all_outputs = [connection.recv() for connection in self.connections] + success = all(output for output in all_outputs) + return {'message': 'Request received, resetting mm cache status: ' + str(success)} + async def get_engine_type(self): """ Return a dictionary describing the runtime engine configuration. diff --git a/swift/rlhf_trainers/rollout_mixin.py b/swift/rlhf_trainers/rollout_mixin.py index 1bed826a6c..ecc41c6640 100644 --- a/swift/rlhf_trainers/rollout_mixin.py +++ b/swift/rlhf_trainers/rollout_mixin.py @@ -217,6 +217,7 @@ def _prepare_vllm(self): if self.vllm_mode == 'server': if self.accelerator.is_main_process: self.vllm_client.get_engine_type() + self.vllm_client.reset_mm_cache() vllm_use_async_engine = [self.vllm_client.use_async_engine] use_gym_env = [self.vllm_client.use_gym_env] enable_multi_turn = [self.vllm_client.enable_multi_turn] @@ -246,6 +247,7 @@ def _prepare_vllm(self): with context(): self.engine = self._prepare_vllm_engine() + self.engine.engine.reset_mm_cache() if args.sleep_level > 0: self.engine.engine.sleep(args.sleep_level) self.dynamic_num_samples = False # grpo multi-turn @@ -451,11 +453,19 @@ def _move_model_to_vllm(self, skip_async_check=False): else: self._move_adapter_to_vllm() - # Reset prefix cache + self._reset_vllm_cache() + + def _reset_vllm_cache(self): + # Reset prefix cache and encoder cache after weight update + vllm_ge_16 = check_vllm_version_ge('0.16') if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() + if vllm_ge_16: + self.vllm_client.reset_encoder_cache() elif self.vllm_mode == 'colocate': self.engine.engine.reset_prefix_cache() + if vllm_ge_16: + self.engine.engine.reset_encoder_cache() def _move_adapter_to_vllm(self): """Transfer LoRA adapter weights to vLLM engine""" diff --git a/swift/rlhf_trainers/vllm_client.py b/swift/rlhf_trainers/vllm_client.py index e529ca2ca3..cc83e8432e 100644 --- a/swift/rlhf_trainers/vllm_client.py +++ b/swift/rlhf_trainers/vllm_client.py @@ -429,6 +429,44 @@ def _reset_single_server(i): if all_errors: raise RuntimeError(f'Multiple errors on reset_prefix_cache: {all_errors}') + def reset_encoder_cache(self): + errors = [None] * self.num_servers + + def _reset_single_server(i): + try: + response = self.sessions[i].post(f'{self.base_urls[i]}/reset_encoder_cache/') + if response.status_code != 200: + raise Exception(f'Server {i} reset failed: {response.text}') + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_reset_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors on reset_encoder_cache: {all_errors}') + + def reset_mm_cache(self): + errors = [None] * self.num_servers + + def _reset_single_server(i): + try: + response = self.sessions[i].post(f'{self.base_urls[i]}/reset_mm_cache/') + if response.status_code != 200: + raise Exception(f'Server {i} reset failed: {response.text}') + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_reset_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors on reset_mm_cache: {all_errors}') + def get_engine_type(self): # assume that all server has same engine type response = self.sessions[0].post(f'{self.base_urls[0]}/get_engine_type/') From 3c677d18ff8a5c1c0ace655bd4ee37bf8e54919a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 9 May 2026 15:36:37 +0800 Subject: [PATCH 2/4] fix comment --- swift/megatron/trainers/rollout_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/trainers/rollout_mixin.py b/swift/megatron/trainers/rollout_mixin.py index 9a3031e83b..41072b6f75 100644 --- a/swift/megatron/trainers/rollout_mixin.py +++ b/swift/megatron/trainers/rollout_mixin.py @@ -327,7 +327,7 @@ def _move_model_to_vllm(self): self._reset_vllm_cache() def _reset_vllm_cache(self): - # Reset prefix cache and encoder cache(vllm>=0.17) + # Reset prefix cache and encoder cache vllm_ge_16 = check_vllm_version_ge('0.16') if self.vllm_mode == 'server' and self.is_main_process: self.vllm_client.reset_prefix_cache() From 55f5c35adc0718f53896d3f9d727bb899a8927a4 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sat, 9 May 2026 16:39:05 +0800 Subject: [PATCH 3/4] fix dp --- swift/pipelines/infer/rollout.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/swift/pipelines/infer/rollout.py b/swift/pipelines/infer/rollout.py index c960b4731d..b6eeaabab6 100644 --- a/swift/pipelines/infer/rollout.py +++ b/swift/pipelines/infer/rollout.py @@ -108,8 +108,12 @@ def init_communicator(self, host: str, port: int, world_size: int) -> None: if self.communicator is not None: raise RuntimeError('Weight update group already initialized. Call close_communicator first.') - # Get the rank of the current worker in the global world group. - rank = get_world_group().rank + # When using independent vLLM instances for DP (each with its own world group), + # offset the rank by dp_rank * tp_size so each DP worker gets a unique rank + # in the communicator's process group. + dp_rank = int(os.environ.get('SWIFT_ROLLOUT_DP_RANK', '0')) + tp_size = int(os.environ.get('SWIFT_ROLLOUT_TP_RANK', '1')) + rank = get_world_group().rank + dp_rank * tp_size # Create a stateless process group to manage communication between training processes and vLLM workers. # Initialize the NCCL-based communicator for weight synchronization. @@ -329,6 +333,8 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int args._import_external_plugins() _set_visible_devices_for_dp_rank(data_parallel_rank, args.vllm_tensor_parallel_size) os.environ['VLLM_DP_MASTER_PORT'] = str(master_port) + os.environ['SWIFT_ROLLOUT_DP_RANK'] = str(data_parallel_rank) + os.environ['SWIFT_ROLLOUT_TP_RANK'] = str(args.vllm_tensor_parallel_size) worker_seed = get_seed() engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(), seed=worker_seed) rollout_engine = get_rollout_engine_type(args, engine) @@ -365,6 +371,8 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast connection: Connection) -> None: try: args._import_external_plugins() + os.environ['SWIFT_ROLLOUT_DP_RANK'] = str(data_parallel_rank) + os.environ['SWIFT_ROLLOUT_TP_RANK'] = str(args.vllm_tensor_parallel_size) worker_seed = get_seed() engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(), seed=worker_seed) rollout_engine = get_rollout_engine_type(args, engine) From e6d824f2432ef04c6290176a7ff8fab9eb5aa19f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 May 2026 12:41:46 +0800 Subject: [PATCH 4/4] fix async dp timeout --- swift/infer_engine/vllm_engine.py | 39 +++++++++++++++++++++++++++++++ swift/pipelines/infer/rollout.py | 15 +++++++----- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/swift/infer_engine/vllm_engine.py b/swift/infer_engine/vllm_engine.py index 8800e4eb4f..5102374e1b 100644 --- a/swift/infer_engine/vllm_engine.py +++ b/swift/infer_engine/vllm_engine.py @@ -1,7 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio import inspect +import multiprocessing import os +import time import torch from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -49,6 +51,43 @@ dtype_mapping = {torch.float16: 'float16', torch.bfloat16: 'bfloat16', torch.float32: 'float32'} +def _patch_vllm_dp_coordinator_timeout(): + # https://github.com/vllm-project/vllm/pull/37452 introduced a 30-second default timeout, + # which is prone to timing out in spawn scenarios. Patch it to 180 seconds here. + try: + from vllm.v1.engine import coordinator as coordinator_module + except ImportError: + return + + coordinator_cls = coordinator_module.DPCoordinator + if not hasattr(coordinator_cls, '_wait_for_zmq_addrs'): + return + + if getattr(coordinator_cls, '_swift_timeout_patched', False): + return + + def _wait_for_zmq_addrs(self, zmq_addr_pipe): + t0 = time.monotonic() + try: + ready = multiprocessing.connection.wait([zmq_addr_pipe, self.proc.sentinel], timeout=180) + elapsed = time.monotonic() - t0 + if not ready: + raise RuntimeError(f'DP Coordinator process failed to report ZMQ addresses ' + f'within 180s (elapsed={elapsed:.1f}s).') + try: + return zmq_addr_pipe.recv() + except EOFError: + raise RuntimeError('DP Coordinator process failed during startup.') from None + finally: + zmq_addr_pipe.close() + + coordinator_cls._wait_for_zmq_addrs = _wait_for_zmq_addrs + coordinator_cls._swift_timeout_patched = True + + +_patch_vllm_dp_coordinator_timeout() + + class VllmEngine(InferEngine): def __init__( diff --git a/swift/pipelines/infer/rollout.py b/swift/pipelines/infer/rollout.py index b6eeaabab6..bdb10d183a 100644 --- a/swift/pipelines/infer/rollout.py +++ b/swift/pipelines/infer/rollout.py @@ -106,13 +106,16 @@ def init_communicator(self, host: str, port: int, world_size: int) -> None: Total number of participating processes in the update group. """ if self.communicator is not None: - raise RuntimeError('Weight update group already initialized. Call close_communicator first.') + return - # When using independent vLLM instances for DP (each with its own world group), - # offset the rank by dp_rank * tp_size so each DP worker gets a unique rank - # in the communicator's process group. - dp_rank = int(os.environ.get('SWIFT_ROLLOUT_DP_RANK', '0')) - tp_size = int(os.environ.get('SWIFT_ROLLOUT_TP_RANK', '1')) + parallel_config = getattr(getattr(self, 'vllm_config', None), 'parallel_config', None) + dp_index = int(getattr(parallel_config, 'data_parallel_index', 0)) if parallel_config is not None else 0 + if dp_index > 0: + dp_rank = dp_index + tp_size = int(parallel_config.tensor_parallel_size) + else: + dp_rank = int(os.environ.get('SWIFT_ROLLOUT_DP_RANK', '0')) + tp_size = int(os.environ.get('SWIFT_ROLLOUT_TP_RANK', '1')) rank = get_world_group().rank + dp_rank * tp_size # Create a stateless process group to manage communication between training processes and vLLM workers.