diff --git a/docs/training/weight_transfer/nccl.md b/docs/training/weight_transfer/nccl.md index bfde1ee2ae37..7b531218568b 100644 --- a/docs/training/weight_transfer/nccl.md +++ b/docs/training/weight_transfer/nccl.md @@ -84,7 +84,10 @@ Both the trainer (`NCCLTrainerSendWeightsArgs`) and inference side (`NCCLWeightT ## Receiving Weights (Inference Side) -The inference side triggers weight reception using the four-phase protocol — `init_weight_transfer_engine`, `start_weight_update`, `update_weights`, `finish_weight_update`. The init phase is shown [above](#initialization). The remaining three steps are: +The inference side triggers weight reception using the four-phase protocol: +`init_weight_transfer_engine`, `start_weight_update`, `update_weights`, +`finish_weight_update`. The init phase is shown [above](#initialization). The +remaining three steps are: ```python from vllm.distributed.weight_transfer.base import WeightTransferUpdateRequest @@ -108,12 +111,24 @@ llm.update_weights( llm.finish_weight_update() ``` -The `names`, `dtype_names`, and `shapes` lists describe each parameter. These must match the order in which the trainer iterates over its parameters. +The `names`, `dtype_names`, and `shapes` lists describe each parameter. These +must match the order in which the trainer iterates over its parameters. -`start_weight_update` must be called before `update_weights`, and `finish_weight_update` must be called after all weight chunks have been transferred. The `is_checkpoint_format` flag controls whether layerwise reload processing is applied (`True` for checkpoint-format weights, `False` for pre-processed kernel-format weights). +`start_weight_update` must be called before `update_weights`, and +`finish_weight_update` must be called after all weight chunks have been +transferred. The `is_checkpoint_format` flag controls whether layerwise reload +processing is applied (`True` for checkpoint-format weights, `False` for +pre-processed kernel-format weights). + +Sparse NCCL patches still use `update_kind="sparse_flat"` inside +`update_info`, but they should be wrapped in +`start_weight_update(is_checkpoint_format=False)` because sparse patches apply +directly to runtime/kernel-format parameters. The current sparse MVP requires +`TP=1` and `PP=1`. ## Examples - [RLHF with NCCL weight syncing (offline, Ray)](../../../examples/rl/rlhf_nccl.py) - Trainer on one GPU, 2x tensor-parallel vLLM engine on two others, with packed NCCL weight broadcast +- [RLHF with sparse NCCL weight syncing (offline, Ray)](../../../examples/rl/rlhf_sparse_nccl.py) - Dense-vs-sparse equivalence demo with a real model on a 2-GPU trainer/inference setup; sparse patches use `start_weight_update(is_checkpoint_format=False)` and currently require `TP=1` and `PP=1` - [RLHF with async weight syncing (offline, Ray)](../../../examples/rl/rlhf_async_new_apis.py) - Async generation with mid-flight pause, weight sync, resume, and validation against a fresh model - [RLHF with NCCL weight syncing (online serving, HTTP)](../../../examples/rl/rlhf_http_nccl.py) - Weight transfer with a running vLLM HTTP server using HTTP control plane and NCCL data plane diff --git a/examples/rl/rlhf_sparse_nccl.py b/examples/rl/rlhf_sparse_nccl.py new file mode 100644 index 000000000000..bddd28b6485e --- /dev/null +++ b/examples/rl/rlhf_sparse_nccl.py @@ -0,0 +1,526 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates dense-vs-sparse NCCL weight syncing with a real model. + +This example mirrors the validation story used for the sparse NCCL MVP: +both the dense update path and the sparse patch path start from the same real +checkpoint and apply the same deterministic trainer-side patch. The script then +checks that greedy 1-token outputs match between the dense and sparse vLLM +engines after the update. + +The example performs the following steps: +* Load a training model on one GPU via a Ray actor. +* Launch a vLLM engine with the same real model on a second GPU. +* Verify trainer vs vLLM baseline agreement before any update. +* Apply a deterministic patch to ``model.embed_tokens.weight`` on the trainer. +* Run a dense NCCL update into a fresh vLLM engine and collect post-update + outputs. +* Reset the trainer back to the baseline checkpoint. +* Apply the same deterministic patch again. +* Run a sparse NCCL update into another fresh vLLM engine and collect + post-update outputs. +* Compare dense vs sparse baseline outputs, dense vs sparse post-update + outputs, estimated payload sizes, and trainer-side send times. + +Current sparse weight transfer MVP limitations: +* ``TP=1`` and ``PP=1`` only +* sparse updates use runtime/kernel-format parameter names +* sparse updates are not composable with checkpoint-format or packed updates + +This example assumes a single-node cluster with two GPUs. +""" + +import hashlib +import os +import time +from collections.abc import Sequence + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import SparseWeightPatch +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, + NCCLWeightTransferEngine, +) +from vllm.utils.network_utils import get_ip, get_open_port + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +PATCHED_PARAM_NAME = "model.embed_tokens.weight" +MAX_PATCH_ROWS = 32 +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +SAMPLING_PARAMS = SamplingParams(temperature=0.0, max_tokens=1) + + +class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, *args, **kwargs): + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0" + super().__init__(*args, **kwargs) + + +@ray.remote(num_gpus=1) +class TrainModel: + """Ray actor that owns the trainer-side model and deterministic patch state.""" + + def __init__(self, model_name: str): + self.model_name = model_name + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.model = None + self.patched_param = None + self.pending_sparse_patches: list[SparseWeightPatch] | None = None + self.model_update_group = None + self.master_address = get_ip() + self.port = get_open_port() + self.reset_model() + + def reset_model(self) -> None: + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + ).to("cuda:0") + self.model.eval() + + try: + self.patched_param = self.model.get_parameter(PATCHED_PARAM_NAME) + except AttributeError as exc: + raise RuntimeError( + f"Expected trainer model to expose `{PATCHED_PARAM_NAME}`" + ) from exc + + self.pending_sparse_patches = None + + def create_rendezvous(self) -> tuple[str, int]: + self.port = get_open_port() + return self.master_address, self.port + + def init_weight_transfer_group(self, world_size: int) -> None: + self.model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=self.master_address, + master_port=self.port, + world_size=world_size, + ) + ) + + def get_dense_update_info(self, packed: bool = False) -> tuple[dict, int]: + names = [] + dtype_names = [] + shapes = [] + payload_bytes = 0 + for name, param in self.model.named_parameters(): + names.append(name) + dtype_names.append(str(param.dtype).split(".")[-1]) + shapes.append(list(param.shape)) + payload_bytes += param.numel() * param.element_size() + + return ( + dict( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=packed, + ), + payload_bytes, + ) + + @torch.inference_mode() + def generate( + self, + prompts: Sequence[str], + max_new_tokens: int = 1, + ) -> list[dict[str, object]]: + generations = [] + for prompt in prompts: + model_inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda:0") + output = self.model.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, + ) + new_token_ids = output[0, model_inputs["input_ids"].shape[1] :].tolist() + generations.append( + { + "token_ids": new_token_ids, + "text": self.tokenizer.decode( + new_token_ids, + skip_special_tokens=False, + ), + } + ) + return generations + + def prepare_sparse_patch( + self, + prompts: Sequence[str], + max_patch_rows: int = MAX_PATCH_ROWS, + ) -> tuple[dict[str, object], list[int], str, int]: + selected_token_ids: list[int] = [] + special_ids = set(self.tokenizer.all_special_ids) + for prompt in prompts: + token_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + for token_id in token_ids: + if token_id in special_ids or token_id in selected_token_ids: + continue + selected_token_ids.append(token_id) + if len(selected_token_ids) == max_patch_rows: + break + if len(selected_token_ids) == max_patch_rows: + break + + if not selected_token_ids: + raise ValueError("Could not derive any non-special token IDs to patch") + + vocab_size = self.patched_param.shape[0] + next_token_id = selected_token_ids[-1] + while len(selected_token_ids) < max_patch_rows: + next_token_id = (next_token_id + 1) % vocab_size + if next_token_id in special_ids or next_token_id in selected_token_ids: + continue + selected_token_ids.append(next_token_id) + + row_ids = torch.tensor( + selected_token_ids, + device=self.patched_param.device, + dtype=torch.long, + ) + hidden_size = self.patched_param.shape[1] + column_offsets = torch.arange( + hidden_size, + device=self.patched_param.device, + dtype=torch.long, + ) + + with torch.no_grad(): + # Rotate the selected embedding rows instead of zeroing them so the + # patch remains deterministic while avoiding a degenerate collapse + # to the same special token after the update. + replacement_rows = self.patched_param[row_ids].roll(shifts=1, dims=0) + self.patched_param[row_ids] = replacement_rows + + flat_indices = ( + row_ids.unsqueeze(1).mul(hidden_size).add(column_offsets).reshape(-1) + ) + flat_values = self.patched_param[row_ids].reshape(-1).contiguous() + self.pending_sparse_patches = [ + SparseWeightPatch( + name=PATCHED_PARAM_NAME, + indices=flat_indices.to(torch.int32), + values=flat_values, + ) + ] + patch_digest = hashlib.sha256( + self.pending_sparse_patches[0].indices.cpu().numpy().tobytes() + + self.pending_sparse_patches[0] + .values.detach() + .float() + .cpu() + .numpy() + .tobytes() + ).hexdigest() + + sparse_payload_bytes = ( + flat_indices.numel() * torch.tensor([], dtype=torch.int32).element_size() + + flat_values.numel() * flat_values.element_size() + ) + update_info = dict( + names=[PATCHED_PARAM_NAME], + dtype_names=[str(self.patched_param.dtype).split(".")[-1]], + shapes=[list(self.patched_param.shape)], + num_updates_list=[flat_indices.numel()], + update_kind="sparse_flat", + ) + return update_info, selected_token_ids, patch_digest, sparse_payload_bytes + + def broadcast_weights(self, packed: bool = False) -> float: + if self.model_update_group is None: + raise RuntimeError("Weight transfer group is not initialized") + + trainer_args = NCCLTrainerSendWeightsArgs( + group=self.model_update_group, + packed=packed, + ) + start = time.perf_counter() + NCCLWeightTransferEngine.trainer_send_weights( + iterator=self.model.named_parameters(), + trainer_args=trainer_args, + ) + torch.accelerator.synchronize() + return (time.perf_counter() - start) * 1000.0 + + def broadcast_pending_sparse_patch(self) -> float: + if self.model_update_group is None: + raise RuntimeError("Weight transfer group is not initialized") + if self.pending_sparse_patches is None: + raise RuntimeError("Sparse patch has not been prepared") + + start = time.perf_counter() + NCCLWeightTransferEngine.trainer_send_sparse_weights( + iter(self.pending_sparse_patches), + NCCLTrainerSendWeightsArgs(group=self.model_update_group), + ) + torch.accelerator.synchronize() + self.pending_sparse_patches = None + return (time.perf_counter() - start) * 1000.0 + + +def launch_llm( + scheduling_inference: PlacementGroupSchedulingStrategy, +): + return ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, + )(MyLLM).remote( + model=MODEL_NAME, + enforce_eager=True, + tensor_parallel_size=1, + distributed_executor_backend="ray", + gpu_memory_utilization=0.7, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + +def collect_vllm_generations(llm_handle) -> list[dict[str, object]]: + outputs = ray.get(llm_handle.generate.remote(PROMPTS, SAMPLING_PARAMS)) + generations = [] + for output in outputs: + generations.append( + { + "token_ids": output.outputs[0].token_ids, + "text": output.outputs[0].text, + } + ) + return generations + + +def token_sequences_match( + left: Sequence[dict[str, object]], + right: Sequence[dict[str, object]], +) -> bool: + return [item["token_ids"] for item in left] == [item["token_ids"] for item in right] + + +def print_generations(label: str, prompts: Sequence[str], generations) -> None: + print(f"\n{label}") + print("-" * 50) + for prompt, generation in zip(prompts, generations): + print(f"Prompt: {prompt!r}") + print(f"Token IDs: {generation['token_ids']}") + print(f"Text: {generation['text']!r}") + print("-" * 50) + + +def run_dense_phase( + train_model, + scheduling_inference: PlacementGroupSchedulingStrategy, +) -> dict[str, object]: + ray.get(train_model.reset_model.remote()) + llm = launch_llm(scheduling_inference) + try: + dense_before = collect_vllm_generations(llm) + + ray.get(llm.sleep.remote(level=0)) + master_address, master_port = ray.get(train_model.create_rendezvous.remote()) + world_size = ray.get(llm.get_world_size.remote()) + 1 + inference_init = llm.init_weight_transfer_engine.remote( + dict( + init_info=dict( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, + ) + ) + ) + trainer_init = train_model.init_weight_transfer_group.remote(world_size) + ray.get([trainer_init, inference_init]) + ray.get(llm.start_weight_update.remote(is_checkpoint_format=True)) + + dense_update_info, dense_payload_bytes = ray.get( + train_model.get_dense_update_info.remote() + ) + _, selected_token_ids, patch_digest, _ = ray.get( + train_model.prepare_sparse_patch.remote(PROMPTS) + ) + + inference_update = llm.update_weights.remote( + dict(update_info=dense_update_info) + ) + dense_send_ms, _ = ray.get( + [ + train_model.broadcast_weights.remote(packed=False), + inference_update, + ] + ) + ray.get(llm.finish_weight_update.remote()) + ray.get(llm.wake_up.remote(tags=["scheduling"])) + + dense_after = collect_vllm_generations(llm) + + return { + "dense_before": dense_before, + "dense_after": dense_after, + "selected_token_ids": selected_token_ids, + "patch_digest": patch_digest, + "dense_payload_bytes": dense_payload_bytes, + "dense_send_ms": dense_send_ms, + } + finally: + ray.kill(llm) + + +def run_sparse_phase( + train_model, + scheduling_inference: PlacementGroupSchedulingStrategy, +) -> dict[str, object]: + ray.get(train_model.reset_model.remote()) + llm = launch_llm(scheduling_inference) + try: + sparse_before = collect_vllm_generations(llm) + + ray.get(llm.sleep.remote(level=0)) + master_address, master_port = ray.get(train_model.create_rendezvous.remote()) + world_size = ray.get(llm.get_world_size.remote()) + 1 + inference_init = llm.init_weight_transfer_engine.remote( + dict( + init_info=dict( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, + ) + ) + ) + trainer_init = train_model.init_weight_transfer_group.remote(world_size) + ray.get([trainer_init, inference_init]) + ray.get(llm.start_weight_update.remote(is_checkpoint_format=False)) + + sparse_update_info, selected_token_ids, patch_digest, sparse_payload_bytes = ( + ray.get(train_model.prepare_sparse_patch.remote(PROMPTS)) + ) + + inference_update = llm.update_weights.remote( + dict(update_info=sparse_update_info) + ) + sparse_send_ms, _ = ray.get( + [ + train_model.broadcast_pending_sparse_patch.remote(), + inference_update, + ] + ) + ray.get(llm.finish_weight_update.remote()) + ray.get(llm.wake_up.remote(tags=["scheduling"])) + + sparse_after = collect_vllm_generations(llm) + + return { + "sparse_before": sparse_before, + "sparse_after": sparse_after, + "selected_token_ids": selected_token_ids, + "patch_digest": patch_digest, + "sparse_payload_bytes": sparse_payload_bytes, + "sparse_send_ms": sparse_send_ms, + } + finally: + ray.kill(llm) + + +ray.init() + +try: + train_model = TrainModel.remote(MODEL_NAME) + + pg_inference = placement_group([{"GPU": 1, "CPU": 0}]) + ray.get(pg_inference.ready()) + scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, + ) + + dense_results = run_dense_phase(train_model, scheduling_inference) + sparse_results = run_sparse_phase(train_model, scheduling_inference) + + baseline_equal = token_sequences_match( + dense_results["dense_before"], + sparse_results["sparse_before"], + ) + patch_selection_equal = ( + dense_results["selected_token_ids"] == sparse_results["selected_token_ids"] + ) + patch_digest_equal = dense_results["patch_digest"] == sparse_results["patch_digest"] + after_equal = token_sequences_match( + dense_results["dense_after"], + sparse_results["sparse_after"], + ) + any_output_changed = any( + before["token_ids"] != after["token_ids"] + for before, after in zip( + dense_results["dense_before"], + dense_results["dense_after"], + ) + ) + dense_payload_mb = dense_results["dense_payload_bytes"] / (1024 * 1024) + sparse_payload_mb = sparse_results["sparse_payload_bytes"] / (1024 * 1024) + + print_generations( + "Dense baseline outputs", + PROMPTS, + dense_results["dense_before"], + ) + print_generations( + "Sparse baseline outputs", PROMPTS, sparse_results["sparse_before"] + ) + print_generations( + "Dense outputs after update", PROMPTS, dense_results["dense_after"] + ) + print_generations( + "Sparse outputs after update", + PROMPTS, + sparse_results["sparse_after"], + ) + + print(f"patched_token_ids = {dense_results['selected_token_ids']}") + print(f"patch_selection_equal = {patch_selection_equal}") + print(f"dense_patch_digest = {dense_results['patch_digest']}") + print(f"sparse_patch_digest = {sparse_results['patch_digest']}") + print(f"patch_digest_equal = {patch_digest_equal}") + print(f"baseline_equal = {baseline_equal}") + print(f"after_equal = {after_equal}") + print(f"any_output_changed = {any_output_changed}") + print(f"dense_payload_mb = {dense_payload_mb:.2f}") + print(f"sparse_payload_mb = {sparse_payload_mb:.2f}") + print(f"dense_send_ms = {dense_results['dense_send_ms']:.2f}") + print(f"sparse_send_ms = {sparse_results['sparse_send_ms']:.2f}") + + if not baseline_equal: + raise RuntimeError( + "Dense and sparse phases did not start from the same baseline" + ) + if not patch_selection_equal: + raise RuntimeError("Dense and sparse phases used different sparse patches") + if not patch_digest_equal: + raise RuntimeError("Dense and sparse phases produced different patch values") + if not after_equal: + raise RuntimeError("Dense and sparse updates produced different outputs") + if not any_output_changed: + raise RuntimeError("Patch did not change the observed outputs") +finally: + ray.shutdown() diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py index 295e812a1245..467e3934a053 100644 --- a/tests/distributed/test_weight_transfer.py +++ b/tests/distributed/test_weight_transfer.py @@ -18,6 +18,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer import WeightTransferEngineFactory +from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.distributed.weight_transfer.ipc_engine import ( IPCWeightTransferEngine, IPCWeightTransferInitInfo, @@ -89,6 +90,67 @@ def test_empty_lists_valid(self): ) assert len(info.names) == 0 + def test_valid_sparse_update_info(self): + """Test creating valid sparse NCCL update info.""" + info = NCCLWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "bfloat16"], + shapes=[[10, 10], [10]], + num_updates_list=[4, 2], + update_kind="sparse_flat", + ) + assert info.update_kind == "sparse_flat" + assert info.num_updates_list == [4, 2] + + def test_sparse_update_requires_num_updates_list(self): + with pytest.raises(ValueError, match="`num_updates_list` is required"): + NCCLWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + update_kind="sparse_flat", + ) + + def test_sparse_update_rejects_empty_num_updates_list(self): + with pytest.raises(ValueError, match="cannot be empty"): + NCCLWeightTransferUpdateInfo( + names=[], + dtype_names=[], + shapes=[], + num_updates_list=[], + update_kind="sparse_flat", + ) + + def test_sparse_update_rejects_packed(self): + with pytest.raises(ValueError, match="cannot be combined with `packed=True`"): + NCCLWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + num_updates_list=[3], + update_kind="sparse_flat", + packed=True, + ) + + def test_sparse_update_rejects_mismatched_num_updates(self): + with pytest.raises(ValueError, match="`num_updates_list`"): + NCCLWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10], [10]], + num_updates_list=[3], + update_kind="sparse_flat", + ) + + def test_dense_update_rejects_sparse_metadata(self): + with pytest.raises(ValueError, match="Sparse metadata"): + NCCLWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + num_updates_list=[3], + ) + # --- Unit Tests: Engine Parsing --- @@ -222,6 +284,27 @@ def test_nccl_receive_weights_without_init_raises(): engine.receive_weights(update_info, lambda x: None) +def test_nccl_receive_sparse_weights_without_init_raises(): + """Test that sparse receive raises if init_transfer_engine wasn't called.""" + if torch.accelerator.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + update_info = NCCLWeightTransferUpdateInfo( + names=["w"], + dtype_names=["float32"], + shapes=[[10]], + num_updates_list=[2], + update_kind="sparse_flat", + ) + + with pytest.raises(RuntimeError, match="not initialized"): + engine.receive_sparse_weights(update_info, lambda x: None) + + # --- Integration Test: NCCL Weight Transfer Between Ray Tasks --- @@ -379,6 +462,136 @@ def test_nccl_weight_transfer_between_processes(): ) +@ray.remote(num_gpus=1) +def trainer_broadcast_sparse_tensor( + master_address: str, + master_port: int, + world_size: int, +) -> bool: + """Trainer task that broadcasts sparse patches via NCCL.""" + import torch + + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + from vllm.distributed.weight_transfer.base import SparseWeightPatch + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, + NCCLWeightTransferEngine, + ) + + pg = StatelessProcessGroup.create( + host=master_address, + port=master_port, + rank=0, + world_size=world_size, + ) + comm = PyNcclCommunicator(pg, device=0) + + patch = SparseWeightPatch( + name="test.weight", + indices=torch.tensor([1, 7, 25], dtype=torch.int32, device="cuda:0"), + values=torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32, device="cuda:0"), + ) + NCCLWeightTransferEngine.trainer_send_sparse_weights( + iter([patch]), + NCCLTrainerSendWeightsArgs(group=comm), + ) + torch.accelerator.synchronize() + return True + + +@ray.remote(num_gpus=1) +def inference_receive_sparse_tensor( + master_address: str, + master_port: int, + world_size: int, +) -> dict: + """Inference task that receives sparse patches via NCCLWeightTransferEngine.""" + from unittest.mock import MagicMock + + import torch + + from vllm.config.parallel import ParallelConfig + from vllm.config.weight_transfer import WeightTransferConfig + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, + ) + + config = WeightTransferConfig(backend="nccl") + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.rank = 0 + parallel_config.world_size = 1 + parallel_config.data_parallel_rank = 0 + parallel_config.data_parallel_index = 0 + + engine = NCCLWeightTransferEngine(config, parallel_config) + engine.init_transfer_engine( + NCCLWeightTransferInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, + ) + ) + + target = torch.zeros(30, dtype=torch.float32, device="cuda") + + def apply_sparse_patches(patches: list[SparseWeightPatch]): + for patch in patches: + target.index_copy_(0, patch.indices.to(torch.long), patch.values) + + update_info = NCCLWeightTransferUpdateInfo( + names=["test.weight"], + dtype_names=["float32"], + shapes=[[30]], + num_updates_list=[3], + update_kind="sparse_flat", + ) + engine.receive_sparse_weights(update_info, apply_sparse_patches) + torch.accelerator.synchronize() + + expected = torch.zeros(30, dtype=torch.float32, device="cuda") + expected[[1, 7, 25]] = torch.tensor( + [10.0, 20.0, 30.0], dtype=torch.float32, device="cuda" + ) + success = torch.equal(target, expected) + engine.shutdown() + return { + "success": success, + "selected_values": target[[1, 7, 25]].cpu().tolist(), + } + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 2, + reason="Need at least 2 GPUs to run NCCL sparse weight transfer test.", +) +def test_nccl_sparse_weight_transfer_between_processes(): + """Test NCCL sparse weight transfer from trainer to inference process.""" + ray.init(ignore_reinit_error=True) + + master_address = "127.0.0.1" + master_port = get_open_port() + world_size = 2 + + inference_future = inference_receive_sparse_tensor.remote( + master_address, master_port, world_size + ) + trainer_future = trainer_broadcast_sparse_tensor.remote( + master_address, master_port, world_size + ) + + trainer_result, result = ray.get([trainer_future, inference_future]) + + assert trainer_result, "Trainer should complete successfully" + assert result["success"], ( + "Sparse weight transfer failed. " + f"Received selected values: {result['selected_values']}" + ) + + # --- Unit Tests: IPCWeightTransferUpdateInfo Validation --- @@ -461,9 +674,101 @@ def test_mismatched_ipc_handles_raises(self): ipc_handles=ipc_handles, ) - def test_missing_ipc_handles_raises(self): - """Test that omitting ipc_handles raises TypeError.""" - with pytest.raises(TypeError): + def test_sparse_update_kind_rejected(self): + """Test that IPC backend rejects sparse update metadata.""" + if torch.accelerator.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] + + with pytest.raises(NotImplementedError, match="dense updates"): + IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + num_updates_list=[1], + ipc_handles=ipc_handles, + update_kind="sparse_flat", + ) + + def test_sparse_methods_not_supported(self): + """Test that IPC engine inherits sparse rejection from the base class.""" + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine( + config, parallel_config, MagicMock(spec=torch.nn.Module) + ) + + with pytest.raises(NotImplementedError, match="(?i)sparse weight updates"): + engine.receive_sparse_weights(MagicMock(), lambda _: None) + with pytest.raises(NotImplementedError, match="(?i)sparse weight updates"): + engine.trainer_send_sparse_weights( + iter([]), + {"mode": "http", "url": "http://localhost:8000"}, + ) + + def test_valid_update_info_from_pickled(self, monkeypatch): + """Test creating IPCWeightTransferUpdateInfo from pickled handles.""" + if torch.accelerator.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] + + pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") + + info = IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles_pickled=pickled, + ) + assert info.ipc_handles == ipc_handles + assert info.ipc_handles_pickled is None + + def test_pickled_requires_insecure_serialization_flag(self, monkeypatch): + """Test that pickled handles are rejected unless env flag is enabled.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0") + + with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"): + IPCWeightTransferUpdateInfo( + names=[], + dtype_names=[], + shapes=[], + ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"), + ) + + def test_both_handles_and_pickled_raises(self): + """Test that providing both ipc_handles and ipc_handles_pickled raises.""" + if torch.accelerator.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] + + pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") + + with pytest.raises(ValueError, match="Cannot specify both"): + IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles=ipc_handles, + ipc_handles_pickled=pickled, + ) + + def test_neither_handles_nor_pickled_raises(self): + """Test that providing neither ipc_handles nor ipc_handles_pickled raises.""" + with pytest.raises(ValueError, match="must be provided"): IPCWeightTransferUpdateInfo( names=["layer.weight"], dtype_names=["float32"], @@ -558,6 +863,28 @@ def test_parse_update_info_pickled(self, monkeypatch): assert gpu_uuid in update_info.ipc_handles[0] assert gpu_uuid in update_info.ipc_handles[1] + def test_parse_update_info_ignores_none_pickled_handles(self): + """Test Ray/asdict payloads with a null pickled field use ipc_handles.""" + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine( + config, parallel_config, MagicMock(spec=torch.nn.Module) + ) + ipc_handles = [{"gpu-uuid": ("ipc-args",)}] + + update_info = engine.parse_update_info( + { + "names": ["w1"], + "dtype_names": ["float32"], + "shapes": [[1]], + "ipc_handles": ipc_handles, + "ipc_handles_pickled": None, + } + ) + + assert isinstance(update_info, IPCWeightTransferUpdateInfo) + assert update_info.ipc_handles == ipc_handles + def test_parse_update_info_both_handles_and_pickled_raises(self): """Test that providing both ipc_handles and ipc_handles_pickled raises.""" if torch.accelerator.device_count() < 1: diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index aea4c523eca8..1dd89afcf80c 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -48,6 +48,7 @@ class MockUpdateInfo(WeightTransferUpdateInfo): names: list[str] | None = None dtype_names: list[str] | None = None shapes: list[list[int]] | None = None + num_updates_list: list[int] | None = None class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo]): @@ -63,8 +64,8 @@ class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo last_init_info: MockInitInfo | None = None last_update_info: MockUpdateInfo | None = None - def __init__(self, config, parallel_config): - super().__init__(config, parallel_config) + def __init__(self, config, parallel_config, model): + super().__init__(config, parallel_config, model) # Reset tracking on init MockWeightTransferEngine.init_transfer_engine_called = False MockWeightTransferEngine.receive_weights_called = False @@ -87,6 +88,15 @@ def receive_weights( # (In real implementation, this would receive and load actual weights) load_weights([]) + def receive_sparse_weights( + self, + update_info: MockUpdateInfo, + apply_patches: Callable[[list], None], + ) -> None: + MockWeightTransferEngine.receive_weights_called = True + MockWeightTransferEngine.last_update_info = update_info + apply_patches([]) + def shutdown(self) -> None: MockWeightTransferEngine.shutdown_called = True @@ -95,9 +105,9 @@ def trainer_send_weights(self, *args, **kwargs): pass -def mock_create_engine(config, parallel_config): +def mock_create_engine(config, parallel_config, model): """Mock factory function that returns our mock engine.""" - return MockWeightTransferEngine(config, parallel_config) + return MockWeightTransferEngine(config, parallel_config, model) # --- Tests --- @@ -198,8 +208,6 @@ def test_update_weights_calls_engine(): llm.init_weight_transfer_engine( WeightTransferInitRequest(init_info={"test_param": "init"}) ) - - # Start weight update (required before update_weights) llm.start_weight_update(is_checkpoint_format=True) # Call update_weights @@ -232,14 +240,67 @@ def check_update_called(self): assert dtypes == test_dtypes assert shapes == test_shapes - # Finish weight update + llm.finish_weight_update() + + +@create_new_process_for_each_test() +def test_update_weights_passes_sparse_metadata(): + """Test sparse update metadata is forwarded unchanged to the engine.""" + if torch.accelerator.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + llm.init_weight_transfer_engine( + WeightTransferInitRequest(init_info={"test_param": "init"}) + ) + llm.start_weight_update(is_checkpoint_format=False) + + llm.update_weights( + WeightTransferUpdateRequest( + update_info={ + "names": ["layer.weight"], + "dtype_names": ["bfloat16"], + "shapes": [[100]], + "num_updates_list": [3], + "update_kind": "sparse_flat", + } + ) + ) + + def check_sparse_update_called(self): + engine = self.weight_transfer_engine + if not engine.receive_weights_called: + return None + info = engine.last_update_info + return ( + info.update_kind, + info.num_updates_list, + ) + + results = llm.collective_rpc(check_sparse_update_called) + for result in results: + assert result == ("sparse_flat", [3]) + llm.finish_weight_update() @create_new_process_for_each_test() def test_full_weight_transfer_flow(): - """Test the complete weight transfer flow: - init -> start -> update -> finish.""" + """Test the complete weight transfer flow: init -> start -> update -> finish.""" if torch.accelerator.device_count() < 1: pytest.skip("Need at least 1 GPU for this test") diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 1da5d9570737..dae6ef578bd1 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -7,6 +7,7 @@ import numpy as np import pytest import torch +import torch.nn as nn import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module from vllm.config import ( @@ -22,6 +23,7 @@ init_distributed_environment, initialize_model_parallel, ) +from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform @@ -782,6 +784,73 @@ def test_sample_passes_reordered_draft_probs_to_rejection_sampler(): assert torch.equal(passed_draft_probs, expected_draft_probs) +def test_apply_sparse_weight_patches_updates_only_selected_entries(): + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(6, dtype=torch.float32)) + + runner = object.__new__(GPUModelRunner) + runner.model = DummyModel() + + runner.apply_sparse_weight_patches( + [ + SparseWeightPatch( + name="weight", + indices=torch.tensor([1, 4], dtype=torch.int32), + values=torch.tensor([3.5, -2.0], dtype=torch.float32), + ) + ] + ) + + expected = torch.tensor([0.0, 3.5, 0.0, 0.0, -2.0, 0.0], dtype=torch.float32) + assert torch.equal(runner.get_model().weight.data, expected) + + +def test_apply_sparse_weight_patches_rejects_mismatched_lengths(): + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(4, dtype=torch.float32)) + + runner = object.__new__(GPUModelRunner) + runner.model = DummyModel() + + with pytest.raises(ValueError, match="matching lengths"): + runner.apply_sparse_weight_patches( + [ + SparseWeightPatch( + name="weight", + indices=torch.tensor([1, 2], dtype=torch.int32), + values=torch.tensor([1.0], dtype=torch.float32), + ) + ] + ) + + +def test_apply_sparse_weight_patches_rejects_non_contiguous_param(): + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter( + torch.arange(12, dtype=torch.float32).view(3, 4).t() + ) + + runner = object.__new__(GPUModelRunner) + runner.model = DummyModel() + + with pytest.raises(NotImplementedError, match="contiguous params"): + runner.apply_sparse_weight_patches( + [ + SparseWeightPatch( + name="weight", + indices=torch.tensor([1], dtype=torch.int32), + values=torch.tensor([1.0], dtype=torch.float32), + ) + ] + ) + + def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config): torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" diff --git a/tests/v1/worker/test_gpu_worker_weight_transfer.py b/tests/v1/worker/test_gpu_worker_weight_transfer.py new file mode 100644 index 000000000000..dba0f658542a --- /dev/null +++ b/tests/v1/worker/test_gpu_worker_weight_transfer.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import SparseWeightPatch +from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine +from vllm.v1.worker.gpu_worker import Worker + + +def _make_nccl_engine() -> NCCLWeightTransferEngine: + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.rank = 0 + parallel_config.world_size = 1 + parallel_config.data_parallel_rank = 0 + parallel_config.data_parallel_index = 0 + return NCCLWeightTransferEngine( + WeightTransferConfig(backend="nccl"), + parallel_config, + MagicMock(spec=torch.nn.Module), + ) + + +def test_update_weights_sparse_dispatches_to_sparse_receive(monkeypatch): + monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) + + worker = object.__new__(Worker) + worker.device = "cpu" + worker.parallel_config = SimpleNamespace(world_size=1) + worker.weight_transfer_engine = _make_nccl_engine() + worker._weight_update_active = True + worker._is_checkpoint_format = False + + applied_patches = [] + + def apply_sparse_weight_patches(patches): + applied_patches.extend(patches) + + worker.model_runner = SimpleNamespace( + apply_sparse_weight_patches=apply_sparse_weight_patches, + ) + + received_kinds = [] + + def receive_sparse_weights(update_info, apply_patches): + received_kinds.append(update_info.update_kind) + apply_patches( + [ + SparseWeightPatch( + name="layer.weight", + indices=torch.tensor([1], dtype=torch.int32), + values=torch.tensor([2.0], dtype=torch.float32), + ) + ] + ) + + worker.weight_transfer_engine.receive_sparse_weights = receive_sparse_weights + + Worker.update_weights( + worker, + { + "names": ["layer.weight"], + "dtype_names": ["float32"], + "shapes": [[4]], + "num_updates_list": [1], + "update_kind": "sparse_flat", + }, + ) + + assert received_kinds == ["sparse_flat"] + assert len(applied_patches) == 1 + assert torch.equal(applied_patches[0].indices, torch.tensor([1], dtype=torch.int32)) + + +def test_update_weights_sparse_rejects_tp_or_pp(monkeypatch): + monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) + + worker = object.__new__(Worker) + worker.device = "cpu" + worker.parallel_config = SimpleNamespace(world_size=2) + worker.weight_transfer_engine = _make_nccl_engine() + worker._weight_update_active = True + worker._is_checkpoint_format = False + worker.model_runner = SimpleNamespace(apply_sparse_weight_patches=lambda _: None) + + with pytest.raises(NotImplementedError, match="TP=1 and PP=1"): + Worker.update_weights( + worker, + { + "names": ["layer.weight"], + "dtype_names": ["float32"], + "shapes": [[4]], + "num_updates_list": [1], + "update_kind": "sparse_flat", + }, + ) + assert worker._weight_update_active is False + assert worker._is_checkpoint_format is True + + +def test_update_weights_sparse_rejects_checkpoint_format(monkeypatch): + monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) + + worker = object.__new__(Worker) + worker.device = "cpu" + worker.parallel_config = SimpleNamespace(world_size=1) + worker.weight_transfer_engine = _make_nccl_engine() + worker._weight_update_active = True + worker._is_checkpoint_format = True + worker.model_runner = SimpleNamespace(model=MagicMock()) + + with pytest.raises(ValueError, match="start_weight_update"): + Worker.update_weights( + worker, + { + "names": ["layer.weight"], + "dtype_names": ["float32"], + "shapes": [[4]], + "num_updates_list": [1], + "update_kind": "sparse_flat", + }, + ) + assert worker._weight_update_active is False + assert worker._is_checkpoint_format is True + + +def test_update_weights_resets_state_when_update_info_is_invalid(monkeypatch): + monkeypatch.setattr(torch.accelerator, "synchronize", lambda: None) + + worker = object.__new__(Worker) + worker.device = "cpu" + worker.parallel_config = SimpleNamespace(world_size=1) + worker.weight_transfer_engine = _make_nccl_engine() + worker._weight_update_active = True + worker._is_checkpoint_format = False + + with pytest.raises(ValueError, match="cannot be empty"): + Worker.update_weights( + worker, + { + "names": [], + "dtype_names": [], + "shapes": [], + "num_updates_list": [], + "update_kind": "sparse_flat", + }, + ) + assert worker._weight_update_active is False + assert worker._is_checkpoint_format is True diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index 6e99adde1ca7..eda209c3f6b7 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -4,8 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterator -from dataclasses import dataclass, field -from typing import Any, Generic, TypeVar +from dataclasses import KW_ONLY, dataclass, field +from typing import Any, Generic, Literal, TypeVar import torch @@ -28,7 +28,44 @@ class WeightTransferInitInfo(ABC): # noqa: B024 class WeightTransferUpdateInfo(ABC): # noqa: B024 """Base class for backend-specific weight update info.""" - pass + _: KW_ONLY + update_kind: Literal["dense", "sparse_flat"] = "dense" + """Weight update format.""" + num_updates_list: list[int] | None = None + """Number of sparse entries to receive for each parameter in ``names``.""" + + def __post_init__(self) -> None: + if self.update_kind not in ("dense", "sparse_flat"): + raise ValueError(f"Unsupported update_kind: {self.update_kind}") + if self.update_kind == "dense": + if self.num_updates_list is not None: + raise ValueError( + "Sparse metadata is only supported for `update_kind='sparse_flat'`" + ) + return + + if self.num_updates_list is None: + raise ValueError("`num_updates_list` is required for sparse updates") + if len(self.num_updates_list) == 0: + raise ValueError("`num_updates_list` cannot be empty for sparse updates") + if any(num_updates < 0 for num_updates in self.num_updates_list): + raise ValueError("Sparse `num_updates_list` entries must be non-negative") + + names = getattr(self, "names", None) + if names is not None and len(self.num_updates_list) != len(names): + raise ValueError( + f"`num_updates_list` should be of the same size as `names`: " + f"got {len(self.num_updates_list)} and {len(names)}" + ) + + +@dataclass +class SparseWeightPatch: + """A sparse in-place patch for one existing parameter.""" + + name: str + indices: torch.Tensor + values: torch.Tensor # API-level request classes (accept dicts for backend-agnostic serialization) @@ -150,6 +187,16 @@ def receive_weights( """ raise NotImplementedError + def receive_sparse_weights( + self, + update_info: TUpdateInfo, + apply_patches: Callable[[list[SparseWeightPatch]], None], + ) -> None: + """Receive sparse weight patches from the trainer.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support sparse weight updates" + ) + @abstractmethod def shutdown(self) -> None: """ @@ -184,3 +231,11 @@ def trainer_send_weights( >>> engine.trainer_send_weights(param_iter, trainer_args) """ raise NotImplementedError + + @staticmethod + def trainer_send_sparse_weights( + _iterator: Iterator[SparseWeightPatch], + _trainer_args: dict[str, Any] | Any, + ) -> None: + """Send sparse weight patches from trainer to inference workers.""" + raise NotImplementedError("Sparse weight updates are not supported") diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py index b138c7dd9374..a77aab751ff2 100644 --- a/vllm/distributed/weight_transfer/ipc_engine.py +++ b/vllm/distributed/weight_transfer/ipc_engine.py @@ -74,10 +74,12 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo): names: list[str] dtype_names: list[str] shapes: list[list[int]] - ipc_handles: list[dict[str, tuple]] | dict[str, tuple] + ipc_handles: list[dict[str, tuple]] | dict[str, tuple] | None = None """IPC handles mapping physical GPU UUID to rebuild_cuda_tensor args. For non-packed mode: list of per-parameter handle dicts. For packed mode: single handle dict for the packed buffer.""" + ipc_handles_pickled: str | None = None + """Base64-encoded pickled IPC handles, used for HTTP transport.""" tensor_sizes: list[int] | None = None """Per-parameter sizes in bytes within the packed buffer. Required when packed=True, unused otherwise.""" @@ -85,6 +87,29 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo): """Whether this update uses packed tensor format.""" def __post_init__(self): + super().__post_init__() + if self.update_kind != "dense": + raise NotImplementedError("IPC weight transfer only supports dense updates") + + if self.ipc_handles_pickled is not None: + if self.ipc_handles is not None: + raise ValueError( + "Cannot specify both `ipc_handles` and `ipc_handles_pickled`" + ) + + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + raise ValueError( + "Refusing to deserialize `ipc_handles_pickled` without " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1" + ) + + self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled)) + self.ipc_handles_pickled = None + + if self.ipc_handles is None: + raise ValueError( + "Either `ipc_handles` or `ipc_handles_pickled` must be provided" + ) num_params = len(self.names) if len(self.dtype_names) != num_params: raise ValueError( @@ -153,8 +178,9 @@ def parse_update_info( Requires ``VLLM_ALLOW_INSECURE_SERIALIZATION=1`` because the payload is deserialized via ``pickle.loads``. """ - if "ipc_handles_pickled" in update_dict: - if "ipc_handles" in update_dict: + pickled = update_dict.pop("ipc_handles_pickled", None) + if pickled is not None: + if update_dict.get("ipc_handles") is not None: raise ValueError( "Cannot specify both `ipc_handles` and `ipc_handles_pickled`" ) @@ -165,7 +191,6 @@ def parse_update_info( "VLLM_ALLOW_INSECURE_SERIALIZATION=1" ) - pickled = update_dict.pop("ipc_handles_pickled") update_dict["ipc_handles"] = pickle.loads(base64.b64decode(pickled)) return super().parse_update_info(update_dict) diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 3b04a5f65ba3..674f5b524da6 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -14,6 +14,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( + SparseWeightPatch, WeightTransferEngine, WeightTransferInitInfo, WeightTransferUpdateInfo, @@ -81,6 +82,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): def __post_init__(self): """Validate that all lists have the same length.""" + super().__post_init__() num_params = len(self.names) if len(self.dtype_names) != num_params: raise ValueError( @@ -92,6 +94,13 @@ def __post_init__(self): f"`shapes` should be of the same size as `names`: " f"got {len(self.shapes)} and {len(self.names)}" ) + if self.update_kind == "dense": + return + + if self.packed: + raise ValueError( + "`update_kind='sparse_flat'` cannot be combined with `packed=True`" + ) class NCCLWeightTransferEngine( @@ -178,6 +187,11 @@ def receive_weights( "NCCL weight transfer not initialized. " "Call init_transfer_engine() first." ) + if update_info.update_kind != "dense": + raise ValueError( + "Sparse updates must use `receive_sparse_weights`, not " + "`receive_weights`" + ) if update_info.packed: # Build iterator of (name, (shape, dtype)) from update_info @@ -209,6 +223,42 @@ def state_dict_info_iterator(): load_weights([(name, weight)]) del weight + def receive_sparse_weights( + self, + update_info: NCCLWeightTransferUpdateInfo, + apply_patches: Callable[[list[SparseWeightPatch]], None], + ) -> None: + """Receive sparse flat-index patches from trainer via NCCL.""" + if self.model_update_group is None: + raise RuntimeError( + "NCCL weight transfer not initialized. " + "Call init_transfer_engine() first." + ) + if update_info.update_kind != "sparse_flat": + raise ValueError("Sparse receive path requires `update_kind='sparse_flat'`") + assert update_info.num_updates_list is not None + + for name, dtype_name, num_updates in zip( + update_info.names, + update_info.dtype_names, + update_info.num_updates_list, + ): + dtype = getattr(torch, dtype_name) + device = torch.accelerator.current_device_index() + indices = torch.empty(num_updates, dtype=torch.int32, device=device) + values = torch.empty(num_updates, dtype=dtype, device=device) + self.model_update_group.broadcast( + indices, src=0, stream=torch.cuda.current_stream() + ) + self.model_update_group.broadcast( + values, src=0, stream=torch.cuda.current_stream() + ) + apply_patches( + [SparseWeightPatch(name=name, indices=indices, values=values)] + ) + del indices + del values + def shutdown(self) -> None: if self.model_update_group is not None: # Clean up the communicator by removing the reference @@ -272,6 +322,27 @@ def trainer_send_weights( stream=args.stream or torch.cuda.current_stream(), ) + @staticmethod + def trainer_send_sparse_weights( + iterator: Iterator[SparseWeightPatch], + trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs, + ) -> None: + """Broadcast sparse flat-index patches from trainer to vLLM workers.""" + if isinstance(trainer_args, dict): + args = NCCLTrainerSendWeightsArgs(**trainer_args) + else: + args = trainer_args + + if args.packed: + raise ValueError( + "Sparse NCCL updates cannot be combined with `packed=True`" + ) + + stream = args.stream or torch.cuda.current_stream() + for patch in iterator: + args.group.broadcast(patch.indices, src=args.src, stream=stream) + args.group.broadcast(patch.values, src=args.src, stream=stream) + @staticmethod def trainer_init( init_info: NCCLWeightTransferInitInfo | dict, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cb3f3444ebac..fceaae824d7e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1462,14 +1462,7 @@ def init_weight_transfer_engine( ) def start_weight_update(self, is_checkpoint_format: bool = True) -> None: - """ - Start a new weight update. - - Args: - is_checkpoint_format: Whether incoming weights are in checkpoint - format (need layerwise processing) or kernel format (direct - copy). - """ + """Start a new weight update.""" self.llm_engine.collective_rpc( "start_weight_update", kwargs={"is_checkpoint_format": is_checkpoint_format}, @@ -1491,9 +1484,7 @@ def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None: ) def finish_weight_update(self) -> None: - """ - Finish the current weight update. - """ + """Finish the current weight update.""" self.llm_engine.collective_rpc("finish_weight_update") def __repr__(self) -> str: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d51bf2284096..41b0b96a5b5b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -49,6 +49,7 @@ is_global_first_rank, prepare_communication_buffer_for_model, ) +from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.forward_context import ( BatchDescriptor, set_forward_context, @@ -3181,6 +3182,44 @@ def get_model(self) -> nn.Module: return self.model.unwrap() return self.model + def apply_sparse_weight_patches(self, patches: Iterable[SparseWeightPatch]) -> None: + """Apply sparse flat-index patches directly to existing model params.""" + model = self.get_model() + for patch in patches: + param = model.get_parameter(patch.name) + if not param.data.is_contiguous(): + raise NotImplementedError( + "Sparse weight updates currently require contiguous params: " + f"{patch.name}" + ) + + if patch.indices.dtype != torch.int32: + raise ValueError( + "Sparse weight updates currently require int32 indices: " + f"{patch.name}" + ) + if patch.indices.ndim != 1 or patch.values.ndim != 1: + raise ValueError( + f"Sparse weight patches must be 1D flattened updates: {patch.name}" + ) + if patch.indices.numel() != patch.values.numel(): + raise ValueError( + "`indices` and `values` must have matching lengths for " + f"{patch.name}" + ) + if patch.values.dtype != param.dtype: + raise ValueError( + f"Sparse values dtype {patch.values.dtype} does not match " + f"parameter dtype {param.dtype} for {patch.name}" + ) + + flat_param = param.data.view(-1) + flat_param.index_copy_( + 0, + patch.indices.to(device=flat_param.device, dtype=torch.long), + patch.values.to(device=flat_param.device), + ) + def get_supported_generation_tasks(self) -> list[GenerationTask]: model = self.get_model() supported_tasks = list[GenerationTask]() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 582c6a17cb4e..d04263e91472 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -989,23 +989,19 @@ def init_weight_transfer_engine(self, init_info: dict) -> None: def start_weight_update(self, is_checkpoint_format: bool = True) -> None: """ - Start a new weight update. - - Prepares the model for receiving weights. For checkpoint format, - this initializes state for layerwise processing. For kernel format, this is - a no-op but must still be called for consistency. + Start a new weight update session. Args: is_checkpoint_format: Whether incoming weights are in checkpoint format (need layerwise processing) or kernel format (direct - copy). Stored as state for finish_weight_update. + copy / sparse patch application). """ self._check_weight_transfer_engine() if self._weight_update_active: raise RuntimeError( - "start_weight_update called while a weight update is " - "already active. Call finish_weight_update first." + "start_weight_update called while a weight update is already " + "active. Call finish_weight_update first." ) if is_checkpoint_format: @@ -1017,16 +1013,15 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None: with torch.device(self.device): initialize_layerwise_reload(model) - # Store state so update_weights/finish_weight_update can check self._is_checkpoint_format = is_checkpoint_format self._weight_update_active = True def update_weights(self, update_info: dict) -> None: """ - Receive weights from the trainer (one or more chunks). + Receive one weight update chunk from the trainer. start_weight_update must be called before update_weights and - finish_weight_update must be called after. + finish_weight_update must be called after all chunks have been sent. Args: update_info: Dictionary containing backend-specific update info @@ -1039,52 +1034,72 @@ def update_weights(self, update_info: dict) -> None: "start_weight_update must be called before update_weights." ) - # Parse dict into backend-specific typed dataclass - typed_update_info = self.weight_transfer_engine.parse_update_info(update_info) - - model = self.model_runner.model + update_succeeded = False + try: + # Parse dict into backend-specific typed dataclass + typed_update_info = self.weight_transfer_engine.parse_update_info( + update_info + ) - with torch.device(self.device): - if self._is_checkpoint_format: - self.weight_transfer_engine.receive_weights( - typed_update_info, - load_weights=model.load_weights, - ) - else: - # Weights are already in kernel format, copy directly - def load_weights_direct( - weights: list[tuple[str, torch.Tensor]], - ) -> None: - for name, weight in weights: - param = model.get_parameter(name) - param.copy_(weight) - - self.weight_transfer_engine.receive_weights( - typed_update_info, - load_weights=load_weights_direct, - ) + with torch.device(self.device): + if self._is_checkpoint_format: + if typed_update_info.update_kind != "dense": + raise ValueError( + "Sparse weight updates require " + "`start_weight_update(is_checkpoint_format=False)`." + ) + + model = self.model_runner.model + + # Use layerwise reload pattern for checkpoint format weights + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=model.load_weights, + ) + elif typed_update_info.update_kind == "sparse_flat": + if self.parallel_config.world_size != 1: + raise NotImplementedError( + "Sparse weight updates currently require TP=1 and PP=1" + ) + self.weight_transfer_engine.receive_sparse_weights( + typed_update_info, + apply_patches=self.model_runner.apply_sparse_weight_patches, + ) + else: + model = self.model_runner.model + + # Weights are already in kernel format, copy directly. + def load_weights_direct( + weights: list[tuple[str, torch.Tensor]], + ) -> None: + for name, weight in weights: + param = model.get_parameter(name) + param.copy_(weight) + + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=load_weights_direct, + ) - # NCCL broadcast/packed path are asynchronous. - # Sync here so the next step uses the new weights. - torch.accelerator.synchronize() + # NCCL broadcast/packed path are asynchronous. + # Sync here so the next step uses the new weights. + torch.accelerator.synchronize() + update_succeeded = True + finally: + if not update_succeeded: + self._weight_update_active = False + self._is_checkpoint_format = True def finish_weight_update(self) -> None: - """ - Finish the current weight update. - - For checkpoint format, this runs layerwise postprocessing. - Uses the is_checkpoint_format state stored by start_weight_update. - """ + """Finish the current weight update session.""" self._check_weight_transfer_engine() if not self._weight_update_active: raise RuntimeError( - "start_weight_update must be called before finish_weight_update." + "finish_weight_update called without a matching start_weight_update." ) - is_checkpoint_format = self._is_checkpoint_format - - if is_checkpoint_format: + if self._is_checkpoint_format: from vllm.model_executor.model_loader.reload import ( finalize_layerwise_reload, ) @@ -1093,7 +1108,6 @@ def finish_weight_update(self) -> None: with torch.device(self.device): finalize_layerwise_reload(model, self.model_config) - # Reset state self._weight_update_active = False self._is_checkpoint_format = True