diff --git a/docs/training/weight_transfer/nccl.md b/docs/training/weight_transfer/nccl.md index 7b531218568b..bfde1ee2ae37 100644 --- a/docs/training/weight_transfer/nccl.md +++ b/docs/training/weight_transfer/nccl.md @@ -84,10 +84,7 @@ Both the trainer (`NCCLTrainerSendWeightsArgs`) and inference side (`NCCLWeightT ## Receiving Weights (Inference Side) -The inference side triggers weight reception using the four-phase protocol: -`init_weight_transfer_engine`, `start_weight_update`, `update_weights`, -`finish_weight_update`. The init phase is shown [above](#initialization). The -remaining three steps are: +The inference side triggers weight reception using the four-phase protocol — `init_weight_transfer_engine`, `start_weight_update`, `update_weights`, `finish_weight_update`. The init phase is shown [above](#initialization). The remaining three steps are: ```python from vllm.distributed.weight_transfer.base import WeightTransferUpdateRequest @@ -111,24 +108,12 @@ llm.update_weights( llm.finish_weight_update() ``` -The `names`, `dtype_names`, and `shapes` lists describe each parameter. These -must match the order in which the trainer iterates over its parameters. +The `names`, `dtype_names`, and `shapes` lists describe each parameter. These must match the order in which the trainer iterates over its parameters. -`start_weight_update` must be called before `update_weights`, and -`finish_weight_update` must be called after all weight chunks have been -transferred. The `is_checkpoint_format` flag controls whether layerwise reload -processing is applied (`True` for checkpoint-format weights, `False` for -pre-processed kernel-format weights). - -Sparse NCCL patches still use `update_kind="sparse_flat"` inside -`update_info`, but they should be wrapped in -`start_weight_update(is_checkpoint_format=False)` because sparse patches apply -directly to runtime/kernel-format parameters. The current sparse MVP requires -`TP=1` and `PP=1`. +`start_weight_update` must be called before `update_weights`, and `finish_weight_update` must be called after all weight chunks have been transferred. The `is_checkpoint_format` flag controls whether layerwise reload processing is applied (`True` for checkpoint-format weights, `False` for pre-processed kernel-format weights). ## Examples - [RLHF with NCCL weight syncing (offline, Ray)](../../../examples/rl/rlhf_nccl.py) - Trainer on one GPU, 2x tensor-parallel vLLM engine on two others, with packed NCCL weight broadcast -- [RLHF with sparse NCCL weight syncing (offline, Ray)](../../../examples/rl/rlhf_sparse_nccl.py) - Dense-vs-sparse equivalence demo with a real model on a 2-GPU trainer/inference setup; sparse patches use `start_weight_update(is_checkpoint_format=False)` and currently require `TP=1` and `PP=1` - [RLHF with async weight syncing (offline, Ray)](../../../examples/rl/rlhf_async_new_apis.py) - Async generation with mid-flight pause, weight sync, resume, and validation against a fresh model - [RLHF with NCCL weight syncing (online serving, HTTP)](../../../examples/rl/rlhf_http_nccl.py) - Weight transfer with a running vLLM HTTP server using HTTP control plane and NCCL data plane diff --git a/examples/rl/rlhf_sparse_nccl.py b/examples/rl/rlhf_sparse_nccl.py deleted file mode 100644 index bddd28b6485e..000000000000 --- a/examples/rl/rlhf_sparse_nccl.py +++ /dev/null @@ -1,526 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrates dense-vs-sparse NCCL weight syncing with a real model. - -This example mirrors the validation story used for the sparse NCCL MVP: -both the dense update path and the sparse patch path start from the same real -checkpoint and apply the same deterministic trainer-side patch. The script then -checks that greedy 1-token outputs match between the dense and sparse vLLM -engines after the update. - -The example performs the following steps: -* Load a training model on one GPU via a Ray actor. -* Launch a vLLM engine with the same real model on a second GPU. -* Verify trainer vs vLLM baseline agreement before any update. -* Apply a deterministic patch to ``model.embed_tokens.weight`` on the trainer. -* Run a dense NCCL update into a fresh vLLM engine and collect post-update - outputs. -* Reset the trainer back to the baseline checkpoint. -* Apply the same deterministic patch again. -* Run a sparse NCCL update into another fresh vLLM engine and collect - post-update outputs. -* Compare dense vs sparse baseline outputs, dense vs sparse post-update - outputs, estimated payload sizes, and trainer-side send times. - -Current sparse weight transfer MVP limitations: -* ``TP=1`` and ``PP=1`` only -* sparse updates use runtime/kernel-format parameter names -* sparse updates are not composable with checkpoint-format or packed updates - -This example assumes a single-node cluster with two GPUs. -""" - -import hashlib -import os -import time -from collections.abc import Sequence - -import ray -import torch -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from transformers import AutoModelForCausalLM, AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.config import WeightTransferConfig -from vllm.distributed.weight_transfer.base import SparseWeightPatch -from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLTrainerSendWeightsArgs, - NCCLWeightTransferEngine, -) -from vllm.utils.network_utils import get_ip, get_open_port - -MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" -PATCHED_PARAM_NAME = "model.embed_tokens.weight" -MAX_PATCH_ROWS = 32 -PROMPTS = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -SAMPLING_PARAMS = SamplingParams(temperature=0.0, max_tokens=1) - - -class MyLLM(LLM): - """Configure the vLLM worker for Ray placement group execution.""" - - def __init__(self, *args, **kwargs): - os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0" - super().__init__(*args, **kwargs) - - -@ray.remote(num_gpus=1) -class TrainModel: - """Ray actor that owns the trainer-side model and deterministic patch state.""" - - def __init__(self, model_name: str): - self.model_name = model_name - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.model = None - self.patched_param = None - self.pending_sparse_patches: list[SparseWeightPatch] | None = None - self.model_update_group = None - self.master_address = get_ip() - self.port = get_open_port() - self.reset_model() - - def reset_model(self) -> None: - self.model = AutoModelForCausalLM.from_pretrained( - self.model_name, - torch_dtype=torch.bfloat16, - ).to("cuda:0") - self.model.eval() - - try: - self.patched_param = self.model.get_parameter(PATCHED_PARAM_NAME) - except AttributeError as exc: - raise RuntimeError( - f"Expected trainer model to expose `{PATCHED_PARAM_NAME}`" - ) from exc - - self.pending_sparse_patches = None - - def create_rendezvous(self) -> tuple[str, int]: - self.port = get_open_port() - return self.master_address, self.port - - def init_weight_transfer_group(self, world_size: int) -> None: - self.model_update_group = NCCLWeightTransferEngine.trainer_init( - dict( - master_address=self.master_address, - master_port=self.port, - world_size=world_size, - ) - ) - - def get_dense_update_info(self, packed: bool = False) -> tuple[dict, int]: - names = [] - dtype_names = [] - shapes = [] - payload_bytes = 0 - for name, param in self.model.named_parameters(): - names.append(name) - dtype_names.append(str(param.dtype).split(".")[-1]) - shapes.append(list(param.shape)) - payload_bytes += param.numel() * param.element_size() - - return ( - dict( - names=names, - dtype_names=dtype_names, - shapes=shapes, - packed=packed, - ), - payload_bytes, - ) - - @torch.inference_mode() - def generate( - self, - prompts: Sequence[str], - max_new_tokens: int = 1, - ) -> list[dict[str, object]]: - generations = [] - for prompt in prompts: - model_inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda:0") - output = self.model.generate( - **model_inputs, - max_new_tokens=max_new_tokens, - do_sample=False, - pad_token_id=self.tokenizer.pad_token_id, - ) - new_token_ids = output[0, model_inputs["input_ids"].shape[1] :].tolist() - generations.append( - { - "token_ids": new_token_ids, - "text": self.tokenizer.decode( - new_token_ids, - skip_special_tokens=False, - ), - } - ) - return generations - - def prepare_sparse_patch( - self, - prompts: Sequence[str], - max_patch_rows: int = MAX_PATCH_ROWS, - ) -> tuple[dict[str, object], list[int], str, int]: - selected_token_ids: list[int] = [] - special_ids = set(self.tokenizer.all_special_ids) - for prompt in prompts: - token_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] - for token_id in token_ids: - if token_id in special_ids or token_id in selected_token_ids: - continue - selected_token_ids.append(token_id) - if len(selected_token_ids) == max_patch_rows: - break - if len(selected_token_ids) == max_patch_rows: - break - - if not selected_token_ids: - raise ValueError("Could not derive any non-special token IDs to patch") - - vocab_size = self.patched_param.shape[0] - next_token_id = selected_token_ids[-1] - while len(selected_token_ids) < max_patch_rows: - next_token_id = (next_token_id + 1) % vocab_size - if next_token_id in special_ids or next_token_id in selected_token_ids: - continue - selected_token_ids.append(next_token_id) - - row_ids = torch.tensor( - selected_token_ids, - device=self.patched_param.device, - dtype=torch.long, - ) - hidden_size = self.patched_param.shape[1] - column_offsets = torch.arange( - hidden_size, - device=self.patched_param.device, - dtype=torch.long, - ) - - with torch.no_grad(): - # Rotate the selected embedding rows instead of zeroing them so the - # patch remains deterministic while avoiding a degenerate collapse - # to the same special token after the update. - replacement_rows = self.patched_param[row_ids].roll(shifts=1, dims=0) - self.patched_param[row_ids] = replacement_rows - - flat_indices = ( - row_ids.unsqueeze(1).mul(hidden_size).add(column_offsets).reshape(-1) - ) - flat_values = self.patched_param[row_ids].reshape(-1).contiguous() - self.pending_sparse_patches = [ - SparseWeightPatch( - name=PATCHED_PARAM_NAME, - indices=flat_indices.to(torch.int32), - values=flat_values, - ) - ] - patch_digest = hashlib.sha256( - self.pending_sparse_patches[0].indices.cpu().numpy().tobytes() - + self.pending_sparse_patches[0] - .values.detach() - .float() - .cpu() - .numpy() - .tobytes() - ).hexdigest() - - sparse_payload_bytes = ( - flat_indices.numel() * torch.tensor([], dtype=torch.int32).element_size() - + flat_values.numel() * flat_values.element_size() - ) - update_info = dict( - names=[PATCHED_PARAM_NAME], - dtype_names=[str(self.patched_param.dtype).split(".")[-1]], - shapes=[list(self.patched_param.shape)], - num_updates_list=[flat_indices.numel()], - update_kind="sparse_flat", - ) - return update_info, selected_token_ids, patch_digest, sparse_payload_bytes - - def broadcast_weights(self, packed: bool = False) -> float: - if self.model_update_group is None: - raise RuntimeError("Weight transfer group is not initialized") - - trainer_args = NCCLTrainerSendWeightsArgs( - group=self.model_update_group, - packed=packed, - ) - start = time.perf_counter() - NCCLWeightTransferEngine.trainer_send_weights( - iterator=self.model.named_parameters(), - trainer_args=trainer_args, - ) - torch.accelerator.synchronize() - return (time.perf_counter() - start) * 1000.0 - - def broadcast_pending_sparse_patch(self) -> float: - if self.model_update_group is None: - raise RuntimeError("Weight transfer group is not initialized") - if self.pending_sparse_patches is None: - raise RuntimeError("Sparse patch has not been prepared") - - start = time.perf_counter() - NCCLWeightTransferEngine.trainer_send_sparse_weights( - iter(self.pending_sparse_patches), - NCCLTrainerSendWeightsArgs(group=self.model_update_group), - ) - torch.accelerator.synchronize() - self.pending_sparse_patches = None - return (time.perf_counter() - start) * 1000.0 - - -def launch_llm( - scheduling_inference: PlacementGroupSchedulingStrategy, -): - return ray.remote( - num_cpus=0, - num_gpus=0, - scheduling_strategy=scheduling_inference, - )(MyLLM).remote( - model=MODEL_NAME, - enforce_eager=True, - tensor_parallel_size=1, - distributed_executor_backend="ray", - gpu_memory_utilization=0.7, - weight_transfer_config=WeightTransferConfig(backend="nccl"), - ) - - -def collect_vllm_generations(llm_handle) -> list[dict[str, object]]: - outputs = ray.get(llm_handle.generate.remote(PROMPTS, SAMPLING_PARAMS)) - generations = [] - for output in outputs: - generations.append( - { - "token_ids": output.outputs[0].token_ids, - "text": output.outputs[0].text, - } - ) - return generations - - -def token_sequences_match( - left: Sequence[dict[str, object]], - right: Sequence[dict[str, object]], -) -> bool: - return [item["token_ids"] for item in left] == [item["token_ids"] for item in right] - - -def print_generations(label: str, prompts: Sequence[str], generations) -> None: - print(f"\n{label}") - print("-" * 50) - for prompt, generation in zip(prompts, generations): - print(f"Prompt: {prompt!r}") - print(f"Token IDs: {generation['token_ids']}") - print(f"Text: {generation['text']!r}") - print("-" * 50) - - -def run_dense_phase( - train_model, - scheduling_inference: PlacementGroupSchedulingStrategy, -) -> dict[str, object]: - ray.get(train_model.reset_model.remote()) - llm = launch_llm(scheduling_inference) - try: - dense_before = collect_vllm_generations(llm) - - ray.get(llm.sleep.remote(level=0)) - master_address, master_port = ray.get(train_model.create_rendezvous.remote()) - world_size = ray.get(llm.get_world_size.remote()) + 1 - inference_init = llm.init_weight_transfer_engine.remote( - dict( - init_info=dict( - master_address=master_address, - master_port=master_port, - rank_offset=1, - world_size=world_size, - ) - ) - ) - trainer_init = train_model.init_weight_transfer_group.remote(world_size) - ray.get([trainer_init, inference_init]) - ray.get(llm.start_weight_update.remote(is_checkpoint_format=True)) - - dense_update_info, dense_payload_bytes = ray.get( - train_model.get_dense_update_info.remote() - ) - _, selected_token_ids, patch_digest, _ = ray.get( - train_model.prepare_sparse_patch.remote(PROMPTS) - ) - - inference_update = llm.update_weights.remote( - dict(update_info=dense_update_info) - ) - dense_send_ms, _ = ray.get( - [ - train_model.broadcast_weights.remote(packed=False), - inference_update, - ] - ) - ray.get(llm.finish_weight_update.remote()) - ray.get(llm.wake_up.remote(tags=["scheduling"])) - - dense_after = collect_vllm_generations(llm) - - return { - "dense_before": dense_before, - "dense_after": dense_after, - "selected_token_ids": selected_token_ids, - "patch_digest": patch_digest, - "dense_payload_bytes": dense_payload_bytes, - "dense_send_ms": dense_send_ms, - } - finally: - ray.kill(llm) - - -def run_sparse_phase( - train_model, - scheduling_inference: PlacementGroupSchedulingStrategy, -) -> dict[str, object]: - ray.get(train_model.reset_model.remote()) - llm = launch_llm(scheduling_inference) - try: - sparse_before = collect_vllm_generations(llm) - - ray.get(llm.sleep.remote(level=0)) - master_address, master_port = ray.get(train_model.create_rendezvous.remote()) - world_size = ray.get(llm.get_world_size.remote()) + 1 - inference_init = llm.init_weight_transfer_engine.remote( - dict( - init_info=dict( - master_address=master_address, - master_port=master_port, - rank_offset=1, - world_size=world_size, - ) - ) - ) - trainer_init = train_model.init_weight_transfer_group.remote(world_size) - ray.get([trainer_init, inference_init]) - ray.get(llm.start_weight_update.remote(is_checkpoint_format=False)) - - sparse_update_info, selected_token_ids, patch_digest, sparse_payload_bytes = ( - ray.get(train_model.prepare_sparse_patch.remote(PROMPTS)) - ) - - inference_update = llm.update_weights.remote( - dict(update_info=sparse_update_info) - ) - sparse_send_ms, _ = ray.get( - [ - train_model.broadcast_pending_sparse_patch.remote(), - inference_update, - ] - ) - ray.get(llm.finish_weight_update.remote()) - ray.get(llm.wake_up.remote(tags=["scheduling"])) - - sparse_after = collect_vllm_generations(llm) - - return { - "sparse_before": sparse_before, - "sparse_after": sparse_after, - "selected_token_ids": selected_token_ids, - "patch_digest": patch_digest, - "sparse_payload_bytes": sparse_payload_bytes, - "sparse_send_ms": sparse_send_ms, - } - finally: - ray.kill(llm) - - -ray.init() - -try: - train_model = TrainModel.remote(MODEL_NAME) - - pg_inference = placement_group([{"GPU": 1, "CPU": 0}]) - ray.get(pg_inference.ready()) - scheduling_inference = PlacementGroupSchedulingStrategy( - placement_group=pg_inference, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=0, - ) - - dense_results = run_dense_phase(train_model, scheduling_inference) - sparse_results = run_sparse_phase(train_model, scheduling_inference) - - baseline_equal = token_sequences_match( - dense_results["dense_before"], - sparse_results["sparse_before"], - ) - patch_selection_equal = ( - dense_results["selected_token_ids"] == sparse_results["selected_token_ids"] - ) - patch_digest_equal = dense_results["patch_digest"] == sparse_results["patch_digest"] - after_equal = token_sequences_match( - dense_results["dense_after"], - sparse_results["sparse_after"], - ) - any_output_changed = any( - before["token_ids"] != after["token_ids"] - for before, after in zip( - dense_results["dense_before"], - dense_results["dense_after"], - ) - ) - dense_payload_mb = dense_results["dense_payload_bytes"] / (1024 * 1024) - sparse_payload_mb = sparse_results["sparse_payload_bytes"] / (1024 * 1024) - - print_generations( - "Dense baseline outputs", - PROMPTS, - dense_results["dense_before"], - ) - print_generations( - "Sparse baseline outputs", PROMPTS, sparse_results["sparse_before"] - ) - print_generations( - "Dense outputs after update", PROMPTS, dense_results["dense_after"] - ) - print_generations( - "Sparse outputs after update", - PROMPTS, - sparse_results["sparse_after"], - ) - - print(f"patched_token_ids = {dense_results['selected_token_ids']}") - print(f"patch_selection_equal = {patch_selection_equal}") - print(f"dense_patch_digest = {dense_results['patch_digest']}") - print(f"sparse_patch_digest = {sparse_results['patch_digest']}") - print(f"patch_digest_equal = {patch_digest_equal}") - print(f"baseline_equal = {baseline_equal}") - print(f"after_equal = {after_equal}") - print(f"any_output_changed = {any_output_changed}") - print(f"dense_payload_mb = {dense_payload_mb:.2f}") - print(f"sparse_payload_mb = {sparse_payload_mb:.2f}") - print(f"dense_send_ms = {dense_results['dense_send_ms']:.2f}") - print(f"sparse_send_ms = {sparse_results['sparse_send_ms']:.2f}") - - if not baseline_equal: - raise RuntimeError( - "Dense and sparse phases did not start from the same baseline" - ) - if not patch_selection_equal: - raise RuntimeError("Dense and sparse phases used different sparse patches") - if not patch_digest_equal: - raise RuntimeError("Dense and sparse phases produced different patch values") - if not after_equal: - raise RuntimeError("Dense and sparse updates produced different outputs") - if not any_output_changed: - raise RuntimeError("Patch did not change the observed outputs") -finally: - ray.shutdown() diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py index 467e3934a053..295e812a1245 100644 --- a/tests/distributed/test_weight_transfer.py +++ b/tests/distributed/test_weight_transfer.py @@ -18,7 +18,6 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer import WeightTransferEngineFactory -from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.distributed.weight_transfer.ipc_engine import ( IPCWeightTransferEngine, IPCWeightTransferInitInfo, @@ -90,67 +89,6 @@ def test_empty_lists_valid(self): ) assert len(info.names) == 0 - def test_valid_sparse_update_info(self): - """Test creating valid sparse NCCL update info.""" - info = NCCLWeightTransferUpdateInfo( - names=["layer.weight", "layer.bias"], - dtype_names=["float32", "bfloat16"], - shapes=[[10, 10], [10]], - num_updates_list=[4, 2], - update_kind="sparse_flat", - ) - assert info.update_kind == "sparse_flat" - assert info.num_updates_list == [4, 2] - - def test_sparse_update_requires_num_updates_list(self): - with pytest.raises(ValueError, match="`num_updates_list` is required"): - NCCLWeightTransferUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - update_kind="sparse_flat", - ) - - def test_sparse_update_rejects_empty_num_updates_list(self): - with pytest.raises(ValueError, match="cannot be empty"): - NCCLWeightTransferUpdateInfo( - names=[], - dtype_names=[], - shapes=[], - num_updates_list=[], - update_kind="sparse_flat", - ) - - def test_sparse_update_rejects_packed(self): - with pytest.raises(ValueError, match="cannot be combined with `packed=True`"): - NCCLWeightTransferUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - num_updates_list=[3], - update_kind="sparse_flat", - packed=True, - ) - - def test_sparse_update_rejects_mismatched_num_updates(self): - with pytest.raises(ValueError, match="`num_updates_list`"): - NCCLWeightTransferUpdateInfo( - names=["layer.weight", "layer.bias"], - dtype_names=["float32", "float32"], - shapes=[[10, 10], [10]], - num_updates_list=[3], - update_kind="sparse_flat", - ) - - def test_dense_update_rejects_sparse_metadata(self): - with pytest.raises(ValueError, match="Sparse metadata"): - NCCLWeightTransferUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - num_updates_list=[3], - ) - # --- Unit Tests: Engine Parsing --- @@ -284,27 +222,6 @@ def test_nccl_receive_weights_without_init_raises(): engine.receive_weights(update_info, lambda x: None) -def test_nccl_receive_sparse_weights_without_init_raises(): - """Test that sparse receive raises if init_transfer_engine wasn't called.""" - if torch.accelerator.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - config = WeightTransferConfig(backend="nccl") - parallel_config = create_mock_parallel_config() - engine = NCCLWeightTransferEngine(config, parallel_config) - - update_info = NCCLWeightTransferUpdateInfo( - names=["w"], - dtype_names=["float32"], - shapes=[[10]], - num_updates_list=[2], - update_kind="sparse_flat", - ) - - with pytest.raises(RuntimeError, match="not initialized"): - engine.receive_sparse_weights(update_info, lambda x: None) - - # --- Integration Test: NCCL Weight Transfer Between Ray Tasks --- @@ -462,136 +379,6 @@ def test_nccl_weight_transfer_between_processes(): ) -@ray.remote(num_gpus=1) -def trainer_broadcast_sparse_tensor( - master_address: str, - master_port: int, - world_size: int, -) -> bool: - """Trainer task that broadcasts sparse patches via NCCL.""" - import torch - - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - from vllm.distributed.weight_transfer.base import SparseWeightPatch - from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLTrainerSendWeightsArgs, - NCCLWeightTransferEngine, - ) - - pg = StatelessProcessGroup.create( - host=master_address, - port=master_port, - rank=0, - world_size=world_size, - ) - comm = PyNcclCommunicator(pg, device=0) - - patch = SparseWeightPatch( - name="test.weight", - indices=torch.tensor([1, 7, 25], dtype=torch.int32, device="cuda:0"), - values=torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32, device="cuda:0"), - ) - NCCLWeightTransferEngine.trainer_send_sparse_weights( - iter([patch]), - NCCLTrainerSendWeightsArgs(group=comm), - ) - torch.accelerator.synchronize() - return True - - -@ray.remote(num_gpus=1) -def inference_receive_sparse_tensor( - master_address: str, - master_port: int, - world_size: int, -) -> dict: - """Inference task that receives sparse patches via NCCLWeightTransferEngine.""" - from unittest.mock import MagicMock - - import torch - - from vllm.config.parallel import ParallelConfig - from vllm.config.weight_transfer import WeightTransferConfig - from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLWeightTransferEngine, - NCCLWeightTransferInitInfo, - NCCLWeightTransferUpdateInfo, - ) - - config = WeightTransferConfig(backend="nccl") - parallel_config = MagicMock(spec=ParallelConfig) - parallel_config.rank = 0 - parallel_config.world_size = 1 - parallel_config.data_parallel_rank = 0 - parallel_config.data_parallel_index = 0 - - engine = NCCLWeightTransferEngine(config, parallel_config) - engine.init_transfer_engine( - NCCLWeightTransferInitInfo( - master_address=master_address, - master_port=master_port, - rank_offset=1, - world_size=world_size, - ) - ) - - target = torch.zeros(30, dtype=torch.float32, device="cuda") - - def apply_sparse_patches(patches: list[SparseWeightPatch]): - for patch in patches: - target.index_copy_(0, patch.indices.to(torch.long), patch.values) - - update_info = NCCLWeightTransferUpdateInfo( - names=["test.weight"], - dtype_names=["float32"], - shapes=[[30]], - num_updates_list=[3], - update_kind="sparse_flat", - ) - engine.receive_sparse_weights(update_info, apply_sparse_patches) - torch.accelerator.synchronize() - - expected = torch.zeros(30, dtype=torch.float32, device="cuda") - expected[[1, 7, 25]] = torch.tensor( - [10.0, 20.0, 30.0], dtype=torch.float32, device="cuda" - ) - success = torch.equal(target, expected) - engine.shutdown() - return { - "success": success, - "selected_values": target[[1, 7, 25]].cpu().tolist(), - } - - -@pytest.mark.skipif( - torch.accelerator.device_count() < 2, - reason="Need at least 2 GPUs to run NCCL sparse weight transfer test.", -) -def test_nccl_sparse_weight_transfer_between_processes(): - """Test NCCL sparse weight transfer from trainer to inference process.""" - ray.init(ignore_reinit_error=True) - - master_address = "127.0.0.1" - master_port = get_open_port() - world_size = 2 - - inference_future = inference_receive_sparse_tensor.remote( - master_address, master_port, world_size - ) - trainer_future = trainer_broadcast_sparse_tensor.remote( - master_address, master_port, world_size - ) - - trainer_result, result = ray.get([trainer_future, inference_future]) - - assert trainer_result, "Trainer should complete successfully" - assert result["success"], ( - "Sparse weight transfer failed. " - f"Received selected values: {result['selected_values']}" - ) - - # --- Unit Tests: IPCWeightTransferUpdateInfo Validation --- @@ -674,101 +461,9 @@ def test_mismatched_ipc_handles_raises(self): ipc_handles=ipc_handles, ) - def test_sparse_update_kind_rejected(self): - """Test that IPC backend rejects sparse update metadata.""" - if torch.accelerator.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - dummy_tensor = torch.ones(10, 10, device="cuda:0") - ipc_handle = reduce_tensor(dummy_tensor) - gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) - ipc_handles = [{gpu_uuid: ipc_handle}] - - with pytest.raises(NotImplementedError, match="dense updates"): - IPCWeightTransferUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - num_updates_list=[1], - ipc_handles=ipc_handles, - update_kind="sparse_flat", - ) - - def test_sparse_methods_not_supported(self): - """Test that IPC engine inherits sparse rejection from the base class.""" - config = WeightTransferConfig(backend="ipc") - parallel_config = create_mock_parallel_config() - engine = IPCWeightTransferEngine( - config, parallel_config, MagicMock(spec=torch.nn.Module) - ) - - with pytest.raises(NotImplementedError, match="(?i)sparse weight updates"): - engine.receive_sparse_weights(MagicMock(), lambda _: None) - with pytest.raises(NotImplementedError, match="(?i)sparse weight updates"): - engine.trainer_send_sparse_weights( - iter([]), - {"mode": "http", "url": "http://localhost:8000"}, - ) - - def test_valid_update_info_from_pickled(self, monkeypatch): - """Test creating IPCWeightTransferUpdateInfo from pickled handles.""" - if torch.accelerator.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - - dummy_tensor = torch.ones(10, 10, device="cuda:0") - ipc_handle = reduce_tensor(dummy_tensor) - gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) - ipc_handles = [{gpu_uuid: ipc_handle}] - - pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") - - info = IPCWeightTransferUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - ipc_handles_pickled=pickled, - ) - assert info.ipc_handles == ipc_handles - assert info.ipc_handles_pickled is None - - def test_pickled_requires_insecure_serialization_flag(self, monkeypatch): - """Test that pickled handles are rejected unless env flag is enabled.""" - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0") - - with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"): - IPCWeightTransferUpdateInfo( - names=[], - dtype_names=[], - shapes=[], - ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"), - ) - - def test_both_handles_and_pickled_raises(self): - """Test that providing both ipc_handles and ipc_handles_pickled raises.""" - if torch.accelerator.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - dummy_tensor = torch.ones(10, 10, device="cuda:0") - ipc_handle = reduce_tensor(dummy_tensor) - gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) - ipc_handles = [{gpu_uuid: ipc_handle}] - - pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") - - with pytest.raises(ValueError, match="Cannot specify both"): - IPCWeightTransferUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - ipc_handles=ipc_handles, - ipc_handles_pickled=pickled, - ) - - def test_neither_handles_nor_pickled_raises(self): - """Test that providing neither ipc_handles nor ipc_handles_pickled raises.""" - with pytest.raises(ValueError, match="must be provided"): + def test_missing_ipc_handles_raises(self): + """Test that omitting ipc_handles raises TypeError.""" + with pytest.raises(TypeError): IPCWeightTransferUpdateInfo( names=["layer.weight"], dtype_names=["float32"], @@ -863,28 +558,6 @@ def test_parse_update_info_pickled(self, monkeypatch): assert gpu_uuid in update_info.ipc_handles[0] assert gpu_uuid in update_info.ipc_handles[1] - def test_parse_update_info_ignores_none_pickled_handles(self): - """Test Ray/asdict payloads with a null pickled field use ipc_handles.""" - config = WeightTransferConfig(backend="ipc") - parallel_config = create_mock_parallel_config() - engine = IPCWeightTransferEngine( - config, parallel_config, MagicMock(spec=torch.nn.Module) - ) - ipc_handles = [{"gpu-uuid": ("ipc-args",)}] - - update_info = engine.parse_update_info( - { - "names": ["w1"], - "dtype_names": ["float32"], - "shapes": [[1]], - "ipc_handles": ipc_handles, - "ipc_handles_pickled": None, - } - ) - - assert isinstance(update_info, IPCWeightTransferUpdateInfo) - assert update_info.ipc_handles == ipc_handles - def test_parse_update_info_both_handles_and_pickled_raises(self): """Test that providing both ipc_handles and ipc_handles_pickled raises.""" if torch.accelerator.device_count() < 1: diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index 1dd89afcf80c..6c6269865075 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -48,7 +48,6 @@ class MockUpdateInfo(WeightTransferUpdateInfo): names: list[str] | None = None dtype_names: list[str] | None = None shapes: list[list[int]] | None = None - num_updates_list: list[int] | None = None class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo]): @@ -88,15 +87,6 @@ def receive_weights( # (In real implementation, this would receive and load actual weights) load_weights([]) - def receive_sparse_weights( - self, - update_info: MockUpdateInfo, - apply_patches: Callable[[list], None], - ) -> None: - MockWeightTransferEngine.receive_weights_called = True - MockWeightTransferEngine.last_update_info = update_info - apply_patches([]) - def shutdown(self) -> None: MockWeightTransferEngine.shutdown_called = True @@ -208,6 +198,8 @@ def test_update_weights_calls_engine(): llm.init_weight_transfer_engine( WeightTransferInitRequest(init_info={"test_param": "init"}) ) + + # Start weight update (required before update_weights) llm.start_weight_update(is_checkpoint_format=True) # Call update_weights @@ -240,67 +232,14 @@ def check_update_called(self): assert dtypes == test_dtypes assert shapes == test_shapes - llm.finish_weight_update() - - -@create_new_process_for_each_test() -def test_update_weights_passes_sparse_metadata(): - """Test sparse update metadata is forwarded unchanged to the engine.""" - if torch.accelerator.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" - - with patch( - "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", - mock_create_engine, - ): - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=1, - weight_transfer_config=WeightTransferConfig(backend="nccl"), - ) - - llm.init_weight_transfer_engine( - WeightTransferInitRequest(init_info={"test_param": "init"}) - ) - llm.start_weight_update(is_checkpoint_format=False) - - llm.update_weights( - WeightTransferUpdateRequest( - update_info={ - "names": ["layer.weight"], - "dtype_names": ["bfloat16"], - "shapes": [[100]], - "num_updates_list": [3], - "update_kind": "sparse_flat", - } - ) - ) - - def check_sparse_update_called(self): - engine = self.weight_transfer_engine - if not engine.receive_weights_called: - return None - info = engine.last_update_info - return ( - info.update_kind, - info.num_updates_list, - ) - - results = llm.collective_rpc(check_sparse_update_called) - for result in results: - assert result == ("sparse_flat", [3]) - + # Finish weight update llm.finish_weight_update() @create_new_process_for_each_test() def test_full_weight_transfer_flow(): - """Test the complete weight transfer flow: init -> start -> update -> finish.""" + """Test the complete weight transfer flow: + init -> start -> update -> finish.""" if torch.accelerator.device_count() < 1: pytest.skip("Need at least 1 GPU for this test") diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 1a1352249c33..56811982b91d 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -7,7 +7,6 @@ import numpy as np import pytest import torch -import torch.nn as nn import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module from vllm.config import ( @@ -23,7 +22,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform @@ -786,73 +784,6 @@ def test_sample_passes_reordered_draft_probs_to_rejection_sampler(): assert torch.equal(passed_draft_probs, expected_draft_probs) -def test_apply_sparse_weight_patches_updates_only_selected_entries(): - class DummyModel(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.zeros(6, dtype=torch.float32)) - - runner = object.__new__(GPUModelRunner) - runner.model = DummyModel() - - runner.apply_sparse_weight_patches( - [ - SparseWeightPatch( - name="weight", - indices=torch.tensor([1, 4], dtype=torch.int32), - values=torch.tensor([3.5, -2.0], dtype=torch.float32), - ) - ] - ) - - expected = torch.tensor([0.0, 3.5, 0.0, 0.0, -2.0, 0.0], dtype=torch.float32) - assert torch.equal(runner.get_model().weight.data, expected) - - -def test_apply_sparse_weight_patches_rejects_mismatched_lengths(): - class DummyModel(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.zeros(4, dtype=torch.float32)) - - runner = object.__new__(GPUModelRunner) - runner.model = DummyModel() - - with pytest.raises(ValueError, match="matching lengths"): - runner.apply_sparse_weight_patches( - [ - SparseWeightPatch( - name="weight", - indices=torch.tensor([1, 2], dtype=torch.int32), - values=torch.tensor([1.0], dtype=torch.float32), - ) - ] - ) - - -def test_apply_sparse_weight_patches_rejects_non_contiguous_param(): - class DummyModel(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter( - torch.arange(12, dtype=torch.float32).view(3, 4).t() - ) - - runner = object.__new__(GPUModelRunner) - runner.model = DummyModel() - - with pytest.raises(NotImplementedError, match="contiguous params"): - runner.apply_sparse_weight_patches( - [ - SparseWeightPatch( - name="weight", - indices=torch.tensor([1], dtype=torch.int32), - values=torch.tensor([1.0], dtype=torch.float32), - ) - ] - ) - - def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config): torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" diff --git a/tests/v1/worker/test_gpu_worker_weight_transfer.py b/tests/v1/worker/test_gpu_worker_weight_transfer.py deleted file mode 100644 index dba0f658542a..000000000000 --- a/tests/v1/worker/test_gpu_worker_weight_transfer.py +++ /dev/null @@ -1,155 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -import torch - -from vllm.config.parallel import ParallelConfig -from vllm.config.weight_transfer import WeightTransferConfig -from vllm.distributed.weight_transfer.base import SparseWeightPatch -from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine -from vllm.v1.worker.gpu_worker import Worker - - -def _make_nccl_engine() -> NCCLWeightTransferEngine: - parallel_config = MagicMock(spec=ParallelConfig) - parallel_config.rank = 0 - parallel_config.world_size = 1 - parallel_config.data_parallel_rank = 0 - parallel_config.data_parallel_index = 0 - return NCCLWeightTransferEngine( - WeightTransferConfig(backend="nccl"), - parallel_config, - MagicMock(spec=torch.nn.Module), - ) - - -def test_update_weights_sparse_dispatches_to_sparse_receive(monkeypatch): - monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) - - worker = object.__new__(Worker) - worker.device = "cpu" - worker.parallel_config = SimpleNamespace(world_size=1) - worker.weight_transfer_engine = _make_nccl_engine() - worker._weight_update_active = True - worker._is_checkpoint_format = False - - applied_patches = [] - - def apply_sparse_weight_patches(patches): - applied_patches.extend(patches) - - worker.model_runner = SimpleNamespace( - apply_sparse_weight_patches=apply_sparse_weight_patches, - ) - - received_kinds = [] - - def receive_sparse_weights(update_info, apply_patches): - received_kinds.append(update_info.update_kind) - apply_patches( - [ - SparseWeightPatch( - name="layer.weight", - indices=torch.tensor([1], dtype=torch.int32), - values=torch.tensor([2.0], dtype=torch.float32), - ) - ] - ) - - worker.weight_transfer_engine.receive_sparse_weights = receive_sparse_weights - - Worker.update_weights( - worker, - { - "names": ["layer.weight"], - "dtype_names": ["float32"], - "shapes": [[4]], - "num_updates_list": [1], - "update_kind": "sparse_flat", - }, - ) - - assert received_kinds == ["sparse_flat"] - assert len(applied_patches) == 1 - assert torch.equal(applied_patches[0].indices, torch.tensor([1], dtype=torch.int32)) - - -def test_update_weights_sparse_rejects_tp_or_pp(monkeypatch): - monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) - - worker = object.__new__(Worker) - worker.device = "cpu" - worker.parallel_config = SimpleNamespace(world_size=2) - worker.weight_transfer_engine = _make_nccl_engine() - worker._weight_update_active = True - worker._is_checkpoint_format = False - worker.model_runner = SimpleNamespace(apply_sparse_weight_patches=lambda _: None) - - with pytest.raises(NotImplementedError, match="TP=1 and PP=1"): - Worker.update_weights( - worker, - { - "names": ["layer.weight"], - "dtype_names": ["float32"], - "shapes": [[4]], - "num_updates_list": [1], - "update_kind": "sparse_flat", - }, - ) - assert worker._weight_update_active is False - assert worker._is_checkpoint_format is True - - -def test_update_weights_sparse_rejects_checkpoint_format(monkeypatch): - monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) - - worker = object.__new__(Worker) - worker.device = "cpu" - worker.parallel_config = SimpleNamespace(world_size=1) - worker.weight_transfer_engine = _make_nccl_engine() - worker._weight_update_active = True - worker._is_checkpoint_format = True - worker.model_runner = SimpleNamespace(model=MagicMock()) - - with pytest.raises(ValueError, match="start_weight_update"): - Worker.update_weights( - worker, - { - "names": ["layer.weight"], - "dtype_names": ["float32"], - "shapes": [[4]], - "num_updates_list": [1], - "update_kind": "sparse_flat", - }, - ) - assert worker._weight_update_active is False - assert worker._is_checkpoint_format is True - - -def test_update_weights_resets_state_when_update_info_is_invalid(monkeypatch): - monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) - - worker = object.__new__(Worker) - worker.device = "cpu" - worker.parallel_config = SimpleNamespace(world_size=1) - worker.weight_transfer_engine = _make_nccl_engine() - worker._weight_update_active = True - worker._is_checkpoint_format = False - - with pytest.raises(ValueError, match="cannot be empty"): - Worker.update_weights( - worker, - { - "names": [], - "dtype_names": [], - "shapes": [], - "num_updates_list": [], - "update_kind": "sparse_flat", - }, - ) - assert worker._weight_update_active is False - assert worker._is_checkpoint_format is True diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index eda209c3f6b7..6e99adde1ca7 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -4,8 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterator -from dataclasses import KW_ONLY, dataclass, field -from typing import Any, Generic, Literal, TypeVar +from dataclasses import dataclass, field +from typing import Any, Generic, TypeVar import torch @@ -28,44 +28,7 @@ class WeightTransferInitInfo(ABC): # noqa: B024 class WeightTransferUpdateInfo(ABC): # noqa: B024 """Base class for backend-specific weight update info.""" - _: KW_ONLY - update_kind: Literal["dense", "sparse_flat"] = "dense" - """Weight update format.""" - num_updates_list: list[int] | None = None - """Number of sparse entries to receive for each parameter in ``names``.""" - - def __post_init__(self) -> None: - if self.update_kind not in ("dense", "sparse_flat"): - raise ValueError(f"Unsupported update_kind: {self.update_kind}") - if self.update_kind == "dense": - if self.num_updates_list is not None: - raise ValueError( - "Sparse metadata is only supported for `update_kind='sparse_flat'`" - ) - return - - if self.num_updates_list is None: - raise ValueError("`num_updates_list` is required for sparse updates") - if len(self.num_updates_list) == 0: - raise ValueError("`num_updates_list` cannot be empty for sparse updates") - if any(num_updates < 0 for num_updates in self.num_updates_list): - raise ValueError("Sparse `num_updates_list` entries must be non-negative") - - names = getattr(self, "names", None) - if names is not None and len(self.num_updates_list) != len(names): - raise ValueError( - f"`num_updates_list` should be of the same size as `names`: " - f"got {len(self.num_updates_list)} and {len(names)}" - ) - - -@dataclass -class SparseWeightPatch: - """A sparse in-place patch for one existing parameter.""" - - name: str - indices: torch.Tensor - values: torch.Tensor + pass # API-level request classes (accept dicts for backend-agnostic serialization) @@ -187,16 +150,6 @@ def receive_weights( """ raise NotImplementedError - def receive_sparse_weights( - self, - update_info: TUpdateInfo, - apply_patches: Callable[[list[SparseWeightPatch]], None], - ) -> None: - """Receive sparse weight patches from the trainer.""" - raise NotImplementedError( - f"{self.__class__.__name__} does not support sparse weight updates" - ) - @abstractmethod def shutdown(self) -> None: """ @@ -231,11 +184,3 @@ def trainer_send_weights( >>> engine.trainer_send_weights(param_iter, trainer_args) """ raise NotImplementedError - - @staticmethod - def trainer_send_sparse_weights( - _iterator: Iterator[SparseWeightPatch], - _trainer_args: dict[str, Any] | Any, - ) -> None: - """Send sparse weight patches from trainer to inference workers.""" - raise NotImplementedError("Sparse weight updates are not supported") diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py index a77aab751ff2..b138c7dd9374 100644 --- a/vllm/distributed/weight_transfer/ipc_engine.py +++ b/vllm/distributed/weight_transfer/ipc_engine.py @@ -74,12 +74,10 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo): names: list[str] dtype_names: list[str] shapes: list[list[int]] - ipc_handles: list[dict[str, tuple]] | dict[str, tuple] | None = None + ipc_handles: list[dict[str, tuple]] | dict[str, tuple] """IPC handles mapping physical GPU UUID to rebuild_cuda_tensor args. For non-packed mode: list of per-parameter handle dicts. For packed mode: single handle dict for the packed buffer.""" - ipc_handles_pickled: str | None = None - """Base64-encoded pickled IPC handles, used for HTTP transport.""" tensor_sizes: list[int] | None = None """Per-parameter sizes in bytes within the packed buffer. Required when packed=True, unused otherwise.""" @@ -87,29 +85,6 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo): """Whether this update uses packed tensor format.""" def __post_init__(self): - super().__post_init__() - if self.update_kind != "dense": - raise NotImplementedError("IPC weight transfer only supports dense updates") - - if self.ipc_handles_pickled is not None: - if self.ipc_handles is not None: - raise ValueError( - "Cannot specify both `ipc_handles` and `ipc_handles_pickled`" - ) - - if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise ValueError( - "Refusing to deserialize `ipc_handles_pickled` without " - "VLLM_ALLOW_INSECURE_SERIALIZATION=1" - ) - - self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled)) - self.ipc_handles_pickled = None - - if self.ipc_handles is None: - raise ValueError( - "Either `ipc_handles` or `ipc_handles_pickled` must be provided" - ) num_params = len(self.names) if len(self.dtype_names) != num_params: raise ValueError( @@ -178,9 +153,8 @@ def parse_update_info( Requires ``VLLM_ALLOW_INSECURE_SERIALIZATION=1`` because the payload is deserialized via ``pickle.loads``. """ - pickled = update_dict.pop("ipc_handles_pickled", None) - if pickled is not None: - if update_dict.get("ipc_handles") is not None: + if "ipc_handles_pickled" in update_dict: + if "ipc_handles" in update_dict: raise ValueError( "Cannot specify both `ipc_handles` and `ipc_handles_pickled`" ) @@ -191,6 +165,7 @@ def parse_update_info( "VLLM_ALLOW_INSECURE_SERIALIZATION=1" ) + pickled = update_dict.pop("ipc_handles_pickled") update_dict["ipc_handles"] = pickle.loads(base64.b64decode(pickled)) return super().parse_update_info(update_dict) diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 674f5b524da6..3b04a5f65ba3 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -14,7 +14,6 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( - SparseWeightPatch, WeightTransferEngine, WeightTransferInitInfo, WeightTransferUpdateInfo, @@ -82,7 +81,6 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): def __post_init__(self): """Validate that all lists have the same length.""" - super().__post_init__() num_params = len(self.names) if len(self.dtype_names) != num_params: raise ValueError( @@ -94,13 +92,6 @@ def __post_init__(self): f"`shapes` should be of the same size as `names`: " f"got {len(self.shapes)} and {len(self.names)}" ) - if self.update_kind == "dense": - return - - if self.packed: - raise ValueError( - "`update_kind='sparse_flat'` cannot be combined with `packed=True`" - ) class NCCLWeightTransferEngine( @@ -187,11 +178,6 @@ def receive_weights( "NCCL weight transfer not initialized. " "Call init_transfer_engine() first." ) - if update_info.update_kind != "dense": - raise ValueError( - "Sparse updates must use `receive_sparse_weights`, not " - "`receive_weights`" - ) if update_info.packed: # Build iterator of (name, (shape, dtype)) from update_info @@ -223,42 +209,6 @@ def state_dict_info_iterator(): load_weights([(name, weight)]) del weight - def receive_sparse_weights( - self, - update_info: NCCLWeightTransferUpdateInfo, - apply_patches: Callable[[list[SparseWeightPatch]], None], - ) -> None: - """Receive sparse flat-index patches from trainer via NCCL.""" - if self.model_update_group is None: - raise RuntimeError( - "NCCL weight transfer not initialized. " - "Call init_transfer_engine() first." - ) - if update_info.update_kind != "sparse_flat": - raise ValueError("Sparse receive path requires `update_kind='sparse_flat'`") - assert update_info.num_updates_list is not None - - for name, dtype_name, num_updates in zip( - update_info.names, - update_info.dtype_names, - update_info.num_updates_list, - ): - dtype = getattr(torch, dtype_name) - device = torch.accelerator.current_device_index() - indices = torch.empty(num_updates, dtype=torch.int32, device=device) - values = torch.empty(num_updates, dtype=dtype, device=device) - self.model_update_group.broadcast( - indices, src=0, stream=torch.cuda.current_stream() - ) - self.model_update_group.broadcast( - values, src=0, stream=torch.cuda.current_stream() - ) - apply_patches( - [SparseWeightPatch(name=name, indices=indices, values=values)] - ) - del indices - del values - def shutdown(self) -> None: if self.model_update_group is not None: # Clean up the communicator by removing the reference @@ -322,27 +272,6 @@ def trainer_send_weights( stream=args.stream or torch.cuda.current_stream(), ) - @staticmethod - def trainer_send_sparse_weights( - iterator: Iterator[SparseWeightPatch], - trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs, - ) -> None: - """Broadcast sparse flat-index patches from trainer to vLLM workers.""" - if isinstance(trainer_args, dict): - args = NCCLTrainerSendWeightsArgs(**trainer_args) - else: - args = trainer_args - - if args.packed: - raise ValueError( - "Sparse NCCL updates cannot be combined with `packed=True`" - ) - - stream = args.stream or torch.cuda.current_stream() - for patch in iterator: - args.group.broadcast(patch.indices, src=args.src, stream=stream) - args.group.broadcast(patch.values, src=args.src, stream=stream) - @staticmethod def trainer_init( init_info: NCCLWeightTransferInitInfo | dict, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 802d7a6d7968..c8e9b8c08f0b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -873,7 +873,14 @@ def init_weight_transfer_engine( ) def start_weight_update(self, is_checkpoint_format: bool = True) -> None: - """Start a new weight update.""" + """ + Start a new weight update. + + Args: + is_checkpoint_format: Whether incoming weights are in checkpoint + format (need layerwise processing) or kernel format (direct + copy). + """ self.llm_engine.collective_rpc( "start_weight_update", kwargs={"is_checkpoint_format": is_checkpoint_format}, @@ -895,7 +902,9 @@ def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None: ) def finish_weight_update(self) -> None: - """Finish the current weight update.""" + """ + Finish the current weight update. + """ self.llm_engine.collective_rpc("finish_weight_update") def __repr__(self) -> str: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9a05c765894f..f82d2224a413 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -49,7 +49,6 @@ is_global_first_rank, prepare_communication_buffer_for_model, ) -from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.forward_context import ( BatchDescriptor, set_forward_context, @@ -3190,44 +3189,6 @@ def get_model(self) -> nn.Module: return self.model.unwrap() return self.model - def apply_sparse_weight_patches(self, patches: Iterable[SparseWeightPatch]) -> None: - """Apply sparse flat-index patches directly to existing model params.""" - model = self.get_model() - for patch in patches: - param = model.get_parameter(patch.name) - if not param.data.is_contiguous(): - raise NotImplementedError( - "Sparse weight updates currently require contiguous params: " - f"{patch.name}" - ) - - if patch.indices.dtype != torch.int32: - raise ValueError( - "Sparse weight updates currently require int32 indices: " - f"{patch.name}" - ) - if patch.indices.ndim != 1 or patch.values.ndim != 1: - raise ValueError( - f"Sparse weight patches must be 1D flattened updates: {patch.name}" - ) - if patch.indices.numel() != patch.values.numel(): - raise ValueError( - "`indices` and `values` must have matching lengths for " - f"{patch.name}" - ) - if patch.values.dtype != param.dtype: - raise ValueError( - f"Sparse values dtype {patch.values.dtype} does not match " - f"parameter dtype {param.dtype} for {patch.name}" - ) - - flat_param = param.data.view(-1) - flat_param.index_copy_( - 0, - patch.indices.to(device=flat_param.device, dtype=torch.long), - patch.values.to(device=flat_param.device), - ) - def get_supported_generation_tasks(self) -> list[GenerationTask]: model = self.get_model() supported_tasks = list[GenerationTask]() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 121fc69f5327..e63f50bc8dc2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -992,19 +992,23 @@ def init_weight_transfer_engine(self, init_info: dict) -> None: def start_weight_update(self, is_checkpoint_format: bool = True) -> None: """ - Start a new weight update session. + Start a new weight update. + + Prepares the model for receiving weights. For checkpoint format, + this initializes state for layerwise processing. For kernel format, this is + a no-op but must still be called for consistency. Args: is_checkpoint_format: Whether incoming weights are in checkpoint format (need layerwise processing) or kernel format (direct - copy / sparse patch application). + copy). Stored as state for finish_weight_update. """ self._check_weight_transfer_engine() if self._weight_update_active: raise RuntimeError( - "start_weight_update called while a weight update is already " - "active. Call finish_weight_update first." + "start_weight_update called while a weight update is " + "already active. Call finish_weight_update first." ) if is_checkpoint_format: @@ -1016,15 +1020,16 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None: with torch.device(self.device): initialize_layerwise_reload(model) + # Store state so update_weights/finish_weight_update can check self._is_checkpoint_format = is_checkpoint_format self._weight_update_active = True def update_weights(self, update_info: dict) -> None: """ - Receive one weight update chunk from the trainer. + Receive weights from the trainer (one or more chunks). start_weight_update must be called before update_weights and - finish_weight_update must be called after all chunks have been sent. + finish_weight_update must be called after. Args: update_info: Dictionary containing backend-specific update info @@ -1037,72 +1042,52 @@ def update_weights(self, update_info: dict) -> None: "start_weight_update must be called before update_weights." ) - update_succeeded = False - try: - # Parse dict into backend-specific typed dataclass - typed_update_info = self.weight_transfer_engine.parse_update_info( - update_info - ) + # Parse dict into backend-specific typed dataclass + typed_update_info = self.weight_transfer_engine.parse_update_info(update_info) - with torch.device(self.device): - if self._is_checkpoint_format: - if typed_update_info.update_kind != "dense": - raise ValueError( - "Sparse weight updates require " - "`start_weight_update(is_checkpoint_format=False)`." - ) - - model = self.model_runner.model - - # Use layerwise reload pattern for checkpoint format weights - self.weight_transfer_engine.receive_weights( - typed_update_info, - load_weights=model.load_weights, - ) - elif typed_update_info.update_kind == "sparse_flat": - if self.parallel_config.world_size != 1: - raise NotImplementedError( - "Sparse weight updates currently require TP=1 and PP=1" - ) - self.weight_transfer_engine.receive_sparse_weights( - typed_update_info, - apply_patches=self.model_runner.apply_sparse_weight_patches, - ) - else: - model = self.model_runner.model - - # Weights are already in kernel format, copy directly. - def load_weights_direct( - weights: list[tuple[str, torch.Tensor]], - ) -> None: - for name, weight in weights: - param = model.get_parameter(name) - param.copy_(weight) - - self.weight_transfer_engine.receive_weights( - typed_update_info, - load_weights=load_weights_direct, - ) + model = self.model_runner.model - # NCCL broadcast/packed path are asynchronous. - # Sync here so the next step uses the new weights. - torch.accelerator.synchronize() - update_succeeded = True - finally: - if not update_succeeded: - self._weight_update_active = False - self._is_checkpoint_format = True + with torch.device(self.device): + if self._is_checkpoint_format: + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=model.load_weights, + ) + else: + # Weights are already in kernel format, copy directly + def load_weights_direct( + weights: list[tuple[str, torch.Tensor]], + ) -> None: + for name, weight in weights: + param = model.get_parameter(name) + param.copy_(weight) + + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=load_weights_direct, + ) + + # NCCL broadcast/packed path are asynchronous. + # Sync here so the next step uses the new weights. + torch.accelerator.synchronize() def finish_weight_update(self) -> None: - """Finish the current weight update session.""" + """ + Finish the current weight update. + + For checkpoint format, this runs layerwise postprocessing. + Uses the is_checkpoint_format state stored by start_weight_update. + """ self._check_weight_transfer_engine() if not self._weight_update_active: raise RuntimeError( - "finish_weight_update called without a matching start_weight_update." + "start_weight_update must be called before finish_weight_update." ) - if self._is_checkpoint_format: + is_checkpoint_format = self._is_checkpoint_format + + if is_checkpoint_format: from vllm.model_executor.model_loader.reload import ( finalize_layerwise_reload, ) @@ -1111,6 +1096,7 @@ def finish_weight_update(self) -> None: with torch.device(self.device): finalize_layerwise_reload(model, self.model_config) + # Reset state self._weight_update_active = False self._is_checkpoint_format = True